diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index e840b6bad22..6882ab9d0fb 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -2497,6 +2497,7 @@ agx_optimize_nir(nir_shader *nir, unsigned *preamble_size) * do it after fusing constant shifts. Constant folding will clean up. */ NIR_PASS(_, nir, agx_nir_lower_algebraic_late); + NIR_PASS(_, nir, agx_nir_fuse_selects); NIR_PASS(_, nir, nir_opt_constant_folding); NIR_PASS(_, nir, nir_opt_combine_barriers, NULL, NULL); diff --git a/src/asahi/compiler/agx_nir.h b/src/asahi/compiler/agx_nir.h index c54b445c5c6..83c9ad87093 100644 --- a/src/asahi/compiler/agx_nir.h +++ b/src/asahi/compiler/agx_nir.h @@ -11,6 +11,7 @@ struct nir_shader; bool agx_nir_lower_interpolation(struct nir_shader *s); bool agx_nir_lower_algebraic_late(struct nir_shader *shader); +bool agx_nir_fuse_selects(struct nir_shader *shader); bool agx_nir_fuse_algebraic_late(struct nir_shader *shader); bool agx_nir_fence_images(struct nir_shader *shader); bool agx_nir_lower_layer(struct nir_shader *s); diff --git a/src/asahi/compiler/agx_nir_algebraic.py b/src/asahi/compiler/agx_nir_algebraic.py index 97e22308c40..09025011506 100644 --- a/src/asahi/compiler/agx_nir_algebraic.py +++ b/src/asahi/compiler/agx_nir_algebraic.py @@ -83,17 +83,7 @@ lower_pack = [ ('isub', 32, 'bits'))), ] -# Rewriting bcsel(a || b, ...) in terms of bcsel(a, ...) and bcsel(b, ...) lets -# our rules to fuse compare-and-select do a better job, assuming that a and b -# are comparisons themselves. -lower_selects = [ - (('bcsel', ('ior(is_used_once)', a, b), c, d), - ('bcsel', a, c, ('bcsel', b, c, d))), - - (('bcsel', ('iand(is_used_once)', a, b), c, d), - ('bcsel', a, ('bcsel', b, c, d), d)), -] - +lower_selects = [] for T, sizes, one in [('f', [16, 32], 1.0), ('i', [8, 16, 32], 1), ('b', [32], -1)]: @@ -103,6 +93,20 @@ for T, sizes, one in [('f', [16, 32], 1.0), ((f'b2{T}{size}', 'a@1'), ('bcsel', a, one, 0)), ]) +# Rewriting bcsel(a || b, ...) in terms of bcsel(a, ...) and bcsel(b, ...) lets +# our rules to fuse compare-and-select do a better job, assuming that a and b +# are comparisons themselves. +# +# This needs to be a separate pass that runs after lower_selects, in order to +# pick up patterns like b2f32(iand(...)) +opt_selects = [ + (('bcsel', ('ior(is_used_once)', a, b), c, d), + ('bcsel', a, c, ('bcsel', b, c, d))), + + (('bcsel', ('iand(is_used_once)', a, b), c, d), + ('bcsel', a, ('bcsel', b, c, d), d)), +] + fuse_extr = [] for start in range(32): fuse_extr.extend([ @@ -192,9 +196,11 @@ def run(): print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", lower_sm5_shift + lower_pack + lower_selects).render()) + print(nir_algebraic.AlgebraicPass("agx_nir_fuse_selects", + opt_selects).render()) print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late", - fuse_extr + fuse_ubfe + fuse_imad + - ixor_bcsel).render()) + fuse_extr + fuse_ubfe + + fuse_imad + ixor_bcsel).render()) if __name__ == '__main__':