Module weight quantization (#2000)

* Add q_into_data and q_reshape

* Fix tch quantize f16 and q_into_data

* Convert to actual dtype/kind in dequantize

* Add module quantization and q_from_data

* Fix clippy

* Add documentation

* Handle deserialize data conversion

* Fix typo

* Add calibration tests

* Fix clippy precision

* Add QTensorOps require_grad methods to avoid dequantizing

* Add Dequantize mapper docs

* Remove dead code
This commit is contained in:
Guillaume Lagrange 2024-07-15 08:20:37 -04:00 committed by GitHub
parent a4123f6c2e
commit 3afff434bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 618 additions and 31 deletions

View File

@ -25,6 +25,7 @@
- [ONNX Model](./import/onnx-model.md) - [ONNX Model](./import/onnx-model.md)
- [PyTorch Model](./import/pytorch-model.md) - [PyTorch Model](./import/pytorch-model.md)
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md) - [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
- [Quantization (Beta)](./quantization.md)
- [Advanced](./advanced/README.md) - [Advanced](./advanced/README.md)
- [Backend Extension](./advanced/backend-extension/README.md) - [Backend Extension](./advanced/backend-extension/README.md)
- [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md) - [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)

View File

@ -0,0 +1,122 @@
# Quantization (Beta)
Quantization techniques perform computations and store tensors in lower precision data types like
8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep
learning model categorized as:
- Post-training quantization (PTQ)
- Quantization aware training (QAT)
In post-training quantization, the model is trained in floating point precision and later converted
to the lower precision data type.
There are two types of post-training quantization:
1. Static quantization: quantizes the weights and activations of the model. Quantizing the
activations statically requires data to be calibrated (i.e., recording the activation values to
compute the optimal quantization parameters with representative data).
1. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the
activations are dynamically at runtime.
Sometimes post-training quantization is not able to achieve acceptable task accuracy. This is where
quantization aware training comes into play, as it models the effects of quantization during
training. Quantization errors are thus modeled in the forward and backward passes using fake
quantization modules, which helps the model learn representations that are more robust to the
reduction in precision.
<div class="warning">
Quantization support in Burn is currently in active development.
It supports the following modes on some backends:
- Static per-tensor quantization to signed 8-bit integer (`i8`)
No integer operations are currently supported, which means tensors are dequantized to perform the
operations in floating point precision.
</div>
## Module Quantization
Quantizing the weights of your model after training is quite simple. We have access to the weight
tensors and can collect their statistics, such as the min and max value when using
`MinMaxCalibration`, to compute the quantization parameters.
```rust , ignore
# use burn::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType, Quantizer};
#
// Quantization config
let mut quantizer = Quantizer {
calibration: MinMaxCalibration {
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
},
};
// Quantize the weights
let model = model.quantize_weights(&mut quantizer);
```
> Given that all operations are currently performed in floating point precision, it might be wise to
> dequantize the module parameters before inference. This allows us to save disk space by storing
> the model in reduced precision while preserving the inference speed.
>
> This can easily be implemented with a `ModuleMapper`.
>
> ```rust, ignore
> # use burn::module::{ModuleMapper, ParamId};
> # use burn::tensor::{backend::Backend, Tensor};
> #
> /// Module mapper used to dequantize the model params being loaded.
> pub struct Dequantize {}
>
> impl<B: Backend> ModuleMapper<B> for Dequantize {
> fn map_float<const D: usize>(
> &mut self,
> _id: &ParamId,
> tensor: Tensor<B, D>,
> ) -> Tensor<B, D> {
> tensor.dequantize()
> }
> }
>
> // Load saved quantized model in floating point precision
> model = model
> .load_file(file_path, recorder, &device)
> .expect("Should be able to load the quantized model weights")
> .map(&mut Dequantize {});
> ```
### Calibration
Calibration is the step during quantization where the range of all floating-point tensors is
computed. This is pretty straightforward for weights since the actual range is known at
_quantization-time_ (weights are static), but activations require more attention.
To compute the quantization parameters, Burn supports the following `Calibration` methods.
| Method | Description |
| :------------------ | :------------------------------------------------------------------------------- |
| `MinMaxCalibration` | Computes the quantization range mapping based on the running min and max values. |
### Quantization Scheme
A quantization scheme defines the quantized type, quantization granularity and range mapping
technique.
Burn currently supports the following `QuantizationType` variants.
| Type | Description |
| :------ | :--------------------------------- |
| `QInt8` | 8-bit signed integer quantization. |
Quantization parameters are defined based on the range of values to represent and can typically be
calculated for the layer's entire weight tensor with per-tensor quantization or separately for each
channel with per-channel quantization (commonly used with CNNs).
Burn currently supports the following `QuantizationScheme` variants.
| Variant | Description |
| :------------------- | :------------------------------------------------------------------------------------------------------------- |
| `PerTensorAffine` | Computes the quantization parameters for the whole tensor and applies an affine range mapping with zero point. |
| `PerTensorSymmetric` | Computes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0. |

View File

@ -1,12 +1,19 @@
use burn_tensor::{ use burn_tensor::{
backend::Backend, backend::Backend,
ops::{FloatTensor, QTensorOps, QuantizedTensor}, ops::{FloatTensor, QTensorOps, QuantizedTensor},
Device, QuantizationStrategy, Shape, Device, QuantizationStrategy, Shape, TensorData,
}; };
use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff}; use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff};
impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> { impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
fn q_from_data<const D: usize>(
_data: TensorData,
_device: &Device<Self>,
) -> QuantizedTensor<Self, D> {
todo!()
}
fn quantize<const D: usize>( fn quantize<const D: usize>(
_tensor: FloatTensor<Self, D>, _tensor: FloatTensor<Self, D>,
_strategy: &QuantizationStrategy, _strategy: &QuantizationStrategy,
@ -28,4 +35,18 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> { fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
B::q_device(tensor) B::q_device(tensor)
} }
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
B::q_reshape(tensor, shape)
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
B::q_into_data(tensor, strategy).await
}
} }

