From a8115221e596a8bed7a64799ccc03aa9ad225d92 Mon Sep 17 00:00:00 2001 From: Ian Romanick Date: Tue, 26 Mar 2024 16:14:33 -0700 Subject: [PATCH] nir: intel/brw: Change the order of sources for nir_dpas_intel It was by pure luck that all sources (and the result) of nir_dpas_intel had the same number of components. It is possible to support matrix sizes where the accumlator matrix and the result matrix are larger (e.g., 16x8 * 8x16 = 16x16). This breaks all of the assumptions of NIR's infrastructure for code generating intrinsics. Fix the by making the accumulator matrix be the first source. The accumulator and the result will always have the same dimensions (due to rules of matrix multiplication) and the same type (due to restructions of the cooperative matrix extension). This forces them to have the same number of components. This doesn't fix all the potential problems. NIR expects that all 0-sized sources will have the same number of components. This just ensures that the result has the correct number of components. Fixes: 6b14da33ad3 ("intel/fs: nir: Add nir_intrinsic_dpas_intel") Reviewed-by: Jordan Justen Part-of: --- src/compiler/nir/nir_intrinsics.py | 10 +++++++--- src/intel/compiler/brw_fs_nir.cpp | 18 +++++++++--------- .../brw_nir_lower_cooperative_matrix.c | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 39f57c22de2..fdb5d9d51b1 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -2116,11 +2116,15 @@ system_value("leaf_procedural_intel", 1, bit_sizes=[1]) system_value("btd_shader_type_intel", 1) system_value("ray_query_global_intel", 1, bit_sizes=[64]) -# Source 0: A matrix (type specified by SRC_TYPE) -# Source 1: B matrix (type specified by SRC_TYPE) -# Source 2: Accumulator matrix (type specified by DEST_TYPE) +# Source 0: Accumulator matrix (type specified by DEST_TYPE) +# Source 1: A matrix (type specified by SRC_TYPE) +# Source 2: B matrix (type specified by SRC_TYPE) # # The matrix parameters are the slices owned by the invocation. +# +# The accumulator is source 0 because that is the source the intrinsic +# infrastructure in NIR uses to determine the number of components in the +# result. intrinsic("dpas_intel", dest_comp=0, src_comp=[0, 0, 0], indices=[DEST_TYPE, SRC_TYPE, SATURATE, CMAT_SIGNED_MASK, SYSTOLIC_DEPTH, REPEAT_COUNT], flags=[CAN_ELIMINATE]) diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp index 6811d57d67b..593a27e1f00 100644 --- a/src/intel/compiler/brw_fs_nir.cpp +++ b/src/intel/compiler/brw_fs_nir.cpp @@ -4516,7 +4516,7 @@ fs_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, brw_type_for_nir_type(devinfo, nir_intrinsic_src_type(instr)); dest = retype(dest, dest_type); - fs_reg src2 = retype(get_nir_src(ntb, instr->src[2]), dest_type); + fs_reg src0 = retype(get_nir_src(ntb, instr->src[0]), dest_type); const fs_reg dest_hf = dest; fs_builder bld8 = bld.exec_all().group(8, 0); @@ -4532,24 +4532,24 @@ fs_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, !s.compiler->lower_dpas) { dest = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); - if (src2.file != ARF) { - const fs_reg src2_hf = src2; + if (src0.file != ARF) { + const fs_reg src0_hf = src0; - src2 = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); + src0 = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); for (unsigned i = 0; i < 4; i++) { - bld16.MOV(byte_offset(src2, REG_SIZE * i * 2), - byte_offset(src2_hf, REG_SIZE * i)); + bld16.MOV(byte_offset(src0, REG_SIZE * i * 2), + byte_offset(src0_hf, REG_SIZE * i)); } } else { - src2 = retype(src2, BRW_REGISTER_TYPE_F); + src0 = retype(src0, BRW_REGISTER_TYPE_F); } } bld8.DPAS(dest, - src2, + src0, + retype(get_nir_src(ntb, instr->src[2]), src_type), retype(get_nir_src(ntb, instr->src[1]), src_type), - retype(get_nir_src(ntb, instr->src[0]), src_type), sdepth, rcount) ->saturate = nir_intrinsic_saturate(instr); diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index f8743d89691..809aa7f456d 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -649,9 +649,9 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_def *result = nir_dpas_intel(b, packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type), + nir_load_deref(b, accum_slice), nir_load_deref(b, A_slice), nir_load_deref(b, B_slice), - nir_load_deref(b, accum_slice), .dest_type = nir_get_nir_type_for_glsl_base_type(dst_desc.element_type), .src_type = nir_get_nir_type_for_glsl_base_type(src_desc.element_type), .saturate = nir_intrinsic_saturate(intrin),