nir/spirv: Add real support for outer products
This commit is contained in:
@@ -187,10 +187,6 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
dest->ssa = vtn_ssa_transpose(b, src0);
|
||||
break;
|
||||
|
||||
case SpvOpOuterProduct:
|
||||
dest->ssa = matrix_multiply(b, src0, vtn_ssa_transpose(b, src1));
|
||||
break;
|
||||
|
||||
case SpvOpMatrixTimesScalar:
|
||||
if (src0->transposed) {
|
||||
dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
|
||||
@@ -292,6 +288,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
case SpvOpSMod: op = nir_op_umod; break; /* FIXME? */
|
||||
case SpvOpFMod: op = nir_op_fmod; break;
|
||||
|
||||
case SpvOpOuterProduct: {
|
||||
for (unsigned i = 0; i < src[1]->num_components; i++) {
|
||||
val->ssa->elems[i]->def =
|
||||
nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
case SpvOpDot:
|
||||
assert(src[0]->num_components == src[1]->num_components);
|
||||
switch (src[0]->num_components) {
|
||||
|
||||
Reference in New Issue
Block a user