feat: cross entropy loss (#130)

This commit is contained in:
Nathaniel Simard 2022-12-25 10:10:22 -05:00 committed by GitHub
parent 1a1d86dc3e
commit 3a9dfe6097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 131 additions and 55 deletions

View File

@ -158,7 +158,8 @@ impl<P: std::fmt::Debug, const D: usize> Data<P, D>
where
P: Zeros + Default,
{
pub fn zeros(shape: Shape<D>) -> Data<P, D> {
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<P, D> {
let shape = shape.into();
let elem = P::default();
let num_elements = shape.num_elements();
let mut data = Vec::with_capacity(num_elements);

View File

@ -0,0 +1,90 @@
use burn_tensor::{backend::Backend, loss::cross_entropy_with_logits, Tensor};
/// Calculate the cross entropy loss from the input logits and the targets.
pub struct CrossEntropyLoss<B: Backend> {
num_targets: usize,
pad_index: Option<usize>,
_b: B,
}
impl<B: Backend> CrossEntropyLoss<B> {
/// Create the criterion.
///
/// # Notes
///
/// The number of targets must be specified, this correspond to the number of classes in a
/// classification task. A padding index can also be specified.
pub fn new(num_targets: usize, pad_index: Option<usize>) -> Self {
Self {
num_targets,
pad_index,
_b: B::default(),
}
}
/// Compute the criterion on the input tensor.
///
/// # Shapes
///
/// - logits: [batch_size, num_targets]
/// - targets: [batch_size]
pub fn forward(
&self,
logits: &Tensor<B, 2>,
targets: &Tensor<B::IntegerBackend, 1>,
) -> Tensor<B, 1> {
let device = logits.device();
let [batch_size] = targets.dims();
let indexes = targets.to_data();
let mut targets_logits =
Tensor::<B, 2>::zeros_device([batch_size, self.num_targets], device);
for b in 0..batch_size {
let index = indexes.value[b] as usize;
if let Some(pad_index) = self.pad_index {
if index == pad_index {
continue;
}
}
targets_logits = targets_logits.index_assign(
[b..b + 1, index..index + 1],
&Tensor::ones_device([1, 1], device),
);
}
cross_entropy_with_logits(logits, &targets_logits.detach())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::{Data, Distribution};
#[test]
fn test_cross_entropy_loss() {
let [batch_size, num_targets] = [4, 5];
let logits = Tensor::<TestBackend, 2>::random(
[batch_size, num_targets],
Distribution::Normal(0., 1.0),
);
let targets =
Tensor::<<TestBackend as Backend>::IntegerBackend, 1>::from_data(Data::from([
2, 0, 4, 1_i64,
]));
let targets_logits = Tensor::<TestBackend, 2>::from_data(Data::from([
[0.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
]));
let loss_1 = CrossEntropyLoss::new(5, None).forward(&logits, &targets);
let loss_2 = cross_entropy_with_logits(&logits, &targets_logits);
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
}
}

3
burn/src/nn/loss/mod.rs Normal file
View File

@ -0,0 +1,3 @@
mod cross_entropy;
pub use cross_entropy::*;

View File

@ -1,5 +1,6 @@
pub mod attention;
pub mod cache;
pub mod loss;
pub mod transformer;
mod dropout;

View File

@ -6,7 +6,7 @@ use burn_tensor::Tensor;
pub struct ClassificationOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub output: Tensor<B, 2>,
pub targets: Tensor<B, 2>,
pub targets: Tensor<B::IntegerBackend, 1>,
}
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric {
@ -24,6 +24,6 @@ impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMet
}
fn clear(&mut self) {
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B, 2>)>>::clear(self);
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)>>::clear(self);
}
}

View File

@ -31,19 +31,18 @@ impl Numeric for AccuracyMetric {
}
}
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B, 2>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B, 2>)) -> MetricStateDyn {
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)) -> MetricStateDyn {
let (outputs, targets) = batch;
let logits_outputs = outputs.argmax(1).to_device(B::Device::default());
let logits_targets = targets.argmax(1).to_device(B::Device::default());
let count_current = logits_targets.shape().dims[0];
let count_current = outputs.dims()[0];
let total_current = logits_outputs
.equal(&logits_targets)
.to_int()
.sum()
.to_data()
.value[0] as usize;
let targets = targets.to_device(B::Device::default());
let outputs = outputs
.argmax(1)
.to_device(B::Device::default())
.reshape([count_current]);
let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;
self.count += count_current;
self.total += total_current;

View File

@ -10,7 +10,7 @@ pub struct MNISTBatcher<B: Backend> {
#[derive(Clone, Debug)]
pub struct MNISTBatch<B: Backend> {
pub images: Tensor<B, 2>,
pub targets: Tensor<B, 2>,
pub targets: Tensor<B::IntegerBackend, 1>,
}
impl<B: Backend> MNISTBatcher<B> {
@ -31,7 +31,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
let targets = items
.iter()
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
.map(|item| Tensor::<B::IntegerBackend, 1>::from_data(Data::from([item.label as i64])))
.collect();
let images = Tensor::cat(images, 0).to_device(self.device).detach();

View File

@ -5,11 +5,10 @@ use crate::{
use burn::{
config::Config,
module::{Module, Param},
nn,
nn::{self, loss::CrossEntropyLoss},
optim::SgdConfig,
tensor::{
backend::{ADBackend, Backend},
loss::cross_entropy_with_logits,
Tensor,
},
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
@ -34,6 +33,7 @@ pub struct Model<B: Backend> {
mlp: Param<Mlp<B>>,
input: Param<nn::Linear<B>>,
output: Param<nn::Linear<B>>,
num_classes: usize,
}
impl<B: Backend> Model<B> {
@ -46,6 +46,7 @@ impl<B: Backend> Model<B> {
mlp: Param::new(mlp),
output: Param::new(output),
input: Param::new(input),
num_classes,
}
}
@ -62,7 +63,8 @@ impl<B: Backend> Model<B> {
pub fn forward_classification(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
let targets = item.targets;
let output = self.forward(item.images);
let loss = cross_entropy_with_logits(&output, &targets);
let loss = CrossEntropyLoss::new(self.num_classes, None);
let loss = loss.forward(&output, &targets);
ClassificationOutput {
loss,

View File

@ -2,14 +2,13 @@ use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
use burn::{
data::dataloader::batcher::Batcher,
nn::attention::generate_padding_mask,
tensor::{backend::Backend, BoolTensor, Tensor},
tensor::{backend::Backend, BoolTensor, Data, Tensor},
};
use std::sync::Arc;
#[derive(new)]
pub struct TextClassificationBatcher<B: Backend> {
tokenizer: Arc<dyn Tokenizer>,
num_classes: usize,
device: B::Device,
max_seq_lenght: usize,
}
@ -17,7 +16,7 @@ pub struct TextClassificationBatcher<B: Backend> {
#[derive(Debug, Clone, new)]
pub struct TextClassificationBatch<B: Backend> {
pub tokens: Tensor<B::IntegerBackend, 2>,
pub labels: Tensor<B, 2>,
pub labels: Tensor<B::IntegerBackend, 1>,
pub mask_pad: BoolTensor<B, 2>,
}
@ -30,7 +29,7 @@ impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>
for item in items {
tokens_list.push(self.tokenizer.encode(&item.text));
labels_list.push(Tensor::one_hot(item.label, self.num_classes));
labels_list.push(Tensor::from_data(Data::from([item.label as i64])));
}
let mask = generate_padding_mask(

View File

@ -3,11 +3,12 @@ use burn::{
config::Config,
module::{Module, Param},
nn::{
loss::CrossEntropyLoss,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Embedding, EmbeddingConfig, Linear, LinearConfig,
},
tensor::backend::{ADBackend, Backend},
tensor::{loss::cross_entropy_with_logits, Tensor},
tensor::Tensor,
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};
@ -77,7 +78,8 @@ impl<B: Backend> TextClassificationModel<B> {
.index([0..batch_size, 0..1])
.reshape([batch_size, self.n_classes]);
let loss = cross_entropy_with_logits(&output_classification, &labels.clone().detach());
let loss = CrossEntropyLoss::new(self.n_classes, None);
let loss = loss.forward(&output_classification, &labels);
ClassificationOutput {
loss,

View File

@ -42,13 +42,11 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let tokenizer = Arc::new(BertCasedTokenizer::default());
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
tokenizer.clone(),
n_classes,
device,
config.max_seq_length,
));
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
tokenizer.clone(),
n_classes,
device,
config.max_seq_length,
));

View File

@ -9,8 +9,6 @@ use std::sync::Arc;
#[derive(new)]
pub struct TextGenerationBatcher {
tokenizer: Arc<dyn Tokenizer>,
vocab_size: usize,
pad_token: usize,
max_seq_lenght: usize,
}
@ -23,7 +21,7 @@ pub struct TextGenerationBatch<B: Backend> {
#[derive(Debug, Clone, new)]
pub struct TrainingTextGenerationBatch<B: Backend> {
pub tokens_inputs: Tensor<B::IntegerBackend, 2>,
pub targets: Tensor<B, 2>,
pub targets: Tensor<B::IntegerBackend, 2>,
pub mask_pad: BoolTensor<B, 2>,
}
@ -60,21 +58,6 @@ impl<B: Backend> Batcher<TextGenerationItem, TrainingTextGenerationBatch<B>>
let targets = item.tokens.index([0..batch_size, 1..seq_length]);
let mask_pad = item.mask_pad.index([0..batch_size, 0..seq_length - 1]);
let seq_length = seq_length - 1;
let targets = targets
.reshape([batch_size * seq_length])
.to_data()
.value
.iter()
.map(|index| match *index as usize == self.pad_token {
true => Tensor::<B, 2>::zeros([1, self.vocab_size]),
false => Tensor::<B, 2>::one_hot(*index as usize, self.vocab_size),
})
.collect();
let targets = Tensor::cat(targets, 0);
TrainingTextGenerationBatch::new(inputs, targets, mask_pad)
}
}

View File

@ -4,11 +4,12 @@ use burn::{
module::{Module, Param},
nn::{
attention::generate_autoregressive_mask,
loss::CrossEntropyLoss,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Embedding, EmbeddingConfig, Linear, LinearConfig,
},
tensor::backend::{ADBackend, Backend},
tensor::{loss::cross_entropy_with_logits, Tensor},
tensor::Tensor,
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};
@ -83,15 +84,16 @@ impl<B: Backend> TextClassificationModel<B> {
);
let output = self.output.forward(encoded);
let output_classification = output.reshape([batch_size * seq_length, self.vocab_size]);
let targets = item.targets.to_device(device).detach();
let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]);
let targets_flatten = item.targets.reshape([batch_size * seq_length]);
let loss = cross_entropy_with_logits(&output_classification, &targets);
let loss = CrossEntropyLoss::new(self.vocab_size, Some(self.pad_token));
let loss = loss.forward(&output_flatten, &targets_flatten);
ClassificationOutput {
loss,
output: output_classification,
targets,
output: output_flatten,
targets: targets_flatten,
}
}
}

View File

@ -42,14 +42,10 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
let tokenizer = Arc::new(Gpt2Tokenizer::default());
let batcher_train = Arc::new(TextGenerationBatcher::new(
tokenizer.clone(),
tokenizer.vocab_size(),
tokenizer.pad_token(),
config.max_seq_length,
));
let batcher_test = Arc::new(TextGenerationBatcher::new(
tokenizer.clone(),
tokenizer.vocab_size(),
tokenizer.pad_token(),
config.max_seq_length,
));