View File

@ -1,7 +1,7 @@
use burn_tensor::{ use burn_tensor::{
backend::Backend, backend::Backend,
ops::{FloatTensor, QTensorOps, QuantizedTensor}, ops::{FloatTensor, QTensorOps, QuantizedTensor},
Device, QuantizationStrategy, Shape, DType, Device, QuantizationStrategy, Shape, TensorData,
}; };
use crate::{ use crate::{
@ -10,6 +10,13 @@ use crate::{
}; };
impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> { impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> {
fn q_from_data<const D: usize>(
data: TensorData,
device: &Device<Self>,
) -> QuantizedTensor<Self, D> {
unimplemented!() // no i8 support
}
fn quantize<const D: usize>( fn quantize<const D: usize>(
_tensor: FloatTensor<Self, D>, _tensor: FloatTensor<Self, D>,
_strategy: &QuantizationStrategy, _strategy: &QuantizationStrategy,
@ -31,4 +38,18 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> { fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
super::base::device(tensor) super::base::device(tensor)
} }
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
super::base::reshape(tensor, shape)
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
super::base::into_data(tensor)
}
} }

View File

@ -33,6 +33,9 @@ pub mod module;
/// Neural network module. /// Neural network module.
pub mod nn; pub mod nn;
/// Quantization module.
pub mod quantization;
/// Module for the recorder. /// Module for the recorder.
pub mod record; pub mod record;

View File

@ -1,5 +1,6 @@
use super::ParamId; use super::ParamId;
use crate::{ use crate::{
quantization::{Calibration, Quantizer},
record::Record, record::Record,
tensor::backend::{AutodiffBackend, Backend}, tensor::backend::{AutodiffBackend, Backend},
}; };
@ -202,6 +203,11 @@ pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
Ok(self.load_record(record)) Ok(self.load_record(record))
} }
/// Quantize the weights of the module.
fn quantize_weights<C: Calibration>(self, quantizer: &mut Quantizer<C>) -> Self {
self.map(quantizer)
}
} }
/// Module visitor trait. /// Module visitor trait.

View File

@ -0,0 +1,80 @@
use burn_tensor::{
backend::Backend, AffineQuantization, ElementConversion, Quantization, QuantizationStrategy,
SymmetricQuantization, Tensor,
};
use super::{QuantizationScheme, QuantizationType};
/// Calibration method used to compute the quantization range mapping.
pub trait Calibration {
/// Configure the quantization strategy.
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy;
}
/// Computes the quantization range mapping based on the running min and max values.
pub struct MinMaxCalibration {
/// Quantization scheme to be used.
pub scheme: QuantizationScheme,
}
impl Calibration for MinMaxCalibration {
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy {
let min = tensor.clone().min().into_scalar().elem::<f32>();
let max = tensor.clone().max().into_scalar().elem::<f32>();
match &self.scheme {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(min, max))
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8(
SymmetricQuantization::new(min, max),
),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn min_max_calibration_per_tensor_affine_int8() {
let device = <TestBackend as Backend>::Device::default();
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
let calibration = MinMaxCalibration {
scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
};
let strategy = calibration.configure(&tensor);
if let QuantizationStrategy::PerTensorAffineInt8(q) = strategy {
assert_eq!(q.scale, 0.009_019_608);
assert_eq!(q.offset, 72);
} else {
panic!("Wrong quantization strategy");
}
}
#[test]
fn min_max_calibration_per_tensor_symmetric_int8() {
let device = <TestBackend as Backend>::Device::default();
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
let calibration = MinMaxCalibration {
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
};
let strategy = calibration.configure(&tensor);
if let QuantizationStrategy::PerTensorSymmetricInt8(q) = strategy {
assert_eq!(q.scale, 0.014_173_228);
} else {
panic!("Wrong quantization strategy");
}
}
}

View File

@ -0,0 +1,7 @@
mod calibration;
mod quantize;
mod scheme;
pub use calibration::*;
pub use quantize::*;
pub use scheme::*;

View File

@ -0,0 +1,18 @@
use burn_tensor::{backend::Backend, Tensor};
use crate::module::{ModuleMapper, ParamId};
use super::Calibration;
/// Describes how to quantize a module.
pub struct Quantizer<C: Calibration> {
/// The calibration method used in quantization.
pub calibration: C,
}
impl<B: Backend, C: Calibration> ModuleMapper<B> for Quantizer<C> {
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let strategy = self.calibration.configure(&tensor);
tensor.quantize(strategy)
}
}

View File

@ -0,0 +1,17 @@
/// Quantization data type.
pub enum QuantizationType {
/// 8-bit signed integer.
QInt8,
}
/// Quantization scheme.
pub enum QuantizationScheme {
/// Per-tensor affine/asymmetric quantization.
PerTensorAffine(QuantizationType),
/// Per-tensor symmetric quantization.
PerTensorSymmetric(QuantizationType),
// /// Per-channel affine/asymmetric quantization.
// PerChannelAffine,
// /// Per-channel symmetric quantization.
// PerChannelSymmetric,
}

View File

@ -1,7 +1,7 @@
use core::marker::PhantomData; use core::marker::PhantomData;
use super::{PrecisionSettings, Record}; use super::{PrecisionSettings, Record};
use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData}; use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(not(feature = "record-backward-compat"))] #[cfg(not(feature = "record-backward-compat"))]
@ -43,7 +43,12 @@ where
e e
)) ))
})?; })?;
Ok(data.convert::<E>()) let data = if let DType::QFloat(_) = data.dtype {
data // do not convert quantized tensors
} else {
data.convert::<E>()
};
Ok(data)
} }
} }
@ -137,15 +142,25 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
type Item<S: PrecisionSettings> = FloatTensorSerde<S>; type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> { fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
FloatTensorSerde::new(self.into_data().convert::<S::FloatElem>()) let data = self.into_data();
let data = if let DType::QFloat(_) = data.dtype {
data // do not convert quantized tensors
} else {
data.convert::<S::FloatElem>()
};
FloatTensorSerde::new(data)
} }
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self { fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Tensor::from_data(item.data.convert::<B::FloatElem>(), device) let data = if let DType::QFloat(_) = item.data.dtype {
item.data // do not convert quantized tensors
} else {
item.data.convert::<B::FloatElem>()
};
Tensor::from_data(data, device)
} }
} }
#[allow(deprecated)]
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> { impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
type Item<S: PrecisionSettings> = IntTensorSerde<S>; type Item<S: PrecisionSettings> = IntTensorSerde<S>;
@ -158,7 +173,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
} }
} }
#[allow(deprecated)]
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> { impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
type Item<S: PrecisionSettings> = BoolTensorSerde; type Item<S: PrecisionSettings> = BoolTensorSerde;

View File

@ -1,12 +1,19 @@
use burn_tensor::{ use burn_tensor::{
backend::Backend, backend::Backend,
ops::{QTensorOps, QuantizedTensor}, ops::{QTensorOps, QuantizedTensor},
Device, QuantizationStrategy, Shape, Device, QuantizationStrategy, Shape, TensorData,
}; };
use crate::{client::FusionClient, Fusion, FusionBackend}; use crate::{client::FusionClient, Fusion, FusionBackend};
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> { impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn q_from_data<const D: usize>(
_data: TensorData,
_device: &Device<Self>,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn quantize<const D: usize>( fn quantize<const D: usize>(
_tensor: <Self as Backend>::FloatTensorPrimitive<D>, _tensor: <Self as Backend>::FloatTensorPrimitive<D>,
_strategy: &QuantizationStrategy, _strategy: &QuantizationStrategy,
@ -28,4 +35,18 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> { fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
tensor.client.device().clone() tensor.client.device().clone()
} }
fn q_reshape<const D1: usize, const D2: usize>(
_tensor: QuantizedTensor<Self, D1>,
_shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
unimplemented!()
}
async fn q_into_data<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_strategy: QuantizationStrategy,
) -> TensorData {
unimplemented!()
}
} }

View File

@ -1,6 +1,6 @@
use burn_tensor::{ use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor}, ops::{FloatTensor, QTensorOps, QuantizedTensor},
Device, QuantizationStrategy, Shape, Device, QuantizationStrategy, Shape, TensorData,
}; };
use crate::{FloatElement, IntElement, JitBackend, JitRuntime}; use crate::{FloatElement, IntElement, JitBackend, JitRuntime};
@ -11,6 +11,13 @@ where
F: FloatElement, F: FloatElement,
I: IntElement, I: IntElement,
{ {
fn q_from_data<const D: usize>(
_data: TensorData,
_device: &Device<Self>,
) -> QuantizedTensor<Self, D> {
todo!()
}
fn quantize<const D: usize>( fn quantize<const D: usize>(
_tensor: FloatTensor<Self, D>, _tensor: FloatTensor<Self, D>,
_strategy: &QuantizationStrategy, _strategy: &QuantizationStrategy,
@ -32,4 +39,18 @@ where
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> { fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
tensor.device.clone() tensor.device.clone()
} }
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
super::reshape(tensor, shape)
}
async fn q_into_data<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_strategy: QuantizationStrategy,
) -> TensorData {
unimplemented!()
}
} }

View File

@ -1,10 +1,12 @@
use burn_tensor::{ use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor}, ops::{FloatTensor, QTensorOps, QuantizedTensor},
Quantization, QuantizationStrategy, Shape, TensorData, DType, Quantization, QuantizationStrategy, Shape, TensorData,
}; };
use crate::{element::NdArrayElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor}; use crate::{element::NdArrayElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
use super::NdArrayOps;
fn into_data<E: NdArrayElement, const D: usize>(tensor: NdArrayTensor<E, D>) -> TensorData { fn into_data<E: NdArrayElement, const D: usize>(tensor: NdArrayTensor<E, D>) -> TensorData {
let shape = tensor.shape(); let shape = tensor.shape();
let values = tensor.array.into_iter().collect(); let values = tensor.array.into_iter().collect();
@ -12,6 +14,28 @@ fn into_data<E: NdArrayElement, const D: usize>(tensor: NdArrayTensor<E, D>) ->
} }
impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> { impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
fn q_from_data<const D: usize>(
data: TensorData,
_device: &NdArrayDevice,
) -> QuantizedTensor<Self, D> {
match data.dtype {
DType::QFloat(strategy) => match strategy {
QuantizationStrategy::PerTensorAffineInt8(_) => {
let data = data.convert::<i8>();
NdArrayTensor::<i8, D>::from_data(data)
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
let data = data.convert::<i8>();
NdArrayTensor::<i8, D>::from_data(data)
}
},
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
data.dtype
),
}
}
fn quantize<const D: usize>( fn quantize<const D: usize>(
tensor: FloatTensor<Self, D>, tensor: FloatTensor<Self, D>,
strategy: &QuantizationStrategy, strategy: &QuantizationStrategy,
@ -41,4 +65,20 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
fn q_device<const D: usize>(_tensor: &QuantizedTensor<Self, D>) -> NdArrayDevice { fn q_device<const D: usize>(_tensor: &QuantizedTensor<Self, D>) -> NdArrayDevice {
NdArrayDevice::Cpu NdArrayDevice::Cpu
} }
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
NdArrayOps::reshape(tensor, shape)
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
TensorData::quantized(values, shape, strategy)
}
} }

