diff --git a/src/nouveau/compiler/nak_nir_split_64bit_conversions.c b/src/nouveau/compiler/nak_nir_split_64bit_conversions.c index 4d40f8e82ea..57db8533d3e 100644 --- a/src/nouveau/compiler/nak_nir_split_64bit_conversions.c +++ b/src/nouveau/compiler/nak_nir_split_64bit_conversions.c @@ -55,6 +55,7 @@ split_64bit_conversion(nir_builder *b, nir_instr *instr, UNUSED void *_data) nir_alu_type dst_full_type = nir_op_infos[alu->op].output_type; assert(nir_alu_type_get_type_size(dst_full_type) == dst_bit_size); nir_alu_type dst_type = nir_alu_type_get_base_type(dst_full_type); + const nir_rounding_mode rounding_mode = op_rounding_mode(alu->op); /* We can't cross the 64-bit boundary in one conversion */ if ((src_bit_size <= 32 && dst_bit_size <= 32) || @@ -87,10 +88,95 @@ split_64bit_conversion(nir_builder *b, nir_instr *instr, UNUSED void *_data) b->cursor = nir_before_instr(&alu->instr); nir_def *src = nir_ssa_for_alu_src(b, alu, 0); - nir_def *tmp = nir_type_convert(b, src, src_type, tmp_type, - nir_rounding_mode_undef); + nir_def *tmp; + if (src_full_type == nir_type_float64 && dst_full_type == nir_type_float16) { + /* For fp64->fp16 conversions, we need to be careful with the first + * conversion or else rounding might not accumulate properly. + */ + assert(tmp_type == nir_type_float32); + if (rounding_mode == nir_rounding_mode_rtne || + rounding_mode == nir_rounding_mode_undef) { + nir_def *src_lo = nir_unpack_64_2x32_split_x(b, src); + nir_def *src_hi = nir_unpack_64_2x32_split_y(b, src); + + /* RTNE is tricky to get right through a double conversion. To work + * around this, we do a little fixup of the fp64 value first. + * + * For a 64-bit float, the mantissa bits are as follows: + * + * HHHHHHHHHHHLTFFFFFFFFF FFFDDDDDDDDDDDDDDDDDDDDDDDDDDDDD + * | | + * +------- bottom 32 bits -------+ + * + * Where: + * - D are only used for fp64 + * - T and F are used for fp64 and fp32 + * - H and L are used for fp64, fp32, and fp16 + * - L denotes the low bit of the fp16 mantissa + * - T is the tie bit + * + * The RTNE tie-breaking rules for fp64 -> fp16 can then be described + * as follows: + * + * - If any F or D bit is non-zero: + * - If T == 1, round up + * - If T == 0, round down + * - If all F and D bits are zero: + * - If T == 0, it's already fp16, do nothing + * - If T != 0 and L == 0, round down + * - If T != 0 and L != 0, round up + * + * What's important here is that the only way the F or D bits fit + * into the algorithm is if any are zero or none are zero. So we + * will get the same result if we take all of the bits in the low + * dword, or them together, and then or that into the low F bits of + * the high dword. The result of "all F and D bits are zero" will be + * the same. We can also zero the low dword without affecting the + * final result. Doing this accomplishes two useful things: + * + * 1. The resulting fp64 value is exactly representable as fp32 so + * we don't have to care about the rounding of the fp64 -> fp32 + * conversion. + * + * 2. The fp32 -> fp16 conversion will round exactly the same as a + * full fp64 -> fp16 conversion on the original data since it now + * takes all of the D bits into account as well as the F bits. + * + * It's also correct for NaN/INF since those are delineated by the + * entire mantissa being either zero or non-zero. For denorms, + * anything that might be a denorm in fp32 or fp64 will have a + * sufficiently negative exponent that it will flush to zero when + * converted to fp16, regardless of what we do here. + * + * There are many operations we could choose for combining the low + * dword bits for ORing into the high dword. We choose umin because + * it nicely translates to a single fixed-latency instruction on + * everything except Volta. + */ + src_hi = nir_ior(b, src_hi, nir_umin_imm(b, src_lo, 1)); + src_lo = nir_imm_int(b, 0); + + tmp = nir_f2f32(b, nir_pack_64_2x32_split(b, src_lo, src_hi)); + } else { + /* For round-up, round-down, and round-towards-zero, the rounding + * accumulates properly as long as we use the same rounding mode for + * both operations. + */ + tmp = nir_convert_alu_types(b, 32, src, + .src_type = nir_type_float64, + .dest_type = tmp_type, + .rounding_mode = rounding_mode, + .saturate = false); + } + } else { + /* This is an up-convert or a convert to integer, in which case we + * always round towards zero. + */ + tmp = nir_type_convert(b, src, src_type, tmp_type, + nir_rounding_mode_undef); + } nir_def *res = nir_type_convert(b, tmp, tmp_type, dst_full_type, - op_rounding_mode(alu->op)); + rounding_mode); nir_def_replace(&alu->def, res); return true;