nir/group_loads: store our custom instr->index in an array

we'll put more stuff in the new structure

Acked-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36100>
This commit is contained in:
Marek Olšák
2025-06-27 12:46:22 -04:00
committed by Marge Bot
parent 821dc611c5
commit 7b9acabe89
+67 -40
View File
@@ -49,6 +49,12 @@
*/
#include "nir.h"
#include "util/u_dynarray.h"
typedef struct {
bool visited;
uint32_t instr_index;
} instr_info;
static nir_instr *
get_load_resource(nir_instr *instr)
@@ -136,8 +142,9 @@ is_part_of_group(nir_instr *instr, uint8_t current_indirection_level)
}
struct check_sources_state {
instr_info *infos;
nir_block *block;
uint32_t first_index;
uint32_t first_instr_index;
};
static bool
@@ -147,14 +154,15 @@ has_only_sources_less_than(nir_src *src, void *data)
/* true if nir_foreach_src should keep going */
return state->block != src->ssa->parent_instr->block ||
src->ssa->parent_instr->index < state->first_index;
state->infos[src->ssa->parent_instr->index].instr_index <
state->first_instr_index;
}
static void
group_loads(nir_instr *first, nir_instr *last)
group_loads(nir_instr *first, nir_instr *last, instr_info *infos)
{
assert(nir_is_grouped_load(first));
assert(nir_is_grouped_load(last));
assert(is_grouped_load(first));
assert(is_grouped_load(last));
/* Walk the instruction range between the first and last backward, and
* move those that have no uses within the range after the last one.
@@ -170,7 +178,8 @@ group_loads(nir_instr *first, nir_instr *last)
if (def) {
nir_foreach_use(use, def) {
if (nir_src_parent_instr(use)->block == instr->block &&
nir_src_parent_instr(use)->index <= last->index) {
infos[nir_src_parent_instr(use)->index].instr_index <=
infos[last->index].instr_index) {
all_uses_after_last = false;
break;
}
@@ -188,13 +197,15 @@ group_loads(nir_instr *first, nir_instr *last)
* indicate that it's after it.
*/
nir_instr_move(nir_after_instr(last), move_instr);
move_instr->index = last->index + 1;
infos[move_instr->index].instr_index =
infos[last->index].instr_index + 1;
}
}
struct check_sources_state state;
state.infos = infos;
state.block = first->block;
state.first_index = first->index;
state.first_instr_index = infos[first->index].instr_index;
/* Walk the instruction range between the first and last forward, and move
* those that have no sources within the range before the first one.
@@ -214,7 +225,8 @@ group_loads(nir_instr *first, nir_instr *last)
* to indicate that it's before it.
*/
nir_instr_move(nir_before_instr(first), move_instr);
move_instr->index = first->index - 1;
infos[move_instr->index].instr_index =
infos[first->index].instr_index - 1;
}
}
}
@@ -230,9 +242,9 @@ is_pseudo_inst(nir_instr *instr)
}
static void
set_instr_indices(nir_block *block)
set_instr_indices(nir_block *block, instr_info *infos)
{
/* Start with 1 because we'll move instruction before the first one
/* Start with 1 because we'll move instructions before the first one
* and will want to label it 0.
*/
unsigned counter = 1;
@@ -246,7 +258,7 @@ set_instr_indices(nir_block *block)
counter++;
/* Set each instruction's index within the block. */
instr->index = counter;
infos[instr->index].instr_index = counter;
/* Only count non-pseudo instructions. */
if (!is_pseudo_inst(instr))
@@ -257,15 +269,19 @@ set_instr_indices(nir_block *block)
}
static void
handle_load_range(nir_instr **first, nir_instr **last,
nir_instr *current, unsigned max_distance)
handle_load_range(nir_instr **first, nir_instr **last, nir_instr *current,
unsigned max_distance, instr_info *infos)
{
assert(!current || !*first || current->index >= (*first)->index);
assert(!current || !*first ||
infos[current->index].instr_index >=
infos[(*first)->index].instr_index);
if (*first && *last &&
(!current || current->index - (*first)->index > max_distance)) {
(!current ||
infos[current->index].instr_index -
infos[(*first)->index].instr_index > max_distance)) {
assert(*first != *last);
group_loads(*first, *last);
set_instr_indices((*first)->block);
group_loads(*first, *last, infos);
set_instr_indices((*first)->block, infos);
*first = NULL;
*last = NULL;
}
@@ -288,12 +304,13 @@ is_demote(nir_instr *instr)
}
struct indirection_state {
instr_info *infos;
nir_block *block;
unsigned indirections;
};
static unsigned
get_num_indirections(nir_instr *instr);
get_num_indirections(nir_instr *instr, instr_info *infos);
static bool
gather_indirections(nir_src *src, void *data)
@@ -303,7 +320,8 @@ gather_indirections(nir_src *src, void *data)
/* We only count indirections within the same block. */
if (instr->block == state->block) {
unsigned indirections = get_num_indirections(src->ssa->parent_instr);
unsigned indirections = get_num_indirections(src->ssa->parent_instr,
state->infos);
if (instr->type == nir_instr_type_tex || is_grouped_load(instr))
indirections++;
@@ -316,7 +334,7 @@ gather_indirections(nir_src *src, void *data)
/* Return the number of load indirections within the block. */
static unsigned
get_num_indirections(nir_instr *instr)
get_num_indirections(nir_instr *instr, instr_info *infos)
{
/* Don't traverse phis because we could end up in an infinite recursion
* if the phi points to the current block (such as a loop body).
@@ -324,35 +342,30 @@ get_num_indirections(nir_instr *instr)
if (instr->type == nir_instr_type_phi)
return 0;
if (instr->index != UINT32_MAX)
return instr->index; /* we've visited this instruction before */
if (infos[instr->index].visited)
return infos[instr->index].instr_index;
struct indirection_state state;
state.infos = infos;
state.block = instr->block;
state.indirections = 0;
nir_foreach_src(instr, gather_indirections, &state);
instr->index = state.indirections;
infos[instr->index].visited = true;
infos[instr->index].instr_index = state.indirections;
return state.indirections;
}
static void
process_block(nir_block *block, nir_load_grouping grouping,
unsigned max_distance)
unsigned max_distance, instr_info *infos)
{
int max_indirection = -1;
unsigned num_inst_per_level[256] = { 0 };
/* UINT32_MAX means the instruction has not been visited. Once
* an instruction has been visited and its indirection level has been
* determined, we'll store the indirection level in the index. The next
* instruction that visits it will use the index instead of recomputing
* the indirection level, which would result in an exponetial time
* complexity.
*/
nir_foreach_instr(instr, block) {
instr->index = UINT32_MAX; /* unknown */
for (unsigned i = 0; i < block->end_ip + 1 - block->start_ip; i++) {
infos[block->start_ip + i].visited = false;
}
/* Count the number of load indirections for each load instruction
@@ -360,7 +373,7 @@ process_block(nir_block *block, nir_load_grouping grouping,
*/
nir_foreach_instr(instr, block) {
if (is_grouped_load(instr)) {
unsigned indirections = get_num_indirections(instr);
unsigned indirections = get_num_indirections(instr, infos);
/* pass_flags has only 8 bits */
indirections = MIN2(indirections, 255);
@@ -379,7 +392,7 @@ process_block(nir_block *block, nir_load_grouping grouping,
if (num_inst_per_level[level] <= 1)
continue;
set_instr_indices(block);
set_instr_indices(block, infos);
nir_instr *resource = NULL;
nir_instr *first_load = NULL, *last_load = NULL;
@@ -393,7 +406,7 @@ process_block(nir_block *block, nir_load_grouping grouping,
/* Don't group across terminate. */
if (is_demote(current)) {
/* Group unconditionally. */
handle_load_range(&first_load, &last_load, NULL, 0);
handle_load_range(&first_load, &last_load, NULL, 0, infos);
first_load = NULL;
last_load = NULL;
continue;
@@ -426,11 +439,12 @@ process_block(nir_block *block, nir_load_grouping grouping,
}
/* Group only if we exceeded the maximum distance. */
handle_load_range(&first_load, &last_load, current, max_distance);
handle_load_range(&first_load, &last_load, current, max_distance,
infos);
}
/* Group unconditionally. */
handle_load_range(&first_load, &last_load, NULL, 0);
handle_load_range(&first_load, &last_load, NULL, 0, infos);
}
}
@@ -441,14 +455,27 @@ bool
nir_group_loads(nir_shader *shader, nir_load_grouping grouping,
unsigned max_distance)
{
/* Temporary space for instruction info. */
struct util_dynarray infos_scratch;
util_dynarray_init(&infos_scratch, NULL);
nir_foreach_function_impl(impl, shader) {
nir_metadata_require(impl, nir_metadata_instr_index);
unsigned num_instr =
nir_impl_last_block(impl)->end_ip + 1; /* we might need 1 more */
instr_info *infos =
(instr_info*)util_dynarray_resize(&infos_scratch, instr_info,
num_instr);
nir_foreach_block(block, impl) {
process_block(block, grouping, max_distance);
process_block(block, grouping, max_distance, infos);
}
nir_progress(true, impl,
nir_metadata_control_flow | nir_metadata_loop_analysis);
}
util_dynarray_fini(&infos_scratch);
return true;
}