mirror of https://github.com/tracel-ai/burn.git
Refactor/metric adaptor (#139)
This commit is contained in:
parent
567adfb93e
commit
248039da0a
|
@ -1,6 +1,12 @@
|
|||
name: test
|
||||
|
||||
on: [push, pull_request]
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test-burn-dataset:
|
||||
|
|
|
@ -50,6 +50,6 @@ Therefore, creating the tape only requires a simple and efficent graph traversal
|
|||
|
||||
To run with CUDA set `TORCH_CUDA_VERSION=cu113`.
|
||||
|
||||
## Note
|
||||
## Notes
|
||||
|
||||
This crate can be use alone without the entire burn stack and with only selected backends for smaller binaries.
|
||||
|
|
|
@ -26,7 +26,7 @@ pub trait Optimizer: Send + Sync {
|
|||
|
||||
/// Register the optimizer state for a given parameter.
|
||||
///
|
||||
/// # Note
|
||||
/// # Notes
|
||||
///
|
||||
/// This should only be called by generated code.
|
||||
fn register_param_state<const D: usize>(
|
||||
|
@ -39,7 +39,7 @@ pub trait Optimizer: Send + Sync {
|
|||
|
||||
/// Load the optimizer state for a given parameter.
|
||||
///
|
||||
/// # Note
|
||||
/// # Notes
|
||||
///
|
||||
/// This should only be called by generated code.
|
||||
fn load_param_state<const D: usize>(
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::train::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer
|
|||
use crate::train::logger::FileMetricLogger;
|
||||
use crate::train::metric::dashboard::cli::CLIDashboardRenderer;
|
||||
use crate::train::metric::dashboard::Dashboard;
|
||||
use crate::train::metric::{Metric, Numeric};
|
||||
use crate::train::metric::{Adaptor, Metric, Numeric};
|
||||
use crate::train::AsyncTrainerCallback;
|
||||
use burn_tensor::backend::ADBackend;
|
||||
use burn_tensor::Element;
|
||||
|
@ -53,13 +53,19 @@ where
|
|||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub fn metric_train<M: Metric<T> + 'static>(mut self, metric: M) -> Self {
|
||||
pub fn metric_train<M: Metric + 'static>(mut self, metric: M) -> Self
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_train(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a validation metric.
|
||||
pub fn metric_valid<M: Metric<V> + 'static>(mut self, metric: M) -> Self {
|
||||
pub fn metric_valid<M: Metric + 'static>(mut self, metric: M) -> Self
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_valid(metric);
|
||||
self
|
||||
}
|
||||
|
@ -86,7 +92,10 @@ where
|
|||
/// Only [numeric](Numeric) metric can be displayed on a plot.
|
||||
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
|
||||
/// the same graph will be used for both.
|
||||
pub fn metric_train_plot<M: Metric<T> + Numeric + 'static>(mut self, metric: M) -> Self {
|
||||
pub fn metric_train_plot<M: Metric + Numeric + 'static>(mut self, metric: M) -> Self
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_train_plot(metric);
|
||||
self
|
||||
}
|
||||
|
@ -98,7 +107,10 @@ where
|
|||
/// Only [numeric](Numeric) metric can be displayed on a plot.
|
||||
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
|
||||
/// the same graph will be used for both.
|
||||
pub fn metric_valid_plot<M: Metric<V> + Numeric + 'static>(mut self, metric: M) -> Self {
|
||||
pub fn metric_valid_plot<M: Metric + Numeric + 'static>(mut self, metric: M) -> Self
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_valid_plot(metric);
|
||||
self
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::train::metric;
|
||||
use crate::train::metric::{AccuracyInput, Adaptor, LossInput};
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
/// Simple classification output adapted for multiple metrics.
|
||||
#[derive(new)]
|
||||
pub struct ClassificationOutput<B: Backend> {
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
@ -9,21 +10,14 @@ pub struct ClassificationOutput<B: Backend> {
|
|||
pub targets: Tensor<B::IntegerBackend, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric {
|
||||
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
|
||||
self.update(&item.loss)
|
||||
}
|
||||
fn clear(&mut self) {
|
||||
<metric::LossMetric as metric::Metric<Tensor<B, 1>>>::clear(self);
|
||||
impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> AccuracyInput<B> {
|
||||
AccuracyInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMetric {
|
||||
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
|
||||
self.update(&(item.output.clone(), item.targets.clone()))
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)>>::clear(self);
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use crate::train::metric::MetricEntry;
|
||||
|
||||
use super::{AsyncLogger, FileLogger, Logger};
|
||||
use crate::train::metric::MetricState;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub trait MetricLogger: Send {
|
||||
fn log(&mut self, item: &dyn MetricState);
|
||||
fn log(&mut self, item: &MetricEntry);
|
||||
fn epoch(&mut self, epoch: usize);
|
||||
}
|
||||
|
||||
|
@ -24,11 +25,11 @@ impl FileMetricLogger {
|
|||
}
|
||||
|
||||
impl MetricLogger for FileMetricLogger {
|
||||
fn log(&mut self, item: &dyn MetricState) {
|
||||
let key = item.name();
|
||||
let value = item.serialize();
|
||||
fn log(&mut self, item: &MetricEntry) {
|
||||
let key = &item.name;
|
||||
let value = &item.serialize;
|
||||
|
||||
let logger = match self.loggers.get_mut(&key) {
|
||||
let logger = match self.loggers.get_mut(key) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
let directory = format!("{}/epoch-{}", self.directory, self.epoch);
|
||||
|
@ -39,11 +40,11 @@ impl MetricLogger for FileMetricLogger {
|
|||
let logger = AsyncLogger::new(Box::new(logger));
|
||||
|
||||
self.loggers.insert(key.clone(), Box::new(logger));
|
||||
self.loggers.get_mut(&key).unwrap()
|
||||
self.loggers.get_mut(key).unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
logger.log(value);
|
||||
logger.log(value.clone());
|
||||
}
|
||||
|
||||
fn epoch(&mut self, epoch: usize) {
|
||||
|
|
|
@ -1,74 +1,60 @@
|
|||
use super::RunningMetricResult;
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use super::MetricEntry;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::train::metric::{Metric, MetricStateDyn, Numeric};
|
||||
use crate::train::metric::{Metric, Numeric};
|
||||
|
||||
pub struct AccuracyMetric {
|
||||
current: f64,
|
||||
count: usize,
|
||||
total: usize,
|
||||
/// The accuracy metric.
|
||||
#[derive(Default)]
|
||||
pub struct AccuracyMetric<B: Backend> {
|
||||
state: NumericMetricState,
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl AccuracyMetric {
|
||||
/// The [accuracy metric](AccuracyMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct AccuracyInput<B: Backend> {
|
||||
outputs: Tensor<B, 2>,
|
||||
targets: Tensor<B::IntegerBackend, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> AccuracyMetric<B> {
|
||||
/// Create the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
current: 0.0,
|
||||
total: 0,
|
||||
}
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AccuracyMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl<B: Backend> Metric for AccuracyMetric<B> {
|
||||
type Input = AccuracyInput<B>;
|
||||
|
||||
impl Numeric for AccuracyMetric {
|
||||
fn value(&self) -> f64 {
|
||||
self.current * 100.0
|
||||
}
|
||||
}
|
||||
fn update(&mut self, input: &AccuracyInput<B>) -> MetricEntry {
|
||||
let [batch_size, _n_classes] = input.outputs.dims();
|
||||
|
||||
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for AccuracyMetric {
|
||||
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)) -> MetricStateDyn {
|
||||
let (outputs, targets) = batch;
|
||||
let count_current = outputs.dims()[0];
|
||||
|
||||
let targets = targets.to_device(B::Device::default());
|
||||
let outputs = outputs
|
||||
let targets = input.targets.to_device(B::Device::default());
|
||||
let outputs = input
|
||||
.outputs
|
||||
.argmax(1)
|
||||
.to_device(B::Device::default())
|
||||
.reshape([count_current]);
|
||||
.reshape([batch_size]);
|
||||
|
||||
let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;
|
||||
let accuracy = 100.0 * total_current as f64 / batch_size as f64;
|
||||
|
||||
self.count += count_current;
|
||||
self.total += total_current;
|
||||
self.current = total_current as f64 / count_current as f64;
|
||||
|
||||
let name = String::from("Accurracy");
|
||||
let running = self.total as f64 / self.count as f64;
|
||||
let raw_running = format!("{running}");
|
||||
let raw_current = format!("{}", self.current);
|
||||
let formatted = format!(
|
||||
"running {:.2} % current {:.2} %",
|
||||
100.0 * running,
|
||||
100.0 * self.current
|
||||
);
|
||||
|
||||
Box::new(RunningMetricResult {
|
||||
name,
|
||||
formatted,
|
||||
raw_running,
|
||||
raw_current,
|
||||
})
|
||||
self.state.update(
|
||||
accuracy,
|
||||
batch_size,
|
||||
FormatOptions::new("Accuracy").unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.count = 0;
|
||||
self.total = 0;
|
||||
self.current = 0.0;
|
||||
self.state.reset()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for AccuracyMetric<B> {
|
||||
fn value(&self) -> f64 {
|
||||
self.state.value()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,38 +1,42 @@
|
|||
pub trait Metric<T>: Send + Sync {
|
||||
fn update(&mut self, item: &T) -> MetricStateDyn;
|
||||
/// Metric trait.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Implementations should define their own input type only used by the metric.
|
||||
/// This is important since some conflict may happen when the model output is adapted for each
|
||||
/// metric's input type.
|
||||
pub trait Metric: Send + Sync {
|
||||
type Input;
|
||||
|
||||
/// Update the metric state and returns the current metric entry.
|
||||
fn update(&mut self, item: &Self::Input) -> MetricEntry;
|
||||
/// Clear the metric state.
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
pub trait MetricState {
|
||||
fn name(&self) -> String;
|
||||
fn pretty(&self) -> String;
|
||||
fn serialize(&self) -> String;
|
||||
/// Adaptor are used to transform types so that they can be used by metrics.
|
||||
///
|
||||
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
|
||||
/// registed with the [leaner buidler](burn::train::LearnerBuilder).
|
||||
pub trait Adaptor<T> {
|
||||
/// Adapt the type to be passed to a [metric](Metric).
|
||||
fn adapt(&self) -> T;
|
||||
}
|
||||
|
||||
/// Declare a metric to be numeric.
|
||||
///
|
||||
/// This is usefull to plot the values of a metric during training.
|
||||
pub trait Numeric {
|
||||
fn value(&self) -> f64;
|
||||
}
|
||||
|
||||
pub type MetricStateDyn = Box<dyn MetricState>;
|
||||
|
||||
/// Data type that contains the current state of a metric at a given time.
|
||||
#[derive(new)]
|
||||
pub struct RunningMetricResult {
|
||||
pub struct MetricEntry {
|
||||
/// The name of the metric.
|
||||
pub name: String,
|
||||
/// The string to be displayed.
|
||||
pub formatted: String,
|
||||
pub raw_running: String,
|
||||
pub raw_current: String,
|
||||
}
|
||||
|
||||
impl MetricState for RunningMetricResult {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn pretty(&self) -> String {
|
||||
self.formatted.clone()
|
||||
}
|
||||
|
||||
fn serialize(&self) -> String {
|
||||
self.raw_current.clone()
|
||||
}
|
||||
/// The string to be saved.
|
||||
pub serialize: String,
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use super::RunningMetricResult;
|
||||
use crate::train::metric::{Metric, MetricState};
|
||||
use super::Adaptor;
|
||||
use crate::train::metric::{Metric, MetricEntry};
|
||||
use nvml_wrapper::Nvml;
|
||||
|
||||
/// Track basic cuda infos.
|
||||
pub struct CUDAMetric {
|
||||
nvml: Nvml,
|
||||
}
|
||||
|
@ -20,8 +21,14 @@ impl Default for CUDAMetric {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> Metric<T> for CUDAMetric {
|
||||
fn update(&mut self, _item: &T) -> Box<dyn MetricState> {
|
||||
impl<T> Adaptor<()> for T {
|
||||
fn adapt(&self) {}
|
||||
}
|
||||
|
||||
impl Metric for CUDAMetric {
|
||||
type Input = ();
|
||||
|
||||
fn update(&mut self, _item: &()) -> MetricEntry {
|
||||
let name = String::from("Cuda");
|
||||
|
||||
let mut formatted = String::new();
|
||||
|
@ -44,12 +51,7 @@ impl<T> Metric<T> for CUDAMetric {
|
|||
formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
|
||||
}
|
||||
|
||||
Box::new(RunningMetricResult {
|
||||
name,
|
||||
formatted,
|
||||
raw_running,
|
||||
raw_current: String::new(),
|
||||
})
|
||||
MetricEntry::new(name, formatted, raw_running)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
data::dataloader::Progress,
|
||||
train::{
|
||||
logger::MetricLogger,
|
||||
metric::{Metric, MetricStateDyn, Numeric},
|
||||
metric::{Adaptor, Metric, MetricEntry, Numeric},
|
||||
LearnerCallback, LearnerItem,
|
||||
},
|
||||
};
|
||||
|
@ -29,8 +29,8 @@ impl TrainingProgress {
|
|||
}
|
||||
|
||||
pub enum DashboardMetricState {
|
||||
Generic(MetricStateDyn),
|
||||
Numeric(MetricStateDyn, f64),
|
||||
Generic(MetricEntry),
|
||||
Numeric(MetricEntry, f64),
|
||||
}
|
||||
|
||||
pub trait DashboardRenderer: Send + Sync {
|
||||
|
@ -75,21 +75,33 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn register_train<M: Metric<T> + 'static>(&mut self, metric: M) {
|
||||
pub fn register_train<M: Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_train
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
pub fn register_train_plot<M: Numeric + Metric<T> + 'static>(&mut self, metric: M) {
|
||||
pub fn register_train_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_train_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
pub fn register_valid<M: Metric<V> + 'static>(&mut self, metric: M) {
|
||||
pub fn register_valid<M: Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_valid
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
pub fn register_valid_plot<M: Numeric + Metric<V> + 'static>(&mut self, metric: M) {
|
||||
pub fn register_valid_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_valid_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
@ -114,14 +126,14 @@ where
|
|||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
for metric in self.metrics_train.iter_mut() {
|
||||
let state = metric.update(&item);
|
||||
self.logger_train.log(state.as_ref());
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_train(DashboardMetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics_train_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item);
|
||||
self.logger_train.log(state.as_ref());
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_train(DashboardMetricState::Numeric(state, value));
|
||||
|
@ -132,14 +144,14 @@ where
|
|||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
for metric in self.metrics_valid.iter_mut() {
|
||||
let state = metric.update(&item);
|
||||
self.logger_valid.log(state.as_ref());
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_valid(DashboardMetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics_valid_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item);
|
||||
self.logger_valid.log(state.as_ref());
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_valid(DashboardMetricState::Numeric(state, value));
|
||||
|
@ -169,12 +181,12 @@ where
|
|||
}
|
||||
|
||||
trait DashboardNumericMetric<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> (MetricStateDyn, f64);
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> (MetricEntry, f64);
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
trait DashboardMetric<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> MetricStateDyn;
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> MetricEntry;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
|
@ -186,10 +198,11 @@ struct MetricWrapper<M> {
|
|||
impl<T, M> DashboardNumericMetric<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric<T> + Numeric + 'static,
|
||||
M: Metric + Numeric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> (MetricStateDyn, f64) {
|
||||
let update = self.metric.update(&item.item);
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> (MetricEntry, f64) {
|
||||
let update = self.metric.update(&item.item.adapt());
|
||||
let numeric = self.metric.value();
|
||||
|
||||
(update, numeric)
|
||||
|
@ -203,10 +216,11 @@ where
|
|||
impl<T, M> DashboardMetric<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric<T> + 'static,
|
||||
M: Metric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> MetricStateDyn {
|
||||
self.metric.update(&item.item)
|
||||
fn update(&mut self, item: &LearnerItem<T>) -> MetricEntry {
|
||||
self.metric.update(&item.item.adapt())
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
|
|
|
@ -33,24 +33,24 @@ impl DashboardRenderer for CLIDashboardRenderer {
|
|||
fn update_train(&mut self, state: DashboardMetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(state) => {
|
||||
self.metric_train.insert(state.name(), state.pretty());
|
||||
self.metric_train.insert(state.name, state.formatted);
|
||||
}
|
||||
DashboardMetricState::Numeric(state, value) => {
|
||||
self.metric_train.insert(state.name(), state.pretty());
|
||||
let name = &state.name;
|
||||
self.metric_train.insert(name.clone(), state.formatted);
|
||||
|
||||
let name = state.name();
|
||||
if let Some(mut plot) = self.text_plot_in_both(&name) {
|
||||
if let Some(mut plot) = self.text_plot_in_both(name) {
|
||||
plot.update_train(value as f32);
|
||||
self.metric_both_plot.insert(name, plot);
|
||||
self.metric_both_plot.insert(name.clone(), plot);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(plot) = self.metric_train_plot.get_mut(&name) {
|
||||
if let Some(plot) = self.metric_train_plot.get_mut(name) {
|
||||
plot.update_train(value as f32);
|
||||
} else {
|
||||
let mut plot = TextPlot::new();
|
||||
plot.update_train(value as f32);
|
||||
self.metric_train_plot.insert(state.name(), plot);
|
||||
self.metric_train_plot.insert(state.name, plot);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -59,24 +59,24 @@ impl DashboardRenderer for CLIDashboardRenderer {
|
|||
fn update_valid(&mut self, state: DashboardMetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(state) => {
|
||||
self.metric_valid.insert(state.name(), state.pretty());
|
||||
self.metric_valid.insert(state.name, state.formatted);
|
||||
}
|
||||
DashboardMetricState::Numeric(state, value) => {
|
||||
self.metric_valid.insert(state.name(), state.pretty());
|
||||
let name = &state.name;
|
||||
self.metric_valid.insert(name.clone(), state.formatted);
|
||||
|
||||
let name = state.name();
|
||||
if let Some(mut plot) = self.text_plot_in_both(&name) {
|
||||
if let Some(mut plot) = self.text_plot_in_both(name) {
|
||||
plot.update_valid(value as f32);
|
||||
self.metric_both_plot.insert(name, plot);
|
||||
self.metric_both_plot.insert(name.clone(), plot);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(plot) = self.metric_valid_plot.get_mut(&name) {
|
||||
if let Some(plot) = self.metric_valid_plot.get_mut(name) {
|
||||
plot.update_valid(value as f32);
|
||||
} else {
|
||||
let mut plot = TextPlot::new();
|
||||
plot.update_valid(value as f32);
|
||||
self.metric_valid_plot.insert(state.name(), plot);
|
||||
self.metric_valid_plot.insert(state.name, plot);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,62 +1,48 @@
|
|||
use super::RunningMetricResult;
|
||||
use super::state::FormatOptions;
|
||||
use super::state::NumericMetricState;
|
||||
use super::MetricEntry;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ElementConversion;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::train::metric::{Metric, MetricState, Numeric};
|
||||
use crate::train::metric::{Metric, Numeric};
|
||||
|
||||
pub struct LossMetric {
|
||||
current: f64,
|
||||
count: usize,
|
||||
total: f64,
|
||||
/// The loss metric.
|
||||
#[derive(Default)]
|
||||
pub struct LossMetric<B: Backend> {
|
||||
state: NumericMetricState,
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl LossMetric {
|
||||
/// The [loss metric](LossMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct LossInput<B: Backend> {
|
||||
tensor: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> LossMetric<B> {
|
||||
/// Create the metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
current: 0.0,
|
||||
total: 0.0,
|
||||
}
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LossMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl<B: Backend> Metric for LossMetric<B> {
|
||||
type Input = LossInput<B>;
|
||||
|
||||
impl Numeric for LossMetric {
|
||||
fn value(&self) -> f64 {
|
||||
self.current * 100.0
|
||||
}
|
||||
}
|
||||
fn update(&mut self, loss: &Self::Input) -> MetricEntry {
|
||||
let loss = f64::from_elem(loss.tensor.mean().into_data().value[0]);
|
||||
|
||||
impl<B: Backend> Metric<Tensor<B, 1>> for LossMetric {
|
||||
fn update(&mut self, loss: &Tensor<B, 1>) -> Box<dyn MetricState> {
|
||||
let loss = f64::from_elem(loss.to_data().value[0]);
|
||||
|
||||
self.count += 1;
|
||||
self.total += loss;
|
||||
self.current = loss;
|
||||
|
||||
let name = String::from("Loss");
|
||||
let running = self.total / self.count as f64;
|
||||
let raw_running = format!("{running}");
|
||||
let raw_current = format!("{}", self.current);
|
||||
let formatted = format!("running {:.3} current {:.3}", running, self.current);
|
||||
|
||||
Box::new(RunningMetricResult {
|
||||
name,
|
||||
formatted,
|
||||
raw_running,
|
||||
raw_current,
|
||||
})
|
||||
self.state
|
||||
.update(loss, 1, FormatOptions::new("Loss").precision(4))
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.count = 0;
|
||||
self.total = 0.0;
|
||||
self.current = 0.0;
|
||||
self.state.reset()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for LossMetric<B> {
|
||||
fn value(&self) -> f64 {
|
||||
self.state.value()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
pub mod dashboard;
|
||||
pub mod state;
|
||||
|
||||
mod acc;
|
||||
mod base;
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
use super::{MetricEntry, Numeric};
|
||||
|
||||
/// Usefull utility to implement numeric [metrics](crate::train::metric::Metric).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The numeric metric store values inside floats.
|
||||
/// Even if some metric are integers, their mean are floats.
|
||||
pub struct NumericMetricState {
|
||||
sum: f64,
|
||||
count: usize,
|
||||
current: f64,
|
||||
}
|
||||
|
||||
/// Formatting options for the [numeric metric state](NumericMetricState).
|
||||
pub struct FormatOptions {
|
||||
name: String,
|
||||
unit: Option<String>,
|
||||
precision: Option<usize>,
|
||||
}
|
||||
|
||||
impl FormatOptions {
|
||||
/// Create the [formatting options](FormatOptions) with a name.
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
unit: None,
|
||||
precision: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Specify the metric unit.
|
||||
pub fn unit(mut self, unit: &str) -> Self {
|
||||
self.unit = Some(unit.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify the floating point precision.
|
||||
pub fn precision(mut self, precision: usize) -> Self {
|
||||
self.precision = Some(precision);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl NumericMetricState {
|
||||
/// Create a new [numeric metric state](NumericMetricState).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sum: 0.0,
|
||||
count: 0,
|
||||
current: f64::NAN,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the state.
|
||||
pub fn reset(&mut self) {
|
||||
self.sum = 0.0;
|
||||
self.count = 0;
|
||||
self.current = f64::NAN;
|
||||
}
|
||||
|
||||
/// Update the state.
|
||||
pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry {
|
||||
self.sum += value * batch_size as f64;
|
||||
self.count += batch_size;
|
||||
self.current = value;
|
||||
|
||||
let value_current = value;
|
||||
let value_running = self.sum / self.count as f64;
|
||||
let serialized = value_current.to_string();
|
||||
|
||||
let (formatted_current, formatted_running) = match format.precision {
|
||||
Some(precision) => (
|
||||
format!("{value_current:.0$}", precision),
|
||||
format!("{value_running:.0$}", precision),
|
||||
),
|
||||
None => (format!("{value_current}"), format!("{value_running}")),
|
||||
};
|
||||
|
||||
let formatted = match format.unit {
|
||||
Some(unit) => {
|
||||
format!("Running {formatted_running} {unit} - Current {formatted_current} {unit}")
|
||||
}
|
||||
None => format!("Running {formatted_running} - Current {formatted_current}"),
|
||||
};
|
||||
|
||||
MetricEntry::new(format.name, formatted, serialized)
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for NumericMetricState {
|
||||
fn value(&self) -> f64 {
|
||||
self.current
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NumericMetricState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue