diff --git a/burn-train/src/learner/mod.rs b/burn-train/src/learner/mod.rs index 2b737c8b4..5ef0a1c99 100644 --- a/burn-train/src/learner/mod.rs +++ b/burn-train/src/learner/mod.rs @@ -2,6 +2,7 @@ mod base; mod builder; mod classification; mod epoch; +mod regression; mod step; mod train_val; @@ -11,6 +12,7 @@ pub use base::*; pub use builder::*; pub use classification::*; pub use epoch::*; +pub use regression::*; pub use step::*; pub use train::*; pub use train_val::*; diff --git a/burn-train/src/learner/regression.rs b/burn-train/src/learner/regression.rs new file mode 100644 index 000000000..2838397bb --- /dev/null +++ b/burn-train/src/learner/regression.rs @@ -0,0 +1,17 @@ +use crate::metric::{Adaptor, LossInput}; +use burn_core::tensor::backend::Backend; +use burn_core::tensor::Tensor; + +/// Simple regression output adapted for multiple metrics. +#[derive(new)] +pub struct RegressionOutput { + pub loss: Tensor, + pub output: Tensor, + pub targets: Tensor, +} + +impl Adaptor> for RegressionOutput { + fn adapt(&self) -> LossInput { + LossInput::new(self.loss.clone()) + } +}