nir: add support for cooperative matrix reduction operations.

This adds some new call operations to handle various parts of the
reductions.

cmat_reduce: is the initial toplevel operation from SPIR-V
this is used after lowering for row/col operation on single hw
supported matrix sizes. The spir-v operation is lowered into
multiple of these on flex dimensions, but also can be lowered into
others.

cmat_reduce_finish:
after multiple reduction operations on a flexible dimension matrix,
there is often subsequent operations on the output matrices to
finish the operation.

cmat_reduce_2x2:
this takes 4 input matrices, and 1 dst to do a 2x2 reduction op.

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38389>
This commit is contained in:
Dave Airlie
2025-08-11 16:20:52 +10:00
committed by Marge Bot
parent 9385d94bc9
commit 438245404c
4 changed files with 258 additions and 31 deletions

View File

@@ -910,9 +910,14 @@ int
nir_cmat_call_op_params(nir_cmat_call_op op, nir_function *callee)
{
switch (op) {
default:
return callee->num_params;
case nir_cmat_call_op_reduce:
return 2;
case nir_cmat_call_op_reduce_finish:
return 3;
case nir_cmat_call_op_reduce_2x2:
return 5;
}
UNREACHABLE("Invalid cmat call op");
}
nir_cmat_call_instr *

View File

