Refactor/metric adaptor (#139)

This commit is contained in:
Nathaniel Simard 2022-12-26 16:30:25 -05:00 committed by GitHub
parent 567adfb93e
commit 248039da0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 304 additions and 197 deletions

View File

@ -1,6 +1,12 @@
name: test
on: [push, pull_request]
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
test-burn-dataset:

View File

@ -50,6 +50,6 @@ Therefore, creating the tape only requires a simple and efficent graph traversal
To run with CUDA set `TORCH_CUDA_VERSION=cu113`.
## Note
## Notes
This crate can be use alone without the entire burn stack and with only selected backends for smaller binaries.

View File

@ -26,7 +26,7 @@ pub trait Optimizer: Send + Sync {
/// Register the optimizer state for a given parameter.
///
/// # Note
/// # Notes
///
/// This should only be called by generated code.
fn register_param_state<const D: usize>(
@ -39,7 +39,7 @@ pub trait Optimizer: Send + Sync {
/// Load the optimizer state for a given parameter.
///
/// # Note
/// # Notes
///
/// This should only be called by generated code.
fn load_param_state<const D: usize>(

View File

@ -6,7 +6,7 @@ use crate::train::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer
use crate::train::logger::FileMetricLogger;
use crate::train::metric::dashboard::cli::CLIDashboardRenderer;
use crate::train::metric::dashboard::Dashboard;
use crate::train::metric::{Metric, Numeric};
use crate::train::metric::{Adaptor, Metric, Numeric};
use crate::train::AsyncTrainerCallback;
use burn_tensor::backend::ADBackend;
use burn_tensor::Element;
@ -53,13 +53,19 @@ where
}
/// Register a training metric.
pub fn metric_train<M: Metric<T> + 'static>(mut self, metric: M) -> Self {
pub fn metric_train<M: Metric + 'static>(mut self, metric: M) -> Self
where
T: Adaptor<M::Input>,
{
self.dashboard.register_train(metric);
self
}
/// Register a validation metric.
pub fn metric_valid<M: Metric<V> + 'static>(mut self, metric: M) -> Self {
pub fn metric_valid<M: Metric + 'static>(mut self, metric: M) -> Self
where
V: Adaptor<M::Input>,
{
self.dashboard.register_valid(metric);
self
}
@ -86,7 +92,10 @@ where
/// Only [numeric](Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
/// the same graph will be used for both.
pub fn metric_train_plot<M: Metric<T> + Numeric + 'static>(mut self, metric: M) -> Self {
pub fn metric_train_plot<M: Metric + Numeric + 'static>(mut self, metric: M) -> Self
where
T: Adaptor<M::Input>,
{
self.dashboard.register_train_plot(metric);
self
}
@ -98,7 +107,10 @@ where
/// Only [numeric](Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
/// the same graph will be used for both.
pub fn metric_valid_plot<M: Metric<V> + Numeric + 'static>(mut self, metric: M) -> Self {
pub fn metric_valid_plot<M: Metric + Numeric + 'static>(mut self, metric: M) -> Self
where
V: Adaptor<M::Input>,
{
self.dashboard.register_valid_plot(metric);
self
}

View File

@ -1,7 +1,8 @@
use crate::tensor::backend::Backend;
use crate::train::metric;
use crate::train::metric::{AccuracyInput, Adaptor, LossInput};
use burn_tensor::Tensor;
/// Simple classification output adapted for multiple metrics.
#[derive(new)]
pub struct ClassificationOutput<B: Backend> {
pub loss: Tensor<B, 1>,
@ -9,21 +10,14 @@ pub struct ClassificationOutput<B: Backend> {
pub targets: Tensor<B::IntegerBackend, 1>,
}
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric {
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
self.update(&item.loss)
}
fn clear(&mut self) {
<metric::LossMetric as metric::Metric<Tensor<B, 1>>>::clear(self);
impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> AccuracyInput<B> {
AccuracyInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMetric {
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
self.update(&(item.output.clone(), item.targets.clone()))
}
fn clear(&mut self) {
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)>>::clear(self);
impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}

View File

@ -1,9 +1,10 @@
use crate::train::metric::MetricEntry;
use super::{AsyncLogger, FileLogger, Logger};
use crate::train::metric::MetricState;
use std::collections::HashMap;
pub trait MetricLogger: Send {
fn log(&mut self, item: &dyn MetricState);
fn log(&mut self, item: &MetricEntry);
fn epoch(&mut self, epoch: usize);
}
@ -24,11 +25,11 @@ impl FileMetricLogger {
}
impl MetricLogger for FileMetricLogger {
fn log(&mut self, item: &dyn MetricState) {
let key = item.name();
let value = item.serialize();
fn log(&mut self, item: &MetricEntry) {
let key = &item.name;
let value = &item.serialize;
let logger = match self.loggers.get_mut(&key) {
let logger = match self.loggers.get_mut(key) {
Some(val) => val,
None => {
let directory = format!("{}/epoch-{}", self.directory, self.epoch);
@ -39,11 +40,11 @@ impl MetricLogger for FileMetricLogger {
let logger = AsyncLogger::new(Box::new(logger));
self.loggers.insert(key.clone(), Box::new(logger));
self.loggers.get_mut(&key).unwrap()
self.loggers.get_mut(key).unwrap()
}
};
logger.log(value);
logger.log(value.clone());
}
fn epoch(&mut self, epoch: usize) {

View File

@ -1,74 +1,60 @@
use super::RunningMetricResult;
use super::state::{FormatOptions, NumericMetricState};
use super::MetricEntry;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use crate::train::metric::{Metric, MetricStateDyn, Numeric};
use crate::train::metric::{Metric, Numeric};
pub struct AccuracyMetric {
current: f64,
count: usize,
total: usize,
/// The accuracy metric.
#[derive(Default)]
pub struct AccuracyMetric<B: Backend> {
state: NumericMetricState,
_b: B,
}
impl AccuracyMetric {
/// The [accuracy metric](AccuracyMetric) input type.
#[derive(new)]
pub struct AccuracyInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B::IntegerBackend, 1>,
}
impl<B: Backend> AccuracyMetric<B> {
/// Create the metric.
pub fn new() -> Self {
Self {
count: 0,
current: 0.0,
total: 0,
}
Self::default()
}
}
impl Default for AccuracyMetric {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> Metric for AccuracyMetric<B> {
type Input = AccuracyInput<B>;
impl Numeric for AccuracyMetric {
fn value(&self) -> f64 {
self.current * 100.0
}
}
fn update(&mut self, input: &AccuracyInput<B>) -> MetricEntry {
let [batch_size, _n_classes] = input.outputs.dims();
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)) -> MetricStateDyn {
let (outputs, targets) = batch;
let count_current = outputs.dims()[0];
let targets = targets.to_device(B::Device::default());
let outputs = outputs
let targets = input.targets.to_device(B::Device::default());
let outputs = input
.outputs
.argmax(1)
.to_device(B::Device::default())
.reshape([count_current]);
.reshape([batch_size]);
let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;
let accuracy = 100.0 * total_current as f64 / batch_size as f64;
self.count += count_current;
self.total += total_current;
self.current = total_current as f64 / count_current as f64;
let name = String::from("Accurracy");
let running = self.total as f64 / self.count as f64;
let raw_running = format!("{running}");
let raw_current = format!("{}", self.current);
let formatted = format!(
"running {:.2} % current {:.2} %",
100.0 * running,
100.0 * self.current
);
Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current,
})
self.state.update(
accuracy,
batch_size,
FormatOptions::new("Accuracy").unit("%").precision(2),
)
}
fn clear(&mut self) {
self.count = 0;
self.total = 0;
self.current = 0.0;
self.state.reset()
}
}
impl<B: Backend> Numeric for AccuracyMetric<B> {
fn value(&self) -> f64 {
self.state.value()
}
}

View File

@ -1,38 +1,42 @@
pub trait Metric<T>: Send + Sync {
fn update(&mut self, item: &T) -> MetricStateDyn;
/// Metric trait.
///
/// # Notes
///
/// Implementations should define their own input type only used by the metric.
/// This is important since some conflict may happen when the model output is adapted for each
/// metric's input type.
pub trait Metric: Send + Sync {
type Input;
/// Update the metric state and returns the current metric entry.
fn update(&mut self, item: &Self::Input) -> MetricEntry;
/// Clear the metric state.
fn clear(&mut self);
}
pub trait MetricState {
fn name(&self) -> String;
fn pretty(&self) -> String;
fn serialize(&self) -> String;
/// Adaptor are used to transform types so that they can be used by metrics.
///
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
/// registed with the [leaner buidler](burn::train::LearnerBuilder).
pub trait Adaptor<T> {
/// Adapt the type to be passed to a [metric](Metric).
fn adapt(&self) -> T;
}
/// Declare a metric to be numeric.
///
/// This is usefull to plot the values of a metric during training.
pub trait Numeric {
fn value(&self) -> f64;
}
pub type MetricStateDyn = Box<dyn MetricState>;
/// Data type that contains the current state of a metric at a given time.
#[derive(new)]
pub struct RunningMetricResult {
pub struct MetricEntry {
/// The name of the metric.
pub name: String,
/// The string to be displayed.
pub formatted: String,
pub raw_running: String,
pub raw_current: String,
}
impl MetricState for RunningMetricResult {
fn name(&self) -> String {
self.name.clone()
}
fn pretty(&self) -> String {
self.formatted.clone()
}
fn serialize(&self) -> String {
self.raw_current.clone()
}
/// The string to be saved.
pub serialize: String,
}

View File

@ -1,7 +1,8 @@
use super::RunningMetricResult;
use crate::train::metric::{Metric, MetricState};
use super::Adaptor;
use crate::train::metric::{Metric, MetricEntry};
use nvml_wrapper::Nvml;
/// Track basic cuda infos.
pub struct CUDAMetric {
nvml: Nvml,
}
@ -20,8 +21,14 @@ impl Default for CUDAMetric {
}
}
impl<T> Metric<T> for CUDAMetric {
fn update(&mut self, _item: &T) -> Box<dyn MetricState> {
impl<T> Adaptor<()> for T {
fn adapt(&self) {}
}
impl Metric for CUDAMetric {
type Input = ();
fn update(&mut self, _item: &()) -> MetricEntry {
let name = String::from("Cuda");
let mut formatted = String::new();
@ -44,12 +51,7 @@ impl<T> Metric<T> for CUDAMetric {
formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
}
Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current: String::new(),
})
MetricEntry::new(name, formatted, raw_running)
}
fn clear(&mut self) {}

View File

@ -2,7 +2,7 @@ use crate::{
data::dataloader::Progress,
train::{
logger::MetricLogger,
metric::{Metric, MetricStateDyn, Numeric},
metric::{Adaptor, Metric, MetricEntry, Numeric},
LearnerCallback, LearnerItem,
},
};
@ -29,8 +29,8 @@ impl TrainingProgress {
}
pub enum DashboardMetricState {
Generic(MetricStateDyn),
Numeric(MetricStateDyn, f64),
Generic(MetricEntry),
Numeric(MetricEntry, f64),
}
pub trait DashboardRenderer: Send + Sync {
@ -75,21 +75,33 @@ where
}
}
pub fn register_train<M: Metric<T> + 'static>(&mut self, metric: M) {
pub fn register_train<M: Metric + 'static>(&mut self, metric: M)
where
T: Adaptor<M::Input>,
{
self.metrics_train
.push(Box::new(MetricWrapper::new(metric)));
}
pub fn register_train_plot<M: Numeric + Metric<T> + 'static>(&mut self, metric: M) {
pub fn register_train_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
where
T: Adaptor<M::Input>,
{
self.metrics_train_numeric
.push(Box::new(MetricWrapper::new(metric)));
}
pub fn register_valid<M: Metric<V> + 'static>(&mut self, metric: M) {
pub fn register_valid<M: Metric + 'static>(&mut self, metric: M)
where
V: Adaptor<M::Input>,
{
self.metrics_valid
.push(Box::new(MetricWrapper::new(metric)));
}
pub fn register_valid_plot<M: Numeric + Metric<V> + 'static>(&mut self, metric: M) {
pub fn register_valid_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
where
V: Adaptor<M::Input>,
{
self.metrics_valid_numeric
.push(Box::new(MetricWrapper::new(metric)));
}
@ -114,14 +126,14 @@ where
fn on_train_item(&mut self, item: LearnerItem<T>) {
for metric in self.metrics_train.iter_mut() {
let state = metric.update(&item);
self.logger_train.log(state.as_ref());
self.logger_train.log(&state);
self.renderer
.update_train(DashboardMetricState::Generic(state));
}
for metric in self.metrics_train_numeric.iter_mut() {
let (state, value) = metric.update(&item);
self.logger_train.log(state.as_ref());
self.logger_train.log(&state);
self.renderer
.update_train(DashboardMetricState::Numeric(state, value));
@ -132,14 +144,14 @@ where
fn on_valid_item(&mut self, item: LearnerItem<V>) {
for metric in self.metrics_valid.iter_mut() {
let state = metric.update(&item);
self.logger_valid.log(state.as_ref());
self.logger_valid.log(&state);
self.renderer
.update_valid(DashboardMetricState::Generic(state));
}
for metric in self.metrics_valid_numeric.iter_mut() {
let (state, value) = metric.update(&item);
self.logger_valid.log(state.as_ref());
self.logger_valid.log(&state);
self.renderer
.update_valid(DashboardMetricState::Numeric(state, value));
@ -169,12 +181,12 @@ where
}
trait DashboardNumericMetric<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>) -> (MetricStateDyn, f64);
fn update(&mut self, item: &LearnerItem<T>) -> (MetricEntry, f64);
fn clear(&mut self);
}
trait DashboardMetric<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>) -> MetricStateDyn;
fn update(&mut self, item: &LearnerItem<T>) -> MetricEntry;
fn clear(&mut self);
}
@ -186,10 +198,11 @@ struct MetricWrapper<M> {
impl<T, M> DashboardNumericMetric<T> for MetricWrapper<M>
where
T: 'static,
M: Metric<T> + Numeric + 'static,
M: Metric + Numeric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>) -> (MetricStateDyn, f64) {
let update = self.metric.update(&item.item);
fn update(&mut self, item: &LearnerItem<T>) -> (MetricEntry, f64) {
let update = self.metric.update(&item.item.adapt());
let numeric = self.metric.value();
(update, numeric)
@ -203,10 +216,11 @@ where
impl<T, M> DashboardMetric<T> for MetricWrapper<M>
where
T: 'static,
M: Metric<T> + 'static,
M: Metric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>) -> MetricStateDyn {
self.metric.update(&item.item)
fn update(&mut self, item: &LearnerItem<T>) -> MetricEntry {
self.metric.update(&item.item.adapt())
}
fn clear(&mut self) {

View File

@ -33,24 +33,24 @@ impl DashboardRenderer for CLIDashboardRenderer {
fn update_train(&mut self, state: DashboardMetricState) {
match state {
DashboardMetricState::Generic(state) => {
self.metric_train.insert(state.name(), state.pretty());
self.metric_train.insert(state.name, state.formatted);
}
DashboardMetricState::Numeric(state, value) => {
self.metric_train.insert(state.name(), state.pretty());
let name = &state.name;
self.metric_train.insert(name.clone(), state.formatted);
let name = state.name();
if let Some(mut plot) = self.text_plot_in_both(&name) {
if let Some(mut plot) = self.text_plot_in_both(name) {
plot.update_train(value as f32);
self.metric_both_plot.insert(name, plot);
self.metric_both_plot.insert(name.clone(), plot);
return;
}
if let Some(plot) = self.metric_train_plot.get_mut(&name) {
if let Some(plot) = self.metric_train_plot.get_mut(name) {
plot.update_train(value as f32);
} else {
let mut plot = TextPlot::new();
plot.update_train(value as f32);
self.metric_train_plot.insert(state.name(), plot);
self.metric_train_plot.insert(state.name, plot);
}
}
};
@ -59,24 +59,24 @@ impl DashboardRenderer for CLIDashboardRenderer {
fn update_valid(&mut self, state: DashboardMetricState) {
match state {
DashboardMetricState::Generic(state) => {
self.metric_valid.insert(state.name(), state.pretty());
self.metric_valid.insert(state.name, state.formatted);
}
DashboardMetricState::Numeric(state, value) => {
self.metric_valid.insert(state.name(), state.pretty());
let name = &state.name;
self.metric_valid.insert(name.clone(), state.formatted);
let name = state.name();
if let Some(mut plot) = self.text_plot_in_both(&name) {
if let Some(mut plot) = self.text_plot_in_both(name) {
plot.update_valid(value as f32);
self.metric_both_plot.insert(name, plot);
self.metric_both_plot.insert(name.clone(), plot);
return;
}
if let Some(plot) = self.metric_valid_plot.get_mut(&name) {
if let Some(plot) = self.metric_valid_plot.get_mut(name) {
plot.update_valid(value as f32);
} else {
let mut plot = TextPlot::new();
plot.update_valid(value as f32);
self.metric_valid_plot.insert(state.name(), plot);
self.metric_valid_plot.insert(state.name, plot);
}
}
};

View File

@ -1,62 +1,48 @@
use super::RunningMetricResult;
use super::state::FormatOptions;
use super::state::NumericMetricState;
use super::MetricEntry;
use crate::tensor::backend::Backend;
use crate::tensor::ElementConversion;
use crate::tensor::Tensor;
use crate::train::metric::{Metric, MetricState, Numeric};
use crate::train::metric::{Metric, Numeric};
pub struct LossMetric {
current: f64,
count: usize,
total: f64,
/// The loss metric.
#[derive(Default)]
pub struct LossMetric<B: Backend> {
state: NumericMetricState,
_b: B,
}
impl LossMetric {
/// The [loss metric](LossMetric) input type.
#[derive(new)]
pub struct LossInput<B: Backend> {
tensor: Tensor<B, 1>,
}
impl<B: Backend> LossMetric<B> {
/// Create the metric.
pub fn new() -> Self {
Self {
count: 0,
current: 0.0,
total: 0.0,
}
Self::default()
}
}
impl Default for LossMetric {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> Metric for LossMetric<B> {
type Input = LossInput<B>;
impl Numeric for LossMetric {
fn value(&self) -> f64 {
self.current * 100.0
}
}
fn update(&mut self, loss: &Self::Input) -> MetricEntry {
let loss = f64::from_elem(loss.tensor.mean().into_data().value[0]);
impl<B: Backend> Metric<Tensor<B, 1>> for LossMetric {
fn update(&mut self, loss: &Tensor<B, 1>) -> Box<dyn MetricState> {
let loss = f64::from_elem(loss.to_data().value[0]);
self.count += 1;
self.total += loss;
self.current = loss;
let name = String::from("Loss");
let running = self.total / self.count as f64;
let raw_running = format!("{running}");
let raw_current = format!("{}", self.current);
let formatted = format!("running {:.3} current {:.3}", running, self.current);
Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current,
})
self.state
.update(loss, 1, FormatOptions::new("Loss").precision(4))
}
fn clear(&mut self) {
self.count = 0;
self.total = 0.0;
self.current = 0.0;
self.state.reset()
}
}
impl<B: Backend> Numeric for LossMetric<B> {
fn value(&self) -> f64 {
self.state.value()
}
}

View File

@ -1,4 +1,5 @@
pub mod dashboard;
pub mod state;
mod acc;
mod base;

View File

@ -0,0 +1,101 @@
use super::{MetricEntry, Numeric};
/// Usefull utility to implement numeric [metrics](crate::train::metric::Metric).
///
/// # Notes
///
/// The numeric metric store values inside floats.
/// Even if some metric are integers, their mean are floats.
pub struct NumericMetricState {
sum: f64,
count: usize,
current: f64,
}
/// Formatting options for the [numeric metric state](NumericMetricState).
pub struct FormatOptions {
name: String,
unit: Option<String>,
precision: Option<usize>,
}
impl FormatOptions {
/// Create the [formatting options](FormatOptions) with a name.
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
unit: None,
precision: None,
}
}
/// Specify the metric unit.
pub fn unit(mut self, unit: &str) -> Self {
self.unit = Some(unit.to_string());
self
}
/// Specify the floating point precision.
pub fn precision(mut self, precision: usize) -> Self {
self.precision = Some(precision);
self
}
}
impl NumericMetricState {
/// Create a new [numeric metric state](NumericMetricState).
pub fn new() -> Self {
Self {
sum: 0.0,
count: 0,
current: f64::NAN,
}
}
/// Reset the state.
pub fn reset(&mut self) {
self.sum = 0.0;
self.count = 0;
self.current = f64::NAN;
}
/// Update the state.
pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry {
self.sum += value * batch_size as f64;
self.count += batch_size;
self.current = value;
let value_current = value;
let value_running = self.sum / self.count as f64;
let serialized = value_current.to_string();
let (formatted_current, formatted_running) = match format.precision {
Some(precision) => (
format!("{value_current:.0$}", precision),
format!("{value_running:.0$}", precision),
),
None => (format!("{value_current}"), format!("{value_running}")),
};
let formatted = match format.unit {
Some(unit) => {
format!("Running {formatted_running} {unit} - Current {formatted_current} {unit}")
}
None => format!("Running {formatted_running} - Current {formatted_current}"),
};
MetricEntry::new(format.name, formatted, serialized)
}
}
impl Numeric for NumericMetricState {
fn value(&self) -> f64 {
self.current
}
}
impl Default for NumericMetricState {
fn default() -> Self {
Self::new()
}
}