fix(book): add missing second parameter to CrosEntropyLoss constructor (#1301)

* fix(book): add missing second parameter to CrosEntropyLoss constructor

CrossEntropyLoss::new() expects two parameters, the pad_index and the device

* fix: fix missing closing parenthese
This commit is contained in:
Jakub 2024-02-15 15:46:41 +01:00 committed by GitHub
parent 00b6c7d136
commit 3592f3799a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -22,7 +22,7 @@ impl<B: Backend> Model<B> {
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
let output = self.forward(images);
let loss = CrossEntropyLoss::new(None).forward(output.clone(), targets.clone());
let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
ClassificationOutput::new(loss, output, targets)
}