aco/optimizer: apply sgprs/extract with new helpers

Foz-DB Navi21:
Totals from 387 (0.49% of 79789) affected shaders:
MaxWaves: 7332 -> 7324 (-0.11%)
Instrs: 3156365 -> 3155691 (-0.02%); split: -0.02%, +0.00%
CodeSize: 17013948 -> 17014456 (+0.00%); split: -0.01%, +0.01%
VGPRs: 24768 -> 24776 (+0.03%)
Latency: 28569179 -> 28568183 (-0.00%); split: -0.00%, +0.00%
InvThroughput: 6530832 -> 6530566 (-0.00%); split: -0.00%, +0.00%
VClause: 90988 -> 90989 (+0.00%); split: -0.00%, +0.00%
Copies: 269074 -> 269060 (-0.01%); split: -0.01%, +0.01%
PreSGPRs: 22503 -> 22499 (-0.02%)
PreVGPRs: 22928 -> 22935 (+0.03%)
VALU: 2100245 -> 2099560 (-0.03%)

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35272>
This commit is contained in:
Georg Lehmann
2024-10-25 12:00:45 +02:00
committed by Marge Bot
parent 58163f65f0
commit 26da5cf8d9
2 changed files with 137 additions and 225 deletions

View File

@@ -1583,47 +1583,6 @@ pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsi
return true;
}
/* This expects the DPP modifier to be removed. */
bool
can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
assert(instr->isVALU());
if (instr->isSDWA() && ctx.program->gfx_level < GFX9)
return false;
return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
instr->opcode != aco_opcode::v_readlane_b32 &&
instr->opcode != aco_opcode::v_readlane_b32_e64 &&
instr->opcode != aco_opcode::v_writelane_b32 &&
instr->opcode != aco_opcode::v_writelane_b32_e64 &&
instr->opcode != aco_opcode::v_permlane16_b32 &&
instr->opcode != aco_opcode::v_permlanex16_b32 &&
instr->opcode != aco_opcode::v_permlane64_b32 &&
instr->opcode != aco_opcode::v_interp_p1_f32 &&
instr->opcode != aco_opcode::v_interp_p2_f32 &&
instr->opcode != aco_opcode::v_interp_mov_f32 &&
instr->opcode != aco_opcode::v_interp_p1ll_f16 &&
instr->opcode != aco_opcode::v_interp_p1lv_f16 &&
instr->opcode != aco_opcode::v_interp_p2_legacy_f16 &&
instr->opcode != aco_opcode::v_interp_p2_f16 &&
instr->opcode != aco_opcode::v_interp_p2_hi_f16 &&
instr->opcode != aco_opcode::v_interp_p10_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p2_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4 &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_fp8 &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_bf8 &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf8_fp8 &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf8_bf8;
}
/* only covers special cases */
bool
pseudo_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
@@ -2141,19 +2100,6 @@ does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
}
}
bool
can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp, unsigned idx)
{
float_mode* fp = &ctx.fp_mode;
if (ctx.info[tmp.id()].is_canonicalized() ||
(tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
return true;
aco_opcode op = instr->opcode;
return can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, idx) &&
does_fp_op_flush_denorms(ctx, op);
}
bool
can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags, bool allow_cselect = false)
{
@@ -2189,13 +2135,6 @@ can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags, bool allow_c
}
}
bool
is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned idx)
{
return info.is_temp() ||
(info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp, idx));
}
bool
is_op_canonicalized(opt_ctx& ctx, Operand op)
{
@@ -2310,6 +2249,12 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type)
type.num_components = 1;
type.bit_size = tmp.bytes() * 8;
if (info.is_extract()) {
op_info.extract[0] = parse_extract(info.parent_instr);
op_info.op = info.parent_instr->operands[0];
return true;
}
if (info.is_constant_or_literal(type.bit_size)) {
op_info.op = get_constant_op(ctx, info, type.bit_size);
return true;
@@ -2330,6 +2275,10 @@ bool
combine_operand(opt_ctx& ctx, alu_opt_op& inner, const aco_type& inner_type,
const alu_opt_op& outer, const aco_type& outer_type, bool flushes_denorms)
{
/* Nothing to be gained by bothering with lane masks. */
if (inner_type.bit_size <= 1)
return false;
if (inner.op.size() != outer.op.size())
return false;
@@ -2397,7 +2346,7 @@ decrease_and_dce(opt_ctx& ctx, Temp tmp)
}
void
alu_propagate_temp_const(opt_ctx& ctx, aco_ptr<Instruction>& instr)
alu_propagate_temp_const(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool uses_valid)
{
alu_opt_info info;
if (!alu_opt_gather_info(ctx, instr.get(), info))
@@ -2414,45 +2363,89 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->opcode == aco_opcode::v_min3_f32 || instr->opcode == aco_opcode::v_max3_f32 ||
instr->opcode == aco_opcode::v_med3_f32);
unsigned operand_mask = BITFIELD_MASK(info.operands.size());
bool progress = false;
alu_opt_info result_info;
for (unsigned i = 0; i < info.operands.size(); i++) {
while (info.operands[i].op.isTemp()) {
alu_opt_op outer;
aco_type outer_type;
if (!parse_operand(ctx, info.operands[i].op.getTemp(), outer, outer_type))
while (operand_mask) {
uint32_t i = UINT32_MAX;
uint32_t op_uses = UINT32_MAX;
u_foreach_bit (candidate, operand_mask) {
if (!info.operands[candidate].op.isTemp()) {
operand_mask &= ~BITFIELD_BIT(candidate);
continue;
}
if (!uses_valid) {
i = candidate;
break;
}
/* Applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier,
* otherwise we apply SGPRs later.
*/
if (info.operands[i].op.isOfType(RegType::vgpr) && outer.op.isOfType(RegType::sgpr) &&
!instr->isVOP1())
break;
alu_opt_op inner = info.operands[i];
aco_type inner_type = get_canonical_operand_type(info.opcode, i);
if (inner.f16_to_f32)
inner_type.bit_size = 16;
bool flushes_denorms = inner_type.base_type == aco_base_type_float && !gfx8_min_max;
if (!combine_operand(ctx, inner, inner_type, outer, outer_type, flushes_denorms))
break;
alu_opt_info info_copy = info;
info_copy.operands[i] = inner;
if (!alu_opt_info_is_valid(ctx, info_copy))
break;
bool has_lit = std::any_of(info_copy.operands.begin(), info_copy.operands.end(),
[](const alu_opt_op& op) { return op.op.isLiteral(); });
if (!had_lit && has_lit)
break;
result_info = info_copy;
info.operands[i] = inner;
progress = true;
unsigned new_uses = ctx.uses[info.operands[candidate].op.tempId()];
if (new_uses >= op_uses)
continue;
i = candidate;
op_uses = new_uses;
}
if (i == UINT32_MAX)
break;
alu_opt_op outer;
aco_type outer_type;
if (!parse_operand(ctx, info.operands[i].op.getTemp(), outer, outer_type)) {
operand_mask &= ~BITFIELD_BIT(i);
continue;
}
/* Applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier,
* otherwise we apply SGPRs later.
*/
bool valu_new_sgpr = info.operands[i].op.isOfType(RegType::vgpr) &&
outer.op.isOfType(RegType::sgpr) && !instr->isVOP1();
if ((valu_new_sgpr || ctx.info[info.operands[i].op.tempId()].is_extract()) && !uses_valid) {
operand_mask &= ~BITFIELD_BIT(i);
continue;
}
alu_opt_op inner = info.operands[i];
aco_type inner_type = get_canonical_operand_type(info.opcode, i);
if (inner.f16_to_f32)
inner_type.bit_size = 16;
bool flushes_denorms = inner_type.base_type == aco_base_type_float && !gfx8_min_max;
if (!combine_operand(ctx, inner, inner_type, outer, outer_type, flushes_denorms)) {
operand_mask &= ~BITFIELD_BIT(i);
continue;
}
alu_opt_info info_copy = info;
info_copy.operands[i] = inner;
if (!alu_opt_info_is_valid(ctx, info_copy)) {
operand_mask &= ~BITFIELD_BIT(i);
continue;
}
bool has_lit = std::any_of(info_copy.operands.begin(), info_copy.operands.end(),
[](const alu_opt_op& op) { return op.op.isLiteral(); });
if (!had_lit && has_lit) {
operand_mask &= ~BITFIELD_BIT(i);
continue;
}
bool valu_removed_sgpr = info.operands[i].op.isOfType(RegType::sgpr) &&
!inner.op.isOfType(RegType::sgpr) && instr->isVALU();
if (valu_removed_sgpr && uses_valid)
operand_mask = BITFIELD_MASK(info.operands.size());
if (uses_valid) {
if (inner.op.isTemp())
ctx.uses[inner.op.tempId()]++;
decrease_and_dce(ctx, info.operands[i].op.getTemp());
}
result_info = info_copy;
info.operands[i] = inner;
progress = true;
}
if (!progress)
@@ -2618,7 +2611,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
/* SALU / VALU: propagate inline constants, temps, and imod */
if (instr->isSALU() || instr->isVALU()) {
alu_propagate_temp_const(ctx, instr);
alu_propagate_temp_const(ctx, instr, false);
}
/* if this instruction doesn't define anything, return */
@@ -3862,104 +3855,6 @@ combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opc
return false;
}
void
apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 ||
instr->opcode == aco_opcode::v_lshlrev_b64 ||
instr->opcode == aco_opcode::v_lshrrev_b64 ||
instr->opcode == aco_opcode::v_ashrrev_i64;
/* find candidates and create the set of sgprs already read */
unsigned sgpr_ids[2] = {0, 0};
uint32_t operand_mask = 0;
bool has_literal = false;
for (unsigned i = 0; i < instr->operands.size(); i++) {
if (instr->operands[i].isLiteral())
has_literal = true;
if (!instr->operands[i].isTemp())
continue;
if (instr->operands[i].getTemp().type() == RegType::sgpr) {
if (instr->operands[i].tempId() != sgpr_ids[0])
sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
}
ssa_info& info = ctx.info[instr->operands[i].tempId()];
if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::sgpr)
operand_mask |= 1u << i;
if (info.is_extract() && info.parent_instr->operands[0].getTemp().type() == RegType::sgpr)
operand_mask |= 1u << i;
}
unsigned max_sgprs = 1;
if (ctx.program->gfx_level >= GFX10 && !is_shift64)
max_sgprs = 2;
if (has_literal)
max_sgprs--;
unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
/* keep on applying sgprs until there is nothing left to be done */
while (operand_mask) {
uint32_t sgpr_idx = 0;
uint32_t sgpr_info_id = 0;
uint32_t mask = operand_mask;
/* choose a sgpr */
while (mask) {
unsigned i = u_bit_scan(&mask);
uint16_t uses = ctx.uses[instr->operands[i].tempId()];
if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
sgpr_idx = i;
sgpr_info_id = instr->operands[i].tempId();
}
}
operand_mask &= ~(1u << sgpr_idx);
ssa_info& info = ctx.info[sgpr_info_id];
Temp sgpr = info.is_extract() ? info.parent_instr->operands[0].getTemp() : info.temp;
bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
if (new_sgpr && num_sgprs >= max_sgprs)
continue;
if (sgpr_idx == 0)
instr->format = withoutDPP(instr->format);
if (sgpr_idx == 1 && instr->isDPP())
continue;
if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
info.is_extract()) {
/* can_apply_extract() checks SGPR encoding restrictions */
if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
apply_extract(ctx, instr, sgpr_idx, info);
else if (info.is_extract())
continue;
instr->operands[sgpr_idx] = Operand(sgpr);
} else if (can_swap_operands(instr, &instr->opcode) && !instr->valu().opsel[sgpr_idx]) {
instr->operands[sgpr_idx] = instr->operands[0];
instr->operands[0] = Operand(sgpr);
instr->valu().opsel[0].swap(instr->valu().opsel[sgpr_idx]);
/* swap bits using a 4-entry LUT */
uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
operand_mask = (operand_mask & ~0x3) | swapped;
} else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
instr->format = asVOP3(instr->format);
instr->operands[sgpr_idx] = Operand(sgpr);
} else {
continue;
}
if (new_sgpr)
sgpr_ids[num_sgprs++] = sgpr.id();
ctx.uses[sgpr_info_id]--;
ctx.uses[sgpr.id()]++;
/* TODO: handle when it's a VGPR */
if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
operand_mask |= 1u << sgpr_idx;
}
}
bool
interp_can_become_fma(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
@@ -4703,38 +4598,19 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
return;
for (const Definition& def : instr->definitions) {
ssa_info& info = ctx.info[def.tempId()];
if (info.is_extract() && ctx.uses[def.tempId()] > 4)
info.label &= ~label_extract;
}
if (instr->isVALU() || instr->isSALU()) {
/* Apply SDWA. Do this after label_instruction() so it can remove
* label_extract if not all instructions can take SDWA. */
for (unsigned i = 0; i < instr->operands.size(); i++) {
Operand& op = instr->operands[i];
if (!op.isTemp())
continue;
ssa_info& info = ctx.info[op.tempId()];
if (!info.is_extract())
continue;
/* if there are that many uses, there are likely better combinations */
// TODO: delay applying extract to a point where we know better
if (ctx.uses[op.tempId()] > 4) {
info.label &= ~label_extract;
continue;
}
if (info.is_extract() &&
(info.parent_instr->operands[0].getTemp().type() == RegType::vgpr ||
instr->operands[i].getTemp().type() == RegType::sgpr) &&
can_apply_extract(ctx, instr, i, info)) {
/* Increase use count of the extract's operand if the extract still has uses. */
apply_extract(ctx, instr, i, info);
if (--ctx.uses[instr->operands[i].tempId()])
ctx.uses[info.parent_instr->operands[0].tempId()]++;
instr->operands[i].setTemp(info.parent_instr->operands[0].getTemp());
}
}
alu_propagate_temp_const(ctx, instr, true);
}
if (instr->isVALU()) {
if (can_apply_sgprs(ctx, instr))
apply_sgprs(ctx, instr);
combine_mad_mix(ctx, instr);
while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr))
;

