ac/nir/ngg: support gs streamout
Port from radeonsi. Reviewed-by: Timur Kristóf <timur.kristof@gmail.com> Signed-off-by: Qiang Yu <yuq825@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17654>
This commit is contained in:
@@ -144,7 +144,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
|
||||
unsigned gs_out_vtx_bytes,
|
||||
unsigned gs_total_out_vtx_bytes,
|
||||
bool provoking_vtx_last,
|
||||
bool can_cull);
|
||||
bool can_cull,
|
||||
bool disable_streamout);
|
||||
|
||||
void
|
||||
ac_nir_lower_ngg_ms(nir_shader *shader,
|
||||
|
||||
@@ -112,6 +112,7 @@ typedef struct
|
||||
bool output_compile_time_known;
|
||||
bool provoking_vertex_last;
|
||||
bool can_cull;
|
||||
bool streamout_enabled;
|
||||
gs_output_info output_info[VARYING_SLOT_MAX];
|
||||
} lower_ngg_gs_state;
|
||||
|
||||
@@ -2572,6 +2573,110 @@ ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *max_v
|
||||
return nir_load_var(b, primflag_var);
|
||||
}
|
||||
|
||||
static void
|
||||
ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *st)
|
||||
{
|
||||
nir_xfb_info *info = nir_gather_xfb_info_from_intrinsics(b->shader, NULL);
|
||||
if (unlikely(!info))
|
||||
return;
|
||||
|
||||
nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
|
||||
nir_ssa_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
|
||||
nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, st);
|
||||
nir_ssa_def *prim_live[4] = {0};
|
||||
nir_ssa_def *gen_prim[4] = {0};
|
||||
nir_ssa_def *export_seq[4] = {0};
|
||||
nir_ssa_def *out_vtx_primflag[4] = {0};
|
||||
for (unsigned stream = 0; stream < 4; stream++) {
|
||||
if (!(info->streams_written & BITFIELD_BIT(stream)))
|
||||
continue;
|
||||
|
||||
out_vtx_primflag[stream] =
|
||||
ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, st);
|
||||
|
||||
/* Check bit 0 of primflag for primitive alive, it's set for every last
|
||||
* vertex of a primitive.
|
||||
*/
|
||||
prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1));
|
||||
|
||||
unsigned scratch_stride = ALIGN(st->max_num_waves, 4);
|
||||
|
||||
/* We want to export primitives to streamout buffer in sequence,
|
||||
* but not all vertices are alive or mark end of a primitive, so
|
||||
* there're "holes". We don't need continous invocations to write
|
||||
* primitives to streamout buffer like final vertex export, so
|
||||
* just repack to get the sequence (export_seq) is enough, no need
|
||||
* to do compaction.
|
||||
*
|
||||
* Use separate scratch space for each stream to avoid barrier.
|
||||
* TODO: we may further reduce barriers by writing to all stream
|
||||
* LDS at once, then we only need one barrier instead of one each
|
||||
* stream..
|
||||
*/
|
||||
wg_repack_result rep =
|
||||
repack_invocations_in_workgroup(b, prim_live[stream],
|
||||
st->lds_addr_gs_scratch + stream * scratch_stride,
|
||||
st->max_num_waves, st->wave_size);
|
||||
|
||||
/* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
|
||||
* current wave, but still need LDS to sum all wave's count to get workgroup count.
|
||||
* And we need repack to export primitive to streamout buffer anyway, so do here.
|
||||
*/
|
||||
gen_prim[stream] = rep.num_repacked_invocations;
|
||||
export_seq[stream] = rep.repacked_invocation_index;
|
||||
}
|
||||
|
||||
/* Workgroup barrier: wait for LDS scratch reads finish. */
|
||||
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_semantics = NIR_MEMORY_ACQ_REL,
|
||||
.memory_modes = nir_var_mem_shared);
|
||||
|
||||
/* Get global buffer offset where this workgroup will stream out data to. */
|
||||
nir_ssa_def *emit_prim[4] = {0};
|
||||
nir_ssa_def *buffer_offsets[4] = {0};
|
||||
nir_ssa_def *so_buffer[4] = {0};
|
||||
nir_ssa_def *prim_stride[4] = {0};
|
||||
ngg_build_streamout_buffer_info(b, info, st->lds_addr_gs_scratch, tid_in_tg, gen_prim,
|
||||
prim_stride, so_buffer, buffer_offsets, emit_prim);
|
||||
|
||||
/* GS use packed location for vertex LDS storage. */
|
||||
int slot_to_register[NUM_TOTAL_VARYING_SLOTS];
|
||||
for (int i = 0; i < info->output_count; i++) {
|
||||
unsigned location = info->outputs[i].location;
|
||||
slot_to_register[location] =
|
||||
util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(location));
|
||||
}
|
||||
|
||||
for (unsigned stream = 0; stream < 4; stream++) {
|
||||
if (!(info->streams_written & BITFIELD_BIT(stream)))
|
||||
continue;
|
||||
|
||||
nir_ssa_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]);
|
||||
nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream]));
|
||||
{
|
||||
/* Get streamout buffer vertex index for the first vertex of this primitive. */
|
||||
nir_ssa_def *vtx_buffer_idx =
|
||||
nir_imul_imm(b, export_seq[stream], st->num_vertices_per_primitive);
|
||||
|
||||
/* Get all vertices' lds address of this primitive. */
|
||||
nir_ssa_def *exported_vtx_lds_addr[3];
|
||||
ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
|
||||
out_vtx_primflag[stream], st,
|
||||
exported_vtx_lds_addr);
|
||||
|
||||
/* Write all vertices of this primitive to streamout buffer. */
|
||||
for (unsigned i = 0; i < st->num_vertices_per_primitive; i++) {
|
||||
ngg_build_streamout_vertex(b, info, stream, slot_to_register,
|
||||
so_buffer, buffer_offsets,
|
||||
nir_iadd_imm(b, vtx_buffer_idx, i),
|
||||
exported_vtx_lds_addr[i]);
|
||||
}
|
||||
}
|
||||
nir_pop_if(b, if_emit);
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
|
||||
{
|
||||
@@ -2589,9 +2694,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
|
||||
nir_pop_if(b, if_wave_0);
|
||||
}
|
||||
|
||||
/* Workgroup barrier: wait for all GS threads to finish */
|
||||
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
|
||||
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
|
||||
/* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */
|
||||
|
||||
nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
|
||||
|
||||
@@ -2654,7 +2757,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
|
||||
unsigned gs_out_vtx_bytes,
|
||||
unsigned gs_total_out_vtx_bytes,
|
||||
bool provoking_vertex_last,
|
||||
bool can_cull)
|
||||
bool can_cull,
|
||||
bool disable_streamout)
|
||||
{
|
||||
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
|
||||
assert(impl);
|
||||
@@ -2669,9 +2773,14 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
|
||||
.lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
|
||||
.provoking_vertex_last = provoking_vertex_last,
|
||||
.can_cull = can_cull,
|
||||
.streamout_enabled = shader->xfb_info && !disable_streamout,
|
||||
};
|
||||
|
||||
unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
|
||||
unsigned lds_scratch_bytes = ALIGN(state.max_num_waves, 4u);
|
||||
/* streamout take 8 dwords for buffer offset and emit vertex per stream */
|
||||
if (state.streamout_enabled)
|
||||
lds_scratch_bytes = MAX2(lds_scratch_bytes, 32);
|
||||
|
||||
unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
|
||||
shader->info.shared_size = total_lds_bytes;
|
||||
|
||||
@@ -2715,6 +2824,13 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
|
||||
b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
|
||||
nir_pop_if(b, if_gs_thread);
|
||||
|
||||
/* Workgroup barrier: wait for all GS threads to finish */
|
||||
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
|
||||
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
|
||||
|
||||
if (state.streamout_enabled)
|
||||
ngg_gs_build_streamout(b, &state);
|
||||
|
||||
/* Lower the GS intrinsics */
|
||||
lower_ngg_gs_intrinsics(shader, &state);
|
||||
b->cursor = nir_after_cf_list(&impl->body);
|
||||
|
||||
@@ -1341,7 +1341,8 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
|
||||
assert(info->is_ngg);
|
||||
NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size,
|
||||
info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
|
||||
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, false);
|
||||
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last,
|
||||
false, true);
|
||||
} else if (nir->info.stage == MESA_SHADER_MESH) {
|
||||
bool scratch_ring = false;
|
||||
NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);
|
||||
|
||||
Reference in New Issue
Block a user