Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37805>
129 lines
3.6 KiB
Lua
129 lines
3.6 KiB
Lua
-- Copyright © 2025 Intel Corporation
|
|
-- SPDX-License-Identifier: MIT
|
|
--
|
|
-- Encode and decode floating point types.
|
|
--
|
|
-- Just enough to get basic usage of the examples. If this get serious might
|
|
-- be a good idea to implement it using the existing Mesa routines and exposing
|
|
-- as part of executor API.
|
|
|
|
local M = {}
|
|
|
|
M.decode_float = function(bits, mantissa_bits, exponent_bits, bias)
|
|
local total_bits = 1 + exponent_bits + mantissa_bits
|
|
|
|
local sign_mask = 1 << (total_bits - 1)
|
|
local exponent_mask = ((1 << exponent_bits) - 1) << mantissa_bits
|
|
local mantissa_mask = (1 << mantissa_bits) - 1
|
|
|
|
local sign = (bits & sign_mask) ~= 0
|
|
local exponent = (bits & exponent_mask) >> mantissa_bits
|
|
local mantissa = bits & mantissa_mask
|
|
|
|
if exponent == 0 then
|
|
if mantissa == 0 then
|
|
return sign and "-0.0" or "0.0"
|
|
else
|
|
-- Subnormal number. They don't have implicit leading 1, so
|
|
-- the number corresponds to "0.mantissa * 2^(1-bias)".
|
|
local value = mantissa / (1 << mantissa_bits)
|
|
value = value * (2 ^ (1 - bias))
|
|
if sign then value = -value end
|
|
return string.format("%.17g", value)
|
|
end
|
|
elseif exponent == (1 << exponent_bits) - 1 then
|
|
if mantissa == 0 then
|
|
return sign and "-inf" or "inf"
|
|
else
|
|
return "nan"
|
|
end
|
|
else
|
|
-- Normal numbers have implicit leading 1, so the
|
|
-- number corresponds to "1.mantissa * 2^(exponent-bias)".
|
|
local value = 1.0 + (mantissa / (1 << mantissa_bits))
|
|
value = value * (2 ^ (exponent - bias))
|
|
if sign then value = -value end
|
|
return string.format("%.17g", value)
|
|
end
|
|
end
|
|
|
|
M.decode_f32 = function(bits)
|
|
return M.decode_float(bits, 23, 8, 127)
|
|
end
|
|
|
|
M.decode_f16 = function(bits)
|
|
return M.decode_float(bits, 10, 5, 15)
|
|
end
|
|
|
|
M.decode_bf16 = function(bits)
|
|
return M.decode_float(bits, 7, 8, 127)
|
|
end
|
|
|
|
M.encode_float = function(value_str, mantissa_bits, exponent_bits, bias)
|
|
local value = tonumber(value_str)
|
|
if not value then
|
|
return nil
|
|
end
|
|
|
|
local total_bits = 1 + exponent_bits + mantissa_bits
|
|
local max_exponent = (1 << exponent_bits) - 1
|
|
|
|
local sign_bit = value < 0 and (1 << (total_bits - 1)) or 0
|
|
local signed_inf = sign_bit | max_exponent << mantissa_bits
|
|
|
|
-- Handle various special cases first: signed zero, NaN, Inf/-Inf.
|
|
if value == 0.0 then
|
|
return sign_bit
|
|
elseif value ~= value then
|
|
return (1 << (total_bits - 1)) | (max_exponent << mantissa_bits) | 1
|
|
elseif value == math.huge or value == -math.huge then
|
|
return signed_inf
|
|
end
|
|
|
|
-- Do math with the absolute value from now on.
|
|
if sign_bit ~= 0 then value = -value end
|
|
|
|
local exponent = math.floor(math.log(value) / math.log(2))
|
|
|
|
local mantissa_value = value / (2 ^ exponent) - 1.0
|
|
|
|
exponent = exponent + bias
|
|
|
|
if exponent <= 0 then
|
|
-- Subnormal: no implicit leading 1, use minimum exponent
|
|
mantissa_value = value / (2 ^ (1 - bias))
|
|
exponent = 0
|
|
elseif exponent >= max_exponent then
|
|
-- Value too large to represent.
|
|
return signed_inf
|
|
end
|
|
|
|
local mantissa = math.floor(mantissa_value * (1 << mantissa_bits) + 0.5)
|
|
|
|
if mantissa >= (1 << mantissa_bits) then
|
|
-- Rounding caused mantissa to overflow, increment exponent.
|
|
mantissa = 0
|
|
exponent = exponent + 1
|
|
if exponent >= max_exponent then
|
|
-- Value too large to represent.
|
|
return signed_inf
|
|
end
|
|
end
|
|
|
|
return sign_bit | (exponent << mantissa_bits) | mantissa
|
|
end
|
|
|
|
M.encode_f32 = function(value_str)
|
|
return M.encode_float(value_str, 23, 8, 127)
|
|
end
|
|
|
|
M.encode_f16 = function(value_str)
|
|
return M.encode_float(value_str, 10, 5, 15)
|
|
end
|
|
|
|
M.encode_bf16 = function(value_str)
|
|
return M.encode_float(value_str, 7, 8, 127)
|
|
end
|
|
|
|
return M
|