feat: Make RetroForward public (#1905)

This commit is contained in:
phenylshima 2024-06-19 05:44:32 +09:00 committed by GitHub
parent 96468fc3c9
commit f8a7c54272
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 7 deletions

View File

@ -1,7 +1,9 @@
/// Checkpointer module
pub mod base;
pub(crate) mod builder;
pub(crate) mod retro_forward;
pub(crate) mod state;
/// RetroForward module
pub mod retro_forward;
/// BackwardStates module
pub mod state;
/// CheckpointStrategy module
pub mod strategy;

View File

@ -6,8 +6,9 @@ use super::state::{BackwardStates, State};
/// Definition of the forward function of a node, called during retropropagation only.
/// This is different from the normal forward function because it reads and writes from
/// the [InnerStates] map instead of having a clear function signature.
/// the [BackwardStates] map instead of having a clear function signature.
pub trait RetroForward: Debug + Send + 'static {
/// Applies the forward pass for retropropagation.
fn forward(&self, states: &mut BackwardStates, out_node: NodeID);
}

View File

@ -60,16 +60,16 @@ impl State {
}
#[derive(new, Default, Debug)]
/// Links [NodeID]s to their current [State]
/// Links [NodeID]s to their current state
pub struct BackwardStates {
map: HashMap<NodeID, State>,
}
impl BackwardStates {
/// Returns the output in the [State] of the given [NodeID],
/// Returns the output in the state of the given [NodeID],
/// and decrements the number of times this state is required.
/// This function always gives ownership of the output, but will clone it if needed for further uses.
pub(crate) fn get_state<T>(&mut self, node_id: &NodeID) -> T
pub fn get_state<T>(&mut self, node_id: &NodeID) -> T
where
T: Clone + Send + 'static,
{
@ -117,7 +117,8 @@ impl BackwardStates {
self.map.insert(node_id, state);
}
pub(crate) fn save<T>(&mut self, node_id: NodeID, saved_output: T)
/// Saves the output to the state of the given [NodeID].
pub fn save<T>(&mut self, node_id: NodeID, saved_output: T)
where
T: Clone + Send + 'static,
{