mirror of https://github.com/tracel-ai/burn.git
Example/text classification (#123)
This commit is contained in:
parent
7c38a980c1
commit
eee90a5c9e
|
@ -42,9 +42,8 @@ This may also be a good idea to take a look the main [components](#components) o
|
|||
|
||||
### Examples
|
||||
|
||||
For now there is only one example, but more to come 💪..
|
||||
|
||||
* [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using different backends.
|
||||
* [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification) train a transformer encoder from scratch on GPU.
|
||||
|
||||
### Components
|
||||
|
||||
|
|
|
@ -83,6 +83,6 @@ fn parse_asm(ast: &syn::DeriveInput) -> ConfigType {
|
|||
ConfigType::Struct(struct_data.fields.clone().into_iter().collect())
|
||||
}
|
||||
syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()),
|
||||
syn::Data::Union(_) => panic!("Only struct cna be derived"),
|
||||
syn::Data::Union(_) => panic!("Only struct and enum can be derived"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
[package]
|
||||
name = "text-classification"
|
||||
version = "0.1.0"
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
license = "MIT/Apache-2.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
# Burn
|
||||
burn = { path = "../../burn" }
|
||||
burn-autodiff = { path = "../../burn-autodiff" }
|
||||
burn-tch = { path = "../../burn-tch" }
|
||||
|
||||
# Tokenizer
|
||||
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] }
|
||||
|
||||
# Utils
|
||||
derive-new = "0.5"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
|
@ -0,0 +1,12 @@
|
|||
# Text Classification
|
||||
|
||||
The example can be run like so:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/burn-rs/burn.git
|
||||
cd burn
|
||||
# Use the --release flag to really speed up training.
|
||||
export TORCH_CUDA_VERSION=cu113 # Set the cuda version
|
||||
cargo run --example text-classification-ag-news --release # Train on the ag news dataset
|
||||
cargo run --example text-classification-db-pedia --release # Train on the db pedia dataset
|
||||
```
|
|
@ -0,0 +1,23 @@
|
|||
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig};
|
||||
use text_classification::{training::ExperimentConfig, AgNewsDataset};
|
||||
|
||||
type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>;
|
||||
|
||||
fn main() {
|
||||
let config = ExperimentConfig::new(
|
||||
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4),
|
||||
burn::optim::SgdConfig::new()
|
||||
.with_learning_rate(5.0e-3)
|
||||
.with_momentum(None)
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(5e-4)))
|
||||
.with_momentum(Some(MomentumConfig::new().with_nesterov(true))),
|
||||
);
|
||||
|
||||
text_classification::training::train::<Backend, AgNewsDataset>(
|
||||
burn_tch::TchDevice::Cuda(0),
|
||||
AgNewsDataset::train(),
|
||||
AgNewsDataset::test(),
|
||||
config,
|
||||
"/tmp/text-classification-ag-news",
|
||||
);
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig};
|
||||
use text_classification::{training::ExperimentConfig, DbPediaDataset};
|
||||
|
||||
type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>;
|
||||
|
||||
fn main() {
|
||||
let config = ExperimentConfig::new(
|
||||
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4),
|
||||
burn::optim::SgdConfig::new()
|
||||
.with_learning_rate(5.0e-3)
|
||||
.with_momentum(Some(MomentumConfig::new().with_nesterov(true)))
|
||||
.with_weight_decay(Some(WeightDecayConfig::new(5e-4))),
|
||||
);
|
||||
|
||||
text_classification::training::train::<Backend, DbPediaDataset>(
|
||||
burn_tch::TchDevice::Cuda(0),
|
||||
DbPediaDataset::train(),
|
||||
DbPediaDataset::test(),
|
||||
config,
|
||||
"/tmp/text-classification-db-pedia",
|
||||
);
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
|
||||
use burn::{
|
||||
data::dataloader::batcher::Batcher,
|
||||
tensor::{backend::Backend, BoolTensor, Data, Shape, Tensor},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct TextClassificationBatcher<B: Backend> {
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
num_classes: usize,
|
||||
device: B::Device,
|
||||
max_seq_lenght: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct TextClassificationBatch<B: Backend> {
|
||||
pub tokens: Tensor<B::IntegerBackend, 2>,
|
||||
pub labels: Tensor<B, 2>,
|
||||
pub mask_pad: BoolTensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>
|
||||
for TextClassificationBatcher<B>
|
||||
{
|
||||
fn batch(&self, items: Vec<TextClassificationItem>) -> TextClassificationBatch<B> {
|
||||
let mut tokens_list = Vec::with_capacity(items.len());
|
||||
let mut labels_list = Vec::with_capacity(items.len());
|
||||
|
||||
for item in items {
|
||||
tokens_list.push(self.tokenizer.encode(&item.text));
|
||||
labels_list.push(Tensor::one_hot(item.label, self.num_classes));
|
||||
}
|
||||
|
||||
let (tokens, mask_pad) =
|
||||
pad_tokens::<B>(self.tokenizer.pad_token(), tokens_list, self.max_seq_lenght);
|
||||
|
||||
TextClassificationBatch {
|
||||
tokens: tokens.to_device(self.device).detach(),
|
||||
labels: Tensor::cat(labels_list, 0).to_device(self.device).detach(),
|
||||
mask_pad: mask_pad.to_device(self.device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pad_tokens<B: Backend>(
|
||||
pad_token: usize,
|
||||
tokens_list: Vec<Vec<usize>>,
|
||||
max_seq_lenght: usize,
|
||||
) -> (Tensor<B::IntegerBackend, 2>, BoolTensor<B, 2>) {
|
||||
let mut max_size = 0;
|
||||
let batch_size = tokens_list.len();
|
||||
|
||||
for tokens in tokens_list.iter() {
|
||||
if tokens.len() > max_size {
|
||||
max_size = tokens.len();
|
||||
}
|
||||
if tokens.len() >= max_seq_lenght {
|
||||
max_size = max_seq_lenght;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut tensor = Tensor::zeros([batch_size, max_size]);
|
||||
tensor = tensor.add_scalar(pad_token as i64);
|
||||
|
||||
for (index, tokens) in tokens_list.into_iter().enumerate() {
|
||||
let mut seq_length = tokens.len();
|
||||
let mut tokens = tokens;
|
||||
if seq_length > max_seq_lenght {
|
||||
seq_length = max_seq_lenght;
|
||||
let _ = tokens.split_off(seq_length);
|
||||
}
|
||||
tensor = tensor.index_assign(
|
||||
[index..index + 1, 0..tokens.len()],
|
||||
&Tensor::from_data(Data::new(
|
||||
tokens.into_iter().map(|e| e as i64).collect(),
|
||||
Shape::new([1, seq_length]),
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
let mask_pad = BoolTensor::from_int_backend(tensor.equal_scalar(pad_token as i64));
|
||||
|
||||
(tensor, mask_pad)
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
use burn::data::dataset::{
|
||||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset,
|
||||
};
|
||||
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct TextClassificationItem {
|
||||
pub text: String,
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
pub trait TextClassificationDataset: Dataset<TextClassificationItem> {
|
||||
fn num_classes() -> usize;
|
||||
fn class_name(label: usize) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct AgNewsItem {
|
||||
pub text: String,
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
pub struct AgNewsDataset {
|
||||
dataset: InMemDataset<AgNewsItem>,
|
||||
}
|
||||
|
||||
impl Dataset<TextClassificationItem> for AgNewsDataset {
|
||||
fn get(&self, index: usize) -> Option<TextClassificationItem> {
|
||||
self.dataset
|
||||
.get(index)
|
||||
.map(|item| TextClassificationItem::new(item.text, item.label))
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl AgNewsDataset {
|
||||
pub fn train() -> Self {
|
||||
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "train")
|
||||
.extract_string("text")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
}
|
||||
|
||||
pub fn test() -> Self {
|
||||
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "test")
|
||||
.extract_string("text")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
}
|
||||
}
|
||||
|
||||
impl TextClassificationDataset for AgNewsDataset {
|
||||
fn num_classes() -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
fn class_name(label: usize) -> String {
|
||||
match label {
|
||||
0 => "World",
|
||||
1 => "Sports",
|
||||
2 => "Business",
|
||||
3 => "Technology",
|
||||
_ => panic!("invalid class"),
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct DbPediaItem {
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
pub struct DbPediaDataset {
|
||||
dataset: InMemDataset<DbPediaItem>,
|
||||
}
|
||||
|
||||
impl Dataset<TextClassificationItem> for DbPediaDataset {
|
||||
fn get(&self, index: usize) -> Option<TextClassificationItem> {
|
||||
self.dataset.get(index).map(|item| {
|
||||
TextClassificationItem::new(
|
||||
format!("Title: {} - Content: {}", item.title, item.content),
|
||||
item.label,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl DbPediaDataset {
|
||||
pub fn train() -> Self {
|
||||
let dataset: InMemDataset<DbPediaItem> =
|
||||
HuggingfaceDatasetLoader::new("dbpedia_14", "train")
|
||||
.extract_string("title")
|
||||
.extract_string("content")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
}
|
||||
|
||||
pub fn test() -> Self {
|
||||
let dataset: InMemDataset<DbPediaItem> =
|
||||
HuggingfaceDatasetLoader::new("dbpedia_14", "test")
|
||||
.extract_string("title")
|
||||
.extract_string("content")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
}
|
||||
}
|
||||
|
||||
impl TextClassificationDataset for DbPediaDataset {
|
||||
fn num_classes() -> usize {
|
||||
14
|
||||
}
|
||||
|
||||
fn class_name(label: usize) -> String {
|
||||
match label {
|
||||
0 => "Company",
|
||||
1 => "EducationalInstitution",
|
||||
2 => "Artist",
|
||||
3 => "Athlete",
|
||||
4 => "OfficeHolder",
|
||||
5 => "MeanOfTransportation",
|
||||
6 => "Building",
|
||||
7 => "NaturalPlace",
|
||||
8 => "Village",
|
||||
9 => "Animal",
|
||||
10 => "Plant",
|
||||
11 => "Album",
|
||||
12 => "Film",
|
||||
13 => "WrittenWork",
|
||||
_ => panic!("invalid class"),
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod batcher;
|
||||
mod dataset;
|
||||
mod tokenizer;
|
||||
|
||||
pub use batcher::*;
|
||||
pub use dataset::*;
|
||||
pub use tokenizer::*;
|
|
@ -0,0 +1,42 @@
|
|||
pub trait Tokenizer: Send + Sync {
|
||||
fn encode(&self, value: &str) -> Vec<usize>;
|
||||
fn decode(&self, tokens: &[usize]) -> String;
|
||||
fn vocab_size(&self) -> usize;
|
||||
fn pad_token(&self) -> usize;
|
||||
fn pad_token_value(&self) -> String {
|
||||
self.decode(&[self.pad_token()])
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BertCasedTokenizer {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
}
|
||||
|
||||
impl Default for BertCasedTokenizer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tokenizer for BertCasedTokenizer {
|
||||
fn encode(&self, value: &str) -> Vec<usize> {
|
||||
let tokens = self.tokenizer.encode(value, true).unwrap();
|
||||
tokens.get_ids().iter().map(|t| *t as usize).collect()
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: &[usize]) -> String {
|
||||
self.tokenizer
|
||||
.decode(tokens.iter().map(|t| *t as u32).collect(), false)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn vocab_size(&self) -> usize {
|
||||
self.tokenizer.get_vocab_size(true)
|
||||
}
|
||||
|
||||
fn pad_token(&self) -> usize {
|
||||
self.tokenizer.token_to_id("[PAD]").unwrap() as usize
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
mod data;
|
||||
mod model;
|
||||
|
||||
pub mod training;
|
||||
pub use data::{AgNewsDataset, DbPediaDataset};
|
|
@ -0,0 +1,104 @@
|
|||
use crate::data::TextClassificationBatch;
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{
|
||||
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
||||
Embedding, EmbeddingConfig, Linear, LinearConfig,
|
||||
},
|
||||
tensor::backend::{ADBackend, Backend},
|
||||
tensor::{loss::cross_entropy_with_logits, Tensor},
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct TextClassificationModelConfig {
|
||||
transformer: TransformerEncoderConfig,
|
||||
n_classes: usize,
|
||||
vocab_size: usize,
|
||||
max_seq_length: usize,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TextClassificationModel<B: Backend> {
|
||||
transformer: Param<TransformerEncoder<B>>,
|
||||
embedding_token: Param<Embedding<B>>,
|
||||
embedding_pos: Param<Embedding<B>>,
|
||||
output: Param<Linear<B>>,
|
||||
n_classes: usize,
|
||||
max_seq_length: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> TextClassificationModel<B> {
|
||||
pub fn new(config: &TextClassificationModelConfig) -> Self {
|
||||
let config_embedding_token =
|
||||
EmbeddingConfig::new(config.vocab_size, config.transformer.d_model);
|
||||
let config_embedding_pos =
|
||||
EmbeddingConfig::new(config.max_seq_length, config.transformer.d_model);
|
||||
let config_output = LinearConfig::new(config.transformer.d_model, config.n_classes);
|
||||
|
||||
let transformer = TransformerEncoder::new(&config.transformer);
|
||||
let embedding_token = Embedding::new(&config_embedding_token);
|
||||
let embedding_pos = Embedding::new(&config_embedding_pos);
|
||||
let output = Linear::new(&config_output);
|
||||
|
||||
Self {
|
||||
transformer: Param::new(transformer),
|
||||
embedding_token: Param::new(embedding_token),
|
||||
embedding_pos: Param::new(embedding_pos),
|
||||
output: Param::new(output),
|
||||
n_classes: config.n_classes,
|
||||
max_seq_length: config.max_seq_length,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, item: TextClassificationBatch<B>) -> ClassificationOutput<B> {
|
||||
let [batch_size, seq_length] = item.tokens.dims();
|
||||
|
||||
let index_positions = Tensor::<B, 1>::arange_device(0..seq_length, item.tokens.device())
|
||||
.reshape([1, seq_length])
|
||||
.repeat(0, batch_size);
|
||||
let embedding_positions = self.embedding_pos.forward(index_positions.detach());
|
||||
let embedding_tokens = self.embedding_token.forward(item.tokens.detach());
|
||||
let embedding = (embedding_positions + embedding_tokens) / 2;
|
||||
|
||||
let encoded = self
|
||||
.transformer
|
||||
.forward(TransformerEncoderInput::new(embedding).mask_pad(item.mask_pad));
|
||||
let output = self.output.forward(encoded);
|
||||
|
||||
let output_classification = output
|
||||
.index([0..batch_size, 0..1])
|
||||
.reshape([batch_size, self.n_classes]);
|
||||
|
||||
let loss = cross_entropy_with_logits(&output_classification, &item.labels.clone().detach());
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
output: output_classification,
|
||||
targets: item.labels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ADBackend> TrainStep<TextClassificationBatch<B>, ClassificationOutput<B>, B::Gradients>
|
||||
for TextClassificationModel<B>
|
||||
{
|
||||
fn step(
|
||||
&self,
|
||||
item: TextClassificationBatch<B>,
|
||||
) -> TrainOutput<ClassificationOutput<B>, B::Gradients> {
|
||||
let item = self.forward(item);
|
||||
let grads = item.loss.backward();
|
||||
|
||||
TrainOutput::new(grads, item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<TextClassificationBatch<B>, ClassificationOutput<B>>
|
||||
for TextClassificationModel<B>
|
||||
{
|
||||
fn step(&self, item: TextClassificationBatch<B>) -> ClassificationOutput<B> {
|
||||
self.forward(item)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
use crate::{
|
||||
data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer},
|
||||
model::{TextClassificationModel, TextClassificationModelConfig},
|
||||
};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
module::Module,
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::{Sgd, SgdConfig},
|
||||
tensor::backend::ADBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, CUDAMetric, LossMetric},
|
||||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ExperimentConfig {
|
||||
transformer: TransformerEncoderConfig,
|
||||
optimizer: SgdConfig,
|
||||
#[config(default = 256)]
|
||||
max_seq_length: usize,
|
||||
#[config(default = 32)]
|
||||
batch_size: usize,
|
||||
#[config(default = 10)]
|
||||
num_epochs: usize,
|
||||
}
|
||||
|
||||
pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
||||
device: B::Device,
|
||||
dataset_train: D,
|
||||
dataset_test: D,
|
||||
config: ExperimentConfig,
|
||||
artifact_dir: &str,
|
||||
) {
|
||||
let dataset_train = Arc::new(dataset_train);
|
||||
let dataset_test = Arc::new(dataset_test);
|
||||
let n_classes = D::num_classes();
|
||||
|
||||
let tokenizer = Arc::new(BertCasedTokenizer::default());
|
||||
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
|
||||
tokenizer.clone(),
|
||||
n_classes,
|
||||
device,
|
||||
config.max_seq_length,
|
||||
));
|
||||
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
|
||||
tokenizer.clone(),
|
||||
n_classes,
|
||||
device,
|
||||
config.max_seq_length,
|
||||
));
|
||||
|
||||
let mut model = TextClassificationModel::new(&TextClassificationModelConfig::new(
|
||||
config.transformer.clone(),
|
||||
n_classes,
|
||||
tokenizer.vocab_size(),
|
||||
config.max_seq_length,
|
||||
));
|
||||
model.to_device(device);
|
||||
model.detach();
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.num_workers(8)
|
||||
.shuffle(42)
|
||||
.build(dataset_train);
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_test)
|
||||
.batch_size(config.batch_size)
|
||||
.num_workers(8)
|
||||
.build(dataset_test);
|
||||
|
||||
let optim = Sgd::new(&config.optimizer);
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train(CUDAMetric::new())
|
||||
.metric_valid(CUDAMetric::new())
|
||||
.metric_train(AccuracyMetric::new())
|
||||
.metric_valid(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.with_file_checkpointer::<f32>(2)
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim);
|
||||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(&format!("{}/config.json", artifact_dir))
|
||||
.unwrap();
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<f32>()
|
||||
.save(&format!("{}/model.json.gz", artifact_dir))
|
||||
.unwrap();
|
||||
}
|
Loading…
Reference in New Issue