[Breaking] Make Tensor, Module, Optimizer !Sync + Refactor Autodiff (#1575)

This commit is contained in:
Nathaniel Simard 2024-04-04 16:01:17 -04:00 committed by GitHub
parent ce898ff899
commit 1239d9bfa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 1048 additions and 677 deletions

View File

@ -11,8 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-autodiff"
version.workspace = true
[features]
default = []
default = ["std"]
export_tests = ["burn-tensor-testgen"]
std = []
[dependencies]
burn-common = { path = "../burn-common", version = "0.13.0" }

View File

@ -1,7 +1,7 @@
use crate::{
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
grads::Gradients,
graph::backward::backward,
runtime::AutodiffClient,
tensor::AutodiffTensor,
AutodiffBridge,
};
@ -53,7 +53,9 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
type Gradients = Gradients;
fn backward<const D: usize>(tensor: AutodiffTensor<B, D>) -> Gradients {
backward(tensor)
let client = tensor.node.client.clone();
AutodiffClient::backward(&client, tensor)
}
fn grad<const D: usize>(
@ -83,7 +85,7 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
grad: B::FloatTensorPrimitive<D>,
) {
grads.remove(tensor);
grads.register::<B, D>(tensor.node.clone(), grad);
grads.register::<B, D>(tensor.node.id, grad);
}
fn int_inner<const D: usize>(

View File

@ -39,7 +39,7 @@ where
_bridge: PhantomData<Bridge>,
}
#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
struct RetroIntoTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
tensor_id: NodeID,
_backend: PhantomData<B>,
@ -84,9 +84,9 @@ where
_backend: PhantomData,
_bridge: PhantomData,
}
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id))
.parents([&tensor])
.stateless(Bridge::into_target(tensor.primitive, None))
}
@ -101,7 +101,7 @@ where
_bridge: PhantomData<Bridge>,
}
#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
struct RetroFromTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
tensor_id: NodeID,
_backend: PhantomData<B>,
@ -146,9 +146,9 @@ where
_backend: PhantomData,
_bridge: PhantomData,
}
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id))
.parents([&tensor])
.stateless(Bridge::from_target(tensor.primitive, None))
}

View File

