diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 3a33e036470..3eadcd8d041 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -34,11 +34,8 @@ namespace { struct mad_info { aco_ptr add_instr; uint32_t mul_temp_id; - uint16_t literal_mask; - uint16_t fp16_mask; - mad_info(aco_ptr instr, uint32_t id) - : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0), fp16_mask(0) + mad_info(aco_ptr instr, uint32_t id) : add_instr(std::move(instr)), mul_temp_id(id) {} }; @@ -1150,7 +1147,8 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info) if (info.opcode == aco_opcode::s_fmac_f32) { for (unsigned i = 0; i < 2; i++) { if (lmask[i]) { - std::swap(info.operands[i], info.operands[2]); + std::swap(info.operands[i], info.operands[1]); + std::swap(info.operands[1], info.operands[2]); info.opcode = aco_opcode::s_fmamk_f32; break; } @@ -5536,101 +5534,6 @@ select_instruction(opt_ctx& ctx, aco_ptr& instr) ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); mad_info = NULL; } - /* check literals */ - else if (!instr->isDPP() && !instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_f64 && - instr->opcode != aco_opcode::v_mad_legacy_f32 && - instr->opcode != aco_opcode::v_fma_legacy_f32) { - /* FMA can only take literals on GFX10+ */ - if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) && - ctx.program->gfx_level < GFX10) - return; - /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take - * literals (GFX10+), these instructions don't exist. - */ - if (instr->opcode == aco_opcode::v_fma_legacy_f16) - return; - - uint32_t literal_mask = 0; - uint32_t fp16_mask = 0; - uint32_t sgpr_mask = 0; - uint32_t vgpr_mask = 0; - uint32_t literal_uses = UINT32_MAX; - uint32_t literal_value = 0; - - /* Iterate in reverse to prefer v_madak/v_fmaak. */ - for (int i = 2; i >= 0; i--) { - Operand& op = instr->operands[i]; - if (!op.isTemp()) - continue; - if (ctx.info[op.tempId()].is_literal(get_operand_type(instr, i).constant_bits())) { - uint32_t new_literal = ctx.info[op.tempId()].val; - float value = uif(new_literal); - uint16_t fp16_val = _mesa_float_to_half(value); - bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff; - if (_mesa_half_to_float(fp16_val) == value && - (!is_denorm || (ctx.fp_mode.denorm16_64 & fp_denorm_keep_in))) - fp16_mask |= 1 << i; - - if (!literal_mask || literal_value == new_literal) { - literal_value = new_literal; - literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]); - literal_mask |= 1 << i; - continue; - } - } - sgpr_mask |= op.isOfType(RegType::sgpr) << i; - vgpr_mask |= op.isOfType(RegType::vgpr) << i; - } - - /* The constant bus limitations before GFX10 disallows SGPRs. */ - if (sgpr_mask && ctx.program->gfx_level < GFX10) - literal_mask = 0; - - /* Encoding needs a vgpr. */ - if (!vgpr_mask) - literal_mask = 0; - - /* v_madmk/v_fmamk needs a vgpr in the third source. */ - if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100)) - literal_mask = 0; - - /* opsel with GFX11+ is the only modifier supported by fmamk/fmaak*/ - if (instr->valu().abs || instr->valu().neg || instr->valu().omod || instr->valu().clamp || - (instr->valu().opsel && ctx.program->gfx_level < GFX11)) - literal_mask = 0; - - if (instr->valu().opsel & ~vgpr_mask) - literal_mask = 0; - - /* We can't use three unique fp16 literals */ - if (fp16_mask == 0b111) - fp16_mask = 0b11; - - if ((instr->opcode == aco_opcode::v_fma_f32 || - (instr->opcode == aco_opcode::v_mad_f32 && !instr->definitions[0].isPrecise())) && - !instr->valu().omod && ctx.program->gfx_level >= GFX10 && - util_bitcount(fp16_mask) > std::max(util_bitcount(literal_mask), 1)) { - assert(ctx.program->dev.fused_mad_mix); - u_foreach_bit (i, fp16_mask) - ctx.uses[instr->operands[i].tempId()]--; - mad_info->fp16_mask = fp16_mask; - return; - } - - /* Limit the number of literals to apply to not increase the code - * size too much, but always apply literals for v_mad->v_madak - * because both instructions are 64-bit and this doesn't increase - * code size. - * TODO: try to apply the literals earlier to lower the number of - * uses below threshold - */ - if (literal_mask && (literal_uses < threshold || (literal_mask & 0b100))) { - u_foreach_bit (i, literal_mask) - ctx.uses[instr->operands[i].tempId()]--; - mad_info->literal_mask = literal_mask; - return; - } - } } /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions @@ -5781,82 +5684,80 @@ select_instruction(opt_ctx& ctx, aco_ptr& instr) ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); } - if (instr->isSDWA() || (instr->isVOP3() && ctx.program->gfx_level < GFX10) || - (instr->isVOP3P() && ctx.program->gfx_level < GFX10)) - return; /* some encodings can't ever take literals */ - - /* we do not apply the literals yet as we don't know if it is profitable */ - Operand current_literal(s1); - - unsigned literal_id = 0; - unsigned literal_uses = UINT32_MAX; - Operand literal(s1); - unsigned num_operands = 1; - if (instr->isSALU() || (ctx.program->gfx_level >= GFX10 && - (can_use_VOP3(ctx, instr) || instr->isVOP3P()) && !instr->isDPP())) - num_operands = instr->operands.size(); - /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */ - else if (instr->isVALU() && instr->operands.size() >= 3) + /* Check operands for whether we can apply constants or literals. */ + if (std::none_of(instr->operands.begin(), instr->operands.end(), + [&](const Operand& op) + { + if (!op.isTemp() || op.isFixed()) + return false; + auto& temp_info = ctx.info[op.tempId()]; + return temp_info.is_constant_or_literal(op.size() * 32); + })) return; - unsigned sgpr_ids[2] = {0, 0}; - bool is_literal_sgpr = false; - uint32_t mask = 0; - - /* choose a literal to apply */ - for (unsigned i = 0; i < num_operands; i++) { - Operand op = instr->operands[i]; - unsigned bits = get_operand_type(instr, i).constant_bits(); - - if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr && - op.tempId() != sgpr_ids[0]) - sgpr_ids[!!sgpr_ids[0]] = op.tempId(); - - if (op.isLiteral()) { - current_literal = op; - continue; - } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) { - continue; - } - - if (!alu_can_accept_constant(instr, i)) - continue; - - if (ctx.uses[op.tempId()] < literal_uses) { - is_literal_sgpr = op.getTemp().type() == RegType::sgpr; - mask = 0; - literal = Operand::c32(ctx.info[op.tempId()].val); - literal_uses = ctx.uses[op.tempId()]; - literal_id = op.tempId(); - } - - mask |= (op.tempId() == literal_id) << i; - } - - /* don't go over the constant bus limit */ - bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 || - instr->opcode == aco_opcode::v_lshlrev_b64 || - instr->opcode == aco_opcode::v_lshrrev_b64 || - instr->opcode == aco_opcode::v_ashrrev_i64; - unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX; - if (ctx.program->gfx_level >= GFX10 && !is_shift64) - const_bus_limit = 2; - - unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1]; - if (num_sgprs == const_bus_limit && !is_literal_sgpr) + alu_opt_info input_info; + if (!alu_opt_gather_info(ctx, instr.get(), input_info)) return; - if (literal_id && literal_uses < threshold && - (current_literal.isUndefined() || - (current_literal.size() == literal.size() && - current_literal.constantValue() == literal.constantValue()))) { - /* mark the literal to be applied */ - while (mask) { - unsigned i = u_bit_scan(&mask); - if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id) - ctx.uses[instr->operands[i].tempId()]--; + unsigned literal_mask = 0; + for (unsigned i = 0; i < input_info.operands.size(); i++) { + Operand op = input_info.operands[i].op; + if (!op.isTemp() || op.isFixed()) + continue; + auto& temp_info = ctx.info[op.tempId()]; + if (temp_info.is_constant_or_literal(op.size() * 32)) + literal_mask |= BITFIELD_BIT(i); + } + + alu_opt_info lit_info; + bool force_create = false; + unsigned lit_uses = threshold; + for (unsigned sub_mask = (~literal_mask + 1) & literal_mask; sub_mask; + sub_mask = ((sub_mask | ~literal_mask) + 1) & literal_mask) { + alu_opt_info candidate = input_info; + unsigned candidate_uses = UINT32_MAX; + u_foreach_bit (i, sub_mask) { + uint32_t tmpid = candidate.operands[i].op.tempId(); + candidate.operands[i].op = Operand::literal32(ctx.info[tmpid].val); + candidate_uses = MIN2(candidate_uses, ctx.uses[tmpid]); + } + if (!alu_opt_info_is_valid(ctx, candidate)) + continue; + + switch (candidate.opcode) { + case aco_opcode::v_fmaak_f32: + case aco_opcode::v_fmaak_f16: + case aco_opcode::v_madak_f32: + case aco_opcode::v_madak_f16: + /* This instruction won't be able to use fmac, so fmaak doesn't regress code size. */ + force_create = true; + break; + default: break; + } + + if (!force_create && util_bitcount(sub_mask) <= 1 && candidate_uses >= lit_uses) + continue; + lit_info = candidate; + lit_uses = candidate_uses; + + if (util_bitcount(sub_mask) > 1) { + force_create = true; + break; } } + if (!lit_info.operands.size()) + return; + + for (const auto& op_info : lit_info.operands) { + if (op_info.op.isTemp()) + ctx.uses[op_info.op.tempId()]++; + } + for (Operand op : instr->operands) { + if (op.isTemp()) + decrease_and_dce(ctx, op.getTemp()); + } + if (force_create || lit_uses == 1) + instr.reset(alu_opt_info_to_instr(ctx, lit_info, instr.release())); } static aco_opcode @@ -6027,78 +5928,24 @@ apply_literals(opt_ctx& ctx, aco_ptr& instr) if (!instr) return; - /* apply literals on MAD */ - if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) { - mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val]; - const bool madak = (info->literal_mask & 0b100); - bool has_dead_literal = false; - u_foreach_bit (i, info->literal_mask | info->fp16_mask) - has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0; - - if (has_dead_literal && info->fp16_mask) { - instr->format = Format::VOP3P; - instr->opcode = aco_opcode::v_fma_mix_f32; - - uint32_t literal = 0; - bool second = false; - u_foreach_bit (i, info->fp16_mask) { - float value = uif(ctx.info[instr->operands[i].tempId()].val); - literal |= _mesa_float_to_half(value) << (second * 16); - instr->valu().opsel_lo[i] = second; - instr->valu().opsel_hi[i] = true; - second = true; - } - - for (unsigned i = 0; i < 3; i++) { - if (info->fp16_mask & (1 << i)) - instr->operands[i] = Operand::literal32(literal); - } - - ctx.instructions.emplace_back(std::move(instr)); - return; - } - - if (has_dead_literal || madak) { - aco_opcode new_op = madak ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32; - if (instr->opcode == aco_opcode::v_fma_f32) - new_op = madak ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32; - else if (instr->opcode == aco_opcode::v_mad_f16 || - instr->opcode == aco_opcode::v_mad_legacy_f16) - new_op = madak ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16; - else if (instr->opcode == aco_opcode::v_fma_f16) - new_op = madak ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16; - - uint32_t literal = ctx.info[instr->operands[ffs(info->literal_mask) - 1].tempId()].val; - instr->format = Format::VOP2; - instr->opcode = new_op; - for (unsigned i = 0; i < 3; i++) { - if (info->literal_mask & (1 << i)) - instr->operands[i] = Operand::literal32(literal); - } - if (madak) { /* add literal -> madak */ - if (!instr->operands[1].isOfType(RegType::vgpr)) - instr->valu().swapOperands(0, 1); - } else { /* mul literal -> madmk */ - if (!(info->literal_mask & 0b10)) - instr->valu().swapOperands(0, 1); - instr->valu().swapOperands(1, 2); - } - ctx.instructions.emplace_back(std::move(instr)); - return; - } - } - - /* apply literals on other SALU/VALU */ + /* apply literals on SALU/VALU */ if (instr->isSALU() || instr->isVALU()) { - for (unsigned i = 0; i < instr->operands.size(); i++) { - Operand op = instr->operands[i]; - unsigned bits = get_operand_type(instr, i).constant_bits(); - if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) { - Operand literal = Operand::literal32(ctx.info[op.tempId()].val); - instr->format = withoutDPP(instr->format); - if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P) - instr->format = asVOP3(instr->format); - instr->operands[i] = literal; + for (const Operand& op : instr->operands) { + if (op.isTemp() && ctx.info[op.tempId()].is_literal(op.size() * 32) && + ctx.uses[op.tempId()] == 0) { + alu_opt_info info; + if (!alu_opt_gather_info(ctx, instr.get(), info)) + UNREACHABLE("We already check that we can apply lit"); + + for (auto& op_info : info.operands) { + if (op_info.op == op) + op_info.op = Operand::literal32(ctx.info[op.tempId()].val); + } + + if (!alu_opt_info_is_valid(ctx, info)) + UNREACHABLE("We already check that we can apply lit"); + instr.reset(alu_opt_info_to_instr(ctx, info, instr.release())); + break; } } } diff --git a/src/amd/compiler/tests/test_optimizer.cpp b/src/amd/compiler/tests/test_optimizer.cpp index 5d3340bca16..e82220dd51d 100644 --- a/src/amd/compiler/tests/test_optimizer.cpp +++ b/src/amd/compiler/tests/test_optimizer.cpp @@ -1701,14 +1701,14 @@ BEGIN_TEST(optimize.fmamix_two_literals) /* v_fma_mix_f32 is a fused mul/add, so it can't be used for precise separate mul/add. */ //~gfx10! v1: (precise)%res3 = v_madak_f32 %a, %c15, 0x40400000 - //~gfx10_3! v1: (precise)%res3_tmp = v_mul_f32 %a, 0x3fc00000 - //~gfx10_3! v1: %res3 = v_add_f32 %res3_tmp, 0x40400000 + //~gfx10_3! v1: (precise)%res3_tmp = v_mul_f32 0x3fc00000, %a + //~gfx10_3! v1: %res3 = v_add_f32 0x40400000, %res3_tmp //! p_unit_test 3, %res3 writeout(3, fadd(bld.precise().vop2(aco_opcode::v_mul_f32, bld.def(v1), a, c15), c30)); //~gfx10! v1: (precise)%res4 = v_madak_f32 %1, %c16, 0x40400000 - //~gfx10_3! v1: %res4_tmp = v_mul_f32 %a, 0x3fc00000 - //~gfx10_3! v1: (precise)%res4 = v_add_f32 %res4_tmp, 0x40400000 + //~gfx10_3! v1: %res4_tmp = v_mul_f32 0x3fc00000, %a + //~gfx10_3! v1: (precise)%res4 = v_add_f32 0x40400000, %res4_tmp //! p_unit_test 4, %res4 writeout(4, bld.precise().vop2(aco_opcode::v_add_f32, bld.def(v1), fmul(a, c15), c30)); @@ -1743,9 +1743,9 @@ BEGIN_TEST(optimize.fmamix_two_literals) writeout(7, fma(c15, c30, c45)); /* Modifiers must be preserved. */ - //! v1: %res8 = v_fma_mix_f32 -%a, lo(0x44804200), hi(0x44804200) + //! v1: %res8 = v_fma_mix_f32 |%a|, lo(0x44804200), hi(0x44804200) //! p_unit_test 8, %res8 - writeout(8, fma(fneg(a), c30, c45)); + writeout(8, fma(fabs(a), c30, c45)); //! v1: %res9 = v_fma_mix_f32 lo(0x44804200), |%a|, hi(0x44804200) //! p_unit_test 9, %res9 @@ -2026,7 +2026,7 @@ BEGIN_TEST(optimizer.trans_inline_constant) //! p_unit_test 5, %res5 writeout(5, bld.vop1(aco_opcode::v_rcp_f32, bld.def(v1), bld.copy(bld.def(s1), Operand::c32(0x3c00)))); - //! v2b: %res6 = v_rcp_f16 0x3f800000 + //! v2b: %res6 = v_rcp_f16 0 //! p_unit_test 6, %res6 writeout(6, bld.vop1(aco_opcode::v_rcp_f16, bld.def(v2b), bld.copy(bld.def(s1), Operand::c32(0x3f800000))));