docs(book-&-examples): modify book and examples with new `prelude` module (#1372)

This commit is contained in:
Yu Sun 2024-02-29 02:25:25 +08:00 committed by GitHub
parent 57887e7a47
commit 330552afb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 299 additions and 258 deletions

View File

@ -17,7 +17,7 @@ at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/exa
```rust , ignore
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
prelude::*,
};
pub struct MnistBatcher<B: Backend> {

View File

@ -30,14 +30,12 @@ Let us start by defining our model struct in a new file `src/model.rs`.
```rust , ignore
use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
},
tensor::{backend::Backend, Tensor},
prelude::*,
};
#[derive(Module, Debug)]

View File

@ -11,7 +11,6 @@ to your types, allowing you to define default values with ease. Additionally, al
serialized, reducing potential bugs when upgrading versions and improving reproducibility.
```rust , ignore
#[derive(Config)]
use burn::config::Config;
#[derive(Config)]

View File

@ -5,7 +5,6 @@ derive function only generates the necessary methods to essentially act as a par
your type, it makes no assumptions about how the forward pass is declared.
```rust, ignore
use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;

View File

@ -168,6 +168,37 @@ While the previous example is somewhat trivial, the upcoming
basic workflow section will walk you through a much more relevant example for
deep learning applications.
## Using `prelude`
Burn comes with a variety of things in its core library.
When creating a new model or using an existing one for inference,
you may need to import every single component you used, which could be a little verbose.
To address it, a `prelude` module is provided, allowing you to easily import commonly used structs and macros as a group:
```rust, ignore
use burn::prelude::*;
```
which is equal to:
```rust, ignore
use burn::{
config::Config,
module::Module,
nn,
tensor::{
backend::Backend, Bool, Data, Device, ElementConversion, Float, Int, Shape, Tensor,
},
};
```
<div class="warning">
For the sake of simplicity, the subsequent chapters of this book will all use this form of importing. However, this does not include the content in the [Building Blocks](./building-blocks) chapter, as explicit importing aids users in grasping the usage of particular structures and macros.
</div>
## Running examples
Many additional Burn examples available in the

View File

@ -48,9 +48,8 @@ something like this:
```rust
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{backend::Backend, Tensor},
prelude::*,
};
#[derive(Module, Debug)]

View File

@ -10,8 +10,8 @@ serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` whic
// Save model in MessagePack format with full precision
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
```
Note that the file extension is automatically handled by the recorder depending on the one you
@ -23,8 +23,8 @@ Now that you have a trained model saved to your disk, you can easily load it in
// Load model in full precision from MessagePack file
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.load_file(model_path, &recorder, device)
.expect("Should be able to load the model weights from the provided file");
.load_file(model_path, &recorder, device)
.expect("Should be able to load the model weights from the provided file");
```
**Note:** models can be saved in different output formats, just make sure you are using the correct
@ -117,8 +117,8 @@ a model as part of your runtime application, first save the model to a binary fi
// Save model in binary format with full precision
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
```
Then, in your final application, include the model and use the `BinBytesRecorder` to load it.

View File

@ -71,6 +71,8 @@ pub mod prelude {
config::Config,
module::Module,
nn,
tensor::{backend::Backend, Data, Device, ElementConversion, Tensor},
tensor::{
backend::Backend, Bool, Data, Device, ElementConversion, Float, Int, Shape, Tensor,
},
};
}

View File

@ -1,9 +1,12 @@
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::optim::momentum::MomentumConfig;
use burn::optim::SgdConfig;
use burn::{
backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
},
optim::{momentum::MomentumConfig, SgdConfig},
};
use custom_image_dataset::training::{train, TrainingConfig};
pub fn run() {
@ -25,10 +28,13 @@ mod tch_gpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::wgpu::{Wgpu, WgpuDevice};
use burn::backend::Autodiff;
use burn::optim::momentum::MomentumConfig;
use burn::optim::SgdConfig;
use burn::{
backend::{
wgpu::{Wgpu, WgpuDevice},
Autodiff,
},
optim::{momentum::MomentumConfig, SgdConfig},
};
use custom_image_dataset::training::{train, TrainingConfig};
pub fn run() {

View File

@ -3,7 +3,7 @@ use burn::{
dataloader::batcher::Batcher,
dataset::vision::{Annotation, ImageDatasetItem, PixelDepth},
},
tensor::{backend::Backend, Data, Device, ElementConversion, Int, Shape, Tensor},
prelude::*,
};
// CIFAR-10 mean and std values

View File

@ -1,11 +1,10 @@
use burn::{
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{MaxPool2d, MaxPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
},
tensor::{backend::Backend, Device, Tensor},
prelude::*,
};
/// Basic convolutional neural network with VGG-style blocks.

View File

@ -5,21 +5,16 @@ use crate::{
dataset::CIFAR10Loader,
model::Cnn,
};
use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset};
use burn::train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
};
use burn::{
self,
config::Config,
module::Module,
data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset},
nn::loss::CrossEntropyLossConfig,
optim::SgdConfig,
prelude::*,
record::CompactRecorder,
tensor::{
backend::{AutodiffBackend, Backend},
Int, Tensor,
tensor::backend::AutodiffBackend,
train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
},
};

View File

@ -1,5 +1,4 @@
use burn::backend::wgpu::WgpuDevice;
use burn::backend::{Autodiff, Wgpu};
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
fn main() {
custom_renderer::run::<Autodiff<Wgpu>>(WgpuDevice::default());

View File

@ -1,9 +1,12 @@
use burn::data::dataset::vision::MnistDataset;
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::LearnerBuilder;
use burn::{
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
optim::AdamConfig,
tensor::backend::AutodiffBackend,
train::{
renderer::{MetricState, MetricsRenderer, TrainingProgress},
LearnerBuilder,
},
};
use guide::{data::MnistBatcher, model::ModelConfig};

View File

@ -1,5 +1,4 @@
use burn::backend::wgpu::WgpuDevice;
use burn::backend::{Autodiff, Wgpu};
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
fn main() {
custom_training_loop::run::<Autodiff<Wgpu>>(WgpuDevice::default());

View File

@ -1,16 +1,12 @@
use std::marker::PhantomData;
use burn::data::dataset::vision::MnistDataset;
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
module::AutodiffModule,
nn::loss::CrossEntropyLoss,
optim::{AdamConfig, GradientsParams, Optimizer},
tensor::{
backend::{AutodiffBackend, Backend},
ElementConversion, Int, Tensor,
},
prelude::*,
tensor::backend::AutodiffBackend,
};
use guide::{
data::{MnistBatch, MnistBatcher},

View File

@ -1,15 +1,18 @@
use crate::FloatTensor;
use super::{AutodiffBackend, Backend};
use burn::backend::autodiff::{
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
grads::Gradients,
ops::{broadcast_shape, Backward, Ops, OpsKind},
Autodiff, NodeID,
use burn::{
backend::{
autodiff::{
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
grads::Gradients,
ops::{broadcast_shape, Backward, Ops, OpsKind},
Autodiff, NodeID,
},
wgpu::{compute::WgpuRuntime, FloatElement, GraphicsApi, IntElement, JitBackend},
},
tensor::Shape,
};
use burn::backend::wgpu::compute::WgpuRuntime;
use burn::backend::wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend};
use burn::tensor::Shape;
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<WgpuRuntime<G, F, I>>>

View File

@ -1,16 +1,18 @@
use crate::FloatTensor;
use super::Backend;
use burn::backend::wgpu::{
compute::{DynamicKernel, WgpuRuntime, WorkGroup},
kernel::{
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
use burn::{
backend::wgpu::{
compute::{DynamicKernel, WgpuRuntime, WorkGroup},
kernel::{
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
},
kernel_wgsl,
tensor::JitTensor,
FloatElement, GraphicsApi, IntElement, JitBackend,
},
kernel_wgsl,
tensor::JitTensor,
FloatElement, GraphicsApi, IntElement, JitBackend,
tensor::Shape,
};
use burn::tensor::Shape;
use derive_new::new;
use std::marker::PhantomData;

View File

@ -1,7 +1,8 @@
use burn::backend::wgpu::AutoGraphicsApi;
use burn::backend::{Autodiff, Wgpu};
use burn::data::dataset::Dataset;
use burn::optim::AdamConfig;
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};
use guide::{model::ModelConfig, training::TrainingConfig};
fn main() {

View File

@ -1,6 +1,6 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
prelude::*,
};
pub struct MnistBatcher<B: Backend> {

View File

@ -1,10 +1,8 @@
use crate::{data::MnistBatcher, training::TrainingConfig};
use burn::data::dataset::vision::MnistItem;
use burn::{
config::Config,
data::dataloader::batcher::Batcher,
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
record::{CompactRecorder, Recorder},
tensor::backend::Backend,
};
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {

View File

@ -1,12 +1,10 @@
use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
},
tensor::{backend::Backend, Tensor},
prelude::*,
};
#[derive(Module, Debug)]

View File

@ -2,22 +2,16 @@ use crate::{
data::{MnistBatch, MnistBatcher},
model::{Model, ModelConfig},
};
use burn::data::dataset::vision::MnistDataset;
use burn::train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
};
use burn::{
self,
config::Config,
data::dataloader::DataLoaderBuilder,
module::Module,
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
nn::loss::CrossEntropyLossConfig,
optim::AdamConfig,
prelude::*,
record::CompactRecorder,
tensor::{
backend::{AutodiffBackend, Backend},
Int, Tensor,
tensor::backend::AutodiffBackend,
train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
},
};

View File

@ -1,11 +1,12 @@
/// This build script generates the model code from the ONNX file and the labels from the text file.
use std::env;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::path::Path;
use std::{
env,
fs::File,
io::{BufRead, BufReader, Write},
path::Path,
};
use burn_import::burn::graph::RecordType;
use burn_import::onnx::ModelGen;
use burn_import::{burn::graph::RecordType, onnx::ModelGen};
const LABEL_SOURCE_FILE: &str = "src/model/label.txt";
const LABEL_DEST_FILE: &str = "model/label.rs";

View File

@ -1,4 +1,4 @@
use burn::tensor::{backend::Backend, Tensor};
use burn::prelude::*;
// Values are taken from the [ONNX SqueezeNet]
// (https://github.com/onnx/models/tree/main/vision/classification/squeezenet#preprocessing)

View File

@ -8,10 +8,7 @@ use core::convert::Into;
use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel};
use burn::{
backend::NdArray,
tensor::{activation::softmax, backend::Backend, Tensor},
};
use burn::{backend::NdArray, prelude::*, tensor::activation::softmax};
use burn_candle::Candle;
use burn_wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice};

View File

@ -3,9 +3,8 @@
// Originally copied from the burn/examples/mnist package
use burn::{
module::Module,
nn::{self, BatchNorm, PaddingConfig2d},
tensor::{backend::Backend, Tensor},
nn::{BatchNorm, PaddingConfig2d},
prelude::*,
};
#[derive(Module, Debug)]

View File

@ -1,8 +1,8 @@
use crate::model::Model;
use burn::module::Module;
use burn::record::BinBytesRecorder;
use burn::record::FullPrecisionSettings;
use burn::record::Recorder;
use burn::{
module::Module,
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
};
#[cfg(feature = "wgpu")]
use burn::backend::wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice};

View File

@ -5,8 +5,10 @@
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};
use burn::backend::Autodiff;
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
use mnist::training;
pub fn run() {
@ -17,8 +19,10 @@ mod ndarray {
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use mnist::training;
pub fn run() {
@ -33,8 +37,10 @@ mod tch_gpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::wgpu::{Wgpu, WgpuDevice};
use burn::backend::Autodiff;
use burn::backend::{
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};
use mnist::training;
pub fn run() {
@ -45,8 +51,10 @@ mod wgpu {
#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use mnist::training;
pub fn run() {

View File

@ -1,6 +1,6 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
prelude::*,
};
pub struct MnistBatcher<B: Backend> {

View File

@ -1,11 +1,8 @@
use crate::data::MnistBatch;
use burn::{
module::Module,
nn::{self, loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
},
nn::{loss::CrossEntropyLossConfig, BatchNorm, PaddingConfig2d},
prelude::*,
tensor::backend::AutodiffBackend,
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};

View File

@ -1,20 +1,17 @@
use crate::data::MnistBatcher;
use crate::model::Model;
use crate::{data::MnistBatcher, model::Model};
use burn::module::Module;
use burn::optim::decay::WeightDecayConfig;
use burn::optim::AdamConfig;
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
use burn::train::metric::store::{Aggregate, Direction, Split};
use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
optim::{decay::WeightDecayConfig, AdamConfig},
prelude::*,
record::{CompactRecorder, NoStdTrainingRecorder},
tensor::backend::AutodiffBackend,
train::{
metric::{AccuracyMetric, LossMetric},
LearnerBuilder,
metric::{
store::{Aggregate, Direction, Split},
AccuracyMetric, CpuMemory, CpuTemperature, CpuUse, LossMetric,
},
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
},
};

View File

@ -1,5 +1,4 @@
use burn::tensor::backend::Backend;
use burn::tensor::{Dim, Distribution, NamedDim, NamedTensor};
use burn::tensor::{backend::Backend, Dim, Distribution, NamedDim, NamedTensor};
NamedDim!(Batch);
NamedDim!(SeqLength);

View File

@ -37,8 +37,7 @@
"outputs": [],
"source": [
"// Import packages\n",
"use burn::tensor::Tensor;\n",
"use burn::tensor::backend::Backend;\n",
"use burn::prelude::*;\n",
"use burn_ndarray::NdArray;\n",
"\n",
"// Type alias for the backend\n",
@ -111,7 +110,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "rust"
}
},
"outputs": [],
"source": []
}

View File

@ -1,10 +1,10 @@
use std::env::args;
use burn::backend::ndarray::NdArray;
use burn::tensor::Tensor;
use burn::data::dataset::vision::MnistDataset;
use burn::data::dataset::Dataset;
use burn::{
backend::ndarray::NdArray,
data::dataset::{vision::MnistDataset, Dataset},
tensor::Tensor,
};
use onnx_inference::mnist::Model;

View File

@ -3,8 +3,10 @@
/// 2. Saves the model record to a file using the `NamedMpkFileRecorder`.
use std::path::Path;
use burn::backend::NdArray;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn::{
backend::NdArray,
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
};
use burn_import::pytorch::PyTorchFileRecorder;
// Basic backend type (not used directly here).

View File

@ -1,20 +1,14 @@
use std::env;
use std::path::Path;
use burn::nn::conv::Conv2d;
use burn::nn::conv::Conv2dConfig;
use burn::nn::BatchNorm;
use burn::nn::BatchNormConfig;
use burn::nn::Linear;
use burn::nn::LinearConfig;
use burn::record::FullPrecisionSettings;
use burn::record::NamedMpkFileRecorder;
use burn::record::Recorder;
use burn::tensor::activation::log_softmax;
use burn::tensor::activation::relu;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
nn::{
conv::{Conv2d, Conv2dConfig},
BatchNorm, BatchNormConfig, Linear, LinearConfig,
},
prelude::*,
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
tensor::activation::{log_softmax, relu},
};
#[derive(Module, Debug)]

View File

@ -1,12 +1,12 @@
use std::env::args;
use std::path::Path;
use burn::backend::ndarray::NdArray;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn::tensor::Tensor;
use burn::data::dataset::vision::MnistDataset;
use burn::data::dataset::Dataset;
use burn::{
backend::ndarray::NdArray,
data::dataset::{vision::MnistDataset, Dataset},
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
tensor::Tensor,
};
use model::Model;

View File

@ -5,8 +5,10 @@
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};
use burn::backend::Autodiff;
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
use regression::training;
pub fn run() {
@ -17,8 +19,10 @@ mod ndarray {
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use regression::training;
pub fn run() {
@ -33,8 +37,10 @@ mod tch_gpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::wgpu::{Wgpu, WgpuDevice};
use burn::backend::Autodiff;
use burn::backend::{
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};
use regression::training;
pub fn run() {
@ -45,8 +51,10 @@ mod wgpu {
#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use regression::training;
pub fn run() {
let device = LibTorchDevice::Cpu;

View File

@ -1,8 +1,13 @@
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataset::transform::{PartialDataset, ShuffledDataset};
use burn::data::dataset::{Dataset, HuggingfaceDatasetLoader, SqliteDataset};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use burn::{
data::{
dataloader::batcher::Batcher,
dataset::{
transform::{PartialDataset, ShuffledDataset},
Dataset, HuggingfaceDatasetLoader, SqliteDataset,
},
},
prelude::*,
};
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DiabetesItem {

View File

@ -1,14 +1,11 @@
use crate::dataset::DiabetesBatch;
use burn::config::Config;
use burn::nn::loss::Reduction::Mean;
use burn::nn::Relu;
use burn::{
module::Module,
nn::{loss::MseLoss, Linear, LinearConfig},
tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
nn::{
loss::{MseLoss, Reduction::Mean},
Linear, LinearConfig, Relu,
},
prelude::*,
tensor::backend::AutodiffBackend,
train::{RegressionOutput, TrainOutput, TrainStep, ValidStep},
};

View File

@ -1,16 +1,16 @@
use crate::dataset::{DiabetesBatcher, DiabetesDataset};
use crate::model::RegressionModelConfig;
use burn::data::dataset::Dataset;
use burn::module::Module;
use burn::optim::SgdConfig;
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
use burn::train::metric::store::{Aggregate, Direction, Split};
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
optim::SgdConfig,
prelude::*,
record::{CompactRecorder, NoStdTrainingRecorder},
tensor::backend::AutodiffBackend,
train::{metric::LossMetric, LearnerBuilder},
train::{
metric::store::{Aggregate, Direction, Split},
metric::LossMetric,
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
},
};
static ARTIFACT_DIR: &str = "/tmp/burn-example-regression";

View File

@ -1,9 +1,10 @@
use burn::nn::transformer::TransformerEncoderConfig;
use burn::optim::{decay::WeightDecayConfig, AdamConfig};
use burn::tensor::backend::AutodiffBackend;
use burn::{
nn::transformer::TransformerEncoderConfig,
optim::{decay::WeightDecayConfig, AdamConfig},
tensor::backend::AutodiffBackend,
};
use text_classification::training::ExperimentConfig;
use text_classification::AgNewsDataset;
use text_classification::{training::ExperimentConfig, AgNewsDataset};
#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
@ -35,8 +36,10 @@ pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};
use burn::backend::Autodiff;
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -47,8 +50,10 @@ mod ndarray {
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -64,8 +69,10 @@ mod tch_gpu {
#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -77,8 +84,10 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::Autodiff;
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
Autodiff,
};
pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);

View File

@ -34,8 +34,10 @@ pub fn launch<B: AutodiffBackend>(device: B::Device) {
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};
use burn::backend::Autodiff;
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -46,8 +48,10 @@ mod ndarray {
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -63,8 +67,10 @@ mod tch_gpu {
#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::tch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
tch::{LibTorch, LibTorchDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -75,8 +81,10 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::Autodiff;
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
Autodiff,
};
use crate::{launch, ElemType};

View File

@ -1,9 +1,10 @@
use burn::nn::transformer::TransformerEncoderConfig;
use burn::optim::{decay::WeightDecayConfig, AdamConfig};
use burn::tensor::backend::AutodiffBackend;
use burn::{
nn::transformer::TransformerEncoderConfig,
optim::{decay::WeightDecayConfig, AdamConfig},
tensor::backend::AutodiffBackend,
};
use text_classification::training::ExperimentConfig;
use text_classification::DbPediaDataset;
use text_classification::{training::ExperimentConfig, DbPediaDataset};
#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
@ -34,8 +35,10 @@ pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
))]
mod ndarray {
use crate::{launch, ElemType};
use burn::backend::ndarray::{NdArray, NdArrayDevice};
use burn::backend::Autodiff;
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
pub fn run() {
launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);
@ -44,8 +47,10 @@ mod ndarray {
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -61,8 +66,10 @@ mod tch_gpu {
#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use crate::{launch, ElemType};
@ -73,8 +80,10 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::Autodiff;
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
Autodiff,
};
use crate::{launch, ElemType};

View File

@ -11,11 +11,7 @@
// generates a padding mask, and returns a batch object.
use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
use burn::{
data::dataloader::batcher::Batcher,
nn::attention::generate_padding_mask,
tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Tensor},
};
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};
use std::sync::Arc;
/// Struct for batching text classification items

View File

@ -10,10 +10,9 @@ use crate::{
training::ExperimentConfig,
};
use burn::{
config::Config,
data::dataloader::batcher::Batcher,
prelude::*,
record::{CompactRecorder, Recorder},
tensor::backend::Backend,
};
use std::sync::Arc;

View File

@ -5,15 +5,13 @@
use crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch};
use burn::{
config::Config,
module::Module,
nn::{
loss::CrossEntropyLossConfig,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Embedding, EmbeddingConfig, Linear, LinearConfig,
},
tensor::backend::{AutodiffBackend, Backend},
tensor::{activation::softmax, Tensor},
prelude::*,
tensor::{activation::softmax, backend::AutodiffBackend},
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};

View File

@ -10,12 +10,11 @@ use crate::{
model::TextClassificationModelConfig,
};
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
lr_scheduler::noam::NoamLrSchedulerConfig,
module::Module,
nn::transformer::TransformerEncoderConfig,
optim::AdamConfig,
prelude::*,
record::{CompactRecorder, Recorder},
tensor::backend::AutodiffBackend,
train::{

View File

@ -1,9 +1,5 @@
use super::{dataset::TextGenerationItem, tokenizer::Tokenizer};
use burn::{
data::dataloader::batcher::Batcher,
nn::attention::generate_padding_mask,
tensor::{backend::Backend, Bool, Int, Tensor},
};
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};
use std::sync::Arc;
#[derive(new)]

View File

@ -1,15 +1,13 @@
use crate::data::TrainingTextGenerationBatch;
use burn::{
config::Config,
module::Module,
nn::{
attention::generate_autoregressive_mask,
loss::CrossEntropyLossConfig,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Embedding, EmbeddingConfig, Linear, LinearConfig,
},
tensor::backend::{AutodiffBackend, Backend},
tensor::Tensor,
prelude::*,
tensor::backend::AutodiffBackend,
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};

View File

@ -2,14 +2,15 @@ use crate::{
data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer},
model::TextGenerationModelConfig,
};
use burn::data::dataset::transform::SamplerDataset;
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
data::{
dataloader::DataLoaderBuilder,
dataset::{transform::SamplerDataset, Dataset},
},
lr_scheduler::noam::NoamLrSchedulerConfig,
module::Module,
nn::transformer::TransformerEncoderConfig,
optim::AdamConfig,
prelude::*,
record::{CompactRecorder, DefaultRecorder, Recorder},
tensor::backend::AutodiffBackend,
train::{