mirror of https://github.com/tracel-ai/burn.git
fix(book): add missing second parameter to CrosEntropyLoss constructor (#1301)
* fix(book): add missing second parameter to CrosEntropyLoss constructor CrossEntropyLoss::new() expects two parameters, the pad_index and the device * fix: fix missing closing parenthese
This commit is contained in:
parent
00b6c7d136
commit
3592f3799a
|
@ -22,7 +22,7 @@ impl<B: Backend> Model<B> {
|
|||
targets: Tensor<B, 1, Int>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let output = self.forward(images);
|
||||
let loss = CrossEntropyLoss::new(None).forward(output.clone(), targets.clone());
|
||||
let loss = CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
|
||||
|
||||
ClassificationOutput::new(loss, output, targets)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue