aco/optimizer: use new helpers for omod/clamp

Also resolves the old TODO about using omod for multiplication
with negative 0.5, 2.0 or 4.0.

Foz-DB Navi21:
Totals from 5680 (5.82% of 97591) affected shaders:
MaxWaves: 111976 -> 111974 (-0.00%)
Instrs: 12013419 -> 12003946 (-0.08%); split: -0.08%, +0.00%
CodeSize: 65379508 -> 65364884 (-0.02%); split: -0.04%, +0.02%
VGPRs: 375840 -> 375856 (+0.00%); split: -0.00%, +0.01%
Latency: 85804600 -> 85784850 (-0.02%); split: -0.03%, +0.01%
InvThroughput: 20705698 -> 20692571 (-0.06%); split: -0.07%, +0.00%
VClause: 269772 -> 269606 (-0.06%); split: -0.09%, +0.03%
SClause: 324997 -> 324934 (-0.02%); split: -0.03%, +0.01%
Copies: 963255 -> 963264 (+0.00%); split: -0.06%, +0.06%
Branches: 326691 -> 326688 (-0.00%); split: -0.00%, +0.00%
PreSGPRs: 345106 -> 345109 (+0.00%)
PreVGPRs: 317681 -> 317729 (+0.02%)
VALU: 8372681 -> 8363374 (-0.11%); split: -0.11%, +0.00%
SALU: 1456669 -> 1456589 (-0.01%); split: -0.01%, +0.01%

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38658>
This commit is contained in:
Georg Lehmann
2025-01-10 08:32:30 +01:00
committed by Marge Bot
parent 69b5767eee
commit b82339d99e
2 changed files with 70 additions and 220 deletions

View File

