Add book + fix some code (#671)

This commit is contained in:
Nathaniel Simard 2023-08-23 11:52:55 -04:00 committed by GitHub
parent b60d931771
commit 00d3d208b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 495 additions and 28 deletions

View File

@ -1,9 +1,20 @@
- [Overview](./overview.md) - [Overview](./overview.md)
- [Why Burn?](./motivation.md) - [Why Burn?](./motivation.md)
- [Guide](./guide/README.md) - [Basic Workflow: From Training to Inference](./basic-workflow/README.md)
- [Model](./guide/model.md) - [Model](./basic-workflow/model.md)
- [Data](./guide/data.md) - [Data](./basic-workflow/data.md)
- [Training](./guide/training.md) - [Training](./basic-workflow/training.md)
- [Backend](./guide/backend.md) - [Backend](./basic-workflow/backend.md)
- [Inference](./guide/inference.md) - [Inference](./basic-workflow/inference.md)
- [Conclusion](./guide/conclusion.md) - [Conclusion](./basic-workflow/conclusion.md)
- [Building Blocks](./building-blocks/README.md)
- [Backend](./building-blocks/backend.md)
- [Tensor](./building-blocks/tensor.md)
- [Autodiff](./building-blocks/autodiff.md)
- [Module](./building-blocks/module.md)
- [Import ONNX Model]()
- [Advanced]()
- [Custom Training Loops]()
- [Custom Metric]()
- [Custom Kernels]()
- [WGPU]()

View File

@ -34,7 +34,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists"); TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists");
let record = CompactRecorder::new() let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into()) .load(format!("{artifact_dir}/model").into())
.expect("Failed to save trained model"); .expect("Failed to load trained model");
let model = config.model.init_with::<B>(record).to_device(&device); let model = config.model.init_with::<B>(record).to_device(&device);
@ -53,4 +53,4 @@ Then we can fetch the record using the same recorder as we used during training.
Finally we can init the model with the configuration and the record before sending it to the wanted device for inference. Finally we can init the model with the configuration and the record before sending it to the wanted device for inference.
For simplicity we can use the same batcher used during the training to pass from a MNISTItem to a tensor. For simplicity we can use the same batcher used during the training to pass from a MNISTItem to a tensor.
By running the infer function, you should see the predictions of your model! By running the infer function, you should see the predictions of your model!

View File

Before

Width:  |  Height:  |  Size: 52 KiB

After

Width:  |  Height:  |  Size: 52 KiB

View File

@ -114,11 +114,8 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
let model_trained = learner.fit(dataloader_train, dataloader_test); let model_trained = learner.fit(dataloader_train, dataloader_test);
CompactRecorder::new() model_trained
.record( .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
model_trained.into_record(),
format!("{artifact_dir}/model").into(),
)
.expect("Failed to save trained model"); .expect("Failed to save trained model");
} }
``` ```

View File

@ -0,0 +1,7 @@
# Building Blocks
In this section, we'll guide you through the core elements that make up Burn.
We'll walk you through the key components that serve as the building blocks of the framework and your future projects.
As you explore Burn, you might notice that we occasionally draw comparisons to PyTorch.
We believe it can provide a smoother learning curve and help you grasp the nuances more effectively.

View File

