This commit is contained in:
Nathaniel Simard 2023-09-04 20:55:31 -04:00 committed by GitHub
parent b58af4a4a3
commit c484999d54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -1,4 +1,5 @@
// #![allow(clippy::single_range_in_vec_init)]
#![allow(clippy::single_range_in_vec_init)]
use crate as burn;
use crate::{config::Config, module::Module};
@ -94,7 +95,7 @@ impl<B: Backend> BinaryCrossEntropyLoss<B> {
match &self.weights {
Some(weights) => {
let loss = loss * weights.clone().select(0, Tensor::from_ints([0]));
let loss = loss * weights.clone().slice([0..1]);
let weights = weights.clone().gather(0, targets);
loss.neg() / weights
}

Binary file not shown.