diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index 402272cf804..0c3458a61f1 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -1756,20 +1756,17 @@ agx_fminmax_to(agx_builder *b, agx_index dst, agx_index s0, agx_index s1, assert(!nir_alu_instr_is_signed_zero_preserve(alu) && "should've been lowered"); - bool fmax = alu->op == nir_op_fmax; + assert((alu->def.bit_size == 16) == + (alu->op == nir_op_fmin || alu->op == nir_op_fmax) && + "fp32 should be lowered"); + + bool fmax = alu->op == nir_op_fmax || alu->op == nir_op_fmax_agx; enum agx_fcond fcond = fmax ? AGX_FCOND_GTN : AGX_FCOND_LTN; - /* Calculate min/max with the appropriate hardware instruction */ - agx_index tmp = agx_fcmpsel(b, s0, s1, s0, s1, fcond); - - /* G13 flushes fp32 denorms and preserves fp16 denorms. Since cmpsel - * preserves denorms, we need to canonicalize for fp32. Canonicalizing fp16 - * would be harmless but wastes an instruction. + /* Calculate min/max with the appropriate hardware instruction. This will not + * handle denorms, but we were already lowered for that. */ - if (alu->def.bit_size == 32) - return agx_fadd_to(b, dst, tmp, agx_negzero()); - else - return agx_mov_to(b, dst, tmp); + return agx_fcmpsel_to(b, dst, s0, s1, s0, s1, fcond); } static agx_instr * @@ -1893,6 +1890,8 @@ agx_emit_alu(agx_builder *b, nir_alu_instr *instr) case nir_op_fmin: case nir_op_fmax: + case nir_op_fmin_agx: + case nir_op_fmax_agx: return agx_fminmax_to(b, dst, s0, s1, instr); case nir_op_imin: @@ -3035,6 +3034,11 @@ agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size) } while (progress); } + /* Lower fmin/fmax before optimizing preambles so we can see across uniform + * expressions. + */ + NIR_PASS(_, nir, agx_nir_lower_fminmax); + if (preamble_size && (!(agx_compiler_debug & AGX_DBG_NOPREAMBLE))) NIR_PASS(_, nir, agx_nir_opt_preamble, preamble_size); @@ -3048,7 +3052,12 @@ agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size) NIR_PASS(_, nir, nir_opt_peephole_select, 64, false, true); NIR_PASS(_, nir, nir_lower_int64); + /* We need to lower fmin/fmax again after nir_opt_algebraic_late due to f2fmp + * wackiness. This is usually a no-op but is required for correctness in + * GLES. + */ NIR_PASS(_, nir, nir_opt_algebraic_late); + NIR_PASS(_, nir, agx_nir_lower_fminmax); /* Fuse add/sub/multiplies/shifts after running opt_algebraic_late to fuse * isub but before shifts are lowered. diff --git a/src/asahi/compiler/agx_nir.h b/src/asahi/compiler/agx_nir.h index 3b4de56b120..d63a3bf6c86 100644 --- a/src/asahi/compiler/agx_nir.h +++ b/src/asahi/compiler/agx_nir.h @@ -16,3 +16,4 @@ bool agx_nir_fence_images(struct nir_shader *shader); bool agx_nir_lower_layer(struct nir_shader *s); bool agx_nir_lower_clip_distance(struct nir_shader *s); bool agx_nir_lower_subgroups(struct nir_shader *s); +bool agx_nir_lower_fminmax(struct nir_shader *s); diff --git a/src/asahi/compiler/agx_nir_lower_fminmax.c b/src/asahi/compiler/agx_nir_lower_fminmax.c new file mode 100644 index 00000000000..7acee8a97fd --- /dev/null +++ b/src/asahi/compiler/agx_nir_lower_fminmax.c @@ -0,0 +1,77 @@ +/* + * Copyright 2024 Valve Corporation + * SPDX-License-Identifier: MIT + */ + +#include "compiler/nir/nir.h" +#include "compiler/nir/nir_builder.h" +#include "agx_nir.h" + +/* + * AGX generally flushes FP32 denorms. However, the min/max instructions do not + * as they are implemented with cmpsel. We need to flush the results of fp32 + * min/max for correctness. Doing so naively will generate redundant flushes, so + * this pass tries to be clever and elide flushes when possible. + * + * This pass is still pretty simple, it doesn't see through phis or bcsels yet. + */ +static bool +could_be_denorm(nir_scalar s) +{ + /* Constants can be denorms only if they are denorms. */ + if (nir_scalar_is_const(s)) { + return fpclassify(nir_scalar_as_float(s)) == FP_SUBNORMAL; + } + + /* Floating-point instructions flush denormals, so ALU results can only be + * denormal if they are not from a float instruction. Crucially fmin/fmax + * flushes in NIR, so this pass handles chains of fmin/fmax properly. + */ + if (nir_scalar_is_alu(s)) { + nir_op op = nir_scalar_alu_op(s); + nir_alu_type T = nir_op_infos[op].output_type; + + return nir_alu_type_get_base_type(T) != nir_type_float && + op != nir_op_fmin_agx && op != nir_op_fmax_agx; + } + + /* Otherwise, assume it could be denormal (say, loading from a buffer). */ + return true; +} + +static bool +lower(nir_builder *b, nir_alu_instr *alu, void *data) +{ + if ((alu->op != nir_op_fmin && alu->op != nir_op_fmax) || + (alu->def.bit_size != 32)) + return false; + + /* Lower the op, we'll fix up the denorms right after. */ + if (alu->op == nir_op_fmax) + alu->op = nir_op_fmax_agx; + else + alu->op = nir_op_fmin_agx; + + /* We need to canonicalize the result if the output could be a denorm. That + * occurs only when one of the sources could be a denorm. Check each source. + * Swizzles don't affect denormalness so we can grab the def directly. + */ + nir_scalar scalar = nir_get_scalar(&alu->def, 0); + nir_scalar src0 = nir_scalar_chase_alu_src(scalar, 0); + nir_scalar src1 = nir_scalar_chase_alu_src(scalar, 1); + + if (could_be_denorm(src0) || could_be_denorm(src1)) { + b->cursor = nir_after_instr(&alu->instr); + nir_def *canonicalized = nir_fadd_imm(b, &alu->def, -0.0); + nir_def_rewrite_uses_after(&alu->def, canonicalized, + canonicalized->parent_instr); + } + + return true; +} + +bool +agx_nir_lower_fminmax(nir_shader *s) +{ + return nir_shader_alu_pass(s, lower, nir_metadata_control_flow, NULL); +} diff --git a/src/asahi/compiler/meson.build b/src/asahi/compiler/meson.build index 7979dffda4d..db6cdbef0d1 100644 --- a/src/asahi/compiler/meson.build +++ b/src/asahi/compiler/meson.build @@ -9,6 +9,7 @@ libasahi_agx_files = files( 'agx_insert_waits.c', 'agx_nir_lower_address.c', 'agx_nir_lower_cull_distance.c', + 'agx_nir_lower_fminmax.c', 'agx_nir_lower_frag_sidefx.c', 'agx_nir_lower_sample_mask.c', 'agx_nir_lower_discard_zs_emit.c', diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index c6c0e59df28..372728c9f92 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -1364,6 +1364,14 @@ binop_convert("interleave_agx", tuint32, tuint16, "", """ be used as-is for Morton encoding. """) +# These are like fmin/fmax, but do not flush denorms on the output which is why +# they're modeled as conversions. AGX flushes fp32 denorms but preserves fp16 +# denorms, so fp16 fmin/fmax work without lowering. +binop_convert("fmin_agx", tuint32, tfloat32, _2src_commutative + associative, + "(src0 < src1 || isnan(src1)) ? src0 : src1") +binop_convert("fmax_agx", tuint32, tfloat32, _2src_commutative + associative, + "(src0 > src1 || isnan(src1)) ? src0 : src1") + # NVIDIA PRMT opcode("prmt_nv", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32], False, "", """