From 2da7b4bd0a097d948464844f1da2ba60e664e331 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Sat, 6 Sep 2025 16:59:24 +0200 Subject: [PATCH] radv/nir/lower_cmat: add shuffle_xor_imm helper Reviewed-by: Rhys Perry Part-of: --- .../nir/radv_nir_lower_cooperative_matrix.c | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index ca5d146df1a..c4517502ff8 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -165,6 +165,27 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower return base_row; } +static nir_def * +shuffle_xor_imm(nir_builder *b, nir_def *data, unsigned imm) +{ + assert(imm < 64); + if (imm == 32) { + /* v_permlane64_b32 */ + return nir_rotate(b, data, nir_imm_int(b, 32), .cluster_size = 64); + } else if (imm < 32) { + /* All of these map to a single DPP/v_permlanex16 instruction */ + return nir_masked_swizzle_amd(b, data, .swizzle_mask = 0x1f | (imm << 10), .fetch_inactive = 1); + } else { + /* There isn't a single instruction that can do this, but for cooperative matrix + * opcodes all invocations must be active. + * So we can split the operation into the two previous cases instead of + * having to use full width shuffle. + */ + data = shuffle_xor_imm(b, data, 32); + return shuffle_xor_imm(b, data, imm & ~32); + } +} + static bool lower_cmat_length(nir_builder *b, nir_intrinsic_instr *intr, const lower_cmat_params *params) { @@ -454,7 +475,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ nir_def *low_lanes = nir_inverse_ballot_imm(b, UINT32_MAX, 64); for (int i = 0; i < num_comps; i++) { nir_def *comp = components[i]; - nir_def *half_swap = nir_rotate(b, comp, nir_imm_int(b, 32), .cluster_size = 64); + nir_def *half_swap = shuffle_xor_imm(b, comp, 32); tmp[i * 2] = nir_bcsel(b, low_lanes, comp, half_swap); tmp[i * 2 + 1] = nir_bcsel(b, low_lanes, half_swap, comp); @@ -465,8 +486,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ nir_def *low_lanes = nir_inverse_ballot_imm(b, 0xffff0000ffffull, params->wave_size); for (int i = 0; i < num_comps; i++) { - unsigned swap16 = 0x1f | (0x10 << 10); - nir_def *half_swap = nir_masked_swizzle_amd(b, components[i], .swizzle_mask = swap16, .fetch_inactive = 1); + nir_def *half_swap = shuffle_xor_imm(b, components[i], 16); tmp[i * 2] = nir_bcsel(b, low_lanes, components[i], half_swap); tmp[i * 2 + 1] = nir_bcsel(b, low_lanes, half_swap, components[i]); } @@ -490,7 +510,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ nir_def *second_perm = nir_ior_imm(b, first_perm, 0x02020202); for (int i = 0; i < num_comps; i++) { nir_def *comp = components[i]; - nir_def *half_swap = nir_rotate(b, comp, nir_imm_int(b, 32), .cluster_size = 64); + nir_def *half_swap = shuffle_xor_imm(b, comp, 32); tmp[i * 2] = nir_byte_perm_amd(b, half_swap, comp, first_perm); tmp[i * 2 + 1] = nir_byte_perm_amd(b, half_swap, comp, second_perm); @@ -504,8 +524,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ nir_def *second_perm = nir_ior_imm(b, first_perm, 0x02020202); for (int i = 0; i < num_comps; i++) { nir_def *comp = components[i]; - unsigned swap16 = 0x1f | (0x10 << 10); - nir_def *half_swap = nir_masked_swizzle_amd(b, comp, .swizzle_mask = swap16, .fetch_inactive = 1); + nir_def *half_swap = shuffle_xor_imm(b, comp, 16); tmp[i * 2] = nir_byte_perm_amd(b, half_swap, comp, first_perm); tmp[i * 2 + 1] = nir_byte_perm_amd(b, half_swap, comp, second_perm); } @@ -577,11 +596,8 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ nir_def *comp0 = components[pos0]; nir_def *comp1 = components[pos1]; - nir_def *comp0x = - nir_masked_swizzle_amd(b, comp0, .swizzle_mask = 0x1f | (x_mask << 10), .fetch_inactive = 1); - nir_def *comp1x = - nir_masked_swizzle_amd(b, comp1, .swizzle_mask = 0x1f | (x_mask << 10), .fetch_inactive = 1); - + nir_def *comp0x = shuffle_xor_imm(b, comp0, x_mask); + nir_def *comp1x = shuffle_xor_imm(b, comp1, x_mask); components[pos0] = nir_bcsel(b, even, comp0, comp1x); components[pos1] = nir_bcsel(b, odd, comp1, comp0x); } @@ -595,8 +611,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ nir_def *cond = nir_inverse_ballot_imm(b, 0xf0f0f0f00f0f0f0f, params->wave_size); for (unsigned i = 0; i < num_comps; i++) { nir_def *comp = components[i]; - nir_def *compx = nir_rotate(b, comp, nir_imm_int(b, 32)); - compx = nir_masked_swizzle_amd(b, compx, .swizzle_mask = 0x1f | (0x4 << 10), .fetch_inactive = 1); + nir_def *compx = shuffle_xor_imm(b, comp, 0x24); components[i] = nir_bcsel(b, cond, comp, compx); } } @@ -604,7 +619,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ nir_def *cond = nir_inverse_ballot_imm(b, 0xff0000ffff0000ff, params->wave_size); for (unsigned i = 0; i < num_comps; i++) { nir_def *comp = components[i]; - nir_def *compx = nir_masked_swizzle_amd(b, comp, .swizzle_mask = 0x1f | (0x18 << 10), .fetch_inactive = 1); + nir_def *compx = shuffle_xor_imm(b, comp, 0x18); components[i] = nir_bcsel(b, cond, comp, compx); } }