Source code for quantize.quantize

"""
Implementation of quantization functions for converting between floating point and int4 values.
"""

import numpy as np
from typing import List, Union, Tuple


[docs] def quantize_to_int4( values: Union[List[float], np.ndarray], scale_method: str = "minmax" ) -> Tuple[np.ndarray, float, float]: """ Quantize floating point values to int4 values (4-bit integers). Int4 values range from -8 to 7 (16 distinct values). Args: values: List or array of floating point values to quantize scale_method: Method to determine scaling factor ('minmax' or 'absmax') Returns: Tuple of (quantized_values, scale, zero_point) - quantized_values: numpy array of int4 values (stored as int8) - scale: scaling factor used for quantization - zero_point: zero point offset (usually 0 for symmetric quantization) """ values = np.asarray(values, dtype=np.float32) # Determine scaling parameters based on the method if scale_method == "minmax": # Map the min and max values to the int4 range data_min = values.min() data_max = values.max() # Calculate scale and zero_point scale = (data_max - data_min) / 15 # 15 = 2^4 - 1 zero_point = 0 # For simplicity, we use symmetric quantization elif scale_method == "absmax": # Map the absolute max value to the int4 range abs_max = np.max(np.abs(values)) # Calculate scale (zero_point is 0 for symmetric quantization) scale = abs_max / 7 # 7 is the max positive value for int4 zero_point = 0 else: raise ValueError(f"Unknown scale_method: {scale_method}") # Avoid division by zero if scale == 0: scale = 1.0 # Quantize the values quantized = np.round(values / scale).astype(np.int8) # Clip to int4 range [-8, 7] quantized = np.clip(quantized, -8, 7) return quantized, scale, zero_point
[docs] def dequantize_from_int4( quantized_values: np.ndarray, scale: float, zero_point: float = 0 ) -> np.ndarray: """ Dequantize int4 values back to floating point. Args: quantized_values: Array of quantized int4 values (stored as int8) scale: Scaling factor used during quantization zero_point: Zero point offset (usually 0 for symmetric quantization) Returns: Array of dequantized floating point values """ # Convert to float and apply scaling return (quantized_values.astype(np.float32)) * scale
[docs] def pack_int4_to_int8(int4_values: np.ndarray) -> np.ndarray: """ Pack two int4 values into each int8 value to save memory. Args: int4_values: Array of int4 values (stored as int8) Returns: Array of packed int8 values (half the length of input) """ # Ensure we have an even number of elements by padding if necessary if len(int4_values) % 2 != 0: int4_values = np.pad(int4_values, (0, 1), 'constant') # Reshape to pairs of values pairs = int4_values.reshape(-1, 2) # Pack two int4 values into each int8 # First value goes in the lower 4 bits, second in the upper 4 bits packed = (pairs[:, 0] & 0xF) | ((pairs[:, 1] & 0xF) << 4) return packed.astype(np.int8)
[docs] def unpack_int8_to_int4(packed_values: np.ndarray) -> np.ndarray: """ Unpack int8 values back into int4 values. Args: packed_values: Array of packed int8 values Returns: Array of unpacked int4 values (twice the length of input) """ # Extract lower 4 bits for first value lower = packed_values & 0xF # Extract upper 4 bits for second value and shift down upper = (packed_values >> 4) & 0xF # Convert to signed int4 (-8 to 7) # For values 8-15, subtract 16 to get the negative representation lower = np.where(lower > 7, lower - 16, lower) upper = np.where(upper > 7, upper - 16, upper) # Interleave the values unpacked = np.empty(len(packed_values) * 2, dtype=np.int8) unpacked[0::2] = lower unpacked[1::2] = upper return unpacked
# Example usage function
[docs] def example(): """ Example demonstrating the quantization process. """ # Example floating point values float_values = np.array([0.1, 0.5, -1.3, 2.7, 3.9, -0.8, 5.2, -4.7]) print("Original values:", float_values) # Quantize to int4 quantized, scale, zero_point = quantize_to_int4(float_values) print("Quantized values (int4):", quantized) print("Scale factor:", scale) # Pack int4 values into int8 for storage efficiency packed = pack_int4_to_int8(quantized) print("Packed values (int8):", packed) print("Memory usage reduced by 50%") # Unpack back to int4 unpacked = unpack_int8_to_int4(packed) print("Unpacked values (int4):", unpacked[:len(float_values)]) # Trim any padding # Dequantize back to floating point dequantized = dequantize_from_int4(quantized, scale) print("Dequantized values:", dequantized) # Calculate error error = float_values - dequantized print("Quantization error:", error) print("Mean absolute error:", np.mean(np.abs(error))) return float_values, quantized, dequantized
if __name__ == "__main__": example()