Files
mesa/src/intel/executor/examples/mod/fp.lua

129 lines
3.6 KiB
Lua

-- Copyright © 2025 Intel Corporation
-- SPDX-License-Identifier: MIT
--
-- Encode and decode floating point types.
--
-- Just enough to get basic usage of the examples. If this get serious might
-- be a good idea to implement it using the existing Mesa routines and exposing
-- as part of executor API.
local M = {}
M.decode_float = function(bits, mantissa_bits, exponent_bits, bias)
local total_bits = 1 + exponent_bits + mantissa_bits
local sign_mask = 1 << (total_bits - 1)
local exponent_mask = ((1 << exponent_bits) - 1) << mantissa_bits
local mantissa_mask = (1 << mantissa_bits) - 1
local sign = (bits & sign_mask) ~= 0
local exponent = (bits & exponent_mask) >> mantissa_bits
local mantissa = bits & mantissa_mask
if exponent == 0 then
if mantissa == 0 then
return sign and "-0.0" or "0.0"
else
-- Subnormal number. They don't have implicit leading 1, so
-- the number corresponds to "0.mantissa * 2^(1-bias)".
local value = mantissa / (1 << mantissa_bits)
value = value * (2 ^ (1 - bias))
if sign then value = -value end
return string.format("%.17g", value)
end
elseif exponent == (1 << exponent_bits) - 1 then
if mantissa == 0 then
return sign and "-inf" or "inf"
else
return "nan"
end
else
-- Normal numbers have implicit leading 1, so the
-- number corresponds to "1.mantissa * 2^(exponent-bias)".
local value = 1.0 + (mantissa / (1 << mantissa_bits))
value = value * (2 ^ (exponent - bias))
if sign then value = -value end
return string.format("%.17g", value)
end
end
M.decode_f32 = function(bits)
return M.decode_float(bits, 23, 8, 127)
end
M.decode_f16 = function(bits)
return M.decode_float(bits, 10, 5, 15)
end
M.decode_bf16 = function(bits)
return M.decode_float(bits, 7, 8, 127)
end
M.encode_float = function(value_str, mantissa_bits, exponent_bits, bias)
local value = tonumber(value_str)
if not value then
return nil
end
local total_bits = 1 + exponent_bits + mantissa_bits
local max_exponent = (1 << exponent_bits) - 1
local sign_bit = value < 0 and (1 << (total_bits - 1)) or 0
local signed_inf = sign_bit | max_exponent << mantissa_bits
-- Handle various special cases first: signed zero, NaN, Inf/-Inf.
if value == 0.0 then
return sign_bit
elseif value ~= value then
return (1 << (total_bits - 1)) | (max_exponent << mantissa_bits) | 1
elseif value == math.huge or value == -math.huge then
return signed_inf
end
-- Do math with the absolute value from now on.
if sign_bit ~= 0 then value = -value end
local exponent = math.floor(math.log(value) / math.log(2))
local mantissa_value = value / (2 ^ exponent) - 1.0
exponent = exponent + bias
if exponent <= 0 then
-- Subnormal: no implicit leading 1, use minimum exponent
mantissa_value = value / (2 ^ (1 - bias))
exponent = 0
elseif exponent >= max_exponent then
-- Value too large to represent.
return signed_inf
end
local mantissa = math.floor(mantissa_value * (1 << mantissa_bits) + 0.5)
if mantissa >= (1 << mantissa_bits) then
-- Rounding caused mantissa to overflow, increment exponent.
mantissa = 0
exponent = exponent + 1
if exponent >= max_exponent then
-- Value too large to represent.
return signed_inf
end
end
return sign_bit | (exponent << mantissa_bits) | mantissa
end
M.encode_f32 = function(value_str)
return M.encode_float(value_str, 23, 8, 127)
end
M.encode_f16 = function(value_str)
return M.encode_float(value_str, 10, 5, 15)
end
M.encode_bf16 = function(value_str)
return M.encode_float(value_str, 7, 8, 127)
end
return M