@@ -61,19 +61,8 @@ enum Label {
label_canonicalized_fp16 = 1ull << 22,
label_canonicalized_fp32 = 1ull << 23,
label_canonicalized_fp64 = 1ull << 24,
/* label_{omod2,omod4,omod5,clamp} are used for both 16 and
* 32-bit operations but this doesn't cause any issues because
* we check the definition register class.
*/
label_omod2 = 1ull << 33,
label_omod4 = 1ull << 34,
label_omod5 = 1ull << 35,
label_clamp = 1ull << 36,
};
static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp;
static constexpr uint64_t input_mod_labels =
label_abs_fp16 | label_abs_fp32_64 | label_neg_fp16 | label_neg_fp32_64;
@@ -99,8 +88,6 @@ canonicalized_label(unsigned bit_size)
UNREACHABLE("unknown canonicalized size");
}
static_assert((instr_mod_labels & temp_labels) == 0, "labels cannot intersect");
static_assert((instr_mod_labels & val_labels) == 0, "labels cannot intersect");
static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
struct ssa_info {
@@ -108,7 +95,6 @@ struct ssa_info {
union {
uint64_t val;
Temp temp;
Instruction* mod_instr;
};
Instruction* parent_instr;
@@ -116,19 +102,14 @@ struct ssa_info {
void add_label(Label new_label)
{
if (new_label & instr_mod_labels) {
label &= ~instr_mod_labels;
label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
}
if (new_label & temp_labels) {
label &= ~temp_labels;
label &= ~(instr_mod_labels | val_labels); /* instr, temp and val alias */
label &= ~val_labels; /* temp and val alias */
}
if (new_label & val_labels) {
label &= ~val_labels;
label &= ~(instr_mod_labels | temp_labels); /* instr, temp and val alias */
label &= ~temp_labels; /* temp and val alias */
}
label |= new_label;
@@ -194,46 +175,6 @@ struct ssa_info {
bool is_combined() { return label & label_combined_instr; }
void set_omod2(Instruction* mul)
{
if (label & temp_labels)
return;
add_label(label_omod2);
mod_instr = mul;
}
bool is_omod2() { return label & label_omod2; }
void set_omod4(Instruction* mul)
{
if (label & temp_labels)
return;
add_label(label_omod4);
mod_instr = mul;
}
bool is_omod4() { return label & label_omod4; }
void set_omod5(Instruction* mul)
{
if (label & temp_labels)
return;
add_label(label_omod5);
mod_instr = mul;
}
bool is_omod5() { return label & label_omod5; }
void set_clamp(Instruction* med3)
{
if (label & temp_labels)
return;
add_label(label_clamp);
mod_instr = med3;
}
bool is_clamp() { return label & label_clamp; }
void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
@@ -1675,34 +1616,6 @@ gather_canonicalized(opt_ctx& ctx, aco_ptr<Instruction>& instr)
}
}
bool
can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
{
if (instr->isVOP3())
return true;
if (instr->isVOP3P() || instr->isVINTERP_INREG())
return false;
if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10)
return false;
if (instr->isSDWA())
return false;
if (instr->isDPP() && ctx.program->gfx_level < GFX11)
return false;
return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
instr->opcode != aco_opcode::v_permlane64_b32 &&
instr->opcode != aco_opcode::v_readlane_b32 &&
instr->opcode != aco_opcode::v_writelane_b32 &&
instr->opcode != aco_opcode::v_readfirstlane_b32;
}
bool
pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
{
@@ -2417,7 +2330,7 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool uses_va
instr.reset(alu_opt_info_to_instr(ctx, result_info, instr.release()));
for (const Definition& def : instr->definitions)
ctx.info[def.tempId()].label &= instr_mod_labels | canonicalized_labels;
ctx.info[def.tempId()].label &= canonicalized_labels;
}
void
@@ -2856,8 +2769,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
case aco_opcode::v_mul_f32:
case aco_opcode::v_mul_legacy_f32:
case aco_opcode::v_mul_f64:
case aco_opcode::v_mul_f64_e64: { /* omod */
/* TODO: try to move the negate/abs modifier to the consumer instead */
case aco_opcode::v_mul_f64_e64: {
bool uses_mods = instr->usesModifiers();
bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
bool fp64 =
@@ -2875,19 +2787,13 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
bool neg1 = constant == -1.0;
VALU_instruction* valu = &instr->valu();
if (valu->abs[!i] || valu->neg[!i] || valu->omod)
if (valu->abs[!i] || valu->neg[!i] || valu->omod || valu->clamp)
continue;
bool abs = valu->abs[i];
bool neg = neg1 ^ valu->neg[i];
Temp other = instr->operands[i].getTemp();
if (valu->clamp) {
if (!abs && !neg && other.type() == RegType::vgpr)
ctx.info[other.id()].set_clamp(instr.get());
continue;
}
if (abs && neg && other.type() == RegType::vgpr)
ctx.info[instr->definitions[0].tempId()].set_neg_abs(other, bit_size);
else if (abs && !neg && other.type() == RegType::vgpr)
@@ -2900,37 +2806,17 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
else
ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other, bit_size);
}
} else if (uses_mods || (instr->definitions[0].isSZPreserve() &&
instr->opcode != aco_opcode::v_mul_legacy_f32)) {
continue; /* omod uses a legacy multiplication. */
} else if (instr->operands[!i].constantValue64() == 0u &&
} else if (!uses_mods && instr->operands[!i].constantValue64() == 0u &&
((!instr->definitions[0].isNaNPreserve() &&
!instr->definitions[0].isInfPreserve()) ||
!instr->definitions[0].isInfPreserve() &&
!instr->definitions[0].isSZPreserve()) ||
instr->opcode == aco_opcode::v_mul_legacy_f32)) {
ctx.info[instr->definitions[0].tempId()].set_constant(0u);
} else if (denorm_mode != fp_denorm_flush) {
/* omod has no effect if denormals are enabled. */
continue;
} else if (constant == 2.0) {
ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
} else if (constant == 4.0) {
ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
} else if (constant == 0.5) {
ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
} else {
continue;
}
break;
}
break;
}
case aco_opcode::v_med3_f16:
case aco_opcode::v_med3_f32: { /* clamp */
unsigned idx;
if (detect_clamp(instr.get(), &idx) && !instr->valu().abs && !instr->valu().neg)
ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
break;
}
case aco_opcode::s_not_b32:
case aco_opcode::s_not_b64:
if (!instr->operands[0].isTemp()) {
@@ -3480,98 +3366,44 @@ use_absdiff:
return op_instr;
}
bool
interp_can_become_fma(opt_ctx& ctx, aco_ptr<Instruction>& instr)
Instruction*
apply_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent)
{
if (instr->opcode != aco_opcode::v_interp_p2_f32_inreg)
return false;
unsigned idx;
if (!detect_clamp(instr.get(), &idx))
return nullptr;
instr->opcode = aco_opcode::v_fma_f32;
instr->format = Format::VOP3;
bool dpp_allowed = can_use_DPP(ctx.program->gfx_level, instr, false);
instr->opcode = aco_opcode::v_interp_p2_f32_inreg;
instr->format = Format::VINTERP_INREG;
aco_type type = instr_info.alu_opcode_infos[(int)instr->opcode].def_types[0];
return dpp_allowed;
}
if (!ctx.info[parent->definitions[0].tempId()].is_canonicalized(type.bit_size) &&
ctx.fp_mode.denorm32 != fp_denorm_keep)
return nullptr;
void
interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction>& instr)
{
static_assert(sizeof(DPP16_instruction) == sizeof(VINTERP_inreg_instruction),
"Invalid instr cast.");
instr->format = asVOP3(Format::DPP16);
instr->opcode = aco_opcode::v_fma_f32;
instr->dpp16().dpp_ctrl = dpp_quad_perm(2, 2, 2, 2);
instr->dpp16().row_mask = 0xf;
instr->dpp16().bank_mask = 0xf;
instr->dpp16().bound_ctrl = 0;
instr->dpp16().fetch_inactive = 1;
}
aco_type parent_type = instr_info.alu_opcode_infos[(int)parent->opcode].def_types[0];
/* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
bool
apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
!instr_info.alu_opcode_infos[(int)instr->opcode].output_modifiers)
return false;
if (!instr_info.alu_opcode_infos[(int)parent->opcode].output_modifiers ||
type.bit_size != parent_type.bit_size || parent_type.num_components != 1)
return nullptr;
bool can_vop3 = can_use_VOP3(ctx, instr);
bool is_mad_mix =
instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
bool needs_vop3 = !instr->isSDWA() && !instr->isVINTERP_INREG() && !is_mad_mix;
if (needs_vop3 && !can_vop3)
return false;
alu_opt_info parent_info;
if (!alu_opt_gather_info(ctx, parent, parent_info))
return nullptr;
if (instr_info.classes[(int)instr->opcode] == instr_class::valu_pseudo_scalar_trans)
return false;
if (parent_info.uses_insert())
return nullptr;
/* SDWA omod is GFX9+. */
bool can_use_omod = (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P() &&
(!instr->isVINTERP_INREG() || interp_can_become_fma(ctx, instr));
alu_opt_info info;
if (!alu_opt_gather_info(ctx, instr.get(), info))
return nullptr;
ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
if (!backpropagate_input_modifiers(ctx, parent_info, info.operands[idx], type))
return nullptr;
uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
return false;
/* if the omod/clamp instruction is dead, then the single user of this
* instruction is a different instruction */
if (!ctx.uses[def_info.mod_instr->definitions[0].tempId()])
return false;
if (def_info.mod_instr->definitions[0].bytes() != instr->definitions[0].bytes())
return false;
/* MADs/FMAs are created later, so we don't have to update the original add */
assert(!ctx.info[instr->definitions[0].tempId()].is_combined());
if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod))
return false;
if (needs_vop3)
instr->format = asVOP3(instr->format);
if (!def_info.is_clamp() && instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
interp_p2_f32_inreg_to_fma_dpp(instr);
if (def_info.is_omod2())
instr->valu().omod = 1;
else if (def_info.is_omod4())
instr->valu().omod = 2;
else if (def_info.is_omod5())
instr->valu().omod = 3;
else if (def_info.is_clamp())
instr->valu().clamp = true;
instr->definitions[0].swapTemp(def_info.mod_instr->definitions[0]);
ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
ctx.uses[def_info.mod_instr->definitions[0].tempId()]--;
ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get();
ctx.info[def_info.mod_instr->definitions[0].tempId()].parent_instr = def_info.mod_instr;
return true;
parent_info.clamp = true;
parent_info.defs[0].setTemp(info.defs[0].getTemp());
if (!alu_opt_info_is_valid(ctx, parent_info))
return nullptr;
return alu_opt_info_to_instr(ctx, parent_info, parent);
}
/* Combine an p_insert (or p_extract, in some cases) instruction with instr.
@@ -3793,6 +3625,8 @@ apply_output_mul(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent)
if (!op_info_get_constant(ctx, info.operands[cidx], type, &constant))
return nullptr;
unsigned omod = 0;
for (unsigned i = 0; i < type.num_components; i++) {
double val = extract_float(constant, type.bit_size, i);
if (val < 0.0) {
@@ -3800,26 +3634,42 @@ apply_output_mul(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent)
info.operands[!cidx].neg[i] ^= true;
}
if (val != 1.0)
if (val == 1.0)
omod = 0;
else if (val == 2.0)
omod = 1;
else if (val == 4.0)
omod = 2;
else if (val == 0.5)
omod = 3;
else
return nullptr;
if (omod && type.num_components != 1)
return nullptr;
}
if ((info.omod || info.clamp) &&
!instr_info.alu_opcode_infos[(int)parent->opcode].output_modifiers)
if (omod && (info.omod || denorm_mode != fp_denorm_flush ||
(info.opcode != aco_opcode::v_mul_legacy_f32 && info.defs[0].isSZPreserve())))
return nullptr;
omod |= info.omod;
if ((omod || info.clamp) && !instr_info.alu_opcode_infos[(int)parent->opcode].output_modifiers)
return nullptr;
alu_opt_info parent_info;
if (!alu_opt_gather_info(ctx, parent, parent_info))
return nullptr;
if (parent_info.uses_insert() || (info.omod && (parent_info.omod || parent_info.clamp)))
if (parent_info.uses_insert() || (omod && (parent_info.omod || parent_info.clamp)))
return nullptr;
if (!backpropagate_input_modifiers(ctx, parent_info, info.operands[!cidx], type))
return nullptr;
parent_info.clamp |= info.clamp;
parent_info.omod |= info.omod;
parent_info.omod |= omod;
parent_info.insert = info.insert;
parent_info.defs[0].setTemp(info.defs[0].getTemp());
if (!alu_opt_info_is_valid(ctx, parent_info))
@@ -3843,10 +3693,13 @@ apply_output_impl(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent
return apply_s_abs(ctx, instr, parent);
else if (instr->opcode == aco_opcode::v_mul_f64 || instr->opcode == aco_opcode::v_mul_f64_e64 ||
instr->opcode == aco_opcode::v_mul_f32 || instr->opcode == aco_opcode::v_mul_f16 ||
instr->opcode == aco_opcode::v_pk_mul_f16)
instr->opcode == aco_opcode::v_pk_mul_f16 ||
instr->opcode == aco_opcode::v_mul_legacy_f32)
return apply_output_mul(ctx, instr, parent);
else if (instr->opcode == aco_opcode::v_cvt_f16_f32)
return apply_f2f16(ctx, instr, parent);
else if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16)
return apply_clamp(ctx, instr, parent);
else
UNREACHABLE("unhandled opcode");
@@ -3868,7 +3721,10 @@ apply_output(opt_ctx& ctx, aco_ptr<Instruction>& instr)
case aco_opcode::v_mul_f32:
case aco_opcode::v_mul_f16:
case aco_opcode::v_pk_mul_f16:
case aco_opcode::v_cvt_f16_f32: break;
case aco_opcode::v_mul_legacy_f32:
case aco_opcode::v_cvt_f16_f32:
case aco_opcode::v_med3_f32:
case aco_opcode::v_med3_f16: break;
default: return false;
}
@@ -3928,8 +3784,7 @@ apply_output(opt_ctx& ctx, aco_ptr<Instruction>& instr)
for (Definition& def : new_instr->definitions) {
ctx.info[def.tempId()].parent_instr = new_instr;
ctx.info[def.tempId()].label &=
instr_mod_labels | canonicalized_labels | label_combined_instr;
ctx.info[def.tempId()].label &= canonicalized_labels | label_combined_instr;
}
instr.reset();
@@ -4132,11 +3987,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
alu_propagate_temp_const(ctx, instr, true);
}
if (instr->isVALU()) {
while (apply_omod_clamp(ctx, instr))
;
}
if (instr->isDPP())
return;

View File

@@ -1360,8 +1360,8 @@ BEGIN_TEST(optimize.mad_mix.fma.basic)
writeout(1, fadd(fmul(a, b), f2f32(c16)));
/* omod/clamp check */
//! v1: %res2_mul = v_fma_mix_f32 lo(%a16), %b, neg(0)
//! v1: %res2 = v_add_f32 %res2_mul, %c *2
//! v1: %res2_fma = v_fma_mix_f32 lo(%a16), %b, %c
//! v1: %res2 = v_mul_f32 2.0, %res2_fma
//! p_unit_test 2, %res2
writeout(2, bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(0x40000000),
fadd(fmul(f2f32(a16), b), c)));