View File

@ -1,4 +1,4 @@
use burn_tensor::Shape; use burn_tensor::{QuantizationStrategy, Shape};
use tch::Scalar; use tch::Scalar;
use crate::{LibTorchDevice, TchShape, TchTensor}; use crate::{LibTorchDevice, TchShape, TchTensor};
@ -512,4 +512,30 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
) -> TchTensor<i64, D> { ) -> TchTensor<i64, D> {
TchTensor::new(tensor.tensor.argsort(dim as i64, descending)) TchTensor::new(tensor.tensor.argsort(dim as i64, descending))
} }
pub fn quantize<const D: usize, I: tch::kind::Element>(
tensor: TchTensor<E, D>,
strategy: &QuantizationStrategy,
) -> TchTensor<I, D> {
let mut tensor = tensor;
// Quantize only works on Float Tensor
if tensor.tensor.kind() == tch::Kind::Half {
tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
}
match strategy {
QuantizationStrategy::PerTensorAffineInt8(ref q) => {
TchTensor::new(tensor.tensor.quantize_per_tensor(
q.scale.into(),
q.offset.into(),
tch::Kind::QInt8,
))
}
QuantizationStrategy::PerTensorSymmetricInt8(ref q) => TchTensor::new(
tensor
.tensor
.quantize_per_tensor(q.scale.into(), 0, tch::Kind::QInt8),
),
}
}
} }

View File

@ -1,15 +1,63 @@
use burn_tensor::{ use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor}, ops::{FloatTensor, QTensorOps, QuantizedTensor},
QuantizationStrategy, Shape, DType, Quantization, QuantizationStrategy, Shape, TensorData,
}; };
use crate::{LibTorch, LibTorchDevice, TchElement, TchTensor}; use crate::{LibTorch, LibTorchDevice, TchElement, TchShape, TchTensor};
use super::TchOps;
impl<E: TchElement> QTensorOps<Self> for LibTorch<E> { impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
fn q_from_data<const D: usize>(
data: TensorData,
device: &LibTorchDevice,
) -> QuantizedTensor<Self, D> {
let shape_tch = TchShape::<D>::from(data.shape.as_slice());
let device = (*device).into();
// NOTE: tch-rs doesn't have `from_blob_quantized_*` APIs
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/quantized/Quantizer.cpp#L322
// So for now we have to load the dequantized values to quantize them back since the dequantization
// methods take the values provided when quantizing.
let tensor = match data.dtype {
DType::QFloat(strategy) => match strategy {
QuantizationStrategy::PerTensorAffineInt8(q) => {
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
let tensor = tch::Tensor::from_slice(&values).to(device);
TchOps::<E>::quantize::<D, i8>(
TchTensor::new(tensor.reshape(shape_tch.dims)),
&strategy,
)
.tensor
}
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
let tensor = tch::Tensor::from_slice(&values).to(device);
TchOps::<E>::quantize::<D, i8>(
TchTensor::new(tensor.reshape(shape_tch.dims)),
&strategy,
)
.tensor
}
},
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
data.dtype
),
};
TchTensor::new(tensor)
}
fn quantize<const D: usize>( fn quantize<const D: usize>(
tensor: FloatTensor<Self, D>, tensor: FloatTensor<Self, D>,
strategy: &QuantizationStrategy, strategy: &QuantizationStrategy,
) -> QuantizedTensor<Self, D> { ) -> QuantizedTensor<Self, D> {
let mut tensor = tensor;
// Quantize only works on Float Tensor
if E::dtype() == DType::F16 {
tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
}
match strategy { match strategy {
QuantizationStrategy::PerTensorAffineInt8(ref q) => { QuantizationStrategy::PerTensorAffineInt8(ref q) => {
TchTensor::new(tensor.tensor.quantize_per_tensor( TchTensor::new(tensor.tensor.quantize_per_tensor(
@ -30,7 +78,7 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
tensor: QuantizedTensor<Self, D>, tensor: QuantizedTensor<Self, D>,
_strategy: &QuantizationStrategy, _strategy: &QuantizationStrategy,
) -> FloatTensor<Self, D> { ) -> FloatTensor<Self, D> {
TchTensor::new(tensor.tensor.dequantize()) TchTensor::new(tensor.tensor.dequantize().to_kind(E::KIND))
} }
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> { fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
@ -40,4 +88,23 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> LibTorchDevice { fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> LibTorchDevice {
tensor.tensor.device().into() tensor.tensor.device().into()
} }
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<Self, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
TchOps::reshape(tensor, shape)
}
async fn q_into_data<const D: usize>(
tensor: QuantizedTensor<Self, D>,
strategy: QuantizationStrategy,
) -> TensorData {
let shape = Self::q_shape(&tensor);
let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
// To get the integer values we have to call `int_repr()`
let values: Result<Vec<i8>, tch::TchError> = tensor.tensor.int_repr().try_into();
TensorData::quantized(values.unwrap(), shape, strategy)
}
} }

View File

@ -18,7 +18,7 @@ use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk; use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow; use crate::tensor::api::narrow::narrow;
use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind};
use crate::{Element, TensorPrimitive}; use crate::{DType, Element, TensorPrimitive};
/// A tensor with a given backend, shape and data type. /// A tensor with a given backend, shape and data type.
#[derive(new, Clone, Debug)] #[derive(new, Clone, Debug)]
@ -1697,7 +1697,15 @@ impl<B: Backend> BasicOps<B> for Float {
tensor: Self::Primitive<D1>, tensor: Self::Primitive<D1>,
shape: Shape<D2>, shape: Shape<D2>,
) -> Self::Primitive<D2> { ) -> Self::Primitive<D2> {
TensorPrimitive::Float(B::float_reshape(tensor.tensor(), shape)) match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_reshape(tensor, shape))
}
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
tensor: B::q_reshape(tensor, shape),
strategy,
},
}
} }
fn transpose<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> { fn transpose<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
@ -1750,11 +1758,20 @@ impl<B: Backend> BasicOps<B> for Float {
} }
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData { async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
B::float_into_data(tensor.tensor()).await match tensor {
TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
TensorPrimitive::QFloat { tensor, strategy } => B::q_into_data(tensor, strategy).await,
}
} }
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> { fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
TensorPrimitive::Float(B::float_from_data(data, device)) match data.dtype {
DType::QFloat(strategy) => TensorPrimitive::QFloat {
tensor: B::q_from_data(data, device),
strategy,
},
_ => TensorPrimitive::Float(B::float_from_data(data, device)),
}
} }
fn repeat<const D: usize>( fn repeat<const D: usize>(

View File

@ -271,9 +271,9 @@ where
match &self.primitive { match &self.primitive {
TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor), TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
TensorPrimitive::QFloat { TensorPrimitive::QFloat {
tensor: _, tensor,
strategy: _, strategy: _,
} => B::float_is_require_grad(&self.primitive.clone().tensor()), } => B::q_is_require_grad(tensor),
} }
} }
@ -282,10 +282,16 @@ where
/// ///
/// This function does nothing when autodiff is not enabled. /// This function does nothing when autodiff is not enabled.
pub fn set_require_grad(self, require_grad: bool) -> Self { pub fn set_require_grad(self, require_grad: bool) -> Self {
Self::new(TensorPrimitive::Float(B::float_set_require_grad( let primitive = match self.primitive {
self.primitive.tensor(), TensorPrimitive::Float(tensor) => {
require_grad, TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
))) }
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
tensor: B::q_set_require_grad(tensor, require_grad),
strategy,
},
};
Self::new(primitive)
} }
/// Applies the relu function to the tensor. /// Applies the relu function to the tensor.

View File

@ -48,6 +48,15 @@ impl TensorData {
Self::init(value, shape, E::dtype()) Self::init(value, shape, E::dtype())
} }
/// Creates a new quantized tensor data structure.
pub fn quantized<E: Element, S: Into<Vec<usize>>>(
value: Vec<E>,
shape: S,
strategy: QuantizationStrategy,
) -> Self {
Self::init(value, shape, DType::QFloat(strategy))
}
/// Initializes a new tensor data structure from the provided values. /// Initializes a new tensor data structure from the provided values.
fn init<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S, dtype: DType) -> Self { fn init<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S, dtype: DType) -> Self {
Self { Self {
@ -258,15 +267,15 @@ impl TensorData {
"Only f32 data type can be quantized" "Only f32 data type can be quantized"
); );
match &quantization { match &quantization {
QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::init( QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::quantized(
strategy.quantize(self.as_slice().unwrap()), strategy.quantize(self.as_slice().unwrap()),
self.shape, self.shape,
DType::QFloat(quantization), quantization,
), ),
QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::init( QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::quantized(
strategy.quantize(self.as_slice().unwrap()), strategy.quantize(self.as_slice().unwrap()),
self.shape, self.shape,
DType::QFloat(quantization), quantization,
), ),
} }
} }

