ac/nir/ngg: Add EXT_mesh_shader vertex/primitive count.
In EXT_mesh_shader the vertex and primitive counts are set using a built-in SetMeshOutputsEXT function. Signed-off-by: Timur Kristóf <timur.kristof@gmail.com> Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18367>
This commit is contained in:
@@ -172,6 +172,8 @@ typedef struct
|
||||
|
||||
nir_ssa_def *workgroup_index;
|
||||
nir_variable *out_variables[VARYING_SLOT_MAX * 4];
|
||||
nir_variable *primitive_count_var;
|
||||
nir_variable *vertex_count_var;
|
||||
|
||||
/* True if the lowering needs to insert the layer output. */
|
||||
bool insert_layer_output;
|
||||
@@ -2816,6 +2818,25 @@ lower_ms_load_workgroup_index(nir_builder *b,
|
||||
return s->workgroup_index;
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
lower_ms_set_vertex_and_primitive_count(nir_builder *b,
|
||||
nir_intrinsic_instr *intrin,
|
||||
lower_ngg_ms_state *s)
|
||||
{
|
||||
/* If either the number of vertices or primitives is zero, set both of them to zero. */
|
||||
nir_ssa_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
|
||||
nir_ssa_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
|
||||
nir_ssa_def *zero = nir_imm_int(b, 0);
|
||||
nir_ssa_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
|
||||
num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
|
||||
num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
|
||||
|
||||
nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
|
||||
nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
|
||||
|
||||
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
update_ms_scoped_barrier(nir_builder *b,
|
||||
nir_intrinsic_instr *intrin,
|
||||
@@ -2861,6 +2882,8 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
|
||||
return update_ms_scoped_barrier(b, intrin, s);
|
||||
case nir_intrinsic_load_workgroup_index:
|
||||
return lower_ms_load_workgroup_index(b, intrin, s);
|
||||
case nir_intrinsic_set_vertex_and_primitive_count:
|
||||
return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
|
||||
default:
|
||||
unreachable("Not a lowerable mesh shader intrinsic.");
|
||||
}
|
||||
@@ -2881,7 +2904,8 @@ filter_ms_intrinsic(const nir_instr *instr,
|
||||
intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
|
||||
intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
|
||||
intrin->intrinsic == nir_intrinsic_scoped_barrier ||
|
||||
intrin->intrinsic == nir_intrinsic_load_workgroup_index;
|
||||
intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
|
||||
intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
|
||||
}
|
||||
|
||||
static void
|
||||
@@ -3032,6 +3056,85 @@ set_nv_ms_final_output_counts(nir_builder *b,
|
||||
*out_num_vtx = num_vtx;
|
||||
}
|
||||
|
||||
static void
|
||||
set_ms_final_output_counts(nir_builder *b,
|
||||
lower_ngg_ms_state *s,
|
||||
nir_ssa_def **out_num_prm,
|
||||
nir_ssa_def **out_num_vtx)
|
||||
{
|
||||
/* The spec allows the numbers to be divergent, and in that case we need to
|
||||
* use the values from the first invocation. Also the HW requires us to set
|
||||
* both to 0 if either was 0.
|
||||
*
|
||||
* These are already done by the lowering.
|
||||
*/
|
||||
nir_ssa_def *num_prm = nir_load_var(b, s->primitive_count_var);
|
||||
nir_ssa_def *num_vtx = nir_load_var(b, s->vertex_count_var);
|
||||
|
||||
if (s->hw_workgroup_size <= s->wave_size) {
|
||||
/* Single-wave mesh shader workgroup. */
|
||||
nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm);
|
||||
*out_num_prm = num_prm;
|
||||
*out_num_vtx = num_vtx;
|
||||
return;
|
||||
}
|
||||
|
||||
/* Multi-wave mesh shader workgroup:
|
||||
* We need to use LDS to distribute the correct values to the other waves.
|
||||
*
|
||||
* TODO:
|
||||
* If we can prove that the values are workgroup-uniform, we can skip this
|
||||
* and just use whatever the current wave has. However, NIR divergence analysis
|
||||
* currently doesn't support this.
|
||||
*/
|
||||
|
||||
nir_ssa_def *zero = nir_imm_int(b, 0);
|
||||
|
||||
nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
|
||||
{
|
||||
nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
|
||||
{
|
||||
nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
|
||||
.base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
|
||||
}
|
||||
nir_pop_if(b, if_elected);
|
||||
|
||||
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_semantics = NIR_MEMORY_ACQ_REL,
|
||||
.memory_modes = nir_var_mem_shared);
|
||||
|
||||
nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm);
|
||||
}
|
||||
nir_push_else(b, if_wave_0);
|
||||
{
|
||||
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_semantics = NIR_MEMORY_ACQ_REL,
|
||||
.memory_modes = nir_var_mem_shared);
|
||||
|
||||
nir_ssa_def *prm_vtx = NULL;
|
||||
nir_ssa_def *dont_care_2x32 = nir_ssa_undef(b, 2, 32);
|
||||
nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
|
||||
{
|
||||
prm_vtx = nir_load_shared(b, 2, 32, zero,
|
||||
.base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
|
||||
}
|
||||
nir_pop_if(b, if_elected);
|
||||
|
||||
prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
|
||||
num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
|
||||
num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
|
||||
|
||||
nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
|
||||
nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
|
||||
}
|
||||
nir_pop_if(b, if_wave_0);
|
||||
|
||||
*out_num_prm = nir_load_var(b, s->primitive_count_var);
|
||||
*out_num_vtx = nir_load_var(b, s->vertex_count_var);
|
||||
}
|
||||
|
||||
static void
|
||||
emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
|
||||
{
|
||||
@@ -3045,7 +3148,10 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
|
||||
nir_ssa_def *num_prm;
|
||||
nir_ssa_def *num_vtx;
|
||||
|
||||
set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx);
|
||||
if (b->shader->info.mesh.nv)
|
||||
set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx);
|
||||
else
|
||||
set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
|
||||
|
||||
nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
|
||||
|
||||
@@ -3403,6 +3509,11 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
|
||||
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
|
||||
assert(impl);
|
||||
|
||||
state.vertex_count_var =
|
||||
nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
|
||||
state.primitive_count_var =
|
||||
nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
|
||||
|
||||
nir_builder builder;
|
||||
nir_builder *b = &builder; /* This is to avoid the & */
|
||||
nir_builder_init(b, impl);
|
||||
|
||||
Reference in New Issue
Block a user