@@ -249,6 +249,12 @@ typedef enum {
NIR_CMAT_RESULT_SIGNED = 1u << 3,
} nir_cmat_signed;
typedef enum {
NIR_CMAT_REDUCE_ROW = 1u << 0,
NIR_CMAT_REDUCE_COLUMN = 1u << 1,
NIR_CMAT_REDUCE_2X2 = 1u << 2,
} nir_cmat_reduce;
#define nir_const_value_to_array(arr, c, components, m) \
do { \
for (unsigned i = 0; i < components; ++i) \
@@ -1837,7 +1843,22 @@ typedef struct nir_call_instr {
#define NIR_CMAT_CALL_MAX_CONST_INDEX 1
typedef enum {
nir_cmat_call_op_none,
/*
* Cooperative matrix row/column reduce operation
* reduce (dst, src) - const index row/col
*/
nir_cmat_call_op_reduce,
/*
* Cooperative matrix reduce operation finish
* for split flexible dimension matricies
* reduce_finish (dst, src0, src1) - const index 0 row/col reduce
*/
nir_cmat_call_op_reduce_finish,
/*
* Cooperative matrix 2x2 reduce operation
* reduce 2x2 dst, src0, src1, src2, src3.
*/
nir_cmat_call_op_reduce_2x2,
} nir_cmat_call_op;
typedef struct nir_cmat_call_instr {
@@ -1852,6 +1873,11 @@ typedef struct nir_cmat_call_instr {
nir_src params[];
} nir_cmat_call_instr;
static inline nir_cmat_reduce nir_cmat_call_reduce_flags(nir_cmat_call_instr *call)
{
return (nir_cmat_reduce)call->const_index[0];
}
#include "nir_intrinsics.h"
#define NIR_INTRINSIC_MAX_CONST_INDEX 8

View File

@@ -31,11 +31,21 @@ static struct split_mat *find_split(struct hash_table *split_mats,
return entry ? entry->data : NULL;
}
static struct nir_deref_instr *recreate_derefs(nir_builder *b, nir_intrinsic_instr *intr, int idx,
nir_variable *var)
static struct split_mat *find_call_split(struct hash_table *split_mats,
nir_cmat_call_instr *call, int idx)
{
nir_deref_instr *deref = nir_src_as_deref(intr->src[idx]);
nir_deref_instr *deref = nir_src_as_deref(call->params[idx]);
if (!deref)
return NULL;
nir_variable *var = nir_deref_instr_get_variable(deref);
struct hash_entry *entry = _mesa_hash_table_search(split_mats, var);
return entry ? entry->data : NULL;
}
static struct nir_deref_instr *recreate_derefs(nir_builder *b, nir_src *src,
nir_variable *var)
{
nir_deref_instr *deref = nir_src_as_deref(*src);
nir_deref_path path;
nir_deref_path_init(&path, deref, NULL);
@@ -113,7 +123,7 @@ split_cmat_construct(nir_builder *b,
return false;
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_construct(b, &dst_deref->def, intr->src[1].ssa);
}
@@ -140,8 +150,8 @@ split_cmat_copy(nir_builder *b,
assert(src_split);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, intr, 1, src_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_copy(b, &dst_deref->def, &src_deref->def);
}
@@ -211,8 +221,8 @@ split_cmat_insert(nir_builder *b,
nir_def *arr_idx = nir_udiv(b, intr->src[3].ssa, len);
nir_def *base_idx = nir_umod(b, intr->src[3].ssa, len);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, intr, 2, src_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[2], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_def *new_def = nir_cmat_extract(b, nir_src_bit_size(intr->src[1]), &src_deref->def, base_idx);
@@ -247,7 +257,7 @@ split_cmat_extract(nir_builder *b,
nir_def *base_idx = nir_umod(b, intr->src[1].ssa, len);
nir_def *last_def = nir_undef(b, 1, intr->def.bit_size);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *src_deref = recreate_derefs(b, intr, 0, src_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[0], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_def *cond = nir_ieq_imm(b, arr_idx, i);
nir_def *new_def = nir_cmat_extract(b, intr->def.bit_size, &src_deref->def, base_idx);
@@ -274,8 +284,8 @@ split_cmat_convert(nir_builder *b,
unsigned splits = src_split->num_col_splits * src_split->num_row_splits;
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, intr, 1, src_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_convert(b, &dst_deref->def, &src_deref->def, .saturate = nir_intrinsic_saturate(intr),
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
@@ -303,8 +313,8 @@ split_cmat_transpose(nir_builder *b,
for (unsigned c = 0; c < src_split->num_col_splits; c++) {
int in_idx = r * src_split->num_col_splits + c;
int out_idx = c * dst_split->num_col_splits + r;
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[out_idx]);
nir_deref_instr *src_deref = recreate_derefs(b, intr, 1, src_split->split_vars[in_idx]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[out_idx]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[in_idx]);
b->cursor = nir_before_instr(instr);
nir_cmat_transpose(b, &dst_deref->def, &src_deref->def);
}
@@ -332,8 +342,8 @@ split_cmat_bitcast(nir_builder *b,
assert(src_split);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, intr, 1, src_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_bitcast(b, &dst_deref->def, &src_deref->def);
}
@@ -362,9 +372,9 @@ split_cmat_binary_op(nir_builder *b,
assert(src1_split);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *src0_deref = recreate_derefs(b, intr, 1, src0_split->split_vars[i]);
nir_deref_instr *src1_deref = recreate_derefs(b, intr, 2, src1_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src0_deref = recreate_derefs(b, &intr->src[1], src0_split->split_vars[i]);
nir_deref_instr *src1_deref = recreate_derefs(b, &intr->src[2], src1_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_binary_op(b, &dst_deref->def, &src0_deref->def, &src1_deref->def,
.alu_op = nir_intrinsic_alu_op(intr));
@@ -392,8 +402,8 @@ split_cmat_unary_op(nir_builder *b,
assert(src_split);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, intr, 1, src_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_unary_op(b, &dst_deref->def, &src_deref->def, .alu_op = nir_intrinsic_alu_op(intr));
}
@@ -420,8 +430,8 @@ split_cmat_scalar_op(nir_builder *b,
assert(src_split);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, intr, 0, dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, intr, 1, src_split->split_vars[i]);
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_scalar_op(b, &dst_deref->def, &src_deref->def, intr->src[2].ssa,
.alu_op = nir_intrinsic_alu_op(intr));
@@ -472,14 +482,14 @@ split_cmat_muladd(nir_builder *b,
for (unsigned m = 0; m < m_splits; m++) {
for (unsigned n = 0; n < n_splits; n++) {
unsigned idx = m * n_splits + n;
nir_deref_instr *dst_deref = result_split ? recreate_derefs(b, intr, 0, result_split->split_vars[idx]) : nir_src_as_deref(intr->src[0]);
nir_deref_instr *c_deref = c_split ? recreate_derefs(b, intr, 3, c_split->split_vars[idx]) : nir_src_as_deref(intr->src[3]);
nir_deref_instr *dst_deref = result_split ? recreate_derefs(b, &intr->src[0], result_split->split_vars[idx]) : nir_src_as_deref(intr->src[0]);
nir_deref_instr *c_deref = c_split ? recreate_derefs(b, &intr->src[3], c_split->split_vars[idx]) : nir_src_as_deref(intr->src[3]);
for (unsigned k = 0; k < k_splits; k++) {
unsigned a_idx = m * k_splits + k;
unsigned b_idx = k * n_splits + n;
nir_deref_instr *a_deref = a_split ? recreate_derefs(b, intr, 1, a_split->split_vars[a_idx]) : nir_src_as_deref(intr->src[1]);
nir_deref_instr *b_deref = b_split ? recreate_derefs(b, intr, 2, b_split->split_vars[b_idx]) : nir_src_as_deref(intr->src[2]);
nir_deref_instr *a_deref = a_split ? recreate_derefs(b, &intr->src[1], a_split->split_vars[a_idx]) : nir_src_as_deref(intr->src[1]);
nir_deref_instr *b_deref = b_split ? recreate_derefs(b, &intr->src[2], b_split->split_vars[b_idx]) : nir_src_as_deref(intr->src[2]);
nir_deref_instr *k_dst_deref = k == k_splits - 1 ? dst_deref : c_deref;
b->cursor = nir_before_instr(instr);
@@ -494,6 +504,167 @@ split_cmat_muladd(nir_builder *b,
return true;
}
static void
call_reduce(nir_builder *b,
nir_cmat_call_instr *call,
nir_cmat_reduce reduce,
nir_def *dst, nir_def *src0)
{
nir_cmat_call_instr *ncall = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_reduce, call->callee);
ncall->params[0] = nir_src_for_ssa(dst);
ncall->params[1] = nir_src_for_ssa(src0);
ncall->const_index[0] = reduce;
nir_builder_instr_insert(b, &ncall->instr);
}
static void
call_reduce_finish(nir_builder *b,
nir_cmat_call_instr *call,
nir_cmat_reduce reduce,
nir_def *dst, nir_def *src0, nir_def *src1)
{
nir_cmat_call_instr *ncall = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_reduce_finish, call->callee);
ncall->params[0] = nir_src_for_ssa(dst);
ncall->params[1] = nir_src_for_ssa(src0);
ncall->params[2] = nir_src_for_ssa(src1);
ncall->const_index[0] = reduce;
nir_builder_instr_insert(b, &ncall->instr);
}
static void
call_reduce_2x2(nir_builder *b,
nir_cmat_call_instr *call,
nir_def *dst,
nir_def *src0, nir_def *src1,
nir_def *src2, nir_def *src3)
{
nir_cmat_call_instr *ncall = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_reduce_2x2, call->callee);
ncall->params[0] = nir_src_for_ssa(dst);
ncall->params[1] = nir_src_for_ssa(src0);
ncall->params[2] = nir_src_for_ssa(src1);
ncall->params[3] = nir_src_for_ssa(src2);
ncall->params[4] = nir_src_for_ssa(src3);
nir_builder_instr_insert(b, &ncall->instr);
}
static bool
split_cmat_call_reduce(nir_builder *b,
nir_function_impl *impl,
nir_cmat_call_instr *call,
struct split_info *info)
{
nir_instr *instr = &call->instr;
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
struct split_mat *dst_split = find_call_split(info->split_mats, call, 0);
struct split_mat *src_split = find_call_split(info->split_mats, call, 1);
if (reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) {
assert(!(reduce & ~(NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)));
/* for each source split - reduce it by itself. */
int src_splits = 1;
if (src_split)
src_splits = src_split->num_col_splits * src_split->num_row_splits;
nir_deref_instr **temp_derefs = ralloc_array(NULL, nir_deref_instr *, src_splits);
const struct glsl_type *temp_type = nir_deref_instr_get_variable(nir_src_as_deref(call->params[1]))->type;
if (src_splits > 1)
temp_type = src_split->split_vars[0]->type;
for (unsigned i = 0; i < src_splits; i++) {
nir_variable *temp_var = nir_local_variable_create(impl, temp_type,
"reduce_split_srcs");
temp_derefs[i] = nir_build_deref_var(b, temp_var);
}
if (src_splits > 1) {
/* reduce each individual src matrix */
for (unsigned i = 0; i < src_splits; i++) {
nir_deref_instr *src_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
call_reduce(b, call, reduce, &temp_derefs[i]->def, &src_deref->def);
}
if ((reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) == (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) {
for (unsigned i = 1; i < src_splits; i++) {
nir_deref_instr *second_deref = temp_derefs[i];
b->cursor = nir_before_instr(instr);
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
}
} else if (reduce & NIR_CMAT_REDUCE_ROW) {
for (unsigned i = 1; i < src_split->num_col_splits; i++) {
nir_deref_instr *second_deref = temp_derefs[i];
b->cursor = nir_before_instr(instr);
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
}
} else if (reduce & NIR_CMAT_REDUCE_COLUMN) {
for (unsigned i = 1; i < src_split->num_row_splits; i++) {
nir_deref_instr *second_deref = temp_derefs[i * src_split->num_col_splits];
b->cursor = nir_before_instr(instr);
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
}
}
} else {
call_reduce(b, call, reduce, &temp_derefs[0]->def, &nir_src_as_deref(call->params[1])->def);
}
/* at this point temp_derefs should contain all the split reduced src matrices
now to store them */
if (dst_split) {
for (unsigned r = 0; r < dst_split->num_row_splits; r++) {
for (unsigned c = 0; c < dst_split->num_col_splits; c++) {
int didx = r * dst_split->num_col_splits + c;
int idx;
if ((reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) == (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN))
idx = 0;
else if (reduce & NIR_CMAT_REDUCE_ROW)
idx = r % (src_split ? src_split->num_row_splits : 1);
else if (reduce & NIR_CMAT_REDUCE_COLUMN)
idx = c % (src_split ? src_split->num_col_splits : 1);
nir_deref_instr *deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[didx]);
b->cursor = nir_before_instr(instr);
nir_cmat_copy(b, &deref->def, &temp_derefs[idx]->def);
}
}
} else {
nir_cmat_copy(b, call->params[0].ssa, &temp_derefs[0]->def);
}
ralloc_free(temp_derefs);
} else if (reduce & NIR_CMAT_REDUCE_2X2) {
assert(reduce == NIR_CMAT_REDUCE_2X2);
/* dst can have target dimensions, but src but be at least twice as large */
assert (src_split);
int rows = 1, cols = 1;
if (dst_split) {
rows = dst_split->num_row_splits;
cols = dst_split->num_col_splits;
}
for (unsigned r = 0; r < rows; r++) {
for (unsigned c = 0; c < cols; c++) {
int d_idx = c + r * cols;
int src_top_left_col = c * 2;
int src_top_left_row = r * 2;
int src_top_idx = src_top_left_col + src_top_left_row * src_split->num_col_splits;
int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->num_col_splits;
nir_deref_instr *src0_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx]);
nir_deref_instr *src1_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx + 1]);
nir_deref_instr *src2_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_bottom_idx]);
nir_deref_instr *src3_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_bottom_idx + 1]);
nir_deref_instr *dst_deref = dst_split ? recreate_derefs(b, &call->params[0], dst_split->split_vars[d_idx]) : nir_src_as_deref(call->params[0]);
b->cursor = nir_before_instr(instr);
call_reduce_2x2(b, call, &dst_deref->def, &src0_deref->def, &src1_deref->def, &src2_deref->def, &src3_deref->def);
}
}
}
nir_instr_remove(instr);
return true;
}
static bool
split_cmat_load_store(nir_builder *b,
nir_intrinsic_instr *intr,
@@ -510,7 +681,7 @@ split_cmat_load_store(nir_builder *b,
struct split_mat *split = entry->data;
unsigned splits = split->num_row_splits * split->num_col_splits;
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *new_deref = recreate_derefs(b, intr, !is_load, split->split_vars[i]);
nir_deref_instr *new_deref = recreate_derefs(b, &intr->src[!is_load], split->split_vars[i]);
nir_deref_instr *ptr_deref;
nir_def *stride = intr->src[2].ssa;
nir_def *ptr = intr->src[is_load].ssa;
@@ -610,6 +781,17 @@ split_matrix_impl(nir_function_impl *impl, struct split_info *info)
}
break;
}
case nir_instr_type_cmat_call: {
nir_cmat_call_instr *cmat_call = nir_instr_as_cmat_call(instr);
switch (cmat_call->op) {
case nir_cmat_call_op_reduce:
progress |= split_cmat_call_reduce(&b, impl, cmat_call, info);
break;
default:
break;
}
break;
}
default:
break;
}

View File

@@ -2051,6 +2051,20 @@ print_call_instr(nir_call_instr *instr, print_state *state)
}
}
static const char *
get_cmat_call_op_str(nir_cmat_call_op op)
{
switch (op) {
case nir_cmat_call_op_reduce:
return "cmat_call_reduce";
case nir_cmat_call_op_reduce_finish:
return "cmat_call_reduce_finish";
case nir_cmat_call_op_reduce_2x2:
return "cmat_call_reduce_2x2";
}
UNREACHABLE("Unknown cmat call op");
}
static void
print_cmat_call_instr(nir_cmat_call_instr *instr, print_state *state)
{
@@ -2058,7 +2072,7 @@ print_cmat_call_instr(nir_cmat_call_instr *instr, print_state *state)
print_no_dest_padding(state);
fprintf(fp, "cmat_call %s ", instr->callee->name);
fprintf(fp, "%s %s ", get_cmat_call_op_str(instr->op), instr->callee->name);
for (unsigned i = 0; i < instr->num_params; i++) {
if (i != 0)