diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 73570d66adb..e7957f730c6 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -102,6 +102,49 @@ def lowered_sincos(c): def intBitsToFloat(i): return struct.unpack('!f', struct.pack('!I', i))[0] +# Takes a pattern as input and returns a list of patterns where each +# pattern has a different permutation of fneg/fabs(value) as the replacement +# for the key operands in replacements. +def add_fabs_fneg(pattern, replacements, commutative = True): + def to_list(pattern): + return [to_list(i) if isinstance(i, tuple) else i for i in pattern] + + def to_tuple(pattern): + return tuple(to_tuple(i) if isinstance(i, list) else i for i in pattern) + + def replace_varible(pattern, search, replace): + for i in range(len(pattern)): + if pattern[i] == search: + pattern[i] = replace + elif isinstance(pattern[i], list): + replace_varible(pattern[i], search, replace) + + if commutative: + perms = itertools.combinations_with_replacement(range(4), len(replacements)) + else: + perms = itertools.product(range(4), repeat=len(replacements)) + + result = [] + + for perm in perms: + curr = to_list(pattern) + + for i, (search, base) in enumerate(replacements.items()): + if perm[i] == 0: + replace = ['fneg', ['fabs', base]] + elif perm[i] == 1: + replace = ['fabs', base] + elif perm[i] == 2: + replace = ['fneg', base] + elif perm[i] == 3: + replace = base + + replace_varible(curr, search, replace) + + result.append(to_tuple(curr)) + return result + + optimizations = [ (('imul', a, '#b(is_pos_power_of_two)'), ('ishl', a, ('find_lsb', b)), '!options->lower_bitops'), @@ -274,21 +317,21 @@ optimizations = [ # Optimize open-coded fmulz. # (b==0.0 ? 0.0 : a) * (a==0.0 ? 0.0 : b) -> fmulz(a, b) - (('fmul@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, a), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, b)), - ('fmulz', a, b), has_fmulz), - (('fmul@32(nsz)', a, ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)')), - ('fmulz', a, b), has_fmulz), + *add_fabs_fneg((('fmul@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, 'ma'), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, 'mb')), + ('fmulz', 'ma', 'mb'), has_fmulz), {'ma' : a, 'mb' : b}), + *add_fabs_fneg((('fmul@32(nsz)', 'ma', ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)')), + ('fmulz', 'ma', b), has_fmulz), {'ma' : a}), # ffma(b==0.0 ? 0.0 : a, a==0.0 ? 0.0 : b, c) -> ffmaz(a, b, c) - (('ffma@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, a), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, b), c), - ('ffmaz', a, b, c), has_fmulz), - (('ffma@32(nsz)', a, ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)'), c), - ('ffmaz', a, b, c), has_fmulz), + *add_fabs_fneg((('ffma@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, 'ma'), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, 'mb'), c), + ('ffmaz', 'ma', 'mb', c), has_fmulz), {'ma' : a, 'mb' : b}), + *add_fabs_fneg((('ffma@32(nsz)', 'ma', ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)'), c), + ('ffmaz', 'ma', b, c), has_fmulz), {'ma' : a}), # b == 0.0 ? 1.0 : fexp2(fmul(a, b)) -> fexp2(fmulz(a, b)) - (('bcsel(nsz,nnan,ninf)', ignore_exact('feq', b, 0.0), 1.0, ('fexp2', ('fmul@32', a, b))), - ('fexp2', ('fmulz', a, b)), - has_fmulz), + *add_fabs_fneg((('bcsel(nsz,nnan,ninf)', ignore_exact('feq', b, 0.0), 1.0, ('fexp2', ('fmul@32', a, 'mb'))), + ('fexp2', ('fmulz', a, 'mb')), + has_fmulz), {'mb': b}), ] # Shorthand for the expansion of just the dot product part of the [iu]dp4a