aco: propagate swizzles when optimizing packed clamp & fma

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>
This commit is contained in:
Daniel Schürmann
2021-01-07 15:07:09 +01:00
committed by Marge Bot
parent 6ecbccfb23
commit 412291ddef
+34 -11
View File
@@ -2726,6 +2726,28 @@ bool combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
return false;
}
void propagate_swizzles(VOP3P_instruction* instr, uint8_t opsel_lo, uint8_t opsel_hi)
{
/* propagate swizzles which apply to a result down to the instruction's operands:
* result = a.xy + b.xx -> result.yx = a.yx + b.xx */
assert((opsel_lo & 1) == opsel_lo);
assert((opsel_hi & 1) == opsel_hi);
uint8_t tmp_lo = instr->opsel_lo;
uint8_t tmp_hi = instr->opsel_hi;
bool neg_lo[3] = { instr->neg_lo[0], instr->neg_lo[1], instr->neg_lo[2] };
bool neg_hi[3] = { instr->neg_hi[0], instr->neg_hi[1], instr->neg_hi[2] };
if (opsel_lo == 1) {
instr->opsel_lo = tmp_hi;
for (unsigned i = 0; i < 3; i++)
instr->neg_lo[i] = neg_hi[i];
}
if (opsel_hi == 0) {
instr->opsel_hi = tmp_lo;
for (unsigned i = 0; i < 3; i++)
instr->neg_hi[i] = neg_lo[i];
}
}
void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
{
VOP3P_instruction* vop3p = static_cast<VOP3P_instruction*>(instr.get());
@@ -2734,15 +2756,14 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
if (instr->opcode == aco_opcode::v_pk_mul_f16 &&
instr->operands[1].constantEquals(0x3C00) &&
vop3p->clamp &&
vop3p->opsel_lo == 0x0 &&
vop3p->opsel_hi == 0x1 &&
instr->operands[0].isTemp() &&
ctx.uses[instr->operands[0].tempId()] == 1) {
ssa_info& info = ctx.info[instr->operands[0].tempId()];
if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
Instruction* candidate = ctx.info[instr->operands[0].tempId()].instr;
static_cast<VOP3P_instruction*>(candidate)->clamp = true;
VOP3P_instruction* candidate = static_cast<VOP3P_instruction*>(ctx.info[instr->operands[0].tempId()].instr);
candidate->clamp = true;
propagate_swizzles(candidate, vop3p->opsel_lo, vop3p->opsel_hi);
std::swap(instr->definitions[0], candidate->definitions[0]);
ctx.info[candidate->definitions[0].tempId()].instr = candidate;
ctx.uses[instr->definitions[0].tempId()]--;
@@ -2794,6 +2815,7 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
Instruction* mul_instr = nullptr;
unsigned add_op_idx = 0;
uint8_t opsel_lo = 0, opsel_hi = 0;
uint32_t uses = UINT32_MAX;
/* find the 'best' mul instruction to combine with the add */
@@ -2809,16 +2831,14 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
!check_vop3_operands(ctx, 3, op))
continue;
/* opsel of mul needs to be .xy */
if (static_cast<VOP3P_instruction*>(instr.get())->opsel_lo & (1 << i) ||
!(static_cast<VOP3P_instruction*>(instr.get())->opsel_hi & (1 << i)))
continue;
/* no clamp allowed between mul and add */
if (static_cast<VOP3P_instruction*>(info.instr)->clamp)
continue;
mul_instr = info.instr;
add_op_idx = 1 - i;
opsel_lo = (vop3p->opsel_lo >> i) & 1;
opsel_hi = (vop3p->opsel_hi >> i) & 1;
uses = ctx.uses[instr->operands[i].tempId()];
}
@@ -2845,11 +2865,14 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
fma->neg_hi[i] = mul->neg_hi[i];
}
fma->operands[2] = op[2];
fma->clamp = vop3p->clamp;
fma->opsel_lo = mul->opsel_lo;
fma->opsel_hi = mul->opsel_hi;
propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
fma->opsel_lo |= (vop3p->opsel_lo << (2 - add_op_idx)) & 0x4;
fma->opsel_hi |= (vop3p->opsel_hi << (2 - add_op_idx)) & 0x4;
fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
fma->clamp = vop3p->clamp;
fma->opsel_lo = mul->opsel_lo | ((vop3p->opsel_lo << (2 - add_op_idx)) & 0x4);
fma->opsel_hi = mul->opsel_hi | ((vop3p->opsel_hi << (2 - add_op_idx)) & 0x4);
fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
fma->definitions[0] = instr->definitions[0];