vulkan/radix_sort: Stop force-unrolling loops
This is really bad for compile times in lavapipe. Compiling a shader with all loops unrolled can take 2 seconds. nir and llvm should be smart enough to unroll thos themselves if it's beneficial. Acked-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31426>
This commit is contained in:
committed by
Marge Bot
parent
c387699c7b
commit
d6244049a1
@@ -117,7 +117,7 @@ main()
|
||||
//
|
||||
const uint32_t row_idx = dwords_idx - info.dword_offset_min;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_FILL_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_FILL_BLOCK_ROWS; ii++)
|
||||
{
|
||||
if (row_idx + (ii * RS_WORKGROUP_SIZE) < info.dword_offset_max_minus_min)
|
||||
{
|
||||
|
||||
@@ -182,7 +182,7 @@ rs_histogram_zero()
|
||||
{
|
||||
const uint32_t smem_offset = gl_SubgroupInvocationID;
|
||||
|
||||
[[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
{
|
||||
smem.histogram[smem_offset + ii] = 0;
|
||||
}
|
||||
@@ -191,7 +191,7 @@ rs_histogram_zero()
|
||||
{
|
||||
const uint32_t smem_offset = gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
{
|
||||
smem.histogram[smem_offset + ii] = 0;
|
||||
}
|
||||
@@ -225,7 +225,7 @@ rs_histogram_global_store(restrict buffer_rs_histograms rs_histograms)
|
||||
{
|
||||
const uint32_t smem_offset = gl_SubgroupInvocationID;
|
||||
|
||||
[[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
{
|
||||
const uint32_t count = smem.histogram[smem_offset + ii];
|
||||
|
||||
@@ -236,7 +236,7 @@ rs_histogram_global_store(restrict buffer_rs_histograms rs_histograms)
|
||||
{
|
||||
const uint32_t smem_offset = gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
{
|
||||
const uint32_t count = smem.histogram[smem_offset + ii];
|
||||
|
||||
@@ -323,7 +323,7 @@ main()
|
||||
//
|
||||
// Load keyvals
|
||||
//
|
||||
[[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_HISTOGRAM_BLOCK_ROWS; ii++)
|
||||
for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_HISTOGRAM_BLOCK_ROWS; ii++)
|
||||
{
|
||||
kv[ii] = rs_kv_in.extent[ii * RS_WORKGROUP_SIZE];
|
||||
}
|
||||
@@ -348,7 +348,7 @@ main()
|
||||
\
|
||||
rs_histogram_atomic_after_write(); \
|
||||
\
|
||||
[[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t jj = 0; jj < RS_HISTOGRAM_BLOCK_ROWS; jj++) \
|
||||
for (RS_SUBGROUP_UNIFORM uint32_t jj = 0; jj < RS_HISTOGRAM_BLOCK_ROWS; jj++) \
|
||||
{ \
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[jj], pass_); \
|
||||
\
|
||||
@@ -382,7 +382,7 @@ main()
|
||||
push.devaddr_histograms, \
|
||||
rs_histogram_base); \
|
||||
\
|
||||
[[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t jj = 0; jj < RS_HISTOGRAM_BLOCK_ROWS; jj++) \
|
||||
for (RS_SUBGROUP_UNIFORM uint32_t jj = 0; jj < RS_HISTOGRAM_BLOCK_ROWS; jj++) \
|
||||
{ \
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[jj], pass_); \
|
||||
\
|
||||
|
||||
@@ -98,7 +98,7 @@ rs_prefix(RS_PREFIX_ARGS)
|
||||
//
|
||||
// Downsweep 0
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
{
|
||||
const uint32_t h = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE);
|
||||
|
||||
@@ -167,7 +167,7 @@ rs_prefix(RS_PREFIX_ARGS)
|
||||
//
|
||||
// Scan 0 and Downsweep 1
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 32 invocations
|
||||
for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 32 invocations
|
||||
{
|
||||
const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
|
||||
const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
|
||||
@@ -214,7 +214,7 @@ rs_prefix(RS_PREFIX_ARGS)
|
||||
}
|
||||
else
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 64 invocations
|
||||
for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 64 invocations
|
||||
{
|
||||
const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
|
||||
const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
|
||||
@@ -245,7 +245,7 @@ rs_prefix(RS_PREFIX_ARGS)
|
||||
}
|
||||
else
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_S1_PASSES; ii++) // 16 invocations
|
||||
for (uint32_t ii = 0; ii < RS_S1_PASSES; ii++) // 16 invocations
|
||||
{
|
||||
const uint32_t idx1 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x;
|
||||
const uint32_t idx2 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
|
||||
@@ -295,7 +295,7 @@ rs_prefix(RS_PREFIX_ARGS)
|
||||
}
|
||||
else if (RS_SUBGROUP_SIZE >= 16)
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
{
|
||||
const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
|
||||
|
||||
@@ -311,7 +311,7 @@ rs_prefix(RS_PREFIX_ARGS)
|
||||
}
|
||||
else if (RS_SUBGROUP_SIZE == 8)
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
{
|
||||
const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
|
||||
const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE;
|
||||
@@ -329,7 +329,7 @@ rs_prefix(RS_PREFIX_ARGS)
|
||||
}
|
||||
else if (RS_SUBGROUP_SIZE == 4)
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++)
|
||||
{
|
||||
const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID;
|
||||
const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE;
|
||||
|
||||
@@ -370,7 +370,7 @@ rs_histogram_zero()
|
||||
{
|
||||
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
{
|
||||
smem.extent[smem_offset + ii] = 0;
|
||||
}
|
||||
@@ -379,7 +379,7 @@ rs_histogram_zero()
|
||||
{
|
||||
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
{
|
||||
smem.extent[smem_offset + ii] = 0;
|
||||
}
|
||||
@@ -431,7 +431,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
//----------------------------------------------------------------------
|
||||
#ifdef RS_SCATTER_ENABLE_NV_MATCH
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
//
|
||||
// NOTE(allanmac): Unfortunately there is no `match.any.sync.b8`
|
||||
@@ -453,7 +453,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
//----------------------------------------------------------------------
|
||||
#elif !defined(RS_SCATTER_ENABLE_BROADCAST_MATCH)
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
|
||||
|
||||
@@ -467,7 +467,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
match = ballot ^ mask;
|
||||
}
|
||||
|
||||
[[unroll]] for (int32_t bit = 1; bit < RS_RADIX_LOG2; bit++)
|
||||
for (int32_t bit = 1; bit < RS_RADIX_LOG2; bit++)
|
||||
{
|
||||
const bool is_one = RS_BIT_IS_ONE(digit, bit);
|
||||
const u32vec4 ballot = subgroupBallot(is_one);
|
||||
@@ -493,7 +493,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
//
|
||||
if (RS_SUBGROUP_SIZE == 64)
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
|
||||
|
||||
@@ -505,7 +505,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
}
|
||||
|
||||
// subgroup invocations 1-31
|
||||
[[unroll]] for (int32_t jj = 1; jj < 32; jj++)
|
||||
for (int32_t jj = 1; jj < 32; jj++)
|
||||
{
|
||||
match[0] |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0;
|
||||
}
|
||||
@@ -516,7 +516,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
}
|
||||
|
||||
// subgroup invocations 33-63
|
||||
[[unroll]] for (int32_t jj = 1; jj < 32; jj++)
|
||||
for (int32_t jj = 1; jj < 32; jj++)
|
||||
{
|
||||
match[1] |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0;
|
||||
}
|
||||
@@ -526,7 +526,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
bitCount(match.y & gl_SubgroupLeMask.y));
|
||||
}
|
||||
} else if (RS_SUBGROUP_SIZE <= 32) {
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
|
||||
|
||||
@@ -534,7 +534,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
uint32_t match = (subgroupBroadcast(digit, 0) == digit) ? (1u << 0) : 0;
|
||||
|
||||
// subgroup invocations 1-(RS_SUBGROUP_SIZE-1)
|
||||
[[unroll]] for (int32_t jj = 1; jj < RS_SUBGROUP_SIZE; jj++)
|
||||
for (int32_t jj = 1; jj < RS_SUBGROUP_SIZE; jj++)
|
||||
{
|
||||
match |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0;
|
||||
}
|
||||
@@ -553,7 +553,7 @@ rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
{
|
||||
if (gl_SubgroupID == ii)
|
||||
{
|
||||
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
{
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[jj]);
|
||||
const uint32_t prev = RS_HISTOGRAM_LOAD(digit);
|
||||
@@ -600,7 +600,7 @@ rs_first_prefix_store(restrict buffer_rs_partitions rs_partitions)
|
||||
const uint32_t smem_offset_h = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID;
|
||||
const uint32_t smem_offset_l = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
{
|
||||
const uint32_t exc = rs_histogram.extent[ii];
|
||||
const uint32_t red = smem.extent[smem_offset_h + ii];
|
||||
@@ -625,7 +625,7 @@ rs_first_prefix_store(restrict buffer_rs_partitions rs_partitions)
|
||||
const uint32_t smem_offset_h = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x;
|
||||
const uint32_t smem_offset_l = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
{
|
||||
const uint32_t exc = rs_histogram.extent[ii];
|
||||
const uint32_t red = smem.extent[smem_offset_h + ii];
|
||||
@@ -702,7 +702,7 @@ rs_reduction_store(restrict buffer_rs_partitions rs_partitions,
|
||||
//
|
||||
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
{
|
||||
const uint32_t red = smem.extent[smem_offset + ii];
|
||||
|
||||
@@ -721,7 +721,7 @@ rs_reduction_store(restrict buffer_rs_partitions rs_partitions,
|
||||
//
|
||||
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
{
|
||||
const uint32_t red = smem.extent[smem_offset + ii];
|
||||
|
||||
@@ -788,7 +788,7 @@ rs_lookback_store(restrict buffer_rs_partitions rs_partitions,
|
||||
//
|
||||
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
{
|
||||
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
|
||||
uint32_t exc = 0;
|
||||
@@ -844,7 +844,7 @@ rs_lookback_store(restrict buffer_rs_partitions rs_partitions,
|
||||
//
|
||||
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
{
|
||||
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
|
||||
uint32_t exc = 0;
|
||||
@@ -1019,7 +1019,7 @@ rs_lookback_skip_store(restrict buffer_rs_partitions rs_partitions,
|
||||
//
|
||||
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
|
||||
{
|
||||
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
|
||||
uint32_t exc = 0;
|
||||
@@ -1064,7 +1064,7 @@ rs_lookback_skip_store(restrict buffer_rs_partitions rs_partitions,
|
||||
//
|
||||
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
|
||||
{
|
||||
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
|
||||
uint32_t exc = 0;
|
||||
@@ -1197,7 +1197,7 @@ void
|
||||
rs_rank_to_local(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
inout uint32_t kr[RS_SCATTER_BLOCK_ROWS])
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
|
||||
const uint32_t exc = smem.extent[RS_SMEM_HISTOGRAM_OFFSET + digit];
|
||||
@@ -1225,7 +1225,7 @@ rs_rank_to_global(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
//
|
||||
readonly RS_BUFREF_DEFINE(buffer_rs_histogram, rs_histogram, push.devaddr_histograms);
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
|
||||
const uint32_t exc = rs_histogram.extent[digit];
|
||||
@@ -1242,12 +1242,12 @@ rs_reorder(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], inout uint32_t kr[RS_
|
||||
{
|
||||
const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + invocation_id();
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++)
|
||||
{
|
||||
//
|
||||
// Store keyval dword to sorted location
|
||||
//
|
||||
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
{
|
||||
const uint32_t smem_idx = (RS_SMEM_REORDER_OFFSET - 1) + (kr[jj] >> 16);
|
||||
|
||||
@@ -1259,7 +1259,7 @@ rs_reorder(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], inout uint32_t kr[RS_
|
||||
//
|
||||
// Load keyval dword from sorted location
|
||||
//
|
||||
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
{
|
||||
RS_KV_DWORD(kv[jj], ii) = smem.extent[smem_base + jj * RS_WORKGROUP_SIZE];
|
||||
}
|
||||
@@ -1270,7 +1270,7 @@ rs_reorder(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], inout uint32_t kr[RS_
|
||||
//
|
||||
// Store the digit-index to sorted location
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t smem_idx = (RS_SMEM_REORDER_OFFSET - 1) + (kr[ii] >> 16);
|
||||
|
||||
@@ -1282,7 +1282,7 @@ rs_reorder(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], inout uint32_t kr[RS_
|
||||
//
|
||||
// Load kr[] from sorted location -- we only need the rank.
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
kr[ii] = smem.extent[smem_base + ii * RS_WORKGROUP_SIZE] & 0xFFFF;
|
||||
}
|
||||
@@ -1298,12 +1298,12 @@ rs_reorder_1(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
{
|
||||
const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + invocation_id();
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++)
|
||||
{
|
||||
//
|
||||
// Store keyval dword to sorted location
|
||||
//
|
||||
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
{
|
||||
const uint32_t smem_idx = RS_SMEM_REORDER_OFFSET + kr[jj];
|
||||
|
||||
@@ -1315,7 +1315,7 @@ rs_reorder_1(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
//
|
||||
// Load keyval dword from sorted location
|
||||
//
|
||||
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
|
||||
{
|
||||
RS_KV_DWORD(kv[jj], ii) = smem.extent[smem_base + jj * RS_WORKGROUP_SIZE];
|
||||
}
|
||||
@@ -1326,7 +1326,7 @@ rs_reorder_1(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
//
|
||||
// Store the digit-index to sorted location
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t smem_idx = RS_SMEM_REORDER_OFFSET + kr[ii];
|
||||
|
||||
@@ -1338,7 +1338,7 @@ rs_reorder_1(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
//
|
||||
// Load kr[] from sorted location -- we only need the rank.
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
kr[ii] = smem.extent[smem_base + ii * RS_WORKGROUP_SIZE];
|
||||
}
|
||||
@@ -1372,7 +1372,7 @@ rs_load(out RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS])
|
||||
//
|
||||
// Load keyvals
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
kv[ii] = rs_kv_in.extent[ii * RS_SUBGROUP_SIZE];
|
||||
}
|
||||
@@ -1385,7 +1385,7 @@ void
|
||||
rs_local_to_global(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
|
||||
inout uint32_t kr[RS_SCATTER_BLOCK_ROWS])
|
||||
{
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
|
||||
const uint32_t exc = smem.extent[RS_SMEM_LOOKBACK_OFFSET + digit];
|
||||
@@ -1413,7 +1413,7 @@ rs_store(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], const uint32_t kr[RS_SC
|
||||
// FIXME(allanmac): Consider implementing an aligned writeout
|
||||
// strategy to avoid excess global memory transactions.
|
||||
//
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
rs_kv_out.extent[kr[ii]] = kv[ii];
|
||||
}
|
||||
@@ -1513,7 +1513,7 @@ main()
|
||||
RS_DEVADDR_KEYVALS_OUT(push),
|
||||
gl_LocalInvocationID.x * 4);
|
||||
|
||||
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
|
||||
{
|
||||
rs_kv_out.extent[RS_GL_WORKGROUP_ID_X * RS_BLOCK_KEYVALS + ii * RS_WORKGROUP_SIZE] = kr[ii];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user