diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index d4e2537124d..a9e950dd4cc 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -40,6 +40,7 @@ #include "util/u_debug.h" #include "util/u_printf.h" #include "util/mesa-blake3.h" +#include "util/bfloat.h" #include @@ -2697,6 +2698,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, const glsl_type *dst_type = val->type->type; const glsl_type *src_type = dst_type; + const bool bfloat_dst = glsl_type_is_bfloat_16(dst_type); + bool bfloat_src = bfloat_dst; + + if (bfloat_dst) + dst_type = glsl_float_type(); + unsigned num_components = glsl_get_vector_elements(val->type->type); vtn_assert(count <= 7); @@ -2704,10 +2711,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpSConvert: case SpvOpFConvert: - case SpvOpUConvert: + case SpvOpUConvert: { /* We have a different source type in a conversion. */ src_type = vtn_get_value_type(b, w[4])->type; + bfloat_src = glsl_type_is_bfloat_16(src_type); + if (bfloat_src) + src_type = glsl_float_type(); break; + } default: break; }; @@ -2721,7 +2732,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, */ assert(!exact); - unsigned bit_size = glsl_get_bit_size(val->type->type); + unsigned bit_size = glsl_get_bit_size(src_type); nir_const_value src[3][NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < count - 4; i++) { @@ -2739,8 +2750,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, num_components; unsigned j = swap ? 1 - i : i; - for (unsigned c = 0; c < src_comps; c++) + for (unsigned c = 0; c < src_comps; c++) { src[j][c] = src_val->constant->values[c]; + if (bfloat_src) + src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16); + } } /* fix up fixed size sources */ @@ -2769,6 +2783,16 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, nir_eval_const_opcode(op, val->constant->values, num_components, bit_size, srcs, b->shader->info.float_controls_execution_mode); + + if (bfloat_dst) { + for (int i = 0; i < num_components; i++) { + /* Ensure the pad bits are zeroed by fully assigning the value. */ + const uint16_t b = + _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32); + val->constant->values[i] = (nir_const_value){ .u16 = b }; + } + } + break; } /* default */ }