diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c index 6452b3e4180..805a5c570c5 100644 --- a/src/amd/vulkan/radv_cmd_buffer.c +++ b/src/amd/vulkan/radv_cmd_buffer.c @@ -9132,10 +9132,15 @@ radv_emit_shaders(struct radv_cmd_buffer *cmd_buffer) radv_emit_tess_eval_shader(device, cs, cs, tes, gs); break; } - case MESA_SHADER_GEOMETRY: - radv_emit_geometry_shader(device, cs, cs, cmd_buffer->state.shaders[MESA_SHADER_GEOMETRY], NULL, + case MESA_SHADER_GEOMETRY: { + struct radv_shader *es = cmd_buffer->state.shaders[MESA_SHADER_TESS_EVAL] + ? cmd_buffer->state.shaders[MESA_SHADER_TESS_EVAL] + : cmd_buffer->state.shaders[MESA_SHADER_VERTEX]; + + radv_emit_geometry_shader(device, cs, cs, cmd_buffer->state.shaders[MESA_SHADER_GEOMETRY], es, shader_obj->gs.copy_shader); break; + } case MESA_SHADER_FRAGMENT: radv_emit_fragment_shader(device, cs, cs, cmd_buffer->state.shaders[MESA_SHADER_FRAGMENT]); radv_emit_ps_inputs(device, cs, last_vgt_shader, cmd_buffer->state.shaders[MESA_SHADER_FRAGMENT]);