diff --git a/burn-book/src/SUMMARY.md b/burn-book/src/SUMMARY.md
index 3f4fcad94..6ee585637 100644
--- a/burn-book/src/SUMMARY.md
+++ b/burn-book/src/SUMMARY.md
@@ -25,6 +25,7 @@
- [ONNX Model](./import/onnx-model.md)
- [PyTorch Model](./import/pytorch-model.md)
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
+- [Quantization (Beta)](./quantization.md)
- [Advanced](./advanced/README.md)
- [Backend Extension](./advanced/backend-extension/README.md)
- [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)
diff --git a/burn-book/src/quantization.md b/burn-book/src/quantization.md
new file mode 100644
index 000000000..f69f3f7db
--- /dev/null
+++ b/burn-book/src/quantization.md
@@ -0,0 +1,122 @@
+# Quantization (Beta)
+
+Quantization techniques perform computations and store tensors in lower precision data types like
+8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep
+learning model categorized as:
+
+- Post-training quantization (PTQ)
+- Quantization aware training (QAT)
+
+In post-training quantization, the model is trained in floating point precision and later converted
+to the lower precision data type.
+
+There are two types of post-training quantization:
+
+1. Static quantization: quantizes the weights and activations of the model. Quantizing the
+ activations statically requires data to be calibrated (i.e., recording the activation values to
+ compute the optimal quantization parameters with representative data).
+1. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the
+ activations are dynamically at runtime.
+
+Sometimes post-training quantization is not able to achieve acceptable task accuracy. This is where
+quantization aware training comes into play, as it models the effects of quantization during
+training. Quantization errors are thus modeled in the forward and backward passes using fake
+quantization modules, which helps the model learn representations that are more robust to the
+reduction in precision.
+
+
+
+Quantization support in Burn is currently in active development.
+
+It supports the following modes on some backends:
+
+- Static per-tensor quantization to signed 8-bit integer (`i8`)
+
+No integer operations are currently supported, which means tensors are dequantized to perform the
+operations in floating point precision.
+
+
+
+## Module Quantization
+
+Quantizing the weights of your model after training is quite simple. We have access to the weight
+tensors and can collect their statistics, such as the min and max value when using
+`MinMaxCalibration`, to compute the quantization parameters.
+
+```rust , ignore
+# use burn::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType, Quantizer};
+#
+// Quantization config
+let mut quantizer = Quantizer {
+ calibration: MinMaxCalibration {
+ scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
+ },
+};
+
+// Quantize the weights
+let model = model.quantize_weights(&mut quantizer);
+```
+
+> Given that all operations are currently performed in floating point precision, it might be wise to
+> dequantize the module parameters before inference. This allows us to save disk space by storing
+> the model in reduced precision while preserving the inference speed.
+>
+> This can easily be implemented with a `ModuleMapper`.
+>
+> ```rust, ignore
+> # use burn::module::{ModuleMapper, ParamId};
+> # use burn::tensor::{backend::Backend, Tensor};
+> #
+> /// Module mapper used to dequantize the model params being loaded.
+> pub struct Dequantize {}
+>
+> impl ModuleMapper for Dequantize {
+> fn map_float(
+> &mut self,
+> _id: &ParamId,
+> tensor: Tensor,
+> ) -> Tensor {
+> tensor.dequantize()
+> }
+> }
+>
+> // Load saved quantized model in floating point precision
+> model = model
+> .load_file(file_path, recorder, &device)
+> .expect("Should be able to load the quantized model weights")
+> .map(&mut Dequantize {});
+> ```
+
+### Calibration
+
+Calibration is the step during quantization where the range of all floating-point tensors is
+computed. This is pretty straightforward for weights since the actual range is known at
+_quantization-time_ (weights are static), but activations require more attention.
+
+To compute the quantization parameters, Burn supports the following `Calibration` methods.
+
+| Method | Description |
+| :------------------ | :------------------------------------------------------------------------------- |
+| `MinMaxCalibration` | Computes the quantization range mapping based on the running min and max values. |
+
+### Quantization Scheme
+
+A quantization scheme defines the quantized type, quantization granularity and range mapping
+technique.
+
+Burn currently supports the following `QuantizationType` variants.
+
+| Type | Description |
+| :------ | :--------------------------------- |
+| `QInt8` | 8-bit signed integer quantization. |
+
+Quantization parameters are defined based on the range of values to represent and can typically be
+calculated for the layer's entire weight tensor with per-tensor quantization or separately for each
+channel with per-channel quantization (commonly used with CNNs).
+
+Burn currently supports the following `QuantizationScheme` variants.
+
+| Variant | Description |
+| :------------------- | :------------------------------------------------------------------------------------------------------------- |
+| `PerTensorAffine` | Computes the quantization parameters for the whole tensor and applies an affine range mapping with zero point. |
+| `PerTensorSymmetric` | Computes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0. |
diff --git a/crates/burn-autodiff/src/ops/qtensor.rs b/crates/burn-autodiff/src/ops/qtensor.rs
index 7dbbf5730..ac5a1a391 100644
--- a/crates/burn-autodiff/src/ops/qtensor.rs
+++ b/crates/burn-autodiff/src/ops/qtensor.rs
@@ -1,12 +1,19 @@
use burn_tensor::{
backend::Backend,
ops::{FloatTensor, QTensorOps, QuantizedTensor},
- Device, QuantizationStrategy, Shape,
+ Device, QuantizationStrategy, Shape, TensorData,
};
use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff};
impl QTensorOps for Autodiff {
+ fn q_from_data(
+ _data: TensorData,
+ _device: &Device,
+ ) -> QuantizedTensor {
+ todo!()
+ }
+
fn quantize(
_tensor: FloatTensor,
_strategy: &QuantizationStrategy,
@@ -28,4 +35,18 @@ impl QTensorOps for Autodiff {
fn q_device(tensor: &QuantizedTensor) -> Device {
B::q_device(tensor)
}
+
+ fn q_reshape(
+ tensor: QuantizedTensor,
+ shape: Shape,
+ ) -> QuantizedTensor {
+ B::q_reshape(tensor, shape)
+ }
+
+ async fn q_into_data(
+ tensor: QuantizedTensor,
+ strategy: QuantizationStrategy,
+ ) -> TensorData {
+ B::q_into_data(tensor, strategy).await
+ }
}
diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs
index 3de0d23ff..7d9784cef 100644
--- a/crates/burn-candle/src/ops/qtensor.rs
+++ b/crates/burn-candle/src/ops/qtensor.rs
@@ -1,7 +1,7 @@
use burn_tensor::{
backend::Backend,
ops::{FloatTensor, QTensorOps, QuantizedTensor},
- Device, QuantizationStrategy, Shape,
+ DType, Device, QuantizationStrategy, Shape, TensorData,
};
use crate::{
@@ -10,6 +10,13 @@ use crate::{
};
impl QTensorOps for Candle {
+ fn q_from_data(
+ data: TensorData,
+ device: &Device,
+ ) -> QuantizedTensor {
+ unimplemented!() // no i8 support
+ }
+
fn quantize(
_tensor: FloatTensor,
_strategy: &QuantizationStrategy,
@@ -31,4 +38,18 @@ impl QTensorOps for Candle(tensor: &QuantizedTensor) -> Device {
super::base::device(tensor)
}
+
+ fn q_reshape(
+ tensor: QuantizedTensor,
+ shape: Shape,
+ ) -> QuantizedTensor {
+ super::base::reshape(tensor, shape)
+ }
+
+ async fn q_into_data(
+ tensor: QuantizedTensor,
+ strategy: QuantizationStrategy,
+ ) -> TensorData {
+ super::base::into_data(tensor)
+ }
}
diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs
index 4bfb8f900..42c7ef3a3 100644
--- a/crates/burn-core/src/lib.rs
+++ b/crates/burn-core/src/lib.rs
@@ -33,6 +33,9 @@ pub mod module;
/// Neural network module.
pub mod nn;
+/// Quantization module.
+pub mod quantization;
+
/// Module for the recorder.
pub mod record;
diff --git a/crates/burn-core/src/module/base.rs b/crates/burn-core/src/module/base.rs
index c34c155e8..0c3895368 100644
--- a/crates/burn-core/src/module/base.rs
+++ b/crates/burn-core/src/module/base.rs
@@ -1,5 +1,6 @@
use super::ParamId;
use crate::{
+ quantization::{Calibration, Quantizer},
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
@@ -202,6 +203,11 @@ pub trait Module: Clone + Send + core::fmt::Debug {
Ok(self.load_record(record))
}
+
+ /// Quantize the weights of the module.
+ fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
+ self.map(quantizer)
+ }
}
/// Module visitor trait.
diff --git a/crates/burn-core/src/quantization/calibration.rs b/crates/burn-core/src/quantization/calibration.rs
new file mode 100644
index 000000000..81e466d20
--- /dev/null
+++ b/crates/burn-core/src/quantization/calibration.rs
@@ -0,0 +1,80 @@
+use burn_tensor::{
+ backend::Backend, AffineQuantization, ElementConversion, Quantization, QuantizationStrategy,
+ SymmetricQuantization, Tensor,
+};
+
+use super::{QuantizationScheme, QuantizationType};
+
+/// Calibration method used to compute the quantization range mapping.
+pub trait Calibration {
+ /// Configure the quantization strategy.
+ fn configure(&self, tensor: &Tensor) -> QuantizationStrategy;
+}
+
+/// Computes the quantization range mapping based on the running min and max values.
+pub struct MinMaxCalibration {
+ /// Quantization scheme to be used.
+ pub scheme: QuantizationScheme,
+}
+
+impl Calibration for MinMaxCalibration {
+ fn configure(&self, tensor: &Tensor) -> QuantizationStrategy {
+ let min = tensor.clone().min().into_scalar().elem::();
+ let max = tensor.clone().max().into_scalar().elem::();
+
+ match &self.scheme {
+ QuantizationScheme::PerTensorAffine(dtype) => match dtype {
+ QuantizationType::QInt8 => {
+ QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(min, max))
+ }
+ },
+ QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
+ QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8(
+ SymmetricQuantization::new(min, max),
+ ),
+ },
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use crate::TestBackend;
+
+ #[test]
+ fn min_max_calibration_per_tensor_affine_int8() {
+ let device = ::Device::default();
+ let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
+ let calibration = MinMaxCalibration {
+ scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
+ };
+
+ let strategy = calibration.configure(&tensor);
+
+ if let QuantizationStrategy::PerTensorAffineInt8(q) = strategy {
+ assert_eq!(q.scale, 0.009_019_608);
+ assert_eq!(q.offset, 72);
+ } else {
+ panic!("Wrong quantization strategy");
+ }
+ }
+
+ #[test]
+ fn min_max_calibration_per_tensor_symmetric_int8() {
+ let device = ::Device::default();
+ let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device);
+ let calibration = MinMaxCalibration {
+ scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
+ };
+
+ let strategy = calibration.configure(&tensor);
+
+ if let QuantizationStrategy::PerTensorSymmetricInt8(q) = strategy {
+ assert_eq!(q.scale, 0.014_173_228);
+ } else {
+ panic!("Wrong quantization strategy");
+ }
+ }
+}
diff --git a/crates/burn-core/src/quantization/mod.rs b/crates/burn-core/src/quantization/mod.rs
new file mode 100644
index 000000000..38b3bb604
--- /dev/null
+++ b/crates/burn-core/src/quantization/mod.rs
@@ -0,0 +1,7 @@
+mod calibration;
+mod quantize;
+mod scheme;
+
+pub use calibration::*;
+pub use quantize::*;
+pub use scheme::*;
diff --git a/crates/burn-core/src/quantization/quantize.rs b/crates/burn-core/src/quantization/quantize.rs
new file mode 100644
index 000000000..040886a1a
--- /dev/null
+++ b/crates/burn-core/src/quantization/quantize.rs
@@ -0,0 +1,18 @@
+use burn_tensor::{backend::Backend, Tensor};
+
+use crate::module::{ModuleMapper, ParamId};
+
+use super::Calibration;
+
+/// Describes how to quantize a module.
+pub struct Quantizer {
+ /// The calibration method used in quantization.
+ pub calibration: C,
+}
+
+impl ModuleMapper for Quantizer {
+ fn map_float(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor {
+ let strategy = self.calibration.configure(&tensor);
+ tensor.quantize(strategy)
+ }
+}
diff --git a/crates/burn-core/src/quantization/scheme.rs b/crates/burn-core/src/quantization/scheme.rs
new file mode 100644
index 000000000..db11c22ae
--- /dev/null
+++ b/crates/burn-core/src/quantization/scheme.rs
@@ -0,0 +1,17 @@
+/// Quantization data type.
+pub enum QuantizationType {
+ /// 8-bit signed integer.
+ QInt8,
+}
+
+/// Quantization scheme.
+pub enum QuantizationScheme {
+ /// Per-tensor affine/asymmetric quantization.
+ PerTensorAffine(QuantizationType),
+ /// Per-tensor symmetric quantization.
+ PerTensorSymmetric(QuantizationType),
+ // /// Per-channel affine/asymmetric quantization.
+ // PerChannelAffine,
+ // /// Per-channel symmetric quantization.
+ // PerChannelSymmetric,
+}
diff --git a/crates/burn-core/src/record/tensor.rs b/crates/burn-core/src/record/tensor.rs
index 904210359..ab6f448b7 100644
--- a/crates/burn-core/src/record/tensor.rs
+++ b/crates/burn-core/src/record/tensor.rs
@@ -1,7 +1,7 @@
use core::marker::PhantomData;
use super::{PrecisionSettings, Record};
-use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor, TensorData};
+use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData};
use serde::{Deserialize, Serialize};
#[cfg(not(feature = "record-backward-compat"))]
@@ -43,7 +43,12 @@ where
e
))
})?;
- Ok(data.convert::())
+ let data = if let DType::QFloat(_) = data.dtype {
+ data // do not convert quantized tensors
+ } else {
+ data.convert::()
+ };
+ Ok(data)
}
}
@@ -137,15 +142,25 @@ impl Record for Tensor {
type Item = FloatTensorSerde;
fn into_item(self) -> Self::Item {
- FloatTensorSerde::new(self.into_data().convert::())
+ let data = self.into_data();
+ let data = if let DType::QFloat(_) = data.dtype {
+ data // do not convert quantized tensors
+ } else {
+ data.convert::()
+ };
+ FloatTensorSerde::new(data)
}
fn from_item(item: Self::Item, device: &B::Device) -> Self {
- Tensor::from_data(item.data.convert::(), device)
+ let data = if let DType::QFloat(_) = item.data.dtype {
+ item.data // do not convert quantized tensors
+ } else {
+ item.data.convert::()
+ };
+ Tensor::from_data(data, device)
}
}
-#[allow(deprecated)]
impl Record for Tensor {
type Item = IntTensorSerde;
@@ -158,7 +173,6 @@ impl Record for Tensor {
}
}
-#[allow(deprecated)]
impl Record for Tensor {
type Item = BoolTensorSerde;
diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs
index b0f340aa9..087f826fd 100644
--- a/crates/burn-fusion/src/ops/qtensor.rs
+++ b/crates/burn-fusion/src/ops/qtensor.rs
@@ -1,12 +1,19 @@
use burn_tensor::{
backend::Backend,
ops::{QTensorOps, QuantizedTensor},
- Device, QuantizationStrategy, Shape,
+ Device, QuantizationStrategy, Shape, TensorData,
};
use crate::{client::FusionClient, Fusion, FusionBackend};
impl QTensorOps for Fusion {
+ fn q_from_data(
+ _data: TensorData,
+ _device: &Device,
+ ) -> QuantizedTensor {
+ unimplemented!()
+ }
+
fn quantize(
_tensor: ::FloatTensorPrimitive,
_strategy: &QuantizationStrategy,
@@ -28,4 +35,18 @@ impl QTensorOps for Fusion {
fn q_device(tensor: &QuantizedTensor) -> Device {
tensor.client.device().clone()
}
+
+ fn q_reshape(
+ _tensor: QuantizedTensor,
+ _shape: Shape,
+ ) -> QuantizedTensor {
+ unimplemented!()
+ }
+
+ async fn q_into_data(
+ _tensor: QuantizedTensor,
+ _strategy: QuantizationStrategy,
+ ) -> TensorData {
+ unimplemented!()
+ }
}
diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs
index 76577c103..8c87c2d53 100644
--- a/crates/burn-jit/src/ops/qtensor.rs
+++ b/crates/burn-jit/src/ops/qtensor.rs
@@ -1,6 +1,6 @@
use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor},
- Device, QuantizationStrategy, Shape,
+ Device, QuantizationStrategy, Shape, TensorData,
};
use crate::{FloatElement, IntElement, JitBackend, JitRuntime};
@@ -11,6 +11,13 @@ where
F: FloatElement,
I: IntElement,
{
+ fn q_from_data(
+ _data: TensorData,
+ _device: &Device,
+ ) -> QuantizedTensor {
+ todo!()
+ }
+
fn quantize(
_tensor: FloatTensor,
_strategy: &QuantizationStrategy,
@@ -32,4 +39,18 @@ where
fn q_device(tensor: &QuantizedTensor) -> Device {
tensor.device.clone()
}
+
+ fn q_reshape(
+ tensor: QuantizedTensor,
+ shape: Shape,
+ ) -> QuantizedTensor {
+ super::reshape(tensor, shape)
+ }
+
+ async fn q_into_data(
+ _tensor: QuantizedTensor,
+ _strategy: QuantizationStrategy,
+ ) -> TensorData {
+ unimplemented!()
+ }
}
diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs
index f551c5b49..77274c321 100644
--- a/crates/burn-ndarray/src/ops/qtensor.rs
+++ b/crates/burn-ndarray/src/ops/qtensor.rs
@@ -1,10 +1,12 @@
use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor},
- Quantization, QuantizationStrategy, Shape, TensorData,
+ DType, Quantization, QuantizationStrategy, Shape, TensorData,
};
use crate::{element::NdArrayElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
+use super::NdArrayOps;
+
fn into_data(tensor: NdArrayTensor) -> TensorData {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
@@ -12,6 +14,28 @@ fn into_data(tensor: NdArrayTensor) ->
}
impl QTensorOps for NdArray {
+ fn q_from_data(
+ data: TensorData,
+ _device: &NdArrayDevice,
+ ) -> QuantizedTensor {
+ match data.dtype {
+ DType::QFloat(strategy) => match strategy {
+ QuantizationStrategy::PerTensorAffineInt8(_) => {
+ let data = data.convert::();
+ NdArrayTensor::::from_data(data)
+ }
+ QuantizationStrategy::PerTensorSymmetricInt8(_) => {
+ let data = data.convert::();
+ NdArrayTensor::::from_data(data)
+ }
+ },
+ _ => panic!(
+ "Invalid dtype (expected DType::QFloat, got {:?})",
+ data.dtype
+ ),
+ }
+ }
+
fn quantize(
tensor: FloatTensor,
strategy: &QuantizationStrategy,
@@ -41,4 +65,20 @@ impl QTensorOps for NdArray {
fn q_device(_tensor: &QuantizedTensor) -> NdArrayDevice {
NdArrayDevice::Cpu
}
+
+ fn q_reshape(
+ tensor: QuantizedTensor,
+ shape: Shape,
+ ) -> QuantizedTensor {
+ NdArrayOps::reshape(tensor, shape)
+ }
+
+ async fn q_into_data(
+ tensor: QuantizedTensor,
+ strategy: QuantizationStrategy,
+ ) -> TensorData {
+ let shape = tensor.shape();
+ let values = tensor.array.into_iter().collect();
+ TensorData::quantized(values, shape, strategy)
+ }
}
diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs
index 6a09a0239..e557c4ae9 100644
--- a/crates/burn-tch/src/ops/base.rs
+++ b/crates/burn-tch/src/ops/base.rs
@@ -1,4 +1,4 @@
-use burn_tensor::Shape;
+use burn_tensor::{QuantizationStrategy, Shape};
use tch::Scalar;
use crate::{LibTorchDevice, TchShape, TchTensor};
@@ -512,4 +512,30 @@ impl TchOps {
) -> TchTensor {
TchTensor::new(tensor.tensor.argsort(dim as i64, descending))
}
+
+ pub fn quantize(
+ tensor: TchTensor,
+ strategy: &QuantizationStrategy,
+ ) -> TchTensor {
+ let mut tensor = tensor;
+ // Quantize only works on Float Tensor
+ if tensor.tensor.kind() == tch::Kind::Half {
+ tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
+ }
+
+ match strategy {
+ QuantizationStrategy::PerTensorAffineInt8(ref q) => {
+ TchTensor::new(tensor.tensor.quantize_per_tensor(
+ q.scale.into(),
+ q.offset.into(),
+ tch::Kind::QInt8,
+ ))
+ }
+ QuantizationStrategy::PerTensorSymmetricInt8(ref q) => TchTensor::new(
+ tensor
+ .tensor
+ .quantize_per_tensor(q.scale.into(), 0, tch::Kind::QInt8),
+ ),
+ }
+ }
}
diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs
index 60d317576..73238e787 100644
--- a/crates/burn-tch/src/ops/qtensor.rs
+++ b/crates/burn-tch/src/ops/qtensor.rs
@@ -1,15 +1,63 @@
use burn_tensor::{
ops::{FloatTensor, QTensorOps, QuantizedTensor},
- QuantizationStrategy, Shape,
+ DType, Quantization, QuantizationStrategy, Shape, TensorData,
};
-use crate::{LibTorch, LibTorchDevice, TchElement, TchTensor};
+use crate::{LibTorch, LibTorchDevice, TchElement, TchShape, TchTensor};
+
+use super::TchOps;
impl QTensorOps for LibTorch {
+ fn q_from_data(
+ data: TensorData,
+ device: &LibTorchDevice,
+ ) -> QuantizedTensor {
+ let shape_tch = TchShape::::from(data.shape.as_slice());
+ let device = (*device).into();
+
+ // NOTE: tch-rs doesn't have `from_blob_quantized_*` APIs
+ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/quantized/Quantizer.cpp#L322
+ // So for now we have to load the dequantized values to quantize them back since the dequantization
+ // methods take the values provided when quantizing.
+ let tensor = match data.dtype {
+ DType::QFloat(strategy) => match strategy {
+ QuantizationStrategy::PerTensorAffineInt8(q) => {
+ let values = q.dequantize(&data.iter::().collect::>());
+ let tensor = tch::Tensor::from_slice(&values).to(device);
+ TchOps::::quantize::(
+ TchTensor::new(tensor.reshape(shape_tch.dims)),
+ &strategy,
+ )
+ .tensor
+ }
+ QuantizationStrategy::PerTensorSymmetricInt8(q) => {
+ let values = q.dequantize(&data.iter::().collect::>());
+ let tensor = tch::Tensor::from_slice(&values).to(device);
+ TchOps::::quantize::(
+ TchTensor::new(tensor.reshape(shape_tch.dims)),
+ &strategy,
+ )
+ .tensor
+ }
+ },
+ _ => panic!(
+ "Invalid dtype (expected DType::QFloat, got {:?})",
+ data.dtype
+ ),
+ };
+ TchTensor::new(tensor)
+ }
+
fn quantize(
tensor: FloatTensor,
strategy: &QuantizationStrategy,
) -> QuantizedTensor {
+ let mut tensor = tensor;
+ // Quantize only works on Float Tensor
+ if E::dtype() == DType::F16 {
+ tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
+ }
+
match strategy {
QuantizationStrategy::PerTensorAffineInt8(ref q) => {
TchTensor::new(tensor.tensor.quantize_per_tensor(
@@ -30,7 +78,7 @@ impl QTensorOps for LibTorch {
tensor: QuantizedTensor,
_strategy: &QuantizationStrategy,
) -> FloatTensor {
- TchTensor::new(tensor.tensor.dequantize())
+ TchTensor::new(tensor.tensor.dequantize().to_kind(E::KIND))
}
fn q_shape(tensor: &QuantizedTensor) -> Shape {
@@ -40,4 +88,23 @@ impl QTensorOps for LibTorch {
fn q_device(tensor: &QuantizedTensor) -> LibTorchDevice {
tensor.tensor.device().into()
}
+
+ fn q_reshape(
+ tensor: QuantizedTensor,
+ shape: Shape,
+ ) -> QuantizedTensor {
+ TchOps::reshape(tensor, shape)
+ }
+
+ async fn q_into_data(
+ tensor: QuantizedTensor,
+ strategy: QuantizationStrategy,
+ ) -> TensorData {
+ let shape = Self::q_shape(&tensor);
+ let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
+ // To get the integer values we have to call `int_repr()`
+ let values: Result, tch::TchError> = tensor.tensor.int_repr().try_into();
+
+ TensorData::quantized(values.unwrap(), shape, strategy)
+ }
}
diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs
index 0d44aa4fc..0fd16c151 100644
--- a/crates/burn-tensor/src/tensor/api/base.rs
+++ b/crates/burn-tensor/src/tensor/api/base.rs
@@ -18,7 +18,7 @@ use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow;
use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind};
-use crate::{Element, TensorPrimitive};
+use crate::{DType, Element, TensorPrimitive};
/// A tensor with a given backend, shape and data type.
#[derive(new, Clone, Debug)]
@@ -1697,7 +1697,15 @@ impl BasicOps for Float {
tensor: Self::Primitive,
shape: Shape,
) -> Self::Primitive {
- TensorPrimitive::Float(B::float_reshape(tensor.tensor(), shape))
+ match tensor {
+ TensorPrimitive::Float(tensor) => {
+ TensorPrimitive::Float(B::float_reshape(tensor, shape))
+ }
+ TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
+ tensor: B::q_reshape(tensor, shape),
+ strategy,
+ },
+ }
}
fn transpose(tensor: Self::Primitive) -> Self::Primitive {
@@ -1750,11 +1758,20 @@ impl BasicOps for Float {
}
async fn into_data_async(tensor: Self::Primitive) -> TensorData {
- B::float_into_data(tensor.tensor()).await
+ match tensor {
+ TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
+ TensorPrimitive::QFloat { tensor, strategy } => B::q_into_data(tensor, strategy).await,
+ }
}
fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
- TensorPrimitive::Float(B::float_from_data(data, device))
+ match data.dtype {
+ DType::QFloat(strategy) => TensorPrimitive::QFloat {
+ tensor: B::q_from_data(data, device),
+ strategy,
+ },
+ _ => TensorPrimitive::Float(B::float_from_data(data, device)),
+ }
}
fn repeat(
diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs
index 7da52eb1b..ef320d7b8 100644
--- a/crates/burn-tensor/src/tensor/api/float.rs
+++ b/crates/burn-tensor/src/tensor/api/float.rs
@@ -271,9 +271,9 @@ where
match &self.primitive {
TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
TensorPrimitive::QFloat {
- tensor: _,
+ tensor,
strategy: _,
- } => B::float_is_require_grad(&self.primitive.clone().tensor()),
+ } => B::q_is_require_grad(tensor),
}
}
@@ -282,10 +282,16 @@ where
///
/// This function does nothing when autodiff is not enabled.
pub fn set_require_grad(self, require_grad: bool) -> Self {
- Self::new(TensorPrimitive::Float(B::float_set_require_grad(
- self.primitive.tensor(),
- require_grad,
- )))
+ let primitive = match self.primitive {
+ TensorPrimitive::Float(tensor) => {
+ TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
+ }
+ TensorPrimitive::QFloat { tensor, strategy } => TensorPrimitive::QFloat {
+ tensor: B::q_set_require_grad(tensor, require_grad),
+ strategy,
+ },
+ };
+ Self::new(primitive)
}
/// Applies the relu function to the tensor.
diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs
index f3c8b3843..edf293d74 100644
--- a/crates/burn-tensor/src/tensor/data.rs
+++ b/crates/burn-tensor/src/tensor/data.rs
@@ -48,6 +48,15 @@ impl TensorData {
Self::init(value, shape, E::dtype())
}
+ /// Creates a new quantized tensor data structure.
+ pub fn quantized>>(
+ value: Vec,
+ shape: S,
+ strategy: QuantizationStrategy,
+ ) -> Self {
+ Self::init(value, shape, DType::QFloat(strategy))
+ }
+
/// Initializes a new tensor data structure from the provided values.
fn init>>(value: Vec, shape: S, dtype: DType) -> Self {
Self {
@@ -258,15 +267,15 @@ impl TensorData {
"Only f32 data type can be quantized"
);
match &quantization {
- QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::init(
+ QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::quantized(
strategy.quantize(self.as_slice().unwrap()),
self.shape,
- DType::QFloat(quantization),
+ quantization,
),
- QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::init(
+ QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::quantized(
strategy.quantize(self.as_slice().unwrap()),
self.shape,
- DType::QFloat(quantization),
+ quantization,
),
}
}
diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs
index 9163f9df7..458c156f9 100644
--- a/crates/burn-tensor/src/tensor/ops/qtensor.rs
+++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs
@@ -1,10 +1,24 @@
-use crate::{backend::Backend, Device, QuantizationStrategy, Shape};
+use core::future::Future;
+
+use crate::{backend::Backend, Device, QuantizationStrategy, Shape, TensorData};
use super::{FloatTensor, QuantizedTensor};
/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
/// for documentation on each function.
pub trait QTensorOps {
+ /// Creates a new tensor from the data structure.
+ ///
+ /// # Arguments
+ ///
+ /// * `data` - The data structure.
+ /// * `device` - The device to create the tensor on.
+ ///
+ /// # Returns
+ ///
+ /// The tensor with the given data.
+ fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor;
+
/// Convert the tensor to a lower precision data type based on the quantization strategy.
fn quantize(
tensor: FloatTensor,
@@ -38,4 +52,48 @@ pub trait QTensorOps {
///
/// The device of the tensor.
fn q_device(tensor: &QuantizedTensor) -> Device;
+
+ /// Reshapes a tensor.
+ ///
+ /// # Arguments
+ ///
+ /// * `tensor` - The tensor to reshape.
+ /// * `shape` - The new shape of the tensor.
+ ///
+ /// # Returns
+ ///
+ /// The tensor with the new shape.
+ fn q_reshape(
+ tensor: QuantizedTensor,
+ shape: Shape,
+ ) -> QuantizedTensor;
+
+ /// Converts the tensor to a data structure.
+ ///
+ /// # Arguments
+ ///
+ /// * `tensor` - The tensor.
+ ///
+ /// # Returns
+ ///
+ /// The data structure with the tensor's data.
+ fn q_into_data(
+ tensor: QuantizedTensor,
+ strategy: QuantizationStrategy,
+ ) -> impl Future