From 90e1b128903cabfe4fcfb5ae52cf46d5ddbf1189 Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Wed, 19 Mar 2025 09:12:38 -0700 Subject: [PATCH] spirv: Add bfloat16 support to SpecConstantOp Handle bfloat16 by converting sources to float, performing the operation, and converting result back to bfloat16 if needed. This is done because not all ALU ops have a `bf` version in NIR. Reviewed-by: Rohan Garg Reviewed-by: Ian Romanick Reviewed-by: Georg Lehmann Part-of: --- src/compiler/spirv/spirv_to_nir.c | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index d4e2537124d..a9e950dd4cc 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -40,6 +40,7 @@ #include "util/u_debug.h" #include "util/u_printf.h" #include "util/mesa-blake3.h" +#include "util/bfloat.h" #include @@ -2697,6 +2698,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, const glsl_type *dst_type = val->type->type; const glsl_type *src_type = dst_type; + const bool bfloat_dst = glsl_type_is_bfloat_16(dst_type); + bool bfloat_src = bfloat_dst; + + if (bfloat_dst) + dst_type = glsl_float_type(); + unsigned num_components = glsl_get_vector_elements(val->type->type); vtn_assert(count <= 7); @@ -2704,10 +2711,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpSConvert: case SpvOpFConvert: - case SpvOpUConvert: + case SpvOpUConvert: { /* We have a different source type in a conversion. */ src_type = vtn_get_value_type(b, w[4])->type; + bfloat_src = glsl_type_is_bfloat_16(src_type); + if (bfloat_src) + src_type = glsl_float_type(); break; + } default: break; }; @@ -2721,7 +2732,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, */ assert(!exact); - unsigned bit_size = glsl_get_bit_size(val->type->type); + unsigned bit_size = glsl_get_bit_size(src_type); nir_const_value src[3][NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < count - 4; i++) { @@ -2739,8 +2750,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, num_components; unsigned j = swap ? 1 - i : i; - for (unsigned c = 0; c < src_comps; c++) + for (unsigned c = 0; c < src_comps; c++) { src[j][c] = src_val->constant->values[c]; + if (bfloat_src) + src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16); + } } /* fix up fixed size sources */ @@ -2769,6 +2783,16 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, nir_eval_const_opcode(op, val->constant->values, num_components, bit_size, srcs, b->shader->info.float_controls_execution_mode); + + if (bfloat_dst) { + for (int i = 0; i < num_components; i++) { + /* Ensure the pad bits are zeroed by fully assigning the value. */ + const uint16_t b = + _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32); + val->constant->values[i] = (nir_const_value){ .u16 = b }; + } + } + break; } /* default */ }