radv: add support for coopmat2 flexible dimensions

This allows matricies that are multiples of a base size.

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36544>
This commit is contained in:
Dave Airlie
2025-08-04 13:32:51 +10:00
committed by Marge Bot
parent 7a96a928a2
commit a55c036450
2 changed files with 90 additions and 5 deletions

View File

@@ -1371,6 +1371,7 @@ radv_physical_device_get_features(const struct radv_physical_device *pdev, struc
/* VK_NV_cooperative_matrix2 */
.cooperativeMatrixConversions = true,
.cooperativeMatrixFlexibleDimensions = true,
/* VK_KHR_video_encode_av1 */
.videoEncodeAV1 = true,
@@ -2023,6 +2024,9 @@ radv_get_physical_device_properties(struct radv_physical_device *pdev)
/* VK_KHR_maintenance9 */
.image2DViewOf3DSparse = pdev->info.gfx_level >= GFX8,
.defaultVertexAttributeValue = VK_DEFAULT_VERTEX_ATTRIBUTE_VALUE_ZERO_ZERO_ZERO_ZERO_KHR,
/* VK_NV_cooperative_matrix2 */
.cooperativeMatrixFlexibleDimensionsMaxDimension = 1024,
};
struct vk_properties *p = &pdev->vk.properties;
@@ -3076,10 +3080,84 @@ radv_GetPhysicalDeviceCooperativeMatrixPropertiesKHR(VkPhysicalDevice physicalDe
}
VKAPI_ATTR VkResult VKAPI_CALL
radv_GetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(
VkPhysicalDevice physicalDevice, uint32_t *pPropertyCount,
VkCooperativeMatrixFlexibleDimensionsPropertiesNV *pProperties)
radv_GetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(VkPhysicalDevice physicalDevice, uint32_t *pPropertyCount,
VkCooperativeMatrixFlexibleDimensionsPropertiesNV *pProperties)
{
*pPropertyCount = 0;
return VK_SUCCESS;
VK_FROM_HANDLE(radv_physical_device, pdev, physicalDevice);
VK_OUTARRAY_MAKE_TYPED(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, out, pProperties, pPropertyCount);
if (pdev->info.gfx_level >= GFX12) {
for (unsigned e5m2_a = 0; e5m2_a < 2; e5m2_a++) {
for (unsigned e5m2_b = 0; e5m2_b < 2; e5m2_b++) {
VkComponentTypeKHR a_type = e5m2_a ? VK_COMPONENT_TYPE_FLOAT8_E5M2_EXT : VK_COMPONENT_TYPE_FLOAT8_E4M3_EXT;
VkComponentTypeKHR b_type = e5m2_b ? VK_COMPONENT_TYPE_FLOAT8_E5M2_EXT : VK_COMPONENT_TYPE_FLOAT8_E4M3_EXT;
vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p)
{
*p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){
.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV,
.MGranularity = 16,
.NGranularity = 16,
.KGranularity = 16,
.AType = a_type,
.BType = b_type,
.CType = VK_COMPONENT_TYPE_FLOAT32_KHR,
.ResultType = VK_COMPONENT_TYPE_FLOAT32_KHR,
.saturatingAccumulation = false,
.scope = VK_SCOPE_SUBGROUP_KHR};
}
}
}
}
for (unsigned bfloat = 0; bfloat < 2; bfloat++) {
for (unsigned fp32 = 0; fp32 < 2; fp32++) {
VkComponentTypeKHR ab_type = bfloat ? VK_COMPONENT_TYPE_BFLOAT16_KHR : VK_COMPONENT_TYPE_FLOAT16_KHR;
VkComponentTypeKHR cd_type = fp32 ? VK_COMPONENT_TYPE_FLOAT32_KHR : ab_type;
if (pdev->info.gfx_level < GFX12 && bfloat)
continue; /* BF16 isn't working precisely on GFX11. */
vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p)
{
*p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV,
.MGranularity = 16,
.NGranularity = 16,
.KGranularity = 16,
.AType = ab_type,
.BType = ab_type,
.CType = cd_type,
.ResultType = cd_type,
.saturatingAccumulation = false,
.scope = VK_SCOPE_SUBGROUP_KHR};
}
}
}
for (unsigned asigned = 0; asigned < 2; asigned++) {
for (unsigned bsigned = 0; bsigned < 2; bsigned++) {
for (unsigned csigned = 0; csigned < 2; csigned++) {
for (unsigned saturate = 0; saturate < 2; saturate++) {
if (!csigned && saturate)
continue; /* The HW only supports signed acc. */
vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p)
{
*p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){
.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV,
.MGranularity = 16,
.NGranularity = 16,
.KGranularity = 16,
.AType = asigned ? VK_COMPONENT_TYPE_SINT8_KHR : VK_COMPONENT_TYPE_UINT8_KHR,
.BType = bsigned ? VK_COMPONENT_TYPE_SINT8_KHR : VK_COMPONENT_TYPE_UINT8_KHR,
.CType = csigned ? VK_COMPONENT_TYPE_SINT32_KHR : VK_COMPONENT_TYPE_UINT32_KHR,
.ResultType = csigned ? VK_COMPONENT_TYPE_SINT32_KHR : VK_COMPONENT_TYPE_UINT32_KHR,
.saturatingAccumulation = saturate,
.scope = VK_SCOPE_SUBGROUP_KHR};
}
}
}
}
}
return vk_outarray_status(&out);
}

View File

@@ -457,6 +457,13 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st
*/
NIR_PASS(_, nir, nir_lower_variable_initializers, ~0);
progress = false;
NIR_PASS(progress, nir, nir_lower_cooperative_matrix_flexible_dimensions, 16, 16, 16);
if (progress) {
NIR_PASS(_, nir, nir_opt_deref);
NIR_PASS(_, nir, nir_opt_dce);
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
}
NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, pdev->info.gfx_level, subgroup_size);
/* Split member structs. We do this before lower_io_to_temporaries so that