Feat/recorder/custom device (#1165)

This commit is contained in:
Nathaniel Simard 2024-01-23 13:05:41 -05:00 committed by GitHub
parent e9d1656687
commit eaa4dc3207
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 490 additions and 267 deletions

View File

@ -37,10 +37,10 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into())
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
let model = config.model.init_with::<B>(record).to_device(&device);
let model = config.model.init_with::<B>(record);
let label = item.label;
let batcher = MNISTBatcher::new(device);

View File

@ -23,7 +23,7 @@ Now that you have a trained model saved to your disk, you can easily load it in
// Load model in full precision from MessagePack file
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.load_file(model_path, &recorder)
.load_file(model_path, &recorder, device)
.expect("Should be able to load the model weights from the provided file");
```
@ -96,7 +96,7 @@ Afterwards, the model can just as easily be loaded from the record saved on disk
```rust, ignore
// Load model record on the backend's default device
let record: ModelRecord<MyBackend> = NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(model_path.into())
.load(model_path.into(), device)
.expect("Should be able to load the model weights from the provided file");
// Directly initialize a new model with the loaded record/weights
@ -133,7 +133,7 @@ static MODEL_BYTES: &[u8] = include_bytes!("path/to/model.bin");
// Load model binary record in full precision
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
.load(MODEL_BYTES.to_vec())
.load(MODEL_BYTES.to_vec(), device)
.expect("Should be able to load model the model weights from bytes");
// Load that record with the model

View File

