mirror of https://github.com/tracel-ai/burn.git
feat: cross entropy loss (#130)
This commit is contained in:
parent
1a1d86dc3e
commit
3a9dfe6097
|
@ -158,7 +158,8 @@ impl<P: std::fmt::Debug, const D: usize> Data<P, D>
|
|||
where
|
||||
P: Zeros + Default,
|
||||
{
|
||||
pub fn zeros(shape: Shape<D>) -> Data<P, D> {
|
||||
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<P, D> {
|
||||
let shape = shape.into();
|
||||
let elem = P::default();
|
||||
let num_elements = shape.num_elements();
|
||||
let mut data = Vec::with_capacity(num_elements);
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
use burn_tensor::{backend::Backend, loss::cross_entropy_with_logits, Tensor};
|
||||
|
||||
/// Calculate the cross entropy loss from the input logits and the targets.
|
||||
pub struct CrossEntropyLoss<B: Backend> {
|
||||
num_targets: usize,
|
||||
pad_index: Option<usize>,
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B: Backend> CrossEntropyLoss<B> {
|
||||
/// Create the criterion.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The number of targets must be specified, this correspond to the number of classes in a
|
||||
/// classification task. A padding index can also be specified.
|
||||
pub fn new(num_targets: usize, pad_index: Option<usize>) -> Self {
|
||||
Self {
|
||||
num_targets,
|
||||
pad_index,
|
||||
_b: B::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the criterion on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - logits: [batch_size, num_targets]
|
||||
/// - targets: [batch_size]
|
||||
pub fn forward(
|
||||
&self,
|
||||
logits: &Tensor<B, 2>,
|
||||
targets: &Tensor<B::IntegerBackend, 1>,
|
||||
) -> Tensor<B, 1> {
|
||||
let device = logits.device();
|
||||
let [batch_size] = targets.dims();
|
||||
let indexes = targets.to_data();
|
||||
|
||||
let mut targets_logits =
|
||||
Tensor::<B, 2>::zeros_device([batch_size, self.num_targets], device);
|
||||
|
||||
for b in 0..batch_size {
|
||||
let index = indexes.value[b] as usize;
|
||||
if let Some(pad_index) = self.pad_index {
|
||||
if index == pad_index {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
targets_logits = targets_logits.index_assign(
|
||||
[b..b + 1, index..index + 1],
|
||||
&Tensor::ones_device([1, 1], device),
|
||||
);
|
||||
}
|
||||
|
||||
cross_entropy_with_logits(logits, &targets_logits.detach())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::{Data, Distribution};
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_loss() {
|
||||
let [batch_size, num_targets] = [4, 5];
|
||||
let logits = Tensor::<TestBackend, 2>::random(
|
||||
[batch_size, num_targets],
|
||||
Distribution::Normal(0., 1.0),
|
||||
);
|
||||
let targets =
|
||||
Tensor::<<TestBackend as Backend>::IntegerBackend, 1>::from_data(Data::from([
|
||||
2, 0, 4, 1_i64,
|
||||
]));
|
||||
let targets_logits = Tensor::<TestBackend, 2>::from_data(Data::from([
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
]));
|
||||
|
||||
let loss_1 = CrossEntropyLoss::new(5, None).forward(&logits, &targets);
|
||||
let loss_2 = cross_entropy_with_logits(&logits, &targets_logits);
|
||||
|
||||
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
mod cross_entropy;
|
||||
|
||||
pub use cross_entropy::*;
|
|
@ -1,5 +1,6 @@
|
|||
pub mod attention;
|
||||
pub mod cache;
|
||||
pub mod loss;
|
||||
pub mod transformer;
|
||||
|
||||
mod dropout;
|
||||
|
|
|
@ -6,7 +6,7 @@ use burn_tensor::Tensor;
|
|||
pub struct ClassificationOutput<B: Backend> {
|
||||
pub loss: Tensor<B, 1>,
|
||||
pub output: Tensor<B, 2>,
|
||||
pub targets: Tensor<B, 2>,
|
||||
pub targets: Tensor<B::IntegerBackend, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric {
|
||||
|
@ -24,6 +24,6 @@ impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMet
|
|||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B, 2>)>>::clear(self);
|
||||
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)>>::clear(self);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,19 +31,18 @@ impl Numeric for AccuracyMetric {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B, 2>)> for AccuracyMetric {
|
||||
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B, 2>)) -> MetricStateDyn {
|
||||
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for AccuracyMetric {
|
||||
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)) -> MetricStateDyn {
|
||||
let (outputs, targets) = batch;
|
||||
let logits_outputs = outputs.argmax(1).to_device(B::Device::default());
|
||||
let logits_targets = targets.argmax(1).to_device(B::Device::default());
|
||||
let count_current = logits_targets.shape().dims[0];
|
||||
let count_current = outputs.dims()[0];
|
||||
|
||||
let total_current = logits_outputs
|
||||
.equal(&logits_targets)
|
||||
.to_int()
|
||||
.sum()
|
||||
.to_data()
|
||||
.value[0] as usize;
|
||||
let targets = targets.to_device(B::Device::default());
|
||||
let outputs = outputs
|
||||
.argmax(1)
|
||||
.to_device(B::Device::default())
|
||||
.reshape([count_current]);
|
||||
|
||||
let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;
|
||||
|
||||
self.count += count_current;
|
||||
self.total += total_current;
|
||||
|
|
|
@ -10,7 +10,7 @@ pub struct MNISTBatcher<B: Backend> {
|
|||
#[derive(Clone, Debug)]
|
||||
pub struct MNISTBatch<B: Backend> {
|
||||
pub images: Tensor<B, 2>,
|
||||
pub targets: Tensor<B, 2>,
|
||||
pub targets: Tensor<B::IntegerBackend, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MNISTBatcher<B> {
|
||||
|
@ -31,7 +31,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
|
||||
let targets = items
|
||||
.iter()
|
||||
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
|
||||
.map(|item| Tensor::<B::IntegerBackend, 1>::from_data(Data::from([item.label as i64])))
|
||||
.collect();
|
||||
|
||||
let images = Tensor::cat(images, 0).to_device(self.device).detach();
|
||||
|
|
|
@ -5,11 +5,10 @@ use crate::{
|
|||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
nn::{self, loss::CrossEntropyLoss},
|
||||
optim::SgdConfig,
|
||||
tensor::{
|
||||
backend::{ADBackend, Backend},
|
||||
loss::cross_entropy_with_logits,
|
||||
Tensor,
|
||||
},
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
|
@ -34,6 +33,7 @@ pub struct Model<B: Backend> {
|
|||
mlp: Param<Mlp<B>>,
|
||||
input: Param<nn::Linear<B>>,
|
||||
output: Param<nn::Linear<B>>,
|
||||
num_classes: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
|
@ -46,6 +46,7 @@ impl<B: Backend> Model<B> {
|
|||
mlp: Param::new(mlp),
|
||||
output: Param::new(output),
|
||||
input: Param::new(input),
|
||||
num_classes,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,7 +63,8 @@ impl<B: Backend> Model<B> {
|
|||
pub fn forward_classification(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
|
||||
let targets = item.targets;
|
||||
let output = self.forward(item.images);
|
||||
let loss = cross_entropy_with_logits(&output, &targets);
|
||||
let loss = CrossEntropyLoss::new(self.num_classes, None);
|
||||
let loss = loss.forward(&output, &targets);
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
|
|
|
@ -2,14 +2,13 @@ use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
|
|||
use burn::{
|
||||
data::dataloader::batcher::Batcher,
|
||||
nn::attention::generate_padding_mask,
|
||||
tensor::{backend::Backend, BoolTensor, Tensor},
|
||||
tensor::{backend::Backend, BoolTensor, Data, Tensor},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct TextClassificationBatcher<B: Backend> {
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
num_classes: usize,
|
||||
device: B::Device,
|
||||
max_seq_lenght: usize,
|
||||
}
|
||||
|
@ -17,7 +16,7 @@ pub struct TextClassificationBatcher<B: Backend> {
|
|||
#[derive(Debug, Clone, new)]
|
||||
pub struct TextClassificationBatch<B: Backend> {
|
||||
pub tokens: Tensor<B::IntegerBackend, 2>,
|
||||
pub labels: Tensor<B, 2>,
|
||||
pub labels: Tensor<B::IntegerBackend, 1>,
|
||||
pub mask_pad: BoolTensor<B, 2>,
|
||||
}
|
||||
|
||||
|
@ -30,7 +29,7 @@ impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>
|
|||
|
||||
for item in items {
|
||||
tokens_list.push(self.tokenizer.encode(&item.text));
|
||||
labels_list.push(Tensor::one_hot(item.label, self.num_classes));
|
||||
labels_list.push(Tensor::from_data(Data::from([item.label as i64])));
|
||||
}
|
||||
|
||||
let mask = generate_padding_mask(
|
||||
|
|
|
@ -3,11 +3,12 @@ use burn::{
|
|||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{
|
||||
loss::CrossEntropyLoss,
|
||||
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
||||
Embedding, EmbeddingConfig, Linear, LinearConfig,
|
||||
},
|
||||
tensor::backend::{ADBackend, Backend},
|
||||
tensor::{loss::cross_entropy_with_logits, Tensor},
|
||||
tensor::Tensor,
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
|
@ -77,7 +78,8 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
.index([0..batch_size, 0..1])
|
||||
.reshape([batch_size, self.n_classes]);
|
||||
|
||||
let loss = cross_entropy_with_logits(&output_classification, &labels.clone().detach());
|
||||
let loss = CrossEntropyLoss::new(self.n_classes, None);
|
||||
let loss = loss.forward(&output_classification, &labels);
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
|
|
|
@ -42,13 +42,11 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
let tokenizer = Arc::new(BertCasedTokenizer::default());
|
||||
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
|
||||
tokenizer.clone(),
|
||||
n_classes,
|
||||
device,
|
||||
config.max_seq_length,
|
||||
));
|
||||
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
|
||||
tokenizer.clone(),
|
||||
n_classes,
|
||||
device,
|
||||
config.max_seq_length,
|
||||
));
|
||||
|
|
|
@ -9,8 +9,6 @@ use std::sync::Arc;
|
|||
#[derive(new)]
|
||||
pub struct TextGenerationBatcher {
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
vocab_size: usize,
|
||||
pad_token: usize,
|
||||
max_seq_lenght: usize,
|
||||
}
|
||||
|
||||
|
@ -23,7 +21,7 @@ pub struct TextGenerationBatch<B: Backend> {
|
|||
#[derive(Debug, Clone, new)]
|
||||
pub struct TrainingTextGenerationBatch<B: Backend> {
|
||||
pub tokens_inputs: Tensor<B::IntegerBackend, 2>,
|
||||
pub targets: Tensor<B, 2>,
|
||||
pub targets: Tensor<B::IntegerBackend, 2>,
|
||||
pub mask_pad: BoolTensor<B, 2>,
|
||||
}
|
||||
|
||||
|
@ -60,21 +58,6 @@ impl<B: Backend> Batcher<TextGenerationItem, TrainingTextGenerationBatch<B>>
|
|||
let targets = item.tokens.index([0..batch_size, 1..seq_length]);
|
||||
let mask_pad = item.mask_pad.index([0..batch_size, 0..seq_length - 1]);
|
||||
|
||||
let seq_length = seq_length - 1;
|
||||
|
||||
let targets = targets
|
||||
.reshape([batch_size * seq_length])
|
||||
.to_data()
|
||||
.value
|
||||
.iter()
|
||||
.map(|index| match *index as usize == self.pad_token {
|
||||
true => Tensor::<B, 2>::zeros([1, self.vocab_size]),
|
||||
false => Tensor::<B, 2>::one_hot(*index as usize, self.vocab_size),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let targets = Tensor::cat(targets, 0);
|
||||
|
||||
TrainingTextGenerationBatch::new(inputs, targets, mask_pad)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,11 +4,12 @@ use burn::{
|
|||
module::{Module, Param},
|
||||
nn::{
|
||||
attention::generate_autoregressive_mask,
|
||||
loss::CrossEntropyLoss,
|
||||
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
||||
Embedding, EmbeddingConfig, Linear, LinearConfig,
|
||||
},
|
||||
tensor::backend::{ADBackend, Backend},
|
||||
tensor::{loss::cross_entropy_with_logits, Tensor},
|
||||
tensor::Tensor,
|
||||
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
|
@ -83,15 +84,16 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
);
|
||||
|
||||
let output = self.output.forward(encoded);
|
||||
let output_classification = output.reshape([batch_size * seq_length, self.vocab_size]);
|
||||
let targets = item.targets.to_device(device).detach();
|
||||
let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]);
|
||||
let targets_flatten = item.targets.reshape([batch_size * seq_length]);
|
||||
|
||||
let loss = cross_entropy_with_logits(&output_classification, &targets);
|
||||
let loss = CrossEntropyLoss::new(self.vocab_size, Some(self.pad_token));
|
||||
let loss = loss.forward(&output_flatten, &targets_flatten);
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
output: output_classification,
|
||||
targets,
|
||||
output: output_flatten,
|
||||
targets: targets_flatten,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,14 +42,10 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
let tokenizer = Arc::new(Gpt2Tokenizer::default());
|
||||
let batcher_train = Arc::new(TextGenerationBatcher::new(
|
||||
tokenizer.clone(),
|
||||
tokenizer.vocab_size(),
|
||||
tokenizer.pad_token(),
|
||||
config.max_seq_length,
|
||||
));
|
||||
let batcher_test = Arc::new(TextGenerationBatcher::new(
|
||||
tokenizer.clone(),
|
||||
tokenizer.vocab_size(),
|
||||
tokenizer.pad_token(),
|
||||
config.max_seq_length,
|
||||
));
|
||||
|
||||
|
|
Loading…
Reference in New Issue