Example/text classification (#123)

This commit is contained in:
Nathaniel Simard 2022-12-02 17:42:49 -05:00 committed by GitHub
parent 7c38a980c1
commit eee90a5c9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 578 additions and 3 deletions

View File

@ -42,9 +42,8 @@ This may also be a good idea to take a look the main [components](#components) o
### Examples
For now there is only one example, but more to come 💪..
* [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using different backends.
* [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification) train a transformer encoder from scratch on GPU.
### Components

View File

@ -83,6 +83,6 @@ fn parse_asm(ast: &syn::DeriveInput) -> ConfigType {
ConfigType::Struct(struct_data.fields.clone().into_iter().collect())
}
syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()),
syn::Data::Union(_) => panic!("Only struct cna be derived"),
syn::Data::Union(_) => panic!("Only struct and enum can be derived"),
}
}

View File

@ -0,0 +1,23 @@
[package]
name = "text-classification"
version = "0.1.0"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
license = "MIT/Apache-2.0"
edition = "2021"
publish = false
[features]
default = []
[dependencies]
# Burn
burn = { path = "../../burn" }
burn-autodiff = { path = "../../burn-autodiff" }
burn-tch = { path = "../../burn-tch" }
# Tokenizer
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] }
# Utils
derive-new = "0.5"
serde = { version = "1.0", features = ["derive"] }

View File

@ -0,0 +1,12 @@
# Text Classification
The example can be run like so:
```bash
git clone https://github.com/burn-rs/burn.git
cd burn
# Use the --release flag to really speed up training.
export TORCH_CUDA_VERSION=cu113 # Set the cuda version
cargo run --example text-classification-ag-news --release # Train on the ag news dataset
cargo run --example text-classification-db-pedia --release # Train on the db pedia dataset
```

View File

@ -0,0 +1,23 @@
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig};
use text_classification::{training::ExperimentConfig, AgNewsDataset};
type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>;
fn main() {
let config = ExperimentConfig::new(
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4),
burn::optim::SgdConfig::new()
.with_learning_rate(5.0e-3)
.with_momentum(None)
.with_weight_decay(Some(WeightDecayConfig::new(5e-4)))
.with_momentum(Some(MomentumConfig::new().with_nesterov(true))),
);
text_classification::training::train::<Backend, AgNewsDataset>(
burn_tch::TchDevice::Cuda(0),
AgNewsDataset::train(),
AgNewsDataset::test(),
config,
"/tmp/text-classification-ag-news",
);
}

View File

@ -0,0 +1,22 @@
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig};
use text_classification::{training::ExperimentConfig, DbPediaDataset};
type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>;
fn main() {
let config = ExperimentConfig::new(
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4),
burn::optim::SgdConfig::new()
.with_learning_rate(5.0e-3)
.with_momentum(Some(MomentumConfig::new().with_nesterov(true)))
.with_weight_decay(Some(WeightDecayConfig::new(5e-4))),
);
text_classification::training::train::<Backend, DbPediaDataset>(
burn_tch::TchDevice::Cuda(0),
DbPediaDataset::train(),
DbPediaDataset::test(),
config,
"/tmp/text-classification-db-pedia",
);
}

View File

@ -0,0 +1,86 @@
use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
use burn::{
data::dataloader::batcher::Batcher,
tensor::{backend::Backend, BoolTensor, Data, Shape, Tensor},
};
use std::sync::Arc;
#[derive(new)]
pub struct TextClassificationBatcher<B: Backend> {
tokenizer: Arc<dyn Tokenizer>,
num_classes: usize,
device: B::Device,
max_seq_lenght: usize,
}
#[derive(Debug, Clone, new)]
pub struct TextClassificationBatch<B: Backend> {
pub tokens: Tensor<B::IntegerBackend, 2>,
pub labels: Tensor<B, 2>,
pub mask_pad: BoolTensor<B, 2>,
}
impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>
for TextClassificationBatcher<B>
{
fn batch(&self, items: Vec<TextClassificationItem>) -> TextClassificationBatch<B> {
let mut tokens_list = Vec::with_capacity(items.len());
let mut labels_list = Vec::with_capacity(items.len());
for item in items {
tokens_list.push(self.tokenizer.encode(&item.text));
labels_list.push(Tensor::one_hot(item.label, self.num_classes));
}
let (tokens, mask_pad) =
pad_tokens::<B>(self.tokenizer.pad_token(), tokens_list, self.max_seq_lenght);
TextClassificationBatch {
tokens: tokens.to_device(self.device).detach(),
labels: Tensor::cat(labels_list, 0).to_device(self.device).detach(),
mask_pad: mask_pad.to_device(self.device),
}
}
}
pub fn pad_tokens<B: Backend>(
pad_token: usize,
tokens_list: Vec<Vec<usize>>,
max_seq_lenght: usize,
) -> (Tensor<B::IntegerBackend, 2>, BoolTensor<B, 2>) {
let mut max_size = 0;
let batch_size = tokens_list.len();
for tokens in tokens_list.iter() {
if tokens.len() > max_size {
max_size = tokens.len();
}
if tokens.len() >= max_seq_lenght {
max_size = max_seq_lenght;
break;
}
}
let mut tensor = Tensor::zeros([batch_size, max_size]);
tensor = tensor.add_scalar(pad_token as i64);
for (index, tokens) in tokens_list.into_iter().enumerate() {
let mut seq_length = tokens.len();
let mut tokens = tokens;
if seq_length > max_seq_lenght {
seq_length = max_seq_lenght;
let _ = tokens.split_off(seq_length);
}
tensor = tensor.index_assign(
[index..index + 1, 0..tokens.len()],
&Tensor::from_data(Data::new(
tokens.into_iter().map(|e| e as i64).collect(),
Shape::new([1, seq_length]),
)),
);
}
let mask_pad = BoolTensor::from_int_backend(tensor.equal_scalar(pad_token as i64));
(tensor, mask_pad)
}

View File

@ -0,0 +1,150 @@
use burn::data::dataset::{
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset,
};
#[derive(new, Clone, Debug)]
pub struct TextClassificationItem {
pub text: String,
pub label: usize,
}
pub trait TextClassificationDataset: Dataset<TextClassificationItem> {
fn num_classes() -> usize;
fn class_name(label: usize) -> String;
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct AgNewsItem {
pub text: String,
pub label: usize,
}
pub struct AgNewsDataset {
dataset: InMemDataset<AgNewsItem>,
}
impl Dataset<TextClassificationItem> for AgNewsDataset {
fn get(&self, index: usize) -> Option<TextClassificationItem> {
self.dataset
.get(index)
.map(|item| TextClassificationItem::new(item.text, item.label))
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl AgNewsDataset {
pub fn train() -> Self {
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "train")
.extract_string("text")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}
pub fn test() -> Self {
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "test")
.extract_string("text")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}
}
impl TextClassificationDataset for AgNewsDataset {
fn num_classes() -> usize {
4
}
fn class_name(label: usize) -> String {
match label {
0 => "World",
1 => "Sports",
2 => "Business",
3 => "Technology",
_ => panic!("invalid class"),
}
.to_string()
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DbPediaItem {
pub title: String,
pub content: String,
pub label: usize,
}
pub struct DbPediaDataset {
dataset: InMemDataset<DbPediaItem>,
}
impl Dataset<TextClassificationItem> for DbPediaDataset {
fn get(&self, index: usize) -> Option<TextClassificationItem> {
self.dataset.get(index).map(|item| {
TextClassificationItem::new(
format!("Title: {} - Content: {}", item.title, item.content),
item.label,
)
})
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl DbPediaDataset {
pub fn train() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "train")
.extract_string("title")
.extract_string("content")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}
pub fn test() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "test")
.extract_string("title")
.extract_string("content")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}
}
impl TextClassificationDataset for DbPediaDataset {
fn num_classes() -> usize {
14
}
fn class_name(label: usize) -> String {
match label {
0 => "Company",
1 => "EducationalInstitution",
2 => "Artist",
3 => "Athlete",
4 => "OfficeHolder",
5 => "MeanOfTransportation",
6 => "Building",
7 => "NaturalPlace",
8 => "Village",
9 => "Animal",
10 => "Plant",
11 => "Album",
12 => "Film",
13 => "WrittenWork",
_ => panic!("invalid class"),
}
.to_string()
}
}

View File

@ -0,0 +1,7 @@
mod batcher;
mod dataset;
mod tokenizer;
pub use batcher::*;
pub use dataset::*;
pub use tokenizer::*;

View File

@ -0,0 +1,42 @@
pub trait Tokenizer: Send + Sync {
fn encode(&self, value: &str) -> Vec<usize>;
fn decode(&self, tokens: &[usize]) -> String;
fn vocab_size(&self) -> usize;
fn pad_token(&self) -> usize;
fn pad_token_value(&self) -> String {
self.decode(&[self.pad_token()])
}
}
pub struct BertCasedTokenizer {
tokenizer: tokenizers::Tokenizer,
}
impl Default for BertCasedTokenizer {
fn default() -> Self {
Self {
tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(),
}
}
}
impl Tokenizer for BertCasedTokenizer {
fn encode(&self, value: &str) -> Vec<usize> {
let tokens = self.tokenizer.encode(value, true).unwrap();
tokens.get_ids().iter().map(|t| *t as usize).collect()
}
fn decode(&self, tokens: &[usize]) -> String {
self.tokenizer
.decode(tokens.iter().map(|t| *t as u32).collect(), false)
.unwrap()
}
fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}
fn pad_token(&self) -> usize {
self.tokenizer.token_to_id("[PAD]").unwrap() as usize
}
}

View File

@ -0,0 +1,8 @@
#[macro_use]
extern crate derive_new;
mod data;
mod model;
pub mod training;
pub use data::{AgNewsDataset, DbPediaDataset};

View File

