spirv: Add support for SPV_KHR_integer_dot_product

v2 (Ivan): Add missing capability enum handling.

v3 (idr): Properly handle cases where dest_size != 32.

v4 (idr): Rewrite most of the error checking to use vtn_fail_if.  Use
nir_ssa_def with vtn_push_nir_ssa instead of vtn_ssa_value with
vtn_push_ssa_value.  All suggested by Jason.  Massive rewrite of the
handling of packed 4x8 saturating opcodes.  Based on some observations
made by Jason.

v5 (idr): Remove some debugging cruft accidentally added in v4.  Noticed
by Jason.

v6: Emit packed versions of vectored instructions when possible.
Suggested by Jason.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142>
This commit is contained in:
Ian Romanick
2021-06-14 14:12:36 -07:00
committed by Marge Bot
parent 652d304ee9
commit fe956d0182
3 changed files with 288 additions and 0 deletions
+13
View File
@@ -4364,6 +4364,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvCapabilityImageGatherExtended:
case SpvCapabilityStorageImageExtendedFormats:
case SpvCapabilityVector16:
case SpvCapabilityDotProductKHR:
case SpvCapabilityDotProductInputAllKHR:
case SpvCapabilityDotProductInput4x8BitKHR:
case SpvCapabilityDotProductInput4x8BitPackedKHR:
break;
case SpvCapabilityLinkage:
@@ -5650,6 +5654,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
vtn_handle_alu(b, opcode, w, count);
break;
case SpvOpSDotKHR:
case SpvOpUDotKHR:
case SpvOpSUDotKHR:
case SpvOpSDotAccSatKHR:
case SpvOpUDotAccSatKHR:
case SpvOpSUDotAccSatKHR:
vtn_handle_integer_dot(b, opcode, w, count);
break;
case SpvOpBitcast:
vtn_handle_bitcast(b, w, count);
break;
+272
View File
@@ -765,6 +765,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
break;
}
case SpvOpSDotKHR:
case SpvOpUDotKHR:
case SpvOpSUDotKHR:
case SpvOpSDotAccSatKHR:
case SpvOpUDotAccSatKHR:
case SpvOpSUDotAccSatKHR:
unreachable("Should have called vtn_handle_integer_dot instead.");
default: {
bool swap;
bool exact;
@@ -823,6 +831,270 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
b->nb.exact = b->exact;
}
void
vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
const unsigned dest_size = glsl_get_bit_size(dest_type);
vtn_handle_no_contraction(b, dest_val);
/* Collect the various SSA sources.
*
* Due to the optional "Packed Vector Format" field, determine number of
* inputs from the opcode. This differs from vtn_handle_alu.
*/
const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
opcode == SpvOpUDotAccSatKHR ||
opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
vtn_assert(count >= num_inputs + 3);
struct vtn_ssa_value *vtn_src[3] = { NULL, };
nir_ssa_def *src[3] = { NULL, };
for (unsigned i = 0; i < num_inputs; i++) {
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
src[i] = vtn_src[i]->def;
vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
}
/* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
* the SPV_KHR_integer_dot_product spec says:
*
* _Vector 1_ and _Vector 2_ must have the same type.
*
* The practical requirement is the same bit-size and the same number of
* components.
*/
vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
glsl_get_bit_size(vtn_src[1]->type) ||
glsl_get_vector_elements(vtn_src[0]->type) !=
glsl_get_vector_elements(vtn_src[1]->type),
"Vector 1 and vector 2 source of opcode %s must have the same "
"type",
spirv_op_to_string(opcode));
if (num_inputs == 3) {
/* The SPV_KHR_integer_dot_product spec says:
*
* The type of Accumulator must be the same as Result Type.
*
* The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
* types (far below) assumes these types have the same size.
*/
vtn_fail_if(dest_type != vtn_src[2]->type,
"Accumulator type must be the same as Result Type for "
"opcode %s",
spirv_op_to_string(opcode));
}
if (glsl_type_is_vector(vtn_src[0]->type)) {
/* FINISHME: Is this actually as good or better for platforms that don't
* have the special instructions (i.e., one or both of has_dot_4x8 or
* has_sudot_4x8 is false)?
*/
if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
glsl_get_bit_size(vtn_src[0]->type) == 8 &&
glsl_get_bit_size(dest_type) <= 32) {
src[0] = nir_pack_32_4x8(&b->nb, src[0]);
src[1] = nir_pack_32_4x8(&b->nb, src[1]);
}
} else if (glsl_type_is_scalar(vtn_src[0]->type) &&
glsl_type_is_32bit(vtn_src[0]->type)) {
/* The SPV_KHR_integer_dot_product spec says:
*
* When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
* Vector Format_ must be specified to select how the integers are to
* be interpreted as vectors.
*
* The "Packed Vector Format" value follows the last input.
*/
vtn_assert(count == (num_inputs + 4));
const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
"Unsupported vector packing format %d for opcode %s",
pack_format, spirv_op_to_string(opcode));
} else {
vtn_fail_with_opcode("Invalid source types.", opcode);
}
nir_ssa_def *dest = NULL;
if (src[0]->num_components > 1) {
const nir_op s_conversion_op =
nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
nir_rounding_mode_undef);
const nir_op u_conversion_op =
nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
nir_rounding_mode_undef);
nir_op src0_conversion_op;
nir_op src1_conversion_op;
switch (opcode) {
case SpvOpSDotKHR:
case SpvOpSDotAccSatKHR:
src0_conversion_op = s_conversion_op;
src1_conversion_op = s_conversion_op;
break;
case SpvOpUDotKHR:
case SpvOpUDotAccSatKHR:
src0_conversion_op = u_conversion_op;
src1_conversion_op = u_conversion_op;
break;
case SpvOpSUDotKHR:
case SpvOpSUDotAccSatKHR:
src0_conversion_op = s_conversion_op;
src1_conversion_op = u_conversion_op;
break;
default:
unreachable("Invalid opcode.");
}
/* The SPV_KHR_integer_dot_product spec says:
*
* All components of the input vectors are sign-extended to the bit
* width of the result's type. The sign-extended input vectors are
* then multiplied component-wise and all components of the vector
* resulting from the component-wise multiplication are added
* together. The resulting value will equal the low-order N bits of
* the correct result R, where N is the result width and R is
* computed with enough precision to avoid overflow and underflow.
*/
const unsigned vector_components =
glsl_get_vector_elements(vtn_src[0]->type);
for (unsigned i = 0; i < vector_components; i++) {
nir_ssa_def *const src0 =
nir_build_alu(&b->nb, src0_conversion_op,
nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
nir_ssa_def *const src1 =
nir_build_alu(&b->nb, src1_conversion_op,
nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
}
if (num_inputs == 3) {
/* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
*
* Signed integer dot product of _Vector 1_ and _Vector 2_ and
* signed saturating addition of the result with _Accumulator_.
*
* For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
*
* Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
* unsigned saturating addition of the result with _Accumulator_.
*
* For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
*
* Mixed-signedness integer dot product of _Vector 1_ and _Vector
* 2_ and signed saturating addition of the result with
* _Accumulator_.
*/
dest = (opcode == SpvOpUDotAccSatKHR)
? nir_uadd_sat(&b->nb, dest, src[2])
: nir_iadd_sat(&b->nb, dest, src[2]);
}
} else {
assert(src[0]->num_components == 1 && src[1]->num_components == 1);
assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
bool is_signed;
switch (opcode) {
case SpvOpSDotKHR:
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
case SpvOpUDotKHR:
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
is_signed = false;
break;
case SpvOpSUDotKHR:
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
case SpvOpSDotAccSatKHR:
if (dest_size == 32)
dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
case SpvOpUDotAccSatKHR:
if (dest_size == 32)
dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
is_signed = false;
break;
case SpvOpSUDotAccSatKHR:
if (dest_size == 32)
dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
else
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
is_signed = true;
break;
default:
unreachable("Invalid opcode.");
}
if (dest_size != 32) {
/* When the accumulator is 32-bits, a NIR dot-product with saturate
* is generated above. In all other cases a regular dot-product is
* generated above, and separate addition with saturate is generated
* here.
*
* The SPV_KHR_integer_dot_product spec says:
*
* If any of the multiplications or additions, with the exception
* of the final accumulation, overflow or underflow, the result of
* the instruction is undefined.
*
* Therefore it is safe to cast the dot-product result down to the
* size of the accumulator before doing the addition. Since the
* result of the dot-product cannot overflow 32-bits, this is also
* safe to cast up.
*/
if (num_inputs == 3) {
dest = is_signed
? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
: nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
} else {
dest = is_signed
? nir_i2i(&b->nb, dest, dest_size)
: nir_u2u(&b->nb, dest, dest_size);
}
}
}
vtn_push_nir_ssa(b, w[2], dest);
b->nb.exact = b->exact;
}
void
vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
{
+3
View File
@@ -919,6 +919,9 @@ nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count);
void vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count);
void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w,
unsigned count);