Add configurable application logger to learner builder (#1774)

* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct

* Remove unused log module and renames, fmt

* Renamed tracing subscriber logger

* renamed to application logger installer

* book learner configuration update update

* fix typo

* unused import
This commit is contained in:
Jonathan Richard 2024-05-16 16:25:33 -04:00 committed by GitHub
parent 7ab2ba1809
commit 8de05e1419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 89 additions and 64 deletions

View File

@ -31,6 +31,7 @@ The learner builder provides numerous options when it comes to configurations.
| Num Epochs | Set the number of epochs. |
| Devices | Set the devices to be used |
| Checkpoint | Restart training from a checkpoint |
| Application logging | Configure the application logging installer (default is writing to `experiment.log`) |
When the builder is configured at your liking, you can then move forward to build the learner. The
build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note

View File

@ -0,0 +1,67 @@
use std::path::Path;
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{registry, Layer};
/// This trait is used to install an application logger.
pub trait ApplicationLoggerInstaller {
/// Install the application logger.
fn install(&self) -> Result<(), String>;
}
/// This struct is used to install a local file application logger to output logs to a given file path.
pub struct FileApplicationLoggerInstaller {
path: String,
}
impl FileApplicationLoggerInstaller {
/// Create a new file application logger.
pub fn new(path: &str) -> Self {
Self {
path: path.to_string(),
}
}
}
impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
fn install(&self) -> Result<(), String> {
let path = Path::new(&self.path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name()
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(writer)
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));
if registry().with(layer).try_init().is_err() {
return Err("Failed to install the file logger.".to_string());
}
let hook = std::panic::take_hook();
let file_path: String = self.path.to_owned();
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
);
hook(info);
}));
Ok(())
}
}

View File

@ -1,7 +1,6 @@
use std::collections::HashSet;
use std::rc::Rc;
use super::log::install_file_logger;
use super::Learner;
use crate::checkpoint::{
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
@ -15,7 +14,10 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer};
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
use crate::{
ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
LearnerSummaryConfig,
};
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::AutodiffModule;
use burn_core::optim::Optimizer;
@ -50,7 +52,7 @@ where
metrics: Metrics<T, V>,
event_store: LogEventStore,
interrupter: TrainingInterrupter,
log_to_file: bool,
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
num_loggers: usize,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
@ -84,7 +86,9 @@ where
event_store: LogEventStore::default(),
renderer: None,
interrupter: TrainingInterrupter::new(),
log_to_file: true,
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
format!("{}/experiment.log", directory).as_str(),
))),
num_loggers: 0,
checkpointer_strategy: Box::new(
ComposedCheckpointingStrategy::builder()
@ -233,8 +237,11 @@ where
/// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling
/// will apply.
pub fn log_to_file(mut self, enabled: bool) -> Self {
self.log_to_file = enabled;
pub fn with_application_logger(
mut self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> Self {
self.tracing_logger = logger;
self
}
@ -258,7 +265,7 @@ where
format!("{}/checkpoint", self.directory).as_str(),
"optim",
);
let checkpointer_scheduler = FileCheckpointer::new(
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
recorder,
format!("{}/checkpoint", self.directory).as_str(),
"scheduler",
@ -309,8 +316,10 @@ where
O::Record: 'static,
S::Record: 'static,
{
if self.log_to_file {
self.init_logger();
if self.tracing_logger.is_some() {
if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
log::warn!("Failed to install the experiment logger: {}", e);
}
}
let renderer = self.renderer.unwrap_or_else(|| {
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
@ -360,9 +369,4 @@ where
summary,
}
}
fn init_logger(&self) {
let file_path = format!("{}/experiment.log", self.directory);
install_file_logger(file_path.as_str());
}
}

View File

@ -1,47 +0,0 @@
use std::path::Path;
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{registry, Layer};
/// If a global tracing subscriber is not already configured, set up logging to a file,
/// and add our custom panic hook.
pub(crate) fn install_file_logger(file_path: &str) {
let path = Path::new(file_path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name()
.unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(writer)
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));
if registry().with(layer).try_init().is_ok() {
update_panic_hook(file_path);
}
}
fn update_panic_hook(file_path: &str) {
let hook = std::panic::take_hook();
let file_path = file_path.to_owned();
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
);
hook(info);
}));
}

View File

@ -1,3 +1,4 @@
mod application_logger;
mod base;
mod builder;
mod classification;
@ -8,8 +9,7 @@ mod step;
mod summary;
mod train_val;
pub(crate) mod log;
pub use application_logger::*;
pub use base::*;
pub use builder::*;
pub use classification::*;

View File

@ -76,7 +76,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
.devices(vec![device])
.num_epochs(config.num_epochs)
.renderer(CustomRenderer {})
.log_to_file(false);
.with_application_logger(None);
// can be used to interrupt training
let _interrupter = builder.interrupter();