mirror of https://github.com/tracel-ai/burn.git
Add warmup logic when calculating eta (#923)
This commit is contained in:
parent
2ac348c604
commit
dddc138757
|
@ -1,22 +1,20 @@
|
|||
use crate::renderer::TrainingProgress;
|
||||
|
||||
use super::TerminalFrame;
|
||||
use crate::renderer::TrainingProgress;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Style, Stylize},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Gauge, Paragraph},
|
||||
};
|
||||
use std::time::Instant;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Simple progress bar for the training.
|
||||
///
|
||||
/// We currently ignore the time taken for the validation part.
|
||||
pub(crate) struct ProgressBarState {
|
||||
progress_train: f64, // Progress for total training.
|
||||
progress_train_for_eta: f64, // Progress considering the starting epoch.
|
||||
progress_train: f64, // Progress for total training.
|
||||
starting_epoch: usize,
|
||||
started: Instant,
|
||||
estimate: ProgressEstimate,
|
||||
}
|
||||
|
||||
const MINUTE: u64 = 60;
|
||||
|
@ -27,15 +25,14 @@ impl ProgressBarState {
|
|||
pub fn new(checkpoint: Option<usize>) -> Self {
|
||||
Self {
|
||||
progress_train: 0.0,
|
||||
progress_train_for_eta: 0.0,
|
||||
started: Instant::now(),
|
||||
estimate: ProgressEstimate::new(),
|
||||
starting_epoch: checkpoint.unwrap_or(0),
|
||||
}
|
||||
}
|
||||
/// Update the training progress.
|
||||
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
|
||||
self.progress_train = calculate_progress(progress, 0);
|
||||
self.progress_train_for_eta = calculate_progress(progress, self.starting_epoch);
|
||||
self.progress_train = calculate_progress(progress, 0, 0);
|
||||
self.estimate.update(progress, self.starting_epoch);
|
||||
}
|
||||
|
||||
/// Update the validation progress.
|
||||
|
@ -45,15 +42,11 @@ impl ProgressBarState {
|
|||
|
||||
/// Create a view for the current progress.
|
||||
pub(crate) fn view(&self) -> ProgressBarView {
|
||||
let eta = self.started.elapsed();
|
||||
let total_estimated = (eta.as_secs() as f64) / self.progress_train_for_eta;
|
||||
const NO_ETA: &str = "---";
|
||||
|
||||
let eta = if total_estimated.is_normal() {
|
||||
let remaining = 1.0 - self.progress_train_for_eta;
|
||||
let eta = (total_estimated * remaining) as u64;
|
||||
format_eta(eta)
|
||||
} else {
|
||||
"---".to_string()
|
||||
let eta = match self.estimate.secs() {
|
||||
Some(eta) => format_eta(eta),
|
||||
None => NO_ETA.to_string(),
|
||||
};
|
||||
ProgressBarView::new(self.progress_train, eta)
|
||||
}
|
||||
|
@ -105,15 +98,87 @@ impl ProgressBarView {
|
|||
}
|
||||
}
|
||||
|
||||
fn calculate_progress(progress: &TrainingProgress, starting_epoch: usize) -> f64 {
|
||||
struct ProgressEstimate {
|
||||
started: Instant,
|
||||
started_after_warmup: Option<Instant>,
|
||||
warmup_num_items: usize,
|
||||
progress: f64,
|
||||
}
|
||||
|
||||
impl ProgressEstimate {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
started: Instant::now(),
|
||||
started_after_warmup: None,
|
||||
warmup_num_items: 0,
|
||||
progress: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn secs(&self) -> Option<u64> {
|
||||
let eta = match self.started_after_warmup {
|
||||
Some(started) => started.elapsed(),
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let total_estimated = (eta.as_secs() as f64) / self.progress;
|
||||
|
||||
if total_estimated.is_normal() {
|
||||
let remaining = 1.0 - self.progress;
|
||||
let eta = (total_estimated * remaining) as u64;
|
||||
Some(eta)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
|
||||
if self.started_after_warmup.is_some() {
|
||||
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
|
||||
return;
|
||||
}
|
||||
|
||||
const WARMUP_NUM_ITERATION: usize = 10;
|
||||
|
||||
// When the training has started since 30 seconds.
|
||||
if self.started.elapsed() > Duration::from_secs(30) {
|
||||
self.init(progress, starting_epoch);
|
||||
return;
|
||||
}
|
||||
|
||||
// When the training has started since at least 10 seconds and completed 10 iterations.
|
||||
if progress.iteration >= WARMUP_NUM_ITERATION
|
||||
&& self.started.elapsed() > Duration::from_secs(10)
|
||||
{
|
||||
self.init(progress, starting_epoch);
|
||||
}
|
||||
}
|
||||
|
||||
fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
|
||||
let epoch = progress.epoch - starting_epoch;
|
||||
let epoch_items = (epoch - 1) * progress.progress.items_total;
|
||||
let iteration_items = progress.progress.items_processed;
|
||||
|
||||
self.warmup_num_items = epoch_items + iteration_items;
|
||||
self.started_after_warmup = Some(Instant::now());
|
||||
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_progress(
|
||||
progress: &TrainingProgress,
|
||||
starting_epoch: usize,
|
||||
ignore_num_items: usize,
|
||||
) -> f64 {
|
||||
let epoch_total = progress.epoch_total - starting_epoch;
|
||||
let epoch = progress.epoch - starting_epoch;
|
||||
|
||||
let total_items = progress.progress.items_total * epoch_total;
|
||||
let epoch_items = (epoch - 1) * progress.progress.items_total;
|
||||
let iteration_items = progress.progress.items_processed as f64;
|
||||
let iteration_items = progress.progress.items_processed;
|
||||
let num_items = epoch_items + iteration_items - ignore_num_items;
|
||||
|
||||
(epoch_items as f64 + iteration_items) / total_items as f64
|
||||
num_items as f64 / total_items as f64
|
||||
}
|
||||
|
||||
fn format_eta(eta_secs: u64) -> String {
|
||||
|
@ -171,9 +236,29 @@ mod tests {
|
|||
};
|
||||
|
||||
let starting_epoch = 8;
|
||||
let progress = calculate_progress(&progress, starting_epoch);
|
||||
let progress = calculate_progress(&progress, starting_epoch, 0);
|
||||
|
||||
// Two epochs remaining while the first is half done.
|
||||
assert_eq!(0.25, progress);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculate_progress_for_eta_with_warmup() {
|
||||
let half = Progress {
|
||||
items_processed: 110,
|
||||
items_total: 1000,
|
||||
};
|
||||
let progress = TrainingProgress {
|
||||
progress: half,
|
||||
epoch: 9,
|
||||
epoch_total: 10,
|
||||
iteration: 500,
|
||||
};
|
||||
|
||||
let starting_epoch = 8;
|
||||
let progress = calculate_progress(&progress, starting_epoch, 10);
|
||||
|
||||
// Two epochs remaining while the first is half done.
|
||||
assert_eq!(0.05, progress);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue