nir/lower_subgroups: add generic scan/reduce lower
this is the lowering from NAK, fixed up for common code. the existing code is used for boolean scan/reduce. I make no guarantee that this works for subgroup sizes other than 32. Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io> Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28993>
This commit is contained in:
committed by
Marge Bot
parent
8b070c36ec
commit
b9a0c8dc6d
@@ -5824,6 +5824,7 @@ typedef struct nir_lower_subgroups_options {
|
||||
bool lower_rotate_to_shuffle : 1;
|
||||
bool lower_ballot_bit_count_to_mbcnt_amd : 1;
|
||||
bool lower_inverse_ballot : 1;
|
||||
bool lower_reduce : 1;
|
||||
bool lower_boolean_reduce : 1;
|
||||
bool lower_boolean_shuffle : 1;
|
||||
} nir_lower_subgroups_options;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
/*
|
||||
* Copyright © 2023 Collabora, Ltd.
|
||||
* Copyright © 2017 Intel Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
@@ -579,6 +580,167 @@ lower_boolean_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
|
||||
return nir_inverse_ballot(b, 1, val);
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
build_identity(nir_builder *b, unsigned bit_size, nir_op op)
|
||||
{
|
||||
nir_const_value ident_const = nir_alu_binop_identity(op, bit_size);
|
||||
return nir_build_imm(b, 1, bit_size, &ident_const);
|
||||
}
|
||||
|
||||
/* Implementation of scan/reduce that assumes a full subgroup */
|
||||
static nir_def *
|
||||
build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
||||
nir_def *data, unsigned cluster_size)
|
||||
{
|
||||
switch (op) {
|
||||
case nir_intrinsic_exclusive_scan:
|
||||
case nir_intrinsic_inclusive_scan: {
|
||||
for (unsigned i = 1; i < cluster_size; i *= 2) {
|
||||
nir_def *idx = nir_load_subgroup_invocation(b);
|
||||
nir_def *has_buddy = nir_ige_imm(b, idx, i);
|
||||
|
||||
nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i));
|
||||
nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
|
||||
data = nir_bcsel(b, has_buddy, accum, data);
|
||||
}
|
||||
|
||||
if (op == nir_intrinsic_exclusive_scan) {
|
||||
/* For exclusive scans, we need to shift one more time and fill in the
|
||||
* bottom channel with identity.
|
||||
*/
|
||||
nir_def *idx = nir_load_subgroup_invocation(b);
|
||||
nir_def *has_buddy = nir_ige_imm(b, idx, 1);
|
||||
|
||||
nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1));
|
||||
nir_def *identity = build_identity(b, data->bit_size, red_op);
|
||||
data = nir_bcsel(b, has_buddy, buddy_data, identity);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
case nir_intrinsic_reduce: {
|
||||
for (unsigned i = 1; i < cluster_size; i *= 2) {
|
||||
nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i));
|
||||
data = nir_build_alu2(b, red_op, data, buddy_data);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
default:
|
||||
unreachable("Unsupported scan/reduce op");
|
||||
}
|
||||
}
|
||||
|
||||
/* Fully generic implementation of scan/reduce that takes a mask */
|
||||
static nir_def *
|
||||
build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
||||
nir_def *data, nir_def *mask, unsigned max_mask_bits,
|
||||
unsigned subgroup_size)
|
||||
{
|
||||
nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, subgroup_size);
|
||||
|
||||
/* Mask of all channels whose values we need to accumulate. Our own value
|
||||
* is already in accum, if inclusive, thanks to the initialization above.
|
||||
* We only need to consider lower indexed invocations.
|
||||
*/
|
||||
nir_def *remaining = nir_iand(b, mask, lt_mask);
|
||||
|
||||
for (unsigned i = 1; i < max_mask_bits; i *= 2) {
|
||||
/* At each step, our buddy channel is the first channel we have yet to
|
||||
* take into account in the accumulator.
|
||||
*/
|
||||
nir_def *has_buddy = nir_ine_imm(b, remaining, 0);
|
||||
nir_def *buddy = nir_ufind_msb(b, remaining);
|
||||
|
||||
/* Accumulate with our buddy channel, if any */
|
||||
nir_def *buddy_data = nir_shuffle(b, data, buddy);
|
||||
nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
|
||||
data = nir_bcsel(b, has_buddy, accum, data);
|
||||
|
||||
/* We just took into account everything in our buddy's accumulator from
|
||||
* the previous step. The only things remaining are whatever channels
|
||||
* were remaining for our buddy.
|
||||
*/
|
||||
nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy);
|
||||
remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0));
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
case nir_intrinsic_exclusive_scan: {
|
||||
/* For exclusive scans, we need to shift one more time and fill in the
|
||||
* bottom channel with identity.
|
||||
*
|
||||
* Some of this will get CSE'd with the first step but that's okay. The
|
||||
* code is cleaner this way.
|
||||
*/
|
||||
nir_def *lower = nir_iand(b, mask, lt_mask);
|
||||
nir_def *has_buddy = nir_ine_imm(b, lower, 0);
|
||||
nir_def *buddy = nir_ufind_msb(b, lower);
|
||||
|
||||
nir_def *buddy_data = nir_shuffle(b, data, buddy);
|
||||
nir_def *identity = build_identity(b, data->bit_size, red_op);
|
||||
return nir_bcsel(b, has_buddy, buddy_data, identity);
|
||||
}
|
||||
|
||||
case nir_intrinsic_inclusive_scan:
|
||||
return data;
|
||||
|
||||
case nir_intrinsic_reduce: {
|
||||
/* For reductions, we need to take the top value of the scan */
|
||||
nir_def *idx = nir_ufind_msb(b, mask);
|
||||
return nir_shuffle(b, data, idx);
|
||||
}
|
||||
|
||||
default:
|
||||
unreachable("Unsupported scan/reduce op");
|
||||
}
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
|
||||
unsigned subgroup_size)
|
||||
{
|
||||
const nir_op red_op = nir_intrinsic_reduction_op(intrin);
|
||||
|
||||
/* Grab the cluster size */
|
||||
unsigned cluster_size = subgroup_size;
|
||||
if (nir_intrinsic_has_cluster_size(intrin)) {
|
||||
cluster_size = nir_intrinsic_cluster_size(intrin);
|
||||
if (cluster_size == 0 || cluster_size > subgroup_size)
|
||||
cluster_size = subgroup_size;
|
||||
}
|
||||
|
||||
/* Check if all invocations are active. If so, we use the fast path. */
|
||||
nir_def *mask = nir_ballot(b, 1, subgroup_size, nir_imm_true(b));
|
||||
|
||||
nir_def *full, *partial;
|
||||
nir_push_if(b, nir_ieq_imm(b, mask, -1));
|
||||
{
|
||||
full = build_scan_full(b, intrin->intrinsic, red_op,
|
||||
intrin->src[0].ssa, cluster_size);
|
||||
}
|
||||
nir_push_else(b, NULL);
|
||||
{
|
||||
/* Mask according to the cluster size */
|
||||
if (cluster_size < subgroup_size) {
|
||||
nir_def *idx = nir_load_subgroup_invocation(b);
|
||||
nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
|
||||
|
||||
nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size));
|
||||
cluster_mask = nir_ishl(b, cluster_mask, cluster);
|
||||
|
||||
mask = nir_iand(b, mask, cluster_mask);
|
||||
}
|
||||
|
||||
partial = build_scan_reduce(b, intrin->intrinsic, red_op,
|
||||
intrin->src[0].ssa, mask, cluster_size,
|
||||
subgroup_size);
|
||||
}
|
||||
nir_pop_if(b, NULL);
|
||||
return nir_if_phi(b, full, partial);
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_subgroups_filter(const nir_instr *instr, const void *_options)
|
||||
{
|
||||
@@ -1048,16 +1210,22 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
|
||||
return intrin->src[0].ssa;
|
||||
if (options->lower_to_scalar && intrin->num_components > 1)
|
||||
return lower_subgroup_op_to_scalar(b, intrin);
|
||||
if (options->lower_boolean_reduce && intrin->def.bit_size == 1)
|
||||
if (intrin->def.bit_size == 1 &&
|
||||
(options->lower_boolean_reduce || options->lower_reduce))
|
||||
return lower_boolean_reduce(b, intrin, options);
|
||||
if (options->lower_reduce)
|
||||
return lower_scan_reduce(b, intrin, options->subgroup_size);
|
||||
return ret;
|
||||
}
|
||||
case nir_intrinsic_inclusive_scan:
|
||||
case nir_intrinsic_exclusive_scan:
|
||||
if (options->lower_to_scalar && intrin->num_components > 1)
|
||||
return lower_subgroup_op_to_scalar(b, intrin);
|
||||
if (options->lower_boolean_reduce && intrin->def.bit_size == 1)
|
||||
if (intrin->def.bit_size == 1 &&
|
||||
(options->lower_boolean_reduce || options->lower_reduce))
|
||||
return lower_boolean_reduce(b, intrin, options);
|
||||
if (options->lower_reduce)
|
||||
return lower_scan_reduce(b, intrin, options->subgroup_size);
|
||||
break;
|
||||
|
||||
case nir_intrinsic_rotate:
|
||||
|
||||
Reference in New Issue
Block a user