Replaced `str` with `Path` (#1919)

* replaced str with Path

* minor change (Path to AsRef<Path>)

* fixed clippy lint
This commit is contained in:
Roy Varon 2024-06-30 02:17:59 +03:00 committed by GitHub
parent 98a58c867d
commit a7efc102b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 90 additions and 66 deletions

View File

@ -1,3 +1,5 @@
use std::path::{Path, PathBuf};
use super::{Checkpointer, CheckpointerError};
use burn_core::{
record::{FileRecorder, Record},
@ -6,7 +8,7 @@ use burn_core::{
/// The file checkpointer.
pub struct FileCheckpointer<FR> {
directory: String,
directory: PathBuf,
name: String,
recorder: FR,
}
@ -19,17 +21,19 @@ impl<FR> FileCheckpointer<FR> {
/// * `recorder` - The file recorder.
/// * `directory` - The directory to save the checkpoints.
/// * `name` - The name of the checkpoint.
pub fn new(recorder: FR, directory: &str, name: &str) -> Self {
pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
let directory = directory.as_ref();
std::fs::create_dir_all(directory).ok();
Self {
directory: directory.to_string(),
directory: directory.to_path_buf(),
name: name.to_string(),
recorder,
}
}
fn path_for_epoch(&self, epoch: usize) -> String {
format!("{}/{}-{}", self.directory, self.name, epoch)
fn path_for_epoch(&self, epoch: usize) -> PathBuf {
self.directory.join(format!("{}-{}", self.name, epoch))
}
}
@ -41,10 +45,10 @@ where
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Saving checkpoint {} to {}", epoch, file_path);
log::info!("Saving checkpoint {} to {}", epoch, file_path.display());
self.recorder
.record(record, file_path.into())
.record(record, file_path)
.map_err(CheckpointerError::RecorderError)?;
Ok(())
@ -52,17 +56,25 @@ where
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
log::info!(
"Restoring checkpoint {} from {}",
epoch,
file_path.display()
);
let record = self
.recorder
.load(file_path.into(), device)
.load(file_path, device)
.map_err(CheckpointerError::RecorderError)?;
Ok(record)
}
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),);
let file_to_remove = format!(
"{}.{}",
self.path_for_epoch(epoch).display(),
FR::file_extension(),
);
if std::path::Path::new(&file_to_remove).exists() {
log::info!("Removing checkpoint {}", file_to_remove);

View File

@ -1,4 +1,4 @@
use std::path::Path;
use std::path::{Path, PathBuf};
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
@ -12,14 +12,14 @@ pub trait ApplicationLoggerInstaller {
/// This struct is used to install a local file application logger to output logs to a given file path.
pub struct FileApplicationLoggerInstaller {
path: String,
path: PathBuf,
}
impl FileApplicationLoggerInstaller {
/// Create a new file application logger.
pub fn new(path: &str) -> Self {
pub fn new(path: impl AsRef<Path>) -> Self {
Self {
path: path.to_string(),
path: path.as_ref().to_path_buf(),
}
}
}
@ -29,8 +29,9 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
let path = Path::new(&self.path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name()
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
path.file_name().unwrap_or_else(|| {
panic!("The path '{}' to point to a file.", self.path.display())
}),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
@ -51,13 +52,14 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
}
let hook = std::panic::take_hook();
let file_path: String = self.path.to_owned();
let file_path = self.path.to_owned();
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
'{}'\n=============",
file_path.display()
);
hook(info);
}));

View File

