Add warmup logic when calculating eta (#923)

This commit is contained in:
Nathaniel Simard 2023-11-03 08:57:09 -04:00 committed by GitHub
parent 2ac348c604
commit dddc138757
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 107 additions and 22 deletions

View File

@ -1,22 +1,20 @@
use crate::renderer::TrainingProgress;
use super::TerminalFrame;
use crate::renderer::TrainingProgress;
use ratatui::{
prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Style, Stylize},
text::{Line, Span},
widgets::{Block, Borders, Gauge, Paragraph},
};
use std::time::Instant;
use std::time::{Duration, Instant};
/// Simple progress bar for the training.
///
/// We currently ignore the time taken for the validation part.
pub(crate) struct ProgressBarState {
progress_train: f64, // Progress for total training.
progress_train_for_eta: f64, // Progress considering the starting epoch.
progress_train: f64, // Progress for total training.
starting_epoch: usize,
started: Instant,
estimate: ProgressEstimate,
}
const MINUTE: u64 = 60;
@ -27,15 +25,14 @@ impl ProgressBarState {
pub fn new(checkpoint: Option<usize>) -> Self {
Self {
progress_train: 0.0,
progress_train_for_eta: 0.0,
started: Instant::now(),
estimate: ProgressEstimate::new(),
starting_epoch: checkpoint.unwrap_or(0),
}
}
/// Update the training progress.
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
self.progress_train = calculate_progress(progress, 0);
self.progress_train_for_eta = calculate_progress(progress, self.starting_epoch);
self.progress_train = calculate_progress(progress, 0, 0);
self.estimate.update(progress, self.starting_epoch);
}
/// Update the validation progress.
@ -45,15 +42,11 @@ impl ProgressBarState {
/// Create a view for the current progress.
pub(crate) fn view(&self) -> ProgressBarView {
let eta = self.started.elapsed();
let total_estimated = (eta.as_secs() as f64) / self.progress_train_for_eta;
const NO_ETA: &str = "---";
let eta = if total_estimated.is_normal() {
let remaining = 1.0 - self.progress_train_for_eta;
let eta = (total_estimated * remaining) as u64;
format_eta(eta)
} else {
"---".to_string()
let eta = match self.estimate.secs() {
Some(eta) => format_eta(eta),
None => NO_ETA.to_string(),
};
ProgressBarView::new(self.progress_train, eta)
}
@ -105,15 +98,87 @@ impl ProgressBarView {
}
}
fn calculate_progress(progress: &TrainingProgress, starting_epoch: usize) -> f64 {
struct ProgressEstimate {
started: Instant,
started_after_warmup: Option<Instant>,
warmup_num_items: usize,
progress: f64,
}
impl ProgressEstimate {
fn new() -> Self {
Self {
started: Instant::now(),
started_after_warmup: None,
warmup_num_items: 0,
progress: 0.0,
}
}
fn secs(&self) -> Option<u64> {
let eta = match self.started_after_warmup {
Some(started) => started.elapsed(),
None => return None,
};
let total_estimated = (eta.as_secs() as f64) / self.progress;
if total_estimated.is_normal() {
let remaining = 1.0 - self.progress;
let eta = (total_estimated * remaining) as u64;
Some(eta)
} else {
None
}
}
fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
if self.started_after_warmup.is_some() {
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
return;
}
const WARMUP_NUM_ITERATION: usize = 10;
// When the training has started since 30 seconds.
if self.started.elapsed() > Duration::from_secs(30) {
self.init(progress, starting_epoch);
return;
}
// When the training has started since at least 10 seconds and completed 10 iterations.
if progress.iteration >= WARMUP_NUM_ITERATION
&& self.started.elapsed() > Duration::from_secs(10)
{
self.init(progress, starting_epoch);
}
}
fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
let epoch = progress.epoch - starting_epoch;
let epoch_items = (epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed;
self.warmup_num_items = epoch_items + iteration_items;
self.started_after_warmup = Some(Instant::now());
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
}
}
fn calculate_progress(
progress: &TrainingProgress,
starting_epoch: usize,
ignore_num_items: usize,
) -> f64 {
let epoch_total = progress.epoch_total - starting_epoch;
let epoch = progress.epoch - starting_epoch;
let total_items = progress.progress.items_total * epoch_total;
let epoch_items = (epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed as f64;
let iteration_items = progress.progress.items_processed;
let num_items = epoch_items + iteration_items - ignore_num_items;
(epoch_items as f64 + iteration_items) / total_items as f64
num_items as f64 / total_items as f64
}
fn format_eta(eta_secs: u64) -> String {
@ -171,9 +236,29 @@ mod tests {
};
let starting_epoch = 8;
let progress = calculate_progress(&progress, starting_epoch);
let progress = calculate_progress(&progress, starting_epoch, 0);
// Two epochs remaining while the first is half done.
assert_eq!(0.25, progress);
}
#[test]
fn calculate_progress_for_eta_with_warmup() {
let half = Progress {
items_processed: 110,
items_total: 1000,
};
let progress = TrainingProgress {
progress: half,
epoch: 9,
epoch_total: 10,
iteration: 500,
};
let starting_epoch = 8;
let progress = calculate_progress(&progress, starting_epoch, 10);
// Two epochs remaining while the first is half done.
assert_eq!(0.05, progress);
}
}