feat(learner): add RegressionOutput (#380)

This commit is contained in:
Yu Sun 2023-06-04 22:21:29 +08:00 committed by GitHub
parent 0a205a3603
commit 105c259d44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 0 deletions

View File

@ -2,6 +2,7 @@ mod base;
mod builder;
mod classification;
mod epoch;
mod regression;
mod step;
mod train_val;
@ -11,6 +12,7 @@ pub use base::*;
pub use builder::*;
pub use classification::*;
pub use epoch::*;
pub use regression::*;
pub use step::*;
pub use train::*;
pub use train_val::*;

View File

@ -0,0 +1,17 @@
use crate::metric::{Adaptor, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::Tensor;
/// Simple regression output adapted for multiple metrics.
#[derive(new)]
pub struct RegressionOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub output: Tensor<B, 2>,
pub targets: Tensor<B, 2>,
}
impl<B: Backend> Adaptor<LossInput<B>> for RegressionOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}