nir/loop_analyze: check min compatibility with comparison

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Acked-by: Timothy Arceri <tarceri@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26225>
This commit is contained in:
Rhys Perry
2023-11-14 20:26:44 +00:00
committed by Marge Bot
parent b6c2a5d48d
commit 9591c36666
2 changed files with 182 additions and 4 deletions
+38 -4
View File
@@ -671,15 +671,49 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val,
return false;
}
static nir_op invert_comparison_if_needed(nir_op alu_op, bool invert);
/* Returns whether "limit_op(a, b) alu_op c" is equivalent to "(a alu_op c) || (b alu_op c)". */
static bool
try_find_limit_of_alu(nir_scalar limit, nir_const_value *limit_val,
nir_loop_terminator *terminator, loop_info_state *state)
is_min_compatible(nir_op limit_op, nir_op alu_op, bool limit_rhs, bool invert_cond)
{
switch (limit_op) {
case nir_op_imin:
case nir_op_fmin:
break;
default:
return false;
}
if (nir_op_infos[limit_op].input_types[0] != nir_op_infos[alu_op].input_types[0])
return false;
/* Comparisons we can split are:
* - min(a, b) < c
* - c >= min(a, b)
*/
switch (invert_comparison_if_needed(alu_op, invert_cond)) {
case nir_op_ilt:
case nir_op_flt:
return !limit_rhs;
case nir_op_ige:
case nir_op_fge:
return limit_rhs;
default:
return false;
}
}
static bool
try_find_limit_of_alu(nir_scalar limit, nir_const_value *limit_val, nir_op alu_op,
bool invert_cond, nir_loop_terminator *terminator,
loop_info_state *state)
{
if (!nir_scalar_is_alu(limit))
return false;
nir_op limit_op = nir_scalar_alu_op(limit);
if (limit_op == nir_op_imin || limit_op == nir_op_fmin) {
if (is_min_compatible(limit_op, alu_op, !terminator->induction_rhs, invert_cond)) {
for (unsigned i = 0; i < 2; i++) {
nir_scalar src = nir_scalar_chase_alu_src(limit, i);
if (nir_scalar_is_const(src)) {
@@ -1308,7 +1342,7 @@ find_trip_count(loop_info_state *state, unsigned execution_mode,
} else {
trip_count_known = false;
if (!try_find_limit_of_alu(limit, &limit_val, terminator, state)) {
if (!try_find_limit_of_alu(limit, &limit_val, alu_op, invert_cond, terminator, state)) {
/* Guess loop limit based on array access */
if (!guess_loop_limit(state, &limit_val, basic_ind)) {
terminator->exact_trip_count_unknown = true;
@@ -286,6 +286,28 @@ COMPARE_REVERSE(ishl)
INOT_COMPARE(ilt_rev)
INOT_COMPARE(ine)
#define CMP_MIN(cmp, min) \
static nir_def *nir_##cmp##_##min(nir_builder *b, nir_def *counter, nir_def *limit) \
{ \
nir_def *unk = nir_load_vertex_id(b); \
return nir_##cmp(b, counter, nir_##min(b, limit, unk)); \
}
#define CMP_MIN_REV(cmp, min) \
static nir_def *nir_##cmp##_##min##_rev(nir_builder *b, nir_def *counter, nir_def *limit) \
{ \
nir_def *unk = nir_load_vertex_id(b); \
return nir_##cmp(b, nir_##min(b, limit, unk), counter); \
}
CMP_MIN(ige, imin)
CMP_MIN_REV(ige, imin)
CMP_MIN(ige, fmin)
CMP_MIN(uge, imin)
CMP_MIN(ilt, imin)
CMP_MIN_REV(ilt, imin)
INOT_COMPARE(ilt_imin_rev)
#define KNOWN_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr, count) \
TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _known_count_ ## count) \
{ \
@@ -320,6 +342,40 @@ INOT_COMPARE(ine)
} \
}
#define INEXACT_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr, count) \
TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _inexact_count_ ## count) \
{ \
nir_loop *loop = \
loop_builder(&b, {.init_value = _init_value, \
.cond_value = _cond_value, \
.incr_value = _incr_value, \
.cond_instr = nir_ ## cond, \
.incr_instr = nir_ ## incr}); \
\
nir_validate_shader(b.shader, "input"); \
\
nir_loop_analyze_impl(b.impl, nir_var_all, false); \
\
ASSERT_NE((void *)0, loop->info); \
EXPECT_NE((void *)0, loop->info->limiting_terminator); \
EXPECT_EQ(count, loop->info->max_trip_count); \
EXPECT_FALSE(loop->info->exact_trip_count_known); \
\
EXPECT_EQ(2, loop->info->num_induction_vars); \
ASSERT_NE((void *)0, loop->info->induction_vars); \
\
const nir_loop_induction_variable *const ivars = \
loop->info->induction_vars; \
\
for (unsigned i = 0; i < loop->info->num_induction_vars; i++) { \
EXPECT_NE((void *)0, ivars[i].def); \
ASSERT_NE((void *)0, ivars[i].init_src); \
EXPECT_TRUE(nir_src_is_const(*ivars[i].init_src)); \
ASSERT_NE((void *)0, ivars[i].update_src); \
EXPECT_TRUE(nir_src_is_const(ivars[i].update_src->src)); \
} \
}
#define UNKNOWN_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr) \
TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _unknown_count) \
{ \
@@ -567,6 +623,16 @@ KNOWN_COUNT_TEST_INVERT(0x00000000, 0x00000001, 0x00000006, ige, iadd, 5)
*/
KNOWN_COUNT_TEST(0x0000000a, 0x00000005, 0xffffffff, inot_ilt_rev, iadd, 5)
/* int i = 10;
* while (true) {
* if (!(imin(vertex_id, 5) < i))
* break;
*
* i += -1;
* }
*/
UNKNOWN_COUNT_TEST(0x0000000a, 0x00000005, 0xffffffff, inot_ilt_imin_rev, iadd)
/* uint i = 0;
* while (true) {
* if (i != 0)
@@ -1459,3 +1525,81 @@ KNOWN_COUNT_TEST_INVERT(0x0000007f, 0x00000003, 0x00000001, ilt, imul, 16)
* }
*/
KNOWN_COUNT_TEST_INVERT(0xffff7fff, 0x0000000f, 0x34cce9b0, ige, imul, 4)
/* int i = 0;
* while (true) {
* if (i >= imin(vertex_id, 4))
* break;
*
* i++;
* }
*/
INEXACT_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, ige_imin, iadd, 4)
/* This fmin is the wrong type to be useful.
*
* int i = 0;
* while (true) {
* if (i >= fmin(vertex_id, 4))
* break;
*
* i++;
* }
*/
UNKNOWN_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, ige_fmin, iadd)
/* The comparison is unsigned, so this isn't safe if vertex_id is negative.
*
* uint i = 0;
* while (true) {
* if (i >= imin(vertex_id, 4))
* break;
*
* i++;
* }
*/
UNKNOWN_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, uge_imin, iadd)
/* int i = 8;
* while (true) {
* if (4 >= i)
* break;
*
* i += -1;
* }
*/
KNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ige_rev, iadd, 4)
/* int i = 8;
* while (true) {
* if (i < 4)
* break;
*
* i += -1;
* }
*/
KNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ilt, iadd, 5)
/* This imin can increase the iteration count, not limit it.
*
* int i = 8;
* while (true) {
* if (imin(vertex_id, 4) >= i)
* break;
*
* i += -1;
* }
*/
UNKNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ige_imin_rev, iadd)
/* This imin can increase the iteration count, not limit it.
*
* int i = 8;
* while (true) {
* if (i < imin(vertex_id, 4))
* break;
*
* i += -1;
* }
*/
UNKNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ilt_imin, iadd)