diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 46a14b7a2ac..074d8300579 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -1569,7 +1569,8 @@ radv_get_max_waves(struct radv_device *device, unsigned max_simd_waves; unsigned lds_per_wave = 0; - max_simd_waves = device->physical_device->rad_info.max_wave64_per_simd; + max_simd_waves = device->physical_device->rad_info.max_wave64_per_simd * + (64 / wave_size); if (stage == MESA_SHADER_FRAGMENT) { lds_per_wave = conf->lds_size * lds_increment + @@ -1582,7 +1583,7 @@ radv_get_max_waves(struct radv_device *device, DIV_ROUND_UP(max_workgroup_size, wave_size); } - if (conf->num_sgprs) { + if (conf->num_sgprs && chip_class < GFX10) { unsigned sgprs = align(conf->num_sgprs, chip_class >= GFX8 ? 16 : 8); max_simd_waves = MIN2(max_simd_waves, @@ -1591,12 +1592,12 @@ radv_get_max_waves(struct radv_device *device, } if (conf->num_vgprs) { + unsigned physical_vgprs = device->physical_device->rad_info.num_physical_wave64_vgprs_per_simd * + (64 / wave_size); unsigned vgprs = align(conf->num_vgprs, wave_size == 32 ? 8 : 4); if (chip_class >= GFX10_3) vgprs = align(vgprs, wave_size == 32 ? 16 : 8); - max_simd_waves = - MIN2(max_simd_waves, - device->physical_device->rad_info.num_physical_wave64_vgprs_per_simd / vgprs); + max_simd_waves = MIN2(max_simd_waves, physical_vgprs / vgprs); } unsigned simd_per_workgroup = device->physical_device->rad_info.num_simd_per_compute_unit; @@ -1607,7 +1608,7 @@ radv_get_max_waves(struct radv_device *device, if (lds_per_wave) max_simd_waves = MIN2(max_simd_waves, max_lds_per_simd / lds_per_wave); - return max_simd_waves; + return chip_class >= GFX10 ? max_simd_waves * (wave_size / 32) : max_simd_waves; } VkResult