2022-09-05 02:22:56 +08:00
|
|
|
<div align="center">
|
2022-11-29 09:08:36 +08:00
|
|
|
<img src="https://raw.githubusercontent.com/burn-rs/burn/main/assets/logo-burn-full.png" width="200px"/>
|
2022-07-28 04:15:48 +08:00
|
|
|
|
2023-03-17 21:32:22 +08:00
|
|
|
[![Discord](https://img.shields.io/discord/1038839012602941528.svg?color=7289da&&logo=discord)](https://discord.gg/uPEBbYYDB6)
|
2022-11-21 08:41:55 +08:00
|
|
|
[![Test Status](https://github.com/burn-rs/burn/actions/workflows/test.yml/badge.svg)](https://github.com/burn-rs/burn/actions/workflows/test.yml)
|
2022-09-05 09:29:34 +08:00
|
|
|
[![Documentation](https://docs.rs/burn/badge.svg)](https://docs.rs/burn)
|
2023-01-02 23:40:30 +08:00
|
|
|
[![Current Crates.io Version](https://img.shields.io/crates/v/burn.svg)](https://crates.io/crates/burn)
|
2023-03-20 23:51:07 +08:00
|
|
|
[![Rust Version](https://img.shields.io/badge/Rust-1.65.0+-blue)](https://releases.rs/docs/1.65.0)
|
|
|
|
![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)
|
2022-07-28 04:15:48 +08:00
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
This library strives to serve as a comprehensive **deep learning framework**, offering exceptional
|
|
|
|
flexibility and written in Rust. Our objective is to cater to both researchers and practitioners by
|
|
|
|
simplifying the process of experimenting, training, and deploying models.
|
2023-04-13 03:29:59 +08:00
|
|
|
|
2022-09-05 02:22:56 +08:00
|
|
|
<div align="left">
|
|
|
|
|
2022-10-05 08:30:03 +08:00
|
|
|
## Features
|
|
|
|
|
2023-07-06 20:51:57 +08:00
|
|
|
- Customizable, user-friendly neural network [module](#module) 🔥
|
|
|
|
- Comprehensive [training](#learner) tools, inclusive of `metrics`, `logging`, and `checkpointing`
|
|
|
|
📈
|
|
|
|
- Versatile [Tensor](#tensor) crate equipped with pluggable backends 🔧
|
|
|
|
- [Torch](https://github.com/burn-rs/burn/tree/main/burn-tch) backend, supporting both CPU and GPU
|
|
|
|
🚀
|
|
|
|
- [Ndarray](https://github.com/burn-rs/burn/tree/main/burn-ndarray) backend with
|
|
|
|
[`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌
|
|
|
|
- [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform,
|
|
|
|
browser-inclusive, GPU-based computations 🌐
|
|
|
|
- [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables
|
|
|
|
differentiability across all backends 🌟
|
|
|
|
- [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range
|
|
|
|
of utilities and sources 📚
|
|
|
|
- [Import](https://github.com/burn-rs/burn/tree/main/burn-import) crate that simplifies the
|
|
|
|
integration of pretrained models 📦
|
|
|
|
|
2023-06-28 21:54:10 +08:00
|
|
|
## Supported Platforms
|
|
|
|
|
|
|
|
### [Burn-ndarray][1] Backend
|
|
|
|
|
|
|
|
| Option | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
|
|
|
|
| :--------- | :-: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |
|
2023-07-06 20:51:57 +08:00
|
|
|
| Pure Rust | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes |
|
|
|
|
| Accelerate | Yes | No | No | Yes | No | No | Yes | No |
|
|
|
|
| Netlib | Yes | No | Yes | Yes | Yes | No | No | No |
|
|
|
|
| Openblas | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
|
2023-06-28 21:54:10 +08:00
|
|
|
|
|
|
|
### [Burn-tch][2] Backend
|
|
|
|
|
|
|
|
| Option | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
|
|
|
|
| :----- | :-: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |
|
2023-07-06 20:51:57 +08:00
|
|
|
| CPU | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
|
|
|
|
| CUDA | No | Yes | Yes | No | Yes | No | No | No |
|
|
|
|
| MPS | No | Yes | No | Yes | No | No | No | No |
|
|
|
|
| Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No |
|
2023-06-28 21:54:10 +08:00
|
|
|
|
|
|
|
### [Burn-wgpu][3] Backend
|
|
|
|
|
|
|
|
| Option | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
|
|
|
|
| :-------- | :-: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |
|
2023-07-06 20:51:57 +08:00
|
|
|
| Metal | No | Yes | No | Yes | No | No | Yes | No |
|
|
|
|
| Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |
|
|
|
|
| OpenGL | No | Yes | Yes | Yes | Yes | Yes | Yes | No |
|
|
|
|
| WebGpu | No | Yes | No | No | No | No | No | Yes |
|
|
|
|
| Dx11/Dx12 | No | Yes | No | No | Yes | No | No | No |
|
2023-06-28 21:54:10 +08:00
|
|
|
|
|
|
|
[1]: https://github.com/burn-rs/burn/tree/main/burn-ndarray
|
|
|
|
[2]: https://github.com/burn-rs/burn/tree/main/burn-tch
|
|
|
|
[3]: https://github.com/burn-rs/burn/tree/main/burn-wgpu
|
2022-10-07 05:44:04 +08:00
|
|
|
|
|
|
|
## Get Started
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
The best way to get started with `burn` is to clone the repo and play with the
|
|
|
|
[examples](#examples). This may also be a good idea to take a look the main
|
|
|
|
[components](#components) of `burn` to get a quick overview of the fundamental building blocks. If
|
|
|
|
you're interested in how the framework works, you can read our
|
|
|
|
[architecture document](https://github.com/burn-rs/burn/tree/main/ARCHITECTURE.md).
|
2022-10-07 05:44:04 +08:00
|
|
|
|
|
|
|
### Examples
|
2022-10-05 08:30:03 +08:00
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
- [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using
|
|
|
|
different backends.
|
|
|
|
- [MNIST Inference Web](https://github.com/burn-rs/burn/tree/main/examples/mnist-inference-web) run
|
|
|
|
trained model in the browser for inference.
|
|
|
|
- [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification)
|
|
|
|
train a transformer encoder from scratch on GPU.
|
|
|
|
- [Text Generation](https://github.com/burn-rs/burn/tree/main/examples/text-generation) train an
|
|
|
|
autoregressive transformer from scratch on GPU.
|
2022-11-10 10:32:51 +08:00
|
|
|
|
2022-10-05 08:30:03 +08:00
|
|
|
### Components
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
Understanding the key components and philosophy of `burn` can greatly help when beginning to work
|
|
|
|
with the framework.
|
2022-10-05 08:30:03 +08:00
|
|
|
|
2022-10-07 05:44:04 +08:00
|
|
|
#### Backend
|
2022-10-05 08:30:03 +08:00
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
Nearly everything in `burn` is based on the `Backend` trait, which enables you to run tensor
|
|
|
|
operations using different implementations without having to modify your code. While a backend may
|
|
|
|
not necessarily have autodiff capabilities, the `ADBackend` trait specifies when autodiff is needed.
|
|
|
|
This trait not only abstracts operations but also tensor, device and element types, providing each
|
|
|
|
backend the flexibility they need. It's worth noting that the trait assumes eager mode since `burn`
|
|
|
|
fully supports dynamic graphs. However, we may create another API to assist with integrating
|
|
|
|
graph-based backends, without requiring any changes to the user's code.
|
2022-10-05 08:30:03 +08:00
|
|
|
|
|
|
|
#### Tensor
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
At the core of burn lies the `Tensor` struct, which encompasses multiple types of tensors, including
|
|
|
|
`Float`, `Int`, and `Bool`. The element types of these tensors are specified by the backend and are
|
|
|
|
usually designated as a generic argument (e.g., `NdArrayBackend<f32>`). Although the same struct is
|
|
|
|
used for all tensors, the available methods differ depending on the tensor kind. You can specify the
|
|
|
|
desired tensor kind by setting the third generic argument, which defaults to `Float`. The first
|
|
|
|
generic argument specifies the backend, while the second specifies the number of dimensions.
|
2022-10-05 08:30:03 +08:00
|
|
|
|
2023-03-18 01:46:36 +08:00
|
|
|
```rust
|
|
|
|
use burn::tensor::backend::Backend;
|
|
|
|
use burn::tensor::{Tensor, Int};
|
|
|
|
|
|
|
|
fn function<B: Backend>(tensor_float: Tensor<B, 2>) {
|
|
|
|
let _tensor_bool = tensor_float.clone().equal_elem(2.0); // Tensor<B, 2, Bool>
|
|
|
|
let _tensor_int = tensor_float.argmax(1) // Tensor<B, 2, Int>
|
|
|
|
}
|
|
|
|
```
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
As demonstrated in the previous example, nearly all operations require owned tensors as parameters,
|
|
|
|
which means that calling `Clone` explicitly is necessary when reusing the same tensor multiple
|
|
|
|
times. However, there's no need to worry since the tensor's data won't be copied, it will be flagged
|
|
|
|
as readonly when multiple tensors use the same allocated memory. This enables backends to reuse
|
|
|
|
tensor data when possible, similar to a copy-on-write pattern, while remaining completely
|
|
|
|
transparent to the user.
|
2023-03-18 01:46:36 +08:00
|
|
|
|
|
|
|
#### Autodiff
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
The 'Backend' trait is highly flexible, enabling backpropagation to be implemented using a simple
|
|
|
|
backend decorator, which makes any backend differentiable.
|
2022-11-21 08:41:55 +08:00
|
|
|
|
2022-10-05 08:30:03 +08:00
|
|
|
```rust
|
2022-11-21 08:41:55 +08:00
|
|
|
use burn::tensor::backend::{ADBackend, Backend};
|
|
|
|
use burn::tensor::{Distribution, Tensor};
|
|
|
|
use burn_autodiff::ADBackendDecorator;
|
|
|
|
use burn_ndarray::NdArrayBackend;
|
|
|
|
|
2023-03-18 01:46:36 +08:00
|
|
|
fn linear<B: Backend>(x: Tensor<B, 2>, weight: Tensor<B, 2>, bias: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
|
|
x.matmul(weight) + bias
|
2022-10-07 05:44:04 +08:00
|
|
|
}
|
2022-10-05 08:30:03 +08:00
|
|
|
|
2022-10-07 05:44:04 +08:00
|
|
|
fn main() {
|
2023-03-18 01:46:36 +08:00
|
|
|
type Backend = NdArrayBackend<f32>;
|
2022-11-21 08:41:55 +08:00
|
|
|
|
2023-03-18 01:46:36 +08:00
|
|
|
let weight = Tensor::random([3, 3], Distribution::Standard);
|
|
|
|
let bias = Tensor::zeros([1, 3]);
|
|
|
|
let x = Tensor::random([3, 3], Distribution::Standard);
|
2022-11-21 08:41:55 +08:00
|
|
|
|
2023-03-18 01:46:36 +08:00
|
|
|
let y = linear::<Backend>(x.clone(), weight.clone(), bias.clone());
|
|
|
|
// y.backward() // Method backward doesn't exist
|
2022-11-21 08:41:55 +08:00
|
|
|
|
2023-03-18 01:46:36 +08:00
|
|
|
let y = linear::<ADBackendDecorator<Backend>>(
|
|
|
|
Tensor::from_inner(x),
|
|
|
|
Tensor::from_inner(weight).require_grad(),
|
|
|
|
Tensor::from_inner(bias).require_grad(),
|
|
|
|
);
|
|
|
|
let grads = y.backward(); // Method exists
|
2022-10-07 05:44:04 +08:00
|
|
|
}
|
2023-03-18 01:46:36 +08:00
|
|
|
|
2022-10-07 05:44:04 +08:00
|
|
|
```
|
2022-10-05 08:30:03 +08:00
|
|
|
|
|
|
|
#### Module
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
The `Module` derive allows you to create your own neural network modules, similar to PyTorch. The
|
|
|
|
derive function only generates the necessary methods to essentially act as a parameter container for
|
|
|
|
your type, it makes no assumptions about how the forward pass is declared.
|
2022-10-05 08:30:03 +08:00
|
|
|
|
|
|
|
```rust
|
|
|
|
use burn::nn;
|
2023-04-03 05:37:01 +08:00
|
|
|
use burn::module::Module;
|
2022-10-05 08:30:03 +08:00
|
|
|
use burn::tensor::backend::Backend;
|
|
|
|
|
|
|
|
#[derive(Module, Debug)]
|
2023-03-18 01:46:36 +08:00
|
|
|
pub struct PositionWiseFeedForward<B: Backend> {
|
2023-04-03 05:37:01 +08:00
|
|
|
linear_inner: Linear<B>,
|
|
|
|
linear_outer: Linear<B>,
|
2023-03-18 01:46:36 +08:00
|
|
|
dropout: Dropout,
|
|
|
|
gelu: GELU,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<B: Backend> PositionWiseFeedForward<B> {
|
|
|
|
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
|
|
|
let x = self.linear_inner.forward(input);
|
|
|
|
let x = self.gelu.forward(x);
|
|
|
|
let x = self.dropout.forward(x);
|
|
|
|
|
|
|
|
self.linear_outer.forward(x)
|
|
|
|
}
|
2022-10-05 08:30:03 +08:00
|
|
|
}
|
|
|
|
```
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
Note that all fields declared in the struct must also implement the `Module` trait. The `Tensor`
|
|
|
|
struct doesn't implement `Module`, but `Param<Tensor<B, D>>` does.
|
2022-10-05 08:30:03 +08:00
|
|
|
|
2022-10-07 05:44:04 +08:00
|
|
|
#### Config
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
The `Config` derive lets you define serializable and deserializable configurations or
|
|
|
|
hyper-parameters for your [modules](#module) or any components.
|
2022-10-07 05:44:04 +08:00
|
|
|
|
|
|
|
```rust
|
|
|
|
use burn::config::Config;
|
|
|
|
|
|
|
|
#[derive(Config)]
|
2023-03-18 01:46:36 +08:00
|
|
|
pub struct PositionWiseFeedForwardConfig {
|
|
|
|
pub d_model: usize,
|
|
|
|
pub d_ff: usize,
|
|
|
|
#[config(default = 0.1)]
|
|
|
|
pub dropout: f64,
|
2022-10-07 05:44:04 +08:00
|
|
|
}
|
|
|
|
```
|
2023-03-18 01:46:36 +08:00
|
|
|
|
|
|
|
The derive also adds useful methods to your config, similar to a builder pattern.
|
2022-10-07 05:44:04 +08:00
|
|
|
|
|
|
|
```rust
|
2022-11-21 08:41:55 +08:00
|
|
|
fn main() {
|
2023-03-18 01:46:36 +08:00
|
|
|
let config = PositionWiseFeedForwardConfig::new(512, 2048);
|
|
|
|
println!("{}", config.d_model); // 512
|
|
|
|
println!("{}", config.d_ff); // 2048
|
|
|
|
println!("{}", config.dropout); // 0.1
|
|
|
|
let config = config.with_dropout(0.2);
|
|
|
|
println!("{}", config.dropout); // 0.2
|
2022-10-07 05:44:04 +08:00
|
|
|
}
|
|
|
|
```
|
|
|
|
|
|
|
|
#### Learner
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
The `Learner` is the main `struct` that let you train a neural network with support for `logging`,
|
|
|
|
`metric`, `checkpointing` and more. In order to create a learner, you must use the `LearnerBuilder`.
|
2022-10-07 05:44:04 +08:00
|
|
|
|
|
|
|
```rust
|
|
|
|
use burn::train::LearnerBuilder;
|
2022-11-21 08:41:55 +08:00
|
|
|
use burn::train::metric::{AccuracyMetric, LossMetric};
|
2023-04-03 05:37:01 +08:00
|
|
|
use burn::record::DefaultRecordSettings;
|
2022-10-07 05:44:04 +08:00
|
|
|
|
2022-11-21 08:41:55 +08:00
|
|
|
fn main() {
|
|
|
|
let dataloader_train = ...;
|
|
|
|
let dataloader_valid = ...;
|
|
|
|
|
|
|
|
let model = ...;
|
|
|
|
let optim = ...;
|
|
|
|
|
|
|
|
let learner = LearnerBuilder::new("/tmp/artifact_dir")
|
|
|
|
.metric_train_plot(AccuracyMetric::new())
|
|
|
|
.metric_valid_plot(AccuracyMetric::new())
|
|
|
|
.metric_train(LossMetric::new())
|
|
|
|
.metric_valid(LossMetric::new())
|
2023-04-03 05:37:01 +08:00
|
|
|
.with_file_checkpointer::<DefaultRecordSettings>(2)
|
2022-11-21 08:41:55 +08:00
|
|
|
.num_epochs(10)
|
|
|
|
.build(model, optim);
|
|
|
|
|
|
|
|
let _model_trained = learner.fit(dataloader_train, dataloader_valid);
|
|
|
|
}
|
2022-10-07 05:44:04 +08:00
|
|
|
```
|
|
|
|
|
2022-11-21 08:41:55 +08:00
|
|
|
See this [example](https://github.com/burn-rs/burn/tree/main/examples/mnist) for a real usage.
|
2022-10-07 05:44:04 +08:00
|
|
|
|
2023-06-28 21:54:10 +08:00
|
|
|
## Support for `no_std`
|
|
|
|
|
|
|
|
Burn, including its `burn-ndarray` backend, can work in a `no_std` environment, provided `alloc` is
|
|
|
|
available for the inference mode. To accomplish this, simply turn off the default features in `burn`
|
|
|
|
and `burn-ndarray` (which is the minimum requirement for running the inference mode). You can find a
|
|
|
|
reference example in
|
|
|
|
[burn-no-std-tests](https://github.com/burn-rs/burn/tree/main/examples/burn-no-std-tests).
|
2023-02-25 22:38:01 +08:00
|
|
|
|
2023-06-28 21:54:10 +08:00
|
|
|
The `burn-core` and `burn-tensor` crates also support `no_std` with `alloc`. These crates can be
|
|
|
|
directly added as dependencies if necessary, as they are reexported by the `burn` crate.
|
2023-02-25 22:38:01 +08:00
|
|
|
|
2023-06-28 21:54:10 +08:00
|
|
|
Please be aware that when using the `no_std` mode, a random seed will be generated at build time if
|
|
|
|
one hasn't been set using the `Backend::seed` method. Also, the
|
|
|
|
[spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used instead of
|
|
|
|
[std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) in this mode.
|
2023-02-25 22:38:01 +08:00
|
|
|
|
2023-04-26 00:46:29 +08:00
|
|
|
## Contributing
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
Before contributing, please take a moment to review our
|
|
|
|
[code of conduct](https://github.com/burn-rs/burn/tree/main/CODE-OF-CONDUCT.md). It's also highly
|
|
|
|
recommended to read our
|
|
|
|
[architecture document](https://github.com/burn-rs/burn/tree/main/ARCHITECTURE.md), which explains
|
|
|
|
our architectural decisions. Please see more details in our [contributing guide](/CONTRIBUTING.md).
|
|
|
|
|
|
|
|
## Disclamer
|
|
|
|
|
|
|
|
Burn is currently in active development, and there will be breaking changes. While any resulting
|
|
|
|
issues are likely to be easy to fix, there are no guarantees at this stage.
|
2023-04-26 00:46:29 +08:00
|
|
|
|
2023-04-03 05:37:01 +08:00
|
|
|
## Sponsors
|
|
|
|
|
2023-06-21 22:24:29 +08:00
|
|
|
You can sponsor the founder of Burn from his
|
|
|
|
[GitHub Sponsors profile](https://github.com/sponsors/nathanielsimard). The Burn-rs organization
|
|
|
|
doesn't yet have a fiscal entity, but other sponsor methods might become available as the project
|
|
|
|
grows.
|
2023-04-03 05:37:01 +08:00
|
|
|
|
|
|
|
Thanks to all current sponsors 🙏.
|
|
|
|
|
|
|
|
<a href="https://github.com/smallstepman"><img src="https://github.com/smallstepman.png" width="60px" style="border-radius: 50%;" alt="nathanielsimard" /></a>
|
|
|
|
|
2022-10-07 05:44:04 +08:00
|
|
|
## License
|
|
|
|
|
|
|
|
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).
|
2023-06-21 22:24:29 +08:00
|
|
|
See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details. Opening a pull
|
|
|
|
request is assumed to signal agreement with these licensing terms.
|