View File

@@ -689,7 +689,7 @@ BEGIN_TEST(optimize.sdwa.subdword_extract)
Operand::c32(8), Operand::c32(0)),
inputs[2]));
//! v1b: %res3 = v_or_b32 %a, %b dst_sel:ubyte0 dst_preserve src0_sel:uword0 src1_sel:ubyte2
//! v1b: %res3 = v_or_b32 %a, %b dst_sel:ubyte0 dst_preserve src0_sel:uword0 src1_sel:uword1
//! p_unit_test 3, %res3
writeout(3, bld.vop2(aco_opcode::v_or_b32, bld.def(v1b),
bld.pseudo(aco_opcode::p_extract, bld.def(v1b), a, Operand::c32(0),
@@ -703,9 +703,7 @@ BEGIN_TEST(optimize.sdwa.subdword_extract)
bld.pseudo(aco_opcode::p_extract, bld.def(v2b), a, Operand::c32(0),
Operand::c32(8), Operand::c32(1))));
/* TODO incremental conversion to sdwa loses information if zero extend is actually necessary */
//! v2b: %tmp5 = p_extract %b, 1, 8, 1
//! v2b: %res5 = v_or_b32 %a, %tmp5 dst_sel:uword0 dst_preserve src0_sel:sbyte0 src1_sel:uword0
//! v2b: %res5 = v_or_b32 %a, %b dst_sel:uword0 dst_preserve src0_sel:sbyte0 src1_sel:sbyte1
//! p_unit_test 5, %res5
writeout(5, bld.vop2(aco_opcode::v_or_b32, bld.def(v2b),
bld.pseudo(aco_opcode::p_extract, bld.def(v2b), a, Operand::c32(0),
@@ -845,3 +843,41 @@ BEGIN_TEST(optimize.sdwa.extract_vector)
finish_opt_test();
END_TEST
BEGIN_TEST(optimizer.sdwa.lanemask_extract)
for (unsigned i = GFX10; i <= GFX11; i++) {
if (i == GFX10_3)
continue;
//>> v1: %a:v[0], v1: %b:v[1], s1: %c:s[0] = p_startpgm
if (!setup_cs("v1 v1 s1", (amd_gfx_level)i, CHIP_UNKNOWN, "", 32))
continue;
Temp a = inputs[0];
Temp b = inputs[1];
Temp c = inputs[2];
//! s1: %mask0, s1: %_:scc = p_extract %c, 0, 16, 0
//! v1: %res0 = v_cndmask_b32 %a, %b, %mask0
//! p_unit_test 0, %res0
Temp mask = ext_ushort(c, 0);
Temp bcsel = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), a, b, mask);
writeout(0, bcsel);
//! s1: %mask1, s1: %_:scc = p_extract %c, 2, 8, 1
//! v1: %res1 = v_cndmask_b32 %a, %b, %mask1
//! p_unit_test 1, %res1
mask = ext_sbyte(c, 2);
bcsel = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), a, b, mask);
writeout(1, bcsel);
//! s1: %mask2, s1: %_:scc = p_extract %c, 3, 8, 0
//! v1: %res2 = v_cndmask_b32 %a, %b, %mask2
//! p_unit_test 2, %res2
mask = ext_ubyte(c, 3);
bcsel = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), a, b, mask);
writeout(2, bcsel);
finish_opt_test();
}
END_TEST