@ -0,0 +1,83 @@
# Autodiff
Burn's tensor also supports autodifferentiation, which is an essential part of any deep learning framework.
We introduced the `Backend` trait in the [previous section](./backend.md), but Burn also has another trait for autodiff: `ADBackend`.
However, not all tensors support auto-differentiation; you need a backend that implements both the `Backend` and `ADBackend` traits.
Fortunately, you can add autodifferentiation capabilities to any backend using a backend decorator: `type MyAutodiffBackend = ADBackendDecorator<MyBackend>`.
This decorator implements both the `ADBackend` and `Backend` traits by maintaining a dynamic computational graph and utilizing the inner backend to execute tensor operations.
The `ADBackend` trait adds new operations on float tensors that can't be called otherwise.
It also provides a new associated type, `B::Gradients`, where each calculated gradient resides.
```rust, ignore
fn calculate_gradients<B: ADBackend>(tensor: Tensor<B, 2>) -> B::Gradients {
let mut gradients = tensor.clone().backward();
let tensor_grad = tensor.grad(&gradients); // get
let tensor_grad = tensor.grad_remove(&mut gradients); // pop
gradients
}
```
Note that some functions will always be available even if the backend doesn't implement the `ADBackend` trait.
In such cases, those functions will do nothing.
| Burn API | PyTorch Equivalent |
|--------------------------------------------------------|------------------------------------------------------|
| `tensor.detach()` | `tensor.detach()` |
| `tensor.require_grad()` | `tensor.requires_grad()` |
| `tensor.is_require_grad()` | `tensor.requires_grad` |
| `tensor.set_require_grad(require_grad)` | `tensor.requires_grad(False)` |
However, you're unlikely to make any mistakes since you can't call `backward` on a tensor that is on a backend that doesn't implement `ADBackend`.
Additionally, you can't retrieve the gradient of a tensor without an autodiff backend.
## Difference with PyTorch
The way Burn handles gradients is different from PyTorch.
First, when calling `backward`, each parameter doesn't have its `grad` field updated.
Instead, the backward pass returns all the calculated gradients in a container.
This approach offers numerous benefits, such as the ability to easily send gradients to other threads.
You can also retrieve the gradient for a specific parameter using the `grad` method on a tensor.
Since this method takes the gradients as input, it's hard to forget to call `backward` beforehand.
Note that sometimes, using `grad_remove` can improve performance by allowing inplace operations.
In PyTorch, when you don't need gradients for inference or validation, you typically need to scope your code using a block.
```python
# Inference mode
torch.inference():
# your code
...
# Or no grad
torch.no_grad():
# your code
...
```
With Burn, you don't need to wrap the backend with the `ADBackendDecorator` for inference, and you can call `inner()` to obtain the inner tensor, which is useful for validation.
```rust, ignore
/// Use `B: ADBackend`
fn example_validation<B: ADBackend>(tensor: Tensor<B, 2>) {
let inner_tensor: Tensor<B::InnerBackend, 2> = tensor.inner();
let _ = inner_tensor + 5;
}
/// Use `B: Backend`
fn example_inference<B: Backend>(tensor: Tensor<B, 2>) {
let _ = tensor + 5;
...
}
```
**Gradients with Optimizers**
We've seen how gradients can be used with tensors, but the process is a bit different when working with optimizers from `burn-core`.
To work with the `Module` trait, a translation step is required to link tensor parameters with their gradients.
This step is necessary to easily support gradient accumulation and training on multiple devices, where each module can be forked and run on different devices in parallel.
We'll explore deeper into this topic in the [Module](./module.md) section.

View File

@ -0,0 +1,11 @@
# Backend
Nearly everything in Burn is based on the `Backend` trait, which enables you to run tensor operations using different implementations without having to modify your code.
While a backend may not necessarily have autodiff capabilities, the `ADBackend` trait specifies when autodiff is needed.
This trait not only abstracts operations but also tensor, device, and element types, providing each backend the flexibility they need.
It's worth noting that the trait assumes eager mode since burn fully supports dynamic graphs.
However, we may create another API to assist with integrating graph-based backends, without requiring any changes to the user's code.
Users are not expected to directly use the backend trait methods, as it is primarily designed with backend developers in mind rather than Burn users.
Therefore, most Burn userland APIs are generic across backends.
This approach helps users discover the API more organically with proper autocomplete and documentation.

View File

@ -0,0 +1,101 @@
# Module
The `Module` derive allows you to create your own neural network modules, similar to PyTorch.
The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.
```rust, ignore
use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: GELU,
}
impl<B: Backend> PositionWiseFeedForward<B> {
/// Normal method added to a struct.
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);
self.linear_outer.forward(x)
}
}
```
Note that all fields declared in the struct must also implement the `Module` trait.
## Tensor
If you want to create your own module that contains tensors, and not just other modules defined with the `Module` derive, you need to be careful to achieve the behavior you want.
- `Param<Tensor<B, D>>`:
If you want the tensor to be included as a parameter of your modules, you need to wrap the tensor in a `Param` struct.
This will create an ID that will be used to identify this parameter.
This is essential when performing module optimization and when saving states such as optimizer and module checkpoints.
Note that a module's record only contains parameters.
- `Param<Tensor<B, D>>.set_require_grad(false)`:
If you want the tensor to be included as a parameter of your modules, and therefore saved with the module's weights, but you don't want it to be updated by the optimizer.
- `Tensor<B, D>`:
If you want the tensor to act as a constant that can be recreated when instantiating a module.
This can be useful when generating sinusoidal embeddings, for example.
## Methods
These methods are available for all modules.
| Burn API | PyTorch Equivalent |
|--------------------------------------------------------|---------------------------------------------------------|
| `module.devices()` | N/A |
| `module.fork(device)` | Similar to `module.to(device).detach()` |
| `module.to_device(device)` | `module.to(device)` |
| `module.no_grad()` | `module.require_grad_(False)` |
| `module.num_params()` | N/A |
| `module.visit(visitor)` | N/A |
| `module.map(mapper)` | N/A |
| `module.into_record()` | Similar to `state_dict` |
| `module.load_record(record)` | Similar to `load_state_dict(state_dict)` |
| `module.save_file(file_path, recorder)` | N/A |
| `module.load_file(file_path, recorder)` | N/A |
Similar to the backend trait, there is also the `ADModule` trait to signify a module with autodiff support.
| Burn API | PyTorch Equivalent |
|--------------------------------------------------------|---------------------------------------------------------|
| `module.valid()` | `module.eval()` |
## Visitor & Mapper
As mentioned earlier, modules primarily function as parameter containers.
Therefore, we naturally offer several ways to perform functions on each parameter.
This is distinct from PyTorch, where extending module functionalities is not as straightforward.
The `map` and `visitor` methods are quite similar but serve different purposes.
Mapping is used for potentially mutable operations where each parameter of a module can be updated to a new value.
In Burn, optimizers are essentially just sophisticated module mappers.
Visitors, on the other hand, are used when you don't intend to modify the module but need to retrieve specific information from it, such as the number of parameters or a list of devices in use.
You can implement your own mapper or visitor by implementing these simple traits:
```rust, ignore
/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a tensor in the module.
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
}
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
}
```

View File

