mirror of https://github.com/tracel-ai/burn.git
Module weight quantization (#2000)
* Add q_into_data and q_reshape * Fix tch quantize f16 and q_into_data * Convert to actual dtype/kind in dequantize * Add module quantization and q_from_data * Fix clippy * Add documentation * Handle deserialize data conversion * Fix typo * Add calibration tests * Fix clippy precision * Add QTensorOps require_grad methods to avoid dequantizing * Add Dequantize mapper docs * Remove dead code
This commit is contained in:
parent
a4123f6c2e
commit
3afff434bd
|
@ -25,6 +25,7 @@
|
|||
- [ONNX Model](./import/onnx-model.md)
|
||||
- [PyTorch Model](./import/pytorch-model.md)
|
||||
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
|
||||
- [Quantization (Beta)](./quantization.md)
|
||||
- [Advanced](./advanced/README.md)
|
||||
- [Backend Extension](./advanced/backend-extension/README.md)
|
||||
- [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
# Quantization (Beta)
|
||||
|
||||
Quantization techniques perform computations and store tensors in lower precision data types like
|
||||
8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep
|
||||
learning model categorized as:
|
||||
|
||||
- Post-training quantization (PTQ)
|
||||
- Quantization aware training (QAT)
|
||||
|
||||
In post-training quantization, the model is trained in floating point precision and later converted
|
||||
to the lower precision data type.
|
||||
|
||||
There are two types of post-training quantization:
|
||||
|
||||
1. Static quantization: quantizes the weights and activations of the model. Quantizing the
|
||||
activations statically requires data to be calibrated (i.e., recording the activation values to
|
||||
compute the optimal quantization parameters with representative data).
|
||||
1. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the
|
||||
activations are dynamically at runtime.
|
||||
|
||||
Sometimes post-training quantization is not able to achieve acceptable task accuracy. This is where
|
||||
quantization aware training comes into play, as it models the effects of quantization during
|
||||
training. Quantization errors are thus modeled in the forward and backward passes using fake
|
||||
quantization modules, which helps the model learn representations that are more robust to the
|
||||
reduction in precision.
|
||||
|
||||
<div class="warning">
|
||||
|
||||
Quantization support in Burn is currently in active development.
|
||||
|
||||
It supports the following modes on some backends:
|
||||
|
||||
- Static per-tensor quantization to signed 8-bit integer (`i8`)
|
||||
|
||||
No integer operations are currently supported, which means tensors are dequantized to perform the
|
||||
operations in floating point precision.
|
||||
|
||||
</div>
|
||||
|
||||
## Module Quantization
|
||||
|
||||
Quantizing the weights of your model after training is quite simple. We have access to the weight
|
||||
tensors and can collect their statistics, such as the min and max value when using
|
||||
`MinMaxCalibration`, to compute the quantization parameters.
|
||||
|
||||
```rust , ignore
|
||||
# use burn::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType, Quantizer};
|
||||
#
|
||||
// Quantization config
|
||||
let mut quantizer = Quantizer {
|
||||
calibration: MinMaxCalibration {
|
||||
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
|
||||
},
|
||||
};
|
||||
|
||||
// Quantize the weights
|
||||
let model = model.quantize_weights(&mut quantizer);
|
||||
```
|
||||
|
||||
> Given that all operations are currently performed in floating point precision, it might be wise to
|
||||
> dequantize the module parameters before inference. This allows us to save disk space by storing
|
||||
> the model in reduced precision while preserving the inference speed.
|
||||
>
|
||||
> This can easily be implemented with a `ModuleMapper`.
|
||||
>
|
||||
> ```rust, ignore
|
||||
> # use burn::module::{ModuleMapper, ParamId};
|
||||
> # use burn::tensor::{backend::Backend, Tensor};
|
||||
> #
|
||||
> /// Module mapper used to dequantize the model params being loaded.
|
||||
> pub struct Dequantize {}
|
||||
>
|
||||
> impl<B: Backend> ModuleMapper<B> for Dequantize {
|
||||
> fn map_float<const D: usize>(
|
||||
> &mut self,
|
||||
> _id: &ParamId,
|
||||
> tensor: Tensor<B, D>,
|
||||
> ) -> Tensor<B, D> {
|
||||
> tensor.dequantize()
|
||||
> }
|
||||
> }
|
||||
>
|
||||
> // Load saved quantized model in floating point precision
|
||||
> model = model
|
||||
> .load_file(file_path, recorder, &device)
|
||||
> .expect("Should be able to load the quantized model weights")
|
||||
> .map(&mut Dequantize {});
|
||||
> ```
|
||||
|
||||
### Calibration
|
||||
|
||||
Calibration is the step during quantization where the range of all floating-point tensors is
|
||||
computed. This is pretty straightforward for weights since the actual range is known at
|
||||
_quantization-time_ (weights are static), but activations require more attention.
|
||||
|
||||
To compute the quantization parameters, Burn supports the following `Calibration` methods.
|
||||
|
||||
| Method | Description |
|
||||
| :------------------ | :------------------------------------------------------------------------------- |
|
||||
| `MinMaxCalibration` | Computes the quantization range mapping based on the running min and max values. |
|
||||
|
||||
### Quantization Scheme
|
||||
|
||||
A quantization scheme defines the quantized type, quantization granularity and range mapping
|
||||
technique.
|
||||
|
||||
Burn currently supports the following `QuantizationType` variants.
|
||||
|
||||
| Type | Description |
|
||||
| :------ | :--------------------------------- |
|
||||
| `QInt8` | 8-bit signed integer quantization. |
|
||||
|
||||
Quantization parameters are defined based on the range of values to represent and can typically be
|
||||
calculated for the layer's entire weight tensor with per-tensor quantization or separately for each
|
||||
channel with per-channel quantization (commonly used with CNNs).
|
||||
|
||||
Burn currently supports the following `QuantizationScheme` variants.
|
||||
|
||||
| Variant | Description |
|
||||
| :------------------- | :------------------------------------------------------------------------------------------------------------- |
|
||||
| `PerTensorAffine` | Computes the quantization parameters for the whole tensor and applies an affine range mapping with zero point. |
|
||||
| `PerTensorSymmetric` | Computes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0. |
|
|
@ -1,12 +1,19 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
Device, QuantizationStrategy, Shape,
|
||||
Device, QuantizationStrategy, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
|
||||
fn q_from_data<const D: usize>(
|
||||
_data: TensorData,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: FloatTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
|
@ -28,4 +35,18 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
|
|||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
|
||||
B::q_device(tensor)
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
B::q_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
B::q_into_data(tensor, strategy).await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
Device, QuantizationStrategy, Shape,
|
||||
DType, Device, QuantizationStrategy, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -10,6 +10,13 @@ use crate::{
|
|||
};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> {
|
||||
fn q_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
unimplemented!() // no i8 support
|
||||
}
|
||||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: FloatTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
|
@ -31,4 +38,18 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
|
|||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
|
||||
super::base::device(tensor)
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
super::base::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
super::base::into_data(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,9 @@ pub mod module;
|
|||
/// Neural network module.
|
||||
pub mod nn;
|
||||
|
||||
/// Quantization module.
|
||||
pub mod quantization;
|
||||
|
||||
/// Module for the recorder.
|
||||
pub mod record;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::ParamId;
|
||||
use crate::{
|
||||
quantization::{Calibration, Quantizer},
|
||||
record::Record,
|
||||
tensor::backend::{AutodiffBackend, Backend},
|
||||
};
|
||||
|
@ -202,6 +203,11 @@ pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
|
|||
|
||||
Ok(self.load_record(record))
|
||||
}
|
||||
|
||||
/// Quantize the weights of the module.
|
||||
fn quantize_weights<C: Calibration>(self, quantizer: &mut Quantizer<C>) -> Self {
|
||||
self.map(quantizer)
|
||||
}
|
||||
}
|
||||
|
||||
/// Module visitor trait.
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend, AffineQuantization, ElementConversion, Quantization, QuantizationStrategy,
|
||||
SymmetricQuantization, Tensor,
|
||||
};
|
||||
|
||||
use super::{QuantizationScheme, QuantizationType};
|
||||
|
||||
/// Calibration method used to compute the quantization range mapping.
|
||||
pub trait Calibration {
|
||||
/// Configure the quantization strategy.
|
||||
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy;
|
||||
}
|
||||
|
||||
/// Computes the quantization range mapping based on the running min and max values.
|
||||
pub struct MinMaxCalibration {
|
||||
/// Quantization scheme to be used.
|
||||
pub scheme: QuantizationScheme,
|
||||
}
|
||||
|
||||
impl Calibration for MinMaxCalibration {
|
||||
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy {
|
||||
let min = tensor.clone().min().into_scalar().elem::<f32>();
|
||||
let max = tensor.clone().max().into_scalar().elem::<f32>();
|
||||
|
||||
match &self.scheme {
|
||||
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => {
|
||||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(min, max))
|
||||
}
|
||||
},
|
||||
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
|
||||
QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8(
|
||||
SymmetricQuantization::new(min, max),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn min_max_calibration_per_tensor_affine_int8() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
|
||||
let calibration = MinMaxCalibration {
|
||||
scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
|
||||
};
|
||||
|
||||
let strategy = calibration.configure(&tensor);
|
||||
|
||||
if let QuantizationStrategy::PerTensorAffineInt8(q) = strategy {
|
||||
assert_eq!(q.scale, 0.009_019_608);
|
||||
assert_eq!(q.offset, 72);
|
||||
} else {
|
||||
panic!("Wrong quantization strategy");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_calibration_per_tensor_symmetric_int8() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
|
||||
let calibration = MinMaxCalibration {
|
||||
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
|
||||
};
|
||||
|
||||
let strategy = calibration.configure(&tensor);
|
||||
|
||||
if let QuantizationStrategy::PerTensorSymmetricInt8(q) = strategy {
|
||||
assert_eq!(q.scale, 0.014_173_228);
|
||||
} else {
|
||||
panic!("Wrong quantization strategy");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod calibration;
|
||||
mod quantize;
|
||||
mod scheme;
|
||||
|
||||
pub use calibration::*;
|
||||
pub use quantize::*;
|
||||
pub use scheme::*;
|
|
@ -0,0 +1,18 @@
|
|||
use burn_tensor::{backend::Backend, Tensor};
|
||||
|
||||
use crate::module::{ModuleMapper, ParamId};
|
||||
|
||||
use super::Calibration;
|
||||
|
||||
/// Describes how to quantize a module.
|
||||
pub struct Quantizer<C: Calibration> {
|
||||
/// The calibration method used in quantization.
|
||||
pub calibration: C,
|
||||
}
|
||||
|
||||
impl<B: Backend, C: Calibration> ModuleMapper<B> for Quantizer<C> {
|
||||
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let strategy = self.calibration.configure(&tensor);
|
||||
tensor.quantize(strategy)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
/// Quantization data type.
|
||||
pub enum QuantizationType {
|
||||
/// 8-bit signed integer.
|
||||
QInt8,
|
||||
}
|
||||
|
||||
/// Quantization scheme.
|
||||
pub enum QuantizationScheme {
|
||||
/// Per-tensor affine/asymmetric quantization.
|
||||
PerTensorAffine(QuantizationType),
|
||||
/// Per-tensor symmetric quantization.
|
||||
PerTensorSymmetric(QuantizationType),
|
||||
// /// Per-channel affine/asymmetric quantization.
|
||||
// PerChannelAffine,
|
||||
// /// Per-channel symmetric quantization.
|
||||
// PerChannelSymmetric,
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use super::{PrecisionSettings, Record};
|
||||
use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData};
|
||||
use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(not(feature = "record-backward-compat"))]
|
||||
|
@ -43,7 +43,12 @@ where
|
|||
e
|
||||
))
|
||||
})?;
|
||||
Ok(data.convert::<E>())
|
||||
let data = if let DType::QFloat(_) = data.dtype {
|
||||
data // do not convert quantized tensors
|
||||
} else {
|
||||
data.convert::<E>()
|
||||
};
|
||||
Ok(data)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,15 +142,25 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
|
|||
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
FloatTensorSerde::new(self.into_data().convert::<S::FloatElem>())
|
||||
let data = self.into_data();
|
||||
let data = if let DType::QFloat(_) = data.dtype {
|
||||
data // do not convert quantized tensors
|
||||
} else {
|
||||
data.convert::<S::FloatElem>()
|
||||
};
|
||||
FloatTensorSerde::new(data)
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Tensor::from_data(item.data.convert::<B::FloatElem>(), device)
|
||||
let data = if let DType::QFloat(_) = item.data.dtype {
|
||||
item.data // do not convert quantized tensors
|
||||
} else {
|
||||
item.data.convert::<B::FloatElem>()
|
||||
};
|
||||
Tensor::from_data(data, device)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
|
||||
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
|
||||
|
||||
|
@ -158,7 +173,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
|
||||
type Item<S: PrecisionSettings> = BoolTensorSerde;
|
||||
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{QTensorOps, QuantizedTensor},
|
||||
Device, QuantizationStrategy, Shape,
|
||||
Device, QuantizationStrategy, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{client::FusionClient, Fusion, FusionBackend};
|
||||
|
||||
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
||||
fn q_from_data<const D: usize>(
|
||||
_data: TensorData,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: <Self as Backend>::FloatTensorPrimitive<D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
|
@ -28,4 +35,18 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
|
||||
tensor.client.device().clone()
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
_tensor: QuantizedTensor<Self, D1>,
|
||||
_shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
_tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn_tensor::{
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
Device, QuantizationStrategy, Shape,
|
||||
Device, QuantizationStrategy, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{FloatElement, IntElement, JitBackend, JitRuntime};
|
||||
|
@ -11,6 +11,13 @@ where
|
|||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn q_from_data<const D: usize>(
|
||||
_data: TensorData,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn quantize<const D: usize>(
|
||||
_tensor: FloatTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
|
@ -32,4 +39,18 @@ where
|
|||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
|
||||
tensor.device.clone()
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
super::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
_tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use burn_tensor::{
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
Quantization, QuantizationStrategy, Shape, TensorData,
|
||||
DType, Quantization, QuantizationStrategy, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{element::NdArrayElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
|
||||
|
||||
use super::NdArrayOps;
|
||||
|
||||
fn into_data<E: NdArrayElement, const D: usize>(tensor: NdArrayTensor<E, D>) -> TensorData {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
|
@ -12,6 +14,28 @@ fn into_data<E: NdArrayElement, const D: usize>(tensor: NdArrayTensor<E, D>) ->
|
|||
}
|
||||
|
||||
impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
|
||||
fn q_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
_device: &NdArrayDevice,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
match data.dtype {
|
||||
DType::QFloat(strategy) => match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(_) => {
|
||||
let data = data.convert::<i8>();
|
||||
NdArrayTensor::<i8, D>::from_data(data)
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
|
||||
let data = data.convert::<i8>();
|
||||
NdArrayTensor::<i8, D>::from_data(data)
|
||||
}
|
||||
},
|
||||
_ => panic!(
|
||||
"Invalid dtype (expected DType::QFloat, got {:?})",
|
||||
data.dtype
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
|
@ -41,4 +65,20 @@ impl<E: FloatNdArrayElement> QTensorOps<Self> for NdArray<E> {
|
|||
fn q_device<const D: usize>(_tensor: &QuantizedTensor<Self, D>) -> NdArrayDevice {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
TensorData::quantized(values, shape, strategy)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_tensor::Shape;
|
||||
use burn_tensor::{QuantizationStrategy, Shape};
|
||||
use tch::Scalar;
|
||||
|
||||
use crate::{LibTorchDevice, TchShape, TchTensor};
|
||||
|
@ -512,4 +512,30 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
) -> TchTensor<i64, D> {
|
||||
TchTensor::new(tensor.tensor.argsort(dim as i64, descending))
|
||||
}
|
||||
|
||||
pub fn quantize<const D: usize, I: tch::kind::Element>(
|
||||
tensor: TchTensor<E, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
) -> TchTensor<I, D> {
|
||||
let mut tensor = tensor;
|
||||
// Quantize only works on Float Tensor
|
||||
if tensor.tensor.kind() == tch::Kind::Half {
|
||||
tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
|
||||
}
|
||||
|
||||
match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(ref q) => {
|
||||
TchTensor::new(tensor.tensor.quantize_per_tensor(
|
||||
q.scale.into(),
|
||||
q.offset.into(),
|
||||
tch::Kind::QInt8,
|
||||
))
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(ref q) => TchTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.quantize_per_tensor(q.scale.into(), 0, tch::Kind::QInt8),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,15 +1,63 @@
|
|||
use burn_tensor::{
|
||||
ops::{FloatTensor, QTensorOps, QuantizedTensor},
|
||||
QuantizationStrategy, Shape,
|
||||
DType, Quantization, QuantizationStrategy, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{LibTorch, LibTorchDevice, TchElement, TchTensor};
|
||||
use crate::{LibTorch, LibTorchDevice, TchElement, TchShape, TchTensor};
|
||||
|
||||
use super::TchOps;
|
||||
|
||||
impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
||||
fn q_from_data<const D: usize>(
|
||||
data: TensorData,
|
||||
device: &LibTorchDevice,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
let shape_tch = TchShape::<D>::from(data.shape.as_slice());
|
||||
let device = (*device).into();
|
||||
|
||||
// NOTE: tch-rs doesn't have `from_blob_quantized_*` APIs
|
||||
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/quantized/Quantizer.cpp#L322
|
||||
// So for now we have to load the dequantized values to quantize them back since the dequantization
|
||||
// methods take the values provided when quantizing.
|
||||
let tensor = match data.dtype {
|
||||
DType::QFloat(strategy) => match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(q) => {
|
||||
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
|
||||
let tensor = tch::Tensor::from_slice(&values).to(device);
|
||||
TchOps::<E>::quantize::<D, i8>(
|
||||
TchTensor::new(tensor.reshape(shape_tch.dims)),
|
||||
&strategy,
|
||||
)
|
||||
.tensor
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
|
||||
let values = q.dequantize(&data.iter::<i8>().collect::<Vec<_>>());
|
||||
let tensor = tch::Tensor::from_slice(&values).to(device);
|
||||
TchOps::<E>::quantize::<D, i8>(
|
||||
TchTensor::new(tensor.reshape(shape_tch.dims)),
|
||||
&strategy,
|
||||
)
|
||||
.tensor
|
||||
}
|
||||
},
|
||||
_ => panic!(
|
||||
"Invalid dtype (expected DType::QFloat, got {:?})",
|
||||
data.dtype
|
||||
),
|
||||
};
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn quantize<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
strategy: &QuantizationStrategy,
|
||||
) -> QuantizedTensor<Self, D> {
|
||||
let mut tensor = tensor;
|
||||
// Quantize only works on Float Tensor
|
||||
if E::dtype() == DType::F16 {
|
||||
tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
|
||||
}
|
||||
|
||||
match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(ref q) => {
|
||||
TchTensor::new(tensor.tensor.quantize_per_tensor(
|
||||
|
@ -30,7 +78,7 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
|||
tensor: QuantizedTensor<Self, D>,
|
||||
_strategy: &QuantizationStrategy,
|
||||
) -> FloatTensor<Self, D> {
|
||||
TchTensor::new(tensor.tensor.dequantize())
|
||||
TchTensor::new(tensor.tensor.dequantize().to_kind(E::KIND))
|
||||
}
|
||||
|
||||
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
|
||||
|
@ -40,4 +88,23 @@ impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
|||
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> LibTorchDevice {
|
||||
tensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<Self, D2> {
|
||||
TchOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
async fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<Self, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> TensorData {
|
||||
let shape = Self::q_shape(&tensor);
|
||||
let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
// To get the integer values we have to call `int_repr()`
|
||||
let values: Result<Vec<i8>, tch::TchError> = tensor.tensor.int_repr().try_into();
|
||||
|
||||
TensorData::quantized(values.unwrap(), shape, strategy)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ use crate::check::TensorCheck;
|
|||
use crate::tensor::api::chunk::chunk;
|
||||
use crate::tensor::api::narrow::narrow;
|
||||
use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind};
|
||||
use crate::{Element, TensorPrimitive};
|
||||
use crate::{DType, Element, TensorPrimitive};
|
||||
|
||||
/// A tensor with a given backend, shape and data type.
|
||||
#[derive(new, Clone, Debug)]
|
||||
|
@ -1697,7 +1697,15 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
tensor: Self::Primitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> Self::Primitive<D2> {
|
||||
TensorPrimitive::Float(B::float_reshape(tensor.tensor(), shape))
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => {
|
||||
TensorPrimitive::Float(B::float_reshape(tensor, shape))
|
||||
}
|
||||
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
|
||||
tensor: B::q_reshape(tensor, shape),
|
||||
strategy,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn transpose<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
|
||||
|
@ -1750,11 +1758,20 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
}
|
||||
|
||||
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
|
||||
B::float_into_data(tensor.tensor()).await
|
||||
match tensor {
|
||||
TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
|
||||
TensorPrimitive::QFloat { tensor, strategy } => B::q_into_data(tensor, strategy).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
|
||||
TensorPrimitive::Float(B::float_from_data(data, device))
|
||||
match data.dtype {
|
||||
DType::QFloat(strategy) => TensorPrimitive::QFloat {
|
||||
tensor: B::q_from_data(data, device),
|
||||
strategy,
|
||||
},
|
||||
_ => TensorPrimitive::Float(B::float_from_data(data, device)),
|
||||
}
|
||||
}
|
||||
|
||||
fn repeat<const D: usize>(
|
||||
|
|
|
@ -271,9 +271,9 @@ where
|
|||
match &self.primitive {
|
||||
TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
|
||||
TensorPrimitive::QFloat {
|
||||
tensor: _,
|
||||
tensor,
|
||||
strategy: _,
|
||||
} => B::float_is_require_grad(&self.primitive.clone().tensor()),
|
||||
} => B::q_is_require_grad(tensor),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -282,10 +282,16 @@ where
|
|||
///
|
||||
/// This function does nothing when autodiff is not enabled.
|
||||
pub fn set_require_grad(self, require_grad: bool) -> Self {
|
||||
Self::new(TensorPrimitive::Float(B::float_set_require_grad(
|
||||
self.primitive.tensor(),
|
||||
require_grad,
|
||||
)))
|
||||
let primitive = match self.primitive {
|
||||
TensorPrimitive::Float(tensor) => {
|
||||
TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
|
||||
}
|
||||
TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
|
||||
tensor: B::q_set_require_grad(tensor, require_grad),
|
||||
strategy,
|
||||
},
|
||||
};
|
||||
Self::new(primitive)
|
||||
}
|
||||
|
||||
/// Applies the relu function to the tensor.
|
||||
|
|
|
@ -48,6 +48,15 @@ impl TensorData {
|
|||
Self::init(value, shape, E::dtype())
|
||||
}
|
||||
|
||||
/// Creates a new quantized tensor data structure.
|
||||
pub fn quantized<E: Element, S: Into<Vec<usize>>>(
|
||||
value: Vec<E>,
|
||||
shape: S,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> Self {
|
||||
Self::init(value, shape, DType::QFloat(strategy))
|
||||
}
|
||||
|
||||
/// Initializes a new tensor data structure from the provided values.
|
||||
fn init<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S, dtype: DType) -> Self {
|
||||
Self {
|
||||
|
@ -258,15 +267,15 @@ impl TensorData {
|
|||
"Only f32 data type can be quantized"
|
||||
);
|
||||
match &quantization {
|
||||
QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::init(
|
||||
QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::quantized(
|
||||
strategy.quantize(self.as_slice().unwrap()),
|
||||
self.shape,
|
||||
DType::QFloat(quantization),
|
||||
quantization,
|
||||
),
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::init(
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::quantized(
|
||||
strategy.quantize(self.as_slice().unwrap()),
|
||||
self.shape,
|
||||
DType::QFloat(quantization),
|
||||
quantization,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,24 @@
|
|||
use crate::{backend::Backend, Device, QuantizationStrategy, Shape};
|
||||
use core::future::Future;
|
||||
|
||||
use crate::{backend::Backend, Device, QuantizationStrategy, Shape, TensorData};
|
||||
|
||||
use super::{FloatTensor, QuantizedTensor};
|
||||
|
||||
/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
|
||||
/// for documentation on each function.
|
||||
pub trait QTensorOps<B: Backend> {
|
||||
/// Creates a new tensor from the data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data structure.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given data.
|
||||
fn q_from_data<const D: usize>(data: TensorData, device: &Device<B>) -> QuantizedTensor<B, D>;
|
||||
|
||||
/// Convert the tensor to a lower precision data type based on the quantization strategy.
|
||||
fn quantize<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
|
@ -38,4 +52,48 @@ pub trait QTensorOps<B: Backend> {
|
|||
///
|
||||
/// The device of the tensor.
|
||||
fn q_device<const D: usize>(tensor: &QuantizedTensor<B, D>) -> Device<B>;
|
||||
|
||||
/// Reshapes a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to reshape.
|
||||
/// * `shape` - The new shape of the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the new shape.
|
||||
fn q_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: QuantizedTensor<B, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> QuantizedTensor<B, D2>;
|
||||
|
||||
/// Converts the tensor to a data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn q_into_data<const D: usize>(
|
||||
tensor: QuantizedTensor<B, D>,
|
||||
strategy: QuantizationStrategy,
|
||||
) -> impl Future<Output = TensorData> + Send;
|
||||
|
||||
/// Sets the `require_grad` flag of a tensor.
|
||||
fn q_set_require_grad<const D: usize>(
|
||||
tensor: QuantizedTensor<B, D>,
|
||||
_require_grad: bool,
|
||||
) -> QuantizedTensor<B, D> {
|
||||
// Should only be overridden by autodiff backends.
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Returns the `require_grad` flag of a tensor.
|
||||
fn q_is_require_grad<const D: usize>(_tensor: &QuantizedTensor<B, D>) -> bool {
|
||||
// Should only be overridden by autodiff backends.
|
||||
false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ use burn_common::{iter_par, run_par};
|
|||
use num_traits::{Float, PrimInt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Quantization scheme/strategy.
|
||||
/// Quantization strategy.
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum QuantizationStrategy {
|
||||
/// Per-tensor `int8` affine/asymmetric quantization.
|
||||
|
|
|
@ -47,6 +47,18 @@
|
|||
//! - Autodiff: Backend decorator that brings backpropagation to any backend
|
||||
//! - Fusion: Backend decorator that brings kernel fusion to backends that support it
|
||||
//!
|
||||
//! # Quantization (Beta)
|
||||
//!
|
||||
//! Quantization techniques perform computations and store tensors in lower precision data types like 8-bit integer
|
||||
//! instead of floating point precision. There are multiple approaches to quantize a deep learning model. In most cases,
|
||||
//! the model is trained in floating point precision and later converted to the lower precision data type. This is called
|
||||
//! post-training quantization (PTQ). On the other hand, quantization aware training (QAT) models the effects of quantization
|
||||
//! during training. Quantization errors are thus modeled in the forward and backward passes, which helps the model learn
|
||||
//! representations that are more robust to the reduction in precision.
|
||||
//!
|
||||
//! Quantization support in Burn is currently in active development. It supports the following modes on some backends:
|
||||
//! - Static per-tensor quantization to signed 8-bit integer (`i8`)
|
||||
//!
|
||||
//! ## Feature Flags
|
||||
//!
|
||||
//! The following feature flags are available.
|
||||
|
|
Loading…
Reference in New Issue