@ -0,0 +1,104 @@
use crate::data::TextClassificationBatch;
use burn::{
config::Config,
module::{Module, Param},
nn::{
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Embedding, EmbeddingConfig, Linear, LinearConfig,
},
tensor::backend::{ADBackend, Backend},
tensor::{loss::cross_entropy_with_logits, Tensor},
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};
#[derive(Config)]
pub struct TextClassificationModelConfig {
transformer: TransformerEncoderConfig,
n_classes: usize,
vocab_size: usize,
max_seq_length: usize,
}
#[derive(Module, Debug)]
pub struct TextClassificationModel<B: Backend> {
transformer: Param<TransformerEncoder<B>>,
embedding_token: Param<Embedding<B>>,
embedding_pos: Param<Embedding<B>>,
output: Param<Linear<B>>,
n_classes: usize,
max_seq_length: usize,
}
impl<B: Backend> TextClassificationModel<B> {
pub fn new(config: &TextClassificationModelConfig) -> Self {
let config_embedding_token =
EmbeddingConfig::new(config.vocab_size, config.transformer.d_model);
let config_embedding_pos =
EmbeddingConfig::new(config.max_seq_length, config.transformer.d_model);
let config_output = LinearConfig::new(config.transformer.d_model, config.n_classes);
let transformer = TransformerEncoder::new(&config.transformer);
let embedding_token = Embedding::new(&config_embedding_token);
let embedding_pos = Embedding::new(&config_embedding_pos);
let output = Linear::new(&config_output);
Self {
transformer: Param::new(transformer),
embedding_token: Param::new(embedding_token),
embedding_pos: Param::new(embedding_pos),
output: Param::new(output),
n_classes: config.n_classes,
max_seq_length: config.max_seq_length,
}
}
pub fn forward(&self, item: TextClassificationBatch<B>) -> ClassificationOutput<B> {
let [batch_size, seq_length] = item.tokens.dims();
let index_positions = Tensor::<B, 1>::arange_device(0..seq_length, item.tokens.device())
.reshape([1, seq_length])
.repeat(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions.detach());
let embedding_tokens = self.embedding_token.forward(item.tokens.detach());
let embedding = (embedding_positions + embedding_tokens) / 2;
let encoded = self
.transformer
.forward(TransformerEncoderInput::new(embedding).mask_pad(item.mask_pad));
let output = self.output.forward(encoded);
let output_classification = output
.index([0..batch_size, 0..1])
.reshape([batch_size, self.n_classes]);
let loss = cross_entropy_with_logits(&output_classification, &item.labels.clone().detach());
ClassificationOutput {
loss,
output: output_classification,
targets: item.labels,
}
}
}
impl<B: ADBackend> TrainStep<TextClassificationBatch<B>, ClassificationOutput<B>, B::Gradients>
for TextClassificationModel<B>
{
fn step(
&self,
item: TextClassificationBatch<B>,
) -> TrainOutput<ClassificationOutput<B>, B::Gradients> {
let item = self.forward(item);
let grads = item.loss.backward();
TrainOutput::new(grads, item)
}
}
impl<B: Backend> ValidStep<TextClassificationBatch<B>, ClassificationOutput<B>>
for TextClassificationModel<B>
{
fn step(&self, item: TextClassificationBatch<B>) -> ClassificationOutput<B> {
self.forward(item)
}
}

View File

@ -0,0 +1,99 @@
use crate::{
data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer},
model::{TextClassificationModel, TextClassificationModelConfig},
};
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,
module::Module,
nn::transformer::TransformerEncoderConfig,
optim::{Sgd, SgdConfig},
tensor::backend::ADBackend,
train::{
metric::{AccuracyMetric, CUDAMetric, LossMetric},
LearnerBuilder,
},
};
use std::sync::Arc;
#[derive(Config)]
pub struct ExperimentConfig {
transformer: TransformerEncoderConfig,
optimizer: SgdConfig,
#[config(default = 256)]
max_seq_length: usize,
#[config(default = 32)]
batch_size: usize,
#[config(default = 10)]
num_epochs: usize,
}
pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
device: B::Device,
dataset_train: D,
dataset_test: D,
config: ExperimentConfig,
artifact_dir: &str,
) {
let dataset_train = Arc::new(dataset_train);
let dataset_test = Arc::new(dataset_test);
let n_classes = D::num_classes();
let tokenizer = Arc::new(BertCasedTokenizer::default());
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
tokenizer.clone(),
n_classes,
device,
config.max_seq_length,
));
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
tokenizer.clone(),
n_classes,
device,
config.max_seq_length,
));
let mut model = TextClassificationModel::new(&TextClassificationModelConfig::new(
config.transformer.clone(),
n_classes,
tokenizer.vocab_size(),
config.max_seq_length,
));
model.to_device(device);
model.detach();
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.num_workers(8)
.shuffle(42)
.build(dataset_train);
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(config.batch_size)
.num_workers(8)
.build(dataset_test);
let optim = Sgd::new(&config.optimizer);
let learner = LearnerBuilder::new(artifact_dir)
.metric_train(CUDAMetric::new())
.metric_valid(CUDAMetric::new())
.metric_train(AccuracyMetric::new())
.metric_valid(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<f32>(2)
.num_epochs(config.num_epochs)
.build(model, optim);
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(&format!("{}/config.json", artifact_dir))
.unwrap();
model_trained
.state()
.convert::<f32>()
.save(&format!("{}/model.json.gz", artifact_dir))
.unwrap();
}