radv: use shared ac_legacy_gs_compute_subgroup_info

Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35473>
This commit is contained in:
Marek Olšák
2025-05-31 08:24:42 -04:00
committed by Marge Bot
parent 8a1e357f71
commit d674e97d5c
+21 -107
View File
@@ -688,10 +688,25 @@ gather_shader_info_tes(struct radv_device *device, const nir_shader *nir, struct
}
static void
radv_init_legacy_gs_ring_info(const struct radv_device *device, struct radv_shader_info *gs_info)
radv_get_legacy_gs_info(const struct radv_device *device, struct radv_shader_info *gs_info)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
struct radv_legacy_gs_info *gs_ring_info = &gs_info->gs_ring_info;
struct radv_legacy_gs_info *out = &gs_info->gs_ring_info;
const unsigned esgs_vertex_stride = radv_compute_esgs_itemsize(device, gs_info->gs.num_linked_inputs);
ac_legacy_gs_subgroup_info info;
ac_legacy_gs_compute_subgroup_info(gs_info->gs.input_prim, gs_info->gs.vertices_out, gs_info->gs.invocations,
esgs_vertex_stride, &info);
const uint32_t lds_granularity = pdev->info.lds_encode_granularity;
const uint32_t total_lds_bytes = align(info.esgs_lds_size * 4, lds_granularity);
out->gs_inst_prims_in_subgroup = info.gs_inst_prims_in_subgroup;
out->es_verts_per_subgroup = info.es_verts_per_subgroup;
out->gs_prims_per_subgroup = info.gs_prims_per_subgroup;
out->esgs_itemsize = esgs_vertex_stride / 4;
out->lds_size = total_lds_bytes / lds_granularity;
unsigned num_se = pdev->info.max_se;
unsigned wave_size = 64;
unsigned max_gs_waves = 32 * num_se; /* max 32 per SE on GCN */
@@ -704,9 +719,9 @@ radv_init_legacy_gs_ring_info(const struct radv_device *device, struct radv_shad
unsigned max_size = ((unsigned)(63.999 * 1024 * 1024) & ~255) * num_se;
/* Calculate the minimum size. */
unsigned min_esgs_ring_size = align(gs_ring_info->esgs_itemsize * 4 * gs_vertex_reuse * wave_size, alignment);
unsigned min_esgs_ring_size = align(esgs_vertex_stride * gs_vertex_reuse * wave_size, alignment);
/* These are recommended sizes, not minimum sizes. */
unsigned esgs_ring_size = max_gs_waves * 2 * wave_size * gs_ring_info->esgs_itemsize * 4 * gs_info->gs.vertices_in;
unsigned esgs_ring_size = max_gs_waves * 2 * wave_size * esgs_vertex_stride * gs_info->gs.vertices_in;
unsigned gsvs_ring_size = max_gs_waves * 2 * wave_size * gs_info->gs.max_gsvs_emit_size;
min_esgs_ring_size = align(min_esgs_ring_size, alignment);
@@ -714,110 +729,9 @@ radv_init_legacy_gs_ring_info(const struct radv_device *device, struct radv_shad
gsvs_ring_size = align(gsvs_ring_size, alignment);
if (pdev->info.gfx_level <= GFX8)
gs_ring_info->esgs_ring_size = CLAMP(esgs_ring_size, min_esgs_ring_size, max_size);
out->esgs_ring_size = CLAMP(esgs_ring_size, min_esgs_ring_size, max_size);
gs_ring_info->gsvs_ring_size = MIN2(gsvs_ring_size, max_size);
}
static void
radv_get_legacy_gs_info(const struct radv_device *device, struct radv_shader_info *gs_info)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
struct radv_legacy_gs_info *out = &gs_info->gs_ring_info;
const unsigned gs_num_invocations = MAX2(gs_info->gs.invocations, 1);
const bool uses_adjacency =
gs_info->gs.input_prim == MESA_PRIM_LINES_ADJACENCY || gs_info->gs.input_prim == MESA_PRIM_TRIANGLES_ADJACENCY;
/* All these are in dwords: */
/* We can't allow using the whole LDS, because GS waves compete with
* other shader stages for LDS space. */
const unsigned max_lds_size = 8 * 1024;
const unsigned esgs_itemsize = radv_compute_esgs_itemsize(device, gs_info->gs.num_linked_inputs) / 4;
unsigned esgs_lds_size;
/* All these are per subgroup: */
const unsigned max_out_prims = 32 * 1024;
const unsigned max_es_verts = 255;
const unsigned ideal_gs_prims = 64;
unsigned max_gs_prims, gs_prims;
unsigned min_es_verts, es_verts, worst_case_es_verts;
if (uses_adjacency || gs_num_invocations > 1)
max_gs_prims = 127 / gs_num_invocations;
else
max_gs_prims = 255;
/* MAX_PRIMS_PER_SUBGROUP = gs_prims * max_vert_out * gs_invocations.
* Make sure we don't go over the maximum value.
*/
if (gs_info->gs.vertices_out > 0) {
max_gs_prims = MIN2(max_gs_prims, max_out_prims / (gs_info->gs.vertices_out * gs_num_invocations));
}
assert(max_gs_prims > 0);
/* If the primitive has adjacency, halve the number of vertices
* that will be reused in multiple primitives.
*/
min_es_verts = gs_info->gs.vertices_in / (uses_adjacency ? 2 : 1);
gs_prims = MIN2(ideal_gs_prims, max_gs_prims);
worst_case_es_verts = MIN2(min_es_verts * gs_prims, max_es_verts);
/* Compute ESGS LDS size based on the worst case number of ES vertices
* needed to create the target number of GS prims per subgroup.
*/
esgs_lds_size = esgs_itemsize * worst_case_es_verts;
/* If total LDS usage is too big, refactor partitions based on ratio
* of ESGS item sizes.
*/
if (esgs_lds_size > max_lds_size) {
/* Our target GS Prims Per Subgroup was too large. Calculate
* the maximum number of GS Prims Per Subgroup that will fit
* into LDS, capped by the maximum that the hardware can support.
*/
gs_prims = MIN2((max_lds_size / (esgs_itemsize * min_es_verts)), max_gs_prims);
assert(gs_prims > 0);
worst_case_es_verts = MIN2(min_es_verts * gs_prims, max_es_verts);
esgs_lds_size = esgs_itemsize * worst_case_es_verts;
assert(esgs_lds_size <= max_lds_size);
}
/* Now calculate remaining ESGS information. */
if (esgs_lds_size)
es_verts = MIN2(esgs_lds_size / esgs_itemsize, max_es_verts);
else
es_verts = max_es_verts;
/* Vertices for adjacency primitives are not always reused, so restore
* it for ES_VERTS_PER_SUBGRP.
*/
min_es_verts = gs_info->gs.vertices_in;
/* For normal primitives, the VGT only checks if they are past the ES
* verts per subgroup after allocating a full GS primitive and if they
* are, kick off a new subgroup. But if those additional ES verts are
* unique (e.g. not reused) we need to make sure there is enough LDS
* space to account for those ES verts beyond ES_VERTS_PER_SUBGRP.
*/
es_verts -= min_es_verts - 1;
const uint32_t es_verts_per_subgroup = es_verts;
const uint32_t gs_prims_per_subgroup = gs_prims;
const uint32_t gs_inst_prims_in_subgroup = gs_prims * gs_num_invocations;
const uint32_t max_prims_per_subgroup = gs_inst_prims_in_subgroup * gs_info->gs.vertices_out;
const uint32_t lds_granularity = pdev->info.lds_encode_granularity;
const uint32_t total_lds_bytes = align(esgs_lds_size * 4, lds_granularity);
out->gs_inst_prims_in_subgroup = gs_inst_prims_in_subgroup;
out->es_verts_per_subgroup = es_verts_per_subgroup;
out->gs_prims_per_subgroup = gs_prims_per_subgroup;
out->esgs_itemsize = esgs_itemsize;
out->lds_size = total_lds_bytes / lds_granularity;
assert(max_prims_per_subgroup <= max_out_prims);
radv_init_legacy_gs_ring_info(device, gs_info);
out->gsvs_ring_size = MIN2(gsvs_ring_size, max_size);
}
static void