diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 0e1a81aa7f9..9f78561943d 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -73,6 +73,7 @@ static const struct spirv_capabilities implemented_capabilities = { .ComputeDerivativeGroupQuadsKHR = true, .CooperativeMatrixKHR = true, .CooperativeMatrixConversionsNV = true, + .CooperativeMatrixReductionsNV = true, .CoreBuiltinsARM = true, .CullDistance = true, .DemoteToHelperInvocation = true, @@ -7007,6 +7008,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpCooperativeMatrixMulAddKHR: case SpvOpCooperativeMatrixConvertNV: case SpvOpCooperativeMatrixTransposeNV: + case SpvOpCooperativeMatrixReduceNV: vtn_handle_cooperative_instruction(b, opcode, w, count); break; diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c index 6cefd027139..1390269d5ff 100644 --- a/src/compiler/spirv/vtn_cmat.c +++ b/src/compiler/spirv/vtn_cmat.c @@ -220,6 +220,24 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpCooperativeMatrixReduceNV: { + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]); + + struct vtn_function *reduce_fn = vtn_value(b, w[5], vtn_value_type_function)->func; + + reduce_fn->referenced = true; + reduce_fn->nir_func->cmat_call = true; + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_reduce_nv"); + nir_cmat_call_instr *call = nir_cmat_call_instr_create(b->nb.shader, nir_cmat_call_op_reduce, reduce_fn->nir_func); + call->params[0] = nir_src_for_ssa(&dst->def); + call->params[1] = nir_src_for_ssa(&src->def); + call->const_index[0] = w[4]; + nir_builder_instr_insert(&b->nb, &call->instr); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } + default: UNREACHABLE("Unexpected opcode for cooperative matrix instruction"); }