diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 5cf3db281cb..1a3924b4c49 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -700,6 +700,13 @@ Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1) } } +uint32_t get_alu_src_ub(isel_context *ctx, nir_alu_instr *instr, int src_idx) +{ + nir_ssa_scalar scalar = (nir_ssa_scalar){instr->src[src_idx].src.ssa, + instr->src[src_idx].swizzle[0]}; + return nir_unsigned_upper_bound(ctx->shader, ctx->range_ht, scalar, &ctx->ub_config); +} + Temp convert_pointer_to_64_bit(isel_context *ctx, Temp ptr) { if (ptr.size() == 2) @@ -1656,7 +1663,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_imul: { if (dst.regClass() == v1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u32, dst); + uint32_t src0_ub = get_alu_src_ub(ctx, instr, 0); + uint32_t src1_ub = get_alu_src_ub(ctx, instr, 1); + + if (src0_ub <= 0xffffff && src1_ub <= 0xffffff) { + emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_u32_u24, dst, true); + } else { + emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u32, dst); + } } else if (dst.regClass() == s1) { emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_i32, dst, false); } else { @@ -1665,14 +1679,21 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_umul_high: { - if (dst.regClass() == v1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_hi_u32, dst); - } else if (dst.regClass() == s1 && ctx->options->chip_class >= GFX9) { + if (dst.regClass() == s1 && ctx->options->chip_class >= GFX9) { emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_hi_u32, dst, false); - } else if (dst.regClass() == s1) { - Temp tmp = bld.vop3(aco_opcode::v_mul_hi_u32, bld.def(v1), get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); - bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), tmp); + } else if (dst.bytes() == 4) { + uint32_t src0_ub = get_alu_src_ub(ctx, instr, 0); + uint32_t src1_ub = get_alu_src_ub(ctx, instr, 1); + + Temp tmp = dst.regClass() == s1 ? bld.tmp(v1) : dst; + if (src0_ub <= 0xffffff && src1_ub <= 0xffffff) { + emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_hi_u32_u24, tmp); + } else { + emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_hi_u32, tmp); + } + + if (dst.regClass() == s1) + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), tmp); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); }