ac/nir/ngg: Add EXT_mesh_shader vertex/primitive count.

In EXT_mesh_shader the vertex and primitive counts are set using a
built-in SetMeshOutputsEXT function.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18367>
This commit is contained in:
Timur Kristóf
2022-02-28 14:24:17 +01:00
committed by Marge Bot
parent 448d09d44a
commit c7ff93a766
+113 -2
View File
@@ -172,6 +172,8 @@ typedef struct
nir_ssa_def *workgroup_index;
nir_variable *out_variables[VARYING_SLOT_MAX * 4];
nir_variable *primitive_count_var;
nir_variable *vertex_count_var;
/* True if the lowering needs to insert the layer output. */
bool insert_layer_output;
@@ -2816,6 +2818,25 @@ lower_ms_load_workgroup_index(nir_builder *b,
return s->workgroup_index;
}
static nir_ssa_def *
lower_ms_set_vertex_and_primitive_count(nir_builder *b,
nir_intrinsic_instr *intrin,
lower_ngg_ms_state *s)
{
/* If either the number of vertices or primitives is zero, set both of them to zero. */
nir_ssa_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
nir_ssa_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
nir_ssa_def *zero = nir_imm_int(b, 0);
nir_ssa_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
}
static nir_ssa_def *
update_ms_scoped_barrier(nir_builder *b,
nir_intrinsic_instr *intrin,
@@ -2861,6 +2882,8 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
return update_ms_scoped_barrier(b, intrin, s);
case nir_intrinsic_load_workgroup_index:
return lower_ms_load_workgroup_index(b, intrin, s);
case nir_intrinsic_set_vertex_and_primitive_count:
return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
default:
unreachable("Not a lowerable mesh shader intrinsic.");
}
@@ -2881,7 +2904,8 @@ filter_ms_intrinsic(const nir_instr *instr,
intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
intrin->intrinsic == nir_intrinsic_scoped_barrier ||
intrin->intrinsic == nir_intrinsic_load_workgroup_index;
intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
}
static void
@@ -3032,6 +3056,85 @@ set_nv_ms_final_output_counts(nir_builder *b,
*out_num_vtx = num_vtx;
}
static void
set_ms_final_output_counts(nir_builder *b,
lower_ngg_ms_state *s,
nir_ssa_def **out_num_prm,
nir_ssa_def **out_num_vtx)
{
/* The spec allows the numbers to be divergent, and in that case we need to
* use the values from the first invocation. Also the HW requires us to set
* both to 0 if either was 0.
*
* These are already done by the lowering.
*/
nir_ssa_def *num_prm = nir_load_var(b, s->primitive_count_var);
nir_ssa_def *num_vtx = nir_load_var(b, s->vertex_count_var);
if (s->hw_workgroup_size <= s->wave_size) {
/* Single-wave mesh shader workgroup. */
nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm);
*out_num_prm = num_prm;
*out_num_vtx = num_vtx;
return;
}
/* Multi-wave mesh shader workgroup:
* We need to use LDS to distribute the correct values to the other waves.
*
* TODO:
* If we can prove that the values are workgroup-uniform, we can skip this
* and just use whatever the current wave has. However, NIR divergence analysis
* currently doesn't support this.
*/
nir_ssa_def *zero = nir_imm_int(b, 0);
nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
{
nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
{
nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
.base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
}
nir_pop_if(b, if_elected);
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
.memory_scope = NIR_SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL,
.memory_modes = nir_var_mem_shared);
nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm);
}
nir_push_else(b, if_wave_0);
{
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
.memory_scope = NIR_SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL,
.memory_modes = nir_var_mem_shared);
nir_ssa_def *prm_vtx = NULL;
nir_ssa_def *dont_care_2x32 = nir_ssa_undef(b, 2, 32);
nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
{
prm_vtx = nir_load_shared(b, 2, 32, zero,
.base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
}
nir_pop_if(b, if_elected);
prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
}
nir_pop_if(b, if_wave_0);
*out_num_prm = nir_load_var(b, s->primitive_count_var);
*out_num_vtx = nir_load_var(b, s->vertex_count_var);
}
static void
emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
{
@@ -3045,7 +3148,10 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
nir_ssa_def *num_prm;
nir_ssa_def *num_vtx;
set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx);
if (b->shader->info.mesh.nv)
set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx);
else
set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
@@ -3403,6 +3509,11 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
assert(impl);
state.vertex_count_var =
nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
state.primitive_count_var =
nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
nir_builder builder;
nir_builder *b = &builder; /* This is to avoid the & */
nir_builder_init(b, impl);