diff --git a/src/amd/vulkan/radv_shader_info.c b/src/amd/vulkan/radv_shader_info.c index f462280fb66..8c0f94c784b 100644 --- a/src/amd/vulkan/radv_shader_info.c +++ b/src/amd/vulkan/radv_shader_info.c @@ -688,10 +688,25 @@ gather_shader_info_tes(struct radv_device *device, const nir_shader *nir, struct } static void -radv_init_legacy_gs_ring_info(const struct radv_device *device, struct radv_shader_info *gs_info) +radv_get_legacy_gs_info(const struct radv_device *device, struct radv_shader_info *gs_info) { const struct radv_physical_device *pdev = radv_device_physical(device); - struct radv_legacy_gs_info *gs_ring_info = &gs_info->gs_ring_info; + struct radv_legacy_gs_info *out = &gs_info->gs_ring_info; + const unsigned esgs_vertex_stride = radv_compute_esgs_itemsize(device, gs_info->gs.num_linked_inputs); + ac_legacy_gs_subgroup_info info; + + ac_legacy_gs_compute_subgroup_info(gs_info->gs.input_prim, gs_info->gs.vertices_out, gs_info->gs.invocations, + esgs_vertex_stride, &info); + + const uint32_t lds_granularity = pdev->info.lds_encode_granularity; + const uint32_t total_lds_bytes = align(info.esgs_lds_size * 4, lds_granularity); + + out->gs_inst_prims_in_subgroup = info.gs_inst_prims_in_subgroup; + out->es_verts_per_subgroup = info.es_verts_per_subgroup; + out->gs_prims_per_subgroup = info.gs_prims_per_subgroup; + out->esgs_itemsize = esgs_vertex_stride / 4; + out->lds_size = total_lds_bytes / lds_granularity; + unsigned num_se = pdev->info.max_se; unsigned wave_size = 64; unsigned max_gs_waves = 32 * num_se; /* max 32 per SE on GCN */ @@ -704,9 +719,9 @@ radv_init_legacy_gs_ring_info(const struct radv_device *device, struct radv_shad unsigned max_size = ((unsigned)(63.999 * 1024 * 1024) & ~255) * num_se; /* Calculate the minimum size. */ - unsigned min_esgs_ring_size = align(gs_ring_info->esgs_itemsize * 4 * gs_vertex_reuse * wave_size, alignment); + unsigned min_esgs_ring_size = align(esgs_vertex_stride * gs_vertex_reuse * wave_size, alignment); /* These are recommended sizes, not minimum sizes. */ - unsigned esgs_ring_size = max_gs_waves * 2 * wave_size * gs_ring_info->esgs_itemsize * 4 * gs_info->gs.vertices_in; + unsigned esgs_ring_size = max_gs_waves * 2 * wave_size * esgs_vertex_stride * gs_info->gs.vertices_in; unsigned gsvs_ring_size = max_gs_waves * 2 * wave_size * gs_info->gs.max_gsvs_emit_size; min_esgs_ring_size = align(min_esgs_ring_size, alignment); @@ -714,110 +729,9 @@ radv_init_legacy_gs_ring_info(const struct radv_device *device, struct radv_shad gsvs_ring_size = align(gsvs_ring_size, alignment); if (pdev->info.gfx_level <= GFX8) - gs_ring_info->esgs_ring_size = CLAMP(esgs_ring_size, min_esgs_ring_size, max_size); + out->esgs_ring_size = CLAMP(esgs_ring_size, min_esgs_ring_size, max_size); - gs_ring_info->gsvs_ring_size = MIN2(gsvs_ring_size, max_size); -} - -static void -radv_get_legacy_gs_info(const struct radv_device *device, struct radv_shader_info *gs_info) -{ - const struct radv_physical_device *pdev = radv_device_physical(device); - struct radv_legacy_gs_info *out = &gs_info->gs_ring_info; - const unsigned gs_num_invocations = MAX2(gs_info->gs.invocations, 1); - const bool uses_adjacency = - gs_info->gs.input_prim == MESA_PRIM_LINES_ADJACENCY || gs_info->gs.input_prim == MESA_PRIM_TRIANGLES_ADJACENCY; - - /* All these are in dwords: */ - /* We can't allow using the whole LDS, because GS waves compete with - * other shader stages for LDS space. */ - const unsigned max_lds_size = 8 * 1024; - const unsigned esgs_itemsize = radv_compute_esgs_itemsize(device, gs_info->gs.num_linked_inputs) / 4; - unsigned esgs_lds_size; - - /* All these are per subgroup: */ - const unsigned max_out_prims = 32 * 1024; - const unsigned max_es_verts = 255; - const unsigned ideal_gs_prims = 64; - unsigned max_gs_prims, gs_prims; - unsigned min_es_verts, es_verts, worst_case_es_verts; - - if (uses_adjacency || gs_num_invocations > 1) - max_gs_prims = 127 / gs_num_invocations; - else - max_gs_prims = 255; - - /* MAX_PRIMS_PER_SUBGROUP = gs_prims * max_vert_out * gs_invocations. - * Make sure we don't go over the maximum value. - */ - if (gs_info->gs.vertices_out > 0) { - max_gs_prims = MIN2(max_gs_prims, max_out_prims / (gs_info->gs.vertices_out * gs_num_invocations)); - } - assert(max_gs_prims > 0); - - /* If the primitive has adjacency, halve the number of vertices - * that will be reused in multiple primitives. - */ - min_es_verts = gs_info->gs.vertices_in / (uses_adjacency ? 2 : 1); - - gs_prims = MIN2(ideal_gs_prims, max_gs_prims); - worst_case_es_verts = MIN2(min_es_verts * gs_prims, max_es_verts); - - /* Compute ESGS LDS size based on the worst case number of ES vertices - * needed to create the target number of GS prims per subgroup. - */ - esgs_lds_size = esgs_itemsize * worst_case_es_verts; - - /* If total LDS usage is too big, refactor partitions based on ratio - * of ESGS item sizes. - */ - if (esgs_lds_size > max_lds_size) { - /* Our target GS Prims Per Subgroup was too large. Calculate - * the maximum number of GS Prims Per Subgroup that will fit - * into LDS, capped by the maximum that the hardware can support. - */ - gs_prims = MIN2((max_lds_size / (esgs_itemsize * min_es_verts)), max_gs_prims); - assert(gs_prims > 0); - worst_case_es_verts = MIN2(min_es_verts * gs_prims, max_es_verts); - - esgs_lds_size = esgs_itemsize * worst_case_es_verts; - assert(esgs_lds_size <= max_lds_size); - } - - /* Now calculate remaining ESGS information. */ - if (esgs_lds_size) - es_verts = MIN2(esgs_lds_size / esgs_itemsize, max_es_verts); - else - es_verts = max_es_verts; - - /* Vertices for adjacency primitives are not always reused, so restore - * it for ES_VERTS_PER_SUBGRP. - */ - min_es_verts = gs_info->gs.vertices_in; - - /* For normal primitives, the VGT only checks if they are past the ES - * verts per subgroup after allocating a full GS primitive and if they - * are, kick off a new subgroup. But if those additional ES verts are - * unique (e.g. not reused) we need to make sure there is enough LDS - * space to account for those ES verts beyond ES_VERTS_PER_SUBGRP. - */ - es_verts -= min_es_verts - 1; - - const uint32_t es_verts_per_subgroup = es_verts; - const uint32_t gs_prims_per_subgroup = gs_prims; - const uint32_t gs_inst_prims_in_subgroup = gs_prims * gs_num_invocations; - const uint32_t max_prims_per_subgroup = gs_inst_prims_in_subgroup * gs_info->gs.vertices_out; - const uint32_t lds_granularity = pdev->info.lds_encode_granularity; - const uint32_t total_lds_bytes = align(esgs_lds_size * 4, lds_granularity); - - out->gs_inst_prims_in_subgroup = gs_inst_prims_in_subgroup; - out->es_verts_per_subgroup = es_verts_per_subgroup; - out->gs_prims_per_subgroup = gs_prims_per_subgroup; - out->esgs_itemsize = esgs_itemsize; - out->lds_size = total_lds_bytes / lds_granularity; - assert(max_prims_per_subgroup <= max_out_prims); - - radv_init_legacy_gs_ring_info(device, gs_info); + out->gsvs_ring_size = MIN2(gsvs_ring_size, max_size); } static void