ac/nir/ngg: support gs streamout

Port from radeonsi.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17654>
This commit is contained in:
Qiang Yu
2022-06-30 16:10:53 +08:00
committed by Marge Bot
parent 3fe8f88124
commit 074f3216f2
3 changed files with 125 additions and 7 deletions
+2 -1
View File
@@ -144,7 +144,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
unsigned gs_out_vtx_bytes,
unsigned gs_total_out_vtx_bytes,
bool provoking_vtx_last,
bool can_cull);
bool can_cull,
bool disable_streamout);
void
ac_nir_lower_ngg_ms(nir_shader *shader,
+121 -5
View File
@@ -112,6 +112,7 @@ typedef struct
bool output_compile_time_known;
bool provoking_vertex_last;
bool can_cull;
bool streamout_enabled;
gs_output_info output_info[VARYING_SLOT_MAX];
} lower_ngg_gs_state;
@@ -2572,6 +2573,110 @@ ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *max_v
return nir_load_var(b, primflag_var);
}
static void
ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *st)
{
nir_xfb_info *info = nir_gather_xfb_info_from_intrinsics(b->shader, NULL);
if (unlikely(!info))
return;
nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
nir_ssa_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, st);
nir_ssa_def *prim_live[4] = {0};
nir_ssa_def *gen_prim[4] = {0};
nir_ssa_def *export_seq[4] = {0};
nir_ssa_def *out_vtx_primflag[4] = {0};
for (unsigned stream = 0; stream < 4; stream++) {
if (!(info->streams_written & BITFIELD_BIT(stream)))
continue;
out_vtx_primflag[stream] =
ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, st);
/* Check bit 0 of primflag for primitive alive, it's set for every last
* vertex of a primitive.
*/
prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1));
unsigned scratch_stride = ALIGN(st->max_num_waves, 4);
/* We want to export primitives to streamout buffer in sequence,
* but not all vertices are alive or mark end of a primitive, so
* there're "holes". We don't need continous invocations to write
* primitives to streamout buffer like final vertex export, so
* just repack to get the sequence (export_seq) is enough, no need
* to do compaction.
*
* Use separate scratch space for each stream to avoid barrier.
* TODO: we may further reduce barriers by writing to all stream
* LDS at once, then we only need one barrier instead of one each
* stream..
*/
wg_repack_result rep =
repack_invocations_in_workgroup(b, prim_live[stream],
st->lds_addr_gs_scratch + stream * scratch_stride,
st->max_num_waves, st->wave_size);
/* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
* current wave, but still need LDS to sum all wave's count to get workgroup count.
* And we need repack to export primitive to streamout buffer anyway, so do here.
*/
gen_prim[stream] = rep.num_repacked_invocations;
export_seq[stream] = rep.repacked_invocation_index;
}
/* Workgroup barrier: wait for LDS scratch reads finish. */
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
.memory_scope = NIR_SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL,
.memory_modes = nir_var_mem_shared);
/* Get global buffer offset where this workgroup will stream out data to. */
nir_ssa_def *emit_prim[4] = {0};
nir_ssa_def *buffer_offsets[4] = {0};
nir_ssa_def *so_buffer[4] = {0};
nir_ssa_def *prim_stride[4] = {0};
ngg_build_streamout_buffer_info(b, info, st->lds_addr_gs_scratch, tid_in_tg, gen_prim,
prim_stride, so_buffer, buffer_offsets, emit_prim);
/* GS use packed location for vertex LDS storage. */
int slot_to_register[NUM_TOTAL_VARYING_SLOTS];
for (int i = 0; i < info->output_count; i++) {
unsigned location = info->outputs[i].location;
slot_to_register[location] =
util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(location));
}
for (unsigned stream = 0; stream < 4; stream++) {
if (!(info->streams_written & BITFIELD_BIT(stream)))
continue;
nir_ssa_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]);
nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream]));
{
/* Get streamout buffer vertex index for the first vertex of this primitive. */
nir_ssa_def *vtx_buffer_idx =
nir_imul_imm(b, export_seq[stream], st->num_vertices_per_primitive);
/* Get all vertices' lds address of this primitive. */
nir_ssa_def *exported_vtx_lds_addr[3];
ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
out_vtx_primflag[stream], st,
exported_vtx_lds_addr);
/* Write all vertices of this primitive to streamout buffer. */
for (unsigned i = 0; i < st->num_vertices_per_primitive; i++) {
ngg_build_streamout_vertex(b, info, stream, slot_to_register,
so_buffer, buffer_offsets,
nir_iadd_imm(b, vtx_buffer_idx, i),
exported_vtx_lds_addr[i]);
}
}
nir_pop_if(b, if_emit);
}
}
static void
ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
{
@@ -2589,9 +2694,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
nir_pop_if(b, if_wave_0);
}
/* Workgroup barrier: wait for all GS threads to finish */
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
/* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */
nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
@@ -2654,7 +2757,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
unsigned gs_out_vtx_bytes,
unsigned gs_total_out_vtx_bytes,
bool provoking_vertex_last,
bool can_cull)
bool can_cull,
bool disable_streamout)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
assert(impl);
@@ -2669,9 +2773,14 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
.lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
.provoking_vertex_last = provoking_vertex_last,
.can_cull = can_cull,
.streamout_enabled = shader->xfb_info && !disable_streamout,
};
unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
unsigned lds_scratch_bytes = ALIGN(state.max_num_waves, 4u);
/* streamout take 8 dwords for buffer offset and emit vertex per stream */
if (state.streamout_enabled)
lds_scratch_bytes = MAX2(lds_scratch_bytes, 32);
unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
shader->info.shared_size = total_lds_bytes;
@@ -2715,6 +2824,13 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
nir_pop_if(b, if_gs_thread);
/* Workgroup barrier: wait for all GS threads to finish */
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
if (state.streamout_enabled)
ngg_gs_build_streamout(b, &state);
/* Lower the GS intrinsics */
lower_ngg_gs_intrinsics(shader, &state);
b->cursor = nir_after_cf_list(&impl->body);
+2 -1
View File
@@ -1341,7 +1341,8 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
assert(info->is_ngg);
NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size,
info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, false);
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last,
false, true);
} else if (nir->info.stage == MESA_SHADER_MESH) {
bool scratch_ring = false;
NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);