From c7ff93a766186f956a14cc05904ead52efc71d99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timur=20Krist=C3=B3f?= Date: Mon, 28 Feb 2022 14:24:17 +0100 Subject: [PATCH] ac/nir/ngg: Add EXT_mesh_shader vertex/primitive count. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In EXT_mesh_shader the vertex and primitive counts are set using a built-in SetMeshOutputsEXT function. Signed-off-by: Timur Kristóf Reviewed-by: Rhys Perry Part-of: --- src/amd/common/ac_nir_lower_ngg.c | 115 +++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 5b9b95255a9..24a125c7c15 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -172,6 +172,8 @@ typedef struct nir_ssa_def *workgroup_index; nir_variable *out_variables[VARYING_SLOT_MAX * 4]; + nir_variable *primitive_count_var; + nir_variable *vertex_count_var; /* True if the lowering needs to insert the layer output. */ bool insert_layer_output; @@ -2816,6 +2818,25 @@ lower_ms_load_workgroup_index(nir_builder *b, return s->workgroup_index; } +static nir_ssa_def * +lower_ms_set_vertex_and_primitive_count(nir_builder *b, + nir_intrinsic_instr *intrin, + lower_ngg_ms_state *s) +{ + /* If either the number of vertices or primitives is zero, set both of them to zero. */ + nir_ssa_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa); + nir_ssa_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa); + nir_ssa_def *zero = nir_imm_int(b, 0); + nir_ssa_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero); + num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx); + num_prm = nir_bcsel(b, is_either_zero, zero, num_prm); + + nir_store_var(b, s->vertex_count_var, num_vtx, 0x1); + nir_store_var(b, s->primitive_count_var, num_prm, 0x1); + + return NIR_LOWER_INSTR_PROGRESS_REPLACE; +} + static nir_ssa_def * update_ms_scoped_barrier(nir_builder *b, nir_intrinsic_instr *intrin, @@ -2861,6 +2882,8 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state) return update_ms_scoped_barrier(b, intrin, s); case nir_intrinsic_load_workgroup_index: return lower_ms_load_workgroup_index(b, intrin, s); + case nir_intrinsic_set_vertex_and_primitive_count: + return lower_ms_set_vertex_and_primitive_count(b, intrin, s); default: unreachable("Not a lowerable mesh shader intrinsic."); } @@ -2881,7 +2904,8 @@ filter_ms_intrinsic(const nir_instr *instr, intrin->intrinsic == nir_intrinsic_store_per_primitive_output || intrin->intrinsic == nir_intrinsic_load_per_primitive_output || intrin->intrinsic == nir_intrinsic_scoped_barrier || - intrin->intrinsic == nir_intrinsic_load_workgroup_index; + intrin->intrinsic == nir_intrinsic_load_workgroup_index || + intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count; } static void @@ -3032,6 +3056,85 @@ set_nv_ms_final_output_counts(nir_builder *b, *out_num_vtx = num_vtx; } +static void +set_ms_final_output_counts(nir_builder *b, + lower_ngg_ms_state *s, + nir_ssa_def **out_num_prm, + nir_ssa_def **out_num_vtx) +{ + /* The spec allows the numbers to be divergent, and in that case we need to + * use the values from the first invocation. Also the HW requires us to set + * both to 0 if either was 0. + * + * These are already done by the lowering. + */ + nir_ssa_def *num_prm = nir_load_var(b, s->primitive_count_var); + nir_ssa_def *num_vtx = nir_load_var(b, s->vertex_count_var); + + if (s->hw_workgroup_size <= s->wave_size) { + /* Single-wave mesh shader workgroup. */ + nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm); + *out_num_prm = num_prm; + *out_num_vtx = num_vtx; + return; + } + + /* Multi-wave mesh shader workgroup: + * We need to use LDS to distribute the correct values to the other waves. + * + * TODO: + * If we can prove that the values are workgroup-uniform, we can skip this + * and just use whatever the current wave has. However, NIR divergence analysis + * currently doesn't support this. + */ + + nir_ssa_def *zero = nir_imm_int(b, 0); + + nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0)); + { + nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); + { + nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero, + .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); + } + nir_pop_if(b, if_elected); + + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + + nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm); + } + nir_push_else(b, if_wave_0); + { + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + + nir_ssa_def *prm_vtx = NULL; + nir_ssa_def *dont_care_2x32 = nir_ssa_undef(b, 2, 32); + nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); + { + prm_vtx = nir_load_shared(b, 2, 32, zero, + .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); + } + nir_pop_if(b, if_elected); + + prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32); + num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0)); + num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1)); + + nir_store_var(b, s->primitive_count_var, num_prm, 0x1); + nir_store_var(b, s->vertex_count_var, num_vtx, 0x1); + } + nir_pop_if(b, if_wave_0); + + *out_num_prm = nir_load_var(b, s->primitive_count_var); + *out_num_vtx = nir_load_var(b, s->vertex_count_var); +} + static void emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) { @@ -3045,7 +3148,10 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) nir_ssa_def *num_prm; nir_ssa_def *num_vtx; - set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx); + if (b->shader->info.mesh.nv) + set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx); + else + set_ms_final_output_counts(b, s, &num_prm, &num_vtx); nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); @@ -3403,6 +3509,11 @@ ac_nir_lower_ngg_ms(nir_shader *shader, nir_function_impl *impl = nir_shader_get_entrypoint(shader); assert(impl); + state.vertex_count_var = + nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var"); + state.primitive_count_var = + nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var"); + nir_builder builder; nir_builder *b = &builder; /* This is to avoid the & */ nir_builder_init(b, impl);