@ -1,22 +1,20 @@
use std::collections::HashMap;
use crate::graph::{NodeID, NodeRef};
use super::{
retro_forward::RetroForwards,
state::{BackwardStates, State},
};
use crate::graph::NodeID;
use std::collections::HashMap;
#[derive(new, Debug)]
/// Links a [NodeID] to its autodiff graph [NodeRef]
pub(crate) struct NodeTree {
map: HashMap<NodeID, NodeRef>,
map: HashMap<NodeID, Vec<NodeID>>,
}
impl NodeTree {
/// Gives the parents of the node in the autodiff graph
pub(crate) fn parents(&self, node_id: &NodeID) -> Option<Vec<NodeID>> {
self.map.get(node_id).map(|node| node.parents.clone())
self.map.get(node_id).cloned()
}
}
@ -33,14 +31,12 @@ impl Checkpointer {
/// or give their pre-computed tensors.
pub fn retrieve_node_output<T>(&mut self, node_id: NodeID) -> T
where
T: Clone + Send + Sync + 'static,
T: Clone + Send + 'static,
{
self.topological_sort(node_id.clone())
.into_iter()
.for_each(|node| {
self.retro_forwards
.execute_retro_forward(node, &mut self.backward_states)
});
self.topological_sort(node_id).into_iter().for_each(|node| {
self.retro_forwards
.execute_retro_forward(node, &mut self.backward_states)
});
self.backward_states.get_state::<T>(&node_id)
}

View File

@ -1,11 +1,9 @@
use std::{any::Any, collections::HashMap, sync::Arc};
use burn_tensor::backend::Backend;
use crate::{
graph::{ComputingProperty, NodeID, NodeRef, NodeSteps},
graph::{ComputingProperty, NodeID, NodeSteps},
tensor::AutodiffTensor,
};
use burn_tensor::backend::Backend;
use std::{any::Any, collections::HashMap, sync::Arc};
use super::{
base::{Checkpointer, NodeTree},
@ -20,31 +18,34 @@ pub enum CheckpointingAction {
/// The node's already computed output should be saved
Computed {
/// The node
node_ref: NodeRef,
node_id: NodeID,
/// The node's output
state_content: Box<dyn Any + Send + Sync>,
state_content: Box<dyn Any + Send>,
},
/// The node should recompute itself when asked
Recompute {
/// The node
node_ref: NodeRef,
node_id: NodeID,
/// How the node should recompute itself
retro_forward: Arc<dyn RetroForward>,
},
}
// TODO: Remove that when proper client server.
unsafe impl Send for CheckpointingAction {}
impl CheckpointingAction {
/// Utilitary function to access the id of the node of the checkpointing action
pub fn id(&self) -> NodeID {
match self {
CheckpointingAction::Computed {
node_ref,
node_id: node_ref,
state_content: _,
} => node_ref.id.clone(),
} => *node_ref,
CheckpointingAction::Recompute {
node_ref,
node_id: node_ref,
retro_forward: _,
} => node_ref.id.clone(),
} => *node_ref,
}
}
}
@ -83,13 +84,13 @@ impl CheckpointerBuilder {
match &tensor.node.properties {
ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {
action_list.push(CheckpointingAction::Computed {
node_ref: tensor.node.clone(),
node_id: tensor.node.id,
state_content: Box::new(tensor.primitive.clone()),
})
}
ComputingProperty::MemoryBound { retro_forward } => {
action_list.push(CheckpointingAction::Recompute {
node_ref: tensor.node.clone(),
node_id: tensor.node.id,
retro_forward: retro_forward.clone(),
})
}
@ -105,10 +106,6 @@ impl CheckpointerBuilder {
}
}
pub(crate) fn len(&self) -> usize {
self.explicit_actions.len() + self.backup_actions.len()
}
pub(crate) fn build(self, graph: &NodeSteps) -> Checkpointer {
let node_tree = self.make_tree(graph);
let mut backward_states_map = HashMap::new();
@ -143,11 +140,11 @@ impl CheckpointerBuilder {
{
match action {
CheckpointingAction::Computed {
node_ref,
node_id: node_ref,
state_content: _,
} => stop_nodes.push(node_ref.id.clone()),
} => stop_nodes.push(*node_ref),
CheckpointingAction::Recompute {
node_ref: _,
node_id: _,
retro_forward: _,
} => {}
}
@ -165,10 +162,10 @@ impl CheckpointerBuilder {
for action in self.explicit_actions.iter() {
match action {
CheckpointingAction::Computed {
node_ref,
node_id: node_ref,
state_content: _,
} => {
let id = node_ref.id.clone();
let id = *node_ref;
match n_required_map.remove(&id) {
Some(n) => {
n_required_map.insert(id, n + 1);
@ -179,10 +176,10 @@ impl CheckpointerBuilder {
};
}
CheckpointingAction::Recompute {
node_ref,
node_id: node_ref,
retro_forward: _,
} => {
let id = node_ref.id.clone();
let id = *node_ref;
Self::update_n_required_of_parents(
id,
&mut n_required_map,
@ -229,13 +226,13 @@ impl CheckpointerBuilder {
match action {
CheckpointingAction::Computed {
node_ref: _,
node_id: _,
state_content,
} => {
self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)
}
CheckpointingAction::Recompute {
node_ref: _,
node_id: _,
retro_forward,
} => self.checkpoint_lazy(
backward_states_map,
@ -251,7 +248,7 @@ impl CheckpointerBuilder {
fn make_tree(&self, graph: &NodeSteps) -> NodeTree {
let mut tree = HashMap::default();
for (id, step) in graph {
tree.insert(id.clone(), step.node());
tree.insert(*id, step.parents());
}
NodeTree::new(tree)
}
@ -267,7 +264,7 @@ impl CheckpointerBuilder {
n_required_map.insert(id, n + 1);
}
None => {
n_required_map.insert(id.clone(), 1);
n_required_map.insert(id, 1);
if !stop_nodes.contains(&id) {
if let Some(parents) = node_tree.parents(&id) {
for p in parents {
@ -288,7 +285,7 @@ impl CheckpointerBuilder {
&self,
backward_states_map: &mut HashMap<NodeID, State>,
node_id: NodeID,
state_content: Box<dyn Any + Send + Sync>,
state_content: Box<dyn Any + Send>,
n_required: usize,
) {
backward_states_map.insert(
@ -308,7 +305,7 @@ impl CheckpointerBuilder {
retro_forward: Arc<dyn RetroForward>,
n_required: usize,
) {
retro_forward_map.insert(node_id.clone(), retro_forward);
backward_states_map.insert(node_id.clone(), State::Recompute { n_required });
retro_forward_map.insert(node_id, retro_forward);
backward_states_map.insert(node_id, State::Recompute { n_required });
}
}

View File

@ -7,7 +7,7 @@ use super::state::{BackwardStates, State};
/// Definition of the forward function of a node, called during retropropagation only.
/// This is different from the normal forward function because it reads and writes from
/// the [InnerStates] map instead of having a clear function signature.
pub trait RetroForward: Debug + Send + Sync + 'static {
pub trait RetroForward: Debug + Send + 'static {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID);
}
@ -31,7 +31,7 @@ impl RetroForwards {
{
// Retro forwards are always used only once because afterwards their state is computed
let retro_forward = self.map.remove(&node_id).unwrap();
retro_forward.forward(backward_states, node_id.clone());
retro_forward.forward(backward_states, node_id);
}
}
@ -48,7 +48,7 @@ macro_rules! retro_unary_scalar {
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
struct $name<B: Backend, const D: usize> {
lhs_id: NodeID,
rhs: FloatElem<B>,
@ -72,7 +72,7 @@ macro_rules! retro_unary {
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
struct $name<B: Backend, const D: usize> {
input_id: NodeID,
_backend: PhantomData<B>,
@ -95,7 +95,7 @@ macro_rules! retro_binary {
$name:ident,
$ops:expr
) => {
#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
struct $name<B: Backend, const D: usize> {
lhs_id: NodeID,
rhs_id: NodeID,

View File

@ -3,7 +3,7 @@ use std::{any::Any, collections::HashMap};
use crate::graph::NodeID;
/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
pub(crate) type StateContent = Box<dyn Any + Send + Sync>;
pub(crate) type StateContent = Box<dyn Any + Send>;
#[derive(Debug)]
/// The state contained at one node. Encapsulates the node output if precomputed,
@ -71,7 +71,7 @@ impl BackwardStates {
/// This function always gives ownership of the output, but will clone it if needed for further uses.
pub(crate) fn get_state<T>(&mut self, node_id: &NodeID) -> T
where
T: Clone + Send + Sync + 'static,
T: Clone + Send + 'static,
{
// Fetch the state and decrement its number of required
let state = self.map.remove(node_id).unwrap();
@ -97,7 +97,7 @@ impl BackwardStates {
.unwrap()
.clone();
self.insert_state(node_id.clone(), new_stored_state);
self.insert_state(*node_id, new_stored_state);
downcasted
} else {
@ -119,7 +119,7 @@ impl BackwardStates {
pub(crate) fn save<T>(&mut self, node_id: NodeID, saved_output: T)
where
T: Clone + Send + Sync + 'static,
T: Clone + Send + 'static,
{
let n_required = self.get_state_ref(&node_id).unwrap().n_required();
self.insert_state(

View File

@ -3,6 +3,7 @@ use burn_tensor::{backend::Backend, container::TensorContainer, Tensor};
use crate::{
graph::{NodeRef, Requirement},
tensor::AutodiffTensor,
NodeID,
};
/// Gradient identifier.
@ -25,7 +26,7 @@ impl Gradients {
container: TensorContainer::new(),
};
gradients.register::<B, D>(
root_node,
root_node.id,
B::float_ones(B::float_shape(&root_tensor), &B::float_device(&root_tensor)),
);
gradients
@ -76,15 +77,15 @@ impl Gradients {
/// If the tensor already exists, add both tensors together before saving the result.
pub fn register<B: Backend, const D: usize>(
&mut self,
node: NodeRef,
node_id: NodeID,
value: TensorPrimitive<B, D>,
) {
if let Some(tensor_old) = self.container.remove::<B, D>(&node.id.value) {
if let Some(tensor_old) = self.container.remove::<B, D>(&node_id.value) {
self.container
.register(node.id.value, Tensor::from_primitive(value).add(tensor_old));
.register(node_id.value, Tensor::from_primitive(value).add(tensor_old));
} else {
self.container
.register::<B, D>(node.id.value, Tensor::from_primitive(value));
.register::<B, D>(node_id.value, Tensor::from_primitive(value));
}
}
}

View File

@ -1,49 +0,0 @@
use burn_tensor::backend::Backend;
use crate::{checkpoint::base::Checkpointer, grads::Gradients, tensor::AutodiffTensor};
use super::{traversal::BreadthFirstSearch, Graph, NodeRef, StepBoxed};
pub fn backward<B: Backend, const D: usize>(root: AutodiffTensor<B, D>) -> Gradients {
let grads = Gradients::new::<B, D>(root.node.clone(), root.primitive);
let checkpointer = root.graph.build_checkpointer();
let tape = build_tape(root.node, root.graph);
execute_steps(tape, grads, checkpointer)
}
fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
let mut tape = (0..root.order)
.map(|_| Vec::with_capacity(1))
.collect::<Vec<_>>();
BreadthFirstSearch.traverse(root, graph, |node, step| {
if node.order == 0 {
return;
}
if let Some(steps) = tape.get_mut(node.order - 1) {
steps.push(step)
};
});
tape
}
fn execute_steps(
tape: Vec<Vec<StepBoxed>>,
mut grads: Gradients,
mut checkpointer: Checkpointer,
) -> Gradients {
tape.into_iter().rev().for_each(|steps| {
steps
.into_iter()
.for_each(|step| step.step(&mut grads, &mut checkpointer))
});
#[cfg(feature = "export_tests")]
// For checkpointing tests
assert!(checkpointer.is_empty());
grads
}

View File

@ -1,153 +1,17 @@
use spin::Mutex;
use std::{collections::HashMap, sync::Arc};
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
grads::Gradients,
};
use super::{NodeID, NodeRef};
use super::NodeID;
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use std::collections::HashMap;
/// Backward step for reverse mode autodiff.
pub trait Step: Send + Sync + std::fmt::Debug {
pub trait Step: Send + std::fmt::Debug {
/// Executes the step and consumes it.
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
/// The node associated to the step.
fn node(&self) -> NodeRef;
fn node(&self) -> NodeID;
/// The parents of the node associated to the step.
fn parents(&self) -> Vec<NodeID>;
fn order(&self) -> usize;
}
pub type StepBoxed = Box<dyn Step>;
pub type NodeSteps = HashMap<NodeID, StepBoxed>;
/// Graph data structure.
///
/// The graph contains the [node steps](Step), which can be access by [node id](NodeID).
#[derive(Default, Clone, Debug)]
pub struct Graph {
steps: Arc<Mutex<NodeSteps>>,
checkpointing_actions: Arc<Mutex<CheckpointerBuilder>>,
}
impl Graph {
/// Create a new graph.
pub fn new() -> Self {
Self::default()
}
/// Get all the steps for the graph.
///
/// # Notes
///
/// This is a owned method, so the current graph will be freed. However, the steps can
/// be shared with other graphs, therefore they are going to be cleared.
///
/// This is useful, since the graph is supposed to be consumed only once for backprop, and
/// keeping all the tensors alive for multiple backward call is a heavy waste of resources.
pub fn steps(self) -> NodeSteps {
let mut map_drain = HashMap::new();
self.execute_mut_steps(|map| {
std::mem::swap(&mut *map, &mut map_drain);
});
map_drain
}
/// # Notes
///
/// This is a owned method, so the current checkpointing actions will be freed.
pub fn take_checkpointing_actions(self) -> CheckpointerBuilder {
let mut actions = CheckpointerBuilder::default();
self.execute_mut_checkpointing_actions(|checkpointing_actions| {
std::mem::swap(&mut *checkpointing_actions, &mut actions);
});
actions
}
/// Register a new step into the graph.
pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self {
self.execute_mut_steps(|map| {
map.insert(id.clone(), ops);
})
}
/// Merge two graphs.
pub fn merge(self, other: Self) -> Self {
if Arc::ptr_eq(&self.steps, &other.steps) {
return self;
}
self.merge_different(other)
}
fn execute_mut_steps<F: FnOnce(&mut NodeSteps)>(mut self, func: F) -> Self {
match Arc::get_mut(&mut self.steps) {
Some(mutex) => {
let map = mutex.get_mut();
func(map);
}
None => {
// Only lock when there are multiple references to the graph.
let mut map = self.steps.lock();
func(&mut map);
}
};
self
}
fn execute_mut_checkpointing_actions<F: FnOnce(&mut CheckpointerBuilder)>(
mut self,
func: F,
) -> Self {
match Arc::get_mut(&mut self.checkpointing_actions) {
Some(mutex) => {
let map = mutex.get_mut();
func(map);
}
None => {
// Only lock when there are multiple references to the graph.
let mut actions = self.checkpointing_actions.lock();
func(&mut actions);
}
};
self
}
fn merge_different(self, other: Self) -> Self {
let mut map2 = other.clone().steps();
let mut actions2 = other.take_checkpointing_actions();
self.execute_mut_steps(|map1| {
if map1.len() > map2.len() {
map1.extend(map2);
} else {
let mut map_drain = HashMap::new();
std::mem::swap(map1, &mut map_drain);
map2.extend(map_drain);
std::mem::swap(map1, &mut map2);
}
})
.execute_mut_checkpointing_actions(|actions1| {
if actions1.len() > actions2.len() {
actions1.extend(actions2);
} else {
let mut checkpointing_drain = CheckpointerBuilder::default();
std::mem::swap(actions1, &mut checkpointing_drain);
actions2.extend(checkpointing_drain);
std::mem::swap(actions1, &mut actions2);
}
})
}
pub(crate) fn build_checkpointer(&self) -> Checkpointer {
let mut guard = self.checkpointing_actions.lock();
let builder: CheckpointerBuilder = std::mem::take(&mut *guard);
builder.build(&self.steps.lock())
}
pub(crate) fn extend_checkpointer_builder(&self, checkpointing_actions: CheckpointerBuilder) {
self.checkpointing_actions
.lock()
.extend(checkpointing_actions);
}
}

View File

@ -2,7 +2,6 @@ mod base;
mod node;
mod requirement;
pub mod backward;
pub mod traversal;
pub use base::*;

View File

@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::checkpoint::retro_forward::RetroForward;
use crate::runtime::AutodiffClientImpl;
use super::Requirement;
@ -14,6 +15,15 @@ pub enum ComputingProperty {
Ambiguous, // Maybe autotune someday
}
/// This is safe only because we only call RetroForward on the autodiff server.
/// Therefore, the trait will never be used by multiple threads at the same time.
///
/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the
/// Arc, which will make (dyn RetroForward) safely implement Send.
unsafe impl Send for ComputingProperty {}
/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.
unsafe impl Sync for ComputingProperty {}
/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
#[derive(new, Debug)]
pub struct Node {
@ -22,6 +32,7 @@ pub struct Node {
pub id: NodeID,
pub requirement: Requirement,
pub properties: ComputingProperty,
pub client: AutodiffClientImpl,
}
pub type NodeRef = Arc<Node>;
@ -36,7 +47,7 @@ impl Node {
}
/// Unique identifier generated for each node.
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
pub struct NodeID {
/// The integer representation of the id
pub value: u64,

View File

@ -1,29 +1,28 @@
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use super::{Graph, NodeRef, StepBoxed};
use crate::NodeID;
use super::StepBoxed;
/// Breadth for search algorithm.
pub struct BreadthFirstSearch;
impl BreadthFirstSearch {
/// Traverse the graph of backward steps from a root node.
pub fn traverse<F: FnMut(NodeRef, StepBoxed)>(
pub fn traverse<F: FnMut(NodeID, StepBoxed)>(
&self,
root: NodeRef,
graph: Graph,
root_id: NodeID,
root_step: StepBoxed,
steps: &mut HashMap<NodeID, StepBoxed>,
mut callback: F,
) {
let mut visited = HashSet::with_capacity(root.order);
let mut parents = Vec::with_capacity(root.order);
let mut steps = graph.steps();
let root_step = steps.remove(&root.id).expect(
"Root node should have a step registered, did you forget to call \
`Tensor::register_grad` on the tensor where you need gradients?",
);
let root_order = root_step.order();
let mut visited = HashSet::with_capacity(root_order);
let mut parents = Vec::with_capacity(root_order);
visited.insert(root.id.clone());
parents.append(&mut root.parents.clone());
callback(root, root_step);
visited.insert(root_id);
parents.append(&mut root_step.parents());
callback(root_id, root_step);
while let Some(id) = parents.pop() {
let step = match steps.remove(&id) {
@ -31,21 +30,22 @@ impl BreadthFirstSearch {
None => continue,
};
let node = step.node();
let step_node = step.node();
let step_parents = step.parents();
if visited.contains(&node.id) {
if visited.contains(&step_node) {
continue;
}
visited.insert(node.id.clone());
visited.insert(step_node);
for id in node.parents.iter() {
for id in step_parents.iter() {
if !visited.contains(id) {
parents.push(id.clone());
parents.push(*id);
}
}
callback(node, step);
callback(step_node, step);
}
}
}

View File

@ -28,6 +28,8 @@ pub(crate) mod utils;
mod backend;
mod bridge;
pub(crate) mod runtime;
pub use backend::*;
pub use bridge::*;

View File

@ -40,9 +40,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
}
match Gelu::<D>
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroGelu::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroGelu::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -77,9 +77,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
}
match Relu
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroRelu::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroRelu::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -115,9 +115,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
}
match Sigmoid
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSigmoid::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroSigmoid::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -153,9 +153,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
}
match LogSigmoid::<D>
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroLogSigmoid::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroLogSigmoid::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{

View File

@ -2,7 +2,7 @@ use super::{Ops, OpsPrep};
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy},
grads::Gradients,
graph::{ComputingProperty, Graph, NodeRef, Requirement},
graph::{ComputingProperty, NodeRef, Requirement},
utils::duplicate,
};
use burn_tensor::backend::Backend;
@ -14,13 +14,13 @@ use burn_tensor::backend::Backend;
/// Concrete types implementing this trait should not have any state.
/// If a state is necessary during the backward pass,
/// they should be declared with the associated type 'State'.
pub trait Backward<B, const D: usize, const N: usize>: Send + Sync + std::fmt::Debug
pub trait Backward<B, const D: usize, const N: usize>: Send + std::fmt::Debug
where
Self: Sized + 'static,
B: Backend,
{
/// Associated type to compute the backward pass.
type State: Clone + Send + Sync + std::fmt::Debug + 'static;
type State: Clone + Send + std::fmt::Debug + 'static;
/// The backward pass.
fn backward(
@ -34,12 +34,10 @@ where
fn prepare<C: CheckpointStrategy>(
self,
nodes: [NodeRef; N],
graphs: [Graph; N],
) -> OpsPrep<Self, B, Self::State, C, D, N> {
let requirement = Requirement::from_nodes(&nodes);
OpsPrep::new(
nodes,
graphs,
requirement,
self,
ComputingProperty::Ambiguous, // If not specified we start with ambiguous
@ -65,12 +63,12 @@ pub fn binary<B, const D_OUT: usize, const D_LHS: usize, const D_RHS: usize, FLh
if let Some(node) = node_lhs {
let grad = func_lhs(grad_4lhs.unwrap());
grads.register::<B, D_LHS>(node, grad)
grads.register::<B, D_LHS>(node.id, grad)
}
if let Some(node) = node_rhs {
let grad = func_rhs(grad_4rhs.unwrap());
grads.register::<B, D_RHS>(node, grad)
grads.register::<B, D_RHS>(node.id, grad)
}
}
@ -89,7 +87,7 @@ pub fn unary<B, const D_OUT: usize, const D_IN: usize, F>(
if let Some(node) = parent_node {
let grad = func(grad);
grads.register::<B, D_IN>(node, grad)
grads.register::<B, D_IN>(node.id, grad)
}
}
@ -110,6 +108,6 @@ pub fn unary_different_backend<BIn, BOut, const D_OUT: usize, const D_IN: usize,
if let Some(node) = parent_node {
let grad = func(grad);
grads.register::<BIn, D_IN>(node, grad)
grads.register::<BIn, D_IN>(node.id, grad)
}
}

View File

@ -7,7 +7,7 @@ use crate::{
strategy::CheckpointStrategy,
},
grads::Gradients,
graph::{ComputingProperty, Graph, NodeID, NodeRef, Requirement, Step},
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
tensor::AutodiffTensor,
};
use burn_tensor::{backend::Backend, Shape};
@ -19,7 +19,6 @@ use std::marker::PhantomData;
#[derive(new)]
pub struct OpsPrep<Backward, B, S, C, const D: usize, const N: usize, Mode = Init> {
nodes: [NodeRef; N],
graphs: [Graph; N],
requirement: Requirement,
backward: Backward,
compute_property: ComputingProperty,
@ -53,7 +52,6 @@ where
pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, D, N, ComputePropertyDone> {
OpsPrep::new(
self.nodes,
self.graphs,
self.requirement,
self.backward,
ComputingProperty::ComputeBound,
@ -66,7 +64,6 @@ where
pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, D, N, MemoryBound> {
OpsPrep::new(
self.nodes,
self.graphs,
self.requirement,
self.backward,
self.compute_property,
@ -88,7 +85,6 @@ where
) -> OpsPrep<BO, B, S, C, D, N, MemoryBoundRetroForward> {
OpsPrep::new(
self.nodes,
self.graphs,
self.requirement,
self.backward,
C::compute_property(retro_forward),
@ -117,7 +113,6 @@ where
OpsPrep::new(
self.nodes,
self.graphs,
self.requirement,
self.backward,
self.compute_property,
@ -146,7 +141,7 @@ where
impl<BO, B, S, C, const D: usize, const N: usize> OpsPrep<BO, B, S, C, D, N, ComputePropertyDone>
where
B: Backend,
S: Clone + Send + Sync + std::fmt::Debug + 'static,
S: Clone + Send + std::fmt::Debug + 'static,
BO: Backward<B, D, N, State = S>,
{
/// Prepare an operation that requires a state during the backward pass.
@ -154,7 +149,6 @@ where
match self.requirement.is_none() {
false => OpsKind::Tracked(OpsPrep::new(
self.nodes,
self.graphs,
self.requirement,
self.backward,
self.compute_property,
@ -162,7 +156,6 @@ where
)),
true => OpsKind::UnTracked(OpsPrep::new(
self.nodes,
self.graphs,
self.requirement,
self.backward,
self.compute_property,
@ -175,7 +168,7 @@ where
impl<BO, B, S, C, const D: usize, const N: usize> OpsPrep<BO, B, S, C, D, N, UnTracked>
where
B: Backend,
S: Clone + Send + Sync + std::fmt::Debug + 'static,
S: Clone + Send + std::fmt::Debug + 'static,
BO: Backward<B, D, N, State = S>,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
@ -183,24 +176,22 @@ where
let output = AutodiffTensor::from_parents(
output,
&self.nodes,
self.graphs.into_iter(),
self.requirement,
self.compute_property,
self.checkpointer_builder,
);
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, output.node.clone(), ());
// We register the ops in the graph even if untracked, otherwise memory bound operations
// that have an untracked parent would not be able to retrieve it
output.register_step(UntrackedOpsStep::new(ops))
output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder)
}
}
impl<BO, B, S, C, const D: usize, const N: usize> OpsPrep<BO, B, S, C, D, N, Tracked>
where
B: Backend,
S: Clone + Send + Sync + std::fmt::Debug + 'static,
S: Clone + Send + std::fmt::Debug + 'static,
BO: Backward<B, D, N, State = S>,
{
/// Finish the preparation of a tracked operation and returns the output tensor.
@ -212,15 +203,13 @@ where
let output = AutodiffTensor::from_parents(
output,
&self.nodes,
self.graphs.into_iter(),
self.requirement,
self.compute_property,
self.checkpointer_builder,
);
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, output.node.clone(), state);
output.register_step(OpsStep::new(ops, self.backward))
output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder)
}
/// Checkpoints the tensor
@ -228,7 +217,7 @@ where
self.checkpointer_builder
.checkpoint(tensor, ActionType::Explicit);
tensor.node.id.clone()
tensor.node.id
}
}
@ -257,7 +246,7 @@ struct OpsStep<B, T, SB, const D: usize, const N: usize>
where
B: Backend,
T: Backward<B, D, N, State = SB>,
SB: Clone + Send + Sync + std::fmt::Debug + 'static,
SB: Clone + Send + std::fmt::Debug + 'static,
{
ops: Ops<SB, N>,
backward: T,
@ -268,14 +257,22 @@ impl<B, T, SB, const D: usize, const N: usize> Step for OpsStep<B, T, SB, D, N>
where
B: Backend,
T: Backward<B, D, N, State = SB>,
SB: Clone + Send + Sync + std::fmt::Debug + 'static,
SB: Clone + Send + std::fmt::Debug + 'static,
{
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
self.backward.backward(self.ops, grads, checkpointer);
}
fn node(&self) -> NodeRef {
self.ops.node.clone()
fn node(&self) -> NodeID {
self.ops.node.id
}
fn parents(&self) -> Vec<NodeID> {
self.ops.node.parents.clone()
}
fn order(&self) -> usize {
self.ops.node.order
}
}
@ -289,8 +286,15 @@ impl<const N: usize> Step for UntrackedOpsStep<N> {
// Nothing to do
}
fn node(&self) -> NodeRef {
self.ops.node.clone()
fn node(&self) -> NodeID {
self.ops.node.id
}
fn parents(&self) -> Vec<NodeID> {
self.ops.node.parents.clone()
}
fn order(&self) -> usize {
self.ops.node.order
}
}

View File

@ -34,7 +34,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
}
match Embedding
.prepare::<C>([weights.node], [weights.graph])
.prepare::<C>([weights.node])
.compute_bound()
.stateful()
{
@ -85,13 +85,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv2d_backward(x, weight, bias, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node, backward.weights_grad)
grads.register::<B, 4>(node.id, backward.weights_grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
}
}
}
@ -115,20 +115,17 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv2d_backward(x, weight, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node, backward.weights_grad)
grads.register::<B, 4>(node.id, backward.weights_grad)
}
}
}
match bias {
Some(bias) => match Conv2DWithBias
.prepare::<C>(
[x.node.clone(), weight.node.clone(), bias.node.clone()],
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
.compute_bound()
.stateful()
{
@ -149,10 +146,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
)),
},
None => match Conv2DNoBias
.prepare::<C>(
[x.node.clone(), weight.node.clone()],
[x.graph.clone(), weight.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone()])
.compute_bound()
.stateful()
{
@ -203,13 +197,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv_transpose2d_backward(x, weight, bias, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node, backward.weights_grad)
grads.register::<B, 4>(node.id, backward.weights_grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
}
}
}
@ -233,20 +227,17 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv_transpose2d_backward(x, weight, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node, backward.weights_grad)
grads.register::<B, 4>(node.id, backward.weights_grad)
}
}
}
match bias {
Some(bias) => match ConvTranspose2DWithBias
.prepare::<C>(
[x.node.clone(), weight.node.clone(), bias.node.clone()],
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
.compute_bound()
.stateful()
{
@ -273,10 +264,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
)),
},
None => match ConvTranspose2DNoBias
.prepare::<C>(
[x.node.clone(), weight.node.clone()],
[x.graph.clone(), weight.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone()])
.compute_bound()
.stateful()
{
@ -330,13 +318,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv1d_backward(x, weight, bias, grad, options);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
grads.register::<B, 3>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node, backward.weights_grad)
grads.register::<B, 3>(node.id, backward.weights_grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
}
}
}
@ -360,19 +348,16 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv1d_backward(x, weight, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
grads.register::<B, 3>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node, backward.weights_grad)
grads.register::<B, 3>(node.id, backward.weights_grad)
}
}
}
match bias {
Some(bias) => match Conv1DWithBias
.prepare::<C>(
[x.node.clone(), weight.node.clone(), bias.node.clone()],
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
.compute_bound()
.stateful()
{
@ -393,10 +378,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
)),
},
None => match Conv1DNoBias
.prepare::<C>(
[x.node.clone(), weight.node.clone()],
[x.graph.clone(), weight.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone()])
.compute_bound()
.stateful()
{
@ -446,13 +428,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv_transpose1d_backward(x, weight, bias, grad, options);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
grads.register::<B, 3>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node, backward.weights_grad)
grads.register::<B, 3>(node.id, backward.weights_grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
}
}
}
@ -476,20 +458,17 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
let backward = B::conv_transpose1d_backward(x, weight, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
grads.register::<B, 3>(node.id, backward.x_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 3>(node, backward.weights_grad)
grads.register::<B, 3>(node.id, backward.weights_grad)
}
}
}
match bias {
Some(bias) => match ConvTranspose1DWithBias
.prepare::<C>(
[x.node.clone(), weight.node.clone(), bias.node.clone()],
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
.compute_bound()
.stateful()
{
@ -515,10 +494,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
)),
},
None => match ConvTranspose1DNoBias
.prepare::<C>(
[x.node.clone(), weight.node.clone()],
[x.graph.clone(), weight.graph.clone()],
)
.prepare::<C>([x.node.clone(), weight.node.clone()])
.compute_bound()
.stateful()
{
@ -588,13 +564,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
padding,
count_include_pad,
);
grads.register::<B, 3>(node, grad);
grads.register::<B, 3>(node.id, grad);
}
}
}
match AvgPool1D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -654,13 +630,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
padding,
count_include_pad,
);
grads.register::<B, 4>(node, grad);
grads.register::<B, 4>(node.id, grad);
}
}
}
match AvgPool2D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -706,7 +682,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
dilation: usize,
) -> AutodiffTensor<B, 3> {
match MaxPool1D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -744,7 +720,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
dilation: usize,
) -> MaxPool1dWithIndices<Self> {
match MaxPool1D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -806,7 +782,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
dilation: [usize; 2],
) -> AutodiffTensor<B, 4> {
match MaxPool2D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -844,7 +820,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
dilation: [usize; 2],
) -> MaxPool2dWithIndices<Self> {
match MaxPool2D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -908,13 +884,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
if let Some(node) = node_parent {
let grad = B::adaptive_avg_pool1d_backward(state, grad);
grads.register::<B, 3>(node, grad);
grads.register::<B, 3>(node.id, grad);
}
}
}
match AdaptiveAvgPool1D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -950,13 +926,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
if let Some(node) = node_parent {
let grad = B::adaptive_avg_pool2d_backward(state, grad);
grads.register::<B, 4>(node, grad);
grads.register::<B, 4>(node.id, grad);
}
}
}
match AdaptiveAvgPool2D
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -1001,13 +977,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
if let Some(node) = node_parent {
let grad = B::interpolate_backward(state, grad, output_size, options);
grads.register::<B, 4>(node, grad);
grads.register::<B, 4>(node.id, grad);
}
}
}
match Interpolate
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
@ -1060,7 +1036,7 @@ impl<B: Backend> Backward<B, 3, 1> for MaxPool1D {
indices,
);
grads.register::<B, 3>(node, grad.x_grad);
grads.register::<B, 3>(node.id, grad.x_grad);
}
}
}
@ -1100,7 +1076,7 @@ impl<B: Backend> Backward<B, 4, 1> for MaxPool2D {
indices,
);
grads.register::<B, 4>(node, grad.x_grad);
grads.register::<B, 4>(node.id, grad.x_grad);
}
}
}

View File

@ -90,7 +90,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match ToDevice
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -137,15 +137,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Add
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone()],
[lhs.graph.clone(), rhs.graph.clone()],
)
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
.memory_bound()
.retro_forward(RetroAdd::<B, D>::new(
lhs.node.id.clone(),
rhs.node.id.clone(),
))
.retro_forward(RetroAdd::<B, D>::new(lhs.node.id, rhs.node.id))
.parents([&lhs, &rhs])
.stateful()
{
@ -183,9 +177,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
AddScalar
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
.prepare::<C>([lhs.node.clone()])
.memory_bound()
.retro_forward(RetroAddScalar::<B, D>::new(lhs.node.id.clone(), rhs))
.retro_forward(RetroAddScalar::<B, D>::new(lhs.node.id, rhs))
.parents([&lhs])
.stateless(B::float_add_scalar(lhs.primitive, rhs))
}
@ -221,15 +215,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Sub
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone()],
[lhs.graph.clone(), rhs.graph.clone()],
)
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
.memory_bound()
.retro_forward(RetroSub::<B, D>::new(
lhs.node.id.clone(),
rhs.node.id.clone(),
))
.retro_forward(RetroSub::<B, D>::new(lhs.node.id, rhs.node.id))
.parents([&lhs, &rhs])
.stateful()
{
@ -267,9 +255,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
SubScalar
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
.prepare::<C>([lhs.node.clone()])
.memory_bound()
.retro_forward(RetroSubScalar::<B, D>::new(lhs.node.id.clone(), rhs))
.retro_forward(RetroSubScalar::<B, D>::new(lhs.node.id, rhs))
.parents([&lhs])
.stateless(B::float_sub_scalar(lhs.primitive, rhs))
}
@ -317,15 +305,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
match Mul
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone()],
[lhs.graph.clone(), rhs.graph.clone()],
)
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
.memory_bound()
.retro_forward(RetroMul::<B, D>::new(
lhs.node.id.clone(),
rhs.node.id.clone(),
))
.retro_forward(RetroMul::<B, D>::new(lhs.node.id, rhs.node.id))
.parents([&lhs, &rhs])
.stateful()
{
@ -367,9 +349,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match MulScalar
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
.prepare::<C>([lhs.node.clone()])
.memory_bound()
.retro_forward(RetroMulScalar::<B, D>::new(lhs.node.id.clone(), rhs))
.retro_forward(RetroMulScalar::<B, D>::new(lhs.node.id, rhs))
.parents([&lhs])
.stateful()
{
@ -429,15 +411,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
match Div
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone()],
[lhs.graph.clone(), rhs.graph.clone()],
)
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
.memory_bound()
.retro_forward(RetroDiv::<B, D>::new(
lhs.node.id.clone(),
rhs.node.id.clone(),
))
.retro_forward(RetroDiv::<B, D>::new(lhs.node.id, rhs.node.id))
.parents([&lhs, &rhs])
.stateful()
{
@ -480,9 +456,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match DivScalar
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
.prepare::<C>([lhs.node.clone()])
.memory_bound()
.retro_forward(RetroDivScalar::<B, D>::new(lhs.node.id.clone(), rhs))
.retro_forward(RetroDivScalar::<B, D>::new(lhs.node.id, rhs))
.parents([&lhs])
.stateful()
{
@ -536,10 +512,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
match Matmul
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone()],
[lhs.graph.clone(), rhs.graph.clone()],
)
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
.compute_bound()
.stateful()
{
@ -574,9 +547,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
Neg.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
Neg.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroNeg::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroNeg::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateless(B::float_neg(tensor.primitive))
}
@ -607,9 +580,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Recip
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroRecip::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroRecip::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -663,13 +636,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match SwapDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSwapDims::<B, D>::new(
tensor.node.id.clone(),
dim1,
dim2,
))
.retro_forward(RetroSwapDims::<B, D>::new(tensor.node.id, dim1, dim2))
.parents([&tensor])
.stateful()
{
@ -728,9 +697,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match PermuteDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroPermuteDims::<B, D>::new(tensor.node.id.clone(), axes))
.retro_forward(RetroPermuteDims::<B, D>::new(tensor.node.id, axes))
.parents([&tensor])
.stateful()
{
@ -779,12 +748,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match FlipDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroFlipDims::<B, D>::new(
tensor.node.id.clone(),
axes.to_vec(),
))
.retro_forward(RetroFlipDims::<B, D>::new(tensor.node.id, axes.to_vec()))
.parents([&tensor])
.stateful()
{
@ -844,10 +810,10 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match ReshapeDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroReshape::<B, D1, D2>::new(
tensor.node.id.clone(),
tensor.node.id,
shape.clone(),
))
.parents([&tensor])
@ -888,7 +854,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Gather
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -945,7 +911,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Scatter
.prepare::<C>([tensor.node, value.node], [tensor.graph, value.graph])
.prepare::<C>([tensor.node, value.node])
.compute_bound()
.stateful()
{
@ -1010,10 +976,10 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Select
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSelect::<B, D>::new(
tensor.node.id.clone(),
tensor.node.id,
dim,
indices.clone(),
))
@ -1090,16 +1056,13 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match IndexSelectDimAssign::<D>
.prepare::<C>(
[tensor.node.clone(), value.node.clone()],
[tensor.graph.clone(), value.graph.clone()],
)
.prepare::<C>([tensor.node.clone(), value.node.clone()])
.memory_bound()
.retro_forward(RetroSelectAssign::<B, D>::new(
tensor.node.id.clone(),
tensor.node.id,
dim,
indices.clone(),
value.node.id.clone(),
value.node.id,
))
.parents([&tensor, &value])
.stateful()
@ -1164,12 +1127,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Index
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSlice::<B, D1, D2>::new(
tensor.node.id.clone(),
ranges.clone(),
))
.retro_forward(RetroSlice::<B, D1, D2>::new(tensor.node.id, ranges.clone()))
.parents([&tensor])
.stateful()
{
@ -1236,15 +1196,12 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match SliceAssign
.prepare::<C>(
[tensor.node.clone(), value.node.clone()],
[tensor.graph.clone(), value.graph.clone()],
)
.prepare::<C>([tensor.node.clone(), value.node.clone()])
.memory_bound()
.retro_forward(RetroSliceAssign::<B, D1, D2>::new(
tensor.node.id.clone(),
tensor.node.id,
ranges.clone(),
value.node.id.clone(),
value.node.id,
))
.parents([&tensor, &value])
.stateful()
@ -1306,7 +1263,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match MaskWhere
.prepare::<C>([tensor.node, source.node], [tensor.graph, source.graph])
.prepare::<C>([tensor.node, source.node])
.compute_bound()
.stateful()
{
@ -1351,7 +1308,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match MaskFill
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -1489,11 +1446,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
match Mean
.prepare::<C>([tensor.node], [tensor.graph])
.compute_bound()
.stateful()
{
match Mean.prepare::<C>([tensor.node]).compute_bound().stateful() {
OpsKind::Tracked(prep) => prep.finish(
B::float_shape(&tensor.primitive),
B::float_mean(tensor.primitive),
@ -1526,11 +1479,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
match Sum
.prepare::<C>([tensor.node], [tensor.graph])
.compute_bound()
.stateful()
{
match Sum.prepare::<C>([tensor.node]).compute_bound().stateful() {
OpsKind::Tracked(prep) => prep.finish(
B::float_shape(&tensor.primitive),
B::float_sum(tensor.primitive),
@ -1569,7 +1518,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match MeanDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -1609,7 +1558,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match SumDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -1653,9 +1602,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Exp
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroExp::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroExp::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -1691,9 +1640,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Log
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroLog::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroLog::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -1731,9 +1680,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Log1P
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroLog1P::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroLog1P::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -1789,9 +1738,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match PowfScalar
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroPowfScalar::<B, D>::new(tensor.node.id.clone(), value))
.retro_forward(RetroPowfScalar::<B, D>::new(tensor.node.id, value))
.parents([&tensor])
.stateful()
{
@ -1828,9 +1777,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Sqrt
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSqrt::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroSqrt::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -1868,9 +1817,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Abs
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroAbs::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroAbs::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -1907,9 +1856,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Cos
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroCos::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroCos::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -1945,9 +1894,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Sin
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSin::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroSin::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -1987,9 +1936,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Tanh
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroTanh::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroTanh::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -2029,9 +1978,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match Erf
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroErf::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroErf::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
@ -2074,17 +2023,27 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
let mut ranges = ranges.clone();
ranges[self.dim] = current_index..dim_size + current_index;
current_index += dim_size;
grads.register::<B, D>(node, B::float_slice(grad.clone(), ranges));
grads.register::<B, D>(node.id, B::float_slice(grad.clone(), ranges));
});
}
fn node(&self) -> NodeRef {
self.output.clone()
fn node(&self) -> NodeID {
self.output.id
}
fn parents(&self) -> Vec<NodeID> {
self.nodes
.iter()
.filter_map(|node| node.clone())
.map(|node| node.id)
.collect()
}
fn order(&self) -> usize {
self.output.order
}
}
let mut nodes = Vec::with_capacity(tensors.len());
let mut graphs = Vec::with_capacity(tensors.len());
let mut primitives = Vec::with_capacity(tensors.len());
let mut dim_sizes = Vec::with_capacity(tensors.len());
@ -2092,7 +2051,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
dim_sizes.push(B::float_shape(&tensor.primitive).dims[dim]);
nodes.push(tensor.node);
primitives.push(tensor.primitive);
graphs.push(tensor.graph);
});
let requirement = Requirement::from_nodes(&nodes);
@ -2106,28 +2064,20 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
return AutodiffTensor::from_parents(
output,
&nodes,
graphs.into_iter(),
requirement,
cat_computing_property,
checkpointer_builder,
);
}
let output = AutodiffTensor::from_parents(
output,
&nodes,
graphs.into_iter(),
requirement,
cat_computing_property,
checkpointer_builder,
);
let output =
AutodiffTensor::from_parents(output, &nodes, requirement, cat_computing_property);
let nodes = nodes
.into_iter()
.map(|node| node.clone_if_require_grad())
.collect::<Vec<_>>();
let ops = CatStep::<B, D>::new(nodes, dim_sizes, output.node.clone(), dim);
output.register_step(ops)
output.register_step(ops, checkpointer_builder)
}
fn float_max_dim<const D: usize>(
@ -2135,7 +2085,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
dim: usize,
) -> FloatTensor<Self, D> {
match MaxMinDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -2152,7 +2102,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
dim: usize,
) -> (FloatTensor<Self, D>, IntTensor<B, D>) {
match MaxMinDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -2176,7 +2126,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
dim: usize,
) -> FloatTensor<Self, D> {
match MaxMinDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -2193,7 +2143,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
dim: usize,
) -> (FloatTensor<Self, D>, IntTensor<B, D>) {
match MaxMinDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -2283,15 +2233,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
match PowF
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone()],
[lhs.graph.clone(), rhs.graph.clone()],
)
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
.memory_bound()
.retro_forward(RetroPowf::<B, D>::new(
lhs.node.id.clone(),
rhs.node.id.clone(),
))
.retro_forward(RetroPowf::<B, D>::new(lhs.node.id, rhs.node.id))
.parents([&lhs, &rhs])
.stateful()
{
@ -2329,9 +2273,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
Sign.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
Sign.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroSign::<B, D>::new(tensor.node.id.clone()))
.retro_forward(RetroSign::<B, D>::new(tensor.node.id))
.parents([&tensor])
.stateless(B::float_sign(tensor.primitive))
}
@ -2394,12 +2338,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
match ExpandDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroExpand::<B, D1, D2>::new(
tensor.node.id.clone(),
shape.clone(),
))
.retro_forward(RetroExpand::<B, D1, D2>::new(tensor.node.id, shape.clone()))
.parents([&tensor])
.stateful()
{
@ -2417,7 +2358,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
descending: bool,
) -> FloatTensor<Self, D> {
match SortDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
@ -2439,7 +2380,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
descending: bool,
) -> (FloatTensor<Self, D>, IntTensor<B, D>) {
match SortDim
.prepare::<C>([tensor.node], [tensor.graph])
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{

View File

@ -0,0 +1,23 @@
use crate::{
checkpoint::builder::CheckpointerBuilder,
grads::Gradients,
graph::StepBoxed,
tensor::{AutodiffTensor, NodeRefCount},
};
use burn_tensor::backend::Backend;
/// Client used to communicate with the autodiff server.
pub trait AutodiffClient: Send + Clone {
/// Register a new step.
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
/// Call backpropagation from the given tensor.
fn backward<B: Backend, const D: usize>(&self, tensor: AutodiffTensor<B, D>) -> Gradients;
}
/// Client implementation in used.
#[cfg(feature = "std")]
pub type AutodiffClientImpl = super::mspc::ChannelClient;
/// Client implementation in used.
#[cfg(not(feature = "std"))]
pub type AutodiffClientImpl = super::mutex::MutexClient;

View File

@ -0,0 +1,271 @@
use crate::{tensor::NodeRefCount, NodeID};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
/// Keeps a version on the graphs created during autodiff with the reference count of each node.
///
/// When all nodes in a graph have only one reference, the graph can be freed.
#[derive(Default, Debug)]
pub struct GraphMemoryManagement {
graphs: HashMap<GraphId, GraphState>,
owned: HashSet<GraphId>,
}
#[derive(new, Hash, PartialEq, Eq, Clone, Copy, Debug)]
pub struct GraphId {
node: NodeID,
}
#[derive(Debug)]
enum GraphState {
Merged(GraphId),
Owned(Vec<NodeRefCount>),
}
impl GraphMemoryManagement {
/// Register a new node with its parent.
pub fn register(&mut self, node: NodeRefCount, parents: Vec<NodeID>) {
let node_id = *node.as_ref();
let graph_id = GraphId::new(node_id);
self.insert_owned_graph(graph_id, vec![node.clone()]);
if !parents.is_empty() {
let graph_ids = parents.into_iter().map(GraphId::new);
if let Some(parent_graph_id) = self.merge_graph(graph_ids) {
self.merge_graph([graph_id, parent_graph_id]);
}
}
}
/// Free the given graph calling the given function for each node deleted.
pub fn free_graph<F>(&mut self, graph_id: GraphId, mut func: F)
where
F: FnMut(&NodeID),
{
self.owned.remove(&graph_id);
let graph = match self.graphs.remove(&graph_id) {
Some(graph) => graph,
None => return,
};
let graph = match graph {
GraphState::Merged(graph) => {
self.free_graph(graph, func);
return;
}
GraphState::Owned(graph) => graph,
};
for node_id in graph.into_iter() {
func(&node_id);
}
}
/// Find the graphs where all nodes are orphan.
///
/// The returned graphs can be safely freed.
pub fn find_orphan_graphs(&self) -> Vec<GraphId> {
self.owned
.iter()
.filter(|id| self.is_orphan(id))
.copied()
.collect()
}
fn is_orphan(&self, id: &GraphId) -> bool {
let graph = match self.graphs.get(id) {
Some(val) => val,
None => return false,
};
let nodes = match graph {
GraphState::Merged(_) => return false,
GraphState::Owned(nodes) => nodes,
};
for node in nodes {
if Arc::strong_count(node) > 1 {
return false;
}
}
true
}
fn insert_owned_graph(&mut self, graph_id: GraphId, nodes: Vec<NodeRefCount>) {
self.graphs.insert(graph_id, GraphState::Owned(nodes));
self.owned.insert(graph_id);
}
fn merge_graph<I: IntoIterator<Item = GraphId>>(&mut self, graph_ids: I) -> Option<GraphId> {
let graph_ids = graph_ids.into_iter();
let graph_ids = graph_ids.collect::<Vec<_>>();
let mut merged = HashSet::new();
let mut updated_nodes = Vec::new();
let mut updated_graph_id = None;
for id in graph_ids {
let graph_id = match self.find_owned_graph(id) {
Some(val) => val,
None => continue,
};
if updated_graph_id.is_none() {
updated_graph_id = Some(graph_id);
}
merged.insert(graph_id);
}
let updated_graph_id = match updated_graph_id {
Some(val) => val,
None => return None,
};
for id in merged {
let mut updated_state = GraphState::Merged(updated_graph_id);
let state = self.graphs.get_mut(&id).unwrap();
self.owned.remove(&id);
core::mem::swap(state, &mut updated_state);
if let GraphState::Owned(nodes) = updated_state {
updated_nodes.extend(nodes)
};
}
self.insert_owned_graph(updated_graph_id, updated_nodes);
Some(updated_graph_id)
}
fn find_owned_graph(&mut self, graph_id: GraphId) -> Option<GraphId> {
let graph = match self.graphs.get(&graph_id) {
Some(val) => val,
None => return None,
};
let merged_graph_id = match graph {
GraphState::Merged(graph_id) => graph_id,
GraphState::Owned(_) => return Some(graph_id),
};
self.find_owned_graph(*merged_graph_id)
}
}
impl core::fmt::Display for GraphMemoryManagement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"Graphs Memory Management with {} owned graphs and total of {} graphs\n",
self.owned.len(),
self.graphs.len()
))?;
for (id, state) in self.graphs.iter() {
f.write_fmt(format_args!("Graph {} => ", id.node.value))?;
match state {
GraphState::Merged(id) => f.write_fmt(format_args!("Merged {}", id.node.value))?,
GraphState::Owned(nodes) => {
f.write_str("Owned")?;
for node in nodes {
f.write_fmt(format_args!(" {}", node.value))?;
}
}
}
f.write_str("\n")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[test]
fn test_graph_memory_management_connect_graphs() {
let mut graph_mm = GraphMemoryManagement::default();
let node_1 = Arc::new(NodeID::new());
let node_2 = Arc::new(NodeID::new());
let node_3 = Arc::new(NodeID::new());
let node_4 = Arc::new(NodeID::new());
let node_5 = Arc::new(NodeID::new());
graph_mm.register(node_1.clone(), vec![]);
graph_mm.register(node_2.clone(), vec![*node_1]);
assert_eq!(graph_mm.owned.len(), 1, "A single connected graph.");
graph_mm.register(node_3.clone(), vec![]);
graph_mm.register(node_4.clone(), vec![*node_3]);
assert_eq!(graph_mm.owned.len(), 2, "Two connected graphs.");
graph_mm.register(node_5.clone(), vec![*node_1, *node_3]);
assert_eq!(
graph_mm.owned.len(),
1,
"Two connected graphs are merged into one."
);
}
#[test]
fn test_graph_memory_management_find_orphans() {
let mut graph_mm = GraphMemoryManagement::default();
let node_1 = Arc::new(NodeID::new());
let node_2 = Arc::new(NodeID::new());
graph_mm.register(node_1.clone(), vec![]);
graph_mm.register(node_2.clone(), vec![*node_1]);
core::mem::drop(node_1);
assert_eq!(
graph_mm.find_orphan_graphs().len(),
0,
"Not all nodes are dropped"
);
core::mem::drop(node_2);
assert_eq!(
graph_mm.find_orphan_graphs().len(),
1,
"All nodes are dropped"
);
}
#[test]
fn test_graph_memory_management_free_graph_from_any_node() {
let mut graph_mm = GraphMemoryManagement::default();
// Create a graph and free(node_1)
let node_1 = Arc::new(NodeID::new());
let node_2 = Arc::new(NodeID::new());
graph_mm.register(node_1.clone(), vec![]);
graph_mm.register(node_2.clone(), vec![*node_1]);
let mut node_ids = Vec::new();
graph_mm.free_graph(GraphId::new(*node_1.as_ref()), |id| node_ids.push(*id));
assert!(node_ids.contains(&node_1));
assert!(node_ids.contains(&node_2));
// Same but with free(node_2);
graph_mm.register(node_1.clone(), vec![]);
graph_mm.register(node_2.clone(), vec![*node_1]);
let mut node_ids = Vec::new();
graph_mm.free_graph(GraphId::new(*node_2.as_ref()), |id| node_ids.push(*id));
assert!(node_ids.contains(&node_1));
assert!(node_ids.contains(&node_2));
}
}

