From 7b9acabe89286b98a5a4da3e672a1a1efc5f54d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ol=C5=A1=C3=A1k?= Date: Fri, 27 Jun 2025 12:46:22 -0400 Subject: [PATCH] nir/group_loads: store our custom instr->index in an array we'll put more stuff in the new structure Acked-by: Alyssa Rosenzweig Part-of: --- src/compiler/nir/nir_group_loads.c | 107 ++++++++++++++++++----------- 1 file changed, 67 insertions(+), 40 deletions(-) diff --git a/src/compiler/nir/nir_group_loads.c b/src/compiler/nir/nir_group_loads.c index dc40d6c0f46..929589055a6 100644 --- a/src/compiler/nir/nir_group_loads.c +++ b/src/compiler/nir/nir_group_loads.c @@ -49,6 +49,12 @@ */ #include "nir.h" +#include "util/u_dynarray.h" + +typedef struct { + bool visited; + uint32_t instr_index; +} instr_info; static nir_instr * get_load_resource(nir_instr *instr) @@ -136,8 +142,9 @@ is_part_of_group(nir_instr *instr, uint8_t current_indirection_level) } struct check_sources_state { + instr_info *infos; nir_block *block; - uint32_t first_index; + uint32_t first_instr_index; }; static bool @@ -147,14 +154,15 @@ has_only_sources_less_than(nir_src *src, void *data) /* true if nir_foreach_src should keep going */ return state->block != src->ssa->parent_instr->block || - src->ssa->parent_instr->index < state->first_index; + state->infos[src->ssa->parent_instr->index].instr_index < + state->first_instr_index; } static void -group_loads(nir_instr *first, nir_instr *last) +group_loads(nir_instr *first, nir_instr *last, instr_info *infos) { - assert(nir_is_grouped_load(first)); - assert(nir_is_grouped_load(last)); + assert(is_grouped_load(first)); + assert(is_grouped_load(last)); /* Walk the instruction range between the first and last backward, and * move those that have no uses within the range after the last one. @@ -170,7 +178,8 @@ group_loads(nir_instr *first, nir_instr *last) if (def) { nir_foreach_use(use, def) { if (nir_src_parent_instr(use)->block == instr->block && - nir_src_parent_instr(use)->index <= last->index) { + infos[nir_src_parent_instr(use)->index].instr_index <= + infos[last->index].instr_index) { all_uses_after_last = false; break; } @@ -188,13 +197,15 @@ group_loads(nir_instr *first, nir_instr *last) * indicate that it's after it. */ nir_instr_move(nir_after_instr(last), move_instr); - move_instr->index = last->index + 1; + infos[move_instr->index].instr_index = + infos[last->index].instr_index + 1; } } struct check_sources_state state; + state.infos = infos; state.block = first->block; - state.first_index = first->index; + state.first_instr_index = infos[first->index].instr_index; /* Walk the instruction range between the first and last forward, and move * those that have no sources within the range before the first one. @@ -214,7 +225,8 @@ group_loads(nir_instr *first, nir_instr *last) * to indicate that it's before it. */ nir_instr_move(nir_before_instr(first), move_instr); - move_instr->index = first->index - 1; + infos[move_instr->index].instr_index = + infos[first->index].instr_index - 1; } } } @@ -230,9 +242,9 @@ is_pseudo_inst(nir_instr *instr) } static void -set_instr_indices(nir_block *block) +set_instr_indices(nir_block *block, instr_info *infos) { - /* Start with 1 because we'll move instruction before the first one + /* Start with 1 because we'll move instructions before the first one * and will want to label it 0. */ unsigned counter = 1; @@ -246,7 +258,7 @@ set_instr_indices(nir_block *block) counter++; /* Set each instruction's index within the block. */ - instr->index = counter; + infos[instr->index].instr_index = counter; /* Only count non-pseudo instructions. */ if (!is_pseudo_inst(instr)) @@ -257,15 +269,19 @@ set_instr_indices(nir_block *block) } static void -handle_load_range(nir_instr **first, nir_instr **last, - nir_instr *current, unsigned max_distance) +handle_load_range(nir_instr **first, nir_instr **last, nir_instr *current, + unsigned max_distance, instr_info *infos) { - assert(!current || !*first || current->index >= (*first)->index); + assert(!current || !*first || + infos[current->index].instr_index >= + infos[(*first)->index].instr_index); if (*first && *last && - (!current || current->index - (*first)->index > max_distance)) { + (!current || + infos[current->index].instr_index - + infos[(*first)->index].instr_index > max_distance)) { assert(*first != *last); - group_loads(*first, *last); - set_instr_indices((*first)->block); + group_loads(*first, *last, infos); + set_instr_indices((*first)->block, infos); *first = NULL; *last = NULL; } @@ -288,12 +304,13 @@ is_demote(nir_instr *instr) } struct indirection_state { + instr_info *infos; nir_block *block; unsigned indirections; }; static unsigned -get_num_indirections(nir_instr *instr); +get_num_indirections(nir_instr *instr, instr_info *infos); static bool gather_indirections(nir_src *src, void *data) @@ -303,7 +320,8 @@ gather_indirections(nir_src *src, void *data) /* We only count indirections within the same block. */ if (instr->block == state->block) { - unsigned indirections = get_num_indirections(src->ssa->parent_instr); + unsigned indirections = get_num_indirections(src->ssa->parent_instr, + state->infos); if (instr->type == nir_instr_type_tex || is_grouped_load(instr)) indirections++; @@ -316,7 +334,7 @@ gather_indirections(nir_src *src, void *data) /* Return the number of load indirections within the block. */ static unsigned -get_num_indirections(nir_instr *instr) +get_num_indirections(nir_instr *instr, instr_info *infos) { /* Don't traverse phis because we could end up in an infinite recursion * if the phi points to the current block (such as a loop body). @@ -324,35 +342,30 @@ get_num_indirections(nir_instr *instr) if (instr->type == nir_instr_type_phi) return 0; - if (instr->index != UINT32_MAX) - return instr->index; /* we've visited this instruction before */ + if (infos[instr->index].visited) + return infos[instr->index].instr_index; struct indirection_state state; + state.infos = infos; state.block = instr->block; state.indirections = 0; nir_foreach_src(instr, gather_indirections, &state); - instr->index = state.indirections; + infos[instr->index].visited = true; + infos[instr->index].instr_index = state.indirections; return state.indirections; } static void process_block(nir_block *block, nir_load_grouping grouping, - unsigned max_distance) + unsigned max_distance, instr_info *infos) { int max_indirection = -1; unsigned num_inst_per_level[256] = { 0 }; - /* UINT32_MAX means the instruction has not been visited. Once - * an instruction has been visited and its indirection level has been - * determined, we'll store the indirection level in the index. The next - * instruction that visits it will use the index instead of recomputing - * the indirection level, which would result in an exponetial time - * complexity. - */ - nir_foreach_instr(instr, block) { - instr->index = UINT32_MAX; /* unknown */ + for (unsigned i = 0; i < block->end_ip + 1 - block->start_ip; i++) { + infos[block->start_ip + i].visited = false; } /* Count the number of load indirections for each load instruction @@ -360,7 +373,7 @@ process_block(nir_block *block, nir_load_grouping grouping, */ nir_foreach_instr(instr, block) { if (is_grouped_load(instr)) { - unsigned indirections = get_num_indirections(instr); + unsigned indirections = get_num_indirections(instr, infos); /* pass_flags has only 8 bits */ indirections = MIN2(indirections, 255); @@ -379,7 +392,7 @@ process_block(nir_block *block, nir_load_grouping grouping, if (num_inst_per_level[level] <= 1) continue; - set_instr_indices(block); + set_instr_indices(block, infos); nir_instr *resource = NULL; nir_instr *first_load = NULL, *last_load = NULL; @@ -393,7 +406,7 @@ process_block(nir_block *block, nir_load_grouping grouping, /* Don't group across terminate. */ if (is_demote(current)) { /* Group unconditionally. */ - handle_load_range(&first_load, &last_load, NULL, 0); + handle_load_range(&first_load, &last_load, NULL, 0, infos); first_load = NULL; last_load = NULL; continue; @@ -426,11 +439,12 @@ process_block(nir_block *block, nir_load_grouping grouping, } /* Group only if we exceeded the maximum distance. */ - handle_load_range(&first_load, &last_load, current, max_distance); + handle_load_range(&first_load, &last_load, current, max_distance, + infos); } /* Group unconditionally. */ - handle_load_range(&first_load, &last_load, NULL, 0); + handle_load_range(&first_load, &last_load, NULL, 0, infos); } } @@ -441,14 +455,27 @@ bool nir_group_loads(nir_shader *shader, nir_load_grouping grouping, unsigned max_distance) { + /* Temporary space for instruction info. */ + struct util_dynarray infos_scratch; + util_dynarray_init(&infos_scratch, NULL); + nir_foreach_function_impl(impl, shader) { + nir_metadata_require(impl, nir_metadata_instr_index); + + unsigned num_instr = + nir_impl_last_block(impl)->end_ip + 1; /* we might need 1 more */ + instr_info *infos = + (instr_info*)util_dynarray_resize(&infos_scratch, instr_info, + num_instr); + nir_foreach_block(block, impl) { - process_block(block, grouping, max_distance); + process_block(block, grouping, max_distance, infos); } nir_progress(true, impl, nir_metadata_control_flow | nir_metadata_loop_analysis); } + util_dynarray_fini(&infos_scratch); return true; }