diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 2910a46b50e..d9d41bfd02d 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -1583,47 +1583,6 @@ pseudo_propagate_temp(opt_ctx& ctx, aco_ptr& instr, Temp temp, unsi return true; } -/* This expects the DPP modifier to be removed. */ -bool -can_apply_sgprs(opt_ctx& ctx, aco_ptr& instr) -{ - assert(instr->isVALU()); - if (instr->isSDWA() && ctx.program->gfx_level < GFX9) - return false; - return instr->opcode != aco_opcode::v_readfirstlane_b32 && - instr->opcode != aco_opcode::v_readlane_b32 && - instr->opcode != aco_opcode::v_readlane_b32_e64 && - instr->opcode != aco_opcode::v_writelane_b32 && - instr->opcode != aco_opcode::v_writelane_b32_e64 && - instr->opcode != aco_opcode::v_permlane16_b32 && - instr->opcode != aco_opcode::v_permlanex16_b32 && - instr->opcode != aco_opcode::v_permlane64_b32 && - instr->opcode != aco_opcode::v_interp_p1_f32 && - instr->opcode != aco_opcode::v_interp_p2_f32 && - instr->opcode != aco_opcode::v_interp_mov_f32 && - instr->opcode != aco_opcode::v_interp_p1ll_f16 && - instr->opcode != aco_opcode::v_interp_p1lv_f16 && - instr->opcode != aco_opcode::v_interp_p2_legacy_f16 && - instr->opcode != aco_opcode::v_interp_p2_f16 && - instr->opcode != aco_opcode::v_interp_p2_hi_f16 && - instr->opcode != aco_opcode::v_interp_p10_f32_inreg && - instr->opcode != aco_opcode::v_interp_p2_f32_inreg && - instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg && - instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg && - instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg && - instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg && - instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 && - instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 && - instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 && - instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 && - instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 && - instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4 && - instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_fp8 && - instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_bf8 && - instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf8_fp8 && - instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf8_bf8; -} - /* only covers special cases */ bool pseudo_can_accept_constant(const aco_ptr& instr, unsigned operand) @@ -2141,19 +2100,6 @@ does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op) } } -bool -can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr& instr, Temp tmp, unsigned idx) -{ - float_mode* fp = &ctx.fp_mode; - if (ctx.info[tmp.id()].is_canonicalized() || - (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep) - return true; - - aco_opcode op = instr->opcode; - return can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, idx) && - does_fp_op_flush_denorms(ctx, op); -} - bool can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags, bool allow_cselect = false) { @@ -2189,13 +2135,6 @@ can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags, bool allow_c } } -bool -is_copy_label(opt_ctx& ctx, aco_ptr& instr, ssa_info& info, unsigned idx) -{ - return info.is_temp() || - (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp, idx)); -} - bool is_op_canonicalized(opt_ctx& ctx, Operand op) { @@ -2310,6 +2249,12 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type) type.num_components = 1; type.bit_size = tmp.bytes() * 8; + if (info.is_extract()) { + op_info.extract[0] = parse_extract(info.parent_instr); + op_info.op = info.parent_instr->operands[0]; + return true; + } + if (info.is_constant_or_literal(type.bit_size)) { op_info.op = get_constant_op(ctx, info, type.bit_size); return true; @@ -2330,6 +2275,10 @@ bool combine_operand(opt_ctx& ctx, alu_opt_op& inner, const aco_type& inner_type, const alu_opt_op& outer, const aco_type& outer_type, bool flushes_denorms) { + /* Nothing to be gained by bothering with lane masks. */ + if (inner_type.bit_size <= 1) + return false; + if (inner.op.size() != outer.op.size()) return false; @@ -2397,7 +2346,7 @@ decrease_and_dce(opt_ctx& ctx, Temp tmp) } void -alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr) +alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_valid) { alu_opt_info info; if (!alu_opt_gather_info(ctx, instr.get(), info)) @@ -2414,45 +2363,89 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr) instr->opcode == aco_opcode::v_min3_f32 || instr->opcode == aco_opcode::v_max3_f32 || instr->opcode == aco_opcode::v_med3_f32); + unsigned operand_mask = BITFIELD_MASK(info.operands.size()); + bool progress = false; alu_opt_info result_info; - for (unsigned i = 0; i < info.operands.size(); i++) { - while (info.operands[i].op.isTemp()) { - alu_opt_op outer; - aco_type outer_type; - if (!parse_operand(ctx, info.operands[i].op.getTemp(), outer, outer_type)) + while (operand_mask) { + uint32_t i = UINT32_MAX; + uint32_t op_uses = UINT32_MAX; + u_foreach_bit (candidate, operand_mask) { + if (!info.operands[candidate].op.isTemp()) { + operand_mask &= ~BITFIELD_BIT(candidate); + continue; + } + + if (!uses_valid) { + i = candidate; break; + } - /* Applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier, - * otherwise we apply SGPRs later. - */ - if (info.operands[i].op.isOfType(RegType::vgpr) && outer.op.isOfType(RegType::sgpr) && - !instr->isVOP1()) - break; - - alu_opt_op inner = info.operands[i]; - aco_type inner_type = get_canonical_operand_type(info.opcode, i); - if (inner.f16_to_f32) - inner_type.bit_size = 16; - bool flushes_denorms = inner_type.base_type == aco_base_type_float && !gfx8_min_max; - if (!combine_operand(ctx, inner, inner_type, outer, outer_type, flushes_denorms)) - break; - - alu_opt_info info_copy = info; - info_copy.operands[i] = inner; - if (!alu_opt_info_is_valid(ctx, info_copy)) - break; - - bool has_lit = std::any_of(info_copy.operands.begin(), info_copy.operands.end(), - [](const alu_opt_op& op) { return op.op.isLiteral(); }); - - if (!had_lit && has_lit) - break; - - result_info = info_copy; - info.operands[i] = inner; - progress = true; + unsigned new_uses = ctx.uses[info.operands[candidate].op.tempId()]; + if (new_uses >= op_uses) + continue; + i = candidate; + op_uses = new_uses; } + + if (i == UINT32_MAX) + break; + + alu_opt_op outer; + aco_type outer_type; + if (!parse_operand(ctx, info.operands[i].op.getTemp(), outer, outer_type)) { + operand_mask &= ~BITFIELD_BIT(i); + continue; + } + + /* Applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier, + * otherwise we apply SGPRs later. + */ + bool valu_new_sgpr = info.operands[i].op.isOfType(RegType::vgpr) && + outer.op.isOfType(RegType::sgpr) && !instr->isVOP1(); + if ((valu_new_sgpr || ctx.info[info.operands[i].op.tempId()].is_extract()) && !uses_valid) { + operand_mask &= ~BITFIELD_BIT(i); + continue; + } + + alu_opt_op inner = info.operands[i]; + aco_type inner_type = get_canonical_operand_type(info.opcode, i); + if (inner.f16_to_f32) + inner_type.bit_size = 16; + bool flushes_denorms = inner_type.base_type == aco_base_type_float && !gfx8_min_max; + if (!combine_operand(ctx, inner, inner_type, outer, outer_type, flushes_denorms)) { + operand_mask &= ~BITFIELD_BIT(i); + continue; + } + + alu_opt_info info_copy = info; + info_copy.operands[i] = inner; + if (!alu_opt_info_is_valid(ctx, info_copy)) { + operand_mask &= ~BITFIELD_BIT(i); + continue; + } + bool has_lit = std::any_of(info_copy.operands.begin(), info_copy.operands.end(), + [](const alu_opt_op& op) { return op.op.isLiteral(); }); + + if (!had_lit && has_lit) { + operand_mask &= ~BITFIELD_BIT(i); + continue; + } + + bool valu_removed_sgpr = info.operands[i].op.isOfType(RegType::sgpr) && + !inner.op.isOfType(RegType::sgpr) && instr->isVALU(); + if (valu_removed_sgpr && uses_valid) + operand_mask = BITFIELD_MASK(info.operands.size()); + + if (uses_valid) { + if (inner.op.isTemp()) + ctx.uses[inner.op.tempId()]++; + decrease_and_dce(ctx, info.operands[i].op.getTemp()); + } + + result_info = info_copy; + info.operands[i] = inner; + progress = true; } if (!progress) @@ -2618,7 +2611,7 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) /* SALU / VALU: propagate inline constants, temps, and imod */ if (instr->isSALU() || instr->isVALU()) { - alu_propagate_temp_const(ctx, instr); + alu_propagate_temp_const(ctx, instr, false); } /* if this instruction doesn't define anything, return */ @@ -3862,104 +3855,6 @@ combine_clamp(opt_ctx& ctx, aco_ptr& instr, aco_opcode min, aco_opc return false; } -void -apply_sgprs(opt_ctx& ctx, aco_ptr& instr) -{ - bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 || - instr->opcode == aco_opcode::v_lshlrev_b64 || - instr->opcode == aco_opcode::v_lshrrev_b64 || - instr->opcode == aco_opcode::v_ashrrev_i64; - - /* find candidates and create the set of sgprs already read */ - unsigned sgpr_ids[2] = {0, 0}; - uint32_t operand_mask = 0; - bool has_literal = false; - for (unsigned i = 0; i < instr->operands.size(); i++) { - if (instr->operands[i].isLiteral()) - has_literal = true; - if (!instr->operands[i].isTemp()) - continue; - if (instr->operands[i].getTemp().type() == RegType::sgpr) { - if (instr->operands[i].tempId() != sgpr_ids[0]) - sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId(); - } - ssa_info& info = ctx.info[instr->operands[i].tempId()]; - if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::sgpr) - operand_mask |= 1u << i; - if (info.is_extract() && info.parent_instr->operands[0].getTemp().type() == RegType::sgpr) - operand_mask |= 1u << i; - } - unsigned max_sgprs = 1; - if (ctx.program->gfx_level >= GFX10 && !is_shift64) - max_sgprs = 2; - if (has_literal) - max_sgprs--; - - unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1]; - - /* keep on applying sgprs until there is nothing left to be done */ - while (operand_mask) { - uint32_t sgpr_idx = 0; - uint32_t sgpr_info_id = 0; - uint32_t mask = operand_mask; - /* choose a sgpr */ - while (mask) { - unsigned i = u_bit_scan(&mask); - uint16_t uses = ctx.uses[instr->operands[i].tempId()]; - if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) { - sgpr_idx = i; - sgpr_info_id = instr->operands[i].tempId(); - } - } - operand_mask &= ~(1u << sgpr_idx); - - ssa_info& info = ctx.info[sgpr_info_id]; - - Temp sgpr = info.is_extract() ? info.parent_instr->operands[0].getTemp() : info.temp; - bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1]; - if (new_sgpr && num_sgprs >= max_sgprs) - continue; - - if (sgpr_idx == 0) - instr->format = withoutDPP(instr->format); - - if (sgpr_idx == 1 && instr->isDPP()) - continue; - - if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() || - info.is_extract()) { - /* can_apply_extract() checks SGPR encoding restrictions */ - if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info)) - apply_extract(ctx, instr, sgpr_idx, info); - else if (info.is_extract()) - continue; - instr->operands[sgpr_idx] = Operand(sgpr); - } else if (can_swap_operands(instr, &instr->opcode) && !instr->valu().opsel[sgpr_idx]) { - instr->operands[sgpr_idx] = instr->operands[0]; - instr->operands[0] = Operand(sgpr); - instr->valu().opsel[0].swap(instr->valu().opsel[sgpr_idx]); - /* swap bits using a 4-entry LUT */ - uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf; - operand_mask = (operand_mask & ~0x3) | swapped; - } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) { - instr->format = asVOP3(instr->format); - instr->operands[sgpr_idx] = Operand(sgpr); - } else { - continue; - } - - if (new_sgpr) - sgpr_ids[num_sgprs++] = sgpr.id(); - ctx.uses[sgpr_info_id]--; - ctx.uses[sgpr.id()]++; - - /* TODO: handle when it's a VGPR */ - if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) && - ctx.info[sgpr.id()].temp.type() == RegType::sgpr) - operand_mask |= 1u << sgpr_idx; - } -} - bool interp_can_become_fma(opt_ctx& ctx, aco_ptr& instr) { @@ -4703,38 +4598,19 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) if (instr->definitions.empty() || is_dead(ctx.uses, instr.get())) return; + for (const Definition& def : instr->definitions) { + ssa_info& info = ctx.info[def.tempId()]; + if (info.is_extract() && ctx.uses[def.tempId()] > 4) + info.label &= ~label_extract; + } + if (instr->isVALU() || instr->isSALU()) { /* Apply SDWA. Do this after label_instruction() so it can remove * label_extract if not all instructions can take SDWA. */ - for (unsigned i = 0; i < instr->operands.size(); i++) { - Operand& op = instr->operands[i]; - if (!op.isTemp()) - continue; - ssa_info& info = ctx.info[op.tempId()]; - if (!info.is_extract()) - continue; - /* if there are that many uses, there are likely better combinations */ - // TODO: delay applying extract to a point where we know better - if (ctx.uses[op.tempId()] > 4) { - info.label &= ~label_extract; - continue; - } - if (info.is_extract() && - (info.parent_instr->operands[0].getTemp().type() == RegType::vgpr || - instr->operands[i].getTemp().type() == RegType::sgpr) && - can_apply_extract(ctx, instr, i, info)) { - /* Increase use count of the extract's operand if the extract still has uses. */ - apply_extract(ctx, instr, i, info); - if (--ctx.uses[instr->operands[i].tempId()]) - ctx.uses[info.parent_instr->operands[0].tempId()]++; - instr->operands[i].setTemp(info.parent_instr->operands[0].getTemp()); - } - } + alu_propagate_temp_const(ctx, instr, true); } if (instr->isVALU()) { - if (can_apply_sgprs(ctx, instr)) - apply_sgprs(ctx, instr); combine_mad_mix(ctx, instr); while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr)) ; diff --git a/src/amd/compiler/tests/test_sdwa.cpp b/src/amd/compiler/tests/test_sdwa.cpp index d340714c632..3b110425aa8 100644 --- a/src/amd/compiler/tests/test_sdwa.cpp +++ b/src/amd/compiler/tests/test_sdwa.cpp @@ -689,7 +689,7 @@ BEGIN_TEST(optimize.sdwa.subdword_extract) Operand::c32(8), Operand::c32(0)), inputs[2])); - //! v1b: %res3 = v_or_b32 %a, %b dst_sel:ubyte0 dst_preserve src0_sel:uword0 src1_sel:ubyte2 + //! v1b: %res3 = v_or_b32 %a, %b dst_sel:ubyte0 dst_preserve src0_sel:uword0 src1_sel:uword1 //! p_unit_test 3, %res3 writeout(3, bld.vop2(aco_opcode::v_or_b32, bld.def(v1b), bld.pseudo(aco_opcode::p_extract, bld.def(v1b), a, Operand::c32(0), @@ -703,9 +703,7 @@ BEGIN_TEST(optimize.sdwa.subdword_extract) bld.pseudo(aco_opcode::p_extract, bld.def(v2b), a, Operand::c32(0), Operand::c32(8), Operand::c32(1)))); - /* TODO incremental conversion to sdwa loses information if zero extend is actually necessary */ - //! v2b: %tmp5 = p_extract %b, 1, 8, 1 - //! v2b: %res5 = v_or_b32 %a, %tmp5 dst_sel:uword0 dst_preserve src0_sel:sbyte0 src1_sel:uword0 + //! v2b: %res5 = v_or_b32 %a, %b dst_sel:uword0 dst_preserve src0_sel:sbyte0 src1_sel:sbyte1 //! p_unit_test 5, %res5 writeout(5, bld.vop2(aco_opcode::v_or_b32, bld.def(v2b), bld.pseudo(aco_opcode::p_extract, bld.def(v2b), a, Operand::c32(0), @@ -845,3 +843,41 @@ BEGIN_TEST(optimize.sdwa.extract_vector) finish_opt_test(); END_TEST + +BEGIN_TEST(optimizer.sdwa.lanemask_extract) + for (unsigned i = GFX10; i <= GFX11; i++) { + if (i == GFX10_3) + continue; + + //>> v1: %a:v[0], v1: %b:v[1], s1: %c:s[0] = p_startpgm + if (!setup_cs("v1 v1 s1", (amd_gfx_level)i, CHIP_UNKNOWN, "", 32)) + continue; + + Temp a = inputs[0]; + Temp b = inputs[1]; + Temp c = inputs[2]; + + //! s1: %mask0, s1: %_:scc = p_extract %c, 0, 16, 0 + //! v1: %res0 = v_cndmask_b32 %a, %b, %mask0 + //! p_unit_test 0, %res0 + Temp mask = ext_ushort(c, 0); + Temp bcsel = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), a, b, mask); + writeout(0, bcsel); + + //! s1: %mask1, s1: %_:scc = p_extract %c, 2, 8, 1 + //! v1: %res1 = v_cndmask_b32 %a, %b, %mask1 + //! p_unit_test 1, %res1 + mask = ext_sbyte(c, 2); + bcsel = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), a, b, mask); + writeout(1, bcsel); + + //! s1: %mask2, s1: %_:scc = p_extract %c, 3, 8, 0 + //! v1: %res2 = v_cndmask_b32 %a, %b, %mask2 + //! p_unit_test 2, %res2 + mask = ext_ubyte(c, 3); + bcsel = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), a, b, mask); + writeout(2, bcsel); + + finish_opt_test(); + } +END_TEST