From a55c03645062fba6bad9316f70cb9bd71c8a39d7 Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Mon, 4 Aug 2025 13:32:51 +1000 Subject: [PATCH] radv: add support for coopmat2 flexible dimensions This allows matricies that are multiples of a base size. Reviewed-by: Georg Lehmann Part-of: --- src/amd/vulkan/radv_physical_device.c | 88 +++++++++++++++++++++++++-- src/amd/vulkan/radv_shader.c | 7 +++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/src/amd/vulkan/radv_physical_device.c b/src/amd/vulkan/radv_physical_device.c index 300ee571ab2..f0329f271ad 100644 --- a/src/amd/vulkan/radv_physical_device.c +++ b/src/amd/vulkan/radv_physical_device.c @@ -1371,6 +1371,7 @@ radv_physical_device_get_features(const struct radv_physical_device *pdev, struc /* VK_NV_cooperative_matrix2 */ .cooperativeMatrixConversions = true, + .cooperativeMatrixFlexibleDimensions = true, /* VK_KHR_video_encode_av1 */ .videoEncodeAV1 = true, @@ -2023,6 +2024,9 @@ radv_get_physical_device_properties(struct radv_physical_device *pdev) /* VK_KHR_maintenance9 */ .image2DViewOf3DSparse = pdev->info.gfx_level >= GFX8, .defaultVertexAttributeValue = VK_DEFAULT_VERTEX_ATTRIBUTE_VALUE_ZERO_ZERO_ZERO_ZERO_KHR, + + /* VK_NV_cooperative_matrix2 */ + .cooperativeMatrixFlexibleDimensionsMaxDimension = 1024, }; struct vk_properties *p = &pdev->vk.properties; @@ -3076,10 +3080,84 @@ radv_GetPhysicalDeviceCooperativeMatrixPropertiesKHR(VkPhysicalDevice physicalDe } VKAPI_ATTR VkResult VKAPI_CALL -radv_GetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV( - VkPhysicalDevice physicalDevice, uint32_t *pPropertyCount, - VkCooperativeMatrixFlexibleDimensionsPropertiesNV *pProperties) +radv_GetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(VkPhysicalDevice physicalDevice, uint32_t *pPropertyCount, + VkCooperativeMatrixFlexibleDimensionsPropertiesNV *pProperties) { - *pPropertyCount = 0; - return VK_SUCCESS; + VK_FROM_HANDLE(radv_physical_device, pdev, physicalDevice); + VK_OUTARRAY_MAKE_TYPED(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, out, pProperties, pPropertyCount); + + if (pdev->info.gfx_level >= GFX12) { + for (unsigned e5m2_a = 0; e5m2_a < 2; e5m2_a++) { + for (unsigned e5m2_b = 0; e5m2_b < 2; e5m2_b++) { + VkComponentTypeKHR a_type = e5m2_a ? VK_COMPONENT_TYPE_FLOAT8_E5M2_EXT : VK_COMPONENT_TYPE_FLOAT8_E4M3_EXT; + VkComponentTypeKHR b_type = e5m2_b ? VK_COMPONENT_TYPE_FLOAT8_E5M2_EXT : VK_COMPONENT_TYPE_FLOAT8_E4M3_EXT; + + vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p) + { + *p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){ + .sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV, + .MGranularity = 16, + .NGranularity = 16, + .KGranularity = 16, + .AType = a_type, + .BType = b_type, + .CType = VK_COMPONENT_TYPE_FLOAT32_KHR, + .ResultType = VK_COMPONENT_TYPE_FLOAT32_KHR, + .saturatingAccumulation = false, + .scope = VK_SCOPE_SUBGROUP_KHR}; + } + } + } + } + + for (unsigned bfloat = 0; bfloat < 2; bfloat++) { + for (unsigned fp32 = 0; fp32 < 2; fp32++) { + VkComponentTypeKHR ab_type = bfloat ? VK_COMPONENT_TYPE_BFLOAT16_KHR : VK_COMPONENT_TYPE_FLOAT16_KHR; + VkComponentTypeKHR cd_type = fp32 ? VK_COMPONENT_TYPE_FLOAT32_KHR : ab_type; + + if (pdev->info.gfx_level < GFX12 && bfloat) + continue; /* BF16 isn't working precisely on GFX11. */ + + vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p) + { + *p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV, + .MGranularity = 16, + .NGranularity = 16, + .KGranularity = 16, + .AType = ab_type, + .BType = ab_type, + .CType = cd_type, + .ResultType = cd_type, + .saturatingAccumulation = false, + .scope = VK_SCOPE_SUBGROUP_KHR}; + } + } + } + + for (unsigned asigned = 0; asigned < 2; asigned++) { + for (unsigned bsigned = 0; bsigned < 2; bsigned++) { + for (unsigned csigned = 0; csigned < 2; csigned++) { + for (unsigned saturate = 0; saturate < 2; saturate++) { + if (!csigned && saturate) + continue; /* The HW only supports signed acc. */ + vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p) + { + *p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){ + .sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV, + .MGranularity = 16, + .NGranularity = 16, + .KGranularity = 16, + .AType = asigned ? VK_COMPONENT_TYPE_SINT8_KHR : VK_COMPONENT_TYPE_UINT8_KHR, + .BType = bsigned ? VK_COMPONENT_TYPE_SINT8_KHR : VK_COMPONENT_TYPE_UINT8_KHR, + .CType = csigned ? VK_COMPONENT_TYPE_SINT32_KHR : VK_COMPONENT_TYPE_UINT32_KHR, + .ResultType = csigned ? VK_COMPONENT_TYPE_SINT32_KHR : VK_COMPONENT_TYPE_UINT32_KHR, + .saturatingAccumulation = saturate, + .scope = VK_SCOPE_SUBGROUP_KHR}; + } + } + } + } + } + + return vk_outarray_status(&out); } diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 235aab920b5..5bec95031f1 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -457,6 +457,13 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st */ NIR_PASS(_, nir, nir_lower_variable_initializers, ~0); + progress = false; + NIR_PASS(progress, nir, nir_lower_cooperative_matrix_flexible_dimensions, 16, 16, 16); + if (progress) { + NIR_PASS(_, nir, nir_opt_deref); + NIR_PASS(_, nir, nir_opt_dce); + NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL); + } NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, pdev->info.gfx_level, subgroup_size); /* Split member structs. We do this before lower_io_to_temporaries so that