diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index eca441fcf39..1e2e3f0eebf 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -292,6 +292,15 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options) return instr->type == nir_instr_type_intrinsic; } +static nir_ssa_def * +build_subgroup_mask(nir_builder *b, unsigned bit_size, + const nir_lower_subgroups_options *options) +{ + return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size), + nir_isub(b, nir_imm_int(b, bit_size), + nir_load_subgroup_size(b))); +} + static nir_ssa_def * lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) { @@ -343,9 +352,6 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) const unsigned bit_size = MAX2(options->ballot_bit_size, intrin->dest.ssa.bit_size); - assert(options->subgroup_size <= 64); - uint64_t group_mask = ~0ull >> (64 - options->subgroup_size); - nir_ssa_def *count = nir_load_subgroup_invocation(b); nir_ssa_def *val; switch (intrin->intrinsic) { @@ -354,11 +360,11 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) break; case nir_intrinsic_load_subgroup_ge_mask: val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count), - nir_imm_intN_t(b, group_mask, bit_size)); + build_subgroup_mask(b, bit_size, options)); break; case nir_intrinsic_load_subgroup_gt_mask: val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count), - nir_imm_intN_t(b, group_mask, bit_size)); + build_subgroup_mask(b, bit_size, options)); break; case nir_intrinsic_load_subgroup_le_mask: val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));