View File

@ -0,0 +1,11 @@
mod client;
mod memory_management;
mod server;
#[cfg(not(feature = "std"))]
pub mod mutex;
#[cfg(feature = "std")]
pub mod mspc;
pub use client::*;

View File

@ -0,0 +1,94 @@
use super::{server::AutodiffServer, AutodiffClient};
use crate::{
checkpoint::builder::CheckpointerBuilder,
grads::Gradients,
graph::StepBoxed,
tensor::{AutodiffTensor, NodeRefCount},
NodeID,
};
use burn_tensor::backend::Backend;
use std::sync::mpsc::Sender;
static INSTANCE: spin::Lazy<ChannelClient> = spin::Lazy::new(ChannelClient::init);
#[derive(Debug, Clone)]
pub struct ChannelClient {
sender: Sender<Message>,
}
enum Message {
Register {
node_id: NodeRefCount,
step: StepBoxed,
actions: CheckpointerBuilder,
},
Backward {
node_id: NodeID,
grads: Gradients,
callback: Sender<Gradients>,
},
}
impl ChannelClient {
pub(crate) fn new() -> Self {
INSTANCE.clone()
}
fn init() -> Self {
let (sender, receiver) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let mut server = AutodiffServer::default();
for message in receiver.iter() {
match message {
Message::Register {
node_id,
step,
actions,
} => server.register(node_id, step, actions),
Message::Backward {
node_id,
grads,
callback,
} => {
let grads = server.backward(grads, node_id);
callback.send(grads).unwrap();
}
}
}
});
Self { sender }
}
}
impl AutodiffClient for ChannelClient {
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
self.sender
.send(Message::Register {
node_id,
step,
actions,
})
.unwrap()
}
fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
let node_id = root.node.id;
let grads = Gradients::new::<B, D>(root.node, root.primitive);
let (callback, receiver) = std::sync::mpsc::channel();
self.sender
.send(Message::Backward {
node_id,
grads,
callback,
})
.unwrap();
match receiver.recv() {
Ok(grads) => grads,
Err(err) => panic!("Error during backward {err:?}"),
}
}
}

View File

@ -0,0 +1,49 @@
use super::{server::AutodiffServer, AutodiffClient};
use crate::{
checkpoint::builder::CheckpointerBuilder,
grads::Gradients,
graph::StepBoxed,
tensor::{AutodiffTensor, NodeRefCount},
};
use burn_tensor::backend::Backend;
#[derive(Clone, new)]
pub struct MutexClient;
impl core::fmt::Debug for MutexClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("MutexClient")
}
}
static SERVER: spin::Mutex<Option<AutodiffServer>> = spin::Mutex::new(None);
impl AutodiffClient for MutexClient {
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
let mut server = SERVER.lock();
if let Some(server) = server.as_mut() {
server.register(node_id, step, actions);
return;
}
let mut server_new = AutodiffServer::default();
server_new.register(node_id, step, actions);
*server = Some(server_new);
}
fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
let mut server = SERVER.lock();
let node_id = root.node.id.clone();
let grads = Gradients::new::<B, D>(root.node, root.primitive);
if let Some(server) = server.as_mut() {
return server.backward(grads, node_id);
}
let mut server_new = AutodiffServer::default();
let gradients = server_new.backward(grads, node_id);
*server = Some(server_new);
gradients
}
}

View File

@ -0,0 +1,99 @@
use super::memory_management::GraphMemoryManagement;
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
grads::Gradients,
graph::{traversal::BreadthFirstSearch, StepBoxed},
tensor::NodeRefCount,
NodeID,
};
use std::collections::HashMap;
#[derive(Default)]
pub struct AutodiffServer {
steps: HashMap<NodeID, StepBoxed>,
actions_builder: HashMap<NodeID, CheckpointerBuilder>,
memory_management: GraphMemoryManagement,
}
impl AutodiffServer {
pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
let parents = step.parents();
let node_id = *rc.as_ref();
self.memory_management.register(rc, parents);
self.steps.insert(node_id, step);
self.actions_builder.insert(node_id, actions);
}
pub fn backward(&mut self, grads: Gradients, node_id: NodeID) -> Gradients {
let step = self.steps.remove(&node_id).expect(
"Root node should have a step registered, did you forget to call \
`Tensor::register_grad` on the tensor where you need gradients?",
);
let builder = self.actions_builder.remove(&node_id).unwrap();
let (tape, builder) = self.build_tape(node_id, step, builder);
let checkpointer = builder.build(&self.steps);
let gradients = Self::execute_steps(tape, grads, checkpointer);
// Cleanup
let mut on_free_graph = |node_id: &NodeID| {
self.steps.remove(node_id);
self.actions_builder.remove(node_id);
};
for graph_id in self.memory_management.find_orphan_graphs() {
self.memory_management
.free_graph(graph_id, &mut on_free_graph);
}
gradients
}
fn build_tape(
&mut self,
root: NodeID,
root_step: StepBoxed,
mut builder: CheckpointerBuilder,
) -> (Vec<Vec<StepBoxed>>, CheckpointerBuilder) {
let mut tape = (0..root_step.order())
.map(|_| Vec::with_capacity(1))
.collect::<Vec<_>>();
BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| {
let order = step.order();
if order == 0 {
return;
}
if let Some(steps) = tape.get_mut(order - 1) {
steps.push(step);
}
if let Some(node_builder) = self.actions_builder.remove(&id) {
builder.extend(node_builder);
}
});
(tape, builder)
}
fn execute_steps(
tape: Vec<Vec<StepBoxed>>,
mut grads: Gradients,
mut checkpointer: Checkpointer,
) -> Gradients {
tape.into_iter().rev().for_each(|steps| {
steps
.into_iter()
.for_each(|step| step.step(&mut grads, &mut checkpointer))
});
#[cfg(feature = "export_tests")]
// For checkpointing tests
assert!(checkpointer.is_empty());
grads
}
}

