mirror of https://github.com/tracel-ai/burn.git
docs(book-&-examples): modify book and examples with new `prelude` module (#1372)
This commit is contained in:
parent
57887e7a47
commit
330552afb4
|
@ -17,7 +17,7 @@ at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/exa
|
|||
```rust , ignore
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
|
|
|
@ -30,14 +30,12 @@ Let us start by defining our model struct in a new file `src/model.rs`.
|
|||
|
||||
```rust , ignore
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
|
@ -11,7 +11,6 @@ to your types, allowing you to define default values with ease. Additionally, al
|
|||
serialized, reducing potential bugs when upgrading versions and improving reproducibility.
|
||||
|
||||
```rust , ignore
|
||||
#[derive(Config)]
|
||||
use burn::config::Config;
|
||||
|
||||
#[derive(Config)]
|
||||
|
|
|
@ -5,7 +5,6 @@ derive function only generates the necessary methods to essentially act as a par
|
|||
your type, it makes no assumptions about how the forward pass is declared.
|
||||
|
||||
```rust, ignore
|
||||
use burn::nn;
|
||||
use burn::module::Module;
|
||||
use burn::tensor::backend::Backend;
|
||||
|
||||
|
|
|
@ -168,6 +168,37 @@ While the previous example is somewhat trivial, the upcoming
|
|||
basic workflow section will walk you through a much more relevant example for
|
||||
deep learning applications.
|
||||
|
||||
## Using `prelude`
|
||||
|
||||
Burn comes with a variety of things in its core library.
|
||||
When creating a new model or using an existing one for inference,
|
||||
you may need to import every single component you used, which could be a little verbose.
|
||||
|
||||
To address it, a `prelude` module is provided, allowing you to easily import commonly used structs and macros as a group:
|
||||
|
||||
```rust, ignore
|
||||
use burn::prelude::*;
|
||||
```
|
||||
|
||||
which is equal to:
|
||||
|
||||
```rust, ignore
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend, Bool, Data, Device, ElementConversion, Float, Int, Shape, Tensor,
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
<div class="warning">
|
||||
|
||||
For the sake of simplicity, the subsequent chapters of this book will all use this form of importing. However, this does not include the content in the [Building Blocks](./building-blocks) chapter, as explicit importing aids users in grasping the usage of particular structures and macros.
|
||||
|
||||
</div>
|
||||
|
||||
## Running examples
|
||||
|
||||
Many additional Burn examples available in the
|
||||
|
|
|
@ -48,9 +48,8 @@ something like this:
|
|||
|
||||
```rust
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
|
@ -71,6 +71,8 @@ pub mod prelude {
|
|||
config::Config,
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{backend::Backend, Data, Device, ElementConversion, Tensor},
|
||||
tensor::{
|
||||
backend::Backend, Bool, Data, Device, ElementConversion, Float, Int, Shape, Tensor,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::optim::momentum::MomentumConfig;
|
||||
use burn::optim::SgdConfig;
|
||||
use burn::{
|
||||
backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
},
|
||||
optim::{momentum::MomentumConfig, SgdConfig},
|
||||
};
|
||||
use custom_image_dataset::training::{train, TrainingConfig};
|
||||
|
||||
pub fn run() {
|
||||
|
@ -25,10 +28,13 @@ mod tch_gpu {
|
|||
|
||||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn::backend::wgpu::{Wgpu, WgpuDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::optim::momentum::MomentumConfig;
|
||||
use burn::optim::SgdConfig;
|
||||
use burn::{
|
||||
backend::{
|
||||
wgpu::{Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
},
|
||||
optim::{momentum::MomentumConfig, SgdConfig},
|
||||
};
|
||||
use custom_image_dataset::training::{train, TrainingConfig};
|
||||
|
||||
pub fn run() {
|
||||
|
|
|
@ -3,7 +3,7 @@ use burn::{
|
|||
dataloader::batcher::Batcher,
|
||||
dataset::vision::{Annotation, ImageDatasetItem, PixelDepth},
|
||||
},
|
||||
tensor::{backend::Backend, Data, Device, ElementConversion, Int, Shape, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
// CIFAR-10 mean and std values
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
use burn::{
|
||||
module::Module,
|
||||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{MaxPool2d, MaxPool2dConfig},
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
|
||||
},
|
||||
tensor::{backend::Backend, Device, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
/// Basic convolutional neural network with VGG-style blocks.
|
||||
|
|
|
@ -5,21 +5,16 @@ use crate::{
|
|||
dataset::CIFAR10Loader,
|
||||
model::Cnn,
|
||||
};
|
||||
use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset};
|
||||
use burn::train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
};
|
||||
use burn::{
|
||||
self,
|
||||
config::Config,
|
||||
module::Module,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset},
|
||||
nn::loss::CrossEntropyLossConfig,
|
||||
optim::SgdConfig,
|
||||
prelude::*,
|
||||
record::CompactRecorder,
|
||||
tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Int, Tensor,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use burn::backend::wgpu::WgpuDevice;
|
||||
use burn::backend::{Autodiff, Wgpu};
|
||||
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
|
||||
|
||||
fn main() {
|
||||
custom_renderer::run::<Autodiff<Wgpu>>(WgpuDevice::default());
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use burn::train::LearnerBuilder;
|
||||
use burn::{
|
||||
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
optim::AdamConfig,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
renderer::{MetricState, MetricsRenderer, TrainingProgress},
|
||||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
use guide::{data::MnistBatcher, model::ModelConfig};
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use burn::backend::wgpu::WgpuDevice;
|
||||
use burn::backend::{Autodiff, Wgpu};
|
||||
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
|
||||
|
||||
fn main() {
|
||||
custom_training_loop::run::<Autodiff<Wgpu>>(WgpuDevice::default());
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
module::AutodiffModule,
|
||||
nn::loss::CrossEntropyLoss,
|
||||
optim::{AdamConfig, GradientsParams, Optimizer},
|
||||
tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
ElementConversion, Int, Tensor,
|
||||
},
|
||||
prelude::*,
|
||||
tensor::backend::AutodiffBackend,
|
||||
};
|
||||
use guide::{
|
||||
data::{MnistBatch, MnistBatcher},
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
use crate::FloatTensor;
|
||||
|
||||
use super::{AutodiffBackend, Backend};
|
||||
use burn::backend::autodiff::{
|
||||
use burn::{
|
||||
backend::{
|
||||
autodiff::{
|
||||
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
|
||||
grads::Gradients,
|
||||
ops::{broadcast_shape, Backward, Ops, OpsKind},
|
||||
Autodiff, NodeID,
|
||||
},
|
||||
wgpu::{compute::WgpuRuntime, FloatElement, GraphicsApi, IntElement, JitBackend},
|
||||
},
|
||||
tensor::Shape,
|
||||
};
|
||||
use burn::backend::wgpu::compute::WgpuRuntime;
|
||||
use burn::backend::wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend};
|
||||
use burn::tensor::Shape;
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
|
||||
for Autodiff<JitBackend<WgpuRuntime<G, F, I>>>
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use crate::FloatTensor;
|
||||
|
||||
use super::Backend;
|
||||
use burn::backend::wgpu::{
|
||||
use burn::{
|
||||
backend::wgpu::{
|
||||
compute::{DynamicKernel, WgpuRuntime, WorkGroup},
|
||||
kernel::{
|
||||
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
|
||||
|
@ -9,8 +10,9 @@ use burn::backend::wgpu::{
|
|||
kernel_wgsl,
|
||||
tensor::JitTensor,
|
||||
FloatElement, GraphicsApi, IntElement, JitBackend,
|
||||
},
|
||||
tensor::Shape,
|
||||
};
|
||||
use burn::tensor::Shape;
|
||||
use derive_new::new;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use burn::backend::wgpu::AutoGraphicsApi;
|
||||
use burn::backend::{Autodiff, Wgpu};
|
||||
use burn::data::dataset::Dataset;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::{
|
||||
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
|
||||
data::dataset::Dataset,
|
||||
optim::AdamConfig,
|
||||
};
|
||||
use guide::{model::ModelConfig, training::TrainingConfig};
|
||||
|
||||
fn main() {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use crate::{data::MnistBatcher, training::TrainingConfig};
|
||||
use burn::data::dataset::vision::MnistItem;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::batcher::Batcher,
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
prelude::*,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
|
||||
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
|
@ -2,22 +2,16 @@ use crate::{
|
|||
data::{MnistBatch, MnistBatcher},
|
||||
model::{Model, ModelConfig},
|
||||
};
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
};
|
||||
use burn::{
|
||||
self,
|
||||
config::Config,
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
module::Module,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
nn::loss::CrossEntropyLossConfig,
|
||||
optim::AdamConfig,
|
||||
prelude::*,
|
||||
record::CompactRecorder,
|
||||
tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Int, Tensor,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
/// This build script generates the model code from the ONNX file and the labels from the text file.
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::Path;
|
||||
use std::{
|
||||
env,
|
||||
fs::File,
|
||||
io::{BufRead, BufReader, Write},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use burn_import::burn::graph::RecordType;
|
||||
use burn_import::onnx::ModelGen;
|
||||
use burn_import::{burn::graph::RecordType, onnx::ModelGen};
|
||||
|
||||
const LABEL_SOURCE_FILE: &str = "src/model/label.txt";
|
||||
const LABEL_DEST_FILE: &str = "model/label.rs";
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn::tensor::{backend::Backend, Tensor};
|
||||
use burn::prelude::*;
|
||||
|
||||
// Values are taken from the [ONNX SqueezeNet]
|
||||
// (https://github.com/onnx/models/tree/main/vision/classification/squeezenet#preprocessing)
|
||||
|
|
|
@ -8,10 +8,7 @@ use core::convert::Into;
|
|||
|
||||
use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel};
|
||||
|
||||
use burn::{
|
||||
backend::NdArray,
|
||||
tensor::{activation::softmax, backend::Backend, Tensor},
|
||||
};
|
||||
use burn::{backend::NdArray, prelude::*, tensor::activation::softmax};
|
||||
|
||||
use burn_candle::Candle;
|
||||
use burn_wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
|
|
|
@ -3,9 +3,8 @@
|
|||
// Originally copied from the burn/examples/mnist package
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{self, BatchNorm, PaddingConfig2d},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
nn::{BatchNorm, PaddingConfig2d},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::model::Model;
|
||||
use burn::module::Module;
|
||||
use burn::record::BinBytesRecorder;
|
||||
use burn::record::FullPrecisionSettings;
|
||||
use burn::record::Recorder;
|
||||
use burn::{
|
||||
module::Module,
|
||||
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
|
||||
};
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
use burn::backend::wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
feature = "ndarray-blas-accelerate",
|
||||
))]
|
||||
mod ndarray {
|
||||
use burn::backend::ndarray::{NdArray, NdArrayDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
ndarray::{NdArray, NdArrayDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
|
@ -17,8 +19,10 @@ mod ndarray {
|
|||
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
|
@ -33,8 +37,10 @@ mod tch_gpu {
|
|||
|
||||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn::backend::wgpu::{Wgpu, WgpuDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
wgpu::{Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
|
@ -45,8 +51,10 @@ mod wgpu {
|
|||
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
mod tch_cpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
use crate::data::MnistBatch;
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{self, loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
|
||||
tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
},
|
||||
nn::{loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
|
||||
prelude::*,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,20 +1,17 @@
|
|||
use crate::data::MnistBatcher;
|
||||
use crate::model::Model;
|
||||
use crate::{data::MnistBatcher, model::Model};
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
|
||||
use burn::train::metric::store::{Aggregate, Direction, Split};
|
||||
use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
|
||||
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
|
||||
optim::{decay::WeightDecayConfig, AdamConfig},
|
||||
prelude::*,
|
||||
record::{CompactRecorder, NoStdTrainingRecorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
LearnerBuilder,
|
||||
metric::{
|
||||
store::{Aggregate, Direction, Split},
|
||||
AccuracyMetric, CpuMemory, CpuTemperature, CpuUse, LossMetric,
|
||||
},
|
||||
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Dim, Distribution, NamedDim, NamedTensor};
|
||||
use burn::tensor::{backend::Backend, Dim, Distribution, NamedDim, NamedTensor};
|
||||
|
||||
NamedDim!(Batch);
|
||||
NamedDim!(SeqLength);
|
||||
|
|
|
@ -37,8 +37,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"// Import packages\n",
|
||||
"use burn::tensor::Tensor;\n",
|
||||
"use burn::tensor::backend::Backend;\n",
|
||||
"use burn::prelude::*;\n",
|
||||
"use burn_ndarray::NdArray;\n",
|
||||
"\n",
|
||||
"// Type alias for the backend\n",
|
||||
|
@ -111,7 +110,11 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "rust"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use std::env::args;
|
||||
|
||||
use burn::backend::ndarray::NdArray;
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::data::dataset::Dataset;
|
||||
use burn::{
|
||||
backend::ndarray::NdArray,
|
||||
data::dataset::{vision::MnistDataset, Dataset},
|
||||
tensor::Tensor,
|
||||
};
|
||||
|
||||
use onnx_inference::mnist::Model;
|
||||
|
||||
|
|
|
@ -3,8 +3,10 @@
|
|||
/// 2. Saves the model record to a file using the `NamedMpkFileRecorder`.
|
||||
use std::path::Path;
|
||||
|
||||
use burn::backend::NdArray;
|
||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn::{
|
||||
backend::NdArray,
|
||||
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
|
||||
};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
// Basic backend type (not used directly here).
|
||||
|
|
|
@ -1,20 +1,14 @@
|
|||
use std::env;
|
||||
use std::path::Path;
|
||||
|
||||
use burn::nn::conv::Conv2d;
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
use burn::nn::BatchNorm;
|
||||
use burn::nn::BatchNormConfig;
|
||||
use burn::nn::Linear;
|
||||
use burn::nn::LinearConfig;
|
||||
use burn::record::FullPrecisionSettings;
|
||||
use burn::record::NamedMpkFileRecorder;
|
||||
use burn::record::Recorder;
|
||||
use burn::tensor::activation::log_softmax;
|
||||
use burn::tensor::activation::relu;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
nn::{
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
BatchNorm, BatchNormConfig, Linear, LinearConfig,
|
||||
},
|
||||
prelude::*,
|
||||
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
|
||||
tensor::activation::{log_softmax, relu},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use std::env::args;
|
||||
use std::path::Path;
|
||||
|
||||
use burn::backend::ndarray::NdArray;
|
||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use burn::data::dataset::vision::MnistDataset;
|
||||
use burn::data::dataset::Dataset;
|
||||
use burn::{
|
||||
backend::ndarray::NdArray,
|
||||
data::dataset::{vision::MnistDataset, Dataset},
|
||||
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
|
||||
tensor::Tensor,
|
||||
};
|
||||
|
||||
use model::Model;
|
||||
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
feature = "ndarray-blas-accelerate",
|
||||
))]
|
||||
mod ndarray {
|
||||
use burn::backend::ndarray::{NdArray, NdArrayDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
ndarray::{NdArray, NdArrayDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use regression::training;
|
||||
|
||||
pub fn run() {
|
||||
|
@ -17,8 +19,10 @@ mod ndarray {
|
|||
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use regression::training;
|
||||
|
||||
pub fn run() {
|
||||
|
@ -33,8 +37,10 @@ mod tch_gpu {
|
|||
|
||||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn::backend::wgpu::{Wgpu, WgpuDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
wgpu::{Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use regression::training;
|
||||
|
||||
pub fn run() {
|
||||
|
@ -45,8 +51,10 @@ mod wgpu {
|
|||
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
mod tch_cpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
use regression::training;
|
||||
pub fn run() {
|
||||
let device = LibTorchDevice::Cpu;
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
use burn::data::dataloader::batcher::Batcher;
|
||||
use burn::data::dataset::transform::{PartialDataset, ShuffledDataset};
|
||||
use burn::data::dataset::{Dataset, HuggingfaceDatasetLoader, SqliteDataset};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::Tensor;
|
||||
use burn::{
|
||||
data::{
|
||||
dataloader::batcher::Batcher,
|
||||
dataset::{
|
||||
transform::{PartialDataset, ShuffledDataset},
|
||||
Dataset, HuggingfaceDatasetLoader, SqliteDataset,
|
||||
},
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct DiabetesItem {
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
use crate::dataset::DiabetesBatch;
|
||||
use burn::config::Config;
|
||||
use burn::nn::loss::Reduction::Mean;
|
||||
use burn::nn::Relu;
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{loss::MseLoss, Linear, LinearConfig},
|
||||
tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
nn::{
|
||||
loss::{MseLoss, Reduction::Mean},
|
||||
Linear, LinearConfig, Relu,
|
||||
},
|
||||
prelude::*,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{RegressionOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
use crate::dataset::{DiabetesBatcher, DiabetesDataset};
|
||||
use crate::model::RegressionModelConfig;
|
||||
use burn::data::dataset::Dataset;
|
||||
use burn::module::Module;
|
||||
use burn::optim::SgdConfig;
|
||||
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
|
||||
use burn::train::metric::store::{Aggregate, Direction, Split};
|
||||
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
|
||||
optim::SgdConfig,
|
||||
prelude::*,
|
||||
record::{CompactRecorder, NoStdTrainingRecorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{metric::LossMetric, LearnerBuilder},
|
||||
train::{
|
||||
metric::store::{Aggregate, Direction, Split},
|
||||
metric::LossMetric,
|
||||
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
|
||||
},
|
||||
};
|
||||
|
||||
static ARTIFACT_DIR: &str = "/tmp/burn-example-regression";
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use burn::nn::transformer::TransformerEncoderConfig;
|
||||
use burn::optim::{decay::WeightDecayConfig, AdamConfig};
|
||||
use burn::tensor::backend::AutodiffBackend;
|
||||
use burn::{
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::{decay::WeightDecayConfig, AdamConfig},
|
||||
tensor::backend::AutodiffBackend,
|
||||
};
|
||||
|
||||
use text_classification::training::ExperimentConfig;
|
||||
use text_classification::AgNewsDataset;
|
||||
use text_classification::{training::ExperimentConfig, AgNewsDataset};
|
||||
|
||||
#[cfg(not(feature = "f16"))]
|
||||
#[allow(dead_code)]
|
||||
|
@ -35,8 +36,10 @@ pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
|
|||
feature = "ndarray-blas-accelerate",
|
||||
))]
|
||||
mod ndarray {
|
||||
use burn::backend::ndarray::{NdArray, NdArrayDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
ndarray::{NdArray, NdArrayDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -47,8 +50,10 @@ mod ndarray {
|
|||
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -64,8 +69,10 @@ mod tch_gpu {
|
|||
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
mod tch_cpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -77,8 +84,10 @@ mod tch_cpu {
|
|||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use crate::{launch, ElemType};
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
|
||||
|
|
|
@ -34,8 +34,10 @@ pub fn launch<B: AutodiffBackend>(device: B::Device) {
|
|||
feature = "ndarray-blas-accelerate",
|
||||
))]
|
||||
mod ndarray {
|
||||
use burn::backend::ndarray::{NdArray, NdArrayDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
ndarray::{NdArray, NdArrayDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -46,8 +48,10 @@ mod ndarray {
|
|||
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -63,8 +67,10 @@ mod tch_gpu {
|
|||
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
mod tch_cpu {
|
||||
use burn::backend::tch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
tch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -75,8 +81,10 @@ mod tch_cpu {
|
|||
|
||||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use burn::nn::transformer::TransformerEncoderConfig;
|
||||
use burn::optim::{decay::WeightDecayConfig, AdamConfig};
|
||||
use burn::tensor::backend::AutodiffBackend;
|
||||
use burn::{
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::{decay::WeightDecayConfig, AdamConfig},
|
||||
tensor::backend::AutodiffBackend,
|
||||
};
|
||||
|
||||
use text_classification::training::ExperimentConfig;
|
||||
use text_classification::DbPediaDataset;
|
||||
use text_classification::{training::ExperimentConfig, DbPediaDataset};
|
||||
|
||||
#[cfg(not(feature = "f16"))]
|
||||
#[allow(dead_code)]
|
||||
|
@ -34,8 +35,10 @@ pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
|
|||
))]
|
||||
mod ndarray {
|
||||
use crate::{launch, ElemType};
|
||||
use burn::backend::ndarray::{NdArray, NdArrayDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
ndarray::{NdArray, NdArrayDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);
|
||||
|
@ -44,8 +47,10 @@ mod ndarray {
|
|||
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -61,8 +66,10 @@ mod tch_gpu {
|
|||
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
mod tch_cpu {
|
||||
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
libtorch::{LibTorch, LibTorchDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
@ -73,8 +80,10 @@ mod tch_cpu {
|
|||
|
||||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::backend::{
|
||||
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
|
|
|
@ -11,11 +11,7 @@
|
|||
// generates a padding mask, and returns a batch object.
|
||||
|
||||
use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
|
||||
use burn::{
|
||||
data::dataloader::batcher::Batcher,
|
||||
nn::attention::generate_padding_mask,
|
||||
tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Tensor},
|
||||
};
|
||||
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Struct for batching text classification items
|
||||
|
|
|
@ -10,10 +10,9 @@ use crate::{
|
|||
training::ExperimentConfig,
|
||||
};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::batcher::Batcher,
|
||||
prelude::*,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
|
@ -5,15 +5,13 @@
|
|||
|
||||
use crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch};
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
loss::CrossEntropyLossConfig,
|
||||
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
||||
Embedding, EmbeddingConfig, Linear, LinearConfig,
|
||||
},
|
||||
tensor::backend::{AutodiffBackend, Backend},
|
||||
tensor::{activation::softmax, Tensor},
|
||||
prelude::*,
|
||||
tensor::{activation::softmax, backend::AutodiffBackend},
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
|
|
|
@ -10,12 +10,11 @@ use crate::{
|
|||
model::TextClassificationModelConfig,
|
||||
};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
|
||||
lr_scheduler::noam::NoamLrSchedulerConfig,
|
||||
module::Module,
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::AdamConfig,
|
||||
prelude::*,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
use super::{dataset::TextGenerationItem, tokenizer::Tokenizer};
|
||||
use burn::{
|
||||
data::dataloader::batcher::Batcher,
|
||||
nn::attention::generate_padding_mask,
|
||||
tensor::{backend::Backend, Bool, Int, Tensor},
|
||||
};
|
||||
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(new)]
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
use crate::data::TrainingTextGenerationBatch;
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
attention::generate_autoregressive_mask,
|
||||
loss::CrossEntropyLossConfig,
|
||||
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
||||
Embedding, EmbeddingConfig, Linear, LinearConfig,
|
||||
},
|
||||
tensor::backend::{AutodiffBackend, Backend},
|
||||
tensor::Tensor,
|
||||
prelude::*,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
|
|
|
@ -2,14 +2,15 @@ use crate::{
|
|||
data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer},
|
||||
model::TextGenerationModelConfig,
|
||||
};
|
||||
use burn::data::dataset::transform::SamplerDataset;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
|
||||
data::{
|
||||
dataloader::DataLoaderBuilder,
|
||||
dataset::{transform::SamplerDataset, Dataset},
|
||||
},
|
||||
lr_scheduler::noam::NoamLrSchedulerConfig,
|
||||
module::Module,
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::AdamConfig,
|
||||
prelude::*,
|
||||
record::{CompactRecorder, DefaultRecorder, Recorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
|
|
Loading…
Reference in New Issue