spirv: add support for cooperative matrix reduction operation

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38389>
This commit is contained in:
Dave Airlie
2025-08-11 16:24:13 +10:00
committed by Marge Bot
parent 438245404c
commit a4a0d28ea6
2 changed files with 20 additions and 0 deletions

View File

@@ -73,6 +73,7 @@ static const struct spirv_capabilities implemented_capabilities = {
.ComputeDerivativeGroupQuadsKHR = true,
.CooperativeMatrixKHR = true,
.CooperativeMatrixConversionsNV = true,
.CooperativeMatrixReductionsNV = true,
.CoreBuiltinsARM = true,
.CullDistance = true,
.DemoteToHelperInvocation = true,
@@ -7007,6 +7008,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpCooperativeMatrixMulAddKHR:
case SpvOpCooperativeMatrixConvertNV:
case SpvOpCooperativeMatrixTransposeNV:
case SpvOpCooperativeMatrixReduceNV:
vtn_handle_cooperative_instruction(b, opcode, w, count);
break;

View File

@@ -220,6 +220,24 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
break;
}
case SpvOpCooperativeMatrixReduceNV: {
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
struct vtn_function *reduce_fn = vtn_value(b, w[5], vtn_value_type_function)->func;
reduce_fn->referenced = true;
reduce_fn->nir_func->cmat_call = true;
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_reduce_nv");
nir_cmat_call_instr *call = nir_cmat_call_instr_create(b->nb.shader, nir_cmat_call_op_reduce, reduce_fn->nir_func);
call->params[0] = nir_src_for_ssa(&dst->def);
call->params[1] = nir_src_for_ssa(&src->def);
call->const_index[0] = w[4];
nir_builder_instr_insert(&b->nb, &call->instr);
vtn_push_var_ssa(b, w[2], dst->var);
break;
}
default:
UNREACHABLE("Unexpected opcode for cooperative matrix instruction");
}