From 9547efa6b0f45a365d67977ef64e61fffdb6be2c Mon Sep 17 00:00:00 2001 From: Yiwei Zhang Date: Wed, 5 Mar 2025 17:14:21 -0800 Subject: [PATCH] venus: prepare push template for ray tracing pipeline Signed-off-by: Yiwei Zhang Part-of: --- src/virtio/vulkan/vn_command_buffer.c | 30 +++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/virtio/vulkan/vn_command_buffer.c b/src/virtio/vulkan/vn_command_buffer.c index 046c8b8eae2..ba6d617d467 100644 --- a/src/virtio/vulkan/vn_command_buffer.c +++ b/src/virtio/vulkan/vn_command_buffer.c @@ -2425,13 +2425,35 @@ vn_CmdPushDescriptorSetWithTemplate2(VkCommandBuffer commandBuffer, templ, VK_NULL_HANDLE, pPushDescriptorSetWithTemplateInfo->pData, &update); + /* Per spec: + * + * If stageFlags specifies a subset of all stages corresponding to one or + * more pipeline bind points, the binding operation still affects all + * stages corresponding to the given pipeline bind point(s) as if the + * equivalent original version of this command had been called with the + * same parameters. + * + * So we just need to pick a single stage belonging to the pipeline type. + */ + VkShaderStageFlags stage_flags; + switch (templ->push.pipeline_bind_point) { + case VK_PIPELINE_BIND_POINT_GRAPHICS: + stage_flags = VK_SHADER_STAGE_ALL_GRAPHICS; + break; + case VK_PIPELINE_BIND_POINT_COMPUTE: + stage_flags = VK_SHADER_STAGE_COMPUTE_BIT; + break; + case VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR: + stage_flags = VK_SHADER_STAGE_RAYGEN_BIT_KHR; + break; + default: + unreachable("bad pipeline bind point in the template"); + break; + } const VkPushDescriptorSetInfo info = { .sType = VK_STRUCTURE_TYPE_PUSH_DESCRIPTOR_SET_INFO, .pNext = pPushDescriptorSetWithTemplateInfo->pNext, - .stageFlags = - templ->push.pipeline_bind_point == VK_PIPELINE_BIND_POINT_GRAPHICS - ? VK_SHADER_STAGE_ALL_GRAPHICS - : VK_SHADER_STAGE_COMPUTE_BIT, + .stageFlags = stage_flags, .layout = pPushDescriptorSetWithTemplateInfo->layout, .set = pPushDescriptorSetWithTemplateInfo->set, .descriptorWriteCount = update.write_count,