spirv: Add support for SPV_KHR_integer_dot_product
v2 (Ivan): Add missing capability enum handling. v3 (idr): Properly handle cases where dest_size != 32. v4 (idr): Rewrite most of the error checking to use vtn_fail_if. Use nir_ssa_def with vtn_push_nir_ssa instead of vtn_ssa_value with vtn_push_ssa_value. All suggested by Jason. Massive rewrite of the handling of packed 4x8 saturating opcodes. Based on some observations made by Jason. v5 (idr): Remove some debugging cruft accidentally added in v4. Noticed by Jason. v6: Emit packed versions of vectored instructions when possible. Suggested by Jason. Reviewed-by: Jason Ekstrand <jason@jlekstrand.net> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142>
This commit is contained in:
@@ -4364,6 +4364,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
|
||||
case SpvCapabilityImageGatherExtended:
|
||||
case SpvCapabilityStorageImageExtendedFormats:
|
||||
case SpvCapabilityVector16:
|
||||
case SpvCapabilityDotProductKHR:
|
||||
case SpvCapabilityDotProductInputAllKHR:
|
||||
case SpvCapabilityDotProductInput4x8BitKHR:
|
||||
case SpvCapabilityDotProductInput4x8BitPackedKHR:
|
||||
break;
|
||||
|
||||
case SpvCapabilityLinkage:
|
||||
@@ -5650,6 +5654,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
|
||||
vtn_handle_alu(b, opcode, w, count);
|
||||
break;
|
||||
|
||||
case SpvOpSDotKHR:
|
||||
case SpvOpUDotKHR:
|
||||
case SpvOpSUDotKHR:
|
||||
case SpvOpSDotAccSatKHR:
|
||||
case SpvOpUDotAccSatKHR:
|
||||
case SpvOpSUDotAccSatKHR:
|
||||
vtn_handle_integer_dot(b, opcode, w, count);
|
||||
break;
|
||||
|
||||
case SpvOpBitcast:
|
||||
vtn_handle_bitcast(b, w, count);
|
||||
break;
|
||||
|
||||
@@ -765,6 +765,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpSDotKHR:
|
||||
case SpvOpUDotKHR:
|
||||
case SpvOpSUDotKHR:
|
||||
case SpvOpSDotAccSatKHR:
|
||||
case SpvOpUDotAccSatKHR:
|
||||
case SpvOpSUDotAccSatKHR:
|
||||
unreachable("Should have called vtn_handle_integer_dot instead.");
|
||||
|
||||
default: {
|
||||
bool swap;
|
||||
bool exact;
|
||||
@@ -823,6 +831,270 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
b->nb.exact = b->exact;
|
||||
}
|
||||
|
||||
void
|
||||
vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
{
|
||||
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
|
||||
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
|
||||
const unsigned dest_size = glsl_get_bit_size(dest_type);
|
||||
|
||||
vtn_handle_no_contraction(b, dest_val);
|
||||
|
||||
/* Collect the various SSA sources.
|
||||
*
|
||||
* Due to the optional "Packed Vector Format" field, determine number of
|
||||
* inputs from the opcode. This differs from vtn_handle_alu.
|
||||
*/
|
||||
const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
|
||||
opcode == SpvOpUDotAccSatKHR ||
|
||||
opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
|
||||
|
||||
vtn_assert(count >= num_inputs + 3);
|
||||
|
||||
struct vtn_ssa_value *vtn_src[3] = { NULL, };
|
||||
nir_ssa_def *src[3] = { NULL, };
|
||||
|
||||
for (unsigned i = 0; i < num_inputs; i++) {
|
||||
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
|
||||
src[i] = vtn_src[i]->def;
|
||||
|
||||
vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
|
||||
}
|
||||
|
||||
/* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
|
||||
* the SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* _Vector 1_ and _Vector 2_ must have the same type.
|
||||
*
|
||||
* The practical requirement is the same bit-size and the same number of
|
||||
* components.
|
||||
*/
|
||||
vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
|
||||
glsl_get_bit_size(vtn_src[1]->type) ||
|
||||
glsl_get_vector_elements(vtn_src[0]->type) !=
|
||||
glsl_get_vector_elements(vtn_src[1]->type),
|
||||
"Vector 1 and vector 2 source of opcode %s must have the same "
|
||||
"type",
|
||||
spirv_op_to_string(opcode));
|
||||
|
||||
if (num_inputs == 3) {
|
||||
/* The SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* The type of Accumulator must be the same as Result Type.
|
||||
*
|
||||
* The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
|
||||
* types (far below) assumes these types have the same size.
|
||||
*/
|
||||
vtn_fail_if(dest_type != vtn_src[2]->type,
|
||||
"Accumulator type must be the same as Result Type for "
|
||||
"opcode %s",
|
||||
spirv_op_to_string(opcode));
|
||||
}
|
||||
|
||||
if (glsl_type_is_vector(vtn_src[0]->type)) {
|
||||
/* FINISHME: Is this actually as good or better for platforms that don't
|
||||
* have the special instructions (i.e., one or both of has_dot_4x8 or
|
||||
* has_sudot_4x8 is false)?
|
||||
*/
|
||||
if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
|
||||
glsl_get_bit_size(vtn_src[0]->type) == 8 &&
|
||||
glsl_get_bit_size(dest_type) <= 32) {
|
||||
src[0] = nir_pack_32_4x8(&b->nb, src[0]);
|
||||
src[1] = nir_pack_32_4x8(&b->nb, src[1]);
|
||||
}
|
||||
} else if (glsl_type_is_scalar(vtn_src[0]->type) &&
|
||||
glsl_type_is_32bit(vtn_src[0]->type)) {
|
||||
/* The SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
|
||||
* Vector Format_ must be specified to select how the integers are to
|
||||
* be interpreted as vectors.
|
||||
*
|
||||
* The "Packed Vector Format" value follows the last input.
|
||||
*/
|
||||
vtn_assert(count == (num_inputs + 4));
|
||||
const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
|
||||
vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
|
||||
"Unsupported vector packing format %d for opcode %s",
|
||||
pack_format, spirv_op_to_string(opcode));
|
||||
} else {
|
||||
vtn_fail_with_opcode("Invalid source types.", opcode);
|
||||
}
|
||||
|
||||
nir_ssa_def *dest = NULL;
|
||||
|
||||
if (src[0]->num_components > 1) {
|
||||
const nir_op s_conversion_op =
|
||||
nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
|
||||
nir_rounding_mode_undef);
|
||||
|
||||
const nir_op u_conversion_op =
|
||||
nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
|
||||
nir_rounding_mode_undef);
|
||||
|
||||
nir_op src0_conversion_op;
|
||||
nir_op src1_conversion_op;
|
||||
|
||||
switch (opcode) {
|
||||
case SpvOpSDotKHR:
|
||||
case SpvOpSDotAccSatKHR:
|
||||
src0_conversion_op = s_conversion_op;
|
||||
src1_conversion_op = s_conversion_op;
|
||||
break;
|
||||
|
||||
case SpvOpUDotKHR:
|
||||
case SpvOpUDotAccSatKHR:
|
||||
src0_conversion_op = u_conversion_op;
|
||||
src1_conversion_op = u_conversion_op;
|
||||
break;
|
||||
|
||||
case SpvOpSUDotKHR:
|
||||
case SpvOpSUDotAccSatKHR:
|
||||
src0_conversion_op = s_conversion_op;
|
||||
src1_conversion_op = u_conversion_op;
|
||||
break;
|
||||
|
||||
default:
|
||||
unreachable("Invalid opcode.");
|
||||
}
|
||||
|
||||
/* The SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* All components of the input vectors are sign-extended to the bit
|
||||
* width of the result's type. The sign-extended input vectors are
|
||||
* then multiplied component-wise and all components of the vector
|
||||
* resulting from the component-wise multiplication are added
|
||||
* together. The resulting value will equal the low-order N bits of
|
||||
* the correct result R, where N is the result width and R is
|
||||
* computed with enough precision to avoid overflow and underflow.
|
||||
*/
|
||||
const unsigned vector_components =
|
||||
glsl_get_vector_elements(vtn_src[0]->type);
|
||||
|
||||
for (unsigned i = 0; i < vector_components; i++) {
|
||||
nir_ssa_def *const src0 =
|
||||
nir_build_alu(&b->nb, src0_conversion_op,
|
||||
nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
|
||||
|
||||
nir_ssa_def *const src1 =
|
||||
nir_build_alu(&b->nb, src1_conversion_op,
|
||||
nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
|
||||
|
||||
nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
|
||||
|
||||
dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
|
||||
}
|
||||
|
||||
if (num_inputs == 3) {
|
||||
/* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* Signed integer dot product of _Vector 1_ and _Vector 2_ and
|
||||
* signed saturating addition of the result with _Accumulator_.
|
||||
*
|
||||
* For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
|
||||
* unsigned saturating addition of the result with _Accumulator_.
|
||||
*
|
||||
* For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* Mixed-signedness integer dot product of _Vector 1_ and _Vector
|
||||
* 2_ and signed saturating addition of the result with
|
||||
* _Accumulator_.
|
||||
*/
|
||||
dest = (opcode == SpvOpUDotAccSatKHR)
|
||||
? nir_uadd_sat(&b->nb, dest, src[2])
|
||||
: nir_iadd_sat(&b->nb, dest, src[2]);
|
||||
}
|
||||
} else {
|
||||
assert(src[0]->num_components == 1 && src[1]->num_components == 1);
|
||||
assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
|
||||
|
||||
nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
|
||||
bool is_signed;
|
||||
|
||||
switch (opcode) {
|
||||
case SpvOpSDotKHR:
|
||||
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
|
||||
is_signed = true;
|
||||
break;
|
||||
|
||||
case SpvOpUDotKHR:
|
||||
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
|
||||
is_signed = false;
|
||||
break;
|
||||
|
||||
case SpvOpSUDotKHR:
|
||||
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
|
||||
is_signed = true;
|
||||
break;
|
||||
|
||||
case SpvOpSDotAccSatKHR:
|
||||
if (dest_size == 32)
|
||||
dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
|
||||
else
|
||||
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
|
||||
|
||||
is_signed = true;
|
||||
break;
|
||||
|
||||
case SpvOpUDotAccSatKHR:
|
||||
if (dest_size == 32)
|
||||
dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
|
||||
else
|
||||
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
|
||||
|
||||
is_signed = false;
|
||||
break;
|
||||
|
||||
case SpvOpSUDotAccSatKHR:
|
||||
if (dest_size == 32)
|
||||
dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
|
||||
else
|
||||
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
|
||||
|
||||
is_signed = true;
|
||||
break;
|
||||
|
||||
default:
|
||||
unreachable("Invalid opcode.");
|
||||
}
|
||||
|
||||
if (dest_size != 32) {
|
||||
/* When the accumulator is 32-bits, a NIR dot-product with saturate
|
||||
* is generated above. In all other cases a regular dot-product is
|
||||
* generated above, and separate addition with saturate is generated
|
||||
* here.
|
||||
*
|
||||
* The SPV_KHR_integer_dot_product spec says:
|
||||
*
|
||||
* If any of the multiplications or additions, with the exception
|
||||
* of the final accumulation, overflow or underflow, the result of
|
||||
* the instruction is undefined.
|
||||
*
|
||||
* Therefore it is safe to cast the dot-product result down to the
|
||||
* size of the accumulator before doing the addition. Since the
|
||||
* result of the dot-product cannot overflow 32-bits, this is also
|
||||
* safe to cast up.
|
||||
*/
|
||||
if (num_inputs == 3) {
|
||||
dest = is_signed
|
||||
? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
|
||||
: nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
|
||||
} else {
|
||||
dest = is_signed
|
||||
? nir_i2i(&b->nb, dest, dest_size)
|
||||
: nir_u2u(&b->nb, dest, dest_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
vtn_push_nir_ssa(b, w[2], dest);
|
||||
|
||||
b->nb.exact = b->exact;
|
||||
}
|
||||
|
||||
void
|
||||
vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
|
||||
{
|
||||
|
||||
@@ -919,6 +919,9 @@ nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
|
||||
void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count);
|
||||
|
||||
void vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count);
|
||||
|
||||
void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w,
|
||||
unsigned count);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user