From f528597906ac0dd8f4fef5743ad10c1c8ce9937b Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Thu, 17 Oct 2024 16:19:17 +0100 Subject: [PATCH] aco: check for SDWA before applying extract to lshl/cvt_f32 Signed-off-by: Rhys Perry Reviewed-by: Georg Lehmann Part-of: --- src/amd/compiler/aco_optimizer.cpp | 13 +++-- src/amd/compiler/tests/helpers.cpp | 29 +++++++++-- src/amd/compiler/tests/helpers.h | 2 + src/amd/compiler/tests/test_sdwa.cpp | 77 ++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 11 deletions(-) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index c0526d4ff1b..de056aea09b 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -1031,10 +1031,10 @@ can_apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_i return true; } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 || instr->opcode == aco_opcode::v_cvt_f32_i32) && - sel.size() == 1 && !sel.sign_extend()) { + sel.size() == 1 && !sel.sign_extend() && !instr->usesModifiers()) { return true; } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() && - sel.offset() == 0 && + sel.offset() == 0 && !instr->usesModifiers() && ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) || (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) { return true; @@ -1055,9 +1055,8 @@ can_apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_i } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16 && sel.size() == 2 && (idx == 1 || ctx.program->gfx_level >= GFX11 || !sel.offset())) { return true; - } else if (sel.size() == 2 && - ((instr->opcode == aco_opcode::s_pack_lh_b32_b16 && idx == 0) || - (instr->opcode == aco_opcode::s_pack_hl_b32_b16 && idx == 1))) { + } else if (sel.size() == 2 && ((instr->opcode == aco_opcode::s_pack_lh_b32_b16 && idx == 0) || + (instr->opcode == aco_opcode::s_pack_hl_b32_b16 && idx == 1))) { return true; } else if (instr->opcode == aco_opcode::p_extract) { SubdwordSel instrSel = parse_extract(instr.get()); @@ -1095,7 +1094,7 @@ apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_info& /* full dword selection */ } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 || instr->opcode == aco_opcode::v_cvt_f32_i32) && - sel.size() == 1 && !sel.sign_extend()) { + sel.size() == 1 && !sel.sign_extend() && !instr->usesModifiers()) { switch (sel.offset()) { case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break; case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break; @@ -1103,7 +1102,7 @@ apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_info& case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break; } } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() && - sel.offset() == 0 && + sel.offset() == 0 && !instr->usesModifiers() && ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) || (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) { /* The undesirable upper bits are already shifted out. */ diff --git a/src/amd/compiler/tests/helpers.cpp b/src/amd/compiler/tests/helpers.cpp index 866c1368804..d9275a6a1a2 100644 --- a/src/amd/compiler/tests/helpers.cpp +++ b/src/amd/compiler/tests/helpers.cpp @@ -528,18 +528,39 @@ fmax(Temp src0, Temp src1, Builder b) return b.vop2(aco_opcode::v_max_f32, b.def(v1), src0, src1); } +static Temp +extract(Temp src, unsigned idx, unsigned size, bool sign_extend, Builder b) +{ + if (src.type() == RegType::sgpr) + return b.pseudo(aco_opcode::p_extract, b.def(src.regClass()), bld.def(s1, scc), src, + Operand::c32(idx), Operand::c32(size), Operand::c32(sign_extend)); + else + return b.pseudo(aco_opcode::p_extract, b.def(src.regClass()), src, Operand::c32(idx), + Operand::c32(size), Operand::c32(sign_extend)); +} + Temp ext_ushort(Temp src, unsigned idx, Builder b) { - return b.pseudo(aco_opcode::p_extract, b.def(src.regClass()), src, Operand::c32(idx), - Operand::c32(16u), Operand::c32(false)); + return extract(src, idx, 16, false, b); +} + +Temp +ext_sshort(Temp src, unsigned idx, Builder b) +{ + return extract(src, idx, 16, true, b); } Temp ext_ubyte(Temp src, unsigned idx, Builder b) { - return b.pseudo(aco_opcode::p_extract, b.def(src.regClass()), src, Operand::c32(idx), - Operand::c32(8u), Operand::c32(false)); + return extract(src, idx, 8, false, b); +} + +Temp +ext_sbyte(Temp src, unsigned idx, Builder b) +{ + return extract(src, idx, 8, true, b); } void diff --git a/src/amd/compiler/tests/helpers.h b/src/amd/compiler/tests/helpers.h index 7a7db78105d..a955764ce6f 100644 --- a/src/amd/compiler/tests/helpers.h +++ b/src/amd/compiler/tests/helpers.h @@ -98,7 +98,9 @@ aco::Temp fsat(aco::Temp src, aco::Builder b = bld); aco::Temp fmin(aco::Temp src0, aco::Temp src1, aco::Builder b = bld); aco::Temp fmax(aco::Temp src0, aco::Temp src1, aco::Builder b = bld); aco::Temp ext_ushort(aco::Temp src, unsigned idx, aco::Builder b = bld); +aco::Temp ext_sshort(aco::Temp src, unsigned idx, aco::Builder b = bld); aco::Temp ext_ubyte(aco::Temp src, unsigned idx, aco::Builder b = bld); +aco::Temp ext_sbyte(aco::Temp src, unsigned idx, aco::Builder b = bld); void emit_divergent_if_else(aco::Program* prog, aco::Builder& b, aco::Operand cond, std::function then, std::function els); diff --git a/src/amd/compiler/tests/test_sdwa.cpp b/src/amd/compiler/tests/test_sdwa.cpp index 0b53736b5e9..0891d11ed32 100644 --- a/src/amd/compiler/tests/test_sdwa.cpp +++ b/src/amd/compiler/tests/test_sdwa.cpp @@ -557,3 +557,80 @@ BEGIN_TEST(optimize.sdwa.insert_modifiers) finish_opt_test(); } END_TEST + +BEGIN_TEST(optimize.sdwa.special_case_valu) + //>> v1: %a, s1: %b = p_startpgm + if (!setup_cs("v1 s1", GFX10_3)) + return; + + Temp a = inputs[0]; + Temp b = inputs[1]; + Temp b_vgpr = bld.copy(bld.def(v1), b); + + //! v1: %res0 = v_cvt_f32_ubyte0 %a + //! p_unit_test 0, %res0 + writeout(0, bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), ext_ubyte(a, 0))); + + //! v1: %res1 = v_cvt_f32_ubyte1 %a + //! p_unit_test 1, %res1 + writeout(1, bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), ext_ubyte(a, 1))); + + //! v1: %res2 = v_cvt_f32_ubyte2 %a + //! p_unit_test 2, %res2 + writeout(2, bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), ext_ubyte(a, 2))); + + //! v1: %res3 = v_cvt_f32_ubyte3 %a + //! p_unit_test 3, %res3 + writeout(3, bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), ext_ubyte(a, 3))); + + //! v1: %res4 = v_cvt_f32_u32 %a dst_sel:dword src0_sel:sbyte3 + //! p_unit_test 4, %res4 + writeout(4, bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), ext_sbyte(a, 3))); + + //! v1: %res5 = v_cvt_f32_u32 %a dst_sel:dword src0_sel:uword1 + //! p_unit_test 5, %res5 + writeout(5, bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), ext_ushort(a, 1))); + + //! v1: %res6_tmp = p_extract %b, 2, 8, 0 + //! v1: %res6 = v_cvt_f32_u32 %res6_tmp dst_sel:dword src0_sel:sword1 + //! p_unit_test 6, %res6 + writeout(6, + bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), ext_sshort(ext_ubyte(b_vgpr, 2), 1))); + + //! v1: %res7 = v_lshlrev_b32 16, %a + //! p_unit_test 7, %res7 + writeout(7, + bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(16), ext_ushort(a, 0))); + + //! v1: %res8 = v_lshlrev_b32 24, %a + //! p_unit_test 8, %res8 + writeout(8, bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(24), ext_ubyte(a, 0))); + + //! v1: %res9 = v_lshlrev_b32 16, %a + //! p_unit_test 9, %res9 + writeout(9, + bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(16), ext_sshort(a, 0))); + + //! v1: %res10 = v_lshlrev_b32 24, %a + //! p_unit_test 10, %res10 + writeout(10, + bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(24), ext_sbyte(a, 0))); + + //! v1: %res11 = v_lshlrev_b32 16, %a dst_sel:dword src0_sel:dword src1_sel:uword1 + //! p_unit_test 11, %res11 + writeout(11, + bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(16), ext_ushort(a, 1))); + + //! v1: %res12 = v_lshlrev_b32 24, %a dst_sel:dword src0_sel:dword src1_sel:ubyte1 + //! p_unit_test 12, %res12 + writeout(12, + bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(24), ext_ubyte(a, 1))); + + //! v1: %res13_tmp = p_extract %b, 0, 16, 1 + //! v1: %res13 = v_lshlrev_b32 16, %res13_tmp dst_sel:dword src0_sel:dword src1_sel:ubyte2 + //! p_unit_test 13, %res13 + writeout(13, bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(16), + ext_ubyte(ext_sshort(b_vgpr, 0), 2))); + + finish_opt_test(); +END_TEST