View File

@ -1,18 +1,22 @@
use burn_tensor::backend::Backend;
use std::sync::Arc;
use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
grads::Gradients,
graph::{ComputingProperty, Graph, Node, NodeID, NodeRef, Requirement, Step},
graph::{ComputingProperty, Node, NodeID, NodeRef, Requirement, Step},
runtime::{AutodiffClient, AutodiffClientImpl},
};
use burn_tensor::backend::Backend;
#[derive(Debug, Clone)]
pub struct AutodiffTensor<B: Backend, const D: usize> {
pub primitive: B::FloatTensorPrimitive<D>,
pub node: NodeRef,
pub graph: Graph,
pub rc: NodeRefCount,
}
pub type NodeRefCount = Arc<NodeID>;
#[derive(new, Debug)]
struct RootStep {
node: NodeRef,
@ -23,8 +27,16 @@ impl Step for RootStep {
// Nothing to do
}
fn node(&self) -> NodeRef {
self.node.clone()
fn node(&self) -> NodeID {
self.node.id
}
fn parents(&self) -> Vec<NodeID> {
self.node.parents.clone()
}
fn order(&self) -> usize {
self.node.order
}
}
@ -38,13 +50,14 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
id,
Requirement::None,
ComputingProperty::Ambiguous,
AutodiffClientImpl::new(),
)
.into();
Self {
rc: Arc::new(node.id),
primitive,
node,
graph: Graph::new(),
}
}
@ -67,33 +80,26 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
self.node = Node::new(
vec![],
0,
self.node.id.clone(),
self.node.id,
Requirement::Grad,
self.node.properties.clone(),
self.node.client.clone(),
)
.into();
let ops = RootStep::new(self.node.clone());
let step = RootStep::new(self.node.clone());
self.register_step(ops)
self.register_step(step, CheckpointerBuilder::default())
}
}
}
/// Create a tensor from parent infos.
pub fn from_parents<I: Iterator<Item = Graph>>(
pub fn from_parents(
primitive: B::FloatTensorPrimitive<D>,
parent_nodes: &[NodeRef],
parent_graphs: I,
requirement: Requirement,
computing_properties: ComputingProperty,
checkpointer_builder: CheckpointerBuilder,
) -> Self {
let graph = parent_graphs
.reduce(|acc, graph| acc.merge(graph))
.unwrap_or_else(Graph::new);
graph.extend_checkpointer_builder(checkpointer_builder);
let order = parent_nodes
.iter()
.map(|node| node.order)
@ -101,25 +107,47 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
.unwrap_or(0)
+ 1;
let client = parent_nodes
.first()
.map(|node| node.client.clone())
.unwrap_or_else(AutodiffClientImpl::new);
let node: NodeRef = Node::new(
parent_nodes.iter().map(|node| node.id.clone()).collect(),
parent_nodes.iter().map(|node| node.id).collect(),
order,
NodeID::new(),
requirement,
computing_properties,
client,
)
.into();
Self {
rc: Arc::new(node.id),
primitive,
node,
graph,
}
}
/// Register a step into a graph for that tensor.
pub fn register_step<O: Step + 'static>(mut self, ops: O) -> Self {
self.graph = self.graph.register(&self.node.id, Box::new(ops));
///
/// # Warning
///
/// This should be called only once per tensor.
pub fn register_step<S: Step + 'static>(
self,
step_that_created_the_tensor: S,
actions: CheckpointerBuilder,
) -> Self {
self.node.client.register(
self.rc.clone(),
Box::new(step_that_created_the_tensor),
actions,
);
self
}
pub fn into_primitive(self) -> B::FloatTensorPrimitive<D> {
self.primitive
}
}

View File

