spirv: explicitly lower derivatives to zero

To allow removal of the existing nir_builder lowering.

Acked-by: Marek Olšák <marek.olsak@amd.com>
Reviewed-by: Timothy Arceri <tarceri@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31137>
This commit is contained in:
Georg Lehmann
2024-09-12 09:04:13 +02:00
committed by Marge Bot
parent 721d23b8ff
commit ff4596ae61
+44 -24
View File
@@ -617,6 +617,49 @@ vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
}
}
static nir_def *
vtn_handle_deriv(struct vtn_builder *b, SpvOp opcode, nir_def *src)
{
/* SPV_NV_compute_shader_derivatives:
* In the GLCompute Execution Model:
* Selection of the four invocations is determined by the DerivativeGroup*NV
* execution mode that was specified for the entry point.
* If neither derivative group mode was specified, the derivatives return zero.
*/
if (b->nb.shader->info.stage == MESA_SHADER_COMPUTE &&
b->nb.shader->info.derivative_group == DERIVATIVE_GROUP_NONE) {
return nir_imm_zero(&b->nb, src->num_components, src->bit_size);
}
switch (opcode) {
case SpvOpDPdx:
return nir_ddx(&b->nb, src);
case SpvOpDPdxFine:
return nir_ddx_fine(&b->nb, src);
case SpvOpDPdxCoarse:
return nir_ddx_coarse(&b->nb, src);
case SpvOpDPdy:
return nir_ddy(&b->nb, src);
case SpvOpDPdyFine:
return nir_ddy_fine(&b->nb, src);
case SpvOpDPdyCoarse:
return nir_ddy_coarse(&b->nb, src);
case SpvOpFwidth:
return nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_ddx(&b->nb, src)),
nir_fabs(&b->nb, nir_ddy(&b->nb, src)));
case SpvOpFwidthFine:
return nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src)),
nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src)));
case SpvOpFwidthCoarse:
return nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src)),
nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src)));
default: unreachable("Not a derivative opcode");
}
}
void
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@@ -721,38 +764,15 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
}
case SpvOpDPdx:
dest->def = nir_ddx(&b->nb, src[0]);
break;
case SpvOpDPdxFine:
dest->def = nir_ddx_fine(&b->nb, src[0]);
break;
case SpvOpDPdxCoarse:
dest->def = nir_ddx_coarse(&b->nb, src[0]);
break;
case SpvOpDPdy:
dest->def = nir_ddy(&b->nb, src[0]);
break;
case SpvOpDPdyFine:
dest->def = nir_ddy_fine(&b->nb, src[0]);
break;
case SpvOpDPdyCoarse:
dest->def = nir_ddy_coarse(&b->nb, src[0]);
break;
case SpvOpFwidth:
dest->def = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_ddx(&b->nb, src[0])),
nir_fabs(&b->nb, nir_ddy(&b->nb, src[0])));
break;
case SpvOpFwidthFine:
dest->def = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src[0])),
nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src[0])));
break;
case SpvOpFwidthCoarse:
dest->def = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src[0])),
nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src[0])));
dest->def = vtn_handle_deriv(b, opcode, src[0]);
break;
case SpvOpVectorTimesScalar: