2022-09-05 02:22:56 +08:00
< div align = "center" >
2022-11-29 09:08:36 +08:00
< img src = "https://raw.githubusercontent.com/burn-rs/burn/main/assets/logo-burn-full.png" width = "200px" / >
2022-07-28 04:15:48 +08:00
2022-09-05 02:22:56 +08:00
[![Current Crates.io Version ](https://img.shields.io/crates/v/burn.svg )](https://crates.io/crates/burn)
2022-11-21 08:41:55 +08:00
[![Test Status ](https://github.com/burn-rs/burn/actions/workflows/test.yml/badge.svg )](https://github.com/burn-rs/burn/actions/workflows/test.yml)
2022-09-05 09:29:34 +08:00
[![Documentation ](https://docs.rs/burn/badge.svg )](https://docs.rs/burn)
2022-11-05 22:00:52 +08:00
[![Rust Version ](https://img.shields.io/badge/Rust-1.65.0-blue )](https://releases.rs/docs/released/1.65.0)
2022-09-05 02:22:56 +08:00
[![license ](https://shields.io/badge/license-MIT%2FApache--2.0-blue )](https://github.com/burn-rs/burn/blob/master/LICENSE)
2022-07-28 04:15:48 +08:00
2022-11-10 09:45:58 +08:00
> This library aims to be a complete deep learning framework with extreme flexibility written in Rust.
2022-10-05 08:30:03 +08:00
> The goal would be to satisfy researchers as well as practitioners making it easier to experiment, train and deploy your models.
2022-09-05 02:22:56 +08:00
< div align = "left" >
2022-10-07 05:44:04 +08:00
__Sections__
* [Features ](#features )
* [Get Started ](#get-started )
* [Examples ](#examples )
* [Components ](#components )
* [Backend ](#backend )
* [Tensor ](#tensor )
* [Module ](#module )
* [Config ](#config )
* [Learner ](#learner )
* [License ](#license )
2022-10-05 08:30:03 +08:00
## Features
2022-11-21 08:41:55 +08:00
* Flexible and intuitive custom neural network [module ](#module ) 🔥
* [Training ](#learner ) with full support for `metric` , `logging` and `checkpointing` 📈
* [Tensor ](#tensor ) crate with backends as pluging 🔧
* [Tch ](https://github.com/burn-rs/burn/tree/main/burn-tch ) backend with CPU/GPU support 🚀
* [NdArray ](https://github.com/burn-rs/burn/tree/main/burn-ndarray ) backend with fast compile time 👌
* [Autodiff ](https://github.com/burn-rs/burn/tree/main/burn-autodiff ) backend making any backend differentiable 🌟
* [Dataset ](https://github.com/burn-rs/burn/tree/main/burn-dataset ) crate with multiple utilities and sources 📚
2022-10-07 05:44:04 +08:00
## Get Started
2022-11-21 08:41:55 +08:00
The best way to get started with `burn` is to clone the repo and play with the [examples ](#examples ).
This may also be a good idea to take a look the main [components ](#components ) of `burn` to get a quick overview of the fundamental building blocks.
2022-10-07 05:44:04 +08:00
### Examples
2022-10-05 08:30:03 +08:00
2022-11-21 08:41:55 +08:00
* [MNIST ](https://github.com/burn-rs/burn/tree/main/examples/mnist ) train a model on CPU/GPU using different backends.
2022-12-03 06:42:49 +08:00
* [Text Classification ](https://github.com/burn-rs/burn/tree/main/examples/text-classification ) train a transformer encoder from scratch on GPU.
2022-11-10 10:32:51 +08:00
2022-10-05 08:30:03 +08:00
### Components
Knowing the main components will be of great help when starting playing with `burn` .
2022-10-07 05:44:04 +08:00
#### Backend
2022-10-05 08:30:03 +08:00
Almost everything is based on the `Backend` trait, which allows to run tensor operations with different implementations without having to change your code.
2022-11-21 08:41:55 +08:00
A backend does not necessary have autodiff capabilities, the `ADBackend` trait is there to specify when autodiff is required.
2022-10-05 08:30:03 +08:00
#### Tensor
The `Tensor` struct is at the core of the `burn` framework.
It takes two generic parameters, the `Backend` and the number of dimensions `D` ,
2022-11-21 08:41:55 +08:00
Backpropagation is also supported on any backend by making them auto differentiable using a simple decorator.
2022-10-05 08:30:03 +08:00
```rust
2022-11-21 08:41:55 +08:00
use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::{Distribution, Tensor};
use burn_autodiff::ADBackendDecorator;
use burn_ndarray::NdArrayBackend;
use burn_tch::TchBackend;
fn simple_function< B: Backend > () -> Tensor< B , 2 > {
let x = Tensor::< B , 2 > ::random([3, 3], Distribution::Standard);
let y = Tensor::< B , 2 > ::random([3, 3], Distribution::Standard);
x.matmul(& y)
}
fn simple_function_grads< B: ADBackend > () -> B::Gradients {
let z = simple_function::< B > ();
2022-10-05 08:30:03 +08:00
2022-11-21 08:41:55 +08:00
z.backward()
2022-10-07 05:44:04 +08:00
}
2022-10-05 08:30:03 +08:00
2022-10-07 05:44:04 +08:00
fn main() {
2022-11-21 08:41:55 +08:00
let _z = simple_function::< NdArrayBackend < f32 > >(); // Compiles
let _z = simple_function::< TchBackend < f32 > >(); // Compiles
let _grads = simple_function_grads::< NdArrayBackend < f32 > >(); // Doesn't compile
let _grads = simple_function_grads::< TchBackend < f32 > >(); // Doesn't compile
type ADNdArrayBackend = ADBackendDecorator< NdArrayBackend < f32 > >;
type ADTchBackend = ADBackendDecorator< TchBackend < f32 > >;
let _grads = simple_function_grads::< ADNdArrayBackend > (); // Compiles
let _grads = simple_function_grads::< ADTchBackend > (); // Compiles
2022-10-07 05:44:04 +08:00
}
```
2022-10-05 08:30:03 +08:00
#### Module
2022-11-21 08:41:55 +08:00
The `Module` derive let your create your own neural network modules similar to PyTorch.
2022-10-05 08:30:03 +08:00
```rust
use burn::nn;
use burn::module::{Param, Module};
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
struct MyModule< B: Backend > {
my_param: Param< nn::Linear < B > >,
repeat: usize,
}
```
Note that only the fields wrapped inside `Param` are updated during training, and the other ones should implement `Clone` .
2022-10-07 05:44:04 +08:00
#### Config
2022-11-07 06:06:24 +08:00
The `Config` derive lets you define serializable and deserializable configurations or hyper-parameters for your [modules ](#module ) or any components.
2022-10-07 05:44:04 +08:00
```rust
use burn::config::Config;
#[derive(Config)]
struct MyConfig {
#[config(default = 1.0e-6)]
2022-11-07 06:06:24 +08:00
pub epsilon: usize,
2022-10-07 05:44:04 +08:00
pub dim: usize,
}
```
2022-11-07 06:06:24 +08:00
The derive also adds useful methods to your config.
2022-10-07 05:44:04 +08:00
```rust
2022-11-21 08:41:55 +08:00
fn main() {
2022-10-07 05:44:04 +08:00
let config = MyConfig::new(100);
2022-11-07 06:06:24 +08:00
println!("{}", config.epsilon); // 1.0.e-6
2022-10-07 05:44:04 +08:00
println!("{}", config.dim); // 100
2022-11-07 06:06:24 +08:00
let config = MyConfig::new(100).with_epsilon(1.0e-8);
println!("{}", config.epsilon); // 1.0.e-8
2022-10-07 05:44:04 +08:00
}
```
#### Learner
The `Learner` is the main `struct` that let you train a neural network with support for `logging` , `metric` , `checkpointing` and more.
In order to create a learner, you must use the `LearnerBuilder` .
```rust
use burn::train::LearnerBuilder;
2022-11-21 08:41:55 +08:00
use burn::train::metric::{AccuracyMetric, LossMetric};
2022-10-07 05:44:04 +08:00
2022-11-21 08:41:55 +08:00
fn main() {
let dataloader_train = ...;
let dataloader_valid = ...;
let model = ...;
let optim = ...;
let learner = LearnerBuilder::new("/tmp/artifact_dir")
.metric_train_plot(AccuracyMetric::new())
.metric_valid_plot(AccuracyMetric::new())
.metric_train(LossMetric::new())
.metric_valid(LossMetric::new())
.with_file_checkpointer::< f32 > (2)
.num_epochs(10)
.build(model, optim);
let _model_trained = learner.fit(dataloader_train, dataloader_valid);
}
2022-10-07 05:44:04 +08:00
```
2022-11-21 08:41:55 +08:00
See this [example ](https://github.com/burn-rs/burn/tree/main/examples/mnist ) for a real usage.
2022-10-07 05:44:04 +08:00
## License
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).
See [LICENSE-APACHE ](./LICENSE-APACHE ) and [LICENSE-MIT ](./LICENSE-MIT ) for details.
Opening a pull request is assumed to signal agreement with these licensing terms.