From 45a111c21c23be94f9297650fb8428fe2acf5641 Mon Sep 17 00:00:00 2001 From: Alyssa Rosenzweig Date: Tue, 25 Oct 2022 22:29:31 -0400 Subject: [PATCH] nir/opt_algebraic: Fuse c - a * b to FMA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Algebraically it is clear that -(a * b) + c = (-a) * b + c = fma(-a, b, c) But this is not clear from the NIR ('fadd', ('fneg', ('fmul', a, b)), c) Add rules to handle this case specially. Note we don't necessarily want to solve this by pushing fneg into fmul, because the rule opt_algebraic (not the late part where FMA fusing happens) specifically pulls fneg out of fmul to push fneg up multiplication chains. Noticed in the big glmark2 "terrain" shader, which has a cycle count reduced by 22% on Mali-G57 thanks to having this pattern a ton and being FMA bound. BEFORE: 1249 inst, 16.015625 cycles, 16.015625 fma, ... 632 quadwords AFTER: 997 inst, 12.437500 cycles, .... 504 quadwords Results on the same shader on AGX are also quite dramatic: BEFORE: 1294 inst, 8600 bytes, 50 halfregs, ... AFTER: 1154 inst, 8040 bytes, 50 halfregs, ... Similar rules apply for fabs. v2: Use a loop over the bit sizes (suggested by Emma). shader-db on Valhall (open + small subset of closed), results on Bifrost are similar: total instructions in shared programs: 167975 -> 164970 (-1.79%) instructions in affected programs: 92642 -> 89637 (-3.24%) helped: 492 HURT: 25 helped stats (abs) min: 1.0 max: 252.0 x̄: 6.25 x̃: 3 helped stats (rel) min: 0.30% max: 20.18% x̄: 3.21% x̃: 2.91% HURT stats (abs) min: 1.0 max: 5.0 x̄: 2.80 x̃: 3 HURT stats (rel) min: 0.46% max: 9.09% x̄: 3.89% x̃: 3.37% 95% mean confidence interval for instructions value: -6.95 -4.68 95% mean confidence interval for instructions %-change: -3.08% -2.65% Instructions are helped. total cycles in shared programs: 10556.89 -> 10538.98 (-0.17%) cycles in affected programs: 265.56 -> 247.66 (-6.74%) helped: 88 HURT: 2 helped stats (abs) min: 0.015625 max: 3.578125 x̄: 0.20 x̃: 0 helped stats (rel) min: 0.65% max: 22.34% x̄: 5.65% x̃: 4.25% HURT stats (abs) min: 0.0625 max: 0.0625 x̄: 0.06 x̃: 0 HURT stats (rel) min: 8.33% max: 12.50% x̄: 10.42% x̃: 10.42% 95% mean confidence interval for cycles value: -0.28 -0.12 95% mean confidence interval for cycles %-change: -6.30% -4.30% Cycles are helped. total fma in shared programs: 1582.42 -> 1535.06 (-2.99%) fma in affected programs: 871.58 -> 824.22 (-5.43%) helped: 502 HURT: 9 helped stats (abs) min: 0.015625 max: 3.578125 x̄: 0.09 x̃: 0 helped stats (rel) min: 0.60% max: 25.00% x̄: 5.46% x̃: 4.82% HURT stats (abs) min: 0.015625 max: 0.0625 x̄: 0.03 x̃: 0 HURT stats (rel) min: 4.35% max: 12.50% x̄: 6.22% x̃: 4.35% 95% mean confidence interval for fma value: -0.11 -0.08 95% mean confidence interval for fma %-change: -5.58% -4.93% Fma are helped. total cvt in shared programs: 665.55 -> 665.95 (0.06%) cvt in affected programs: 61.72 -> 62.12 (0.66%) helped: 33 HURT: 43 helped stats (abs) min: 0.015625 max: 0.359375 x̄: 0.04 x̃: 0 helped stats (rel) min: 1.01% max: 25.00% x̄: 6.68% x̃: 4.35% HURT stats (abs) min: 0.015625 max: 0.109375 x̄: 0.04 x̃: 0 HURT stats (rel) min: 0.78% max: 38.46% x̄: 10.85% x̃: 6.90% 95% mean confidence interval for cvt value: -0.01 0.02 95% mean confidence interval for cvt %-change: 0.23% 6.24% Inconclusive result (value mean confidence interval includes 0). total quadwords in shared programs: 93376 -> 91736 (-1.76%) quadwords in affected programs: 25376 -> 23736 (-6.46%) helped: 169 HURT: 1 helped stats (abs) min: 8.0 max: 128.0 x̄: 9.75 x̃: 8 helped stats (rel) min: 1.52% max: 33.33% x̄: 8.35% x̃: 8.00% HURT stats (abs) min: 8.0 max: 8.0 x̄: 8.00 x̃: 8 HURT stats (rel) min: 25.00% max: 25.00% x̄: 25.00% x̃: 25.00% 95% mean confidence interval for quadwords value: -11.18 -8.11 95% mean confidence interval for quadwords %-change: -8.95% -7.36% Quadwords are helped. total threads in shared programs: 4697 -> 4701 (0.09%) threads in affected programs: 4 -> 8 (100.00%) helped: 4 HURT: 0 helped stats (abs) min: 1.0 max: 1.0 x̄: 1.00 x̃: 1 helped stats (rel) min: 100.00% max: 100.00% x̄: 100.00% x̃: 100.00% 95% mean confidence interval for threads value: 1.00 1.00 95% mean confidence interval for threads %-change: 100.00% 100.00% Threads are helped. Signed-off-by: Alyssa Rosenzweig Reviewed-by: Marek Olk Reviewed-by: Karol Herbst [v1] Part-of: --- src/compiler/nir/nir_opt_algebraic.py | 40 +++++++++++++++---- src/compiler/nir/nir_search_helpers.h | 6 ++- .../drivers/virgl/ci/traces-virgl-iris.yml | 6 +-- src/gallium/drivers/virgl/ci/traces-virgl.yml | 4 +- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 953b67f7e2c..026f0ee5558 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -2638,15 +2638,39 @@ late_optimizations = [ # nir_lower_to_source_mods will collapse this, but its existence during the # optimization loop can prevent other optimizations. - (('fneg', ('fneg', a)), a), + (('fneg', ('fneg', a)), a) +] - # re-combine inexact mul+add to ffma. Do this before fsub so that a * b - c - # gets combined to fma(a, b, -c). - (('~fadd@16', ('fmul(is_only_used_by_fadd)', a, b), c), ('ffma', a, b, c), 'options->fuse_ffma16'), - (('~fadd@32', ('fmul(is_only_used_by_fadd)', a, b), c), ('ffma', a, b, c), 'options->fuse_ffma32'), - (('~fadd@64', ('fmul(is_only_used_by_fadd)', a, b), c), ('ffma', a, b, c), 'options->fuse_ffma64'), - (('~fadd@32', ('fmulz(is_only_used_by_fadd)', a, b), c), ('ffmaz', a, b, c), 'options->fuse_ffma32'), +# re-combine inexact mul+add to ffma. Do this before fsub so that a * b - c +# gets combined to fma(a, b, -c). +for sz, mulz in itertools.product([16, 32, 64], [False, True]): + # fmulz/ffmaz only for fp32 + if mulz and sz != 32: + continue + # Fuse the correct fmul. Only consider fmuls where the only users are fadd + # (or fneg/fabs which are assumed to be propagated away), as a heuristic to + # avoid fusing in cases where it's harmful. + fmul = ('fmulz' if mulz else 'fmul') + '(is_only_used_by_fadd)' + ffma = 'ffmaz' if mulz else 'ffma' + + fadd = '~fadd@{}'.format(sz) + option = 'options->fuse_ffma{}'.format(sz) + + late_optimizations.extend([ + ((fadd, (fmul, a, b), c), (ffma, a, b, c), option), + + ((fadd, ('fneg(is_only_used_by_fadd)', (fmul, a, b)), c), + (ffma, ('fneg', a), b, c), option), + + ((fadd, ('fabs(is_only_used_by_fadd)', (fmul, a, b)), c), + (ffma, ('fabs', a), ('fabs', b), c), option), + + ((fadd, ('fneg(is_only_used_by_fadd)', ('fabs', (fmul, a, b))), c), + (ffma, ('fneg', ('fabs', a)), ('fabs', b), c), option), + ]) + +late_optimizations.extend([ # Subtractions get lowered during optimization, so we need to recombine them (('fadd@8', a, ('fneg', 'b')), ('fsub', 'a', 'b'), 'options->has_fsub'), (('fadd@16', a, ('fneg', 'b')), ('fsub', 'a', 'b'), 'options->has_fsub'), @@ -2823,7 +2847,7 @@ late_optimizations = [ (('extract_i8', ('extract_u8', a, b), 0), ('extract_i8', a, b)), (('extract_u8', ('extract_i8', a, b), 0), ('extract_u8', a, b)), (('extract_u8', ('extract_u8', a, b), 0), ('extract_u8', a, b)), -] +]) # A few more extract cases we'd rather leave late for N in [16, 32]: diff --git a/src/compiler/nir/nir_search_helpers.h b/src/compiler/nir/nir_search_helpers.h index da1e9b506a8..5308fc41e40 100644 --- a/src/compiler/nir/nir_search_helpers.h +++ b/src/compiler/nir/nir_search_helpers.h @@ -433,8 +433,12 @@ is_only_used_by_fadd(const nir_alu_instr *instr) const nir_alu_instr *const user_alu = nir_instr_as_alu(user_instr); assert(instr != user_alu); - if (user_alu->op != nir_op_fadd) + if (user_alu->op == nir_op_fneg || user_alu->op == nir_op_fabs) { + if (!is_only_used_by_fadd(user_alu)) + return false; + } else if (user_alu->op != nir_op_fadd) { return false; + } } return true; diff --git a/src/gallium/drivers/virgl/ci/traces-virgl-iris.yml b/src/gallium/drivers/virgl/ci/traces-virgl-iris.yml index 7ad46c464ed..05aaf2c9541 100644 --- a/src/gallium/drivers/virgl/ci/traces-virgl-iris.yml +++ b/src/gallium/drivers/virgl/ci/traces-virgl-iris.yml @@ -30,7 +30,7 @@ traces: checksum: 32e8b627a33ad08d416dfdb804920371 0ad/0ad-v2.trace: gl-virgl: - checksum: bf22fd7c3fc8baa7b0e9345728626d5f + checksum: 638fa405f78a6631ba829a8fc98392a6 glmark2/buffer:update-fraction=0.5:update-dispersion=0.9:columns=200:update-method=map:interleave=false-v2.trace: gl-virgl: checksum: 040232e01e394a967dc3320bb9252870 @@ -42,7 +42,7 @@ traces: checksum: df21895268db3ab185ae5ffa5b2d7f37 glmark2/bump:bump-render=height-v2.trace: gl-virgl: - checksum: cd32f46925906c53fae747372a8f2ed8 + checksum: cceb2b8d4852b94709684b69c688638c glmark2/bump:bump-render=high-poly-v2.trace: gl-virgl: checksum: 11b7a4820b452934e6f12b57b8910a9a @@ -126,7 +126,7 @@ traces: label: [crash] gputest/pixmark-julia-fp32-v2.trace: gl-virgl: - checksum: 0aa3a82a5b849cb83436e52c4e3e95ac + checksum: fbf5e44a6f46684b84e5bb5ad6d36c67 gputest/pixmark-julia-fp64-v2.trace: gl-virgl: checksum: 1760aea00af985b8cd902128235b08f6 diff --git a/src/gallium/drivers/virgl/ci/traces-virgl.yml b/src/gallium/drivers/virgl/ci/traces-virgl.yml index 7da834d4f83..2be64e826d5 100644 --- a/src/gallium/drivers/virgl/ci/traces-virgl.yml +++ b/src/gallium/drivers/virgl/ci/traces-virgl.yml @@ -123,7 +123,7 @@ traces: label: [crash] gputest/pixmark-julia-fp32-v2.trace: gl-virgl: - checksum: 25f938c726c68c08a88193f28f7c4474 + checksum: 8b3584b1dd8f1d1bb63205564bd78e4e gputest/pixmark-julia-fp64-v2.trace: gl-virgl: checksum: 73ccaff82ea764057fb0f93f0024cf84 @@ -183,7 +183,7 @@ traces: checksum: f4af4067b37c00861fa5911e4c0a6629 supertuxkart/supertuxkart-mansion-egl-gles-v2.trace: gl-virgl: - checksum: 092e8ca38e58aaa83df2a9f0b7b8aee5 + checksum: cc7092975dd6c9064aa54cd7f18053b6 xonotic/xonotic-keybench-high-v2.trace: gl-virgl: checksum: f3b184bf8858a6ebccd09e7ca032197e