@ -1,9 +1,11 @@
use burn_tensor::backend::Backend;
use crate::{record::Record, LearningRate};
/// Learning rate scheduler defines how the learning rate will evolve during training.
pub trait LrScheduler: Send + Sync {
pub trait LrScheduler<B: Backend>: Send + Sync {
/// Scheduler associative type to be used when saving and loading the state.
type Record: Record;
type Record: Record<B>;
/// Perform the scheduler step, potentially updating its state, and returning the effective
/// learning rate.

View File

@ -1,3 +1,5 @@
use burn_tensor::backend::Backend;
use super::LrScheduler;
use crate::LearningRate;
@ -17,7 +19,7 @@ impl From<LearningRate> for ConstantLr {
}
}
impl LrScheduler for ConstantLr {
impl<B: Backend> LrScheduler<B> for ConstantLr {
type Record = ();
fn step(&mut self) -> LearningRate {
@ -31,7 +33,7 @@ impl LrScheduler for ConstantLr {
}
}
impl LrScheduler for LearningRate {
impl<B: Backend> LrScheduler<B> for LearningRate {
type Record = ();
fn step(&mut self) -> LearningRate {

View File

@ -1,3 +1,5 @@
use burn_tensor::backend::Backend;
use crate as burn;
use super::LrScheduler;
@ -37,7 +39,7 @@ impl NoamLrSchedulerConfig {
}
}
impl LrScheduler for NoamLrScheduler {
impl<B: Backend> LrScheduler<B> for NoamLrScheduler {
type Record = usize;
fn step(&mut self) -> LearningRate {
@ -61,6 +63,8 @@ impl LrScheduler for NoamLrScheduler {
#[cfg(test)]
mod tests {
use crate::TestBackend;
use super::*;
#[test]
@ -72,7 +76,7 @@ mod tests {
let mut lr_current = 0.0;
for _ in 0..warmup_steps {
let lr = scheduler.step();
let lr = LrScheduler::<TestBackend>::step(&mut scheduler);
assert!(
lr > lr_current,
"Learning rate should increase before the warmup_steps is reached."
@ -81,7 +85,7 @@ mod tests {
}
for _ in 0..warmup_steps {
let lr = scheduler.step();
let lr = LrScheduler::<TestBackend>::step(&mut scheduler);
assert!(
lr < lr_current,
"Learning rate should decrease after the warmup_steps is reached."

View File

@ -82,7 +82,7 @@ macro_rules! module {
/// ```
pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record;
type Record: Record<B>;
/// Return all the devices found in the underneath module tree added to the given vector
/// without duplicates.
@ -164,11 +164,15 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn save_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), crate::record::RecorderError> {
) -> Result<(), crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = Self::into_record(self);
recorder.record(record, file_path.into())
}
@ -183,12 +187,17 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn load_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<Self, crate::record::RecorderError> {
let record = recorder.load(file_path.into())?;
device: &B::Device,
) -> Result<Self, crate::record::RecorderError>
where
FR: crate::record::FileRecorder<B>,
PB: Into<std::path::PathBuf>,
{
let record = recorder.load(file_path.into(), device)?;
Ok(self.load_record(record))
}

View File

@ -34,14 +34,14 @@ impl<'de> serde::Deserialize<'de> for ConstantRecord {
}
}
impl Record for ConstantRecord {
impl<B: Backend> Record<B> for ConstantRecord {
type Item<S: PrecisionSettings> = ConstantRecord;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
item
}
}
@ -213,7 +213,7 @@ mod tests {
use core::marker::PhantomData;
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;
use burn_tensor::{Device, Tensor};
use crate::TestBackend;
use crate::{
@ -226,21 +226,31 @@ mod tests {
#[test]
fn tensor_load_record_setting() {
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &Default::default());
let device: &Device<TestAutodiffBackend> = &Default::default();
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
let bytes = byte_recorder
.record(tensor.clone().into_record(), ())
.unwrap();
let bytes = Recorder::<TestAutodiffBackend>::record(
&byte_recorder,
tensor.clone().into_record(),
(),
)
.unwrap();
let no_grad_is_require_grad = tensor
.clone()
.no_grad()
.load_record(byte_recorder.load(bytes.clone()).unwrap())
.load_record(
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
.unwrap(),
)
.is_require_grad();
let with_default_is_require_grad = tensor
.load_record(byte_recorder.load(bytes).unwrap())
.load_record(
Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)
.unwrap(),
)
.is_require_grad();
assert!(!no_grad_is_require_grad);

View File

@ -228,7 +228,8 @@ mod tests {
#[test]
fn test_load_record_setting() {
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &Default::default());
let device = Default::default();
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device);
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
let bytes = byte_recorder
@ -237,12 +238,12 @@ mod tests {
let no_grad_is_require_grad = Param::from(tensor.clone())
.no_grad()
.load_record(byte_recorder.load(bytes.clone()).unwrap())
.load_record(byte_recorder.load(bytes.clone(), &device).unwrap())
.value
.is_require_grad();
let with_default_is_require_grad = Param::from(tensor)
.load_record(byte_recorder.load(bytes).unwrap())
.load_record(byte_recorder.load(bytes, &device).unwrap())
.value
.is_require_grad();

View File

@ -11,7 +11,7 @@ where
B: AutodiffBackend,
{
/// Optimizer associative type to be used when saving and loading the state.
type Record: Record;
type Record: Record<B>;
/// Perform the optimizer step using the given learning rate and gradients.
/// The updated module is returned.

View File

@ -18,7 +18,7 @@ where
B: AutodiffBackend,
{
optim: O,
records: HashMap<ParamId, AdaptorRecord<O, B::InnerBackend>>,
records: HashMap<ParamId, AdaptorRecord<O, B>>,
module: PhantomData<M>,
grad_clipping: Option<GradientClipping>,
}
@ -71,7 +71,7 @@ where
M: AutodiffModule<B>,
O: SimpleOptimizer<B::InnerBackend>,
{
type Record = HashMap<ParamId, AdaptorRecord<O, B::InnerBackend>>;
type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M {
let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
@ -102,7 +102,7 @@ where
O: SimpleOptimizer<B::InnerBackend>,
{
optimizer: &'a O,
records: &'a mut HashMap<ParamId, AdaptorRecord<O, B::InnerBackend>>,
records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
grads: &'a mut GradientsParams,
lr: LearningRate,
phantom: PhantomData<M>,

View File

@ -11,7 +11,7 @@ where
B: Backend,
{
/// The state of the optimizer. It also implements [record](Record), so that it can be saved.
type State<const D: usize>: Record + Clone + 'static;
type State<const D: usize>: Record<B> + Clone + 'static;
/// The optimizer step is performed for one tensor at a time with its gradient and state.
///

View File

@ -3,29 +3,37 @@ use crate::{
optim::SimpleOptimizer,
record::{PrecisionSettings, Record},
};
use burn_tensor::backend::Backend;
use burn_tensor::backend::AutodiffBackend;
use serde::{Deserialize, Serialize};
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record.
///
/// Records are versioned for backward compatibility, so old records can be loaded.
pub enum AdaptorRecord<O: SimpleOptimizer<B>, B: Backend> {
pub enum AdaptorRecord<O, B>
where
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
/// Version 1.
V1(AdaptorRecordV1<O, B>),
V1(AdaptorRecordV1<O, B::InnerBackend>),
}
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub enum AdaptorRecordItem<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
pub enum AdaptorRecordItem<
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
S: PrecisionSettings,
> {
/// Version 1.
V1(AdaptorRecordItemV1<O, B, S>),
V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),
}
impl<O, B> Record for AdaptorRecord<O, B>
impl<O, B> Record<B> for AdaptorRecord<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;
@ -35,17 +43,17 @@ where
}
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
match item {
AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item)),
AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),
}
}
}
impl<O, B> Clone for AdaptorRecord<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
fn clone(&self) -> Self {
match self {
@ -56,8 +64,8 @@ where
impl<O, B> AdaptorRecord<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
O: SimpleOptimizer<B::InnerBackend>,
B: AutodiffBackend,
{
/// Converts the record into the optimizer state.
///

View File

@ -53,28 +53,28 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
#[serde(bound = "")]
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
/// Rank 1.
Rank1(<O::State<1> as Record>::Item<S>),
Rank1(<O::State<1> as Record<B>>::Item<S>),
/// Rank 2.
Rank2(<O::State<2> as Record>::Item<S>),
Rank2(<O::State<2> as Record<B>>::Item<S>),
/// Rank 3.
Rank3(<O::State<3> as Record>::Item<S>),
Rank3(<O::State<3> as Record<B>>::Item<S>),
/// Rank 4.
Rank4(<O::State<4> as Record>::Item<S>),
Rank4(<O::State<4> as Record<B>>::Item<S>),
/// Rank 5.
Rank5(<O::State<5> as Record>::Item<S>),
Rank5(<O::State<5> as Record<B>>::Item<S>),
/// Rank 6.
Rank6(<O::State<6> as Record>::Item<S>),
Rank6(<O::State<6> as Record<B>>::Item<S>),
/// Rank 7.
Rank7(<O::State<7> as Record>::Item<S>),
Rank7(<O::State<7> as Record<B>>::Item<S>),
/// Rank 8.
Rank8(<O::State<8> as Record>::Item<S>),
Rank8(<O::State<8> as Record<B>>::Item<S>),
}
impl<O, B> AdaptorRecordV1<O, B>
@ -134,7 +134,7 @@ where
}
}
impl<O, B> Record for AdaptorRecordV1<O, B>
impl<O, B> Record<B> for AdaptorRecordV1<O, B>
where
O: SimpleOptimizer<B>,
B: Backend,
@ -154,31 +154,31 @@ where
}
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
match item {
AdaptorRecordItemV1::Rank1(item) => {
AdaptorRecordV1::Rank1(<O::State<1> as Record>::from_item(item))
AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank2(item) => {
AdaptorRecordV1::Rank2(<O::State<2> as Record>::from_item(item))
AdaptorRecordV1::Rank2(<O::State<2> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank3(item) => {
AdaptorRecordV1::Rank3(<O::State<3> as Record>::from_item(item))
AdaptorRecordV1::Rank3(<O::State<3> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank4(item) => {
AdaptorRecordV1::Rank4(<O::State<4> as Record>::from_item(item))
AdaptorRecordV1::Rank4(<O::State<4> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank5(item) => {
AdaptorRecordV1::Rank5(<O::State<5> as Record>::from_item(item))
AdaptorRecordV1::Rank5(<O::State<5> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank6(item) => {
AdaptorRecordV1::Rank6(<O::State<6> as Record>::from_item(item))
AdaptorRecordV1::Rank6(<O::State<6> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank7(item) => {
AdaptorRecordV1::Rank7(<O::State<7> as Record>::from_item(item))
AdaptorRecordV1::Rank7(<O::State<7> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank8(item) => {
AdaptorRecordV1::Rank8(<O::State<8> as Record>::from_item(item))
AdaptorRecordV1::Rank8(<O::State<8> as Record<B>>::from_item(item, device))
}
}
}

View File

@ -1,10 +1,11 @@
pub use burn_derive::Record;
use burn_tensor::backend::Backend;
use super::PrecisionSettings;
use serde::{de::DeserializeOwned, Serialize};
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
pub trait Record: Send + Sync {
pub trait Record<B: Backend>: Send + Sync {
/// Type of the item that can be serialized and deserialized.
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
@ -12,5 +13,5 @@ pub trait Record: Send + Sync {
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
/// Convert the given item into a record.
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self;
}

View File

@ -1,4 +1,5 @@
use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
use burn_tensor::backend::Backend;
use core::marker::PhantomData;
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use serde::{de::DeserializeOwned, Serialize};
@ -6,8 +7,8 @@ use std::io::{BufReader, BufWriter};
use std::{fs::File, path::PathBuf};
/// Recorder trait specialized to save and load data to and from files.
pub trait FileRecorder:
Recorder<RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
pub trait FileRecorder<B: Backend>:
Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
{
/// File extension of the format used by the recorder.
fn file_extension() -> &'static str;
@ -52,34 +53,34 @@ pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
impl<S: PrecisionSettings> FileRecorder for BinGzFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {
fn file_extension() -> &'static str {
"bin.gz"
}
}
impl<S: PrecisionSettings> FileRecorder for BinFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {
fn file_extension() -> &'static str {
"bin"
}
}
impl<S: PrecisionSettings> FileRecorder for JsonGzFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {
fn file_extension() -> &'static str {
"json.gz"
}
}
impl<S: PrecisionSettings> FileRecorder for PrettyJsonFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {
fn file_extension() -> &'static str {
"json"
}
}
impl<S: PrecisionSettings> FileRecorder for NamedMpkGzFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {
fn file_extension() -> &'static str {
"mpk.gz"
}
}
impl<S: PrecisionSettings> FileRecorder for NamedMpkFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {
fn file_extension() -> &'static str {
"mpk"
}
@ -89,7 +90,7 @@ macro_rules! str2reader {
(
$file:expr
) => {{
$file.set_extension(<Self as FileRecorder>::file_extension());
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
let path = $file.as_path();
File::open(path)
@ -105,7 +106,7 @@ macro_rules! str2writer {
(
$file:expr
) => {{
$file.set_extension(<Self as FileRecorder>::file_extension());
$file.set_extension(<Self as FileRecorder<B>>::file_extension());
let path = $file.as_path();
if path.exists() {
@ -122,7 +123,7 @@ macro_rules! str2writer {
}};
}
impl<S: PrecisionSettings> Recorder for BinGzFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
@ -153,7 +154,7 @@ impl<S: PrecisionSettings> Recorder for BinGzFileRecorder<S> {
}
}
impl<S: PrecisionSettings> Recorder for BinFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
@ -179,7 +180,7 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorder<S> {
}
}
impl<S: PrecisionSettings> Recorder for JsonGzFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
@ -208,7 +209,7 @@ impl<S: PrecisionSettings> Recorder for JsonGzFileRecorder<S> {
}
}
impl<S: PrecisionSettings> Recorder for PrettyJsonFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
@ -234,7 +235,7 @@ impl<S: PrecisionSettings> Recorder for PrettyJsonFileRecorder<S> {
}
}
impl<S: PrecisionSettings> Recorder for NamedMpkGzFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
@ -263,7 +264,7 @@ impl<S: PrecisionSettings> Recorder for NamedMpkGzFileRecorder<S> {
}
}
impl<S: PrecisionSettings> Recorder for NamedMpkFileRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
@ -346,14 +347,18 @@ mod tests {
test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
}
fn test_can_save_and_load<Recorder: FileRecorder>(recorder: Recorder) {
fn test_can_save_and_load<Recorder>(recorder: Recorder)
where
Recorder: FileRecorder<TestBackend>,
{
let device = Default::default();
let model_before = create_model(&device);
recorder
.record(model_before.clone().into_record(), file_path())
.unwrap();
let model_after = create_model(&device).load_record(recorder.load(file_path()).unwrap());
let model_after =
create_model(&device).load_record(recorder.load(file_path(), &device).unwrap());
let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
let model_bytes_before = byte_recorder

View File

@ -1,5 +1,6 @@
use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
use alloc::vec::Vec;
use burn_tensor::backend::Backend;
use serde::{de::DeserializeOwned, Serialize};
/// Recorder trait specialized to save and load data to and from bytes.
@ -8,8 +9,8 @@ use serde::{de::DeserializeOwned, Serialize};
///
/// This is especially useful in no_std environment where weights are stored directly in
/// compiled binaries.
pub trait BytesRecorder:
Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
pub trait BytesRecorder<B: Backend>:
Recorder<B, RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
{
}
@ -19,9 +20,9 @@ pub struct BinBytesRecorder<S: PrecisionSettings> {
_settings: core::marker::PhantomData<S>,
}
impl<S: PrecisionSettings> BytesRecorder for BinBytesRecorder<S> {}
impl<S: PrecisionSettings, B: Backend> BytesRecorder<B> for BinBytesRecorder<S> {}
impl<S: PrecisionSettings> Recorder for BinBytesRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for BinBytesRecorder<S> {
type Settings = S;
type RecordArgs = ();
type RecordOutput = Vec<u8>;
@ -48,10 +49,10 @@ pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
}
#[cfg(feature = "std")]
impl<S: PrecisionSettings> BytesRecorder for NamedMpkBytesRecorder<S> {}
impl<S: PrecisionSettings, B: Backend> BytesRecorder<B> for NamedMpkBytesRecorder<S> {}
#[cfg(feature = "std")]
impl<S: PrecisionSettings> Recorder for NamedMpkBytesRecorder<S> {
impl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkBytesRecorder<S> {
type Settings = S;
type RecordArgs = ();
type RecordOutput = Vec<u8>;
@ -87,14 +88,17 @@ mod tests {
test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())
}
fn test_can_save_and_load<Recorder: BytesRecorder>(recorder: Recorder) {
fn test_can_save_and_load<Recorder>(recorder: Recorder)
where
Recorder: BytesRecorder<TestBackend>,
{
let device = Default::default();
let model1 = create_model::<TestBackend>(&device);
let model2 = create_model::<TestBackend>(&device);
let bytes1 = recorder.record(model1.into_record(), ()).unwrap();
let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();
let model2_after = model2.load_record(recorder.load(bytes1.clone()).unwrap());
let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap());
let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();
assert_ne!(bytes1, bytes2);

View File

@ -16,55 +16,76 @@ use crate::module::{Param, ParamId};
use burn_tensor::{DataSerialize, Element};
use hashbrown::HashMap;
impl Record for () {
impl<B> Record<B> for ()
where
B: Backend,
{
type Item<S: PrecisionSettings> = ();
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
fn from_item<S: PrecisionSettings>(_item: Self::Item<S>) -> Self {}
fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
}
impl<T: Record> Record for Vec<T> {
impl<T, B> Record<B> for Vec<T>
where
T: Record<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.into_iter().map(Record::into_item).collect()
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.into_iter().map(Record::from_item).collect()
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
item.into_iter()
.map(|i| Record::from_item(i, device))
.collect()
}
}
impl<T: Record> Record for Option<T> {
impl<T, B> Record<B> for Option<T>
where
T: Record<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = Option<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.map(Record::into_item)
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.map(Record::from_item)
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
item.map(|i| Record::from_item(i, device))
}
}
impl<const N: usize, T: Record + core::fmt::Debug> Record for [T; N] {
impl<const N: usize, T, B> Record<B> for [T; N]
where
T: Record<B> + core::fmt::Debug,
B: Backend,
{
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.map(Record::into_item).into_iter().collect()
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
item.into_iter()
.map(Record::from_item)
.map(|i| Record::from_item(i, device))
.collect::<Vec<_>>()
.try_into()
.unwrap_or_else(|_| panic!("An arrar of size {N}"))
}
}
impl<T: Record> Record for HashMap<ParamId, T> {
impl<T, B> Record<B> for HashMap<ParamId, T>
where
T: Record<B>,
B: Backend,
{
type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
@ -75,23 +96,27 @@ impl<T: Record> Record for HashMap<ParamId, T> {
items
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
let mut record = HashMap::with_capacity(item.len());
item.into_iter().for_each(|(id, item)| {
record.insert(ParamId::from(id), T::from_item(item));
record.insert(ParamId::from(id), T::from_item(item, device));
});
record
}
}
impl<E: Element> Record for DataSerialize<E> {
impl<E, B> Record<B> for DataSerialize<E>
where
E: Element,
B: Backend,
{
type Item<S: PrecisionSettings> = DataSerialize<S::FloatElem>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.convert()
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
item.convert()
}
}
@ -103,57 +128,72 @@ pub struct ParamSerde<T> {
param: T,
}
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
where
B: Backend,
{
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
ParamSerde::new(self.id.into_string(), self.value.into_item())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Param::new(
ParamId::from(item.id),
Tensor::from_item(item.param).require_grad(), // Same behavior as when we create a new
// Param from a tensor.
Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new
// Param from a tensor.
)
}
}
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Int>> {
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
where
B: Backend,
{
type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
ParamSerde::new(self.id.into_string(), self.value.into_item())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Param::new(ParamId::from(item.id), Tensor::from_item(item.param))
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Param::new(
ParamId::from(item.id),
Tensor::from_item(item.param, device),
)
}
}
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Bool>> {
impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
where
B: Backend,
{
type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
ParamSerde::new(self.id.into_string(), self.value.into_item::<S>())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Param::new(ParamId::from(item.id), Tensor::from_item::<S>(item.param))
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Param::new(
ParamId::from(item.id),
Tensor::from_item::<S>(item.param, device),
)
}
}
// Type that can be serialized as is without any conversion.
macro_rules! primitive {
($type:ty) => {
impl Record for $type {
impl<B: Backend> Record<B> for $type {
type Item<S: PrecisionSettings> = $type;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
item
}
}

View File

@ -1,7 +1,9 @@
use core::any::type_name;
use core::marker::PhantomData;
use alloc::format;
use alloc::string::{String, ToString};
use burn_tensor::backend::Backend;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
@ -13,7 +15,9 @@ use super::{
};
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone {
pub trait Recorder<B: Backend>:
Send + Sync + core::default::Default + core::fmt::Debug + Clone
{
/// Type of the settings used by the recorder.
type Settings: PrecisionSettings;
@ -36,11 +40,14 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
/// # Returns
///
/// The output of the recording.
fn record<R: Record>(
fn record<R>(
&self,
record: R,
args: Self::RecordArgs,
) -> Result<Self::RecordOutput, RecorderError> {
) -> Result<Self::RecordOutput, RecorderError>
where
R: Record<B>,
{
let item = record.into_item::<Self::Settings>();
let item = BurnRecord::new::<Self>(item);
@ -48,12 +55,15 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
}
/// Load an item from the given arguments.
fn load<R: Record>(&self, args: Self::LoadArgs) -> Result<R, RecorderError> {
let item: BurnRecord<R::Item<Self::Settings>> =
fn load<R>(&self, args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
where
R: Record<B>,
{
let item: BurnRecord<R::Item<Self::Settings>, B> =
self.load_item(args.clone()).map_err(|err| {
if let Ok(record) = self.load_item::<BurnRecordNoItem>(args.clone()) {
let mut message = "Unable to load record.".to_string();
let metadata = recorder_metadata::<Self>();
let metadata = recorder_metadata::<Self, B>();
if metadata.float != record.metadata.float {
message += format!(
"\nMetadata has a different float type: Actual {:?}, Expected {:?}",
@ -91,7 +101,7 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
err
})?;
Ok(R::from_item(item.item))
Ok(R::from_item(item.item, device))
}
/// Saves an item.
@ -123,10 +133,16 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
/// # Returns
///
/// The loaded item.
fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError>;
fn load_item<I>(&self, args: Self::LoadArgs) -> Result<I, RecorderError>
where
I: DeserializeOwned;
}
fn recorder_metadata<R: Recorder>() -> BurnMetadata {
fn recorder_metadata<R, B>() -> BurnMetadata
where
R: Recorder<B>,
B: Backend,
{
BurnMetadata::new(
type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),
type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),
@ -181,15 +197,17 @@ pub struct BurnMetadata {
/// Record that can be saved by a [Recorder](Recorder).
#[derive(Serialize, Deserialize, Debug)]
pub struct BurnRecord<I> {
pub struct BurnRecord<I, B: Backend> {
/// Metadata of the record.
pub metadata: BurnMetadata,
/// Item to record.
pub item: I,
_b: PhantomData<B>,
}
impl<I> BurnRecord<I> {
impl<I, B: Backend> BurnRecord<I, B> {
/// Creates a new record.
///
/// # Arguments
@ -199,10 +217,14 @@ impl<I> BurnRecord<I> {
/// # Returns
///
/// The new record.
pub fn new<R: Recorder>(item: I) -> Self {
let metadata = recorder_metadata::<R>();
pub fn new<R: Recorder<B>>(item: I) -> Self {
let metadata = recorder_metadata::<R, B>();
Self { metadata, item }
Self {
metadata,
item,
_b: PhantomData,
}
}
}
@ -254,8 +276,10 @@ pub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;
mod tests {
static FILE_PATH: &str = "/tmp/burn_test_record";
use crate::TestBackend;
use super::*;
use burn_tensor::ElementConversion;
use burn_tensor::{Device, ElementConversion};
#[test]
#[should_panic]
@ -265,7 +289,11 @@ mod tests {
value: S::FloatElem,
}
impl<D: PrecisionSettings> Record for Item<D> {
impl<D, B> Record<B> for Item<D>
where
D: PrecisionSettings,
B: Backend,
{
type Item<S: PrecisionSettings> = Item<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
@ -274,7 +302,7 @@ mod tests {
}
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
Item {
value: item.value.elem(),
}
@ -282,15 +310,19 @@ mod tests {
}
let item = Item::<FullPrecisionSettings>::new(16.elem());
let device: Device<TestBackend> = Default::default();
// Serialize in f32.
let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
recorder.record(item, FILE_PATH.into()).unwrap();
Recorder::<TestBackend>::record(&recorder, item, FILE_PATH.into()).unwrap();
// Can't deserialize f32 into f16.
let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();
recorder
.load::<Item<FullPrecisionSettings>>(FILE_PATH.into())
.unwrap();
Recorder::<TestBackend>::load::<Item<FullPrecisionSettings>>(
&recorder,
FILE_PATH.into(),
&device,
)
.unwrap();
}
}

View File

@ -85,7 +85,7 @@ impl<'de> Deserialize<'de> for BoolTensorSerde {
// --- RECORD IMPLEMENTATIONS --- //
impl<B: Backend, const D: usize> Record for Tensor<B, D> {
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
@ -96,12 +96,12 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D> {
FloatTensorSerde::new(self.into_data().convert().serialize())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Tensor::from_data(item.data.convert::<B::FloatElem>(), &B::Device::default())
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Tensor::from_data(item.data.convert::<B::FloatElem>(), device)
}
}
impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
@ -112,12 +112,12 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
IntTensorSerde::new(self.into_data().convert().serialize())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Tensor::from_data(item.data.convert(), &B::Device::default())
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Tensor::from_data(item.data.convert(), device)
}
}
impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
type Item<S: PrecisionSettings> = BoolTensorSerde;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
@ -128,7 +128,7 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
BoolTensorSerde::new(self.into_data().serialize())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Tensor::from_data(item.data, &B::Device::default())
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Tensor::from_data(item.data, device)
}
}

View File

@ -192,7 +192,7 @@ mod tests {
fn deserialize_with_new_optional_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
where
R: FileRecorder,
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf = file_path(format!("deserialize_with_new_optional_field-{name}"));
@ -206,7 +206,8 @@ mod tests {
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result = recorder.load::<ModelNewOptionalFieldRecord<TestBackend>>(file_path.clone());
let result =
recorder.load::<ModelNewOptionalFieldRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
@ -218,7 +219,7 @@ mod tests {
recorder: R,
) -> Result<(), RecorderError>
where
R: FileRecorder,
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf =
@ -234,7 +235,7 @@ mod tests {
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone());
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
@ -243,7 +244,7 @@ mod tests {
fn deserialize_with_new_constant_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>
where
R: FileRecorder,
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf = file_path(format!("deserialize_with_new_constant_field-{name}"));
@ -257,7 +258,8 @@ mod tests {
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result = recorder.load::<ModelNewConstantFieldRecord<TestBackend>>(file_path.clone());
let result =
recorder.load::<ModelNewConstantFieldRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
@ -269,7 +271,7 @@ mod tests {
recorder: R,
) -> Result<(), RecorderError>
where
R: FileRecorder,
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf =
@ -285,7 +287,7 @@ mod tests {
recorder
.record(model.into_record(), file_path.clone())
.unwrap();
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone());
let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;
@ -294,7 +296,7 @@ mod tests {
fn deserialize_with_new_field_order<R>(name: &str, recorder: R) -> Result<(), RecorderError>
where
R: FileRecorder,
R: FileRecorder<TestBackend>,
{
let device = Default::default();
let file_path: PathBuf = file_path(format!("deserialize_with_new_field_order-{name}"));
@ -309,7 +311,8 @@ mod tests {
.record(model.into_record(), file_path.clone())
.unwrap();
let result = recorder.load::<ModelNewFieldOrdersRecord<TestBackend>>(file_path.clone());
let result =
recorder.load::<ModelNewFieldOrdersRecord<TestBackend>>(file_path.clone(), &device);
std::fs::remove_file(file_path).ok();
result?;

View File

@ -22,12 +22,19 @@ struct RecordDeriveCodegen {
name_item: Ident,
gen: StructRecordItemCodegen,
generics: Generics,
has_backend: bool,
}
impl RecordDeriveCodegen {
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> Self {
let name_record = ast.ident.clone();
let name_item = Ident::new(format!("{}Item", name_record).as_str(), name_record.span());
let has_backend = ast
.generics
.type_params()
.map(|param| param.ident == "B")
.reduce(|accum, is_backend| is_backend || accum)
.unwrap_or(false);
Self {
name_record,
@ -39,6 +46,7 @@ impl RecordDeriveCodegen {
.collect(),
),
generics: ast.generics.clone(),
has_backend,
}
}
@ -51,7 +59,8 @@ impl RecordDeriveCodegen {
generics.params.push(param);
}
self.gen.gen_item_type(&self.name_item, &generics)
self.gen
.gen_item_type(&self.name_item, &generics, self.has_backend)
}
/// Generate the implementation for the Record trait.
@ -61,12 +70,18 @@ impl RecordDeriveCodegen {
let (_, ty_generics_item, _) = item_generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
impl_generic
} else {
quote! { #impl_generics }
};
let name_item = &self.name_item;
let into_item_fn = self.gen.gen_into_item(name_item);
let from_item_fn = self.gen.gen_from_item();
quote! {
impl #impl_generics burn::record::Record for #name #ty_generics #where_clause {
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
#into_item_fn
@ -76,6 +91,20 @@ impl RecordDeriveCodegen {
}
}
fn impl_generics(&self) -> Option<TokenStream> {
if self.has_backend {
return None;
}
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
let mut generics = self.generics.clone();
generics.params.push(syn::GenericParam::Type(param));
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
Some(quote! {#impl_generics})
}
fn record_item_generics(&self) -> Generics {
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
let mut generics = self.generics.clone();
@ -83,6 +112,11 @@ impl RecordDeriveCodegen {
generics.params.push(param);
}
if !self.has_backend {
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
}
generics
}
}

View File

@ -4,7 +4,12 @@ use syn::Generics;
/// Basic trait to be implemented for record generation.
pub(crate) trait RecordItemCodegen {
/// Generate the record item type (i.e a struct)
fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream;
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream;
/// Generate the into_item function.
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
/// Generate the from item function.

View File

@ -1,7 +1,7 @@
use crate::shared::field::FieldTypeAnalyzer;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Generics;
use syn::{parse_quote, Generics};
use super::codegen::RecordItemCodegen;
@ -11,7 +11,12 @@ pub(crate) struct StructRecordItemCodegen {
}
impl RecordItemCodegen for StructRecordItemCodegen {
fn gen_item_type(&self, item_name: &Ident, generics: &Generics) -> TokenStream {
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream {
let mut fields = quote! {};
let mut bounds = quote! {};
@ -21,15 +26,25 @@ impl RecordItemCodegen for StructRecordItemCodegen {
fields.extend(quote! {
/// Field to be serialized.
pub #name: <#ty as burn::record::Record>::Item<S>,
pub #name: <#ty as burn::record::Record<B>>::Item<S>,
});
bounds.extend(quote! {
<#ty as burn::record::Record>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
});
}
let bound = bounds.to_string();
let (generics, _, generics_where) = generics.split_for_impl();
let (generics, generics_where) = if !has_backend {
let mut generics = generics.clone();
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
let (generics, _, generics_where) = generics.split_for_impl();
(quote! { #generics }, quote! { #generics_where })
} else {
let (generics, _, generics_where) = generics.split_for_impl();
(quote! { #generics }, quote! { #generics_where })
};
quote! {
/// The record item type for the module.
@ -49,7 +64,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
let name = &field.field.ident;
body_into_item.extend(quote! {
#name: burn::record::Record::into_item::<S>(self.#name),
#name: burn::record::Record::<B>::into_item::<S>(self.#name),
});
}
@ -69,12 +84,12 @@ impl RecordItemCodegen for StructRecordItemCodegen {
let name = &field.field.ident;
body_from_item.extend(quote! {
#name: burn::record::Record::from_item::<S>(item.#name),
#name: burn::record::Record::<B>::from_item::<S>(item.#name, device),
});
}
quote! {
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>) -> Self {
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Self {
#body_from_item
}

View File

@ -49,6 +49,9 @@ pub struct BurnGraph<PS: PrecisionSettings> {
graph_output_types: Vec<Type>,
}
// The backend used for recording.
type Backend = burn_ndarray::NdArray;
impl<PS: PrecisionSettings> BurnGraph<PS> {
/// Register a new operation node into the graph.
///
@ -96,14 +99,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
match record_type {
RecordType::PrettyJson => {
PrettyJsonFileRecorder::<PS>::new()
.save_item(
BurnRecord::new::<PrettyJsonFileRecorder<PS>>(StructMap(
BurnGraphState::new(&self.nodes),
)),
out_file.clone(),
)
.unwrap();
let recorder = PrettyJsonFileRecorder::<PS>::new();
Recorder::<Backend>::save_item(
&recorder,
BurnRecord::<_, Backend>::new::<PrettyJsonFileRecorder<PS>>(StructMap(
BurnGraphState::new(&self.nodes),
)),
out_file.clone(),
)
.unwrap();
assert!(
!embed_states,
@ -116,14 +121,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
);
}
RecordType::NamedMpkGz => {
NamedMpkGzFileRecorder::<PS>::new()
.save_item(
BurnRecord::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
BurnGraphState::new(&self.nodes),
)),
out_file.clone(),
)
.unwrap();
let recorder = NamedMpkGzFileRecorder::<PS>::new();
Recorder::<Backend>::save_item(
&recorder,
BurnRecord::<_, Backend>::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
BurnGraphState::new(&self.nodes),
)),
out_file.clone(),
)
.unwrap();
assert!(
!embed_states,
@ -136,14 +143,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
}
RecordType::NamedMpk => {
NamedMpkFileRecorder::<PS>::new()
.save_item(
BurnRecord::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
BurnGraphState::new(&self.nodes),
)),
out_file.clone(),
)
.unwrap();
let recorder = NamedMpkFileRecorder::<PS>::new();
Recorder::<Backend>::save_item(
&recorder,
BurnRecord::<_, Backend>::new::<NamedMpkGzFileRecorder<PS>>(StructMap(
BurnGraphState::new(&self.nodes),
)),
out_file.clone(),
)
.unwrap();
assert!(
!embed_states,
@ -157,14 +166,16 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
}
RecordType::Bincode => {
BinFileRecorder::<PS>::new()
.save_item(
BurnRecord::new::<BinFileRecorder<PS>>(StructTuple(BurnGraphState::new(
&self.nodes,
))),
out_file.clone(),
)
.unwrap();
let recorder = BinFileRecorder::<PS>::new();
Recorder::<Backend>::save_item(
&recorder,
BurnRecord::<_, Backend>::new::<BinFileRecorder<PS>>(StructTuple(
BurnGraphState::new(&self.nodes),
)),
out_file.clone(),
)
.unwrap();
if embed_states {
self.register_record_embed(out_file);
@ -349,14 +360,14 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
_blank_!();
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
Self::from_file(#file)
Self::from_file(#file, &Default::default())
}
}
_blank_!();
impl<B: Backend> Model<B> {
pub fn from_file(file: &str) -> Self {
pub fn from_file(file: &str, device: &B::Device) -> Self {
let record = #recorder_ty::new()
.load(file.into())
.load(file.into(), device)
.expect("Record file to exist.");
Self::new_with(record)
}
@ -373,7 +384,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
self.imports.register("burn::record::BinBytesRecorder");
let mut file = file;
file.set_extension(BinFileRecorder::<PS>::file_extension());
file.set_extension(<BinFileRecorder<PS> as FileRecorder<Backend>>::file_extension());
let file = file.to_str().unwrap();
self.default = Some(quote! {
_blank_!();
@ -381,14 +392,14 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
_blank_!();
impl<B: Backend> Default for Model<B> {
fn default() -> Self {
Self::from_embedded()
Self::from_embedded(&Default::default())
}
}
_blank_!();
impl<B: Backend> Model<B> {
pub fn from_embedded() -> Self {
pub fn from_embedded(device: &B::Device) -> Self {
let record = BinBytesRecorder::<#precision_ty>::default()
.load(EMBEDDED_STATES.to_vec())
.load(EMBEDDED_STATES.to_vec(), device)
.expect("Failed to decode state");
Self::new_with(record)

View File

@ -1,26 +1,35 @@
use super::{Checkpointer, CheckpointerError};
use burn_core::record::Record;
use burn_core::{record::Record, tensor::backend::Backend};
use std::sync::mpsc;
enum Message<R> {
Restore(usize, mpsc::SyncSender<Result<R, CheckpointerError>>),
enum Message<R, B: Backend> {
Restore(
usize,
B::Device,
mpsc::SyncSender<Result<R, CheckpointerError>>,
),
Save(usize, R),
Delete(usize),
End,
}
#[derive(new)]
struct CheckpointerThread<C, R> {
struct CheckpointerThread<C, R, B: Backend> {
checkpointer: C,
receiver: mpsc::Receiver<Message<R>>,
receiver: mpsc::Receiver<Message<R, B>>,
}
impl<C: Checkpointer<R>, R: Record> CheckpointerThread<C, R> {
impl<C, R, B> CheckpointerThread<C, R, B>
where
C: Checkpointer<R, B>,
R: Record<B>,
B: Backend,
{
fn run(self) {
for item in self.receiver.iter() {
match item {
Message::Restore(epoch, callback) => {
let record = self.checkpointer.restore(epoch);
Message::Restore(epoch, device, callback) => {
let record = self.checkpointer.restore(epoch, &device);
callback
.send(record)
.expect("Can send response through callback channel.");
@ -42,12 +51,16 @@ impl<C: Checkpointer<R>, R: Record> CheckpointerThread<C, R> {
}
/// Async checkpointer.
pub struct AsyncCheckpointer<Record> {
sender: mpsc::SyncSender<Message<Record>>,
pub struct AsyncCheckpointer<Record, B: Backend> {
sender: mpsc::SyncSender<Message<Record, B>>,
handler: Option<std::thread::JoinHandle<()>>,
}
impl<R: Record + 'static> AsyncCheckpointer<R> {
impl<R, B> AsyncCheckpointer<R, B>
where
R: Record<B> + 'static,
B: Backend,
{
/// Create a new async checkpointer.
///
/// # Arguments
@ -59,7 +72,7 @@ impl<R: Record + 'static> AsyncCheckpointer<R> {
/// The async checkpointer.
pub fn new<C>(checkpointer: C) -> Self
where
C: Checkpointer<R> + Send + 'static,
C: Checkpointer<R, B> + Send + 'static,
{
// Only on checkpoint can be done in advance.
let (sender, receiver) = mpsc::sync_channel(0);
@ -70,9 +83,10 @@ impl<R: Record + 'static> AsyncCheckpointer<R> {
}
}
impl<R> Checkpointer<R> for AsyncCheckpointer<R>
impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
where
R: Record + 'static,
R: Record<B> + 'static,
B: Backend,
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
self.sender
@ -82,10 +96,10 @@ where
Ok(())
}
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::Restore(epoch, sender))
.send(Message::Restore(epoch, device.clone(), sender))
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
if let Ok(record) = receiver.recv() {
@ -104,7 +118,10 @@ where
}
}
impl<E> Drop for AsyncCheckpointer<E> {
impl<E, B> Drop for AsyncCheckpointer<E, B>
where
B: Backend,
{
fn drop(&mut self) {
self.sender
.send(Message::End)

View File

@ -1,4 +1,7 @@
use burn_core::record::{Record, RecorderError};
use burn_core::{
record::{Record, RecorderError},
tensor::backend::Backend,
};
/// The error type for checkpointer.
#[derive(Debug)]
@ -14,7 +17,11 @@ pub enum CheckpointerError {
}
/// The trait for checkpointer.
pub trait Checkpointer<R: Record> {
pub trait Checkpointer<R, B>
where
R: Record<B>,
B: Backend,
{
/// Save the record.
///
/// # Arguments
@ -31,9 +38,10 @@ pub trait Checkpointer<R: Record> {
/// # Arguments
///
/// * `epoch` - The epoch.
/// * `device` - The device used to restore the record.
///
/// # Returns
///
/// The record.
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError>;
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
}

View File

@ -1,5 +1,8 @@
use super::{Checkpointer, CheckpointerError};
use burn_core::record::{FileRecorder, Record};
use burn_core::{
record::{FileRecorder, Record},
tensor::backend::Backend,
};
/// The file checkpointer.
pub struct FileCheckpointer<FR> {
@ -30,10 +33,11 @@ impl<FR> FileCheckpointer<FR> {
}
}
impl<FR, R> Checkpointer<R> for FileCheckpointer<FR>
impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
where
R: Record,
FR: FileRecorder,
R: Record<B>,
FR: FileRecorder<B>,
B: Backend,
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
@ -46,12 +50,12 @@ where
Ok(())
}
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
let record = self
.recorder
.load(file_path.into())
.load(file_path.into(), device)
.map_err(CheckpointerError::RecorderError)?;
Ok(record)

View File

@ -15,19 +15,26 @@ pub trait LearnerComponents {
/// The backend in used for the training.
type Backend: AutodiffBackend;
/// The learning rate scheduler used for the training.
type LrScheduler: LrScheduler;
type LrScheduler: LrScheduler<Self::Backend>;
/// The model to train.
type Model: AutodiffModule<Self::Backend> + core::fmt::Display + 'static;
/// The optimizer used for the training.
type Optimizer: Optimizer<Self::Model, Self::Backend>;
/// The checkpointer used for the model.
type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record>;
type CheckpointerModel: Checkpointer<
<Self::Model as Module<Self::Backend>>::Record,
Self::Backend,
>;
/// The checkpointer used for the optimizer.
type CheckpointerOptimizer: Checkpointer<
<Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
Self::Backend,
>;
/// The checkpointer used for the scheduler.
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
type CheckpointerLrScheduler: Checkpointer<
<Self::LrScheduler as LrScheduler<Self::Backend>>::Record,
Self::Backend,
>;
type EventProcessor: EventProcessor + 'static;
/// The strategy to save and delete checkpoints.
type CheckpointerStrategy: CheckpointingStrategy;
@ -50,12 +57,12 @@ impl<B, LR, M, O, CM, CO, CS, EP, S> LearnerComponents
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S>
where
B: AutodiffBackend,
LR: LrScheduler,
LR: LrScheduler<B>,
M: AutodiffModule<B> + core::fmt::Display + 'static,
O: Optimizer<M, B>,
CM: Checkpointer<M::Record>,
CO: Checkpointer<O::Record>,
CS: Checkpointer<LR::Record>,
CM: Checkpointer<M::Record, B>,
CO: Checkpointer<O::Record, B>,
CS: Checkpointer<LR::Record, B>,
EP: EventProcessor + 'static,
S: CheckpointingStrategy,
{

View File

@ -6,6 +6,7 @@ use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::Module;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::Backend;
use burn_core::tensor::Device;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@ -79,23 +80,24 @@ impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
model: LC::Model,
optim: LC::Optimizer,
scheduler: LC::LrScheduler,
device: &Device<LC::Backend>,
epoch: usize,
) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
let record = self
.model
.restore(epoch)
.restore(epoch, device)
.expect("Can load model checkpoint.");
let model = model.load_record(record);
let record = self
.optim
.restore(epoch)
.restore(epoch, device)
.expect("Can load optimizer checkpoint.");
let optim = optim.load_record(record);
let record = self
.lr_scheduler
.restore(epoch)
.restore(epoch, device)
.expect("Can load learning rate scheduler checkpoint.");
let scheduler = scheduler.load_record(record);

View File

@ -29,16 +29,16 @@ where
B: AutodiffBackend,
M: AutodiffModule<B>,
O: Optimizer<M, B>,
S: LrScheduler,
S: LrScheduler<B>,
{
// Not that complex and very convenient when the traits are
// already constrained correctly. Extracting in another type
// would be more complex.
#[allow(clippy::type_complexity)]
checkpointers: Option<(
AsyncCheckpointer<M::Record>,
AsyncCheckpointer<O::Record>,
AsyncCheckpointer<S::Record>,
AsyncCheckpointer<M::Record, B>,
AsyncCheckpointer<O::Record, B>,
AsyncCheckpointer<S::Record, B>,
)>,
num_epochs: usize,
checkpoint: Option<usize>,
@ -62,7 +62,7 @@ where
V: Send + Sync + 'static,
M: AutodiffModule<B> + core::fmt::Display + 'static,
O: Optimizer<M, B>,
S: LrScheduler,
S: LrScheduler<B>,
{
/// Creates a new learner builder.
///
@ -235,7 +235,8 @@ where
/// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
where
FR: FileRecorder + 'static,
FR: FileRecorder<B> + 'static,
FR: FileRecorder<B::InnerBackend> + 'static,
O::Record: 'static,
M::Record: 'static,
S::Record: 'static,
@ -281,9 +282,9 @@ where
S,
M,
O,
AsyncCheckpointer<M::Record>,
AsyncCheckpointer<O::Record>,
AsyncCheckpointer<S::Record>,
AsyncCheckpointer<M::Record, B>,
AsyncCheckpointer<O::Record, B>,
AsyncCheckpointer<S::Record, B>,
FullEventProcessor<T, V>,
Box<dyn CheckpointingStrategy>,
>,

View File

@ -135,6 +135,7 @@ impl<LC: LearnerComponents> Learner<LC> {
self.model,
self.optim,
self.lr_scheduler,
&Default::default(), // Load the checkpoint on the default device.
checkpoint,
);
}

View File

@ -3,7 +3,6 @@ use burn::data::dataset::source::huggingface::MNISTItem;
use burn::{
config::Config,
data::dataloader::batcher::Batcher,
module::Module,
record::{CompactRecorder, Recorder},
tensor::backend::Backend,
};
@ -12,10 +11,10 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into())
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
let model = config.model.init_with::<B>(record).to_device(&device);
let model = config.model.init_with::<B>(record);
let label = item.label;
let batcher = MNISTBatcher::new(device);

View File

@ -133,7 +133,7 @@ impl<B: Backend> Model<B> {
/// Constructor
pub fn new(device: &B::Device) -> Self {
Self {
model: SqueezenetModel::from_embedded(),
model: SqueezenetModel::from_embedded(device),
normalizer: Normalizer::new(device),
}
}

View File

@ -22,7 +22,7 @@ pub async fn build_and_load_model() -> Model<Backend> {
let model: Model<Backend> = Model::new(&Default::default());
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
.load(STATE_ENCODED.to_vec())
.load(STATE_ENCODED.to_vec(), &Default::default())
.expect("Failed to decode state");
model.load_record(record)

View File

@ -12,7 +12,6 @@ use crate::{
use burn::{
config::Config,
data::dataloader::batcher::Batcher,
module::Module,
record::{CompactRecorder, Recorder},
tensor::backend::Backend,
};
@ -44,7 +43,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
// Load pre-trained model weights
println!("Loading weights ...");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into())
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model weights");
// Create model using loaded weights
@ -55,8 +54,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
tokenizer.vocab_size(),
config.max_seq_length,
)
.init_with::<B>(record) // Initialize model with loaded weights
.to_device(&device); // Move model to computation device
.init_with::<B>(record); // Initialize model with loaded weights
// Run inference on the given text samples
println!("Running inference ...");