diff --git a/crates/burn-train/src/checkpoint/file.rs b/crates/burn-train/src/checkpoint/file.rs index 9dbdda809..be3caa9e2 100644 --- a/crates/burn-train/src/checkpoint/file.rs +++ b/crates/burn-train/src/checkpoint/file.rs @@ -1,3 +1,5 @@ +use std::path::{Path, PathBuf}; + use super::{Checkpointer, CheckpointerError}; use burn_core::{ record::{FileRecorder, Record}, @@ -6,7 +8,7 @@ use burn_core::{ /// The file checkpointer. pub struct FileCheckpointer { - directory: String, + directory: PathBuf, name: String, recorder: FR, } @@ -19,17 +21,19 @@ impl FileCheckpointer { /// * `recorder` - The file recorder. /// * `directory` - The directory to save the checkpoints. /// * `name` - The name of the checkpoint. - pub fn new(recorder: FR, directory: &str, name: &str) -> Self { + pub fn new(recorder: FR, directory: impl AsRef, name: &str) -> Self { + let directory = directory.as_ref(); std::fs::create_dir_all(directory).ok(); Self { - directory: directory.to_string(), + directory: directory.to_path_buf(), name: name.to_string(), recorder, } } - fn path_for_epoch(&self, epoch: usize) -> String { - format!("{}/{}-{}", self.directory, self.name, epoch) + + fn path_for_epoch(&self, epoch: usize) -> PathBuf { + self.directory.join(format!("{}-{}", self.name, epoch)) } } @@ -41,10 +45,10 @@ where { fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { let file_path = self.path_for_epoch(epoch); - log::info!("Saving checkpoint {} to {}", epoch, file_path); + log::info!("Saving checkpoint {} to {}", epoch, file_path.display()); self.recorder - .record(record, file_path.into()) + .record(record, file_path) .map_err(CheckpointerError::RecorderError)?; Ok(()) @@ -52,17 +56,25 @@ where fn restore(&self, epoch: usize, device: &B::Device) -> Result { let file_path = self.path_for_epoch(epoch); - log::info!("Restoring checkpoint {} from {}", epoch, file_path); + log::info!( + "Restoring checkpoint {} from {}", + epoch, + file_path.display() + ); let record = self .recorder - .load(file_path.into(), device) + .load(file_path, device) .map_err(CheckpointerError::RecorderError)?; Ok(record) } fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { - let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),); + let file_to_remove = format!( + "{}.{}", + self.path_for_epoch(epoch).display(), + FR::file_extension(), + ); if std::path::Path::new(&file_to_remove).exists() { log::info!("Removing checkpoint {}", file_to_remove); diff --git a/crates/burn-train/src/learner/application_logger.rs b/crates/burn-train/src/learner/application_logger.rs index 793ac6ada..face9f04e 100644 --- a/crates/burn-train/src/learner/application_logger.rs +++ b/crates/burn-train/src/learner/application_logger.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::path::{Path, PathBuf}; use tracing_core::{Level, LevelFilter}; use tracing_subscriber::filter::filter_fn; use tracing_subscriber::prelude::*; @@ -12,14 +12,14 @@ pub trait ApplicationLoggerInstaller { /// This struct is used to install a local file application logger to output logs to a given file path. pub struct FileApplicationLoggerInstaller { - path: String, + path: PathBuf, } impl FileApplicationLoggerInstaller { /// Create a new file application logger. - pub fn new(path: &str) -> Self { + pub fn new(path: impl AsRef) -> Self { Self { - path: path.to_string(), + path: path.as_ref().to_path_buf(), } } } @@ -29,8 +29,9 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller { let path = Path::new(&self.path); let writer = tracing_appender::rolling::never( path.parent().unwrap_or_else(|| Path::new(".")), - path.file_name() - .unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)), + path.file_name().unwrap_or_else(|| { + panic!("The path '{}' to point to a file.", self.path.display()) + }), ); let layer = tracing_subscriber::fmt::layer() .with_ansi(false) @@ -51,13 +52,14 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller { } let hook = std::panic::take_hook(); - let file_path: String = self.path.to_owned(); + let file_path = self.path.to_owned(); std::panic::set_hook(Box::new(move |info| { log::error!("PANIC => {}", info.to_string()); eprintln!( "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \ - '{file_path}'\n=============" + '{}'\n=============", + file_path.display() ); hook(info); })); diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index b0a5e9e5d..1b58f0764 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -1,4 +1,5 @@ use std::collections::HashSet; +use std::path::{Path, PathBuf}; use std::rc::Rc; use super::Learner; @@ -45,7 +46,7 @@ where )>, num_epochs: usize, checkpoint: Option, - directory: String, + directory: PathBuf, grad_accumulation: Option, devices: Vec, renderer: Option>, @@ -74,12 +75,14 @@ where /// # Arguments /// /// * `directory` - The directory to save the checkpoints. - pub fn new(directory: &str) -> Self { + pub fn new(directory: impl AsRef) -> Self { + let directory = directory.as_ref().to_path_buf(); + let experiment_log_file = directory.join("experiment.log"); Self { num_epochs: 1, checkpoint: None, checkpointers: None, - directory: directory.to_string(), + directory, grad_accumulation: None, devices: vec![B::Device::default()], metrics: Metrics::default(), @@ -87,7 +90,7 @@ where renderer: None, interrupter: TrainingInterrupter::new(), tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new( - format!("{}/experiment.log", directory).as_str(), + experiment_log_file, ))), num_loggers: 0, checkpointer_strategy: Box::new( @@ -256,21 +259,12 @@ where M::Record: 'static, S::Record: 'static, { - let checkpointer_model = FileCheckpointer::new( - recorder.clone(), - format!("{}/checkpoint", self.directory).as_str(), - "model", - ); - let checkpointer_optimizer = FileCheckpointer::new( - recorder.clone(), - format!("{}/checkpoint", self.directory).as_str(), - "optim", - ); - let checkpointer_scheduler: FileCheckpointer = FileCheckpointer::new( - recorder, - format!("{}/checkpoint", self.directory).as_str(), - "scheduler", - ); + let checkpoint_dir = self.directory.join("checkpoint"); + let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model"); + let checkpointer_optimizer = + FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim"); + let checkpointer_scheduler: FileCheckpointer = + FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler"); self.checkpointers = Some(( AsyncCheckpointer::new(checkpointer_model), @@ -325,17 +319,12 @@ where let renderer = self.renderer.unwrap_or_else(|| { Box::new(default_renderer(self.interrupter.clone(), self.checkpoint)) }); - let directory = &self.directory; if self.num_loggers == 0 { self.event_store - .register_logger_train(FileMetricLogger::new( - format!("{directory}/train").as_str(), - )); + .register_logger_train(FileMetricLogger::new(self.directory.join("train"))); self.event_store - .register_logger_valid(FileMetricLogger::new( - format!("{directory}/valid").as_str(), - )); + .register_logger_valid(FileMetricLogger::new(self.directory.join("valid"))); } let event_store = Rc::new(EventStoreClient::new(self.event_store)); diff --git a/crates/burn-train/src/learner/summary.rs b/crates/burn-train/src/learner/summary.rs index 4e058e302..5d208ba2d 100644 --- a/crates/burn-train/src/learner/summary.rs +++ b/crates/burn-train/src/learner/summary.rs @@ -1,5 +1,8 @@ use core::cmp::Ordering; -use std::{fmt::Display, path::Path}; +use std::{ + fmt::Display, + path::{Path, PathBuf}, +}; use crate::{ logger::FileMetricLogger, @@ -73,16 +76,20 @@ impl LearnerSummary { /// /// * `directory` - The directory containing the training artifacts (checkpoints and logs). /// * `metrics` - The list of metrics to collect for the summary. - pub fn new>(directory: &str, metrics: &[S]) -> Result { - let directory_path = Path::new(directory); - if !directory_path.exists() { - return Err(format!("Artifact directory does not exist at: {directory}")); + pub fn new>(directory: impl AsRef, metrics: &[S]) -> Result { + let directory = directory.as_ref(); + if !directory.exists() { + return Err(format!( + "Artifact directory does not exist at: {}", + directory.display() + )); } - let train_dir = directory_path.join("train"); - let valid_dir = directory_path.join("valid"); + let train_dir = directory.join("train"); + let valid_dir = directory.join("valid"); if !train_dir.exists() & !valid_dir.exists() { return Err(format!( - "No training or validation artifacts found at: {directory}" + "No training or validation artifacts found at: {}", + directory.display() )); } @@ -219,7 +226,7 @@ impl Display for LearnerSummary { } pub(crate) struct LearnerSummaryConfig { - pub(crate) directory: String, + pub(crate) directory: PathBuf, pub(crate) metrics: Vec, } diff --git a/crates/burn-train/src/logger/file.rs b/crates/burn-train/src/logger/file.rs index 79c23b462..c852c2167 100644 --- a/crates/burn-train/src/logger/file.rs +++ b/crates/burn-train/src/logger/file.rs @@ -1,5 +1,5 @@ use super::Logger; -use std::{fs::File, io::Write}; +use std::{fs::File, io::Write, path::Path}; /// File logger. pub struct FileLogger { @@ -16,14 +16,21 @@ impl FileLogger { /// # Returns /// /// The file logger. - pub fn new(path: &str) -> Self { + pub fn new(path: impl AsRef) -> Self { + let path = path.as_ref(); let mut options = std::fs::File::options(); let file = options .write(true) .truncate(true) .create(true) .open(path) - .unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}")); + .unwrap_or_else(|err| { + panic!( + "Should be able to create the new file '{}': {}", + path.display(), + err + ) + }); Self { file } } diff --git a/crates/burn-train/src/logger/metric.rs b/crates/burn-train/src/logger/metric.rs index 244485cbd..9424e90c1 100644 --- a/crates/burn-train/src/logger/metric.rs +++ b/crates/burn-train/src/logger/metric.rs @@ -1,6 +1,10 @@ use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger}; use crate::metric::{MetricEntry, NumericEntry}; -use std::{collections::HashMap, fs}; +use std::{ + collections::HashMap, + fs, + path::{Path, PathBuf}, +}; const EPOCH_PREFIX: &str = "epoch-"; @@ -27,7 +31,7 @@ pub trait MetricLogger: Send { /// The file metric logger. pub struct FileMetricLogger { loggers: HashMap>, - directory: String, + directory: PathBuf, epoch: usize, } @@ -41,10 +45,10 @@ impl FileMetricLogger { /// # Returns /// /// The file metric logger. - pub fn new(directory: &str) -> Self { + pub fn new(directory: impl AsRef) -> Self { Self { loggers: HashMap::new(), - directory: directory.to_string(), + directory: directory.as_ref().to_path_buf(), epoch: 1, } } @@ -76,15 +80,18 @@ impl FileMetricLogger { max_epoch } - fn epoch_directory(&self, epoch: usize) -> String { - format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch) + fn epoch_directory(&self, epoch: usize) -> PathBuf { + let name = format!("{}{}", EPOCH_PREFIX, epoch); + self.directory.join(name) } - fn file_path(&self, name: &str, epoch: usize) -> String { + + fn file_path(&self, name: &str, epoch: usize) -> PathBuf { let directory = self.epoch_directory(epoch); let name = name.replace(' ', "_"); - - format!("{directory}/{name}.log") + let name = format!("{name}.log"); + directory.join(name) } + fn create_directory(&self, epoch: usize) { let directory = self.epoch_directory(epoch); std::fs::create_dir_all(directory).ok(); @@ -102,7 +109,7 @@ impl MetricLogger for FileMetricLogger { self.create_directory(self.epoch); let file_path = self.file_path(key, self.epoch); - let logger = FileLogger::new(&file_path); + let logger = FileLogger::new(file_path); let logger = AsyncLogger::new(logger); self.loggers.insert(key.clone(), logger);