diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 0e7e9e6e968..79a2d65df3b 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -1869,90 +1869,19 @@ parse_insert(Instruction* instr) } } -SubdwordSel -apply_extract_twice(SubdwordSel first, Temp first_dst, SubdwordSel second, Temp second_dst) -{ - /* the outer offset must be within extracted range */ - if (second.offset() >= first.size()) - return SubdwordSel(); - - /* don't remove the sign-extension when increasing the size further */ - if (second.size() > first.size() && first.sign_extend() && - !(second.sign_extend() || - (second.size() == first_dst.bytes() && second.size() == second_dst.bytes()))) - return SubdwordSel(); - - unsigned size = std::min(first.size(), second.size()); - unsigned offset = first.offset() + second.offset(); - bool sign_extend = second.size() <= first.size() ? second.sign_extend() : first.sign_extend(); - return SubdwordSel(size, offset, sign_extend); -} - -bool -can_apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_info& info) -{ - Temp tmp = info.parent_instr->operands[0].getTemp(); - SubdwordSel sel = parse_extract(info.parent_instr); - - if (!sel) { - return false; - } else if (sel.size() == instr->operands[idx].bytes() && sel.size() == tmp.bytes() && - tmp.type() == instr->operands[idx].regClass().type()) { - assert(tmp.type() != RegType::sgpr); /* No sub-dword SGPR regclasses */ - return true; - } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 || - instr->opcode == aco_opcode::v_cvt_f32_i32 || - instr->opcode == aco_opcode::v_cvt_f32_ubyte0) && - sel.size() == 1 && !sel.sign_extend() && !instr->usesModifiers()) { - return true; - } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() && - sel.offset() == 0 && !instr->usesModifiers() && - ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) || - (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) { - return true; - } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 && - !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() && - (instr->operands[!idx].is16bit() || - (instr->operands[!idx].isConstant() && - instr->operands[!idx].constantValue() <= UINT16_MAX))) { - return true; - } else if (idx < 2 && can_use_SDWA(ctx.program->gfx_level, instr, true) && - (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) { - if (instr->isSDWA()) { - /* TODO: if we knew how many bytes this operand actually uses, we could have smaller - * second_dst parameter and apply more sign-extended sels. - */ - return apply_extract_twice(sel, instr->operands[idx].getTemp(), instr->sdwa().sel[idx], - Temp(0, v1)) != SubdwordSel(); - } - return true; - } else if (instr->isVALU() && sel.size() == 2 && !instr->valu().opsel[idx] && - can_use_opsel(ctx.program->gfx_level, instr->opcode, idx)) { - return true; - } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16 && sel.size() == 2 && - (idx == 1 || ctx.program->gfx_level >= GFX11 || !sel.offset())) { - return true; - } else if (sel.size() == 2 && ((instr->opcode == aco_opcode::s_pack_lh_b32_b16 && idx == 0) || - (instr->opcode == aco_opcode::s_pack_hl_b32_b16 && idx == 1))) { - return true; - } - - return false; -} - void -check_sdwa_extract(opt_ctx& ctx, aco_ptr& instr) +remove_operand_extract(opt_ctx& ctx, aco_ptr& instr) { + /* We checked these earlier in alu_propagate_temp_const */ + if (instr->isSALU() || instr->isVALU()) + return; + for (unsigned i = 0; i < instr->operands.size(); i++) { Operand op = instr->operands[i]; if (!op.isTemp()) continue; ssa_info& info = ctx.info[op.tempId()]; - if (info.is_extract() && (info.parent_instr->operands[0].getTemp().type() == RegType::vgpr || - op.getTemp().type() == RegType::sgpr)) { - if (!can_apply_extract(ctx, instr, i, info)) - info.label &= ~label_extract; - } + info.label &= ~label_extract; } } @@ -2238,6 +2167,17 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_va instr->opcode == aco_opcode::v_min3_f32 || instr->opcode == aco_opcode::v_max3_f32 || instr->opcode == aco_opcode::v_med3_f32); + bool remove_extract = !uses_valid; + /* GFX8: Don't remove label_extract if we can't apply the extract to + * neg/abs instructions because we'll likely combine it into another valu. */ + if (instr->opcode == aco_opcode::v_mul_f16) { + for (Operand op : instr->operands) + remove_extract &= !op.constantEquals(0x3c00) && !op.constantEquals(0xbc00); + } else if (instr->opcode == aco_opcode::v_mul_f32) { + for (Operand op : instr->operands) + remove_extract &= !op.constantEquals(0x3f800000) && !op.constantEquals(0xbf800000); + } + unsigned operand_mask = BITFIELD_MASK(info.operands.size()); bool progress = false; @@ -2278,7 +2218,7 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_va */ bool valu_new_sgpr = info.operands[i].op.isOfType(RegType::vgpr) && outer.op.isOfType(RegType::sgpr) && !instr->isVOP1(); - if ((valu_new_sgpr || ctx.info[info.operands[i].op.tempId()].is_extract()) && !uses_valid) { + if (valu_new_sgpr && !uses_valid) { operand_mask &= ~BITFIELD_BIT(i); continue; } @@ -2289,6 +2229,8 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_va inner_type.bit_size = 16; bool flushes_denorms = inner_type.base_type == aco_base_type_float && !gfx8_min_max; if (!combine_operand(ctx, inner, inner_type, outer, outer_type, flushes_denorms)) { + if (remove_extract) + ctx.info[info.operands[i].op.tempId()].label &= ~label_extract; operand_mask &= ~BITFIELD_BIT(i); continue; } @@ -2296,13 +2238,16 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_va alu_opt_info info_copy = info; info_copy.operands[i] = inner; if (!alu_opt_info_is_valid(ctx, info_copy)) { + if (remove_extract) + ctx.info[info.operands[i].op.tempId()].label &= ~label_extract; operand_mask &= ~BITFIELD_BIT(i); continue; } bool has_lit = std::any_of(info_copy.operands.begin(), info_copy.operands.end(), [](const alu_opt_op& op) { return op.op.isLiteral(); }); - if (!had_lit && has_lit) { + if ((!had_lit && has_lit) || + (ctx.info[info.operands[i].op.tempId()].is_extract() && !uses_valid)) { operand_mask &= ~BITFIELD_BIT(i); continue; } @@ -2546,7 +2491,7 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) /* if this instruction doesn't define anything, return */ if (instr->definitions.empty()) { - check_sdwa_extract(ctx, instr); + remove_operand_extract(ctx, instr); return; } @@ -2958,10 +2903,7 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) default: break; } - /* Don't remove label_extract if we can't apply the extract to - * neg/abs instructions because we'll likely combine it into another valu. */ - if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs))) - check_sdwa_extract(ctx, instr); + remove_operand_extract(ctx, instr); /* Set parent_instr for all SSA definitions. */ for (const Definition& def : instr->definitions) diff --git a/src/amd/compiler/tests/test_sdwa.cpp b/src/amd/compiler/tests/test_sdwa.cpp index 3b110425aa8..91ac8ea08ac 100644 --- a/src/amd/compiler/tests/test_sdwa.cpp +++ b/src/amd/compiler/tests/test_sdwa.cpp @@ -316,6 +316,28 @@ BEGIN_TEST(optimize.sdwa.extract_modifiers) } END_TEST +BEGIN_TEST(optimize.sdwa.extract_modifiers_fp16) + for (unsigned i = GFX8; i <= GFX10; i++) { + //>> v2b: %a:v[0][0:16], v1: %b:v[1] = p_startpgm + if (!setup_cs("v2b v1", (amd_gfx_level)i)) + continue; + + Temp hi = bld.pseudo(aco_opcode::p_extract_vector, bld.def(v2b), inputs[1], Operand::c32(1)); + + //! v2b: %res0 = v_mul_f16 -%b, %a dst_sel:uword0 dst_preserve src0_sel:uword1 src1_sel:uword0 + //! p_unit_test 0, %res0 + Temp fneg_hi = fneg(hi); + writeout(0, fmul(fneg_hi, inputs[0])); + + //! v2b: %res1 = v_mul_f16 |%b|, %a dst_sel:uword0 dst_preserve src0_sel:uword1 src1_sel:uword0 + //! p_unit_test 1, %res1 + Temp fabs_hi = fabs(hi); + writeout(1, fmul(fabs_hi, inputs[0])); + + finish_opt_test(); + } +END_TEST + BEGIN_TEST(optimize.sdwa.extract.sgpr) for (unsigned i = GFX8; i <= GFX10; i++) { //>> v1: %a:v[0], v1: %b:v[1], s1: %c:s[0], s1: %d:s[1] = p_startpgm