diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index b78f135fc20..7eb629efee4 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -1273,6 +1273,20 @@ detect_clamp(Instruction* instr, unsigned* clamped_idx) } } +void +decrease_and_dce(opt_ctx& ctx, Temp tmp) +{ + assert(ctx.uses[tmp.id()]); + ctx.uses[tmp.id()]--; + Instruction* instr = ctx.info[tmp.id()].parent_instr; + if (is_dead(ctx.uses, instr)) { + for (const Operand& op : instr->operands) { + if (op.isTemp()) + decrease_and_dce(ctx, op.getTemp()); + } + } +} + void label_instruction(opt_ctx& ctx, aco_ptr& instr) { @@ -1961,24 +1975,6 @@ original_temp_id(opt_ctx& ctx, Temp tmp) return tmp.id(); } -void -decrease_op_uses_if_dead(opt_ctx& ctx, Instruction* instr) -{ - if (is_dead(ctx.uses, instr)) { - for (const Operand& op : instr->operands) { - if (op.isTemp()) - ctx.uses[op.tempId()]--; - } - } -} - -void -decrease_uses(opt_ctx& ctx, Instruction* instr) -{ - ctx.uses[instr->definitions[0].tempId()]--; - decrease_op_uses_if_dead(ctx, instr); -} - Operand copy_operand(opt_ctx& ctx, Operand op) { @@ -2265,7 +2261,7 @@ combine_xor_not(opt_ctx& ctx, aco_ptr& instr) instr->opcode = aco_opcode::v_xnor_b32; instr->operands[i] = copy_operand(ctx, op_instr->operands[0]); - decrease_uses(ctx, op_instr); + decrease_and_dce(ctx, op_instr->definitions[0].getTemp()); if (instr->operands[0].isOfType(RegType::vgpr)) std::swap(instr->operands[0], instr->operands[1]); if (!instr->operands[1].isOfType(RegType::vgpr)) @@ -2477,7 +2473,7 @@ combine_salu_lshl_add(opt_ctx& ctx, aco_ptr& instr) instr->operands[1] = instr->operands[!i]; instr->operands[0] = copy_operand(ctx, op2_instr->operands[0]); - decrease_uses(ctx, op2_instr); + decrease_and_dce(ctx, op2_instr->definitions[0].getTemp()); ctx.info[instr->definitions[0].tempId()].label = 0; instr->opcode = std::array{ @@ -3191,7 +3187,7 @@ combine_v_andor_not(opt_ctx& ctx, aco_ptr& instr) new_instr->definitions[0] = instr->definitions[0]; new_instr->pass_flags = instr->pass_flags; instr.reset(new_instr); - decrease_uses(ctx, op_instr); + decrease_and_dce(ctx, op_instr->definitions[0].getTemp()); ctx.info[instr->definitions[0].tempId()].label = 0; ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); return true; @@ -3471,7 +3467,7 @@ combine_vop3p(opt_ctx& ctx, aco_ptr& instr) fma->pass_flags = instr->pass_flags; instr = std::move(fma); ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - decrease_uses(ctx, mul_instr); + decrease_and_dce(ctx, mul_instr->definitions[0].getTemp()); return; } }