nir: add coopmat per element operations.
Cooperative matrix as per-element calls that are var args from a spir-v. These uses the new call op enum. Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36992>
This commit is contained in:
@@ -913,6 +913,8 @@ int
|
||||
nir_cmat_call_op_params(nir_cmat_call_op op, nir_function *callee)
|
||||
{
|
||||
switch (op) {
|
||||
case nir_cmat_call_op_per_element_op:
|
||||
return callee->num_params;
|
||||
case nir_cmat_call_op_reduce:
|
||||
return 2;
|
||||
case nir_cmat_call_op_reduce_finish:
|
||||
|
||||
@@ -1859,6 +1859,11 @@ typedef enum {
|
||||
* reduce 2x2 dst, src0, src1, src2, src3.
|
||||
*/
|
||||
nir_cmat_call_op_reduce_2x2,
|
||||
/*
|
||||
* Cooperative matrix per-element operation call
|
||||
* per-element dst, row offset, col offset, src
|
||||
*/
|
||||
nir_cmat_call_op_per_element_op,
|
||||
} nir_cmat_call_op;
|
||||
|
||||
typedef struct nir_cmat_call_instr {
|
||||
|
||||
@@ -724,6 +724,50 @@ split_cmat_load_store(nir_builder *b,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
split_cmat_call_per_element_op(nir_builder *b,
|
||||
nir_cmat_call_instr *call,
|
||||
struct split_info *info)
|
||||
{
|
||||
nir_instr *instr = &call->instr;
|
||||
struct split_mat *dst_split = find_call_split(info->split_mats, call, 0);
|
||||
struct split_mat *src_split = find_call_split(info->split_mats, call, 3);
|
||||
if (!dst_split)
|
||||
return false;
|
||||
|
||||
assert(src_split);
|
||||
int splits = dst_split->num_col_splits * dst_split->num_row_splits;
|
||||
if (splits <= 1)
|
||||
return false;
|
||||
|
||||
for (unsigned r = 0; r < dst_split->num_row_splits; r++) {
|
||||
for (unsigned c = 0; c < dst_split->num_col_splits; c++) {
|
||||
int idx = r * dst_split->num_col_splits + c;
|
||||
nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[idx]);
|
||||
nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[idx]);
|
||||
struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type);
|
||||
nir_cmat_call_instr *new_call = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_per_element_op, call->callee);
|
||||
new_call->params[0] = nir_src_for_ssa(&dst_deref->def);
|
||||
new_call->params[1] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.rows * r));
|
||||
new_call->params[2] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.cols * c));
|
||||
new_call->params[3] = nir_src_for_ssa(&src_deref->def);
|
||||
|
||||
for (unsigned i = 4; i < call->num_params; i++) {
|
||||
if (nir_src_as_deref(call->params[i])) {
|
||||
struct split_mat *src1_split = find_call_split(info->split_mats, call, i);
|
||||
nir_deref_instr *src1_deref = src1_split ? recreate_derefs(b, &call->params[i], src1_split->split_vars[idx]) : nir_src_as_deref(call->params[i]);
|
||||
new_call->params[i] = src1_deref ? nir_src_for_ssa(&src1_deref->def) : call->params[i];
|
||||
} else
|
||||
new_call->params[i] = call->params[i];
|
||||
}
|
||||
b->cursor = nir_before_instr(instr);
|
||||
nir_builder_instr_insert(b, &new_call->instr);
|
||||
}
|
||||
}
|
||||
nir_instr_remove(instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
split_matrix_impl(nir_function_impl *impl, struct split_info *info)
|
||||
{
|
||||
@@ -787,6 +831,9 @@ split_matrix_impl(nir_function_impl *impl, struct split_info *info)
|
||||
case nir_cmat_call_op_reduce:
|
||||
progress |= split_cmat_call_reduce(&b, impl, cmat_call, info);
|
||||
break;
|
||||
case nir_cmat_call_op_per_element_op:
|
||||
progress |= split_cmat_call_per_element_op(&b, cmat_call, info);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -2077,6 +2077,8 @@ get_cmat_call_op_str(nir_cmat_call_op op)
|
||||
return "cmat_call_reduce_finish";
|
||||
case nir_cmat_call_op_reduce_2x2:
|
||||
return "cmat_call_reduce_2x2";
|
||||
case nir_cmat_call_op_per_element_op:
|
||||
return "cmat_call_per_element";
|
||||
}
|
||||
UNREACHABLE("Unknown cmat call op");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user