diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index f28aec9b73d..064b7d3213a 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -7964,8 +7964,9 @@ static void visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) { aco_opcode opcode = aco_opcode::num_opcodes; - unsigned signed_mask = 0; - bool clamp = false; + + bitarray8 neg_lo = nir_intrinsic_neg_lo_amd(instr); + bitarray8 neg_hi = nir_intrinsic_neg_hi_amd(instr); switch (instr->src[0].ssa->bit_size) { case 16: @@ -7974,12 +7975,14 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break; } break; - case 8: + case 8: { opcode = aco_opcode::v_wmma_i32_16x16x16_iu8; - signed_mask = nir_intrinsic_cmat_signed_mask(instr); - clamp = nir_intrinsic_saturate(instr); + unsigned signed_mask = nir_intrinsic_cmat_signed_mask(instr); + neg_lo[0] = signed_mask & NIR_CMAT_A_SIGNED; + neg_lo[1] = signed_mask & NIR_CMAT_B_SIGNED; break; } + } if (opcode == aco_opcode::num_opcodes) unreachable("visit_cmat_muladd: invalid bit size combination"); @@ -7992,9 +7995,9 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) Operand C(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[2].ssa))); VALU_instruction& vop3p = bld.vop3p(opcode, Definition(dst), A, B, C, 0, 0x7)->valu(); - vop3p.neg_lo[0] = (signed_mask & 0x1) != 0; - vop3p.neg_lo[1] = (signed_mask & 0x2) != 0; - vop3p.clamp = clamp; + vop3p.neg_lo = neg_lo; + vop3p.neg_hi = neg_hi; + vop3p.clamp = nir_intrinsic_saturate(instr); emit_split_vector(ctx, dst, instr->def.num_components); } diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 109b9884b74..f3cdb195f84 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -419,13 +419,13 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev nir_def *A = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); nir_def *B = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa); nir_def *C = radv_nir_load_cmat(&b, ¶ms, intr->src[3].ssa); - nir_def *ret; - ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr), - .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr)); + nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, - nir_component_mask(ret->num_components)); + nir_def *ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr), + .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr)); + + nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 8991ddb0c88..7b27cc10909 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -322,6 +322,8 @@ index("struct glsl_cmat_description", "cmat_desc") index("enum glsl_matrix_layout", "matrix_layout") index("nir_cmat_signed", "cmat_signed_mask") index("nir_op", "alu_op") +index("unsigned", "neg_lo_amd") +index("unsigned", "neg_hi_amd") # For Intel DPAS instrinsic. index("unsigned", "systolic_depth") @@ -1967,7 +1969,7 @@ intrinsic("strict_wqm_coord_amd", src_comp=[0], dest_comp=0, bit_sizes=[32], ind flags=[CAN_ELIMINATE]) intrinsic("cmat_muladd_amd", src_comp=[-1, -1, 0], dest_comp=0, bit_sizes=src2, - indices=[SATURATE, CMAT_SIGNED_MASK], flags=[CAN_ELIMINATE]) + indices=[SATURATE, NEG_LO_AMD, NEG_HI_AMD, CMAT_SIGNED_MASK], flags=[CAN_ELIMINATE]) # Get the debug log buffer descriptor. intrinsic("load_debug_log_desc_amd", bit_sizes=[32], dest_comp=4, flags=[CAN_ELIMINATE, CAN_REORDER])