mirror of https://github.com/tracel-ai/burn.git
feat: Make RetroForward public (#1905)
This commit is contained in:
parent
96468fc3c9
commit
f8a7c54272
|
@ -1,7 +1,9 @@
|
|||
/// Checkpointer module
|
||||
pub mod base;
|
||||
pub(crate) mod builder;
|
||||
pub(crate) mod retro_forward;
|
||||
pub(crate) mod state;
|
||||
/// RetroForward module
|
||||
pub mod retro_forward;
|
||||
/// BackwardStates module
|
||||
pub mod state;
|
||||
/// CheckpointStrategy module
|
||||
pub mod strategy;
|
||||
|
|
|
@ -6,8 +6,9 @@ use super::state::{BackwardStates, State};
|
|||
|
||||
/// Definition of the forward function of a node, called during retropropagation only.
|
||||
/// This is different from the normal forward function because it reads and writes from
|
||||
/// the [InnerStates] map instead of having a clear function signature.
|
||||
/// the [BackwardStates] map instead of having a clear function signature.
|
||||
pub trait RetroForward: Debug + Send + 'static {
|
||||
/// Applies the forward pass for retropropagation.
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID);
|
||||
}
|
||||
|
||||
|
|
|
@ -60,16 +60,16 @@ impl State {
|
|||
}
|
||||
|
||||
#[derive(new, Default, Debug)]
|
||||
/// Links [NodeID]s to their current [State]
|
||||
/// Links [NodeID]s to their current state
|
||||
pub struct BackwardStates {
|
||||
map: HashMap<NodeID, State>,
|
||||
}
|
||||
|
||||
impl BackwardStates {
|
||||
/// Returns the output in the [State] of the given [NodeID],
|
||||
/// Returns the output in the state of the given [NodeID],
|
||||
/// and decrements the number of times this state is required.
|
||||
/// This function always gives ownership of the output, but will clone it if needed for further uses.
|
||||
pub(crate) fn get_state<T>(&mut self, node_id: &NodeID) -> T
|
||||
pub fn get_state<T>(&mut self, node_id: &NodeID) -> T
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
|
@ -117,7 +117,8 @@ impl BackwardStates {
|
|||
self.map.insert(node_id, state);
|
||||
}
|
||||
|
||||
pub(crate) fn save<T>(&mut self, node_id: NodeID, saved_output: T)
|
||||
/// Saves the output to the state of the given [NodeID].
|
||||
pub fn save<T>(&mut self, node_id: NodeID, saved_output: T)
|
||||
where
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue