From ff4596ae616fccfe9cfb1c1dd995464c5c7eefab Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Thu, 12 Sep 2024 09:04:13 +0200 Subject: [PATCH] spirv: explicitly lower derivatives to zero MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To allow removal of the existing nir_builder lowering. Acked-by: Marek Olšák Reviewed-by: Timothy Arceri Part-of: --- src/compiler/spirv/vtn_alu.c | 68 +++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index fa8135e749c..94c41c710ca 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -617,6 +617,49 @@ vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value) } } +static nir_def * +vtn_handle_deriv(struct vtn_builder *b, SpvOp opcode, nir_def *src) +{ + /* SPV_NV_compute_shader_derivatives: + * In the GLCompute Execution Model: + * Selection of the four invocations is determined by the DerivativeGroup*NV + * execution mode that was specified for the entry point. + * If neither derivative group mode was specified, the derivatives return zero. + */ + if (b->nb.shader->info.stage == MESA_SHADER_COMPUTE && + b->nb.shader->info.derivative_group == DERIVATIVE_GROUP_NONE) { + return nir_imm_zero(&b->nb, src->num_components, src->bit_size); + } + + switch (opcode) { + case SpvOpDPdx: + return nir_ddx(&b->nb, src); + case SpvOpDPdxFine: + return nir_ddx_fine(&b->nb, src); + case SpvOpDPdxCoarse: + return nir_ddx_coarse(&b->nb, src); + case SpvOpDPdy: + return nir_ddy(&b->nb, src); + case SpvOpDPdyFine: + return nir_ddy_fine(&b->nb, src); + case SpvOpDPdyCoarse: + return nir_ddy_coarse(&b->nb, src); + case SpvOpFwidth: + return nir_fadd(&b->nb, + nir_fabs(&b->nb, nir_ddx(&b->nb, src)), + nir_fabs(&b->nb, nir_ddy(&b->nb, src))); + case SpvOpFwidthFine: + return nir_fadd(&b->nb, + nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src)), + nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src))); + case SpvOpFwidthCoarse: + return nir_fadd(&b->nb, + nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src)), + nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src))); + default: unreachable("Not a derivative opcode"); + } +} + void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -721,38 +764,15 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, } case SpvOpDPdx: - dest->def = nir_ddx(&b->nb, src[0]); - break; case SpvOpDPdxFine: - dest->def = nir_ddx_fine(&b->nb, src[0]); - break; case SpvOpDPdxCoarse: - dest->def = nir_ddx_coarse(&b->nb, src[0]); - break; case SpvOpDPdy: - dest->def = nir_ddy(&b->nb, src[0]); - break; case SpvOpDPdyFine: - dest->def = nir_ddy_fine(&b->nb, src[0]); - break; case SpvOpDPdyCoarse: - dest->def = nir_ddy_coarse(&b->nb, src[0]); - break; - case SpvOpFwidth: - dest->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_ddx(&b->nb, src[0])), - nir_fabs(&b->nb, nir_ddy(&b->nb, src[0]))); - break; case SpvOpFwidthFine: - dest->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src[0])), - nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src[0]))); - break; case SpvOpFwidthCoarse: - dest->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src[0])), - nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src[0]))); + dest->def = vtn_handle_deriv(b, opcode, src[0]); break; case SpvOpVectorTimesScalar: