diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c index 41ba301b444..b8a688a37da 100644 --- a/src/amd/vulkan/radv_cmd_buffer.c +++ b/src/amd/vulkan/radv_cmd_buffer.c @@ -8039,10 +8039,11 @@ ALWAYS_INLINE static void radv_cs_emit_indirect_mesh_draw_packet(struct radv_cmd_buffer *cmd_buffer, uint32_t draw_count, uint64_t count_va, uint32_t stride) { + const struct radv_shader *mesh_shader = cmd_buffer->state.shaders[MESA_SHADER_MESH]; struct radeon_cmdbuf *cs = cmd_buffer->cs; uint32_t base_reg = cmd_buffer->state.vtx_base_sgpr; bool predicating = cmd_buffer->state.predicating; - assert(base_reg); + assert(base_reg || (!cmd_buffer->state.uses_drawid && !mesh_shader->info.cs.uses_grid_size)); /* Reset draw state. */ cmd_buffer->state.last_first_instance = -1; @@ -8050,12 +8051,12 @@ radv_cs_emit_indirect_mesh_draw_packet(struct radv_cmd_buffer *cmd_buffer, uint3 cmd_buffer->state.last_drawid = -1; cmd_buffer->state.last_vertex_offset_valid = false; + uint32_t xyz_dim_enable = mesh_shader->info.cs.uses_grid_size; uint32_t xyz_dim_reg = (base_reg - SI_SH_REG_OFFSET) >> 2; - uint32_t draw_id_reg = (base_reg + 12 - SI_SH_REG_OFFSET) >> 2; + uint32_t draw_id_reg = xyz_dim_reg + (xyz_dim_enable ? 3 : 0); uint32_t draw_id_enable = !!cmd_buffer->state.uses_drawid; - uint32_t xyz_dim_enable = 1; /* TODO: disable XYZ_DIM when unneeded */ - uint32_t mode1_enable = 1; /* legacy fast launch mode */ + uint32_t mode1_enable = 1; /* legacy fast launch mode */ const bool sqtt_en = !!cmd_buffer->device->sqtt.bo; radeon_emit(cs, PKT3(PKT3_DISPATCH_MESH_INDIRECT_MULTI, 7, predicating) | PKT3_RESET_FILTER_CAM_S(1)); @@ -8145,6 +8146,7 @@ radv_cs_emit_dispatch_taskmesh_indirect_multi_ace_packet(struct radv_cmd_buffer ALWAYS_INLINE static void radv_cs_emit_dispatch_taskmesh_gfx_packet(struct radv_cmd_buffer *cmd_buffer) { + const struct radv_shader *mesh_shader = cmd_buffer->state.shaders[MESA_SHADER_MESH]; struct radeon_cmdbuf *cs = cmd_buffer->cs; bool predicating = cmd_buffer->state.predicating; @@ -8153,11 +8155,10 @@ radv_cs_emit_dispatch_taskmesh_gfx_packet(struct radv_cmd_buffer *cmd_buffer) assert(ring_entry_loc->sgpr_idx != -1); - uint32_t base_reg = cmd_buffer->state.vtx_base_sgpr; - uint32_t xyz_dim_reg = (base_reg - SI_SH_REG_OFFSET) >> 2; - uint32_t ring_entry_reg = ((base_reg + ring_entry_loc->sgpr_idx * 4) - SI_SH_REG_OFFSET) >> 2; - uint32_t xyz_dim_en = 1; /* TODO: disable XYZ_DIM when unneeded */ - uint32_t mode1_en = 1; /* legacy fast launch mode */ + uint32_t xyz_dim_reg = (cmd_buffer->state.vtx_base_sgpr - SI_SH_REG_OFFSET) >> 2; + uint32_t ring_entry_reg = ((mesh_shader->info.user_data_0 - SI_SH_REG_OFFSET) >> 2) + ring_entry_loc->sgpr_idx; + uint32_t xyz_dim_en = mesh_shader->info.cs.uses_grid_size; + uint32_t mode1_en = 1; /* legacy fast launch mode */ uint32_t linear_dispatch_en = cmd_buffer->state.shaders[MESA_SHADER_TASK]->info.cs.linear_taskmesh_dispatch; const bool sqtt_en = !!cmd_buffer->device->sqtt.bo; @@ -8226,14 +8227,20 @@ ALWAYS_INLINE static void radv_emit_userdata_mesh(struct radv_cmd_buffer *cmd_buffer, const uint32_t x, const uint32_t y, const uint32_t z) { struct radv_cmd_state *state = &cmd_buffer->state; + const struct radv_shader *mesh_shader = state->shaders[MESA_SHADER_MESH]; struct radeon_cmdbuf *cs = cmd_buffer->cs; const bool uses_drawid = state->uses_drawid; + const bool uses_grid_size = mesh_shader->info.cs.uses_grid_size; + + if (!uses_drawid && !uses_grid_size) + return; radeon_set_sh_reg_seq(cs, state->vtx_base_sgpr, state->vtx_emit_num); - radeon_emit(cs, x); - radeon_emit(cs, y); - radeon_emit(cs, z); - + if (uses_grid_size) { + radeon_emit(cs, x); + radeon_emit(cs, y); + radeon_emit(cs, z); + } if (uses_drawid) { radeon_emit(cs, 0); state->last_drawid = 0; @@ -8497,7 +8504,9 @@ radv_emit_indirect_mesh_draw_packets(struct radv_cmd_buffer *cmd_buffer, const s radeon_emit(cs, va >> 32); if (state->uses_drawid) { - radeon_set_sh_reg_seq(cs, state->vtx_base_sgpr + 12, 1); + const struct radv_shader *mesh_shader = state->shaders[MESA_SHADER_MESH]; + unsigned reg = state->vtx_base_sgpr + (mesh_shader->info.cs.uses_grid_size ? 12 : 0); + radeon_set_sh_reg_seq(cs, reg, 1); radeon_emit(cs, 0); } diff --git a/src/amd/vulkan/radv_shader_args.c b/src/amd/vulkan/radv_shader_args.c index 1a442db98dc..81fe6ce0729 100644 --- a/src/amd/vulkan/radv_shader_args.c +++ b/src/amd/vulkan/radv_shader_args.c @@ -242,7 +242,9 @@ declare_tes_input_vgprs(struct radv_shader_args *args) static void declare_ms_input_sgprs(const struct radv_shader_info *info, struct radv_shader_args *args) { - add_ud_arg(args, 3, AC_ARG_INT, &args->ac.num_work_groups, AC_UD_VS_BASE_VERTEX_START_INSTANCE); + if (info->cs.uses_grid_size) { + add_ud_arg(args, 3, AC_ARG_INT, &args->ac.num_work_groups, AC_UD_VS_BASE_VERTEX_START_INSTANCE); + } if (info->vs.needs_draw_id) { add_ud_arg(args, 1, AC_ARG_INT, &args->ac.draw_id, AC_UD_VS_BASE_VERTEX_START_INSTANCE); } diff --git a/src/amd/vulkan/radv_shader_info.c b/src/amd/vulkan/radv_shader_info.c index e8738f3dc8b..2c54d0bd221 100644 --- a/src/amd/vulkan/radv_shader_info.c +++ b/src/amd/vulkan/radv_shader_info.c @@ -1148,8 +1148,10 @@ radv_nir_shader_info_pass(struct radv_device *device, const struct nir_shader *n info->uses_invocation_id |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_INVOCATION_ID); info->uses_prim_id |= BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID); - /* Used by compute and mesh shaders. */ - info->cs.uses_grid_size = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_NUM_WORKGROUPS); + /* Used by compute and mesh shaders. Mesh shaders must always declare this before GFX11. */ + info->cs.uses_grid_size = + BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_NUM_WORKGROUPS) || + (nir->info.stage == MESA_SHADER_MESH && device->physical_device->rad_info.gfx_level < GFX11); info->cs.uses_local_invocation_idx = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_LOCAL_INVOCATION_INDEX) | BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SUBGROUP_ID) | BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_NUM_SUBGROUPS);