mirror of https://github.com/tracel-ai/burn.git
Doc/improve example (#64)
This commit is contained in:
parent
0c4c657854
commit
5b4855317b
|
@ -17,7 +17,7 @@ jobs:
|
|||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
toolchain: beta
|
||||
components: rustfmt
|
||||
override: true
|
||||
|
||||
|
@ -37,7 +37,7 @@ jobs:
|
|||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
toolchain: beta
|
||||
components: rustfmt
|
||||
override: true
|
||||
|
||||
|
@ -57,7 +57,7 @@ jobs:
|
|||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
toolchain: beta
|
||||
components: rustfmt
|
||||
override: true
|
||||
|
||||
|
@ -78,7 +78,7 @@ jobs:
|
|||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
toolchain: beta
|
||||
components: rustfmt
|
||||
override: true
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ jobs:
|
|||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
toolchain: beta
|
||||
components: rustfmt, clippy
|
||||
override: true
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ jobs:
|
|||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
toolchain: beta
|
||||
components: rustfmt, clippy
|
||||
override: true
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ jobs:
|
|||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
toolchain: beta
|
||||
components: rustfmt, clippy
|
||||
override: true
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
[workspace]
|
||||
|
||||
members = [
|
||||
"burn",
|
||||
"burn-derive",
|
||||
"burn-tensor",
|
||||
"burn-dataset",
|
||||
"examples/*",
|
||||
]
|
||||
|
|
|
@ -46,7 +46,7 @@ For now there is only one example, but more to come 💪.
|
|||
|
||||
#### MNIST
|
||||
|
||||
The [MNIST](https://github.com/burn-rs/burn/blob/main/burn/examples/mnist.rs) example is not just of small script that shows you how to train a basic model, but it's a quick one showing you how to:
|
||||
The [MNIST](https://github.com/burn-rs/burn/blob/main/examples/mnist) example is not just of small script that shows you how to train a basic model, but it's a quick one showing you how to:
|
||||
|
||||
* Define your own custom [module](#module) (MLP).
|
||||
* Create the data pipeline from a raw dataset to a batched multi-threaded fast DataLoader.
|
||||
|
|
|
@ -1,223 +0,0 @@
|
|||
use burn::config::Config;
|
||||
use burn::data::dataloader::batcher::Batcher;
|
||||
use burn::data::dataloader::DataLoaderBuilder;
|
||||
use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem};
|
||||
use burn::module::{Forward, Module, Param};
|
||||
use burn::nn;
|
||||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::momentum::MomentumConfig;
|
||||
use burn::optim::{Sgd, SgdConfig};
|
||||
use burn::tensor::backend::{ADBackend, Backend};
|
||||
use burn::tensor::loss::cross_entropy_with_logits;
|
||||
use burn::tensor::{Data, Tensor};
|
||||
use burn::train::metric::{AccuracyMetric, LossMetric};
|
||||
use burn::train::{ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep};
|
||||
use std::sync::Arc;
|
||||
|
||||
static ARTIFACT_DIR: &str = "/tmp/mnist-test-2";
|
||||
|
||||
#[derive(Config)]
|
||||
struct MnistConfig {
|
||||
#[config(default = 6)]
|
||||
num_epochs: usize,
|
||||
#[config(default = 128)]
|
||||
batch_size: usize,
|
||||
#[config(default = 8)]
|
||||
num_workers: usize,
|
||||
#[config(default = 42)]
|
||||
seed: u64,
|
||||
optimizer: SgdConfig,
|
||||
mlp: MlpConfig,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct Model<B: Backend> {
|
||||
mlp: Param<Mlp<B>>,
|
||||
input: Param<nn::Linear<B>>,
|
||||
output: Param<nn::Linear<B>>,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
struct MlpConfig {
|
||||
#[config(default = 3)]
|
||||
num_layers: usize,
|
||||
#[config(default = 0.5)]
|
||||
dropout: f64,
|
||||
#[config(default = 256)]
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct Mlp<B: Backend> {
|
||||
linears: Param<Vec<nn::Linear<B>>>,
|
||||
dropout: nn::Dropout,
|
||||
activation: nn::ReLU,
|
||||
}
|
||||
|
||||
impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for Mlp<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let mut x = input;
|
||||
|
||||
for linear in self.linears.iter() {
|
||||
x = linear.forward(x);
|
||||
x = self.dropout.forward(x);
|
||||
x = self.activation.forward(x);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for Model<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let mut x = input;
|
||||
|
||||
x = self.input.forward(x);
|
||||
x = self.mlp.forward(x);
|
||||
x = self.output.forward(x);
|
||||
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Forward<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn forward(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
let targets = item.targets;
|
||||
let output = self.forward(item.images);
|
||||
let loss = cross_entropy_with_logits(&output, &targets);
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
output,
|
||||
targets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ADBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward(item);
|
||||
TrainOutput::new(item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward(item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Mlp<B> {
|
||||
fn new(config: &MlpConfig) -> Self {
|
||||
let mut linears = Vec::with_capacity(config.num_layers);
|
||||
|
||||
for _ in 0..config.num_layers {
|
||||
let linear = nn::Linear::new(&nn::LinearConfig::new(config.dim, config.dim));
|
||||
linears.push(linear);
|
||||
}
|
||||
|
||||
Self {
|
||||
linears: Param::new(linears),
|
||||
dropout: nn::Dropout::new(&nn::DropoutConfig::new(0.3)),
|
||||
activation: nn::ReLU::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
fn new(config: &MnistConfig, d_input: usize, num_classes: usize) -> Self {
|
||||
let mlp = Mlp::new(&config.mlp);
|
||||
let output = nn::Linear::new(&nn::LinearConfig::new(config.mlp.dim, num_classes));
|
||||
let input = nn::Linear::new(&nn::LinearConfig::new(d_input, config.mlp.dim));
|
||||
|
||||
Self {
|
||||
mlp: Param::new(mlp),
|
||||
output: Param::new(output),
|
||||
input: Param::new(input),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MNISTBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct MNISTBatch<B: Backend> {
|
||||
images: Tensor<B, 2>,
|
||||
targets: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
||||
let images = items
|
||||
.iter()
|
||||
.map(|item| Data::<f32, 2>::from(item.image))
|
||||
.map(|data| Tensor::<B, 2>::from_data(data.convert()))
|
||||
.map(|tensor| tensor.reshape([1, 784]))
|
||||
.map(|tensor| tensor / 255)
|
||||
.collect();
|
||||
|
||||
let targets = items
|
||||
.iter()
|
||||
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
|
||||
.collect();
|
||||
|
||||
let images = Tensor::cat(images, 0).to_device(self.device).detach();
|
||||
let targets = Tensor::cat(targets, 0).to_device(self.device).detach();
|
||||
|
||||
MNISTBatch { images, targets }
|
||||
}
|
||||
}
|
||||
|
||||
fn run<B: ADBackend>(device: B::Device) {
|
||||
// Config
|
||||
let config_optimizer = SgdConfig::new()
|
||||
.with_learning_rate(2.5e-2)
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(0.05)))
|
||||
.with_momentum(Some(MomentumConfig::new().with_nesterov(true)));
|
||||
let config_mlp = MlpConfig::new();
|
||||
let config = MnistConfig::new(config_optimizer, config_mlp);
|
||||
B::seed(config.seed);
|
||||
|
||||
// Data
|
||||
let batcher_train = Arc::new(MNISTBatcher::<B> { device });
|
||||
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend> { device });
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(Arc::new(MNISTDataset::train()));
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.num_workers(config.num_workers)
|
||||
.build(Arc::new(MNISTDataset::test()));
|
||||
|
||||
// Model
|
||||
let optim = Sgd::new(&config.optimizer);
|
||||
let mut model = Model::new(&config, 784, 10);
|
||||
model.to_device(device);
|
||||
|
||||
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||
.metric_train_plot(AccuracyMetric::new())
|
||||
.metric_valid_plot(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<f32>(2)
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim);
|
||||
|
||||
let _model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(format!("{}/config.json", ARTIFACT_DIR).as_str())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
use burn::tensor::backend::{NdArrayADBackend, NdArrayDevice};
|
||||
|
||||
let device = NdArrayDevice::Cpu;
|
||||
run::<NdArrayADBackend<f32>>(device);
|
||||
println!("Done.");
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "mnist"
|
||||
version = "0.1.0"
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
license = "MIT/Apache-2.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../burn" }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
|
@ -0,0 +1,42 @@
|
|||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem},
|
||||
tensor::{backend::Backend, Data, Tensor},
|
||||
};
|
||||
|
||||
pub struct MNISTBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MNISTBatch<B: Backend> {
|
||||
pub images: Tensor<B, 2>,
|
||||
pub targets: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MNISTBatcher<B> {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
||||
let images = items
|
||||
.iter()
|
||||
.map(|item| Data::<f32, 2>::from(item.image))
|
||||
.map(|data| Tensor::<B, 2>::from_data(data.convert()))
|
||||
.map(|tensor| tensor.reshape([1, 784]))
|
||||
.map(|tensor| tensor / 255)
|
||||
.collect();
|
||||
|
||||
let targets = items
|
||||
.iter()
|
||||
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
|
||||
.collect();
|
||||
|
||||
let images = Tensor::cat(images, 0).to_device(self.device).detach();
|
||||
let targets = Tensor::cat(targets, 0).to_device(self.device).detach();
|
||||
|
||||
MNISTBatch { images, targets }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
pub mod data;
|
||||
pub mod model;
|
||||
|
||||
mod training;
|
||||
|
||||
fn main() {
|
||||
use burn::tensor::backend::{NdArrayADBackend, NdArrayDevice};
|
||||
|
||||
let device = NdArrayDevice::Cpu;
|
||||
training::run::<NdArrayADBackend<f32>>(device);
|
||||
println!("Done.");
|
||||
}
|
|
@ -0,0 +1,135 @@
|
|||
use crate::data::MNISTBatch;
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Forward, Module, Param},
|
||||
nn,
|
||||
optim::SgdConfig,
|
||||
tensor::{
|
||||
backend::{ADBackend, Backend},
|
||||
loss::cross_entropy_with_logits,
|
||||
Tensor,
|
||||
},
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MnistConfig {
|
||||
#[config(default = 6)]
|
||||
pub num_epochs: usize,
|
||||
#[config(default = 128)]
|
||||
pub batch_size: usize,
|
||||
#[config(default = 8)]
|
||||
pub num_workers: usize,
|
||||
#[config(default = 42)]
|
||||
pub seed: u64,
|
||||
pub optimizer: SgdConfig,
|
||||
pub mlp: MlpConfig,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
mlp: Param<Mlp<B>>,
|
||||
input: Param<nn::Linear<B>>,
|
||||
output: Param<nn::Linear<B>>,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MlpConfig {
|
||||
#[config(default = 3)]
|
||||
pub num_layers: usize,
|
||||
#[config(default = 0.5)]
|
||||
pub dropout: f64,
|
||||
#[config(default = 256)]
|
||||
pub dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Mlp<B: Backend> {
|
||||
linears: Param<Vec<nn::Linear<B>>>,
|
||||
dropout: nn::Dropout,
|
||||
activation: nn::ReLU,
|
||||
}
|
||||
|
||||
impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for Mlp<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let mut x = input;
|
||||
|
||||
for linear in self.linears.iter() {
|
||||
x = linear.forward(x);
|
||||
x = self.dropout.forward(x);
|
||||
x = self.activation.forward(x);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for Model<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let mut x = input;
|
||||
|
||||
x = self.input.forward(x);
|
||||
x = self.mlp.forward(x);
|
||||
x = self.output.forward(x);
|
||||
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Forward<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn forward(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
let targets = item.targets;
|
||||
let output = self.forward(item.images);
|
||||
let loss = cross_entropy_with_logits(&output, &targets);
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
output,
|
||||
targets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ADBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
|
||||
let item = self.forward(item);
|
||||
TrainOutput::new(item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
|
||||
fn step(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward(item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Mlp<B> {
|
||||
pub fn new(config: &MlpConfig) -> Self {
|
||||
let mut linears = Vec::with_capacity(config.num_layers);
|
||||
|
||||
for _ in 0..config.num_layers {
|
||||
let linear = nn::Linear::new(&nn::LinearConfig::new(config.dim, config.dim));
|
||||
linears.push(linear);
|
||||
}
|
||||
|
||||
Self {
|
||||
linears: Param::new(linears),
|
||||
dropout: nn::Dropout::new(&nn::DropoutConfig::new(0.3)),
|
||||
activation: nn::ReLU::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn new(config: &MnistConfig, d_input: usize, num_classes: usize) -> Self {
|
||||
let mlp = Mlp::new(&config.mlp);
|
||||
let output = nn::Linear::new(&nn::LinearConfig::new(config.mlp.dim, num_classes));
|
||||
let input = nn::Linear::new(&nn::LinearConfig::new(d_input, config.mlp.dim));
|
||||
|
||||
Self {
|
||||
mlp: Param::new(mlp),
|
||||
output: Param::new(output),
|
||||
input: Param::new(input),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
use crate::data::MNISTBatcher;
|
||||
use crate::model::{MlpConfig, MnistConfig, Model};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
|
||||
module::Module,
|
||||
optim::{decay::WeightDecayConfig, momentum::MomentumConfig, Sgd, SgdConfig},
|
||||
tensor::backend::ADBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
static ARTIFACT_DIR: &str = "/tmp/mnist-test-2";
|
||||
|
||||
pub fn run<B: ADBackend>(device: B::Device) {
|
||||
// Config
|
||||
let config_optimizer = SgdConfig::new()
|
||||
.with_learning_rate(2.5e-2)
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(0.05)))
|
||||
.with_momentum(Some(MomentumConfig::new().with_nesterov(true)));
|
||||
let config_mlp = MlpConfig::new();
|
||||
let config = MnistConfig::new(config_optimizer, config_mlp);
|
||||
B::seed(config.seed);
|
||||
|
||||
// Data
|
||||
let batcher_train = Arc::new(MNISTBatcher::<B>::new(device));
|
||||
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend>::new(device));
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(Arc::new(MNISTDataset::train()));
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.num_workers(config.num_workers)
|
||||
.build(Arc::new(MNISTDataset::test()));
|
||||
|
||||
// Model
|
||||
let optim = Sgd::new(&config.optimizer);
|
||||
let mut model = Model::new(&config, 784, 10);
|
||||
model.to_device(device);
|
||||
|
||||
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||
.metric_train_plot(AccuracyMetric::new())
|
||||
.metric_valid_plot(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<f32>(2)
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim);
|
||||
|
||||
let _model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(format!("{}/config.json", ARTIFACT_DIR).as_str())
|
||||
.unwrap();
|
||||
}
|
Loading…
Reference in New Issue