@ -0,0 +1,169 @@
# Tensor
As previously explained in the [model section](../basic-workflow/model.md), the Tensor struct has 3 generic arguments: the backend, the dimension number (rank), and the kind.
```rust , ignore
Tensor<B, D> // Float tensor (default)
Tensor<B, D, Float> // Explicit float tensor
Tensor<B, D, Int> // Int tensor
Tensor<B, D, Bool> // Bool tensor
```
Note that the specific element types used for `Float`, `Int`, and `Bool` tensors are defined by backend implementations.
## Operations
Almost all Burn operations take ownership of the input tensors.
Therefore, reusing a tensor multiple times will necessitate cloning it.
Don't worry, the tensor's buffer isn't copied, but a reference to it is increased.
This makes it possible to determine exactly how many times a tensor is used, which is very convenient for reusing tensor buffers and improving performance.
For that reason, we don't provide explicit inplace operations.
If a tensor is used only one time, inplace operations will always be used when available.
Normally with PyTorch, explicit inplace operations aren't supported during the backward pass, making them useful only for data preprocessing or inference-only model implementations.
With Burn, you can focus more on _what_ the model should do, rather than on _how_ to do it.
We take the responsibility of making your code run as fast as possible during training as well as inference.
The same principles apply to broadcasting; all operations support broadcasting unless specified otherwise.
Here, we provide a list of all supported operations along with their PyTorch equivalents.
Note that for the sake of simplicity, we ignore type signatures.
For more details, refer to the [full documentation](https://docs.rs/burn/latest/burn/tensor/struct.Tensor.html).
### Basic Operations
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| Burn | PyTorch Equivalent |
|-------------------------------------------------------|---------------------------------------------------|
| `Tensor::empty(shape)` | `torch.empty(shape)` |
| `Tensor::empty_device(shape, device)` | `torch.empty(shape, device=device)` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.device()` | `tensor.device` |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.equal(other)` | `x == y` |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `tensor.into_data()` | N/A |
| `tensor.to_data()` | N/A |
| `Tensor::from_data(data)` | N/A |
| `Tensor::from_data_device(data, device)` | N/A |
| `tensor.into_primitive()` | N/A |
| `Tensor::from_primitive(primitive)` | N/A |
### Numeric Operations
Those operations are available for numeric tensor kinds: `Float` and `Int`.
| Burn | PyTorch Equivalent |
|----------------------------------------------------|---------------------------------------------------------|
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
| `tensor + other` or `tensor.add(other)` | `tensor + other` |
| `tensor + scalar` or `tensor.add_scalar(scalar)` | `tensor + scalar` |
| `tensor - other` or `tensor.sub(other)` | `tensor - other` |
| `tensor - scalar` or `tensor.sub_scalar(scalar)` | `tensor - scalar` |
| `tensor / other` or `tensor.div(other)` | `tensor / other` |
| `tensor / scalar` or `tensor.div_scalar(scalar)` | `tensor / scalar` |
| `tensor * other` or `tensor.mul(other)` | `tensor * other` |
| `tensor * scalar` or `tensor.mul_scalar(scalar)` | `tensor * scalar` |
| `-tensor` or `tensor.neg()` | `-tensor` |
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
| `Tensor::zeros_device(shape, device)` | `torch.zeros(shape, device=device)` |
| `Tensor::ones(shape)` | `torch.ones(shape)` |
| `Tensor::ones_device(shape, device)` | `torch.ones(shape, device=device)` |
| `Tensor::full(shape, fill_value)` | `torch.full(shape, fill_value)` |
| `Tensor::full_device(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
| `tensor.mean()` | `tensor.mean()` |
| `tensor.sum()` | `tensor.sum()` |
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
| `tensor.greater(other)` | `tensor.gt(other)` |
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
| `tensor.lower(other)` | `tensor.lt(other)` |
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
| `tensor.lower_equal(other)` | `tensor.le(other)` |
| `tensor.lower_equal_elem(scalar)` | `tensor.le(scalar)` |
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.argmax(dim)` | `tensor.argmax(dim)` |
| `tensor.max()` | `tensor.max()` |
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
| `tensor.max_dim_with_indices(dim)` | N/A |
| `tensor.argmin(dim)` | `tensor.argmin(dim)` |
| `tensor.min()` | `tensor.min()` |
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
| `tensor.min_dim_with_indices(dim)` | N/A |
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.abs()` | `torch.abs(tensor)` |
### Float Operations
Those operations are only available for `Float` tensors.
| Burn API | PyTorch Equivalent |
|--------------------------------------------------------|---------------------------------------------------|
| `tensor.exp()` | `tensor.exp()` |
| `tensor.log()` | `tensor.log()` |
| `tensor.log1p()` | `tensor.log1p()` |
| `tensor.erf()` | `tensor.erf()` |
| `tensor.powf(value)` | `tensor.pow(value)` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.from_floats(floats)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.one_hot(index, num_classes)` | N/A |
| `tensor.transpose()` | `tensor.T` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.matmul(other)` | `tensor.matmul(other)` |
| `tensor.var(dim)` | `tensor.var(dim)` |
| `tensor.var_bias(dim)` | N/A |
| `tensor.var_mean(dim)` | N/A |
| `tensor.var_mean_bias(dim)` | N/A |
| `tensor.random(shape, distribution)` | N/A |
| `tensor.random_device(shape, distribution, device)` | N/A |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.from_full_precision(tensor)` | N/A |
# Int Operations
Those operations are only available for `Int` tensors.
| Burn API | PyTorch Equivalent |
|--------------------------------------------------------|---------------------------------------------------------|
| `tensor.from_ints(ints)` | N/A |
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.arange(5..10)` | `tensor.arange(start=5, end=10)` |
| `tensor.arange_device(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2)` | `tensor.arange(start=5, end=10, step=2)` |
| `tensor.arange_step_device(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
# Bool Operations
Those operations are only available for `Bool` tensors.
| Burn API | PyTorch Equivalent |
|--------------------------------------------------------|---------------------------------------------------------|
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |

View File

@ -182,6 +182,52 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Convert the module into a record containing the state. /// Convert the module into a record containing the state.
fn into_record(self) -> Self::Record; fn into_record(self) -> Self::Record;
#[cfg(feature = "std")]
/// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
///
/// List of supported file recorders:
///
/// * [default](crate::record::DefaultFileRecorder)
/// * [bincode](crate::record::BinFileRecorder)
/// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
/// * [json pretty](crate::record::PrettyJsonFileRecorder)
/// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
/// * [named mpk](crate::record::NamedMpkFileRecorder)
/// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
///
/// ## Notes
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn save_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), crate::record::RecorderError> {
let record = Self::into_record(self);
recorder.record(record, file_path.into())
}
#[cfg(feature = "std")]
/// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
///
/// The recorder should be the same as the one used to save the module, see
/// [save_file](Self::save_file).
///
/// ## Notes
///
/// The file extension is automatically added depending on the file recorder provided, you
/// don't have to specify it.
fn load_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
self,
file_path: PB,
recorder: &FR,
) -> Result<Self, crate::record::RecorderError> {
let record = recorder.load(file_path.into())?;
Ok(self.load_record(record))
}
} }
/// Module visitor trait. /// Module visitor trait.

View File