@ -24,10 +24,10 @@ where
Server: ComputeServer,
{
_handle: thread::JoinHandle<()>,
sender: mpsc::SyncSender<Message<Server>>,
sender: mpsc::Sender<Message<Server>>,
}
type Callback<Response> = mpsc::SyncSender<Response>;
type Callback<Response> = mpsc::Sender<Response>;
enum Message<Server>
where
@ -45,8 +45,8 @@ where
Server: ComputeServer + 'static,
{
/// Create a new mpsc compute channel.
pub fn new(mut server: Server, bound: usize) -> Self {
let (sender, receiver) = mpsc::sync_channel(bound);
pub fn new(mut server: Server) -> Self {
let (sender, receiver) = mpsc::channel();
let _handle = thread::spawn(move || {
while let Ok(message) = receiver.recv() {
@ -94,7 +94,7 @@ where
Server: ComputeServer + 'static,
{
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
let (callback, response) = mpsc::sync_channel(1);
let (callback, response) = mpsc::channel();
self.state
.sender
@ -105,7 +105,7 @@ where
}
fn create(&self, data: &[u8]) -> Handle<Server> {
let (callback, response) = mpsc::sync_channel(1);
let (callback, response) = mpsc::channel();
self.state
.sender
@ -116,7 +116,7 @@ where
}
fn empty(&self, size: usize) -> Handle<Server> {
let (callback, response) = mpsc::sync_channel(1);
let (callback, response) = mpsc::channel();
self.state
.sender
@ -140,7 +140,7 @@ where
}
fn sync(&self) {
let (callback, response) = mpsc::sync_channel(1);
let (callback, response) = mpsc::channel();
self.state.sender.send(Message::Sync(callback)).unwrap();

View File

@ -20,8 +20,10 @@ default = [
"burn-tch?/default",
"burn-tensor/default",
"burn-wgpu?/default",
"burn-autodiff?/default",
]
std = [
"burn-autodiff?/std",
"bincode/std",
"burn-candle?/std",
"burn-common/std",

View File

@ -18,10 +18,27 @@ pub trait DataLoaderIterator<O>: Iterator<Item = O> {
}
/// A data loader that can be used to iterate over a dataset.
pub trait DataLoader<O> {
pub trait DataLoader<O>: Send {
/// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
/// The number of items (not the number of batches nor the number of iterations),
/// corresponding to the items_total of the progress returned by the iterator.
fn num_items(&self) -> usize;
}
/// A super trait for [dataloader](DataLoader) that allows it to be cloned dynamically.
///
/// Any dataloader that implements [Clone] should also implement this automatically.
pub trait DynDataLoader<O>: DataLoader<O> {
/// Clone the dataloader and returns a new one.
fn clone_dyn(&self) -> Box<dyn DynDataLoader<O>>;
}
impl<D, O> DynDataLoader<O> for D
where
D: DataLoader<O> + Clone + 'static,
{
fn clone_dyn(&self) -> Box<dyn DynDataLoader<O>> {
Box::new(self.clone())
}
}

View File

@ -1,6 +1,6 @@
use super::{
batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, MultiThreadDataLoader,
Progress,
batcher::DynBatcher, BatchStrategy, DataLoader, DataLoaderIterator, DynDataLoader,
MultiThreadDataLoader, Progress,
};
use burn_dataset::{
transform::{PartialDataset, ShuffledDataset},
@ -13,8 +13,19 @@ use std::sync::Arc;
pub struct BatchDataLoader<I, O> {
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
rng: Option<spin::Mutex<rand::rngs::StdRng>>,
batcher: Box<dyn DynBatcher<I, O>>,
rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
}
impl<I, O> Clone for BatchDataLoader<I, O> {
fn clone(&self) -> Self {
Self {
strategy: self.strategy.clone_dyn(),
dataset: self.dataset.clone(),
batcher: self.batcher.clone_dyn(),
rng: self.rng.clone(),
}
}
}
impl<I, O> BatchDataLoader<I, O> {
@ -34,14 +45,14 @@ impl<I, O> BatchDataLoader<I, O> {
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
batcher: Box<dyn DynBatcher<I, O>>,
rng: Option<rand::rngs::StdRng>,
) -> Self {
Self {
strategy,
dataset,
batcher,
rng: rng.map(spin::Mutex::new),
rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
}
}
}
@ -51,13 +62,13 @@ struct BatchDataloaderIterator<I, O> {
current_index: usize,
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
batcher: Box<dyn DynBatcher<I, O>>,
}
impl<I, O> BatchDataLoader<I, O>
where
I: Send + Sync + Clone + 'static,
O: Send + Sync + Clone + 'static,
O: Send + Clone + 'static,
{
/// Creates a new multi-threaded batch data loader.
///
@ -74,14 +85,13 @@ where
pub fn multi_thread(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
batcher: Box<dyn DynBatcher<I, O>>,
num_threads: usize,
mut rng: Option<rand::rngs::StdRng>,
) -> MultiThreadDataLoader<O> {
let datasets = PartialDataset::split(dataset, num_threads);
let mut dataloaders: Vec<Arc<dyn DataLoader<_> + Send + Sync>> =
Vec::with_capacity(num_threads);
let mut dataloaders = Vec::with_capacity(num_threads);
// Create more rngs from the first one, one for each new dataloader.
let rngs = (0..num_threads).map(|_| {
@ -90,17 +100,21 @@ where
});
for (dataset, rng) in datasets.into_iter().zip(rngs) {
let strategy = strategy.new_like();
let strategy = strategy.clone_dyn();
let dataloader =
BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone(), rng);
let dataloader = Arc::new(dataloader);
BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone_dyn(), rng);
let dataloader: Box<dyn DynDataLoader<_>> = Box::new(dataloader);
dataloaders.push(dataloader);
}
MultiThreadDataLoader::new(dataloaders)
}
}
impl<I: Send + Sync + Clone + 'static, O: Send + Sync> DataLoader<O> for BatchDataLoader<I, O> {
impl<I, O> DataLoader<O> for BatchDataLoader<I, O>
where
I: Send + Sync + Clone + 'static,
O: Send + 'static,
{
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
// When starting a new iteration, we first check if the dataloader was created with an rng,
// implying that we should shuffle the dataset beforehand, while advancing the current
@ -117,9 +131,9 @@ impl<I: Send + Sync + Clone + 'static, O: Send + Sync> DataLoader<O> for BatchDa
None => self.dataset.clone(),
};
Box::new(BatchDataloaderIterator::new(
self.strategy.new_like(),
self.strategy.clone_dyn(),
dataset,
self.batcher.clone(),
self.batcher.clone_dyn(),
))
}
@ -143,7 +157,7 @@ impl<I, O> BatchDataloaderIterator<I, O> {
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
batcher: Box<dyn DynBatcher<I, O>>,
) -> Self {
BatchDataloaderIterator {
current_index: 0,
@ -192,7 +206,7 @@ mod tests {
#[test]
fn test_batch_dataloader() {
let batcher = Arc::new(TestBatcher::new());
let batcher = Box::new(TestBatcher::new());
let dataset = Arc::new(FakeDataset::<String>::new(27));
let dataloader = BatchDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
@ -219,12 +233,12 @@ mod tests {
#[test]
fn test_multi_thread_batch_dataloader() {
let batcher = Arc::new(TestBatcher::new());
let batcher = Box::new(TestBatcher::new());
let dataset = Arc::new(FakeDataset::<String>::new(27));
let dataloader_single_thread = BatchDataLoader::new(
Box::new(FixBatchStrategy::new(5)),
dataset.clone(),
batcher.clone(),
batcher.clone_dyn(),
None,
);
let dataloader_multi_thread = BatchDataLoader::multi_thread(

View File

@ -1,5 +1,5 @@
/// A trait for batching items of type `I` into items of type `O`.
pub trait Batcher<I, O>: Send + Sync {
pub trait Batcher<I, O>: Send {
/// Batches the given items.
///
/// # Arguments
@ -12,9 +12,27 @@ pub trait Batcher<I, O>: Send + Sync {
fn batch(&self, items: Vec<I>) -> O;
}
/// A super trait for [batcher](Batcher) that allows it to be cloned dynamically.
///
/// Any batcher that implements [Clone] should also implement this automatically.
pub trait DynBatcher<I, O>: Send + Batcher<I, O> {
/// Clone the batcher and returns a new one.
fn clone_dyn(&self) -> Box<dyn DynBatcher<I, O>>;
}
impl<B, I, O> DynBatcher<I, O> for B
where
B: Batcher<I, O> + Clone + 'static,
{
fn clone_dyn(&self) -> Box<dyn DynBatcher<I, O>> {
Box::new(self.clone())
}
}
#[cfg(test)]
#[derive(new)]
#[derive(new, Clone)]
pub struct TestBatcher;
#[cfg(test)]
impl<I> Batcher<I, Vec<I>> for TestBatcher {
fn batch(&self, items: Vec<I>) -> Vec<I> {

View File

@ -1,4 +1,4 @@
use super::{batcher::Batcher, BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy};
use super::{batcher::DynBatcher, BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy};
use burn_dataset::Dataset;
use rand::{rngs::StdRng, SeedableRng};
use std::sync::Arc;
@ -6,7 +6,7 @@ use std::sync::Arc;
/// A builder for data loaders.
pub struct DataLoaderBuilder<I, O> {
strategy: Option<Box<dyn BatchStrategy<I>>>,
batcher: Arc<dyn Batcher<I, O>>,
batcher: Box<dyn DynBatcher<I, O>>,
num_threads: Option<usize>,
shuffle: Option<u64>,
}
@ -14,7 +14,7 @@ pub struct DataLoaderBuilder<I, O> {
impl<I, O> DataLoaderBuilder<I, O>
where
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Clone + std::fmt::Debug + 'static,
{
/// Creates a new data loader builder.
///
@ -27,10 +27,10 @@ where
/// The data loader builder.
pub fn new<B>(batcher: B) -> Self
where
B: Batcher<I, O> + 'static,
B: DynBatcher<I, O> + 'static,
{
Self {
batcher: Arc::new(batcher),
batcher: Box::new(batcher),
strategy: None,
num_threads: None,
shuffle: None,

View File

@ -1,12 +1,12 @@
use super::{DataLoader, DataLoaderIterator, Progress};
use std::sync::{mpsc, Arc};
use super::{DataLoader, DataLoaderIterator, DynDataLoader, Progress};
use std::sync::mpsc;
use std::thread;
const MAX_QUEUED_ITEMS: usize = 100;
/// A multi-threaded data loader that can be used to iterate over a dataset.
pub struct MultiThreadDataLoader<O> {
dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>,
dataloaders: Vec<Box<dyn DynDataLoader<O>>>,
}
/// A message that can be sent between threads.
@ -36,7 +36,7 @@ impl<O> MultiThreadDataLoader<O> {
/// # Returns
///
/// The multi-threaded data loader.
pub fn new(dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>) -> Self {
pub fn new(dataloaders: Vec<Box<dyn DynDataLoader<O>>>) -> Self {
Self { dataloaders }
}
}
@ -52,11 +52,10 @@ where
let handlers: Vec<_> = self
.dataloaders
.clone()
.into_iter()
.iter()
.enumerate()
.map(|(index, dataloader)| {
let dataloader_cloned = dataloader;
let dataloader_cloned = dataloader.clone_dyn();
let sender_cloned = sender.clone();
progresses.push(Progress::new(0, dataloader_cloned.num_items()));

View File

@ -1,5 +1,5 @@
/// A strategy to batch items.
pub trait BatchStrategy<I>: Send + Sync {
pub trait BatchStrategy<I>: Send {
/// Adds an item to the strategy.
///
/// # Arguments
@ -23,7 +23,7 @@ pub trait BatchStrategy<I>: Send + Sync {
/// # Returns
///
/// The new strategy.
fn new_like(&self) -> Box<dyn BatchStrategy<I>>;
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
}
/// A strategy to batch items with a fixed batch size.
@ -50,7 +50,7 @@ impl<I> FixBatchStrategy<I> {
}
}
impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
impl<I: Send + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
fn add(&mut self, item: I) {
self.items.push(item);
}
@ -70,7 +70,7 @@ impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
Some(items)
}
fn new_like(&self) -> Box<dyn BatchStrategy<I>> {
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
Box::new(Self::new(self.batch_size))
}
}

View File

@ -80,7 +80,7 @@ macro_rules! module {
/// my_other_field: usize,
/// }
/// ```
pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record<B>;
@ -238,7 +238,7 @@ pub trait ModuleMapper<B: Backend> {
}
/// Module with auto-differentiation backend.
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + Sync + core::fmt::Debug {
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
/// Inner module without auto-differentiation.
type InnerModule: Module<B::InnerBackend>;

View File

@ -1,7 +1,8 @@
use super::ParamId;
use alloc::boxed::Box;
use alloc::format;
use burn_common::stub::{RwLock, SyncOnceCell};
use burn_common::stub::RwLock;
use core::cell::OnceCell;
use core::ops::Deref;
/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
@ -31,7 +32,7 @@ use core::ops::Deref;
/// ```
pub struct Param<T: Parameter> {
pub(crate) id: ParamId,
state: SyncOnceCell<T>,
state: OnceCell<T>,
/// The locking is only required because of `lazy_device` and `lazy_is_require_grad`.
///
/// Because of once cell, we have a guarantee that the initialization will only be called once,
@ -54,7 +55,7 @@ impl<T: Parameter> core::fmt::Debug for Param<T> {
}
/// Trait that defines what is necessary for a type to be a parameter.
pub trait Parameter: Clone + core::fmt::Debug + Send + Sync {
pub trait Parameter: Clone + core::fmt::Debug + Send {
/// The device type to be used.
type Device: Clone;
@ -70,7 +71,7 @@ pub trait Parameter: Clone + core::fmt::Debug + Send + Sync {
#[allow(clippy::type_complexity)]
struct Uninitialized<P: Parameter> {
init: Box<dyn Fn(&P::Device, bool) -> P + Send + Sync>,
init: Box<dyn Fn(&P::Device, bool) -> P + Send>,
device: P::Device,
is_require_grad: bool,
}
@ -87,7 +88,7 @@ impl<T: Parameter> Param<T> {
pub fn initialized(id: ParamId, value: T) -> Self {
Self {
id,
state: SyncOnceCell::initialized(value),
state: OnceCell::from(value),
initialization: None,
}
}
@ -95,11 +96,11 @@ impl<T: Parameter> Param<T> {
/// Create a new parameter that is not already initialized.
pub fn uninitialized<F>(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self
where
F: Fn(&T::Device, bool) -> T + Send + Sync + 'static,
F: Fn(&T::Device, bool) -> T + Send + 'static,
{
Self {
id,
state: SyncOnceCell::new(),
state: OnceCell::new(),
initialization: Some(RwLock::new(Some(Uninitialized {
init: Box::new(init),
device,
@ -149,7 +150,7 @@ impl<T: Parameter> Param<T> {
Self {
id,
state: SyncOnceCell::initialized(tensor),
state: OnceCell::from(tensor),
initialization: None,
}
}

View File

@ -5,7 +5,7 @@ use core::fmt::Debug;
impl<T, B> Module<B> for Option<T>
where
T: Module<B> + Debug + Send + Sync + Clone,
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = Option<T::Record>;
@ -48,7 +48,7 @@ where
impl<T, B> AutodiffModule<B> for Option<T>
where
T: AutodiffModule<B> + Debug + Send + Sync + Clone,
T: AutodiffModule<B> + Debug + Send + Clone,
B: AutodiffBackend,
{
type InnerModule = Option<T::InnerModule>;
@ -60,7 +60,7 @@ where
impl<T, B> Module<B> for Vec<T>
where
T: Module<B> + Debug + Send + Sync + Clone,
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = Vec<T::Record>;
@ -116,7 +116,7 @@ where
impl<T, B> AutodiffModule<B> for Vec<T>
where
T: AutodiffModule<B> + Debug + Send + Sync + Clone,
T: AutodiffModule<B> + Debug + Send + Clone,
B: AutodiffBackend,
{
type InnerModule = Vec<T::InnerModule>;
@ -128,7 +128,7 @@ where
impl<const N: usize, T, B> Module<B> for [T; N]
where
T: Module<B> + Debug + Send + Sync + Clone + Copy,
T: Module<B> + Debug + Send + Clone + Copy,
T::Record: Debug,
B: Backend,
{
@ -185,7 +185,7 @@ where
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
where
T: AutodiffModule<B> + Debug + Send + Sync + Clone + Copy,
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
T::InnerModule: Copy + Debug,
<T::InnerModule as Module<B::InnerBackend>>::Record: Debug,
<T as Module<B>>::Record: Debug,
@ -210,7 +210,7 @@ macro_rules! impl_module_tuple {
impl<B, $($l,)*> Module<B> for ($($l,)*)
where
B: Backend,
$($l: Module<B> + Debug + Send + Sync + Clone,)*
$($l: Module<B> + Debug + Send + Clone,)*
{
type Record = ($($l::Record),*);
@ -247,7 +247,7 @@ macro_rules! impl_module_tuple {
impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
where
B: AutodiffBackend,
$($l: AutodiffModule<B> + Debug + Send + Sync + Clone,)*
$($l: AutodiffModule<B> + Debug + Send + Clone,)*
{
type InnerModule = ($($l::InnerModule,)*);

View File

@ -2,6 +2,7 @@ use super::ParamId;
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::stub::Mutex;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
@ -10,7 +11,6 @@ use burn_tensor::{
#[cfg(feature = "std")]
mod threading {
pub(super) use std::collections::HashMap;
pub(super) use std::sync::{Mutex, RwLock};
pub(super) use std::thread::ThreadId;
#[inline(always)]
@ -21,7 +21,7 @@ mod threading {
#[cfg(not(feature = "std"))]
mod threading {
pub(super) use burn_common::stub::{Mutex, RwLock, ThreadId};
pub(super) use burn_common::stub::ThreadId;
pub(super) use hashbrown::HashMap;
#[inline(always)]
@ -42,20 +42,20 @@ use threading::*;
pub struct RunningState<V> {
id: ParamId,
values: Arc<Mutex<HashMap<ThreadId, V>>>,
value: Arc<RwLock<V>>,
value: Arc<Mutex<V>>,
}
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
let tensor = self.value.read().unwrap();
let tensor = self.value.lock().unwrap();
visitor.visit_float(&self.id, &tensor)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut tensor = self.value.write().unwrap();
let mut tensor = self.value.lock().unwrap();
let tensor_out = mapper.map_float(&self.id, tensor.clone());
*tensor = tensor_out;
@ -66,13 +66,13 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
fn into_record(self) -> Self::Record {
self.sync();
let tensor = self.value.read().unwrap();
let tensor = self.value.lock().unwrap();
Param::initialized(self.id, tensor.clone())
}
fn load_record(mut self, record: Self::Record) -> Self {
let mut tensor = self.value.write().unwrap();
let mut tensor = self.value.lock().unwrap();
*tensor = record.val().to_device(&tensor.device());
self.id = record.id;
@ -82,7 +82,7 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
let mut tensor = self.value.write().unwrap();
let mut tensor = self.value.lock().unwrap();
let tensor_out = tensor.clone().to_device(device);
*tensor = tensor_out;
@ -99,7 +99,7 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
&self,
mut devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device> {
let device = self.value.read().unwrap().device();
let device = self.value.lock().unwrap().device();
if !devices.contains(&device) {
devices.push(device)
@ -115,7 +115,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
Self {
id: ParamId::new(),
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(RwLock::new(value)),
value: Arc::new(Mutex::new(value)),
}
}
@ -124,7 +124,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
Self {
id,
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(RwLock::new(value)),
value: Arc::new(Mutex::new(value)),
}
}
@ -134,7 +134,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
Self {
id: record.id,
values: Arc::new(Mutex::new(HashMap::new())),
value: Arc::new(RwLock::new(tensor)),
value: Arc::new(Mutex::new(tensor)),
}
}
@ -156,7 +156,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
///
/// The current value might be outdated by one update.
pub fn value(&self) -> Tensor<B, D> {
let value = self.value.read().unwrap();
let value = self.value.lock().unwrap();
value.clone()
}
@ -174,7 +174,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
self.update_value(&mut map);
}
let value = self.value.read().unwrap();
let value = self.value.lock().unwrap();
value.clone()
}
@ -204,7 +204,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
if let Some(value) = value_updated {
let value = value.div_scalar(counter);
let mut value_old = self.value.write().unwrap();
let mut value_old = self.value.lock().unwrap();
*value_old = value;
}
}

View File

@ -5,7 +5,7 @@ use crate::tensor::backend::AutodiffBackend;
use crate::LearningRate;
/// General trait to optimize [module](AutodiffModule).
pub trait Optimizer<M, B>: Send + Sync
pub trait Optimizer<M, B>: Send
where
M: AutodiffModule<B>,
B: AutodiffBackend,

View File

@ -5,7 +5,7 @@ use super::PrecisionSettings;
use serde::{de::DeserializeOwned, Serialize};
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
pub trait Record<B: Backend>: Send + Sync {
pub trait Record<B: Backend>: Send {
/// Type of the item that can be serialized and deserialized.
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;

View File

@ -15,7 +15,7 @@ pub struct Bool;
/// A type-level representation of the kind of a tensor.
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
/// The primitive type of the tensor.
type Primitive<const D: usize>: Clone + core::fmt::Debug + Sync + Send;
type Primitive<const D: usize>: Clone + core::fmt::Debug + Send;
/// The name of the tensor kind.
fn name() -> &'static str;

View File

@ -34,8 +34,8 @@ use super::BackendBridge;
///
/// ### Multi-Threaded
///
/// Backend tensor types are all `Clone` + `Sync` + `Send`, which allows them to be safely
/// shared between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
/// reuse tensors' buffer without locking; see the next section on the Mutable API.
///
@ -72,17 +72,17 @@ pub trait Backend:
type FullPrecisionBridge: BackendBridge<Self> + 'static;
/// Tensor primitive to be used for all float operations.
type FloatTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
type FloatTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
/// Float element type.
type FloatElem: Element;
/// Tensor primitive to be used for all int operations.
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
type IntTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
/// Int element type.
type IntElem: Element;
/// Tensor primitive to be used for all bool operations.
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
type BoolTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
/// If autodiff is enabled.
fn ad_enabled() -> bool {
@ -109,7 +109,7 @@ pub trait AutodiffBackend: Backend {
>;
/// Gradients type.
type Gradients: Send + Sync;
type Gradients: Send;
/// Backward pass.
///

View File

@ -12,7 +12,7 @@ use crate::{backend::Backend, Tensor};
/// Contains tensor of arbitrary dimension.
#[derive(Debug)]
pub struct TensorContainer<ID> {
tensors: HashMap<ID, Box<dyn Any + Send + Sync>>,
tensors: HashMap<ID, Box<dyn Any + Send>>,
}
impl<ID> Default for TensorContainer<ID>

View File

@ -24,8 +24,8 @@ use burn_core::tensor::backend::AutodiffBackend;
/// Struct to configure and create a [learner](Learner).
pub struct LearnerBuilder<B, T, V, M, O, S>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
T: Send + 'static,
V: Send + 'static,
B: AutodiffBackend,
M: AutodiffModule<B>,
O: Optimizer<M, B>,
@ -58,8 +58,8 @@ where
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
where
B: AutodiffBackend,
T: Send + Sync + 'static,
V: Send + Sync + 'static,
T: Send + 'static,
V: Send + 'static,
M: AutodiffModule<B> + core::fmt::Display + 'static,
O: Optimizer<M, B>,
S: LrScheduler<B>,

View File

@ -11,6 +11,7 @@ const MEAN: [f32; 3] = [0.4914, 0.48216, 0.44653];
const STD: [f32; 3] = [0.24703, 0.24349, 0.26159];
/// Normalizer for the CIFAR-10 dataset.
#[derive(Clone)]
pub struct Normalizer<B: Backend> {
pub mean: Tensor<B, 4>,
pub std: Tensor<B, 4>,
@ -36,6 +37,7 @@ impl<B: Backend> Normalizer<B> {
}
}
#[derive(Clone)]
pub struct ClassificationBatcher<B: Backend> {
normalizer: Normalizer<B>,
device: B::Device,

View File

@ -87,13 +87,13 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
// Register the gradient for each variable based on whether they are marked as
// `tracked`.
if let Some(node) = node_bias {
grads.register::<B, D>(node, grad_bias);
grads.register::<B, D>(node.id, grad_bias);
}
if let Some(node) = node_lhs {
grads.register::<B, D>(node, grad_lhs);
grads.register::<B, D>(node.id, grad_lhs);
}
if let Some(node) = node_rhs {
grads.register::<B, D>(node, grad_rhs);
grads.register::<B, D>(node.id, grad_rhs);
}
}
}
@ -102,10 +102,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
//
// Each node can be fetched with `ops.parents` in the same order as defined here.
match FusedMatmulAddReluBackward
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone(), bias.node.clone()],
[lhs.graph.clone(), rhs.graph.clone(), bias.graph.clone()],
)
.prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
// Marks the operation as compute bound, meaning it will save its
// state instead of recomputing itself during checkpointing
.compute_bound()

View File

@ -3,6 +3,7 @@ use burn::{
prelude::*,
};
#[derive(Clone)]
pub struct MnistBatcher<B: Backend> {
device: B::Device,
}

View File

@ -3,6 +3,7 @@ use burn::{
prelude::*,
};
#[derive(Clone, Debug)]
pub struct MnistBatcher<B: Backend> {
device: B::Device,
}

View File

@ -109,6 +109,7 @@ impl DiabetesDataset {
}
}
#[derive(Clone, Debug)]
pub struct DiabetesBatcher<B: Backend> {
device: B::Device,
}

View File

@ -15,7 +15,7 @@ use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_m
use std::sync::Arc;
/// Struct for batching text classification items
#[derive(new)]
#[derive(Clone, new)]
pub struct TextClassificationBatcher<B: Backend> {
tokenizer: Arc<dyn Tokenizer>, // Tokenizer for converting text to token IDs
device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device)

View File

@ -2,7 +2,7 @@ use super::{dataset::TextGenerationItem, tokenizer::Tokenizer};
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};
use std::sync::Arc;
#[derive(new)]
#[derive(Clone, new)]
pub struct TextGenerationBatcher {
tokenizer: Arc<dyn Tokenizer>,
max_seq_length: usize,