brw: Fix cmat conversion between bfloat16 and non-float32

The HW only supports converting BRW_TYPE_BF values to/from BRW_TYPE_F,
so intermediate conversion is needed.  Move the intermediate conversion
to the implementation of `@convert_cmat_intel` and simplify the
brw_nir_lower_cooperative_matrix pass.  This has two positive effects

- Fixes conversion between BF and integer type cooperative matrices,
  that was still using the old emit_alu1 approach instead of the new
  code for `@convert_cmat_intel`.

- Guarantee the intermediate conversion will result in a valid layout
  for conversions involved USE_B matrices.  If we instead used the
  intrinsic twice in brw_nir_lower_cooperative_matrix.c, a matrix with
  invalid layout would be visible at NIR level and we wouldn't be able
  to keep the current assertion for USE_B case.

Due to the configurations we have exposed, we still don't need to
write a more complex USE_B conversion -- they are all between same
size types (and, consequently, packing factors), so no shuffling of
data is needed to respect the USE_B layout.

Reviewed-by: Matt Turner <mattst88@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36185>
This commit is contained in:
Caio Oliveira
2025-07-16 16:06:02 -07:00
committed by Marge Bot
parent 557ac588e4
commit 2dfd4dcbc5
2 changed files with 28 additions and 42 deletions
+15 -1
View File
@@ -4885,11 +4885,15 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
const unsigned elems = src_components * src_packing_factor;
brw_builder bldn = bld.exec_all();
const brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type);
brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type);
const brw_reg dst = retype(dest, dst_type);
assert(dst_cmat_desc.use == src_cmat_desc.use);
const bool needs_intermediate =
(src.type == BRW_TYPE_BF && dst.type != BRW_TYPE_F) ||
(dst.type == BRW_TYPE_BF && src.type != BRW_TYPE_F);
switch (src_cmat_desc.use) {
case GLSL_CMAT_USE_B:
assert(dst_element_bits == src_element_bits);
@@ -4898,6 +4902,16 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
case GLSL_CMAT_USE_A:
case GLSL_CMAT_USE_ACCUMULATOR: {
const unsigned width = bldn.dispatch_width();
if (needs_intermediate) {
brw_reg tmp = bldn.vgrf(BRW_TYPE_F, elems);
for (unsigned c = 0; c < elems; c++) {
bldn.MOV(suboffset(tmp, c * width),
suboffset(src, c * width));
}
src = tmp;
}
for (unsigned c = 0; c < elems; c++) {
bldn.MOV(suboffset(dst, c * width),
suboffset(src, c * width));