# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
# Block floating point formats
# https://en.wikipedia.org/wiki/Block_floating_point
from dataclasses import dataclass
from typing import Callable, Iterable
import numpy as np
import numpy.typing as npt
from .decode import decode_float
from .round import RoundMode, encode_float, round_float
from .types import FormatInfo
[docs]
def decode_block(fi: BlockFormatInfo, block: Iterable[int]) -> Iterable[float]:
"""
Decode a :paramref:`block` of integer codepoints in Block Format :paramref:`fi`
The scale is encoded in the first value of :paramref:`block`,
with the remaining values encoding the block elements.
The size of the iterable is not checked against the format descriptor.
Args:
fi (BlockFormatInfo): Describes the block format
block (Iterable[int]): Input block
Returns:
A sequence of floats representing the encoded values.
"""
it = iter(block)
scale_encoding = next(it)
scale = decode_float(fi.stype, scale_encoding).fval
for val_encoding in it:
val = scale * decode_float(fi.etype, val_encoding).fval
yield val
# TODO: Assert length of block was k+1? Messy unless block is len()able
[docs]
def encode_block(
fi: BlockFormatInfo,
scale: float,
vals: Iterable[float],
round: RoundMode = RoundMode.TiesToEven,
) -> Iterable[int]:
"""
Encode float :paramref:`vals` into block Format described by :paramref:`fi`
The :paramref:`scale` is explicitly passed, and the :paramref:`vals` are
assumed to already be multiplied by `1/scale`.
That is, this is pure encoding, scaling is computed and applied elsewhere
(see e.g. :func:`quantize_block`).
It is checked for overflow in the target format,
and will raise an exception if it does.
Args:
fi (BlockFormatInfo): Describes the target block format
scale (float): Scale to be recorded in the block
vals (Iterable[float]): Input block
round (RoundMode): Rounding mode to use, defaults to `TiesToEven`
Returns:
A sequence of ints representing the encoded values.
Raises:
ValueError: The scale overflows the target scale encoding format.
"""
if scale > fi.stype.max or scale < fi.stype.min:
raise ValueError(f"Scaled {scale} out of range for {fi.stype}")
sat = True # Saturate elements if out of range
def enc(ty: FormatInfo, x: float) -> int:
return encode_float(ty, round_float(ty, x, round, sat))
yield enc(fi.stype, scale)
for val in vals:
yield enc(fi.etype, val)
ComputeScaleCallable = Callable[[float, npt.ArrayLike], float]
[docs]
def compute_scale_amax(emax: float, vals: npt.ArrayLike) -> float:
"""
Compute a scale factor such that :paramref:`vals` can be scaled to the
range [0, 2**emax]. That is, `scale` is computed such that the largest
exponent in the array `vals * scale` will be `emax`.
The scale is clipped to the range 2**[-127, 127].
If all values are zero, any scale value smaller than emax would be accurate,
but returning the smallest possible means that quick checks on the magnitude
to identify near-zero blocks will also find the all-zero blocks.
Args:
emax (float): Maximum exponent to appear in `vals * scale`
vals (ArrayLike): Input block
Returns:
A float such that `vals * scale` has exponents less than or equal to `emax`.
Note:
If all vals are zero, 1.0 is returned.
"""
amax = np.max(np.abs(vals))
if amax == 0.0:
q_log2scale = -127.0
else:
q_log2scale = np.floor(np.log2(amax)) - emax
q_log2scale = np.clip(q_log2scale, -127.0, 127.0)
return 2.0**q_log2scale
[docs]
def quantize_block(
fi: BlockFormatInfo,
vals: npt.NDArray[np.float64],
compute_scale: ComputeScaleCallable,
round: RoundMode = RoundMode.TiesToEven,
) -> npt.NDArray[np.float64]:
"""
Encode and decode a block of :paramref:`vals` of bytes into
block format described by :paramref:`fi`
Args:
fi (BlockFormatInfo): Describes the target block format
vals (numpy.array): Input block
compute_scale ((float, ArrayLike) -> float):
Callable to compute the scale, defaults to :func:`compute_scale_amax`
round (RoundMode): Rounding mode to use, defaults to `TiesToEven`
Returns:
An array of floats representing the quantized values.
Raises:
ValueError: The scale overflows the target scale encoding format.
"""
q_scale = compute_scale(fi.etype.emax, vals)
scaled_vals = vals / q_scale
enc = encode_block(fi, q_scale, scaled_vals, round)
return np.fromiter(decode_block(fi, enc), float)