@ -1,4 +1,5 @@
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use super::Learner;
@ -45,7 +46,7 @@ where
)>,
num_epochs: usize,
checkpoint: Option<usize>,
directory: String,
directory: PathBuf,
grad_accumulation: Option<usize>,
devices: Vec<B::Device>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
@ -74,12 +75,14 @@ where
/// # Arguments
///
/// * `directory` - The directory to save the checkpoints.
pub fn new(directory: &str) -> Self {
pub fn new(directory: impl AsRef<Path>) -> Self {
let directory = directory.as_ref().to_path_buf();
let experiment_log_file = directory.join("experiment.log");
Self {
num_epochs: 1,
checkpoint: None,
checkpointers: None,
directory: directory.to_string(),
directory,
grad_accumulation: None,
devices: vec![B::Device::default()],
metrics: Metrics::default(),
@ -87,7 +90,7 @@ where
renderer: None,
interrupter: TrainingInterrupter::new(),
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
format!("{}/experiment.log", directory).as_str(),
experiment_log_file,
))),
num_loggers: 0,
checkpointer_strategy: Box::new(
@ -256,21 +259,12 @@ where
M::Record: 'static,
S::Record: 'static,
{
let checkpointer_model = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"model",
);
let checkpointer_optimizer = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"optim",
);
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
recorder,
format!("{}/checkpoint", self.directory).as_str(),
"scheduler",
);
let checkpoint_dir = self.directory.join("checkpoint");
let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
let checkpointer_optimizer =
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
let checkpointer_scheduler: FileCheckpointer<FR> =
FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
self.checkpointers = Some((
AsyncCheckpointer::new(checkpointer_model),
@ -325,17 +319,12 @@ where
let renderer = self.renderer.unwrap_or_else(|| {
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
});
let directory = &self.directory;
if self.num_loggers == 0 {
self.event_store
.register_logger_train(FileMetricLogger::new(
format!("{directory}/train").as_str(),
));
.register_logger_train(FileMetricLogger::new(self.directory.join("train")));
self.event_store
.register_logger_valid(FileMetricLogger::new(
format!("{directory}/valid").as_str(),
));
.register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
}
let event_store = Rc::new(EventStoreClient::new(self.event_store));

View File

@ -1,5 +1,8 @@
use core::cmp::Ordering;
use std::{fmt::Display, path::Path};
use std::{
fmt::Display,
path::{Path, PathBuf},
};
use crate::{
logger::FileMetricLogger,
@ -73,16 +76,20 @@ impl LearnerSummary {
///
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
/// * `metrics` - The list of metrics to collect for the summary.
pub fn new<S: AsRef<str>>(directory: &str, metrics: &[S]) -> Result<Self, String> {
let directory_path = Path::new(directory);
if !directory_path.exists() {
return Err(format!("Artifact directory does not exist at: {directory}"));
pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
let directory = directory.as_ref();
if !directory.exists() {
return Err(format!(
"Artifact directory does not exist at: {}",
directory.display()
));
}
let train_dir = directory_path.join("train");
let valid_dir = directory_path.join("valid");
let train_dir = directory.join("train");
let valid_dir = directory.join("valid");
if !train_dir.exists() & !valid_dir.exists() {
return Err(format!(
"No training or validation artifacts found at: {directory}"
"No training or validation artifacts found at: {}",
directory.display()
));
}
@ -219,7 +226,7 @@ impl Display for LearnerSummary {
}
pub(crate) struct LearnerSummaryConfig {
pub(crate) directory: String,
pub(crate) directory: PathBuf,
pub(crate) metrics: Vec<String>,
}

View File

@ -1,5 +1,5 @@
use super::Logger;
use std::{fs::File, io::Write};
use std::{fs::File, io::Write, path::Path};
/// File logger.
pub struct FileLogger {
@ -16,14 +16,21 @@ impl FileLogger {
/// # Returns
///
/// The file logger.
pub fn new(path: &str) -> Self {
pub fn new(path: impl AsRef<Path>) -> Self {
let path = path.as_ref();
let mut options = std::fs::File::options();
let file = options
.write(true)
.truncate(true)
.create(true)
.open(path)
.unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}"));
.unwrap_or_else(|err| {
panic!(
"Should be able to create the new file '{}': {}",
path.display(),
err
)
});
Self { file }
}

View File

@ -1,6 +1,10 @@
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
use crate::metric::{MetricEntry, NumericEntry};
use std::{collections::HashMap, fs};
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
const EPOCH_PREFIX: &str = "epoch-";
@ -27,7 +31,7 @@ pub trait MetricLogger: Send {
/// The file metric logger.
pub struct FileMetricLogger {
loggers: HashMap<String, AsyncLogger<String>>,
directory: String,
directory: PathBuf,
epoch: usize,
}
@ -41,10 +45,10 @@ impl FileMetricLogger {
/// # Returns
///
/// The file metric logger.
pub fn new(directory: &str) -> Self {
pub fn new(directory: impl AsRef<Path>) -> Self {
Self {
loggers: HashMap::new(),
directory: directory.to_string(),
directory: directory.as_ref().to_path_buf(),
epoch: 1,
}
}
@ -76,15 +80,18 @@ impl FileMetricLogger {
max_epoch
}
fn epoch_directory(&self, epoch: usize) -> String {
format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch)
fn epoch_directory(&self, epoch: usize) -> PathBuf {
let name = format!("{}{}", EPOCH_PREFIX, epoch);
self.directory.join(name)
}
fn file_path(&self, name: &str, epoch: usize) -> String {
fn file_path(&self, name: &str, epoch: usize) -> PathBuf {
let directory = self.epoch_directory(epoch);
let name = name.replace(' ', "_");
format!("{directory}/{name}.log")
let name = format!("{name}.log");
directory.join(name)
}
fn create_directory(&self, epoch: usize) {
let directory = self.epoch_directory(epoch);
std::fs::create_dir_all(directory).ok();
@ -102,7 +109,7 @@ impl MetricLogger for FileMetricLogger {
self.create_directory(self.epoch);
let file_path = self.file_path(key, self.epoch);
let logger = FileLogger::new(&file_path);
let logger = FileLogger::new(file_path);
let logger = AsyncLogger::new(logger);
self.loggers.insert(key.clone(), logger);