diff --git a/burn-book/src/SUMMARY.md b/burn-book/src/SUMMARY.md index 3f4fcad94..6ee585637 100644 --- a/burn-book/src/SUMMARY.md +++ b/burn-book/src/SUMMARY.md @@ -25,6 +25,7 @@ - [ONNX Model](./import/onnx-model.md) - [PyTorch Model](./import/pytorch-model.md) - [Models & Pre-Trained Weights](./models-and-pretrained-weights.md) +- [Quantization (Beta)](./quantization.md) - [Advanced](./advanced/README.md) - [Backend Extension](./advanced/backend-extension/README.md) - [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md) diff --git a/burn-book/src/quantization.md b/burn-book/src/quantization.md new file mode 100644 index 000000000..f69f3f7db --- /dev/null +++ b/burn-book/src/quantization.md @@ -0,0 +1,122 @@ +# Quantization (Beta) + +Quantization techniques perform computations and store tensors in lower precision data types like +8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep +learning model categorized as: + +- Post-training quantization (PTQ) +- Quantization aware training (QAT) + +In post-training quantization, the model is trained in floating point precision and later converted +to the lower precision data type. + +There are two types of post-training quantization: + +1. Static quantization: quantizes the weights and activations of the model. Quantizing the + activations statically requires data to be calibrated (i.e., recording the activation values to + compute the optimal quantization parameters with representative data). +1. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the + activations are dynamically at runtime. + +Sometimes post-training quantization is not able to achieve acceptable task accuracy. This is where +quantization aware training comes into play, as it models the effects of quantization during +training. Quantization errors are thus modeled in the forward and backward passes using fake +quantization modules, which helps the model learn representations that are more robust to the +reduction in precision. + +
+ +Quantization support in Burn is currently in active development. + +It supports the following modes on some backends: + +- Static per-tensor quantization to signed 8-bit integer (`i8`) + +No integer operations are currently supported, which means tensors are dequantized to perform the +operations in floating point precision. + +
+ +## Module Quantization + +Quantizing the weights of your model after training is quite simple. We have access to the weight +tensors and can collect their statistics, such as the min and max value when using +`MinMaxCalibration`, to compute the quantization parameters. + +```rust , ignore +# use burn::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType, Quantizer}; +# +// Quantization config +let mut quantizer = Quantizer { + calibration: MinMaxCalibration { + scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), + }, +}; + +// Quantize the weights +let model = model.quantize_weights(&mut quantizer); +``` + +> Given that all operations are currently performed in floating point precision, it might be wise to +> dequantize the module parameters before inference. This allows us to save disk space by storing +> the model in reduced precision while preserving the inference speed. +> +> This can easily be implemented with a `ModuleMapper`. +> +> ```rust, ignore +> # use burn::module::{ModuleMapper, ParamId}; +> # use burn::tensor::{backend::Backend, Tensor}; +> # +> /// Module mapper used to dequantize the model params being loaded. +> pub struct Dequantize {} +> +> impl ModuleMapper for Dequantize { +> fn map_float( +> &mut self, +> _id: &ParamId, +> tensor: Tensor, +> ) -> Tensor { +> tensor.dequantize() +> } +> } +> +> // Load saved quantized model in floating point precision +> model = model +> .load_file(file_path, recorder, &device) +> .expect("Should be able to load the quantized model weights") +> .map(&mut Dequantize {}); +> ``` + +### Calibration + +Calibration is the step during quantization where the range of all floating-point tensors is +computed. This is pretty straightforward for weights since the actual range is known at +_quantization-time_ (weights are static), but activations require more attention. + +To compute the quantization parameters, Burn supports the following `Calibration` methods. + +| Method | Description | +| :------------------ | :------------------------------------------------------------------------------- | +| `MinMaxCalibration` | Computes the quantization range mapping based on the running min and max values. | + +### Quantization Scheme + +A quantization scheme defines the quantized type, quantization granularity and range mapping +technique. + +Burn currently supports the following `QuantizationType` variants. + +| Type | Description | +| :------ | :--------------------------------- | +| `QInt8` | 8-bit signed integer quantization. | + +Quantization parameters are defined based on the range of values to represent and can typically be +calculated for the layer's entire weight tensor with per-tensor quantization or separately for each +channel with per-channel quantization (commonly used with CNNs). + +Burn currently supports the following `QuantizationScheme` variants. + +| Variant | Description | +| :------------------- | :------------------------------------------------------------------------------------------------------------- | +| `PerTensorAffine` | Computes the quantization parameters for the whole tensor and applies an affine range mapping with zero point. | +| `PerTensorSymmetric` | Computes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0. | diff --git a/crates/burn-autodiff/src/ops/qtensor.rs b/crates/burn-autodiff/src/ops/qtensor.rs index 7dbbf5730..ac5a1a391 100644 --- a/crates/burn-autodiff/src/ops/qtensor.rs +++ b/crates/burn-autodiff/src/ops/qtensor.rs @@ -1,12 +1,19 @@ use burn_tensor::{ backend::Backend, ops::{FloatTensor, QTensorOps, QuantizedTensor}, - Device, QuantizationStrategy, Shape, + Device, QuantizationStrategy, Shape, TensorData, }; use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff}; impl QTensorOps for Autodiff { + fn q_from_data( + _data: TensorData, + _device: &Device, + ) -> QuantizedTensor { + todo!() + } + fn quantize( _tensor: FloatTensor, _strategy: &QuantizationStrategy, @@ -28,4 +35,18 @@ impl QTensorOps for Autodiff { fn q_device(tensor: &QuantizedTensor) -> Device { B::q_device(tensor) } + + fn q_reshape( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor { + B::q_reshape(tensor, shape) + } + + async fn q_into_data( + tensor: QuantizedTensor, + strategy: QuantizationStrategy, + ) -> TensorData { + B::q_into_data(tensor, strategy).await + } } diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs index 3de0d23ff..7d9784cef 100644 --- a/crates/burn-candle/src/ops/qtensor.rs +++ b/crates/burn-candle/src/ops/qtensor.rs @@ -1,7 +1,7 @@ use burn_tensor::{ backend::Backend, ops::{FloatTensor, QTensorOps, QuantizedTensor}, - Device, QuantizationStrategy, Shape, + DType, Device, QuantizationStrategy, Shape, TensorData, }; use crate::{ @@ -10,6 +10,13 @@ use crate::{ }; impl QTensorOps for Candle { + fn q_from_data( + data: TensorData, + device: &Device, + ) -> QuantizedTensor { + unimplemented!() // no i8 support + } + fn quantize( _tensor: FloatTensor, _strategy: &QuantizationStrategy, @@ -31,4 +38,18 @@ impl QTensorOps for Candle(tensor: &QuantizedTensor) -> Device { super::base::device(tensor) } + + fn q_reshape( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor { + super::base::reshape(tensor, shape) + } + + async fn q_into_data( + tensor: QuantizedTensor, + strategy: QuantizationStrategy, + ) -> TensorData { + super::base::into_data(tensor) + } } diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index 4bfb8f900..42c7ef3a3 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -33,6 +33,9 @@ pub mod module; /// Neural network module. pub mod nn; +/// Quantization module. +pub mod quantization; + /// Module for the recorder. pub mod record; diff --git a/crates/burn-core/src/module/base.rs b/crates/burn-core/src/module/base.rs index c34c155e8..0c3895368 100644 --- a/crates/burn-core/src/module/base.rs +++ b/crates/burn-core/src/module/base.rs @@ -1,5 +1,6 @@ use super::ParamId; use crate::{ + quantization::{Calibration, Quantizer}, record::Record, tensor::backend::{AutodiffBackend, Backend}, }; @@ -202,6 +203,11 @@ pub trait Module: Clone + Send + core::fmt::Debug { Ok(self.load_record(record)) } + + /// Quantize the weights of the module. + fn quantize_weights(self, quantizer: &mut Quantizer) -> Self { + self.map(quantizer) + } } /// Module visitor trait. diff --git a/crates/burn-core/src/quantization/calibration.rs b/crates/burn-core/src/quantization/calibration.rs new file mode 100644 index 000000000..81e466d20 --- /dev/null +++ b/crates/burn-core/src/quantization/calibration.rs @@ -0,0 +1,80 @@ +use burn_tensor::{ + backend::Backend, AffineQuantization, ElementConversion, Quantization, QuantizationStrategy, + SymmetricQuantization, Tensor, +}; + +use super::{QuantizationScheme, QuantizationType}; + +/// Calibration method used to compute the quantization range mapping. +pub trait Calibration { + /// Configure the quantization strategy. + fn configure(&self, tensor: &Tensor) -> QuantizationStrategy; +} + +/// Computes the quantization range mapping based on the running min and max values. +pub struct MinMaxCalibration { + /// Quantization scheme to be used. + pub scheme: QuantizationScheme, +} + +impl Calibration for MinMaxCalibration { + fn configure(&self, tensor: &Tensor) -> QuantizationStrategy { + let min = tensor.clone().min().into_scalar().elem::(); + let max = tensor.clone().max().into_scalar().elem::(); + + match &self.scheme { + QuantizationScheme::PerTensorAffine(dtype) => match dtype { + QuantizationType::QInt8 => { + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(min, max)) + } + }, + QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { + QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8( + SymmetricQuantization::new(min, max), + ), + }, + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::TestBackend; + + #[test] + fn min_max_calibration_per_tensor_affine_int8() { + let device = ::Device::default(); + let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let calibration = MinMaxCalibration { + scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8), + }; + + let strategy = calibration.configure(&tensor); + + if let QuantizationStrategy::PerTensorAffineInt8(q) = strategy { + assert_eq!(q.scale, 0.009_019_608); + assert_eq!(q.offset, 72); + } else { + panic!("Wrong quantization strategy"); + } + } + + #[test] + fn min_max_calibration_per_tensor_symmetric_int8() { + let device = ::Device::default(); + let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let calibration = MinMaxCalibration { + scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), + }; + + let strategy = calibration.configure(&tensor); + + if let QuantizationStrategy::PerTensorSymmetricInt8(q) = strategy { + assert_eq!(q.scale, 0.014_173_228); + } else { + panic!("Wrong quantization strategy"); + } + } +} diff --git a/crates/burn-core/src/quantization/mod.rs b/crates/burn-core/src/quantization/mod.rs new file mode 100644 index 000000000..38b3bb604 --- /dev/null +++ b/crates/burn-core/src/quantization/mod.rs @@ -0,0 +1,7 @@ +mod calibration; +mod quantize; +mod scheme; + +pub use calibration::*; +pub use quantize::*; +pub use scheme::*; diff --git a/crates/burn-core/src/quantization/quantize.rs b/crates/burn-core/src/quantization/quantize.rs new file mode 100644 index 000000000..040886a1a --- /dev/null +++ b/crates/burn-core/src/quantization/quantize.rs @@ -0,0 +1,18 @@ +use burn_tensor::{backend::Backend, Tensor}; + +use crate::module::{ModuleMapper, ParamId}; + +use super::Calibration; + +/// Describes how to quantize a module. +pub struct Quantizer { + /// The calibration method used in quantization. + pub calibration: C, +} + +impl ModuleMapper for Quantizer { + fn map_float(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + let strategy = self.calibration.configure(&tensor); + tensor.quantize(strategy) + } +} diff --git a/crates/burn-core/src/quantization/scheme.rs b/crates/burn-core/src/quantization/scheme.rs new file mode 100644 index 000000000..db11c22ae --- /dev/null +++ b/crates/burn-core/src/quantization/scheme.rs @@ -0,0 +1,17 @@ +/// Quantization data type. +pub enum QuantizationType { + /// 8-bit signed integer. + QInt8, +} + +/// Quantization scheme. +pub enum QuantizationScheme { + /// Per-tensor affine/asymmetric quantization. + PerTensorAffine(QuantizationType), + /// Per-tensor symmetric quantization. + PerTensorSymmetric(QuantizationType), + // /// Per-channel affine/asymmetric quantization. + // PerChannelAffine, + // /// Per-channel symmetric quantization. + // PerChannelSymmetric, +} diff --git a/crates/burn-core/src/record/tensor.rs b/crates/burn-core/src/record/tensor.rs index 904210359..ab6f448b7 100644 --- a/crates/burn-core/src/record/tensor.rs +++ b/crates/burn-core/src/record/tensor.rs @@ -1,7 +1,7 @@ use core::marker::PhantomData; use super::{PrecisionSettings, Record}; -use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData}; +use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData}; use serde::{Deserialize, Serialize}; #[cfg(not(feature = "record-backward-compat"))] @@ -43,7 +43,12 @@ where e )) })?; - Ok(data.convert::()) + let data = if let DType::QFloat(_) = data.dtype { + data // do not convert quantized tensors + } else { + data.convert::() + }; + Ok(data) } } @@ -137,15 +142,25 @@ impl Record for Tensor { type Item = FloatTensorSerde; fn into_item(self) -> Self::Item { - FloatTensorSerde::new(self.into_data().convert::()) + let data = self.into_data(); + let data = if let DType::QFloat(_) = data.dtype { + data // do not convert quantized tensors + } else { + data.convert::() + }; + FloatTensorSerde::new(data) } fn from_item(item: Self::Item, device: &B::Device) -> Self { - Tensor::from_data(item.data.convert::(), device) + let data = if let DType::QFloat(_) = item.data.dtype { + item.data // do not convert quantized tensors + } else { + item.data.convert::() + }; + Tensor::from_data(data, device) } } -#[allow(deprecated)] impl Record for Tensor { type Item = IntTensorSerde; @@ -158,7 +173,6 @@ impl Record for Tensor { } } -#[allow(deprecated)] impl Record for Tensor { type Item = BoolTensorSerde; diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index b0f340aa9..087f826fd 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -1,12 +1,19 @@ use burn_tensor::{ backend::Backend, ops::{QTensorOps, QuantizedTensor}, - Device, QuantizationStrategy, Shape, + Device, QuantizationStrategy, Shape, TensorData, }; use crate::{client::FusionClient, Fusion, FusionBackend}; impl QTensorOps for Fusion { + fn q_from_data( + _data: TensorData, + _device: &Device, + ) -> QuantizedTensor { + unimplemented!() + } + fn quantize( _tensor: ::FloatTensorPrimitive, _strategy: &QuantizationStrategy, @@ -28,4 +35,18 @@ impl QTensorOps for Fusion { fn q_device(tensor: &QuantizedTensor) -> Device { tensor.client.device().clone() } + + fn q_reshape( + _tensor: QuantizedTensor, + _shape: Shape, + ) -> QuantizedTensor { + unimplemented!() + } + + async fn q_into_data( + _tensor: QuantizedTensor, + _strategy: QuantizationStrategy, + ) -> TensorData { + unimplemented!() + } } diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index 76577c103..8c87c2d53 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -1,6 +1,6 @@ use burn_tensor::{ ops::{FloatTensor, QTensorOps, QuantizedTensor}, - Device, QuantizationStrategy, Shape, + Device, QuantizationStrategy, Shape, TensorData, }; use crate::{FloatElement, IntElement, JitBackend, JitRuntime}; @@ -11,6 +11,13 @@ where F: FloatElement, I: IntElement, { + fn q_from_data( + _data: TensorData, + _device: &Device, + ) -> QuantizedTensor { + todo!() + } + fn quantize( _tensor: FloatTensor, _strategy: &QuantizationStrategy, @@ -32,4 +39,18 @@ where fn q_device(tensor: &QuantizedTensor) -> Device { tensor.device.clone() } + + fn q_reshape( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor { + super::reshape(tensor, shape) + } + + async fn q_into_data( + _tensor: QuantizedTensor, + _strategy: QuantizationStrategy, + ) -> TensorData { + unimplemented!() + } } diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index f551c5b49..77274c321 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -1,10 +1,12 @@ use burn_tensor::{ ops::{FloatTensor, QTensorOps, QuantizedTensor}, - Quantization, QuantizationStrategy, Shape, TensorData, + DType, Quantization, QuantizationStrategy, Shape, TensorData, }; use crate::{element::NdArrayElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor}; +use super::NdArrayOps; + fn into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); @@ -12,6 +14,28 @@ fn into_data(tensor: NdArrayTensor) -> } impl QTensorOps for NdArray { + fn q_from_data( + data: TensorData, + _device: &NdArrayDevice, + ) -> QuantizedTensor { + match data.dtype { + DType::QFloat(strategy) => match strategy { + QuantizationStrategy::PerTensorAffineInt8(_) => { + let data = data.convert::(); + NdArrayTensor::::from_data(data) + } + QuantizationStrategy::PerTensorSymmetricInt8(_) => { + let data = data.convert::(); + NdArrayTensor::::from_data(data) + } + }, + _ => panic!( + "Invalid dtype (expected DType::QFloat, got {:?})", + data.dtype + ), + } + } + fn quantize( tensor: FloatTensor, strategy: &QuantizationStrategy, @@ -41,4 +65,20 @@ impl QTensorOps for NdArray { fn q_device(_tensor: &QuantizedTensor) -> NdArrayDevice { NdArrayDevice::Cpu } + + fn q_reshape( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor { + NdArrayOps::reshape(tensor, shape) + } + + async fn q_into_data( + tensor: QuantizedTensor, + strategy: QuantizationStrategy, + ) -> TensorData { + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + TensorData::quantized(values, shape, strategy) + } } diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 6a09a0239..e557c4ae9 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -1,4 +1,4 @@ -use burn_tensor::Shape; +use burn_tensor::{QuantizationStrategy, Shape}; use tch::Scalar; use crate::{LibTorchDevice, TchShape, TchTensor}; @@ -512,4 +512,30 @@ impl TchOps { ) -> TchTensor { TchTensor::new(tensor.tensor.argsort(dim as i64, descending)) } + + pub fn quantize( + tensor: TchTensor, + strategy: &QuantizationStrategy, + ) -> TchTensor { + let mut tensor = tensor; + // Quantize only works on Float Tensor + if tensor.tensor.kind() == tch::Kind::Half { + tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float); + } + + match strategy { + QuantizationStrategy::PerTensorAffineInt8(ref q) => { + TchTensor::new(tensor.tensor.quantize_per_tensor( + q.scale.into(), + q.offset.into(), + tch::Kind::QInt8, + )) + } + QuantizationStrategy::PerTensorSymmetricInt8(ref q) => TchTensor::new( + tensor + .tensor + .quantize_per_tensor(q.scale.into(), 0, tch::Kind::QInt8), + ), + } + } } diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 60d317576..73238e787 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -1,15 +1,63 @@ use burn_tensor::{ ops::{FloatTensor, QTensorOps, QuantizedTensor}, - QuantizationStrategy, Shape, + DType, Quantization, QuantizationStrategy, Shape, TensorData, }; -use crate::{LibTorch, LibTorchDevice, TchElement, TchTensor}; +use crate::{LibTorch, LibTorchDevice, TchElement, TchShape, TchTensor}; + +use super::TchOps; impl QTensorOps for LibTorch { + fn q_from_data( + data: TensorData, + device: &LibTorchDevice, + ) -> QuantizedTensor { + let shape_tch = TchShape::::from(data.shape.as_slice()); + let device = (*device).into(); + + // NOTE: tch-rs doesn't have `from_blob_quantized_*` APIs + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/quantized/Quantizer.cpp#L322 + // So for now we have to load the dequantized values to quantize them back since the dequantization + // methods take the values provided when quantizing. + let tensor = match data.dtype { + DType::QFloat(strategy) => match strategy { + QuantizationStrategy::PerTensorAffineInt8(q) => { + let values = q.dequantize(&data.iter::().collect::>()); + let tensor = tch::Tensor::from_slice(&values).to(device); + TchOps::::quantize::( + TchTensor::new(tensor.reshape(shape_tch.dims)), + &strategy, + ) + .tensor + } + QuantizationStrategy::PerTensorSymmetricInt8(q) => { + let values = q.dequantize(&data.iter::().collect::>()); + let tensor = tch::Tensor::from_slice(&values).to(device); + TchOps::::quantize::( + TchTensor::new(tensor.reshape(shape_tch.dims)), + &strategy, + ) + .tensor + } + }, + _ => panic!( + "Invalid dtype (expected DType::QFloat, got {:?})", + data.dtype + ), + }; + TchTensor::new(tensor) + } + fn quantize( tensor: FloatTensor, strategy: &QuantizationStrategy, ) -> QuantizedTensor { + let mut tensor = tensor; + // Quantize only works on Float Tensor + if E::dtype() == DType::F16 { + tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float); + } + match strategy { QuantizationStrategy::PerTensorAffineInt8(ref q) => { TchTensor::new(tensor.tensor.quantize_per_tensor( @@ -30,7 +78,7 @@ impl QTensorOps for LibTorch { tensor: QuantizedTensor, _strategy: &QuantizationStrategy, ) -> FloatTensor { - TchTensor::new(tensor.tensor.dequantize()) + TchTensor::new(tensor.tensor.dequantize().to_kind(E::KIND)) } fn q_shape(tensor: &QuantizedTensor) -> Shape { @@ -40,4 +88,23 @@ impl QTensorOps for LibTorch { fn q_device(tensor: &QuantizedTensor) -> LibTorchDevice { tensor.tensor.device().into() } + + fn q_reshape( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor { + TchOps::reshape(tensor, shape) + } + + async fn q_into_data( + tensor: QuantizedTensor, + strategy: QuantizationStrategy, + ) -> TensorData { + let shape = Self::q_shape(&tensor); + let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()])); + // To get the integer values we have to call `int_repr()` + let values: Result, tch::TchError> = tensor.tensor.int_repr().try_into(); + + TensorData::quantized(values.unwrap(), shape, strategy) + } } diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 0d44aa4fc..0fd16c151 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -18,7 +18,7 @@ use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; -use crate::{Element, TensorPrimitive}; +use crate::{DType, Element, TensorPrimitive}; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] @@ -1697,7 +1697,15 @@ impl BasicOps for Float { tensor: Self::Primitive, shape: Shape, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_reshape(tensor.tensor(), shape)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_reshape(tensor, shape)) + } + TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat { + tensor: B::q_reshape(tensor, shape), + strategy, + }, + } } fn transpose(tensor: Self::Primitive) -> Self::Primitive { @@ -1750,11 +1758,20 @@ impl BasicOps for Float { } async fn into_data_async(tensor: Self::Primitive) -> TensorData { - B::float_into_data(tensor.tensor()).await + match tensor { + TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await, + TensorPrimitive::QFloat { tensor, strategy } => B::q_into_data(tensor, strategy).await, + } } fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive { - TensorPrimitive::Float(B::float_from_data(data, device)) + match data.dtype { + DType::QFloat(strategy) => TensorPrimitive::QFloat { + tensor: B::q_from_data(data, device), + strategy, + }, + _ => TensorPrimitive::Float(B::float_from_data(data, device)), + } } fn repeat( diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index 7da52eb1b..ef320d7b8 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -271,9 +271,9 @@ where match &self.primitive { TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor), TensorPrimitive::QFloat { - tensor: _, + tensor, strategy: _, - } => B::float_is_require_grad(&self.primitive.clone().tensor()), + } => B::q_is_require_grad(tensor), } } @@ -282,10 +282,16 @@ where /// /// This function does nothing when autodiff is not enabled. pub fn set_require_grad(self, require_grad: bool) -> Self { - Self::new(TensorPrimitive::Float(B::float_set_require_grad( - self.primitive.tensor(), - require_grad, - ))) + let primitive = match self.primitive { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad)) + } + TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat { + tensor: B::q_set_require_grad(tensor, require_grad), + strategy, + }, + }; + Self::new(primitive) } /// Applies the relu function to the tensor. diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index f3c8b3843..edf293d74 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -48,6 +48,15 @@ impl TensorData { Self::init(value, shape, E::dtype()) } + /// Creates a new quantized tensor data structure. + pub fn quantized>>( + value: Vec, + shape: S, + strategy: QuantizationStrategy, + ) -> Self { + Self::init(value, shape, DType::QFloat(strategy)) + } + /// Initializes a new tensor data structure from the provided values. fn init>>(value: Vec, shape: S, dtype: DType) -> Self { Self { @@ -258,15 +267,15 @@ impl TensorData { "Only f32 data type can be quantized" ); match &quantization { - QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::init( + QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::quantized( strategy.quantize(self.as_slice().unwrap()), self.shape, - DType::QFloat(quantization), + quantization, ), - QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::init( + QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::quantized( strategy.quantize(self.as_slice().unwrap()), self.shape, - DType::QFloat(quantization), + quantization, ), } } diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 9163f9df7..458c156f9 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -1,10 +1,24 @@ -use crate::{backend::Backend, Device, QuantizationStrategy, Shape}; +use core::future::Future; + +use crate::{backend::Backend, Device, QuantizationStrategy, Shape, TensorData}; use super::{FloatTensor, QuantizedTensor}; /// Quantized Tensor API for basic operations, see [tensor](crate::Tensor) /// for documentation on each function. pub trait QTensorOps { + /// Creates a new tensor from the data structure. + /// + /// # Arguments + /// + /// * `data` - The data structure. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given data. + fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor; + /// Convert the tensor to a lower precision data type based on the quantization strategy. fn quantize( tensor: FloatTensor, @@ -38,4 +52,48 @@ pub trait QTensorOps { /// /// The device of the tensor. fn q_device(tensor: &QuantizedTensor) -> Device; + + /// Reshapes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reshape. + /// * `shape` - The new shape of the tensor. + /// + /// # Returns + /// + /// The tensor with the new shape. + fn q_reshape( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor; + + /// Converts the tensor to a data structure. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The data structure with the tensor's data. + fn q_into_data( + tensor: QuantizedTensor, + strategy: QuantizationStrategy, + ) -> impl Future + Send; + + /// Sets the `require_grad` flag of a tensor. + fn q_set_require_grad( + tensor: QuantizedTensor, + _require_grad: bool, + ) -> QuantizedTensor { + // Should only be overridden by autodiff backends. + tensor + } + + /// Returns the `require_grad` flag of a tensor. + fn q_is_require_grad(_tensor: &QuantizedTensor) -> bool { + // Should only be overridden by autodiff backends. + false + } } diff --git a/crates/burn-tensor/src/tensor/quantization_strategy.rs b/crates/burn-tensor/src/tensor/quantization_strategy.rs index ab029df17..95e1460a5 100644 --- a/crates/burn-tensor/src/tensor/quantization_strategy.rs +++ b/crates/burn-tensor/src/tensor/quantization_strategy.rs @@ -8,7 +8,7 @@ use burn_common::{iter_par, run_par}; use num_traits::{Float, PrimInt}; use serde::{Deserialize, Serialize}; -/// Quantization scheme/strategy. +/// Quantization strategy. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] pub enum QuantizationStrategy { /// Per-tensor `int8` affine/asymmetric quantization. diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index a85269c1c..9f498a1cd 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -47,6 +47,18 @@ //! - Autodiff: Backend decorator that brings backpropagation to any backend //! - Fusion: Backend decorator that brings kernel fusion to backends that support it //! +//! # Quantization (Beta) +//! +//! Quantization techniques perform computations and store tensors in lower precision data types like 8-bit integer +//! instead of floating point precision. There are multiple approaches to quantize a deep learning model. In most cases, +//! the model is trained in floating point precision and later converted to the lower precision data type. This is called +//! post-training quantization (PTQ). On the other hand, quantization aware training (QAT) models the effects of quantization +//! during training. Quantization errors are thus modeled in the forward and backward passes, which helps the model learn +//! representations that are more robust to the reduction in precision. +//! +//! Quantization support in Burn is currently in active development. It supports the following modes on some backends: +//! - Static per-tensor quantization to signed 8-bit integer (`i8`) +//! //! ## Feature Flags //! //! The following feature flags are available.