nir,agx: lower fmin/fmax in NIR
we want to elide flushes, doing so requires more sophisticated analysis than I'd like in the middle of isel. also, it should be done before forming preambles for efficiency (notice the uniform reduction here). let's do it with a NIR pass. total instructions in shared programs: 2768481 -> 2757832 (-0.38%) instructions in affected programs: 644084 -> 633435 (-1.65%) helped: 2242 HURT: 18 helped stats (abs) min: 1 max: 349 x̄: 4.77 x̃: 3 helped stats (rel) min: 0.01% max: 34.91% x̄: 3.19% x̃: 2.19% HURT stats (abs) min: 1 max: 19 x̄: 2.89 x̃: 1 HURT stats (rel) min: 0.24% max: 7.94% x̄: 1.27% x̃: 0.81% 95% mean confidence interval for instructions value: -5.20 -4.22 95% mean confidence interval for instructions %-change: -3.30% -3.01% Instructions are helped. total alu in shared programs: 2182880 -> 2172352 (-0.48%) alu in affected programs: 513166 -> 502638 (-2.05%) helped: 2235 HURT: 16 helped stats (abs) min: 1 max: 349 x̄: 4.73 x̃: 3 helped stats (rel) min: 0.02% max: 37.65% x̄: 3.70% x̃: 2.59% HURT stats (abs) min: 1 max: 19 x̄: 2.50 x̃: 1 HURT stats (rel) min: 0.33% max: 3.74% x̄: 1.04% x̃: 0.91% 95% mean confidence interval for alu value: -5.16 -4.20 95% mean confidence interval for alu %-change: -3.83% -3.49% Alu are helped. total fscib in shared programs: 2178643 -> 2168059 (-0.49%) fscib in affected programs: 514666 -> 504082 (-2.06%) helped: 2243 HURT: 17 helped stats (abs) min: 1 max: 349 x̄: 4.74 x̃: 3 helped stats (rel) min: 0.02% max: 37.65% x̄: 3.74% x̃: 2.59% HURT stats (abs) min: 1 max: 19 x̄: 2.65 x̃: 1 HURT stats (rel) min: 0.33% max: 14.71% x̄: 1.85% x̃: 0.93% 95% mean confidence interval for fscib value: -5.16 -4.20 95% mean confidence interval for fscib %-change: -3.87% -3.53% Fscib are helped. total bytes in shared programs: 18467348 -> 18403042 (-0.35%) bytes in affected programs: 4403648 -> 4339342 (-1.46%) helped: 2247 HURT: 20 helped stats (abs) min: 2 max: 2132 x̄: 28.73 x̃: 18 helped stats (rel) min: 0.01% max: 33.53% x̄: 2.80% x̃: 1.94% HURT stats (abs) min: 4 max: 72 x̄: 12.60 x̃: 6 HURT stats (rel) min: 0.23% max: 6.58% x̄: 1.06% x̃: 0.75% 95% mean confidence interval for bytes value: -31.29 -25.45 95% mean confidence interval for bytes %-change: -2.90% -2.64% Bytes are helped. total regs in shared programs: 864605 -> 864442 (-0.02%) regs in affected programs: 4692 -> 4529 (-3.47%) helped: 68 HURT: 48 helped stats (abs) min: 1 max: 54 x̄: 7.25 x̃: 3 helped stats (rel) min: 4.26% max: 43.20% x̄: 13.21% x̃: 10.53% HURT stats (abs) min: 1 max: 36 x̄: 6.88 x̃: 6 HURT stats (rel) min: 3.64% max: 91.67% x̄: 23.12% x̃: 24.00% 95% mean confidence interval for regs value: -3.60 0.79 95% mean confidence interval for regs %-change: -2.10% 5.75% Inconclusive result (value mean confidence interval includes 0). total uniforms in shared programs: 2120927 -> 2120911 (<.01%) uniforms in affected programs: 770 -> 754 (-2.08%) helped: 6 HURT: 0 helped stats (abs) min: 2 max: 4 x̄: 2.67 x̃: 2 helped stats (rel) min: 1.79% max: 2.70% x̄: 2.13% x̃: 1.96% 95% mean confidence interval for uniforms value: -3.75 -1.58 95% mean confidence interval for uniforms %-change: -2.50% -1.76% Uniforms are helped. total threads in shared programs: 27612224 -> 27613056 (<.01%) threads in affected programs: 7168 -> 8000 (11.61%) helped: 6 HURT: 3 helped stats (abs) min: 64 max: 192 x̄: 170.67 x̃: 192 helped stats (rel) min: 8.33% max: 23.08% x̄: 20.62% x̃: 23.08% HURT stats (abs) min: 64 max: 64 x̄: 64.00 x̃: 64 HURT stats (rel) min: 8.33% max: 9.09% x̄: 8.59% x̃: 8.33% 95% mean confidence interval for threads value: -3.17 188.06 95% mean confidence interval for threads %-change: -0.92% 22.69% Inconclusive result (value mean confidence interval includes 0). Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31908>
This commit is contained in:
@@ -1756,20 +1756,17 @@ agx_fminmax_to(agx_builder *b, agx_index dst, agx_index s0, agx_index s1,
|
||||
assert(!nir_alu_instr_is_signed_zero_preserve(alu) &&
|
||||
"should've been lowered");
|
||||
|
||||
bool fmax = alu->op == nir_op_fmax;
|
||||
assert((alu->def.bit_size == 16) ==
|
||||
(alu->op == nir_op_fmin || alu->op == nir_op_fmax) &&
|
||||
"fp32 should be lowered");
|
||||
|
||||
bool fmax = alu->op == nir_op_fmax || alu->op == nir_op_fmax_agx;
|
||||
enum agx_fcond fcond = fmax ? AGX_FCOND_GTN : AGX_FCOND_LTN;
|
||||
|
||||
/* Calculate min/max with the appropriate hardware instruction */
|
||||
agx_index tmp = agx_fcmpsel(b, s0, s1, s0, s1, fcond);
|
||||
|
||||
/* G13 flushes fp32 denorms and preserves fp16 denorms. Since cmpsel
|
||||
* preserves denorms, we need to canonicalize for fp32. Canonicalizing fp16
|
||||
* would be harmless but wastes an instruction.
|
||||
/* Calculate min/max with the appropriate hardware instruction. This will not
|
||||
* handle denorms, but we were already lowered for that.
|
||||
*/
|
||||
if (alu->def.bit_size == 32)
|
||||
return agx_fadd_to(b, dst, tmp, agx_negzero());
|
||||
else
|
||||
return agx_mov_to(b, dst, tmp);
|
||||
return agx_fcmpsel_to(b, dst, s0, s1, s0, s1, fcond);
|
||||
}
|
||||
|
||||
static agx_instr *
|
||||
@@ -1893,6 +1890,8 @@ agx_emit_alu(agx_builder *b, nir_alu_instr *instr)
|
||||
|
||||
case nir_op_fmin:
|
||||
case nir_op_fmax:
|
||||
case nir_op_fmin_agx:
|
||||
case nir_op_fmax_agx:
|
||||
return agx_fminmax_to(b, dst, s0, s1, instr);
|
||||
|
||||
case nir_op_imin:
|
||||
@@ -3035,6 +3034,11 @@ agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size)
|
||||
} while (progress);
|
||||
}
|
||||
|
||||
/* Lower fmin/fmax before optimizing preambles so we can see across uniform
|
||||
* expressions.
|
||||
*/
|
||||
NIR_PASS(_, nir, agx_nir_lower_fminmax);
|
||||
|
||||
if (preamble_size && (!(agx_compiler_debug & AGX_DBG_NOPREAMBLE)))
|
||||
NIR_PASS(_, nir, agx_nir_opt_preamble, preamble_size);
|
||||
|
||||
@@ -3048,7 +3052,12 @@ agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size)
|
||||
NIR_PASS(_, nir, nir_opt_peephole_select, 64, false, true);
|
||||
NIR_PASS(_, nir, nir_lower_int64);
|
||||
|
||||
/* We need to lower fmin/fmax again after nir_opt_algebraic_late due to f2fmp
|
||||
* wackiness. This is usually a no-op but is required for correctness in
|
||||
* GLES.
|
||||
*/
|
||||
NIR_PASS(_, nir, nir_opt_algebraic_late);
|
||||
NIR_PASS(_, nir, agx_nir_lower_fminmax);
|
||||
|
||||
/* Fuse add/sub/multiplies/shifts after running opt_algebraic_late to fuse
|
||||
* isub but before shifts are lowered.
|
||||
|
||||
@@ -16,3 +16,4 @@ bool agx_nir_fence_images(struct nir_shader *shader);
|
||||
bool agx_nir_lower_layer(struct nir_shader *s);
|
||||
bool agx_nir_lower_clip_distance(struct nir_shader *s);
|
||||
bool agx_nir_lower_subgroups(struct nir_shader *s);
|
||||
bool agx_nir_lower_fminmax(struct nir_shader *s);
|
||||
|
||||
77
src/asahi/compiler/agx_nir_lower_fminmax.c
Normal file
77
src/asahi/compiler/agx_nir_lower_fminmax.c
Normal file
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright 2024 Valve Corporation
|
||||
* SPDX-License-Identifier: MIT
|
||||
*/
|
||||
|
||||
#include "compiler/nir/nir.h"
|
||||
#include "compiler/nir/nir_builder.h"
|
||||
#include "agx_nir.h"
|
||||
|
||||
/*
|
||||
* AGX generally flushes FP32 denorms. However, the min/max instructions do not
|
||||
* as they are implemented with cmpsel. We need to flush the results of fp32
|
||||
* min/max for correctness. Doing so naively will generate redundant flushes, so
|
||||
* this pass tries to be clever and elide flushes when possible.
|
||||
*
|
||||
* This pass is still pretty simple, it doesn't see through phis or bcsels yet.
|
||||
*/
|
||||
static bool
|
||||
could_be_denorm(nir_scalar s)
|
||||
{
|
||||
/* Constants can be denorms only if they are denorms. */
|
||||
if (nir_scalar_is_const(s)) {
|
||||
return fpclassify(nir_scalar_as_float(s)) == FP_SUBNORMAL;
|
||||
}
|
||||
|
||||
/* Floating-point instructions flush denormals, so ALU results can only be
|
||||
* denormal if they are not from a float instruction. Crucially fmin/fmax
|
||||
* flushes in NIR, so this pass handles chains of fmin/fmax properly.
|
||||
*/
|
||||
if (nir_scalar_is_alu(s)) {
|
||||
nir_op op = nir_scalar_alu_op(s);
|
||||
nir_alu_type T = nir_op_infos[op].output_type;
|
||||
|
||||
return nir_alu_type_get_base_type(T) != nir_type_float &&
|
||||
op != nir_op_fmin_agx && op != nir_op_fmax_agx;
|
||||
}
|
||||
|
||||
/* Otherwise, assume it could be denormal (say, loading from a buffer). */
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower(nir_builder *b, nir_alu_instr *alu, void *data)
|
||||
{
|
||||
if ((alu->op != nir_op_fmin && alu->op != nir_op_fmax) ||
|
||||
(alu->def.bit_size != 32))
|
||||
return false;
|
||||
|
||||
/* Lower the op, we'll fix up the denorms right after. */
|
||||
if (alu->op == nir_op_fmax)
|
||||
alu->op = nir_op_fmax_agx;
|
||||
else
|
||||
alu->op = nir_op_fmin_agx;
|
||||
|
||||
/* We need to canonicalize the result if the output could be a denorm. That
|
||||
* occurs only when one of the sources could be a denorm. Check each source.
|
||||
* Swizzles don't affect denormalness so we can grab the def directly.
|
||||
*/
|
||||
nir_scalar scalar = nir_get_scalar(&alu->def, 0);
|
||||
nir_scalar src0 = nir_scalar_chase_alu_src(scalar, 0);
|
||||
nir_scalar src1 = nir_scalar_chase_alu_src(scalar, 1);
|
||||
|
||||
if (could_be_denorm(src0) || could_be_denorm(src1)) {
|
||||
b->cursor = nir_after_instr(&alu->instr);
|
||||
nir_def *canonicalized = nir_fadd_imm(b, &alu->def, -0.0);
|
||||
nir_def_rewrite_uses_after(&alu->def, canonicalized,
|
||||
canonicalized->parent_instr);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
agx_nir_lower_fminmax(nir_shader *s)
|
||||
{
|
||||
return nir_shader_alu_pass(s, lower, nir_metadata_control_flow, NULL);
|
||||
}
|
||||
@@ -9,6 +9,7 @@ libasahi_agx_files = files(
|
||||
'agx_insert_waits.c',
|
||||
'agx_nir_lower_address.c',
|
||||
'agx_nir_lower_cull_distance.c',
|
||||
'agx_nir_lower_fminmax.c',
|
||||
'agx_nir_lower_frag_sidefx.c',
|
||||
'agx_nir_lower_sample_mask.c',
|
||||
'agx_nir_lower_discard_zs_emit.c',
|
||||
|
||||
@@ -1364,6 +1364,14 @@ binop_convert("interleave_agx", tuint32, tuint16, "", """
|
||||
be used as-is for Morton encoding.
|
||||
""")
|
||||
|
||||
# These are like fmin/fmax, but do not flush denorms on the output which is why
|
||||
# they're modeled as conversions. AGX flushes fp32 denorms but preserves fp16
|
||||
# denorms, so fp16 fmin/fmax work without lowering.
|
||||
binop_convert("fmin_agx", tuint32, tfloat32, _2src_commutative + associative,
|
||||
"(src0 < src1 || isnan(src1)) ? src0 : src1")
|
||||
binop_convert("fmax_agx", tuint32, tfloat32, _2src_commutative + associative,
|
||||
"(src0 > src1 || isnan(src1)) ? src0 : src1")
|
||||
|
||||
# NVIDIA PRMT
|
||||
opcode("prmt_nv", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
|
||||
False, "", """
|
||||
|
||||
Reference in New Issue
Block a user