nir/range_analysis: use perform_analysis() in nir_analyze_range()

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21381>
This commit is contained in:
Rhys Perry
2023-02-14 21:38:41 +00:00
committed by Marge Bot
parent 2b03db39b3
commit e99ba0b6d3
+170 -104
View File
@@ -126,18 +126,15 @@ is_not_zero(enum ssa_ranges r)
return r == gt_zero || r == lt_zero || r == ne_zero;
}
static void *
static uint32_t
pack_data(const struct ssa_result_range r)
{
return (void *)(uintptr_t)(r.range | r.is_integral << 8 | r.is_finite << 9 |
r.is_a_number << 10);
return r.range | r.is_integral << 8 | r.is_finite << 9 | r.is_a_number << 10;
}
static struct ssa_result_range
unpack_data(const void *p)
unpack_data(uint32_t v)
{
const uintptr_t v = (uintptr_t) p;
return (struct ssa_result_range){
.range = v & 0xff,
.is_integral = (v & 0x00100) != 0,
@@ -146,31 +143,6 @@ unpack_data(const void *p)
};
}
static void *
pack_key(const struct nir_alu_instr *instr, nir_alu_type type)
{
uintptr_t type_encoding;
uintptr_t ptr = (uintptr_t) instr;
/* The low 2 bits have to be zero or this whole scheme falls apart. */
assert((ptr & 0x3) == 0);
/* NIR is typeless in the sense that sequences of bits have whatever
* meaning is attached to them by the instruction that consumes them.
* However, the number of bits must match between producer and consumer.
* As a result, the number of bits does not need to be encoded here.
*/
switch (nir_alu_type_get_base_type(type)) {
case nir_type_int: type_encoding = 0; break;
case nir_type_uint: type_encoding = 1; break;
case nir_type_bool: type_encoding = 2; break;
case nir_type_float: type_encoding = 3; break;
default: unreachable("Invalid base type.");
}
return (void *)(ptr | type_encoding);
}
static nir_alu_type
nir_alu_src_type(const nir_alu_instr *instr, unsigned src)
{
@@ -319,7 +291,7 @@ analyze_constant(const struct nir_alu_instr *instr, unsigned src,
}
/**
* Short-hand name for use in the tables in analyze_expression. If this name
* Short-hand name for use in the tables in process_fp_query. If this name
* becomes a problem on some compiler, we can change it to _.
*/
#define _______ unknown
@@ -502,6 +474,53 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b)
#define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(t)
#endif /* !defined(NDEBUG) */
struct fp_query {
struct analysis_query head;
const nir_alu_instr *instr;
unsigned src;
nir_alu_type use_type;
};
static void
push_fp_query(struct analysis_state *state, const nir_alu_instr *alu, unsigned src, nir_alu_type type)
{
struct fp_query *pushed_q = push_analysis_query(state, sizeof(struct fp_query));
pushed_q->instr = alu;
pushed_q->src = src;
pushed_q->use_type = type == nir_type_invalid ? nir_alu_src_type(alu, src) : type;
}
static uintptr_t
get_fp_key(struct analysis_query *q)
{
struct fp_query *fp_q = (struct fp_query *)q;
const nir_src *src = &fp_q->instr->src[fp_q->src].src;
if (!src->is_ssa || src->ssa->parent_instr->type != nir_instr_type_alu)
return 0;
uintptr_t type_encoding;
uintptr_t ptr = (uintptr_t)nir_instr_as_alu(src->ssa->parent_instr);
/* The low 2 bits have to be zero or this whole scheme falls apart. */
assert((ptr & 0x3) == 0);
/* NIR is typeless in the sense that sequences of bits have whatever
* meaning is attached to them by the instruction that consumes them.
* However, the number of bits must match between producer and consumer.
* As a result, the number of bits does not need to be encoded here.
*/
switch (nir_alu_type_get_base_type(fp_q->use_type)) {
case nir_type_int: type_encoding = 0; break;
case nir_type_uint: type_encoding = 1; break;
case nir_type_bool: type_encoding = 2; break;
case nir_type_float: type_encoding = 3; break;
default: unreachable("Invalid base type.");
}
return ptr | type_encoding;
}
/**
* Analyze an expression to determine the range of its result
*
@@ -511,21 +530,32 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b)
* This function implements this grammar as a recursive-descent parser. Some
* (but not all) of the grammar is listed in-line in the function.
*/
static struct ssa_result_range
analyze_expression(const nir_alu_instr *instr, unsigned src,
struct hash_table *ht, nir_alu_type use_type)
static void
process_fp_query(struct analysis_state *state, struct analysis_query *aq, uint32_t *result,
const uint32_t *src_res)
{
/* Ensure that the _Pragma("GCC unroll 7") above are correct. */
STATIC_ASSERT(last_range + 1 == 7);
if (!instr->src[src].src.is_ssa)
return (struct ssa_result_range){unknown, false, false, false};
struct fp_query q = *(struct fp_query *)aq;
const nir_alu_instr *instr = q.instr;
unsigned src = q.src;
nir_alu_type use_type = q.use_type;
if (nir_src_is_const(instr->src[src].src))
return analyze_constant(instr, src, use_type);
if (!instr->src[src].src.is_ssa) {
*result = pack_data((struct ssa_result_range){unknown, false, false, false});
return;
}
if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
return (struct ssa_result_range){unknown, false, false, false};
if (nir_src_is_const(instr->src[src].src)) {
*result = pack_data(analyze_constant(instr, src, use_type));
return;
}
if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) {
*result = pack_data((struct ssa_result_range){unknown, false, false, false});
return;
}
const struct nir_alu_instr *const alu =
nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
@@ -544,13 +574,62 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
if (use_base_type != src_base_type &&
(use_base_type == nir_type_float ||
src_base_type == nir_type_float)) {
return (struct ssa_result_range){unknown, false, false, false};
*result = pack_data((struct ssa_result_range){unknown, false, false, false});
return;
}
}
struct hash_entry *he = _mesa_hash_table_search(ht, pack_key(alu, use_type));
if (he != NULL)
return unpack_data(he->data);
if (!aq->pushed_queries) {
switch (alu->op) {
case nir_op_bcsel:
push_fp_query(state, alu, 1, use_type);
push_fp_query(state, alu, 2, use_type);
return;
case nir_op_mov:
push_fp_query(state, alu, 0, use_type);
return;
case nir_op_i2f32:
case nir_op_u2f32:
case nir_op_fabs:
case nir_op_fexp2:
case nir_op_frcp:
case nir_op_fneg:
case nir_op_fsat:
case nir_op_fsign:
case nir_op_ffloor:
case nir_op_fceil:
case nir_op_ftrunc:
case nir_op_fdot2:
case nir_op_fdot3:
case nir_op_fdot4:
case nir_op_fdot8:
case nir_op_fdot16:
case nir_op_fdot2_replicated:
case nir_op_fdot3_replicated:
case nir_op_fdot4_replicated:
case nir_op_fdot8_replicated:
case nir_op_fdot16_replicated:
push_fp_query(state, alu, 0, nir_type_invalid);
return;
case nir_op_fadd:
case nir_op_fmax:
case nir_op_fmin:
case nir_op_fmul:
case nir_op_fmulz:
case nir_op_fpow:
push_fp_query(state, alu, 0, nir_type_invalid);
push_fp_query(state, alu, 1, nir_type_invalid);
return;
case nir_op_ffma:
case nir_op_flrp:
push_fp_query(state, alu, 0, nir_type_invalid);
push_fp_query(state, alu, 1, nir_type_invalid);
push_fp_query(state, alu, 2, nir_type_invalid);
return;
default:
break;
}
}
struct ssa_result_range r = {unknown, false, false, false};
@@ -666,10 +745,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
break;
case nir_op_bcsel: {
const struct ssa_result_range left =
analyze_expression(alu, 1, ht, use_type);
const struct ssa_result_range right =
analyze_expression(alu, 2, ht, use_type);
const struct ssa_result_range left = unpack_data(src_res[0]);
const struct ssa_result_range right = unpack_data(src_res[1]);
r.is_integral = left.is_integral && right.is_integral;
@@ -694,7 +771,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
case nir_op_i2f32:
case nir_op_u2f32:
r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
r = unpack_data(src_res[0]);
r.is_integral = true;
r.is_a_number = true;
@@ -706,7 +783,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
break;
case nir_op_fabs:
r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
r = unpack_data(src_res[0]);
switch (r.range) {
case unknown:
@@ -728,10 +805,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
break;
case nir_op_fadd: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range right =
analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
const struct ssa_result_range left = unpack_data(src_res[0]);
const struct ssa_result_range right = unpack_data(src_res[1]);
r.is_integral = left.is_integral && right.is_integral;
r.range = fadd_table[left.range][right.range];
@@ -755,7 +830,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
ge_zero, ge_zero, ge_zero, gt_zero, gt_zero, ge_zero, gt_zero
};
r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
r = unpack_data(src_res[0]);
ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(table);
ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(table);
@@ -770,10 +845,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
}
case nir_op_fmax: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range right =
analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
const struct ssa_result_range left = unpack_data(src_res[0]);
const struct ssa_result_range right = unpack_data(src_res[1]);
r.is_integral = left.is_integral && right.is_integral;
@@ -856,10 +929,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
}
case nir_op_fmin: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range right =
analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
const struct ssa_result_range left = unpack_data(src_res[0]);
const struct ssa_result_range right = unpack_data(src_res[1]);
r.is_integral = left.is_integral && right.is_integral;
@@ -943,10 +1014,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
case nir_op_fmul:
case nir_op_fmulz: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range right =
analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
const struct ssa_result_range left = unpack_data(src_res[0]);
const struct ssa_result_range right = unpack_data(src_res[1]);
r.is_integral = left.is_integral && right.is_integral;
@@ -981,7 +1050,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
case nir_op_frcp:
r = (struct ssa_result_range){
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
unpack_data(src_res[0]).range,
false,
false, /* Various cases can result in NaN, so assume the worst. */
false /* " " " " " " " " " " */
@@ -989,18 +1058,16 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
break;
case nir_op_mov:
r = analyze_expression(alu, 0, ht, use_type);
r = unpack_data(src_res[0]);
break;
case nir_op_fneg:
r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
r = unpack_data(src_res[0]);
r.range = fneg_table[r.range];
break;
case nir_op_fsat: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range left = unpack_data(src_res[0]);
/* fsat(NaN) = 0. */
r.is_a_number = true;
@@ -1035,7 +1102,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
case nir_op_fsign:
r = (struct ssa_result_range){
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
unpack_data(src_res[0]).range,
true,
true, /* fsign is -1, 0, or 1, even for NaN, so it must be a number. */
true /* fsign is -1, 0, or 1, even for NaN, so it must be finite. */
@@ -1048,8 +1115,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
break;
case nir_op_ffloor: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range left = unpack_data(src_res[0]);
r.is_integral = true;
@@ -1070,8 +1136,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
}
case nir_op_fceil: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range left = unpack_data(src_res[0]);
r.is_integral = true;
@@ -1092,8 +1157,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
}
case nir_op_ftrunc: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range left = unpack_data(src_res[0]);
r.is_integral = true;
@@ -1139,8 +1203,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
case nir_op_fdot4_replicated:
case nir_op_fdot8_replicated:
case nir_op_fdot16_replicated: {
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range left = unpack_data(src_res[0]);
/* If the two sources are the same SSA value, then the result is either
* NaN or some number >= 0. If one source is the negation of the other,
@@ -1211,10 +1274,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
/* eq_zero */ { ge_zero, gt_zero, gt_zero, eq_zero, ge_zero, ge_zero, gt_zero },
};
const struct ssa_result_range left =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range right =
analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
const struct ssa_result_range left = unpack_data(src_res[0]);
const struct ssa_result_range right = unpack_data(src_res[1]);
ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(table);
ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(table);
@@ -1230,12 +1291,9 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
}
case nir_op_ffma: {
const struct ssa_result_range first =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range second =
analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
const struct ssa_result_range third =
analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
const struct ssa_result_range first = unpack_data(src_res[0]);
const struct ssa_result_range second = unpack_data(src_res[1]);
const struct ssa_result_range third = unpack_data(src_res[2]);
r.is_integral = first.is_integral && second.is_integral &&
third.is_integral;
@@ -1261,12 +1319,9 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
}
case nir_op_flrp: {
const struct ssa_result_range first =
analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
const struct ssa_result_range second =
analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
const struct ssa_result_range third =
analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
const struct ssa_result_range first = unpack_data(src_res[0]);
const struct ssa_result_range second = unpack_data(src_res[1]);
const struct ssa_result_range third = unpack_data(src_res[2]);
r.is_integral = first.is_integral && second.is_integral &&
third.is_integral;
@@ -1296,18 +1351,29 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
/* Just like isfinite(), the is_finite flag implies the value is a number. */
assert((int) r.is_finite <= (int) r.is_a_number);
_mesa_hash_table_insert(ht, pack_key(alu, use_type), pack_data(r));
return r;
*result = pack_data(r);
}
#undef _______
struct ssa_result_range
nir_analyze_range(struct hash_table *range_ht,
const nir_alu_instr *instr, unsigned src)
const nir_alu_instr *alu, unsigned src)
{
return analyze_expression(instr, src, range_ht,
nir_alu_src_type(instr, src));
struct fp_query query_alloc[64];
uint32_t result_alloc[64];
struct analysis_state state;
state.range_ht = range_ht;
util_dynarray_init_from_stack(&state.query_stack, query_alloc, sizeof(query_alloc));
util_dynarray_init_from_stack(&state.result_stack, result_alloc, sizeof(result_alloc));
state.query_size = sizeof(struct fp_query);
state.get_key = &get_fp_key;
state.process_query = &process_fp_query;
push_fp_query(&state, alu, src, nir_type_invalid);
return unpack_data(perform_analysis(&state));
}
static uint32_t bitmask(uint32_t size) {