From 8de89f4ffb19e6f371e68ee64ea66938530109d7 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Thu, 27 Mar 2025 21:20:39 +0100 Subject: [PATCH] aco/optimizer: add alu_opt_info helpers Reviewed-by: Rhys Perry Part-of: --- src/amd/compiler/aco_optimizer.cpp | 1146 ++++++++++++++++++++++++++++ 1 file changed, 1146 insertions(+) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 7eb629efee4..3a33e036470 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -347,6 +347,1152 @@ struct opt_ctx { std::vector uses; }; +aco_type +get_canonical_operand_type(aco_opcode opcode, unsigned idx) +{ + aco_type type = instr_info.alu_opcode_infos[(int)opcode].op_types[idx]; + + if (type.bit_size == 8 && type.num_components > 1) { + /* Handling packed fp8/bf8 as non vector is easier. */ + type.bit_size *= type.num_components; + type.num_components = 1; + type.base_type = aco_base_type_none; + } + + return type; +} + +bool +dpp16_ctrl_uses_bc(uint16_t dpp_ctrl) +{ + if (dpp_ctrl >= dpp_row_sl(1) && dpp_ctrl <= dpp_row_sl(15)) + return true; + if (dpp_ctrl >= dpp_row_sr(1) && dpp_ctrl <= dpp_row_sr(15)) + return true; + if (dpp_ctrl == dpp_wf_sl1 || dpp_ctrl == dpp_wf_sr1) + return true; + if (dpp_ctrl == dpp_row_bcast15 || dpp_ctrl == dpp_row_bcast31) + return true; + return false; +} + +struct alu_opt_op { + Operand op; + SubdwordSel extract[2] = {SubdwordSel::dword, SubdwordSel::dword}; + union { + uint16_t _modifiers = 0; + bitfield_array8 neg; + bitfield_array8 abs; + bitfield_bool f16_to_f32; + bitfield_bool dot_sext; + bitfield_bool dpp16; + bitfield_bool dpp8; + bitfield_bool bc; + bitfield_bool fi; + }; + uint32_t dpp_ctrl = 0; + + alu_opt_op& operator=(const alu_opt_op& other) + { + memmove((void*)this, &other, sizeof(*this)); + + return *this; + } + + alu_opt_op() = default; + alu_opt_op(Operand _op) : op(_op) {}; + alu_opt_op(const alu_opt_op& other) { *this = other; } + + uint64_t constant_after_mods(opt_ctx& ctx, aco_type type) const + { + assert(this->op.isConstant()); + uint64_t res = 0; + for (unsigned comp = 0; comp < type.num_components; comp++) { + uint64_t part = this->op.constantValue64(); + /* 16bit negative int inline constants are sign extended, constantValue16 handles that. */ + if (this->op.bytes() == 2) + part = this->op.constantValue16(false) | (this->op.constantValue16(true) << 16); + + if (type.bytes() <= 4) { + SubdwordSel sel = this->extract[comp]; + part = part >> (sel.offset() * 8); + if (sel.size() < 4) { + part &= BITFIELD_MASK(sel.size() * 8); + part = sel.sign_extend() ? util_sign_extend(part, sel.size() * 8) : part; + } + } + + if (this->f16_to_f32) { + if (!(ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)) { + uint32_t absv = part & 0x7fff; + if (absv <= 0x3ff) + part &= 0x8000; + } + part = fui(_mesa_half_to_float(part)); + } + + part &= BITFIELD64_MASK(type.bit_size - this->abs[comp]); + part ^= this->neg[comp] ? BITFIELD64_BIT(type.bit_size - 1) : 0; + res |= part << (type.bit_size * comp); + } + + return res; + } +}; + +struct alu_opt_info { + aco::small_vec defs; + aco::small_vec operands; + aco_opcode opcode; + Format format; + uint32_t imm; + uint32_t pass_flags; /* exec id */ + + /* defs[0] modifiers */ + uint8_t omod; + bool clamp; + bool f32_to_f16; + SubdwordSel insert; + + bool try_swap_operands(unsigned idx0, unsigned idx1) + { + aco_opcode new_opcode = get_swapped_opcode(opcode, idx0, idx1); + if (new_opcode != aco_opcode::num_opcodes) { + opcode = new_opcode; + std::swap(operands[idx0], operands[idx1]); + return true; + } + return false; + } +}; + +bool +at_most_6lsb_used(aco_opcode op, unsigned idx) +{ + if (op == aco_opcode::v_writelane_b32 || op == aco_opcode::v_writelane_b32_e64 || + op == aco_opcode::v_readlane_b32 || op == aco_opcode::v_readlane_b32_e64) + return idx == 1; + return false; +} + +unsigned +bytes_used(opt_ctx& ctx, alu_opt_info& info, unsigned idx) +{ + unsigned used = 4; + aco_type type = get_canonical_operand_type(info.opcode, idx); + if (type.bytes() == 0) + return 4; + used = MIN2(used, type.bytes()); + if (info.opcode == aco_opcode::v_lshlrev_b32 && idx == 1 && info.operands[0].op.isConstant()) { + unsigned shift = info.operands[0].op.constantValue() & 0x1f; + if (shift >= 16) + used = MIN2(used, 2); + if (shift >= 24) + used = MIN2(used, 1); + } + return used; +} + +bool +optimize_constants(opt_ctx& ctx, alu_opt_info& info) +{ + /* inline constants, pack literals */ + uint32_t literal = 0; + unsigned litbits_used = 0; + bool force_f2f32 = false; + for (unsigned i = 0; i < info.operands.size(); i++) { + auto& op_info = info.operands[i]; + assert(!op_info.op.isUndefined()); + if (!op_info.op.isConstant()) + continue; + + aco_type type = get_canonical_operand_type(info.opcode, i); + + if (type.num_components != 1 && type.num_components != 2) + return false; + if (!type.constant_bits()) + return false; + + if (type.bytes() > 4) + continue; + + /* remove modifiers on constants: apply extract, f2f32, abs, neg */ + assert(op_info.op.size() == 1); + uint32_t constant = op_info.constant_after_mods(ctx, type); + op_info.op = Operand(); + for (unsigned comp = 0; comp < type.num_components; comp++) { + op_info.extract[comp] = SubdwordSel(type.bit_size / 8, comp * type.bit_size / 8, false); + op_info.f16_to_f32 = false; + op_info.neg[comp] = false; + op_info.abs[comp] = false; + } + + if (at_most_6lsb_used(info.opcode, i)) + constant &= 0x3f; + + bool can_use_mods = can_use_input_modifiers(ctx.program->gfx_level, info.opcode, i); + + /* inline constants */ + if (type.num_components == 1) { + Operand new_op = + Operand::get_const(ctx.program->gfx_level, constant, type.constant_bits() / 8); + Operand neg_op = + Operand::get_const(ctx.program->gfx_level, BITFIELD_BIT(type.bit_size - 1) ^ constant, + type.constant_bits() / 8); + Operand sext_op = Operand::get_const(ctx.program->gfx_level, 0xffff0000 | constant, + type.constant_bits() / 8); + if (!new_op.isLiteral()) { + op_info.op = new_op; + } else if (can_use_mods && !neg_op.isLiteral()) { + op_info.op = neg_op; + op_info.neg[0] = true; + } else if (type.bit_size == 16 && !sext_op.isLiteral()) { + op_info.op = sext_op; + } + // TODO opsel? + } else if (info.format == Format::VOP3P) { + assert(!can_use_mods || type.constant_bits() == 16); + unsigned num_methods = (type.constant_bits() == 32 ? 5 : 1); + for (unsigned hi = 0; op_info.op.isUndefined() && hi < 2; hi++) { + for (unsigned negate = 0; + op_info.op.isUndefined() && (negate <= unsigned(can_use_mods)); negate++) { + for (unsigned method = 0; op_info.op.isUndefined() && method < num_methods; + method++) { + uint32_t candidate = ((constant >> (hi * 16)) & 0xffff) ^ (negate ? 0x8000 : 0); + switch (method) { + case 0: break; /* try directly as constant */ + case 1: candidate |= 0xffff0000; break; /* sign extend */ + case 2: candidate |= 0x3e220000; break; /* 0.5pi */ + case 3: candidate = (candidate << 16); break; /* high half */ + case 4: candidate = (candidate << 16) | 0xf983; break; /* high half, 0.5pi. */ + default: UNREACHABLE("impossible"); + } + Operand new_op = Operand::get_const(ctx.program->gfx_level, candidate, + type.constant_bits() / 8); + if (new_op.isLiteral()) + continue; + + for (unsigned opsel = 0; op_info.op.isUndefined() && opsel < 2; opsel++) { + uint16_t other = constant >> (!hi * 16); + uint16_t abs_mask = 0xffffu >> unsigned(can_use_mods); + if ((new_op.constantValue16(opsel) & abs_mask) != (other & abs_mask)) + continue; + op_info.op = new_op; + op_info.extract[hi] = method >= 3 ? SubdwordSel::uword1 : SubdwordSel::uword0; + op_info.extract[!hi] = opsel ? SubdwordSel::uword1 : SubdwordSel::uword0; + op_info.neg[hi] = negate; + op_info.neg[!hi] = new_op.constantValue16(opsel) ^ other; + } + } + } + } + } + + /* we found an inline constant */ + if (!op_info.op.isUndefined()) + continue; + + bool use_swizzle = type.num_components == 2 && info.format == Format::VOP3P; + bool try_neg = can_use_mods && (type.num_components == 1 || use_swizzle); + unsigned comp_bits = use_swizzle ? type.bit_size : type.bytes() * 8; + assert(comp_bits == 32 || comp_bits == 16); + uint32_t abs_mask = BITFIELD_MASK(comp_bits - try_neg); + for (unsigned comp = 0; comp <= unsigned(use_swizzle); comp++) { + uint32_t part = constant >> (comp * comp_bits) & BITFIELD_MASK(comp_bits); + + /* Try to re-use another literal, or part of it. */ + bool found_part = false; + for (unsigned litcomp = 0; litcomp < (litbits_used / comp_bits); litcomp++) { + uint32_t litpart = literal >> (litcomp * comp_bits) & BITFIELD_MASK(comp_bits); + if ((litpart & abs_mask) == (part & abs_mask)) { + op_info.neg[comp] = litpart ^ part; + op_info.extract[comp] = SubdwordSel(comp_bits / 8, litcomp * (comp_bits / 8), false); + found_part = true; + } + } + + if (found_part) + continue; + + /* If there isn't enough space for more literal data, try to use fp16 or return false. */ + litbits_used = align(litbits_used, comp_bits); + if (litbits_used + comp_bits > 32) { + if (comp_bits == 32 && !force_f2f32) { + float f32s[] = {uif(literal), uif(constant)}; + literal = 0; + for (unsigned fltidx = 0; fltidx < 2; fltidx++) { + uint32_t fp16_val = _mesa_float_to_half(f32s[fltidx]); + bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff; + if (_mesa_half_to_float(fp16_val) != f32s[fltidx] || + (is_denorm && !(ctx.fp_mode.denorm16_64 & fp_denorm_keep_in))) + return false; + literal |= fp16_val << (fltidx * 16); + } + force_f2f32 = true; + op_info.extract[0] = SubdwordSel::uword1; + break; + } + return false; + } + + literal |= part << litbits_used; + op_info.extract[comp] = SubdwordSel(comp_bits / 8, litbits_used / 8, false); + litbits_used += comp_bits; + } + } + + for (auto& op_info : info.operands) { + if (!op_info.op.isUndefined()) + continue; + op_info.op = Operand::literal32(literal); + op_info.f16_to_f32 = force_f2f32; + } + + return true; +} + +Format +format_combine(Format f1, Format f2) +{ + return (Format)((uint32_t)f1 | (uint32_t)f2); +} + +bool +format_is(Format f1, Format f2) +{ + return ((Format)((uint32_t)f1 & (uint32_t)f2)) == f2; +} + +/* Determine if this alu_opt_info can be represented by a valid ACO IR instruction. + * info is modified to not duplicate work when it's converted to an ACO IR instruction. + * If false is returned, info must no longer be used. + */ +bool +alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info) +{ + info.format = instr_info.format[(int)info.opcode]; + + /* remove dpp if possible, abort in some unsupported cases (bc with sgpr, constant.) */ + for (auto& op_info : info.operands) { + if (!op_info.dpp16 && !op_info.dpp8) + continue; + if (op_info.op.isOfType(RegType::vgpr)) + continue; + /* bc=0: undefined if inactive read (lane disabled, but that's not expressed in SSA) + * if fi=1, bc only matters for a few dpp16 options + */ + if (op_info.bc && (!op_info.fi || (op_info.dpp16 && dpp16_ctrl_uses_bc(op_info.dpp_ctrl)))) + return false; + op_info.dpp16 = false; + op_info.dpp8 = false; + } + + /* if mul, push neg to constant, eliminate double negate */ + switch (info.opcode) { + case aco_opcode::v_mul_f64_e64: + case aco_opcode::v_mul_f64: + case aco_opcode::v_mul_f32: + case aco_opcode::v_mul_legacy_f32: + case aco_opcode::v_mul_f16: + case aco_opcode::v_mad_f32: + case aco_opcode::v_mad_legacy_f32: + case aco_opcode::v_mad_f16: + case aco_opcode::v_mad_legacy_f16: + case aco_opcode::v_fma_f64: + case aco_opcode::v_fma_f32: + case aco_opcode::v_fma_legacy_f32: + case aco_opcode::v_fma_f16: + case aco_opcode::v_fma_legacy_f16: + case aco_opcode::v_fma_mix_f32: + case aco_opcode::v_fma_mixlo_f16: + case aco_opcode::v_pk_mul_f16: + case aco_opcode::v_pk_fma_f16: + case aco_opcode::s_mul_f32: + case aco_opcode::s_mul_f16: + case aco_opcode::s_fmac_f32: + case aco_opcode::s_fmac_f16: + for (unsigned comp = 0; comp < 2; comp++) { + for (unsigned i = 0; i < 2; i++) { + if (info.operands[!i].op.isConstant() || info.operands[!i].neg[comp]) { + info.operands[!i].neg[comp] ^= info.operands[i].neg[comp]; + info.operands[i].neg[comp] = false; + } + } + } + break; + default: break; + } + + if (!optimize_constants(ctx, info)) + return false; + + /* check constant bus limit */ + bool is_salu = false; + switch (info.format) { + case Format::SOPC: + case Format::SOPK: + case Format::SOP1: + case Format::SOP2: + case Format::SOPP: is_salu = true; break; + default: break; + } + int constant_limit = is_salu ? INT_MAX : (ctx.program->gfx_level >= GFX10 ? 2 : 1); + + switch (info.opcode) { + case aco_opcode::v_writelane_b32: + case aco_opcode::v_writelane_b32_e64: constant_limit = INT_MAX; break; + case aco_opcode::v_lshlrev_b64: + case aco_opcode::v_lshlrev_b64_e64: + case aco_opcode::v_lshrrev_b64: + case aco_opcode::v_ashrrev_i64: constant_limit = 1; break; + default: break; + } + + for (unsigned i = 0; i < info.operands.size(); i++) { + const Operand& op = info.operands[i].op; + if (!op.isLiteral() && !op.isOfType(RegType::sgpr)) + continue; + + constant_limit--; + for (unsigned j = 0; j < i; j++) { + const Operand& other = info.operands[j].op; + if (op == other) { + constant_limit++; + break; + } else if (op.isLiteral() && other.isLiteral()) { + return false; + } + } + } + + if (constant_limit < 0) + return false; + + /* apply extract. */ + if (info.opcode == aco_opcode::s_pack_ll_b32_b16) { + if (info.operands[0].extract[0].size() < 2 || info.operands[1].extract[0].size() < 2) + return false; + if (info.operands[0].extract[0].offset() == 2 && info.operands[1].extract[0].offset() == 2) { + info.opcode = aco_opcode::s_pack_hh_b32_b16; + } else if (info.operands[0].extract[0].offset() == 0 && + info.operands[1].extract[0].offset() == 2) { + info.opcode = aco_opcode::s_pack_lh_b32_b16; + } else if (info.operands[0].extract[0].offset() == 2 && + info.operands[1].extract[0].offset() == 0) { + if (ctx.program->gfx_level < GFX11) /* TODO try shifting constant */ + return false; + info.opcode = aco_opcode::s_pack_hl_b32_b16; + } + info.operands[0].extract[0] = SubdwordSel::dword; + info.operands[1].extract[0] = SubdwordSel::dword; + } + + for (unsigned i = 0; i < info.operands.size(); i++) { + aco_type type = get_canonical_operand_type(info.opcode, i); + if (type.bit_size == 16 && type.num_components == 2) { + for (unsigned comp = 0; comp < 2; comp++) { + SubdwordSel sel = info.operands[i].extract[comp]; + if (sel.size() < 2) + return false; + if (info.format != Format::VOP3P && sel.offset() != 2 * comp) + return false; + } + continue; + } + SubdwordSel sel = info.operands[i].extract[0]; + if (sel.size() == 4) { + continue; + } else if (info.operands[i].f16_to_f32 && sel.size() < 2) { + return false; + } else if (info.operands[i].f16_to_f32 && sel.size() == 2) { + continue; + } else if (sel.offset() == 0 && sel.size() >= bytes_used(ctx, info, i)) { + info.operands[i].extract[0] = SubdwordSel::dword; + } else if ((info.opcode == aco_opcode::v_cvt_f32_u32 || + info.opcode == aco_opcode::v_cvt_f32_i32) && + sel.size() == 1 && !sel.sign_extend()) { + switch (sel.offset()) { + case 0: info.opcode = aco_opcode::v_cvt_f32_ubyte0; break; + case 1: info.opcode = aco_opcode::v_cvt_f32_ubyte1; break; + case 2: info.opcode = aco_opcode::v_cvt_f32_ubyte2; break; + case 3: info.opcode = aco_opcode::v_cvt_f32_ubyte3; break; + default: UNREACHABLE("invalid SubdwordSel"); + } + info.operands[i].extract[0] = SubdwordSel::dword; + continue; + } else if (info.opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 && + sel.size() == 2 && !sel.sign_extend() && + !info.operands[!i].extract[0].sign_extend() && + info.operands[!i].extract[0].size() >= 2 && + (info.operands[!i].op.is16bit() || info.operands[!i].extract[0].size() == 2 || + (info.operands[!i].op.isConstant() && + info.operands[!i].op.constantValue() <= UINT16_MAX))) { + info.opcode = aco_opcode::v_mad_u32_u16; + info.format = Format::VOP3; + info.operands.push_back(alu_opt_op{}); + info.operands[2].op = Operand::c32(0); + continue; + } else if (i < 2 && ctx.program->gfx_level >= GFX8 && ctx.program->gfx_level < GFX11 && + (format_is(info.format, Format::VOPC) || format_is(info.format, Format::VOP2) || + format_is(info.format, Format::VOP1))) { + info.format = format_combine(info.format, Format::SDWA); + continue; + } else if (sel.size() == 2 && can_use_opsel(ctx.program->gfx_level, info.opcode, i)) { + continue; + } else if (info.opcode == aco_opcode::s_cvt_f32_f16 && sel.size() == 2 && sel.offset() == 2) { + info.opcode = aco_opcode::s_cvt_hi_f32_f16; + info.operands[i].extract[0] = SubdwordSel::dword; + continue; + } else { + return false; + } + } + + /* convert to v_fma_mix */ + bool uses_f2f32 = false; + for (auto& op_info : info.operands) + uses_f2f32 |= op_info.f16_to_f32; + + if (uses_f2f32 || info.f32_to_f16) { + if (ctx.program->gfx_level < GFX9) + return false; + + /* v_mad_mix* on GFX9 always flushes denormals for 16-bit inputs/outputs */ + if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64) + return false; + + switch (info.opcode) { + case aco_opcode::v_add_f32: + info.operands.insert(info.operands.begin(), alu_opt_op{}); + info.operands[0].op = Operand::c32(0x3f800000); + break; + case aco_opcode::v_mul_f32: + info.operands.push_back(alu_opt_op{}); + info.operands[2].op = Operand::c32(0); + info.operands[2].neg[0] = true; + break; + case aco_opcode::v_fma_f32: + // TODO remove precise, not clear why unfusing fma would be valid + if (!ctx.program->dev.fused_mad_mix && info.defs[0].isPrecise()) + return false; + break; + case aco_opcode::v_mad_f32: + if (ctx.program->dev.fused_mad_mix && info.defs[0].isPrecise()) + return false; + break; + default: return false; + } + + info.opcode = info.f32_to_f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32; + info.format = Format::VOP3P; + } + + /* remove negate modifiers by converting to subtract */ + aco_opcode sub = aco_opcode::num_opcodes; + aco_opcode subrev = aco_opcode::num_opcodes; + switch (info.opcode) { + case aco_opcode::v_add_f32: + sub = aco_opcode::v_sub_f32; + subrev = aco_opcode::v_subrev_f32; + break; + case aco_opcode::v_add_f16: + sub = aco_opcode::v_sub_f16; + subrev = aco_opcode::v_subrev_f16; + break; + case aco_opcode::s_add_f32: sub = aco_opcode::s_sub_f32; break; + case aco_opcode::s_add_f16: sub = aco_opcode::s_sub_f16; break; + default: break; + } + + if (sub != aco_opcode::num_opcodes && (info.operands[0].neg[0] ^ info.operands[1].neg[0])) { + if (info.operands[1].neg[0]) { + info.opcode = sub; + } else if (subrev != aco_opcode::num_opcodes) { + info.opcode = subrev; + } else { + info.opcode = sub; + std::swap(info.operands[0], info.operands[1]); + } + + info.operands[0].neg[0] = false; + info.operands[1].neg[0] = false; + } + + /* convert to DPP */ + bool is_dpp = false; + for (unsigned i = 0; i < info.operands.size(); i++) { + if (info.operands[i].dpp16 || info.operands[i].dpp8) { + if (is_dpp || !info.try_swap_operands(0, i)) + return false; + + is_dpp = true; + if (info.operands[0].dpp16) + info.format = format_combine(info.format, Format::DPP16); + else if (info.operands[0].dpp8) + info.format = format_combine(info.format, Format::DPP8); + } + } + if (is_dpp && info.operands.size() > 2 && !info.operands[1].op.isOfType(RegType::vgpr) && + info.operands[2].op.isOfType(RegType::vgpr)) + info.try_swap_operands(1, 2); + if (is_dpp && info.operands.size() > 1 && !info.operands[1].op.isOfType(RegType::vgpr)) + return false; /* TODO: gfx11.5 */ + + /* dst SDWA */ + if (info.insert != SubdwordSel::dword) { + if (info.insert.offset() == 0 && info.insert.size() >= info.defs[0].bytes()) { + info.insert = SubdwordSel::dword; + } else if (info.defs[0].bytes() != 4 || + (!format_is(info.format, Format::VOP1) && !format_is(info.format, Format::VOP2))) { + return false; + } else { + info.format = format_combine(info.format, Format::SDWA); + } + } + + /* DPP and SDWA can't be used at the same time. */ + if (is_dpp && format_is(info.format, Format::SDWA)) + return false; + + bool is_dpp_or_sdwa = is_dpp || format_is(info.format, Format::SDWA); + + bitarray8 neg = 0; + bitarray8 abs = 0; + bitarray8 opsel = 0; + bitarray8 vmask = 0; + bitarray8 smask = 0; + bitarray8 cmask = 0; + bitarray8 lmask = 0; + + for (unsigned i = 0; i < info.operands.size(); i++) { + aco_type type = get_canonical_operand_type(info.opcode, i); + bool can_use_mods = can_use_input_modifiers(ctx.program->gfx_level, info.opcode, i); + const auto& op_info = info.operands[i]; + + if (!format_is(info.format, Format::VOP3P) && type.num_components == 2 && + (op_info.neg[0] != op_info.neg[1] || op_info.abs[0] != op_info.abs[1])) + return false; + + for (unsigned comp = 0; comp < type.num_components; comp++) { + if (!can_use_mods && (op_info.neg[comp] || op_info.abs[comp])) + return false; + abs[i] |= op_info.abs[comp]; + neg[i] |= op_info.neg[comp]; + } + opsel[i] = op_info.extract[0].offset(); + vmask[i] = op_info.op.isOfType(RegType::vgpr); + smask[i] = op_info.op.isOfType(RegType::sgpr); + cmask[i] = op_info.op.isConstant(); + lmask[i] = op_info.op.isLiteral(); + + /* lane masks must be sgpr */ + if (type.bit_size == 1 && !smask[i]) + return false; + + /* DPP/SDWA doesn't allow 64bit opcodes. */ + if (is_dpp_or_sdwa && info.operands[i].op.size() != 1 && type.bit_size != 1) + return false; + } + + /* DPP/SDWA doesn't allow 64bit opcodes. */ + if (is_dpp_or_sdwa && !format_is(info.format, Format::VOPC) && info.defs[0].size() != 1) + return false; + + if (format_is(info.format, Format::VOP1) || format_is(info.format, Format::VOP2) || + format_is(info.format, Format::VOPC) || format_is(info.format, Format::VOP3)) { + bool needs_vop3 = false; + if (info.omod && format_is(info.format, Format::SDWA) && ctx.program->gfx_level < GFX9) + return false; + + if (info.omod && !format_is(info.format, Format::SDWA)) + needs_vop3 = true; + + if (info.clamp && format_is(info.format, Format::SDWA) && + format_is(info.format, Format::VOPC) && ctx.program->gfx_level >= GFX9) + return false; + + if ((info.clamp || (opsel & ~vmask)) && !format_is(info.format, Format::SDWA)) + needs_vop3 = true; + + if (!format_is(info.format, Format::SDWA) && !format_is(info.format, Format::DPP16) && + (abs || neg)) + needs_vop3 = true; + + if (((cmask | smask) & 0x3) && format_is(info.format, Format::SDWA) && + ctx.program->gfx_level == GFX8) + return false; + + aco_opcode mulk = aco_opcode::num_opcodes; + aco_opcode addk = aco_opcode::num_opcodes; + switch (info.opcode) { + case aco_opcode::v_s_exp_f16: + case aco_opcode::v_s_log_f16: + case aco_opcode::v_s_rcp_f16: + case aco_opcode::v_s_rsq_f16: + case aco_opcode::v_s_sqrt_f16: + /* These can't use inline constants on GFX12 but can use literals. We don't bother since + * they should be constant folded anyway. */ + if (cmask) + return false; + FALLTHROUGH; + case aco_opcode::v_s_exp_f32: + case aco_opcode::v_s_log_f32: + case aco_opcode::v_s_rcp_f32: + case aco_opcode::v_s_rsq_f32: + case aco_opcode::v_s_sqrt_f32: + if (vmask) + return false; + break; + case aco_opcode::v_writelane_b32: + case aco_opcode::v_writelane_b32_e64: + if ((vmask & 0x3) || (~vmask & 0x4)) + return false; + if (is_dpp || format_is(info.format, Format::SDWA)) + return false; + break; + case aco_opcode::v_permlane16_b32: + case aco_opcode::v_permlanex16_b32: + case aco_opcode::v_permlane64_b32: + case aco_opcode::v_readfirstlane_b32: + case aco_opcode::v_readlane_b32: + case aco_opcode::v_readlane_b32_e64: + if ((~vmask & 0x1) || (vmask & 0x6)) + return false; + if (is_dpp || format_is(info.format, Format::SDWA)) + return false; + break; + case aco_opcode::v_mul_lo_u32: + case aco_opcode::v_mul_lo_i32: + case aco_opcode::v_mul_hi_u32: + case aco_opcode::v_mul_hi_i32: + if (is_dpp) + return false; + break; + case aco_opcode::v_fma_f32: + if (ctx.program->gfx_level >= GFX10) { + mulk = aco_opcode::v_fmamk_f32; + addk = aco_opcode::v_fmaak_f32; + } + break; + case aco_opcode::v_fma_f16: + case aco_opcode::v_fma_legacy_f16: + if (ctx.program->gfx_level >= GFX10) { + mulk = aco_opcode::v_fmamk_f16; + addk = aco_opcode::v_fmaak_f16; + } + break; + case aco_opcode::v_mad_f32: + mulk = aco_opcode::v_madmk_f32; + addk = aco_opcode::v_madak_f32; + break; + case aco_opcode::v_mad_f16: + case aco_opcode::v_mad_legacy_f16: + mulk = aco_opcode::v_madmk_f16; + addk = aco_opcode::v_madak_f16; + break; + default: + if ((smask[1] || cmask[1]) && !needs_vop3 && !format_is(info.format, Format::VOP3) && + !format_is(info.format, Format::SDWA)) { + if (is_dpp || !vmask[0] || !info.try_swap_operands(0, 1)) + needs_vop3 = true; + } + if (needs_vop3) + info.format = format_combine(info.format, Format::VOP3); + } + + if (addk != aco_opcode::num_opcodes && vmask && lmask && !needs_vop3 && + (vmask[2] || lmask[2]) && (!opsel || ctx.program->gfx_level >= GFX11)) { + for (int i = 2; i >= 0; i--) { + if (lmask[i]) { + if (i == 0 || (i == 2 && !vmask[1])) + std::swap(info.operands[0], info.operands[1]); + if (i != 2) + std::swap(info.operands[1], info.operands[2]); + info.opcode = i == 2 ? addk : mulk; + info.format = Format::VOP2; + break; + } + } + } + + bool nolit = format_is(info.format, Format::SDWA) || is_dpp || + (format_is(info.format, Format::VOP3) && ctx.program->gfx_level < GFX10); + if (nolit && lmask) + return false; + if (is_dpp && format_is(info.format, Format::VOP3) && ctx.program->gfx_level < GFX11) + return false; + + /* Fix lane mask src/dst to vcc if the format requires it. */ + if (ctx.program->gfx_level < GFX11 && (is_dpp || format_is(info.format, Format::SDWA))) { + if (format_is(info.format, Format::VOP2)) { + if (info.operands.size() > 2) + info.operands[2].op.setPrecolored(vcc); + if (info.defs.size() > 1) + info.defs[1].setPrecolored(vcc); + } + if (format_is(info.format, Format::VOPC) && (is_dpp || ctx.program->gfx_level < GFX9) && + !info.defs[0].isFixed()) + info.defs[0].setPrecolored(vcc); + } + } else if (format_is(info.format, Format::VOP3P)) { + bool fmamix = + info.opcode == aco_opcode::v_fma_mix_f32 || info.opcode == aco_opcode::v_fma_mixlo_f16; + bool dot2_f32 = + info.opcode == aco_opcode::v_dot2_f32_f16 || info.opcode == aco_opcode::v_dot2_f32_bf16; + bool supports_dpp = (fmamix || dot2_f32) && ctx.program->gfx_level >= GFX11; + if ((abs && !fmamix) || (is_dpp && !supports_dpp) || info.omod) + return false; + if (lmask && (ctx.program->gfx_level < GFX10 || is_dpp)) + return false; + } else if (is_salu) { + if (vmask) + return false; + if (info.opcode == aco_opcode::s_fmac_f32) { + for (unsigned i = 0; i < 2; i++) { + if (lmask[i]) { + std::swap(info.operands[i], info.operands[2]); + info.opcode = aco_opcode::s_fmamk_f32; + break; + } + } + if (info.opcode == aco_opcode::s_fmac_f32 && cmask[2]) { + info.operands[2].op = Operand::literal32(info.operands[2].op.constantValue()); + lmask[2] = true; + info.opcode = aco_opcode::s_fmaak_f32; + } + } else if (info.opcode == aco_opcode::s_fmac_f16 && !smask[2]) { + return false; + } + } + + return true; +} + +/* Gather semantic information about an alu instruction and its operands from an ACO IR Instruction. + * + * Some callers expect that the alu_opt_info created by alu_opt_gather_info() or the instruction + * created by alu_opt_info_to_instr() does not have more uses of a temporary than the original + * instruction did. + */ +bool +alu_opt_gather_info(opt_ctx& ctx, Instruction* instr, alu_opt_info& info) +{ + if (!instr->isVALU() && !instr->isSALU()) + return false; + + /* There is nothing to be gained from handling WMMA/mqsad here. */ + if (instr_info.classes[(int)instr->opcode] == instr_class::wmma || + instr->opcode == aco_opcode::v_mqsad_u32_u8) + return false; + + /* TODO handle when this is used for output modifiers. */ + if (instr->isVINTERP_INREG()) + return false; + + switch (instr->opcode) { + case aco_opcode::s_addk_i32: + case aco_opcode::s_cmovk_i32: + case aco_opcode::s_mulk_i32: + case aco_opcode::v_dot2c_f32_f16: + case aco_opcode::v_dot4c_i32_i8: + case aco_opcode::v_fmac_f32: + case aco_opcode::v_fmac_f16: + case aco_opcode::v_fmac_legacy_f32: + case aco_opcode::v_mac_f32: + case aco_opcode::v_mac_f16: + case aco_opcode::v_mac_legacy_f32: + case aco_opcode::v_pk_fmac_f16: UNREACHABLE("Only created by RA."); return false; + default: break; + } + + info = {}; + + info.opcode = instr->opcode; + info.pass_flags = instr->pass_flags; + + if (instr->isSALU()) + info.imm = instr->salu().imm; + + bitarray8 opsel = 0; + if (instr->isVALU()) { + info.omod = instr->valu().omod; + info.clamp = instr->valu().clamp; + opsel = instr->valu().opsel; + } + + if (instr->opcode == aco_opcode::v_permlane16_b32 || + instr->opcode == aco_opcode::v_permlanex16_b32) { + info.imm = opsel; + opsel = 0; + } + + if (instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16) { + info.opcode = ctx.program->dev.fused_mad_mix ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32; + info.f32_to_f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16; + } + + if (instr->isSDWA()) + info.insert = instr->sdwa().dst_sel; + else + info.insert = SubdwordSel::dword; + + for (Definition& def : instr->definitions) + info.defs.push_back(def); + + for (unsigned i = 0; i < instr->operands.size(); i++) { + alu_opt_op op_info = {}; + op_info.op = instr->operands[i]; + if (instr->opcode == aco_opcode::v_fma_mix_f32 || + instr->opcode == aco_opcode::v_fma_mixlo_f16) { + op_info.neg[0] = instr->valu().neg[i]; + op_info.abs[0] = instr->valu().abs[i]; + if (instr->valu().opsel_hi[i]) { + op_info.f16_to_f32 = true; + if (instr->valu().opsel_lo[i]) + op_info.extract[0] = SubdwordSel::uword1; + } + } else if (instr->isVOP3P()) { + op_info.neg[0] = instr->valu().neg_lo[i]; + op_info.neg[1] = instr->valu().neg_hi[i]; + if (instr->valu().opsel_lo[i]) + op_info.extract[0] = SubdwordSel::uword1; + if (instr->valu().opsel_hi[i]) + op_info.extract[1] = SubdwordSel::uword1; + } else if (instr->isVALU() && i < 3) { + op_info.neg[0] = instr->valu().neg[i]; + op_info.neg[1] = instr->valu().neg[i]; + op_info.abs[0] = instr->valu().abs[i]; + op_info.abs[1] = instr->valu().abs[i]; + if (opsel[i]) + op_info.extract[0] = SubdwordSel::uword1; + op_info.extract[1] = SubdwordSel::uword1; + + if (i < 2 && instr->isSDWA()) + op_info.extract[0] = instr->sdwa().sel[i]; + } + + info.operands.push_back(op_info); + } + + if (instr->isDPP16()) { + info.operands[0].dpp16 = true; + info.operands[0].dpp_ctrl = instr->dpp16().dpp_ctrl; + info.operands[0].fi = instr->dpp16().fetch_inactive; + info.operands[0].bc = instr->dpp16().bound_ctrl; + assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf); + } else if (instr->isDPP8()) { + info.operands[0].dpp8 = true; + info.operands[0].dpp_ctrl = instr->dpp8().lane_sel; + info.operands[0].fi = instr->dpp8().fetch_inactive; + } + + switch (info.opcode) { + case aco_opcode::s_cvt_hi_f32_f16: + info.operands[0].extract[0] = SubdwordSel::uword1; + info.opcode = aco_opcode::s_cvt_f32_f16; + break; + case aco_opcode::s_pack_lh_b32_b16: + case aco_opcode::s_pack_hl_b32_b16: + case aco_opcode::s_pack_hh_b32_b16: + if (info.opcode != aco_opcode::s_pack_lh_b32_b16) + info.operands[0].extract[0] = SubdwordSel::uword1; + if (info.opcode != aco_opcode::s_pack_hl_b32_b16) + info.operands[1].extract[0] = SubdwordSel::uword1; + info.opcode = aco_opcode::s_pack_ll_b32_b16; + break; + case aco_opcode::v_sub_f32: + case aco_opcode::v_subrev_f32: + info.operands[info.opcode == aco_opcode::v_sub_f32].neg[0] ^= true; + info.opcode = aco_opcode::v_add_f32; + break; + case aco_opcode::v_sub_f16: + case aco_opcode::v_subrev_f16: + info.operands[info.opcode == aco_opcode::v_sub_f16].neg[0] ^= true; + info.opcode = aco_opcode::v_add_f16; + break; + case aco_opcode::s_sub_f32: + info.operands[1].neg[0] ^= true; + info.opcode = aco_opcode::s_add_f32; + break; + case aco_opcode::s_sub_f16: + info.operands[1].neg[0] ^= true; + info.opcode = aco_opcode::s_add_f16; + break; + case aco_opcode::v_dot4_i32_iu8: + case aco_opcode::v_dot8_i32_iu4: + for (unsigned i = 0; i < 2; i++) { + info.operands[i].dot_sext = info.operands[i].neg[0]; + info.operands[i].neg[0] = false; + } + break; + case aco_opcode::v_mad_f32: + if (ctx.fp_mode.denorm32) + break; + FALLTHROUGH; + case aco_opcode::v_fma_f32: + if (info.operands[2].op.constantEquals(0) && info.operands[2].neg[0]) { + info.operands.pop_back(); + info.opcode = aco_opcode::v_mul_f32; + } else { + for (unsigned i = 0; i < 2; i++) { + uint32_t one = info.operands[i].f16_to_f32 ? 0x3c00 : 0x3f800000; + if (info.operands[i].op.constantEquals(one) && !info.operands[i].neg[0] && + info.operands[i].extract[0] == SubdwordSel::dword) { + info.operands.erase(info.operands.begin() + i); + info.opcode = aco_opcode::v_add_f32; + break; + } + } + } + break; + case aco_opcode::v_fmaak_f32: + case aco_opcode::v_fmamk_f32: + if (info.opcode == aco_opcode::v_fmamk_f32) + std::swap(info.operands[1], info.operands[2]); + info.opcode = aco_opcode::v_fma_f32; + break; + case aco_opcode::v_fmaak_f16: + case aco_opcode::v_fmamk_f16: + if (info.opcode == aco_opcode::v_fmamk_f16) + std::swap(info.operands[1], info.operands[2]); + info.opcode = aco_opcode::v_fma_f16; + break; + case aco_opcode::v_madak_f32: + case aco_opcode::v_madmk_f32: + if (info.opcode == aco_opcode::v_madmk_f32) + std::swap(info.operands[1], info.operands[2]); + info.opcode = aco_opcode::v_mad_f32; + break; + case aco_opcode::v_madak_f16: + case aco_opcode::v_madmk_f16: + if (info.opcode == aco_opcode::v_madmk_f16) + std::swap(info.operands[1], info.operands[2]); + info.opcode = + ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_f16 : aco_opcode::v_mad_f16; + break; + case aco_opcode::s_fmaak_f32: + case aco_opcode::s_fmamk_f32: + if (info.opcode == aco_opcode::s_fmamk_f32) + std::swap(info.operands[1], info.operands[2]); + info.opcode = aco_opcode::s_fmac_f32; + break; + case aco_opcode::v_subbrev_co_u32: + std::swap(info.operands[0], info.operands[1]); + info.opcode = aco_opcode::v_subb_co_u32; + break; + case aco_opcode::v_subrev_co_u32: + std::swap(info.operands[0], info.operands[1]); + info.opcode = aco_opcode::v_sub_co_u32; + break; + case aco_opcode::v_subrev_co_u32_e64: + std::swap(info.operands[0], info.operands[1]); + info.opcode = aco_opcode::v_sub_co_u32_e64; + break; + case aco_opcode::v_subrev_u32: + std::swap(info.operands[0], info.operands[1]); + info.opcode = aco_opcode::v_sub_u32; + break; + default: break; + } + + return true; +} + +/* Convert an alu_opt_info to an ACO IR instruction. + * alu_opt_info_is_valid must have been called and returned true before this. + * If old_instr is large enough for the new instruction, it's reused. + * Otherwise a new instruction is allocated. + */ +Instruction* +alu_opt_info_to_instr(opt_ctx& ctx, alu_opt_info& info, Instruction* old_instr) +{ + Instruction* instr; + if (old_instr && old_instr->definitions.size() >= info.defs.size() && + old_instr->operands.size() >= info.operands.size() && + get_instr_data_size(old_instr->format) >= get_instr_data_size(info.format)) { + instr = old_instr; + while (instr->operands.size() > info.operands.size()) + instr->operands.pop_back(); + while (instr->definitions.size() > info.defs.size()) + instr->definitions.pop_back(); + instr->opcode = info.opcode; + instr->format = info.format; + + if (instr->isVALU()) { + instr->valu().abs = 0; + instr->valu().neg = 0; + instr->valu().opsel = 0; + instr->valu().opsel_hi = 0; + instr->valu().opsel_lo = 0; + } + } else { + instr = create_instruction(info.opcode, info.format, info.operands.size(), info.defs.size()); + } + + instr->pass_flags = info.pass_flags; + + for (unsigned i = 0; i < info.defs.size(); i++) { + instr->definitions[i] = info.defs[i]; + ctx.info[info.defs[i].tempId()].parent_instr = instr; + } + + for (unsigned i = 0; i < info.operands.size(); i++) { + instr->operands[i] = info.operands[i].op; + if (instr->opcode == aco_opcode::v_fma_mix_f32 || + instr->opcode == aco_opcode::v_fma_mixlo_f16) { + instr->valu().neg[i] = info.operands[i].neg[0]; + instr->valu().abs[i] = info.operands[i].abs[0]; + instr->valu().opsel_hi[i] = info.operands[i].f16_to_f32; + instr->valu().opsel_lo[i] = info.operands[i].extract[0].offset(); + } else if (instr->isVOP3P()) { + instr->valu().neg_lo[i] = info.operands[i].neg[0] || info.operands[i].dot_sext; + instr->valu().neg_hi[i] = info.operands[i].neg[1]; + instr->valu().opsel_lo[i] = info.operands[i].extract[0].offset(); + instr->valu().opsel_hi[i] = info.operands[i].extract[1].offset(); + } else if (instr->isVALU()) { + instr->valu().neg[i] = info.operands[i].neg[0]; + instr->valu().abs[i] = info.operands[i].abs[0]; + if (instr->isSDWA() && i < 2) { + SubdwordSel sel = info.operands[i].extract[0]; + unsigned size = MIN2(sel.size(), info.operands[i].op.bytes()); + instr->sdwa().sel[i] = SubdwordSel(size, sel.offset(), sel.sign_extend()); + } else if (info.operands[i].extract[0].offset()) { + instr->valu().opsel[i] = true; + } + } + } + + if (instr->isVALU()) { + instr->valu().omod = info.omod; + instr->valu().clamp = info.clamp; + } + + if (instr->isDPP16()) { + instr->dpp16().dpp_ctrl = info.operands[0].dpp_ctrl; + instr->dpp16().fetch_inactive = info.operands[0].fi; + instr->dpp16().bound_ctrl = info.operands[0].bc; + instr->dpp16().row_mask = 0xf; + instr->dpp16().bank_mask = 0xf; + } else if (instr->isDPP8()) { + instr->dpp8().lane_sel = info.operands[0].dpp_ctrl; + instr->dpp8().fetch_inactive = info.operands[0].fi; + } else if (instr->isSDWA()) { + instr->sdwa().dst_sel = info.insert; + if (!instr->isVOPC() && instr->definitions[0].bytes() != 4) { + instr->sdwa().dst_sel = SubdwordSel(instr->definitions[0].bytes(), 0, false); + assert(instr->sdwa().dst_sel == info.insert || info.insert == SubdwordSel::dword); + } + } else if (instr->opcode == aco_opcode::v_permlane16_b32 || + instr->opcode == aco_opcode::v_permlanex16_b32) { + instr->valu().opsel = info.imm; + } + + if (instr->isSALU()) + instr->salu().imm = info.imm; + + return instr; +} + bool can_use_VOP3(opt_ctx& ctx, const aco_ptr& instr) {