diff --git a/src/glsl/nir/spirv_to_nir.c b/src/glsl/nir/spirv_to_nir.c index c8594085d5e..973ff7c6777 100644 --- a/src/glsl/nir/spirv_to_nir.c +++ b/src/glsl/nir/spirv_to_nir.c @@ -633,21 +633,44 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, } static void -vtn_get_builtin_location(SpvBuiltIn builtin, int *location, +set_mode_system_value(nir_variable_mode *mode) +{ + assert(*mode == nir_var_system_value || *mode == nir_var_shader_in); + *mode = nir_var_system_value; +} + +static void +validate_per_vertex_mode(struct vtn_builder *b, nir_variable_mode mode) +{ + switch (b->shader->stage) { + case MESA_SHADER_VERTEX: + assert(mode == nir_var_shader_out); + break; + case MESA_SHADER_GEOMETRY: + assert(mode == nir_var_shader_out || mode == nir_var_shader_in); + break; + default: + assert(!"Invalid shader stage"); + } +} + +static void +vtn_get_builtin_location(struct vtn_builder *b, + SpvBuiltIn builtin, int *location, nir_variable_mode *mode) { switch (builtin) { case SpvBuiltInPosition: *location = VARYING_SLOT_POS; - *mode = nir_var_shader_out; + validate_per_vertex_mode(b, *mode); break; case SpvBuiltInPointSize: *location = VARYING_SLOT_PSIZ; - *mode = nir_var_shader_out; + validate_per_vertex_mode(b, *mode); break; case SpvBuiltInClipDistance: *location = VARYING_SLOT_CLIP_DIST0; /* XXX CLIP_DIST1? */ - *mode = nir_var_shader_in; + validate_per_vertex_mode(b, *mode); break; case SpvBuiltInCullDistance: /* XXX figure this out */ @@ -657,11 +680,11 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location, * builtin keyword VertexIndex to indicate the non-zero-based value. */ *location = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE; - *mode = nir_var_system_value; + set_mode_system_value(mode); break; case SpvBuiltInInstanceId: *location = SYSTEM_VALUE_INSTANCE_ID; - *mode = nir_var_system_value; + set_mode_system_value(mode); break; case SpvBuiltInPrimitiveId: *location = VARYING_SLOT_PRIMITIVE_ID; @@ -669,7 +692,7 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location, break; case SpvBuiltInInvocationId: *location = SYSTEM_VALUE_INVOCATION_ID; - *mode = nir_var_system_value; + set_mode_system_value(mode); break; case SpvBuiltInLayer: *location = VARYING_SLOT_LAYER; @@ -682,35 +705,40 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location, unreachable("no tessellation support"); case SpvBuiltInFragCoord: *location = VARYING_SLOT_POS; - *mode = nir_var_shader_in; + assert(b->shader->stage == MESA_SHADER_FRAGMENT); + assert(*mode == nir_var_shader_in); break; case SpvBuiltInPointCoord: *location = VARYING_SLOT_PNTC; - *mode = nir_var_shader_out; + assert(b->shader->stage == MESA_SHADER_FRAGMENT); + assert(*mode == nir_var_shader_in); break; case SpvBuiltInFrontFacing: *location = VARYING_SLOT_FACE; - *mode = nir_var_shader_out; + assert(b->shader->stage == MESA_SHADER_FRAGMENT); + assert(*mode == nir_var_shader_in); break; case SpvBuiltInSampleId: *location = SYSTEM_VALUE_SAMPLE_ID; - *mode = nir_var_shader_in; + set_mode_system_value(mode); break; case SpvBuiltInSamplePosition: *location = SYSTEM_VALUE_SAMPLE_POS; - *mode = nir_var_shader_in; + set_mode_system_value(mode); break; case SpvBuiltInSampleMask: *location = SYSTEM_VALUE_SAMPLE_MASK_IN; /* XXX out? */ - *mode = nir_var_shader_in; + set_mode_system_value(mode); break; case SpvBuiltInFragColor: *location = FRAG_RESULT_COLOR; - *mode = nir_var_shader_out; + assert(b->shader->stage == MESA_SHADER_FRAGMENT); + assert(*mode == nir_var_shader_out); break; case SpvBuiltInFragDepth: *location = FRAG_RESULT_DEPTH; - *mode = nir_var_shader_out; + assert(b->shader->stage == MESA_SHADER_FRAGMENT); + assert(*mode == nir_var_shader_out); break; case SpvBuiltInNumWorkgroups: case SpvBuiltInWorkgroupSize: @@ -723,11 +751,11 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location, unreachable("unsupported builtin"); case SpvBuiltInWorkgroupId: *location = SYSTEM_VALUE_WORK_GROUP_ID; - *mode = nir_var_system_value; + set_mode_system_value(mode); break; case SpvBuiltInLocalInvocationId: *location = SYSTEM_VALUE_LOCAL_INVOCATION_ID; - *mode = nir_var_system_value; + set_mode_system_value(mode); break; case SpvBuiltInHelperInvocation: default: @@ -792,8 +820,8 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member, case SpvDecorationBuiltIn: { SpvBuiltIn builtin = dec->literals[0]; - nir_variable_mode mode; - vtn_get_builtin_location(builtin, &var->data.location, &mode); + nir_variable_mode mode = var->data.mode; + vtn_get_builtin_location(b, builtin, &var->data.location, &mode); var->data.explicit_location = true; var->data.mode = mode; if (mode == nir_var_shader_in || mode == nir_var_system_value) @@ -842,7 +870,7 @@ get_builtin_variable(struct vtn_builder *b, if (!var) { int location; - vtn_get_builtin_location(builtin, &location, &mode); + vtn_get_builtin_location(b, builtin, &location, &mode); var = nir_variable_create(b->shader, mode, type, "builtin");