mirror of https://github.com/tracel-ai/burn.git
Add book + fix some code (#671)
This commit is contained in:
parent
b60d931771
commit
00d3d208b8
|
@ -1,9 +1,20 @@
|
|||
- [Overview](./overview.md)
|
||||
- [Why Burn?](./motivation.md)
|
||||
- [Guide](./guide/README.md)
|
||||
- [Model](./guide/model.md)
|
||||
- [Data](./guide/data.md)
|
||||
- [Training](./guide/training.md)
|
||||
- [Backend](./guide/backend.md)
|
||||
- [Inference](./guide/inference.md)
|
||||
- [Conclusion](./guide/conclusion.md)
|
||||
- [Basic Workflow: From Training to Inference](./basic-workflow/README.md)
|
||||
- [Model](./basic-workflow/model.md)
|
||||
- [Data](./basic-workflow/data.md)
|
||||
- [Training](./basic-workflow/training.md)
|
||||
- [Backend](./basic-workflow/backend.md)
|
||||
- [Inference](./basic-workflow/inference.md)
|
||||
- [Conclusion](./basic-workflow/conclusion.md)
|
||||
- [Building Blocks](./building-blocks/README.md)
|
||||
- [Backend](./building-blocks/backend.md)
|
||||
- [Tensor](./building-blocks/tensor.md)
|
||||
- [Autodiff](./building-blocks/autodiff.md)
|
||||
- [Module](./building-blocks/module.md)
|
||||
- [Import ONNX Model]()
|
||||
- [Advanced]()
|
||||
- [Custom Training Loops]()
|
||||
- [Custom Metric]()
|
||||
- [Custom Kernels]()
|
||||
- [WGPU]()
|
||||
|
|
|
@ -34,7 +34,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
|||
TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into())
|
||||
.expect("Failed to save trained model");
|
||||
.expect("Failed to load trained model");
|
||||
|
||||
let model = config.model.init_with::<B>(record).to_device(&device);
|
||||
|
||||
|
@ -53,4 +53,4 @@ Then we can fetch the record using the same recorder as we used during training.
|
|||
Finally we can init the model with the configuration and the record before sending it to the wanted device for inference.
|
||||
For simplicity we can use the same batcher used during the training to pass from a MNISTItem to a tensor.
|
||||
|
||||
By running the infer function, you should see the predictions of your model!
|
||||
By running the infer function, you should see the predictions of your model!
|
Before Width: | Height: | Size: 52 KiB After Width: | Height: | Size: 52 KiB |
|
@ -114,11 +114,8 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
|
|||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
CompactRecorder::new()
|
||||
.record(
|
||||
model_trained.into_record(),
|
||||
format!("{artifact_dir}/model").into(),
|
||||
)
|
||||
model_trained
|
||||
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
||||
.expect("Failed to save trained model");
|
||||
}
|
||||
```
|
|
@ -0,0 +1,7 @@
|
|||
# Building Blocks
|
||||
|
||||
In this section, we'll guide you through the core elements that make up Burn.
|
||||
We'll walk you through the key components that serve as the building blocks of the framework and your future projects.
|
||||
|
||||
As you explore Burn, you might notice that we occasionally draw comparisons to PyTorch.
|
||||
We believe it can provide a smoother learning curve and help you grasp the nuances more effectively.
|
|
@ -0,0 +1,83 @@
|
|||
# Autodiff
|
||||
|
||||
Burn's tensor also supports autodifferentiation, which is an essential part of any deep learning framework.
|
||||
We introduced the `Backend` trait in the [previous section](./backend.md), but Burn also has another trait for autodiff: `ADBackend`.
|
||||
|
||||
However, not all tensors support auto-differentiation; you need a backend that implements both the `Backend` and `ADBackend` traits.
|
||||
Fortunately, you can add autodifferentiation capabilities to any backend using a backend decorator: `type MyAutodiffBackend = ADBackendDecorator<MyBackend>`.
|
||||
This decorator implements both the `ADBackend` and `Backend` traits by maintaining a dynamic computational graph and utilizing the inner backend to execute tensor operations.
|
||||
|
||||
The `ADBackend` trait adds new operations on float tensors that can't be called otherwise.
|
||||
It also provides a new associated type, `B::Gradients`, where each calculated gradient resides.
|
||||
|
||||
```rust, ignore
|
||||
fn calculate_gradients<B: ADBackend>(tensor: Tensor<B, 2>) -> B::Gradients {
|
||||
let mut gradients = tensor.clone().backward();
|
||||
|
||||
let tensor_grad = tensor.grad(&gradients); // get
|
||||
let tensor_grad = tensor.grad_remove(&mut gradients); // pop
|
||||
|
||||
gradients
|
||||
}
|
||||
```
|
||||
|
||||
Note that some functions will always be available even if the backend doesn't implement the `ADBackend` trait.
|
||||
In such cases, those functions will do nothing.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------------------------------------------|------------------------------------------------------|
|
||||
| `tensor.detach()` | `tensor.detach()` |
|
||||
| `tensor.require_grad()` | `tensor.requires_grad()` |
|
||||
| `tensor.is_require_grad()` | `tensor.requires_grad` |
|
||||
| `tensor.set_require_grad(require_grad)` | `tensor.requires_grad(False)` |
|
||||
|
||||
However, you're unlikely to make any mistakes since you can't call `backward` on a tensor that is on a backend that doesn't implement `ADBackend`.
|
||||
Additionally, you can't retrieve the gradient of a tensor without an autodiff backend.
|
||||
|
||||
## Difference with PyTorch
|
||||
|
||||
The way Burn handles gradients is different from PyTorch.
|
||||
First, when calling `backward`, each parameter doesn't have its `grad` field updated.
|
||||
Instead, the backward pass returns all the calculated gradients in a container.
|
||||
This approach offers numerous benefits, such as the ability to easily send gradients to other threads.
|
||||
|
||||
You can also retrieve the gradient for a specific parameter using the `grad` method on a tensor.
|
||||
Since this method takes the gradients as input, it's hard to forget to call `backward` beforehand.
|
||||
Note that sometimes, using `grad_remove` can improve performance by allowing inplace operations.
|
||||
|
||||
In PyTorch, when you don't need gradients for inference or validation, you typically need to scope your code using a block.
|
||||
|
||||
```python
|
||||
# Inference mode
|
||||
torch.inference():
|
||||
# your code
|
||||
...
|
||||
|
||||
# Or no grad
|
||||
torch.no_grad():
|
||||
# your code
|
||||
...
|
||||
```
|
||||
|
||||
With Burn, you don't need to wrap the backend with the `ADBackendDecorator` for inference, and you can call `inner()` to obtain the inner tensor, which is useful for validation.
|
||||
|
||||
```rust, ignore
|
||||
/// Use `B: ADBackend`
|
||||
fn example_validation<B: ADBackend>(tensor: Tensor<B, 2>) {
|
||||
let inner_tensor: Tensor<B::InnerBackend, 2> = tensor.inner();
|
||||
let _ = inner_tensor + 5;
|
||||
}
|
||||
|
||||
/// Use `B: Backend`
|
||||
fn example_inference<B: Backend>(tensor: Tensor<B, 2>) {
|
||||
let _ = tensor + 5;
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
**Gradients with Optimizers**
|
||||
|
||||
We've seen how gradients can be used with tensors, but the process is a bit different when working with optimizers from `burn-core`.
|
||||
To work with the `Module` trait, a translation step is required to link tensor parameters with their gradients.
|
||||
This step is necessary to easily support gradient accumulation and training on multiple devices, where each module can be forked and run on different devices in parallel.
|
||||
We'll explore deeper into this topic in the [Module](./module.md) section.
|
|
@ -0,0 +1,11 @@
|
|||
# Backend
|
||||
|
||||
Nearly everything in Burn is based on the `Backend` trait, which enables you to run tensor operations using different implementations without having to modify your code.
|
||||
While a backend may not necessarily have autodiff capabilities, the `ADBackend` trait specifies when autodiff is needed.
|
||||
This trait not only abstracts operations but also tensor, device, and element types, providing each backend the flexibility they need.
|
||||
It's worth noting that the trait assumes eager mode since burn fully supports dynamic graphs.
|
||||
However, we may create another API to assist with integrating graph-based backends, without requiring any changes to the user's code.
|
||||
|
||||
Users are not expected to directly use the backend trait methods, as it is primarily designed with backend developers in mind rather than Burn users.
|
||||
Therefore, most Burn userland APIs are generic across backends.
|
||||
This approach helps users discover the API more organically with proper autocomplete and documentation.
|
|
@ -0,0 +1,101 @@
|
|||
# Module
|
||||
|
||||
The `Module` derive allows you to create your own neural network modules, similar to PyTorch.
|
||||
The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.
|
||||
|
||||
```rust, ignore
|
||||
use burn::nn;
|
||||
use burn::module::Module;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PositionWiseFeedForward<B: Backend> {
|
||||
linear_inner: Linear<B>,
|
||||
linear_outer: Linear<B>,
|
||||
dropout: Dropout,
|
||||
gelu: GELU,
|
||||
}
|
||||
|
||||
impl<B: Backend> PositionWiseFeedForward<B> {
|
||||
/// Normal method added to a struct.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let x = self.linear_inner.forward(input);
|
||||
let x = self.gelu.forward(x);
|
||||
let x = self.dropout.forward(x);
|
||||
|
||||
self.linear_outer.forward(x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that all fields declared in the struct must also implement the `Module` trait.
|
||||
|
||||
## Tensor
|
||||
|
||||
If you want to create your own module that contains tensors, and not just other modules defined with the `Module` derive, you need to be careful to achieve the behavior you want.
|
||||
|
||||
- `Param<Tensor<B, D>>`:
|
||||
If you want the tensor to be included as a parameter of your modules, you need to wrap the tensor in a `Param` struct.
|
||||
This will create an ID that will be used to identify this parameter.
|
||||
This is essential when performing module optimization and when saving states such as optimizer and module checkpoints.
|
||||
Note that a module's record only contains parameters.
|
||||
|
||||
- `Param<Tensor<B, D>>.set_require_grad(false)`:
|
||||
If you want the tensor to be included as a parameter of your modules, and therefore saved with the module's weights, but you don't want it to be updated by the optimizer.
|
||||
|
||||
- `Tensor<B, D>`:
|
||||
If you want the tensor to act as a constant that can be recreated when instantiating a module.
|
||||
This can be useful when generating sinusoidal embeddings, for example.
|
||||
|
||||
|
||||
## Methods
|
||||
|
||||
These methods are available for all modules.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------------------------------------------|---------------------------------------------------------|
|
||||
| `module.devices()` | N/A |
|
||||
| `module.fork(device)` | Similar to `module.to(device).detach()` |
|
||||
| `module.to_device(device)` | `module.to(device)` |
|
||||
| `module.no_grad()` | `module.require_grad_(False)` |
|
||||
| `module.num_params()` | N/A |
|
||||
| `module.visit(visitor)` | N/A |
|
||||
| `module.map(mapper)` | N/A |
|
||||
| `module.into_record()` | Similar to `state_dict` |
|
||||
| `module.load_record(record)` | Similar to `load_state_dict(state_dict)` |
|
||||
| `module.save_file(file_path, recorder)` | N/A |
|
||||
| `module.load_file(file_path, recorder)` | N/A |
|
||||
|
||||
|
||||
Similar to the backend trait, there is also the `ADModule` trait to signify a module with autodiff support.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------------------------------------------|---------------------------------------------------------|
|
||||
| `module.valid()` | `module.eval()` |
|
||||
|
||||
## Visitor & Mapper
|
||||
|
||||
As mentioned earlier, modules primarily function as parameter containers.
|
||||
Therefore, we naturally offer several ways to perform functions on each parameter.
|
||||
This is distinct from PyTorch, where extending module functionalities is not as straightforward.
|
||||
|
||||
The `map` and `visitor` methods are quite similar but serve different purposes.
|
||||
Mapping is used for potentially mutable operations where each parameter of a module can be updated to a new value.
|
||||
In Burn, optimizers are essentially just sophisticated module mappers.
|
||||
Visitors, on the other hand, are used when you don't intend to modify the module but need to retrieve specific information from it, such as the number of parameters or a list of devices in use.
|
||||
|
||||
You can implement your own mapper or visitor by implementing these simple traits:
|
||||
|
||||
```rust, ignore
|
||||
/// Module visitor trait.
|
||||
pub trait ModuleVisitor<B: Backend> {
|
||||
/// Visit a tensor in the module.
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
|
||||
}
|
||||
|
||||
/// Module mapper trait.
|
||||
pub trait ModuleMapper<B: Backend> {
|
||||
/// Map a tensor in the module.
|
||||
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
|
||||
}
|
||||
```
|
|
@ -0,0 +1,169 @@
|
|||
# Tensor
|
||||
|
||||
As previously explained in the [model section](../basic-workflow/model.md), the Tensor struct has 3 generic arguments: the backend, the dimension number (rank), and the kind.
|
||||
|
||||
```rust , ignore
|
||||
Tensor<B, D> // Float tensor (default)
|
||||
Tensor<B, D, Float> // Explicit float tensor
|
||||
Tensor<B, D, Int> // Int tensor
|
||||
Tensor<B, D, Bool> // Bool tensor
|
||||
```
|
||||
|
||||
Note that the specific element types used for `Float`, `Int`, and `Bool` tensors are defined by backend implementations.
|
||||
|
||||
## Operations
|
||||
|
||||
Almost all Burn operations take ownership of the input tensors.
|
||||
Therefore, reusing a tensor multiple times will necessitate cloning it.
|
||||
Don't worry, the tensor's buffer isn't copied, but a reference to it is increased.
|
||||
This makes it possible to determine exactly how many times a tensor is used, which is very convenient for reusing tensor buffers and improving performance.
|
||||
For that reason, we don't provide explicit inplace operations.
|
||||
If a tensor is used only one time, inplace operations will always be used when available.
|
||||
|
||||
Normally with PyTorch, explicit inplace operations aren't supported during the backward pass, making them useful only for data preprocessing or inference-only model implementations.
|
||||
With Burn, you can focus more on _what_ the model should do, rather than on _how_ to do it.
|
||||
We take the responsibility of making your code run as fast as possible during training as well as inference.
|
||||
The same principles apply to broadcasting; all operations support broadcasting unless specified otherwise.
|
||||
|
||||
Here, we provide a list of all supported operations along with their PyTorch equivalents.
|
||||
Note that for the sake of simplicity, we ignore type signatures.
|
||||
For more details, refer to the [full documentation](https://docs.rs/burn/latest/burn/tensor/struct.Tensor.html).
|
||||
|
||||
### Basic Operations
|
||||
|
||||
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
|-------------------------------------------------------|---------------------------------------------------|
|
||||
| `Tensor::empty(shape)` | `torch.empty(shape)` |
|
||||
| `Tensor::empty_device(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `tensor.dims()` | `tensor.size()` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
||||
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
| `tensor.to_device(device)` | `tensor.to(device)` |
|
||||
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
|
||||
| `tensor.equal(other)` | `x == y` |
|
||||
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
||||
| `tensor.into_data()` | N/A |
|
||||
| `tensor.to_data()` | N/A |
|
||||
| `Tensor::from_data(data)` | N/A |
|
||||
| `Tensor::from_data_device(data, device)` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `Tensor::from_primitive(primitive)` | N/A |
|
||||
|
||||
### Numeric Operations
|
||||
|
||||
Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
|----------------------------------------------------|---------------------------------------------------------|
|
||||
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
|
||||
| `tensor + other` or `tensor.add(other)` | `tensor + other` |
|
||||
| `tensor + scalar` or `tensor.add_scalar(scalar)` | `tensor + scalar` |
|
||||
| `tensor - other` or `tensor.sub(other)` | `tensor - other` |
|
||||
| `tensor - scalar` or `tensor.sub_scalar(scalar)` | `tensor - scalar` |
|
||||
| `tensor / other` or `tensor.div(other)` | `tensor / other` |
|
||||
| `tensor / scalar` or `tensor.div_scalar(scalar)` | `tensor / scalar` |
|
||||
| `tensor * other` or `tensor.mul(other)` | `tensor * other` |
|
||||
| `tensor * scalar` or `tensor.mul_scalar(scalar)` | `tensor * scalar` |
|
||||
| `-tensor` or `tensor.neg()` | `-tensor` |
|
||||
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
|
||||
| `Tensor::zeros_device(shape, device)` | `torch.zeros(shape, device=device)` |
|
||||
| `Tensor::ones(shape)` | `torch.ones(shape)` |
|
||||
| `Tensor::ones_device(shape, device)` | `torch.ones(shape, device=device)` |
|
||||
| `Tensor::full(shape, fill_value)` | `torch.full(shape, fill_value)` |
|
||||
| `Tensor::full_device(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
|
||||
| `tensor.mean()` | `tensor.mean()` |
|
||||
| `tensor.sum()` | `tensor.sum()` |
|
||||
| `tensor.mean_dim(dim)` | `tensor.mean(dim)` |
|
||||
| `tensor.sum_dim(dim)` | `tensor.sum(dim)` |
|
||||
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
|
||||
| `tensor.greater(other)` | `tensor.gt(other)` |
|
||||
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
|
||||
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
|
||||
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
|
||||
| `tensor.lower(other)` | `tensor.lt(other)` |
|
||||
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
|
||||
| `tensor.lower_equal(other)` | `tensor.le(other)` |
|
||||
| `tensor.lower_equal_elem(scalar)` | `tensor.le(scalar)` |
|
||||
| `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` |
|
||||
| `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` |
|
||||
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
|
||||
| `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` |
|
||||
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
|
||||
| `tensor.select_assign(dim, indices, values)` | N/A |
|
||||
| `tensor.argmax(dim)` | `tensor.argmax(dim)` |
|
||||
| `tensor.max()` | `tensor.max()` |
|
||||
| `tensor.max_dim(dim)` | `tensor.max(dim)` |
|
||||
| `tensor.max_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.argmin(dim)` | `tensor.argmin(dim)` |
|
||||
| `tensor.min()` | `tensor.min()` |
|
||||
| `tensor.min_dim(dim)` | `tensor.min(dim)` |
|
||||
| `tensor.min_dim_with_indices(dim)` | N/A |
|
||||
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
|
||||
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
|
||||
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
|
||||
| `tensor.abs()` | `torch.abs(tensor)` |
|
||||
|
||||
### Float Operations
|
||||
|
||||
Those operations are only available for `Float` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------------------------------------------|---------------------------------------------------|
|
||||
| `tensor.exp()` | `tensor.exp()` |
|
||||
| `tensor.log()` | `tensor.log()` |
|
||||
| `tensor.log1p()` | `tensor.log1p()` |
|
||||
| `tensor.erf()` | `tensor.erf()` |
|
||||
| `tensor.powf(value)` | `tensor.pow(value)` |
|
||||
| `tensor.sqrt()` | `tensor.sqrt()` |
|
||||
| `tensor.cos()` | `tensor.cos()` |
|
||||
| `tensor.sin()` | `tensor.sin()` |
|
||||
| `tensor.tanh()` | `tensor.tanh()` |
|
||||
| `tensor.from_floats(floats)` | N/A |
|
||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
||||
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |
|
||||
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
|
||||
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
|
||||
| `tensor.one_hot(index, num_classes)` | N/A |
|
||||
| `tensor.transpose()` | `tensor.T` |
|
||||
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
|
||||
| `tensor.matmul(other)` | `tensor.matmul(other)` |
|
||||
| `tensor.var(dim)` | `tensor.var(dim)` |
|
||||
| `tensor.var_bias(dim)` | N/A |
|
||||
| `tensor.var_mean(dim)` | N/A |
|
||||
| `tensor.var_mean_bias(dim)` | N/A |
|
||||
| `tensor.random(shape, distribution)` | N/A |
|
||||
| `tensor.random_device(shape, distribution, device)` | N/A |
|
||||
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
|
||||
| `tensor.from_full_precision(tensor)` | N/A |
|
||||
|
||||
|
||||
# Int Operations
|
||||
|
||||
Those operations are only available for `Int` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------------------------------------------|---------------------------------------------------------|
|
||||
| `tensor.from_ints(ints)` | N/A |
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.arange(5..10)` | `tensor.arange(start=5, end=10)` |
|
||||
| `tensor.arange_device(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` |
|
||||
| `tensor.arange_step(5..10, 2)` | `tensor.arange(start=5, end=10, step=2)` |
|
||||
| `tensor.arange_step_device(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
|
||||
|
||||
|
||||
# Bool Operations
|
||||
|
||||
Those operations are only available for `Bool` tensors.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------------------------------------------|---------------------------------------------------------|
|
||||
| `tensor.float()` | Similar to `tensor.to(torch.float)` |
|
||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
|
@ -182,6 +182,52 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
|
||||
/// Convert the module into a record containing the state.
|
||||
fn into_record(self) -> Self::Record;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
/// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
|
||||
///
|
||||
/// List of supported file recorders:
|
||||
///
|
||||
/// * [default](crate::record::DefaultFileRecorder)
|
||||
/// * [bincode](crate::record::BinFileRecorder)
|
||||
/// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
|
||||
/// * [json pretty](crate::record::PrettyJsonFileRecorder)
|
||||
/// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
|
||||
/// * [named mpk](crate::record::NamedMpkFileRecorder)
|
||||
/// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
|
||||
///
|
||||
/// ## Notes
|
||||
///
|
||||
/// The file extension is automatically added depending on the file recorder provided, you
|
||||
/// don't have to specify it.
|
||||
fn save_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
|
||||
self,
|
||||
file_path: PB,
|
||||
recorder: &FR,
|
||||
) -> Result<(), crate::record::RecorderError> {
|
||||
let record = Self::into_record(self);
|
||||
recorder.record(record, file_path.into())
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
/// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
|
||||
///
|
||||
/// The recorder should be the same as the one used to save the module, see
|
||||
/// [save_file](Self::save_file).
|
||||
///
|
||||
/// ## Notes
|
||||
///
|
||||
/// The file extension is automatically added depending on the file recorder provided, you
|
||||
/// don't have to specify it.
|
||||
fn load_file<FR: crate::record::FileRecorder, PB: Into<std::path::PathBuf>>(
|
||||
self,
|
||||
file_path: PB,
|
||||
recorder: &FR,
|
||||
) -> Result<Self, crate::record::RecorderError> {
|
||||
let record = recorder.load(file_path.into())?;
|
||||
|
||||
Ok(self.load_record(record))
|
||||
}
|
||||
}
|
||||
|
||||
/// Module visitor trait.
|
||||
|
|
|
@ -34,11 +34,11 @@ impl GradientsParams {
|
|||
}
|
||||
|
||||
/// Remove the gradients for the given [parameter id](ParamId).
|
||||
pub fn remove<B, const D: usize>(&self, id: &ParamId) -> Option<Tensor<B, D>>
|
||||
pub fn remove<B, const D: usize>(&mut self, id: &ParamId) -> Option<Tensor<B, D>>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
self.container.get(id)
|
||||
self.container.remove(id)
|
||||
}
|
||||
|
||||
/// Register a gradients tensor for the given [parameter id](ParamId).
|
||||
|
|
|
@ -46,6 +46,12 @@ pub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {
|
|||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
/// File recorder using the [named msgpack](rmp_serde) format.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct NamedMpkFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> FileRecorder for BinGzFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"bin.gz"
|
||||
|
@ -73,6 +79,12 @@ impl<S: PrecisionSettings> FileRecorder for NamedMpkGzFileRecorder<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> FileRecorder for NamedMpkFileRecorder<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"mpk"
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! str2reader {
|
||||
(
|
||||
$file:expr
|
||||
|
@ -251,6 +263,34 @@ impl<S: PrecisionSettings> Recorder for NamedMpkGzFileRecorder<S> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for NamedMpkFileRecorder<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let mut writer = str2writer!(file)?;
|
||||
|
||||
rmp_serde::encode::write_named(&mut writer, &item)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
|
||||
let reader = str2reader!(file)?;
|
||||
let state = rmp_serde::decode::from_read(reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
@ -289,6 +329,11 @@ mod tests {
|
|||
test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_save_and_load_mpk_format() {
|
||||
test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())
|
||||
}
|
||||
|
||||
fn test_can_save_and_load<Recorder: FileRecorder>(recorder: Recorder) {
|
||||
let model_before = create_model();
|
||||
recorder
|
||||
|
|
|
@ -13,7 +13,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MNISTItem)
|
|||
TrainingConfig::load(&format!("{artifact_dir}/config.json")).expect("A config exists");
|
||||
let record = CompactRecorder::new()
|
||||
.load(format!("{artifact_dir}/model").into())
|
||||
.expect("Failed to save trained model");
|
||||
.expect("Failed to load trained model");
|
||||
|
||||
let model = config.model.init_with::<B>(record).to_device(&device);
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ use burn::{
|
|||
module::Module,
|
||||
nn::loss::CrossEntropyLoss,
|
||||
optim::AdamConfig,
|
||||
record::{CompactRecorder, Recorder},
|
||||
record::CompactRecorder,
|
||||
tensor::{
|
||||
backend::{ADBackend, Backend},
|
||||
Int, Tensor,
|
||||
|
@ -103,10 +103,7 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
|
|||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
CompactRecorder::new()
|
||||
.record(
|
||||
model_trained.into_record(),
|
||||
format!("{artifact_dir}/model").into(),
|
||||
)
|
||||
model_trained
|
||||
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
|
||||
.expect("Failed to save trained model");
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::model::Model;
|
|||
use burn::module::Module;
|
||||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::record::{CompactRecorder, NoStdTrainingRecorder, Recorder};
|
||||
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
|
||||
|
@ -72,10 +72,10 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
|
||||
.unwrap();
|
||||
|
||||
NoStdTrainingRecorder::new()
|
||||
.record(
|
||||
model_trained.into_record(),
|
||||
format!("{ARTIFACT_DIR}/model").into(),
|
||||
model_trained
|
||||
.save_file(
|
||||
format!("{ARTIFACT_DIR}/model"),
|
||||
&NoStdTrainingRecorder::new(),
|
||||
)
|
||||
.expect("Failed to save trained model");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue