From b9a1bcd3a1813ef5d7451cf06dcca6ab75f73acf Mon Sep 17 00:00:00 2001 From: Gert Wollny Date: Wed, 17 Sep 2025 11:34:29 +0200 Subject: [PATCH] r600/sfn: replace hand coded comparison opts with opt_algebraic With that we can easily add a restriction to the not + flt -> fge optimization to handle NaNs like it was done before. Fixes: 51d8ca2dff0 ("r600/sfn: optimize comparison results") v2: use SPDX license identifier (austriancoder) Signed-off-by: Gert Wollny Part-of: --- src/gallium/drivers/r600/meson.build | 13 +- src/gallium/drivers/r600/sfn/sfn_nir.cpp | 3 +- src/gallium/drivers/r600/sfn/sfn_nir.h | 7 +- .../drivers/r600/sfn/sfn_nir_algebraic.py | 73 +++++++++++ .../drivers/r600/sfn/sfn_nir_lower_alu.cpp | 119 ------------------ .../drivers/r600/sfn/sfn_nir_lower_alu.h | 3 - 6 files changed, 91 insertions(+), 127 deletions(-) create mode 100644 src/gallium/drivers/r600/sfn/sfn_nir_algebraic.py diff --git a/src/gallium/drivers/r600/meson.build b/src/gallium/drivers/r600/meson.build index c759e46a0bc..5b8a9c5b882 100644 --- a/src/gallium/drivers/r600/meson.build +++ b/src/gallium/drivers/r600/meson.build @@ -144,6 +144,17 @@ egd_tables_h = custom_target( capture : true, ) +sfn_nir_algebraic_c = custom_target( + 'sfn_nir_algebraic_c', + input : 'sfn/sfn_nir_algebraic.py', + output : 'sfn_nir_algebraic.c', + command : [ + prog_python, '@INPUT@', '-p', dir_compiler_nir, + ], + capture : true, + depend_files : nir_algebraic_depends, +) + r600_c_args = [] r600_cpp_args = [] @@ -154,7 +165,7 @@ endif libr600 = static_library( 'r600', - [files_r600, egd_tables_h, sha1_h], + [files_r600, egd_tables_h, sfn_nir_algebraic_c, sha1_h], c_args : [r600_c_args, '-Wstrict-overflow=0'], cpp_args: r600_cpp_args, gnu_symbol_visibility : 'hidden', diff --git a/src/gallium/drivers/r600/sfn/sfn_nir.cpp b/src/gallium/drivers/r600/sfn/sfn_nir.cpp index 180d3453604..08e067ffa6e 100644 --- a/src/gallium/drivers/r600/sfn/sfn_nir.cpp +++ b/src/gallium/drivers/r600/sfn/sfn_nir.cpp @@ -653,7 +653,7 @@ optimize_once(nir_shader *shader) NIR_PASS(progress, shader, nir_opt_dce); NIR_PASS(progress, shader, nir_opt_undef); NIR_PASS(progress, shader, nir_opt_loop_unroll); - NIR_PASS(progress, shader, r600_nir_opt_compare_results); + NIR_PASS(progress, shader, r600_sfn_lower_alu); return progress; } @@ -879,6 +879,7 @@ r600_lower_and_optimize_nir(nir_shader *sh, bool late_algebraic_progress; do { late_algebraic_progress = false; + NIR_PASS(late_algebraic_progress, sh, r600_sfn_lower_alu); NIR_PASS(late_algebraic_progress, sh, nir_opt_algebraic_late); NIR_PASS(late_algebraic_progress, sh, nir_opt_constant_folding); NIR_PASS(late_algebraic_progress, sh, nir_copy_prop); diff --git a/src/gallium/drivers/r600/sfn/sfn_nir.h b/src/gallium/drivers/r600/sfn/sfn_nir.h index ed78d65a0a7..ef7251eded0 100644 --- a/src/gallium/drivers/r600/sfn/sfn_nir.h +++ b/src/gallium/drivers/r600/sfn/sfn_nir.h @@ -7,9 +7,6 @@ #ifndef SFN_NIR_H #define SFN_NIR_H -#include "gallium/include/pipe/p_state.h" - -#include "amd_family.h" #include "nir.h" #include "nir_builder.h" @@ -119,8 +116,12 @@ r600_lower_and_optimize_nir(nir_shader *sh, void r600_finalize_nir_common(nir_shader *nir, enum amd_gfx_level gfx_level); +bool +r600_sfn_lower_alu(nir_shader *shader); + #ifdef __cplusplus } #endif + #endif // SFN_NIR_H diff --git a/src/gallium/drivers/r600/sfn/sfn_nir_algebraic.py b/src/gallium/drivers/r600/sfn/sfn_nir_algebraic.py new file mode 100644 index 00000000000..efb3f16c0d5 --- /dev/null +++ b/src/gallium/drivers/r600/sfn/sfn_nir_algebraic.py @@ -0,0 +1,73 @@ +# +# Copyright 2025 Collabora Ltd. +# SPDX-License-Identifier: MIT +# + +import argparse +import sys + +lower_alu = [ + + + # this partially duplicates stuff from nir_opt_algebraic, + # but without the 'is_used_once' because on r600 the sequence + # c = comp(a, b) + # d = inot(c) + # requires two instruction groups whereas + # c = comp(a, b) + # d = comp_inv(a, b) + # can be put into one instruction group, that is, if c is not used + # we reduced the code by one instruction and potentially one instruction + # group, if c is used, the we still may need one instruction group less. + + (('inot', ('flt', 'a(is_a_number)', 'b(is_a_number)')), ('fge', 'a', 'b')), + (('inot', ('fge', 'a(is_a_number)', 'b(is_a_number)')), ('flt', 'a', 'b')), + + (('inot', ('fneu', 'a', 'b')), ('feq', 'a', 'b')), + (('inot', ('feq', 'a', 'b')), ('fneu', 'a', 'b')), + + (('inot', ('ilt', 'a', 'b')), ('ige', 'a', 'b')), + (('inot', ('ige', 'a', 'b')), ('ilt', 'a', 'b')), + (('inot', ('ult', 'a', 'b')), ('uge', 'a', 'b')), + (('inot', ('uge', 'a', 'b')), ('ult', 'a', 'b')), + (('inot', ('ieq', 'a', 'b')), ('ine', 'a', 'b')), + (('inot', ('ine', 'a', 'b')), ('ieq', 'a', 'b')), + + (('b2f32', ('fge', 'a@32', 'b@32')), ('sge', 'a', 'b')), + (('b2f32', ('flt', 'a@32', 'b@32')), ('slt', 'a', 'b')), + (('b2f32', ('feq', 'a@32', 'b@32')), ('seq', 'a', 'b')), + (('b2f32', ('fneu', 'a@32', 'b@32')), ('sne', 'a', 'b')), + + (('flt', ('fadd', 'a', 'b'), 0.0), ('flt', 'a', ('fneg', 'b'))), + (('flt', 0.0, ('fadd', 'a', 'b')), ('flt', ('fneg', 'b'), 'a')), + + (('slt', ('fadd', 'a', 'b'), 0.0), ('slt', 'a', ('fneg', 'b'))), + (('slt', 0.0, ('fadd', 'a', 'b')), ('slt', ('fneg', 'b'), 'a')), + + (('sge', ('fadd', 'a', 'b'), 0.0), ('sge', 'a', ('fneg', 'b'))), + (('sge', 0.0, ('fadd', 'a', 'b')), ('sge', ('fneg', 'b'), 'a')), + + (('seq', ('fadd', 'a', 'b'), 0.0), ('seq', 'a', ('fneg', 'b'))), + (('sne', ('fadd', 'a', 'b'), 0.0), ('sne', 'a', ('fneg', 'b'))), +] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--import-path', required=True) + args = parser.parse_args() + sys.path.insert(0, args.import_path) + run() + + +def run(): + import nir_algebraic # pylint: disable=import-error + + print('#include "nir_search_helpers.h"') + print('#include "sfn/sfn_nir.h"') + + print(nir_algebraic.AlgebraicPass("r600_sfn_lower_alu", + lower_alu).render()) + +if __name__ == '__main__': + main() diff --git a/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.cpp b/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.cpp index 7ee9b77227f..5e121875feb 100644 --- a/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.cpp +++ b/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.cpp @@ -148,119 +148,6 @@ nir_def *FixKcacheIndirectRead::lower(nir_instr *instr) return result; } -class OptNotFromComparison : public NirLowerInstruction { -private: - bool filter(const nir_instr *instr) const override; - nir_def *lower(nir_instr *instr) override; -}; - -bool -OptNotFromComparison::filter(const nir_instr *instr) const -{ - if (instr->type != nir_instr_type_alu) - return false; - - auto alu = nir_instr_as_alu(instr); - if (alu->src[0].src.ssa->parent_instr->type != nir_instr_type_alu) - return false; - - auto p = nir_def_as_alu(alu->src[0].src.ssa); - - switch (alu->op) { - case nir_op_inot: - switch (p->op) { - case nir_op_flt: - case nir_op_fge: - case nir_op_feq: - case nir_op_fneu: - case nir_op_ilt: - case nir_op_ult: - case nir_op_ige: - case nir_op_uge: - case nir_op_ieq: - case nir_op_ine: - return true; - default: - return false; - } - case nir_op_b2f32: - if (p->src[0].src.ssa->bit_size != 32) - return false; - switch (p->op) { - case nir_op_fge: - case nir_op_flt: - case nir_op_feq: - case nir_op_fneu: - return true; - default: - return false; - } - default: - return false; - } - - return true; -} - -nir_def * -OptNotFromComparison::lower(nir_instr *instr) -{ - auto alu = nir_instr_as_alu(instr); - - auto p = nir_def_as_alu(alu->src[0].src.ssa); - - auto src0 = nir_channel(b, p->src[0].src.ssa, p->src[0].swizzle[0]); - auto src1 = nir_channel(b, p->src[1].src.ssa, p->src[1].swizzle[0]); - - switch (alu->op) { - case nir_op_inot: - - switch (p->op) { - case nir_op_flt: - return nir_fge(b, src0, src1); - case nir_op_fge: - return nir_flt(b, src0, src1); - case nir_op_feq: - return nir_fneu(b, src0, src1); - case nir_op_fneu: - return nir_feq(b, src0, src1); - - case nir_op_ilt: - return nir_ige(b, src0, src1); - case nir_op_ult: - return nir_uge(b, src0, src1); - - case nir_op_ige: - return nir_ilt(b, src0, src1); - case nir_op_uge: - return nir_ult(b, src0, src1); - - case nir_op_ieq: - return nir_ine(b, src0, src1); - case nir_op_ine: - return nir_ieq(b, src0, src1); - default: - return 0; - } - case nir_op_b2f32: - switch (p->op) { - case nir_op_fge: - return nir_sge(b, src0, src1); - case nir_op_flt: - return nir_slt(b, src0, src1); - case nir_op_feq: - return nir_seq(b, src0, src1); - case nir_op_fneu: - return nir_sne(b, src0, src1); - default: - return 0; - } - default: - return 0; - } - return 0; -} - } // namespace r600 bool @@ -281,9 +168,3 @@ r600_nir_fix_kcache_indirect_access(nir_shader *shader) return shader->info.num_ubos > 14 ? r600::FixKcacheIndirectRead().run(shader) : false; } - -bool -r600_nir_opt_compare_results(nir_shader *shader) -{ - return r600::OptNotFromComparison().run(shader); -} diff --git a/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.h b/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.h index b6e5f858fb6..06df6242396 100644 --- a/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.h +++ b/src/gallium/drivers/r600/sfn/sfn_nir_lower_alu.h @@ -13,9 +13,6 @@ bool r600_nir_lower_pack_unpack_2x16(nir_shader *shader); -bool -r600_nir_opt_compare_results(nir_shader *shader); - bool r600_nir_lower_trigen(nir_shader *shader, enum amd_gfx_level gfx_level);