View File

@ -1,10 +1,24 @@
use crate::{backend::Backend, Device, QuantizationStrategy, Shape}; use core::future::Future;
use crate::{backend::Backend, Device, QuantizationStrategy, Shape, TensorData};
use super::{FloatTensor, QuantizedTensor}; use super::{FloatTensor, QuantizedTensor};
/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor) /// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
/// for documentation on each function. /// for documentation on each function.
pub trait QTensorOps<B: Backend> { pub trait QTensorOps<B: Backend> {
/// Creates a new tensor from the data structure.
///
/// # Arguments
///
/// * `data` - The data structure.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given data.
fn q_from_data<const D: usize>(data: TensorData, device: &Device<B>) -> QuantizedTensor<B, D>;
/// Convert the tensor to a lower precision data type based on the quantization strategy. /// Convert the tensor to a lower precision data type based on the quantization strategy.
fn quantize<const D: usize>( fn quantize<const D: usize>(
tensor: FloatTensor<B, D>, tensor: FloatTensor<B, D>,
@ -38,4 +52,48 @@ pub trait QTensorOps<B: Backend> {
/// ///
/// The device of the tensor. /// The device of the tensor.
fn q_device<const D: usize>(tensor: &QuantizedTensor<B, D>) -> Device<B>; fn q_device<const D: usize>(tensor: &QuantizedTensor<B, D>) -> Device<B>;
/// Reshapes a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to reshape.
/// * `shape` - The new shape of the tensor.
///
/// # Returns
///
/// The tensor with the new shape.
fn q_reshape<const D1: usize, const D2: usize>(
tensor: QuantizedTensor<B, D1>,
shape: Shape<D2>,
) -> QuantizedTensor<B, D2>;
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn q_into_data<const D: usize>(
tensor: QuantizedTensor<B, D>,
strategy: QuantizationStrategy,
) -> impl Future<Output = TensorData> + Send;
/// Sets the `require_grad` flag of a tensor.
fn q_set_require_grad<const D: usize>(
tensor: QuantizedTensor<B, D>,
_require_grad: bool,
) -> QuantizedTensor<B, D> {
// Should only be overridden by autodiff backends.
tensor
}
/// Returns the `require_grad` flag of a tensor.
fn q_is_require_grad<const D: usize>(_tensor: &QuantizedTensor<B, D>) -> bool {
// Should only be overridden by autodiff backends.
false
}
} }

View File

@ -8,7 +8,7 @@ use burn_common::{iter_par, run_par};
use num_traits::{Float, PrimInt}; use num_traits::{Float, PrimInt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Quantization scheme/strategy. /// Quantization strategy.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationStrategy { pub enum QuantizationStrategy {
/// Per-tensor `int8` affine/asymmetric quantization. /// Per-tensor `int8` affine/asymmetric quantization.

View File

@ -47,6 +47,18 @@
//! - Autodiff: Backend decorator that brings backpropagation to any backend //! - Autodiff: Backend decorator that brings backpropagation to any backend
//! - Fusion: Backend decorator that brings kernel fusion to backends that support it //! - Fusion: Backend decorator that brings kernel fusion to backends that support it
//! //!
//! # Quantization (Beta)
//!
//! Quantization techniques perform computations and store tensors in lower precision data types like 8-bit integer
//! instead of floating point precision. There are multiple approaches to quantize a deep learning model. In most cases,
//! the model is trained in floating point precision and later converted to the lower precision data type. This is called
//! post-training quantization (PTQ). On the other hand, quantization aware training (QAT) models the effects of quantization
//! during training. Quantization errors are thus modeled in the forward and backward passes, which helps the model learn
//! representations that are more robust to the reduction in precision.
//!
//! Quantization support in Burn is currently in active development. It supports the following modes on some backends:
//! - Static per-tensor quantization to signed 8-bit integer (`i8`)
//!
//! ## Feature Flags //! ## Feature Flags
//! //!
//! The following feature flags are available. //! The following feature flags are available.