diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 7c56ca8f9be..8492fcac6c1 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -61,19 +61,8 @@ enum Label { label_canonicalized_fp16 = 1ull << 22, label_canonicalized_fp32 = 1ull << 23, label_canonicalized_fp64 = 1ull << 24, - - /* label_{omod2,omod4,omod5,clamp} are used for both 16 and - * 32-bit operations but this doesn't cause any issues because - * we check the definition register class. - */ - label_omod2 = 1ull << 33, - label_omod4 = 1ull << 34, - label_omod5 = 1ull << 35, - label_clamp = 1ull << 36, }; -static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp; - static constexpr uint64_t input_mod_labels = label_abs_fp16 | label_abs_fp32_64 | label_neg_fp16 | label_neg_fp32_64; @@ -99,8 +88,6 @@ canonicalized_label(unsigned bit_size) UNREACHABLE("unknown canonicalized size"); } -static_assert((instr_mod_labels & temp_labels) == 0, "labels cannot intersect"); -static_assert((instr_mod_labels & val_labels) == 0, "labels cannot intersect"); static_assert((temp_labels & val_labels) == 0, "labels cannot intersect"); struct ssa_info { @@ -108,7 +95,6 @@ struct ssa_info { union { uint64_t val; Temp temp; - Instruction* mod_instr; }; Instruction* parent_instr; @@ -116,19 +102,14 @@ struct ssa_info { void add_label(Label new_label) { - if (new_label & instr_mod_labels) { - label &= ~instr_mod_labels; - label &= ~(temp_labels | val_labels); /* instr, temp and val alias */ - } - if (new_label & temp_labels) { label &= ~temp_labels; - label &= ~(instr_mod_labels | val_labels); /* instr, temp and val alias */ + label &= ~val_labels; /* temp and val alias */ } if (new_label & val_labels) { label &= ~val_labels; - label &= ~(instr_mod_labels | temp_labels); /* instr, temp and val alias */ + label &= ~temp_labels; /* temp and val alias */ } label |= new_label; @@ -194,46 +175,6 @@ struct ssa_info { bool is_combined() { return label & label_combined_instr; } - void set_omod2(Instruction* mul) - { - if (label & temp_labels) - return; - add_label(label_omod2); - mod_instr = mul; - } - - bool is_omod2() { return label & label_omod2; } - - void set_omod4(Instruction* mul) - { - if (label & temp_labels) - return; - add_label(label_omod4); - mod_instr = mul; - } - - bool is_omod4() { return label & label_omod4; } - - void set_omod5(Instruction* mul) - { - if (label & temp_labels) - return; - add_label(label_omod5); - mod_instr = mul; - } - - bool is_omod5() { return label & label_omod5; } - - void set_clamp(Instruction* med3) - { - if (label & temp_labels) - return; - add_label(label_clamp); - mod_instr = med3; - } - - bool is_clamp() { return label & label_clamp; } - void set_uniform_bitwise() { add_label(label_uniform_bitwise); } bool is_uniform_bitwise() { return label & label_uniform_bitwise; } @@ -1675,34 +1616,6 @@ gather_canonicalized(opt_ctx& ctx, aco_ptr& instr) } } -bool -can_use_VOP3(opt_ctx& ctx, const aco_ptr& instr) -{ - if (instr->isVOP3()) - return true; - - if (instr->isVOP3P() || instr->isVINTERP_INREG()) - return false; - - if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10) - return false; - - if (instr->isSDWA()) - return false; - - if (instr->isDPP() && ctx.program->gfx_level < GFX11) - return false; - - return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 && - instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 && - instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 && - instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 && - instr->opcode != aco_opcode::v_permlane64_b32 && - instr->opcode != aco_opcode::v_readlane_b32 && - instr->opcode != aco_opcode::v_writelane_b32 && - instr->opcode != aco_opcode::v_readfirstlane_b32; -} - bool pseudo_propagate_temp(opt_ctx& ctx, aco_ptr& instr, Temp temp, unsigned index) { @@ -2417,7 +2330,7 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_va instr.reset(alu_opt_info_to_instr(ctx, result_info, instr.release())); for (const Definition& def : instr->definitions) - ctx.info[def.tempId()].label &= instr_mod_labels | canonicalized_labels; + ctx.info[def.tempId()].label &= canonicalized_labels; } void @@ -2856,8 +2769,7 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) case aco_opcode::v_mul_f32: case aco_opcode::v_mul_legacy_f32: case aco_opcode::v_mul_f64: - case aco_opcode::v_mul_f64_e64: { /* omod */ - /* TODO: try to move the negate/abs modifier to the consumer instead */ + case aco_opcode::v_mul_f64_e64: { bool uses_mods = instr->usesModifiers(); bool fp16 = instr->opcode == aco_opcode::v_mul_f16; bool fp64 = @@ -2875,19 +2787,13 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) bool neg1 = constant == -1.0; VALU_instruction* valu = &instr->valu(); - if (valu->abs[!i] || valu->neg[!i] || valu->omod) + if (valu->abs[!i] || valu->neg[!i] || valu->omod || valu->clamp) continue; bool abs = valu->abs[i]; bool neg = neg1 ^ valu->neg[i]; Temp other = instr->operands[i].getTemp(); - if (valu->clamp) { - if (!abs && !neg && other.type() == RegType::vgpr) - ctx.info[other.id()].set_clamp(instr.get()); - continue; - } - if (abs && neg && other.type() == RegType::vgpr) ctx.info[instr->definitions[0].tempId()].set_neg_abs(other, bit_size); else if (abs && !neg && other.type() == RegType::vgpr) @@ -2900,37 +2806,17 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) else ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other, bit_size); } - } else if (uses_mods || (instr->definitions[0].isSZPreserve() && - instr->opcode != aco_opcode::v_mul_legacy_f32)) { - continue; /* omod uses a legacy multiplication. */ - } else if (instr->operands[!i].constantValue64() == 0u && + } else if (!uses_mods && instr->operands[!i].constantValue64() == 0u && ((!instr->definitions[0].isNaNPreserve() && - !instr->definitions[0].isInfPreserve()) || + !instr->definitions[0].isInfPreserve() && + !instr->definitions[0].isSZPreserve()) || instr->opcode == aco_opcode::v_mul_legacy_f32)) { ctx.info[instr->definitions[0].tempId()].set_constant(0u); - } else if (denorm_mode != fp_denorm_flush) { - /* omod has no effect if denormals are enabled. */ - continue; - } else if (constant == 2.0) { - ctx.info[instr->operands[i].tempId()].set_omod2(instr.get()); - } else if (constant == 4.0) { - ctx.info[instr->operands[i].tempId()].set_omod4(instr.get()); - } else if (constant == 0.5) { - ctx.info[instr->operands[i].tempId()].set_omod5(instr.get()); - } else { - continue; } break; } break; } - case aco_opcode::v_med3_f16: - case aco_opcode::v_med3_f32: { /* clamp */ - unsigned idx; - if (detect_clamp(instr.get(), &idx) && !instr->valu().abs && !instr->valu().neg) - ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get()); - break; - } case aco_opcode::s_not_b32: case aco_opcode::s_not_b64: if (!instr->operands[0].isTemp()) { @@ -3480,98 +3366,44 @@ use_absdiff: return op_instr; } -bool -interp_can_become_fma(opt_ctx& ctx, aco_ptr& instr) +Instruction* +apply_clamp(opt_ctx& ctx, aco_ptr& instr, Instruction* parent) { - if (instr->opcode != aco_opcode::v_interp_p2_f32_inreg) - return false; + unsigned idx; + if (!detect_clamp(instr.get(), &idx)) + return nullptr; - instr->opcode = aco_opcode::v_fma_f32; - instr->format = Format::VOP3; - bool dpp_allowed = can_use_DPP(ctx.program->gfx_level, instr, false); - instr->opcode = aco_opcode::v_interp_p2_f32_inreg; - instr->format = Format::VINTERP_INREG; + aco_type type = instr_info.alu_opcode_infos[(int)instr->opcode].def_types[0]; - return dpp_allowed; -} + if (!ctx.info[parent->definitions[0].tempId()].is_canonicalized(type.bit_size) && + ctx.fp_mode.denorm32 != fp_denorm_keep) + return nullptr; -void -interp_p2_f32_inreg_to_fma_dpp(aco_ptr& instr) -{ - static_assert(sizeof(DPP16_instruction) == sizeof(VINTERP_inreg_instruction), - "Invalid instr cast."); - instr->format = asVOP3(Format::DPP16); - instr->opcode = aco_opcode::v_fma_f32; - instr->dpp16().dpp_ctrl = dpp_quad_perm(2, 2, 2, 2); - instr->dpp16().row_mask = 0xf; - instr->dpp16().bank_mask = 0xf; - instr->dpp16().bound_ctrl = 0; - instr->dpp16().fetch_inactive = 1; -} + aco_type parent_type = instr_info.alu_opcode_infos[(int)parent->opcode].def_types[0]; -/* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */ -bool -apply_omod_clamp(opt_ctx& ctx, aco_ptr& instr) -{ - if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 || - !instr_info.alu_opcode_infos[(int)instr->opcode].output_modifiers) - return false; + if (!instr_info.alu_opcode_infos[(int)parent->opcode].output_modifiers || + type.bit_size != parent_type.bit_size || parent_type.num_components != 1) + return nullptr; - bool can_vop3 = can_use_VOP3(ctx, instr); - bool is_mad_mix = - instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16; - bool needs_vop3 = !instr->isSDWA() && !instr->isVINTERP_INREG() && !is_mad_mix; - if (needs_vop3 && !can_vop3) - return false; + alu_opt_info parent_info; + if (!alu_opt_gather_info(ctx, parent, parent_info)) + return nullptr; - if (instr_info.classes[(int)instr->opcode] == instr_class::valu_pseudo_scalar_trans) - return false; + if (parent_info.uses_insert()) + return nullptr; - /* SDWA omod is GFX9+. */ - bool can_use_omod = (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P() && - (!instr->isVINTERP_INREG() || interp_can_become_fma(ctx, instr)); + alu_opt_info info; + if (!alu_opt_gather_info(ctx, instr.get(), info)) + return nullptr; - ssa_info& def_info = ctx.info[instr->definitions[0].tempId()]; + if (!backpropagate_input_modifiers(ctx, parent_info, info.operands[idx], type)) + return nullptr; - uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5; - if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels))) - return false; - /* if the omod/clamp instruction is dead, then the single user of this - * instruction is a different instruction */ - if (!ctx.uses[def_info.mod_instr->definitions[0].tempId()]) - return false; - - if (def_info.mod_instr->definitions[0].bytes() != instr->definitions[0].bytes()) - return false; - - /* MADs/FMAs are created later, so we don't have to update the original add */ - assert(!ctx.info[instr->definitions[0].tempId()].is_combined()); - - if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod)) - return false; - - if (needs_vop3) - instr->format = asVOP3(instr->format); - - if (!def_info.is_clamp() && instr->opcode == aco_opcode::v_interp_p2_f32_inreg) - interp_p2_f32_inreg_to_fma_dpp(instr); - - if (def_info.is_omod2()) - instr->valu().omod = 1; - else if (def_info.is_omod4()) - instr->valu().omod = 2; - else if (def_info.is_omod5()) - instr->valu().omod = 3; - else if (def_info.is_clamp()) - instr->valu().clamp = true; - - instr->definitions[0].swapTemp(def_info.mod_instr->definitions[0]); - ctx.info[instr->definitions[0].tempId()].label &= label_clamp; - ctx.uses[def_info.mod_instr->definitions[0].tempId()]--; - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - ctx.info[def_info.mod_instr->definitions[0].tempId()].parent_instr = def_info.mod_instr; - - return true; + parent_info.clamp = true; + parent_info.defs[0].setTemp(info.defs[0].getTemp()); + if (!alu_opt_info_is_valid(ctx, parent_info)) + return nullptr; + return alu_opt_info_to_instr(ctx, parent_info, parent); } /* Combine an p_insert (or p_extract, in some cases) instruction with instr. @@ -3793,6 +3625,8 @@ apply_output_mul(opt_ctx& ctx, aco_ptr& instr, Instruction* parent) if (!op_info_get_constant(ctx, info.operands[cidx], type, &constant)) return nullptr; + unsigned omod = 0; + for (unsigned i = 0; i < type.num_components; i++) { double val = extract_float(constant, type.bit_size, i); if (val < 0.0) { @@ -3800,26 +3634,42 @@ apply_output_mul(opt_ctx& ctx, aco_ptr& instr, Instruction* parent) info.operands[!cidx].neg[i] ^= true; } - if (val != 1.0) + if (val == 1.0) + omod = 0; + else if (val == 2.0) + omod = 1; + else if (val == 4.0) + omod = 2; + else if (val == 0.5) + omod = 3; + else + return nullptr; + + if (omod && type.num_components != 1) return nullptr; } - if ((info.omod || info.clamp) && - !instr_info.alu_opcode_infos[(int)parent->opcode].output_modifiers) + if (omod && (info.omod || denorm_mode != fp_denorm_flush || + (info.opcode != aco_opcode::v_mul_legacy_f32 && info.defs[0].isSZPreserve()))) + return nullptr; + + omod |= info.omod; + + if ((omod || info.clamp) && !instr_info.alu_opcode_infos[(int)parent->opcode].output_modifiers) return nullptr; alu_opt_info parent_info; if (!alu_opt_gather_info(ctx, parent, parent_info)) return nullptr; - if (parent_info.uses_insert() || (info.omod && (parent_info.omod || parent_info.clamp))) + if (parent_info.uses_insert() || (omod && (parent_info.omod || parent_info.clamp))) return nullptr; if (!backpropagate_input_modifiers(ctx, parent_info, info.operands[!cidx], type)) return nullptr; parent_info.clamp |= info.clamp; - parent_info.omod |= info.omod; + parent_info.omod |= omod; parent_info.insert = info.insert; parent_info.defs[0].setTemp(info.defs[0].getTemp()); if (!alu_opt_info_is_valid(ctx, parent_info)) @@ -3843,10 +3693,13 @@ apply_output_impl(opt_ctx& ctx, aco_ptr& instr, Instruction* parent return apply_s_abs(ctx, instr, parent); else if (instr->opcode == aco_opcode::v_mul_f64 || instr->opcode == aco_opcode::v_mul_f64_e64 || instr->opcode == aco_opcode::v_mul_f32 || instr->opcode == aco_opcode::v_mul_f16 || - instr->opcode == aco_opcode::v_pk_mul_f16) + instr->opcode == aco_opcode::v_pk_mul_f16 || + instr->opcode == aco_opcode::v_mul_legacy_f32) return apply_output_mul(ctx, instr, parent); else if (instr->opcode == aco_opcode::v_cvt_f16_f32) return apply_f2f16(ctx, instr, parent); + else if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16) + return apply_clamp(ctx, instr, parent); else UNREACHABLE("unhandled opcode"); @@ -3868,7 +3721,10 @@ apply_output(opt_ctx& ctx, aco_ptr& instr) case aco_opcode::v_mul_f32: case aco_opcode::v_mul_f16: case aco_opcode::v_pk_mul_f16: - case aco_opcode::v_cvt_f16_f32: break; + case aco_opcode::v_mul_legacy_f32: + case aco_opcode::v_cvt_f16_f32: + case aco_opcode::v_med3_f32: + case aco_opcode::v_med3_f16: break; default: return false; } @@ -3928,8 +3784,7 @@ apply_output(opt_ctx& ctx, aco_ptr& instr) for (Definition& def : new_instr->definitions) { ctx.info[def.tempId()].parent_instr = new_instr; - ctx.info[def.tempId()].label &= - instr_mod_labels | canonicalized_labels | label_combined_instr; + ctx.info[def.tempId()].label &= canonicalized_labels | label_combined_instr; } instr.reset(); @@ -4132,11 +3987,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) alu_propagate_temp_const(ctx, instr, true); } - if (instr->isVALU()) { - while (apply_omod_clamp(ctx, instr)) - ; - } - if (instr->isDPP()) return; diff --git a/src/amd/compiler/tests/test_optimizer.cpp b/src/amd/compiler/tests/test_optimizer.cpp index 579da333361..8abea38fc68 100644 --- a/src/amd/compiler/tests/test_optimizer.cpp +++ b/src/amd/compiler/tests/test_optimizer.cpp @@ -1360,8 +1360,8 @@ BEGIN_TEST(optimize.mad_mix.fma.basic) writeout(1, fadd(fmul(a, b), f2f32(c16))); /* omod/clamp check */ - //! v1: %res2_mul = v_fma_mix_f32 lo(%a16), %b, neg(0) - //! v1: %res2 = v_add_f32 %res2_mul, %c *2 + //! v1: %res2_fma = v_fma_mix_f32 lo(%a16), %b, %c + //! v1: %res2 = v_mul_f32 2.0, %res2_fma //! p_unit_test 2, %res2 writeout(2, bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(0x40000000), fadd(fmul(f2f32(a16), b), c)));