diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index 517f10d4e20..61e757023ad 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -81,6 +81,8 @@ nir_builder MUST_CHECK PRINTFLIKE(3, 4) typedef bool (*nir_instr_pass_cb)(struct nir_builder *, nir_instr *, void *); typedef bool (*nir_intrinsic_pass_cb)(struct nir_builder *, nir_intrinsic_instr *, void *); +typedef bool (*nir_alu_pass_cb)(struct nir_builder *, + nir_alu_instr *, void *); /** * Iterates over all the instructions in a NIR function and calls the given pass @@ -184,6 +186,39 @@ nir_shader_intrinsics_pass(nir_shader *shader, return progress; } +/* As above, but for ALU */ +static inline bool +nir_shader_alu_pass(nir_shader *shader, + nir_alu_pass_cb pass, + nir_metadata preserved, + void *cb_data) +{ + bool progress = false; + + nir_foreach_function_impl(impl, shader) { + bool func_progress = false; + nir_builder b = nir_builder_create(impl); + + nir_foreach_block_safe(block, impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type == nir_instr_type_alu) { + nir_alu_instr *intr = nir_instr_as_alu(instr); + func_progress |= pass(&b, intr, cb_data); + } + } + } + + if (func_progress) { + nir_metadata_preserve(impl, preserved); + progress = true; + } else { + nir_metadata_preserve(impl, nir_metadata_all); + } + } + + return progress; +} + void nir_builder_instr_insert(nir_builder *build, nir_instr *instr); void nir_builder_instr_insert_at_top(nir_builder *build, nir_instr *instr);