diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 7ba1978b5a8..1fb68bdb32d 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -107,6 +107,8 @@ typedef struct unsigned prim_vtx_indices_addr; unsigned numprims_lds_addr; unsigned wave_size; + unsigned api_workgroup_size; + unsigned hw_workgroup_size; struct { /* Bitmask of components used: 4 bits per slot, 1 bit per component. */ @@ -2457,12 +2459,12 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) } static void -handle_smaller_ms_api_workgroup(nir_function_impl *impl, - nir_builder *b, - unsigned api_workgroup_size, - unsigned hw_workgroup_size, +handle_smaller_ms_api_workgroup(nir_builder *b, lower_ngg_ms_state *s) { + if (s->api_workgroup_size >= s->hw_workgroup_size) + return; + /* Handle barriers manually when the API workgroup * size is less than the HW workgroup size. * @@ -2478,19 +2480,19 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl, * all. In this case, we emit code that consumes every * barrier on the extra waves. */ - assert(hw_workgroup_size % s->wave_size == 0); - bool scan_barriers = ALIGN(api_workgroup_size, s->wave_size) < hw_workgroup_size; - bool can_shrink_barriers = api_workgroup_size <= s->wave_size; + assert(s->hw_workgroup_size % s->wave_size == 0); + bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size; + bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size; bool need_additional_barriers = scan_barriers && !can_shrink_barriers; unsigned api_waves_in_flight_addr = s->numprims_lds_addr + 12; - unsigned num_api_waves = DIV_ROUND_UP(api_workgroup_size, s->wave_size); + unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size); /* Scan the shader for workgroup barriers. */ if (scan_barriers) { bool has_any_workgroup_barriers = false; - nir_foreach_block(block, impl) { + nir_foreach_block(block, b->impl) { nir_foreach_instr_safe(instr, block) { if (instr->type != nir_instr_type_intrinsic) continue; @@ -2521,8 +2523,8 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl, /* Extract the full control flow of the shader. */ nir_cf_list extracted; - nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body)); - b->cursor = nir_before_cf_list(&impl->body); + nir_cf_extract(&extracted, nir_before_cf_list(&b->impl->body), nir_after_cf_list(&b->impl->body)); + b->cursor = nir_before_cf_list(&b->impl->body); /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */ nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); @@ -2542,7 +2544,7 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl, .memory_modes = nir_var_shader_out | nir_var_mem_shared); } - nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, api_workgroup_size)); + nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, s->api_workgroup_size)); nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation); { nir_cf_reinsert(&extracted, b->cursor); @@ -2638,19 +2640,6 @@ ac_nir_lower_ngg_ms(nir_shader *shader, shader->info.shared_size = prim_vtx_indices_addr + prim_vtx_indices_size; - lower_ngg_ms_state state = { - .wave_size = wave_size, - .per_vertex_outputs = per_vertex_outputs, - .per_primitive_outputs = per_primitive_outputs, - .num_per_vertex_outputs = num_per_vertex_outputs, - .num_per_primitive_outputs = num_per_primitive_outputs, - .vertices_per_prim = vertices_per_prim, - .vertex_attr_lds_addr = vertex_attr_lds_addr, - .prim_attr_lds_addr = prim_attr_lds_addr, - .prim_vtx_indices_addr = prim_vtx_indices_addr, - .numprims_lds_addr = numprims_lds_addr, - }; - /* The workgroup size that is specified by the API shader may be different * from the size of the workgroup that actually runs on the HW, due to the * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed. @@ -2665,6 +2654,21 @@ ac_nir_lower_ngg_ms(nir_shader *shader, unsigned hw_workgroup_size = ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size); + lower_ngg_ms_state state = { + .wave_size = wave_size, + .per_vertex_outputs = per_vertex_outputs, + .per_primitive_outputs = per_primitive_outputs, + .num_per_vertex_outputs = num_per_vertex_outputs, + .num_per_primitive_outputs = num_per_primitive_outputs, + .vertices_per_prim = vertices_per_prim, + .vertex_attr_lds_addr = vertex_attr_lds_addr, + .prim_attr_lds_addr = prim_attr_lds_addr, + .prim_vtx_indices_addr = prim_vtx_indices_addr, + .numprims_lds_addr = numprims_lds_addr, + .api_workgroup_size = api_workgroup_size, + .hw_workgroup_size = hw_workgroup_size, + }; + nir_function_impl *impl = nir_shader_get_entrypoint(shader); assert(impl); @@ -2673,9 +2677,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader, nir_builder_init(b, impl); b->cursor = nir_before_cf_list(&impl->body); - if (api_workgroup_size < hw_workgroup_size) { - handle_smaller_ms_api_workgroup(impl, b, api_workgroup_size, hw_workgroup_size, &state); - } + handle_smaller_ms_api_workgroup(b, &state); lower_ms_intrinsics(shader, &state); emit_ms_finale(b, &state);