@ -34,11 +34,11 @@ impl GradientsParams {
} }
/// Remove the gradients for the given [parameter id](ParamId). /// Remove the gradients for the given [parameter id](ParamId).
pub fn remove<B, const D: usize>(&self, id: &ParamId) -> Option<Tensor<B, D>> pub fn remove<B, const D: usize>(&mut self, id: &ParamId) -> Option<Tensor<B, D>>
where where
B: Backend, B: Backend,
{ {
self.container.get(id) self.container.remove(id)
} }
/// Register a gradients tensor for the given [parameter id](ParamId). /// Register a gradients tensor for the given [parameter id](ParamId).

View File

@ -46,6 +46,12 @@ pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>, _settings: PhantomData<S>,
} }
/// File recorder using the [named msgpack](rmp_serde) format.
#[derive(new, Debug, Default, Clone)]
pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
impl<S: PrecisionSettings> FileRecorder for BinGzFileRecorder<S> { impl<S: PrecisionSettings> FileRecorder for BinGzFileRecorder<S> {
fn file_extension() -> &'static str { fn file_extension() -> &'static str {
"bin.gz" "bin.gz"
@ -73,6 +79,12 @@ impl<S: PrecisionSettings> FileRecorder for NamedMpkGzFileRecorder<S> {
} }
} }
impl<S: PrecisionSettings> FileRecorder for NamedMpkFileRecorder<S> {
fn file_extension() -> &'static str {
"mpk"
}
}
macro_rules! str2reader { macro_rules! str2reader {
( (
$file:expr $file:expr
@ -251,6 +263,34 @@ impl<S: PrecisionSettings> Recorder for NamedMpkGzFileRecorder<S> {
} }
} }
impl<S: PrecisionSettings> Recorder for NamedMpkFileRecorder<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let mut writer = str2writer!(file)?;
rmp_serde::encode::write_named(&mut writer, &item)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
let reader = str2reader!(file)?;
let state = rmp_serde::decode::from_read(reader)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -289,6 +329,11 @@ mod tests {
test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default()) test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
} }
#[test]
fn test_can_save_and_load_mpk_format() {
test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
}
fn test_can_save_and_load<Recorder: FileRecorder>(recorder: Recorder) { fn test_can_save_and_load<Recorder: FileRecorder>(recorder: Recorder) {
let model_before = create_model(); let model_before = create_model();
recorder recorder

View File

@ -13,7 +13,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists"); TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists");
let record = CompactRecorder::new() let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into()) .load(format!("{artifact_dir}/model").into())
.expect("Failed to save trained model"); .expect("Failed to load trained model");
let model = config.model.init_with::<B>(record).to_device(&device); let model = config.model.init_with::<B>(record).to_device(&device);

View File

@ -9,7 +9,7 @@ use burn::{
module::Module, module::Module,
nn::loss::CrossEntropyLoss, nn::loss::CrossEntropyLoss,
optim::AdamConfig, optim::AdamConfig,
record::{CompactRecorder, Recorder}, record::CompactRecorder,
tensor::{ tensor::{
backend::{ADBackend, Backend}, backend::{ADBackend, Backend},
Int, Tensor, Int, Tensor,
@ -103,10 +103,7 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
let model_trained = learner.fit(dataloader_train, dataloader_test); let model_trained = learner.fit(dataloader_train, dataloader_test);
CompactRecorder::new() model_trained
.record( .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
model_trained.into_record(),
format!("{artifact_dir}/model").into(),
)
.expect("Failed to save trained model"); .expect("Failed to save trained model");
} }

View File

@ -4,7 +4,7 @@ use crate::model::Model;
use burn::module::Module; use burn::module::Module;
use burn::optim::decay::WeightDecayConfig; use burn::optim::decay::WeightDecayConfig;
use burn::optim::AdamConfig; use burn::optim::AdamConfig;
use burn::record::{CompactRecorder, NoStdTrainingRecorder, Recorder}; use burn::record::{CompactRecorder, NoStdTrainingRecorder};
use burn::{ use burn::{
config::Config, config::Config,
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset}, data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
@ -72,10 +72,10 @@ pub fn run<B: ADBackend>(device: B::Device) {
.save(format!("{ARTIFACT_DIR}/config.json").as_str()) .save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap(); .unwrap();
NoStdTrainingRecorder::new() model_trained
.record( .save_file(
model_trained.into_record(), format!("{ARTIFACT_DIR}/model"),
format!("{ARTIFACT_DIR}/model").into(), &NoStdTrainingRecorder::new(),
) )
.expect("Failed to save trained model"); .expect("Failed to save trained model");
} }