mirror of https://github.com/tracel-ai/burn.git
feat(learner): add RegressionOutput (#380)
This commit is contained in:
parent
0a205a3603
commit
105c259d44
|
@ -2,6 +2,7 @@ mod base;
|
|||
mod builder;
|
||||
mod classification;
|
||||
mod epoch;
|
||||
mod regression;
|
||||
mod step;
|
||||
mod train_val;
|
||||
|
||||
|
@ -11,6 +12,7 @@ pub use base::*;
|
|||
pub use builder::*;
|
||||
pub use classification::*;
|
||||
pub use epoch::*;
|
||||
pub use regression::*;
|
||||
pub use step::*;
|
||||
pub use train::*;
|
||||
pub use train_val::*;
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
use crate::metric::{Adaptor, LossInput};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::Tensor;
|
||||
|
||||
/// Simple regression output adapted for multiple metrics.
|
||||
#[derive(new)]
|
||||
pub struct RegressionOutput<B: Backend> {
|
||||
pub loss: Tensor<B, 1>,
|
||||
pub output: Tensor<B, 2>,
|
||||
pub targets: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for RegressionOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue