diff --git a/crates/burn-train/src/checkpoint/strategy/metric.rs b/crates/burn-train/src/checkpoint/strategy/metric.rs index f2aa58efe..7a1cd6085 100644 --- a/crates/burn-train/src/checkpoint/strategy/metric.rs +++ b/crates/burn-train/src/checkpoint/strategy/metric.rs @@ -76,7 +76,7 @@ mod tests { }, TestBackend, }; - use std::sync::Arc; + use std::rc::Rc; use super::*; @@ -93,7 +93,7 @@ mod tests { store.register_logger_train(InMemoryMetricLogger::default()); // Register the loss metric. metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Arc::new(EventStoreClient::new(store)); + let store = Rc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); // Two points for the first epoch. Mean 0.75 diff --git a/crates/burn-train/src/learner/base.rs b/crates/burn-train/src/learner/base.rs index 0534b0f4d..bd6128681 100644 --- a/crates/burn-train/src/learner/base.rs +++ b/crates/burn-train/src/learner/base.rs @@ -8,6 +8,7 @@ use burn_core::module::Module; use burn_core::optim::Optimizer; use burn_core::tensor::backend::Backend; use burn_core::tensor::Device; +use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -26,7 +27,7 @@ pub struct Learner { pub(crate) interrupter: TrainingInterrupter, pub(crate) early_stopping: Option>, pub(crate) event_processor: LC::EventProcessor, - pub(crate) event_store: Arc, + pub(crate) event_store: Rc, pub(crate) summary: Option, } diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index 394a842ab..cf8c5a92a 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::sync::Arc; +use std::rc::Rc; use super::log::install_file_logger; use super::Learner; @@ -328,7 +328,7 @@ where )); } - let event_store = Arc::new(EventStoreClient::new(self.event_store)); + let event_store = Rc::new(EventStoreClient::new(self.event_store)); let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { diff --git a/crates/burn-train/src/learner/early_stopping.rs b/crates/burn-train/src/learner/early_stopping.rs index c66ea9eed..db3dc478a 100644 --- a/crates/burn-train/src/learner/early_stopping.rs +++ b/crates/burn-train/src/learner/early_stopping.rs @@ -113,7 +113,7 @@ impl MetricEarlyStoppingStrategy { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::rc::Rc; use crate::{ logger::InMemoryMetricLogger, @@ -197,7 +197,7 @@ mod tests { store.register_logger_train(InMemoryMetricLogger::default()); metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Arc::new(EventStoreClient::new(store)); + let store = Rc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); let mut epoch = 1; diff --git a/crates/burn-train/src/metric/processor/full.rs b/crates/burn-train/src/metric/processor/full.rs index b25870dfb..9f76588c8 100644 --- a/crates/burn-train/src/metric/processor/full.rs +++ b/crates/burn-train/src/metric/processor/full.rs @@ -1,7 +1,7 @@ use super::{Event, EventProcessor, Metrics}; use crate::metric::store::EventStoreClient; use crate::renderer::{MetricState, MetricsRenderer}; -use std::sync::Arc; +use std::rc::Rc; /// An [event processor](EventProcessor) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). @@ -9,14 +9,14 @@ use std::sync::Arc; pub struct FullEventProcessor { metrics: Metrics, renderer: Box, - store: Arc, + store: Rc, } impl FullEventProcessor { pub(crate) fn new( metrics: Metrics, renderer: Box, - store: Arc, + store: Rc, ) -> Self { Self { metrics, diff --git a/crates/burn-train/src/metric/processor/minimal.rs b/crates/burn-train/src/metric/processor/minimal.rs index bb60713e4..e95d2e8b4 100644 --- a/crates/burn-train/src/metric/processor/minimal.rs +++ b/crates/burn-train/src/metric/processor/minimal.rs @@ -1,13 +1,13 @@ use super::{Event, EventProcessor, Metrics}; use crate::metric::store::EventStoreClient; -use std::sync::Arc; +use std::rc::Rc; /// An [event processor](EventProcessor) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). #[derive(new)] pub(crate) struct MinimalEventProcessor { metrics: Metrics, - store: Arc, + store: Rc, } impl EventProcessor for MinimalEventProcessor {