mirror of https://github.com/tracel-ai/burn.git
Doc/update readme (#241)
This commit is contained in:
parent
608ee3f124
commit
04d72631d7
128
README.md
128
README.md
|
@ -23,6 +23,7 @@ __Sections__
|
||||||
* [Components](#components)
|
* [Components](#components)
|
||||||
* [Backend](#backend)
|
* [Backend](#backend)
|
||||||
* [Tensor](#tensor)
|
* [Tensor](#tensor)
|
||||||
|
* [Autodiff](#autodiff)
|
||||||
* [Module](#module)
|
* [Module](#module)
|
||||||
* [Config](#config)
|
* [Config](#config)
|
||||||
* [Learner](#learner)
|
* [Learner](#learner)
|
||||||
|
@ -35,7 +36,7 @@ __Sections__
|
||||||
* [Training](#learner) with full support for `metric`, `logging` and `checkpointing` 📈
|
* [Training](#learner) with full support for `metric`, `logging` and `checkpointing` 📈
|
||||||
* [Tensor](#tensor) crate with backends as pluging 🔧
|
* [Tensor](#tensor) crate with backends as pluging 🔧
|
||||||
* [Tch](https://github.com/burn-rs/burn/tree/main/burn-tch) backend with CPU/GPU support 🚀
|
* [Tch](https://github.com/burn-rs/burn/tree/main/burn-tch) backend with CPU/GPU support 🚀
|
||||||
* [NdArray](https://github.com/burn-rs/burn/tree/main/burn-ndarray) backend with fast compile time 👌
|
* [NdArray](https://github.com/burn-rs/burn/tree/main/burn-ndarray) backend with [`no_std`](#no_std-support) support, running on any platform 👌
|
||||||
* [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend making any backend differentiable 🌟
|
* [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend making any backend differentiable 🌟
|
||||||
* [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate with multiple utilities and sources 📚
|
* [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate with multiple utilities and sources 📚
|
||||||
|
|
||||||
|
@ -49,61 +50,81 @@ This may also be a good idea to take a look the main [components](#components) o
|
||||||
* [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using different backends.
|
* [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using different backends.
|
||||||
* [MNIST Inference Web](https://github.com/burn-rs/burn/tree/main/examples/mnist-inference-web) run trained model in the browser for inference.
|
* [MNIST Inference Web](https://github.com/burn-rs/burn/tree/main/examples/mnist-inference-web) run trained model in the browser for inference.
|
||||||
* [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification) train a transformer encoder from scratch on GPU.
|
* [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification) train a transformer encoder from scratch on GPU.
|
||||||
|
* [Text Generation](https://github.com/burn-rs/burn/tree/main/examples/text-generation) train an autoregressive transformer from scratch on GPU.
|
||||||
|
|
||||||
### Components
|
### Components
|
||||||
|
|
||||||
Knowing the main components will be of great help when starting playing with `burn`.
|
Understanding the key components and philosophy of `burn` can greatly help when beginning to work with the framework.
|
||||||
|
|
||||||
#### Backend
|
#### Backend
|
||||||
|
|
||||||
Almost everything is based on the `Backend` trait, which allows to run tensor operations with different implementations without having to change your code.
|
Nearly everything in `burn` is based on the `Backend` trait, which enables you to run tensor operations using different implementations without having to modify your code.
|
||||||
A backend does not necessary have autodiff capabilities, the `ADBackend` trait is there to specify when autodiff is required.
|
While a backend may not necessarily have autodiff capabilities, the `ADBackend` trait specifies when autodiff is needed.
|
||||||
|
This trait not only abstracts operations but also tensor, device and element types, providing each backend the flexibility they need.
|
||||||
|
It's worth noting that the trait assumes eager mode since `burn` fully supports dynamic graphs.
|
||||||
|
However, we may create another API to assist with integrating graph-based backends, without requiring any changes to the user's code.
|
||||||
|
|
||||||
#### Tensor
|
#### Tensor
|
||||||
|
|
||||||
The `Tensor` struct is at the core of the `burn` framework.
|
At the core of burn lies the `Tensor` struct, which encompasses multiple types of tensors, including `Float`, `Int`, and `Bool`.
|
||||||
It takes two generic parameters, the `Backend` and the number of dimensions `D`,
|
The element types of these tensors are specified by the backend and are usually designated as a generic argument (e.g., `NdArrayBackend<f32>`).
|
||||||
|
Although the same struct is used for all tensors, the available methods differ depending on the tensor kind.
|
||||||
|
You can specify the desired tensor kind by setting the third generic argument, which defaults to `Float`.
|
||||||
|
The first generic argument specifies the backend, while the second specifies the number of dimensions.
|
||||||
|
|
||||||
Backpropagation is also supported on any backend by making them auto differentiable using a simple decorator.
|
```rust
|
||||||
|
use burn::tensor::backend::Backend;
|
||||||
|
use burn::tensor::{Tensor, Int};
|
||||||
|
|
||||||
|
fn function<B: Backend>(tensor_float: Tensor<B, 2>) {
|
||||||
|
let _tensor_bool = tensor_float.clone().equal_elem(2.0); // Tensor<B, 2, Bool>
|
||||||
|
let _tensor_int = tensor_float.argmax(1) // Tensor<B, 2, Int>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
As demonstrated in the previous example, nearly all operations require owned tensors as parameters, which means that calling `Clone` explicitly is necessary when reusing the same tensor multiple times.
|
||||||
|
However, there's no need to worry since the tensor's data won't be copied, it will be flagged as readonly when multiple tensors use the same allocated memory.
|
||||||
|
This enables backends to reuse tensor data when possible, similar to a copy-on-write pattern, while remaining completely transparent to the user.
|
||||||
|
|
||||||
|
#### Autodiff
|
||||||
|
|
||||||
|
The 'Backend' trait is highly flexible, enabling backpropagation to be implemented using a simple backend decorator, which makes any backend differentiable.
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
use burn::tensor::backend::{ADBackend, Backend};
|
use burn::tensor::backend::{ADBackend, Backend};
|
||||||
use burn::tensor::{Distribution, Tensor};
|
use burn::tensor::{Distribution, Tensor};
|
||||||
use burn_autodiff::ADBackendDecorator;
|
use burn_autodiff::ADBackendDecorator;
|
||||||
use burn_ndarray::NdArrayBackend;
|
use burn_ndarray::NdArrayBackend;
|
||||||
use burn_tch::TchBackend;
|
|
||||||
|
|
||||||
fn simple_function<B: Backend>() -> Tensor<B, 2> {
|
fn linear<B: Backend>(x: Tensor<B, 2>, weight: Tensor<B, 2>, bias: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
let x = Tensor::<B, 2>::random([3, 3], Distribution::Standard);
|
x.matmul(weight) + bias
|
||||||
let y = Tensor::<B, 2>::random([3, 3], Distribution::Standard);
|
|
||||||
|
|
||||||
x.matmul(&y)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn simple_function_grads<B: ADBackend>() -> B::Gradients {
|
|
||||||
let z = simple_function::<B>();
|
|
||||||
|
|
||||||
z.backward()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let _z = simple_function::<NdArrayBackend<f32>>(); // Compiles
|
type Backend = NdArrayBackend<f32>;
|
||||||
let _z = simple_function::<TchBackend<f32>>(); // Compiles
|
|
||||||
|
|
||||||
let _grads = simple_function_grads::<NdArrayBackend<f32>>(); // Doesn't compile
|
let weight = Tensor::random([3, 3], Distribution::Standard);
|
||||||
let _grads = simple_function_grads::<TchBackend<f32>>(); // Doesn't compile
|
let bias = Tensor::zeros([1, 3]);
|
||||||
|
let x = Tensor::random([3, 3], Distribution::Standard);
|
||||||
|
|
||||||
type ADNdArrayBackend = ADBackendDecorator<NdArrayBackend<f32>>;
|
let y = linear::<Backend>(x.clone(), weight.clone(), bias.clone());
|
||||||
type ADTchBackend = ADBackendDecorator<TchBackend<f32>>;
|
// y.backward() // Method backward doesn't exist
|
||||||
|
|
||||||
let _grads = simple_function_grads::<ADNdArrayBackend>(); // Compiles
|
let y = linear::<ADBackendDecorator<Backend>>(
|
||||||
let _grads = simple_function_grads::<ADTchBackend>(); // Compiles
|
Tensor::from_inner(x),
|
||||||
|
Tensor::from_inner(weight).require_grad(),
|
||||||
|
Tensor::from_inner(bias).require_grad(),
|
||||||
|
);
|
||||||
|
let grads = y.backward(); // Method exists
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Module
|
#### Module
|
||||||
|
|
||||||
The `Module` derive let your create your own neural network modules similar to PyTorch.
|
The `Module` derive allows you to create your own neural network modules, similar to PyTorch.
|
||||||
|
Note that the `Module` derive generates all the necessary methods to make your type essentially a parameter container.
|
||||||
|
It makes no assumptions about how the forward function is declared.
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
use burn::nn;
|
use burn::nn;
|
||||||
|
@ -111,13 +132,25 @@ use burn::module::{Param, Module};
|
||||||
use burn::tensor::backend::Backend;
|
use burn::tensor::backend::Backend;
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
struct MyModule<B: Backend> {
|
pub struct PositionWiseFeedForward<B: Backend> {
|
||||||
my_param: Param<nn::Linear<B>>,
|
linear_inner: Param<Linear<B>>,
|
||||||
repeat: usize,
|
linear_outer: Param<Linear<B>>,
|
||||||
|
dropout: Dropout,
|
||||||
|
gelu: GELU,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> PositionWiseFeedForward<B> {
|
||||||
|
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
let x = self.linear_inner.forward(input);
|
||||||
|
let x = self.gelu.forward(x);
|
||||||
|
let x = self.dropout.forward(x);
|
||||||
|
|
||||||
|
self.linear_outer.forward(x)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that only the fields wrapped inside `Param` are updated during training, and the other ones should implement `Clone`.
|
Note that only the fields wrapped inside `Param` are updated during training, and the other fields should implement the `Clone` trait.
|
||||||
|
|
||||||
#### Config
|
#### Config
|
||||||
|
|
||||||
|
@ -127,21 +160,24 @@ The `Config` derive lets you define serializable and deserializable configuratio
|
||||||
use burn::config::Config;
|
use burn::config::Config;
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config)]
|
||||||
struct MyConfig {
|
pub struct PositionWiseFeedForwardConfig {
|
||||||
#[config(default = 1.0e-6)]
|
pub d_model: usize,
|
||||||
pub epsilon: usize,
|
pub d_ff: usize,
|
||||||
pub dim: usize,
|
#[config(default = 0.1)]
|
||||||
|
pub dropout: f64,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
The derive also adds useful methods to your config.
|
|
||||||
|
The derive also adds useful methods to your config, similar to a builder pattern.
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
fn main() {
|
fn main() {
|
||||||
let config = MyConfig::new(100);
|
let config = PositionWiseFeedForwardConfig::new(512, 2048);
|
||||||
println!("{}", config.epsilon); // 1.0.e-6
|
println!("{}", config.d_model); // 512
|
||||||
println!("{}", config.dim); // 100
|
println!("{}", config.d_ff); // 2048
|
||||||
let config = MyConfig::new(100).with_epsilon(1.0e-8);
|
println!("{}", config.dropout); // 0.1
|
||||||
println!("{}", config.epsilon); // 1.0.e-8
|
let config = config.with_dropout(0.2);
|
||||||
|
println!("{}", config.dropout); // 0.2
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -178,9 +214,13 @@ See this [example](https://github.com/burn-rs/burn/tree/main/examples/mnist) for
|
||||||
|
|
||||||
## no_std support
|
## no_std support
|
||||||
|
|
||||||
Burn supports `no_std` with `alloc` for the inference mode with the NDArray backend. Simply disable the default features of the `burn` and `burn-ndarray` packages (minimum required to run the inference mode). See the [burn-no-std-tests](https://github.com/burn-rs/burn/tree/main/examples/burn-no-std-tests) package as a reference implementation. Additionally `burn-core` and `burn-tensor` packages support `no_std` with `alloc` if needed to direclty include them as dependencies (the `burn` package reexports `burn-core` and `burn-tensor`).
|
Burn supports `no_std` with `alloc` for the inference mode with the NDArray backend.
|
||||||
|
Simply disable the default features of the `burn` and `burn-ndarray` crates (minimum required to run the inference mode).
|
||||||
|
See the [burn-no-std-tests](https://github.com/burn-rs/burn/tree/main/examples/burn-no-std-tests) example as a reference implementation.
|
||||||
|
|
||||||
Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method. Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode.
|
Additionally `burn-core` and `burn-tensor` crates support `no_std` with `alloc` if needed to direclty include them as dependencies (the `burn` crates reexports `burn-core` and `burn-tensor`).
|
||||||
|
Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method.
|
||||||
|
Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue