From eee90a5c9e1cbc83b8822e12b82b737428862969 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Fri, 2 Dec 2022 17:42:49 -0500 Subject: [PATCH] Example/text classification (#123) --- README.md | 3 +- burn-derive/src/config/analyzer.rs | 2 +- examples/text-classification/Cargo.toml | 23 +++ examples/text-classification/README.md | 12 ++ .../examples/text-classification-ag-news.rs | 23 +++ .../examples/text-classification-db-pedia.rs | 22 +++ .../text-classification/src/data/batcher.rs | 86 ++++++++++ .../text-classification/src/data/dataset.rs | 150 ++++++++++++++++++ examples/text-classification/src/data/mod.rs | 7 + .../text-classification/src/data/tokenizer.rs | 42 +++++ examples/text-classification/src/lib.rs | 8 + examples/text-classification/src/model.rs | 104 ++++++++++++ examples/text-classification/src/training.rs | 99 ++++++++++++ 13 files changed, 578 insertions(+), 3 deletions(-) create mode 100644 examples/text-classification/Cargo.toml create mode 100644 examples/text-classification/README.md create mode 100644 examples/text-classification/examples/text-classification-ag-news.rs create mode 100644 examples/text-classification/examples/text-classification-db-pedia.rs create mode 100644 examples/text-classification/src/data/batcher.rs create mode 100644 examples/text-classification/src/data/dataset.rs create mode 100644 examples/text-classification/src/data/mod.rs create mode 100644 examples/text-classification/src/data/tokenizer.rs create mode 100644 examples/text-classification/src/lib.rs create mode 100644 examples/text-classification/src/model.rs create mode 100644 examples/text-classification/src/training.rs diff --git a/README.md b/README.md index e5059ca3d..b48ec71ad 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,8 @@ This may also be a good idea to take a look the main [components](#components) o ### Examples -For now there is only one example, but more to come 💪.. - * [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using different backends. +* [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification) train a transformer encoder from scratch on GPU. ### Components diff --git a/burn-derive/src/config/analyzer.rs b/burn-derive/src/config/analyzer.rs index df9e5d7ff..da2e531e6 100644 --- a/burn-derive/src/config/analyzer.rs +++ b/burn-derive/src/config/analyzer.rs @@ -83,6 +83,6 @@ fn parse_asm(ast: &syn::DeriveInput) -> ConfigType { ConfigType::Struct(struct_data.fields.clone().into_iter().collect()) } syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()), - syn::Data::Union(_) => panic!("Only struct cna be derived"), + syn::Data::Union(_) => panic!("Only struct and enum can be derived"), } } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml new file mode 100644 index 000000000..d8b59ee6b --- /dev/null +++ b/examples/text-classification/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "text-classification" +version = "0.1.0" +authors = ["nathanielsimard "] +license = "MIT/Apache-2.0" +edition = "2021" +publish = false + +[features] +default = [] + +[dependencies] +# Burn +burn = { path = "../../burn" } +burn-autodiff = { path = "../../burn-autodiff" } +burn-tch = { path = "../../burn-tch" } + +# Tokenizer +tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] } + +# Utils +derive-new = "0.5" +serde = { version = "1.0", features = ["derive"] } diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md new file mode 100644 index 000000000..30847cd3e --- /dev/null +++ b/examples/text-classification/README.md @@ -0,0 +1,12 @@ +# Text Classification + +The example can be run like so: + +```bash +git clone https://github.com/burn-rs/burn.git +cd burn +# Use the --release flag to really speed up training. +export TORCH_CUDA_VERSION=cu113 # Set the cuda version +cargo run --example text-classification-ag-news --release # Train on the ag news dataset +cargo run --example text-classification-db-pedia --release # Train on the db pedia dataset +``` diff --git a/examples/text-classification/examples/text-classification-ag-news.rs b/examples/text-classification/examples/text-classification-ag-news.rs new file mode 100644 index 000000000..8d2857d8f --- /dev/null +++ b/examples/text-classification/examples/text-classification-ag-news.rs @@ -0,0 +1,23 @@ +use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig}; +use text_classification::{training::ExperimentConfig, AgNewsDataset}; + +type Backend = burn_autodiff::ADBackendDecorator>; + +fn main() { + let config = ExperimentConfig::new( + burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4), + burn::optim::SgdConfig::new() + .with_learning_rate(5.0e-3) + .with_momentum(None) + .with_weight_decay(Some(WeightDecayConfig::new(5e-4))) + .with_momentum(Some(MomentumConfig::new().with_nesterov(true))), + ); + + text_classification::training::train::( + burn_tch::TchDevice::Cuda(0), + AgNewsDataset::train(), + AgNewsDataset::test(), + config, + "/tmp/text-classification-ag-news", + ); +} diff --git a/examples/text-classification/examples/text-classification-db-pedia.rs b/examples/text-classification/examples/text-classification-db-pedia.rs new file mode 100644 index 000000000..585c6006b --- /dev/null +++ b/examples/text-classification/examples/text-classification-db-pedia.rs @@ -0,0 +1,22 @@ +use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig}; +use text_classification::{training::ExperimentConfig, DbPediaDataset}; + +type Backend = burn_autodiff::ADBackendDecorator>; + +fn main() { + let config = ExperimentConfig::new( + burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4), + burn::optim::SgdConfig::new() + .with_learning_rate(5.0e-3) + .with_momentum(Some(MomentumConfig::new().with_nesterov(true))) + .with_weight_decay(Some(WeightDecayConfig::new(5e-4))), + ); + + text_classification::training::train::( + burn_tch::TchDevice::Cuda(0), + DbPediaDataset::train(), + DbPediaDataset::test(), + config, + "/tmp/text-classification-db-pedia", + ); +} diff --git a/examples/text-classification/src/data/batcher.rs b/examples/text-classification/src/data/batcher.rs new file mode 100644 index 000000000..6e2f3eb5a --- /dev/null +++ b/examples/text-classification/src/data/batcher.rs @@ -0,0 +1,86 @@ +use super::{dataset::TextClassificationItem, tokenizer::Tokenizer}; +use burn::{ + data::dataloader::batcher::Batcher, + tensor::{backend::Backend, BoolTensor, Data, Shape, Tensor}, +}; +use std::sync::Arc; + +#[derive(new)] +pub struct TextClassificationBatcher { + tokenizer: Arc, + num_classes: usize, + device: B::Device, + max_seq_lenght: usize, +} + +#[derive(Debug, Clone, new)] +pub struct TextClassificationBatch { + pub tokens: Tensor, + pub labels: Tensor, + pub mask_pad: BoolTensor, +} + +impl Batcher> + for TextClassificationBatcher +{ + fn batch(&self, items: Vec) -> TextClassificationBatch { + let mut tokens_list = Vec::with_capacity(items.len()); + let mut labels_list = Vec::with_capacity(items.len()); + + for item in items { + tokens_list.push(self.tokenizer.encode(&item.text)); + labels_list.push(Tensor::one_hot(item.label, self.num_classes)); + } + + let (tokens, mask_pad) = + pad_tokens::(self.tokenizer.pad_token(), tokens_list, self.max_seq_lenght); + + TextClassificationBatch { + tokens: tokens.to_device(self.device).detach(), + labels: Tensor::cat(labels_list, 0).to_device(self.device).detach(), + mask_pad: mask_pad.to_device(self.device), + } + } +} + +pub fn pad_tokens( + pad_token: usize, + tokens_list: Vec>, + max_seq_lenght: usize, +) -> (Tensor, BoolTensor) { + let mut max_size = 0; + let batch_size = tokens_list.len(); + + for tokens in tokens_list.iter() { + if tokens.len() > max_size { + max_size = tokens.len(); + } + if tokens.len() >= max_seq_lenght { + max_size = max_seq_lenght; + break; + } + } + + let mut tensor = Tensor::zeros([batch_size, max_size]); + tensor = tensor.add_scalar(pad_token as i64); + + for (index, tokens) in tokens_list.into_iter().enumerate() { + let mut seq_length = tokens.len(); + let mut tokens = tokens; + if seq_length > max_seq_lenght { + seq_length = max_seq_lenght; + let _ = tokens.split_off(seq_length); + } + tensor = tensor.index_assign( + [index..index + 1, 0..tokens.len()], + &Tensor::from_data(Data::new( + tokens.into_iter().map(|e| e as i64).collect(), + Shape::new([1, seq_length]), + )), + ); + } + + let mask_pad = BoolTensor::from_int_backend(tensor.equal_scalar(pad_token as i64)); + + (tensor, mask_pad) +} diff --git a/examples/text-classification/src/data/dataset.rs b/examples/text-classification/src/data/dataset.rs new file mode 100644 index 000000000..bd18f9a1f --- /dev/null +++ b/examples/text-classification/src/data/dataset.rs @@ -0,0 +1,150 @@ +use burn::data::dataset::{ + source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset, +}; + +#[derive(new, Clone, Debug)] +pub struct TextClassificationItem { + pub text: String, + pub label: usize, +} + +pub trait TextClassificationDataset: Dataset { + fn num_classes() -> usize; + fn class_name(label: usize) -> String; +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct AgNewsItem { + pub text: String, + pub label: usize, +} + +pub struct AgNewsDataset { + dataset: InMemDataset, +} + +impl Dataset for AgNewsDataset { + fn get(&self, index: usize) -> Option { + self.dataset + .get(index) + .map(|item| TextClassificationItem::new(item.text, item.label)) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + +impl AgNewsDataset { + pub fn train() -> Self { + let dataset: InMemDataset = HuggingfaceDatasetLoader::new("ag_news", "train") + .extract_string("text") + .extract_number("label") + .load_in_memory() + .unwrap(); + Self { dataset } + } + + pub fn test() -> Self { + let dataset: InMemDataset = HuggingfaceDatasetLoader::new("ag_news", "test") + .extract_string("text") + .extract_number("label") + .load_in_memory() + .unwrap(); + Self { dataset } + } +} + +impl TextClassificationDataset for AgNewsDataset { + fn num_classes() -> usize { + 4 + } + + fn class_name(label: usize) -> String { + match label { + 0 => "World", + 1 => "Sports", + 2 => "Business", + 3 => "Technology", + _ => panic!("invalid class"), + } + .to_string() + } +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct DbPediaItem { + pub title: String, + pub content: String, + pub label: usize, +} + +pub struct DbPediaDataset { + dataset: InMemDataset, +} + +impl Dataset for DbPediaDataset { + fn get(&self, index: usize) -> Option { + self.dataset.get(index).map(|item| { + TextClassificationItem::new( + format!("Title: {} - Content: {}", item.title, item.content), + item.label, + ) + }) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + +impl DbPediaDataset { + pub fn train() -> Self { + let dataset: InMemDataset = + HuggingfaceDatasetLoader::new("dbpedia_14", "train") + .extract_string("title") + .extract_string("content") + .extract_number("label") + .load_in_memory() + .unwrap(); + Self { dataset } + } + + pub fn test() -> Self { + let dataset: InMemDataset = + HuggingfaceDatasetLoader::new("dbpedia_14", "test") + .extract_string("title") + .extract_string("content") + .extract_number("label") + .load_in_memory() + .unwrap(); + Self { dataset } + } +} + +impl TextClassificationDataset for DbPediaDataset { + fn num_classes() -> usize { + 14 + } + + fn class_name(label: usize) -> String { + match label { + 0 => "Company", + 1 => "EducationalInstitution", + 2 => "Artist", + 3 => "Athlete", + 4 => "OfficeHolder", + 5 => "MeanOfTransportation", + 6 => "Building", + 7 => "NaturalPlace", + 8 => "Village", + 9 => "Animal", + 10 => "Plant", + 11 => "Album", + 12 => "Film", + 13 => "WrittenWork", + _ => panic!("invalid class"), + } + .to_string() + } +} diff --git a/examples/text-classification/src/data/mod.rs b/examples/text-classification/src/data/mod.rs new file mode 100644 index 000000000..8f6faeb3c --- /dev/null +++ b/examples/text-classification/src/data/mod.rs @@ -0,0 +1,7 @@ +mod batcher; +mod dataset; +mod tokenizer; + +pub use batcher::*; +pub use dataset::*; +pub use tokenizer::*; diff --git a/examples/text-classification/src/data/tokenizer.rs b/examples/text-classification/src/data/tokenizer.rs new file mode 100644 index 000000000..6ad09e1f0 --- /dev/null +++ b/examples/text-classification/src/data/tokenizer.rs @@ -0,0 +1,42 @@ +pub trait Tokenizer: Send + Sync { + fn encode(&self, value: &str) -> Vec; + fn decode(&self, tokens: &[usize]) -> String; + fn vocab_size(&self) -> usize; + fn pad_token(&self) -> usize; + fn pad_token_value(&self) -> String { + self.decode(&[self.pad_token()]) + } +} + +pub struct BertCasedTokenizer { + tokenizer: tokenizers::Tokenizer, +} + +impl Default for BertCasedTokenizer { + fn default() -> Self { + Self { + tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), + } + } +} + +impl Tokenizer for BertCasedTokenizer { + fn encode(&self, value: &str) -> Vec { + let tokens = self.tokenizer.encode(value, true).unwrap(); + tokens.get_ids().iter().map(|t| *t as usize).collect() + } + + fn decode(&self, tokens: &[usize]) -> String { + self.tokenizer + .decode(tokens.iter().map(|t| *t as u32).collect(), false) + .unwrap() + } + + fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } + + fn pad_token(&self) -> usize { + self.tokenizer.token_to_id("[PAD]").unwrap() as usize + } +} diff --git a/examples/text-classification/src/lib.rs b/examples/text-classification/src/lib.rs new file mode 100644 index 000000000..b9cdcb505 --- /dev/null +++ b/examples/text-classification/src/lib.rs @@ -0,0 +1,8 @@ +#[macro_use] +extern crate derive_new; + +mod data; +mod model; + +pub mod training; +pub use data::{AgNewsDataset, DbPediaDataset}; diff --git a/examples/text-classification/src/model.rs b/examples/text-classification/src/model.rs new file mode 100644 index 000000000..f1b21c3a4 --- /dev/null +++ b/examples/text-classification/src/model.rs @@ -0,0 +1,104 @@ +use crate::data::TextClassificationBatch; +use burn::{ + config::Config, + module::{Module, Param}, + nn::{ + transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, + Embedding, EmbeddingConfig, Linear, LinearConfig, + }, + tensor::backend::{ADBackend, Backend}, + tensor::{loss::cross_entropy_with_logits, Tensor}, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, +}; + +#[derive(Config)] +pub struct TextClassificationModelConfig { + transformer: TransformerEncoderConfig, + n_classes: usize, + vocab_size: usize, + max_seq_length: usize, +} + +#[derive(Module, Debug)] +pub struct TextClassificationModel { + transformer: Param>, + embedding_token: Param>, + embedding_pos: Param>, + output: Param>, + n_classes: usize, + max_seq_length: usize, +} + +impl TextClassificationModel { + pub fn new(config: &TextClassificationModelConfig) -> Self { + let config_embedding_token = + EmbeddingConfig::new(config.vocab_size, config.transformer.d_model); + let config_embedding_pos = + EmbeddingConfig::new(config.max_seq_length, config.transformer.d_model); + let config_output = LinearConfig::new(config.transformer.d_model, config.n_classes); + + let transformer = TransformerEncoder::new(&config.transformer); + let embedding_token = Embedding::new(&config_embedding_token); + let embedding_pos = Embedding::new(&config_embedding_pos); + let output = Linear::new(&config_output); + + Self { + transformer: Param::new(transformer), + embedding_token: Param::new(embedding_token), + embedding_pos: Param::new(embedding_pos), + output: Param::new(output), + n_classes: config.n_classes, + max_seq_length: config.max_seq_length, + } + } + + pub fn forward(&self, item: TextClassificationBatch) -> ClassificationOutput { + let [batch_size, seq_length] = item.tokens.dims(); + + let index_positions = Tensor::::arange_device(0..seq_length, item.tokens.device()) + .reshape([1, seq_length]) + .repeat(0, batch_size); + let embedding_positions = self.embedding_pos.forward(index_positions.detach()); + let embedding_tokens = self.embedding_token.forward(item.tokens.detach()); + let embedding = (embedding_positions + embedding_tokens) / 2; + + let encoded = self + .transformer + .forward(TransformerEncoderInput::new(embedding).mask_pad(item.mask_pad)); + let output = self.output.forward(encoded); + + let output_classification = output + .index([0..batch_size, 0..1]) + .reshape([batch_size, self.n_classes]); + + let loss = cross_entropy_with_logits(&output_classification, &item.labels.clone().detach()); + + ClassificationOutput { + loss, + output: output_classification, + targets: item.labels, + } + } +} + +impl TrainStep, ClassificationOutput, B::Gradients> + for TextClassificationModel +{ + fn step( + &self, + item: TextClassificationBatch, + ) -> TrainOutput, B::Gradients> { + let item = self.forward(item); + let grads = item.loss.backward(); + + TrainOutput::new(grads, item) + } +} + +impl ValidStep, ClassificationOutput> + for TextClassificationModel +{ + fn step(&self, item: TextClassificationBatch) -> ClassificationOutput { + self.forward(item) + } +} diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs new file mode 100644 index 000000000..2302fb742 --- /dev/null +++ b/examples/text-classification/src/training.rs @@ -0,0 +1,99 @@ +use crate::{ + data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, + model::{TextClassificationModel, TextClassificationModelConfig}, +}; +use burn::{ + config::Config, + data::dataloader::DataLoaderBuilder, + module::Module, + nn::transformer::TransformerEncoderConfig, + optim::{Sgd, SgdConfig}, + tensor::backend::ADBackend, + train::{ + metric::{AccuracyMetric, CUDAMetric, LossMetric}, + LearnerBuilder, + }, +}; +use std::sync::Arc; + +#[derive(Config)] +pub struct ExperimentConfig { + transformer: TransformerEncoderConfig, + optimizer: SgdConfig, + #[config(default = 256)] + max_seq_length: usize, + #[config(default = 32)] + batch_size: usize, + #[config(default = 10)] + num_epochs: usize, +} + +pub fn train( + device: B::Device, + dataset_train: D, + dataset_test: D, + config: ExperimentConfig, + artifact_dir: &str, +) { + let dataset_train = Arc::new(dataset_train); + let dataset_test = Arc::new(dataset_test); + let n_classes = D::num_classes(); + + let tokenizer = Arc::new(BertCasedTokenizer::default()); + let batcher_train = Arc::new(TextClassificationBatcher::::new( + tokenizer.clone(), + n_classes, + device, + config.max_seq_length, + )); + let batcher_test = Arc::new(TextClassificationBatcher::::new( + tokenizer.clone(), + n_classes, + device, + config.max_seq_length, + )); + + let mut model = TextClassificationModel::new(&TextClassificationModelConfig::new( + config.transformer.clone(), + n_classes, + tokenizer.vocab_size(), + config.max_seq_length, + )); + model.to_device(device); + model.detach(); + + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .num_workers(8) + .shuffle(42) + .build(dataset_train); + + let dataloader_test = DataLoaderBuilder::new(batcher_test) + .batch_size(config.batch_size) + .num_workers(8) + .build(dataset_test); + + let optim = Sgd::new(&config.optimizer); + + let learner = LearnerBuilder::new(artifact_dir) + .metric_train(CUDAMetric::new()) + .metric_valid(CUDAMetric::new()) + .metric_train(AccuracyMetric::new()) + .metric_valid(AccuracyMetric::new()) + .metric_train_plot(LossMetric::new()) + .metric_valid_plot(LossMetric::new()) + .with_file_checkpointer::(2) + .num_epochs(config.num_epochs) + .build(model, optim); + + let model_trained = learner.fit(dataloader_train, dataloader_test); + + config + .save(&format!("{}/config.json", artifact_dir)) + .unwrap(); + model_trained + .state() + .convert::() + .save(&format!("{}/model.json.gz", artifact_dir)) + .unwrap(); +}