mirror of https://github.com/tracel-ai/burn.git
[Breaking] Make Tensor, Module, Optimizer !Sync + Refactor Autodiff (#1575)
This commit is contained in:
parent
ce898ff899
commit
1239d9bfa3
|
@ -11,8 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-autodiff"
|
|||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["std"]
|
||||
export_tests = ["burn-tensor-testgen"]
|
||||
std = []
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.13.0" }
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
|
||||
grads::Gradients,
|
||||
graph::backward::backward,
|
||||
runtime::AutodiffClient,
|
||||
tensor::AutodiffTensor,
|
||||
AutodiffBridge,
|
||||
};
|
||||
|
@ -53,7 +53,9 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
|
|||
type Gradients = Gradients;
|
||||
|
||||
fn backward<const D: usize>(tensor: AutodiffTensor<B, D>) -> Gradients {
|
||||
backward(tensor)
|
||||
let client = tensor.node.client.clone();
|
||||
|
||||
AutodiffClient::backward(&client, tensor)
|
||||
}
|
||||
|
||||
fn grad<const D: usize>(
|
||||
|
@ -83,7 +85,7 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
|
|||
grad: B::FloatTensorPrimitive<D>,
|
||||
) {
|
||||
grads.remove(tensor);
|
||||
grads.register::<B, D>(tensor.node.clone(), grad);
|
||||
grads.register::<B, D>(tensor.node.id, grad);
|
||||
}
|
||||
|
||||
fn int_inner<const D: usize>(
|
||||
|
|
|
@ -39,7 +39,7 @@ where
|
|||
_bridge: PhantomData<Bridge>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct RetroIntoTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
|
||||
tensor_id: NodeID,
|
||||
_backend: PhantomData<B>,
|
||||
|
@ -84,9 +84,9 @@ where
|
|||
_backend: PhantomData,
|
||||
_bridge: PhantomData,
|
||||
}
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateless(Bridge::into_target(tensor.primitive, None))
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ where
|
|||
_bridge: PhantomData<Bridge>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct RetroFromTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
|
||||
tensor_id: NodeID,
|
||||
_backend: PhantomData<B>,
|
||||
|
@ -146,9 +146,9 @@ where
|
|||
_backend: PhantomData,
|
||||
_bridge: PhantomData,
|
||||
}
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateless(Bridge::from_target(tensor.primitive, None))
|
||||
}
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::graph::{NodeID, NodeRef};
|
||||
|
||||
use super::{
|
||||
retro_forward::RetroForwards,
|
||||
state::{BackwardStates, State},
|
||||
};
|
||||
use crate::graph::NodeID;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
/// Links a [NodeID] to its autodiff graph [NodeRef]
|
||||
pub(crate) struct NodeTree {
|
||||
map: HashMap<NodeID, NodeRef>,
|
||||
map: HashMap<NodeID, Vec<NodeID>>,
|
||||
}
|
||||
|
||||
impl NodeTree {
|
||||
/// Gives the parents of the node in the autodiff graph
|
||||
pub(crate) fn parents(&self, node_id: &NodeID) -> Option<Vec<NodeID>> {
|
||||
self.map.get(node_id).map(|node| node.parents.clone())
|
||||
self.map.get(node_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -33,14 +31,12 @@ impl Checkpointer {
|
|||
/// or give their pre-computed tensors.
|
||||
pub fn retrieve_node_output<T>(&mut self, node_id: NodeID) -> T
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
self.topological_sort(node_id.clone())
|
||||
.into_iter()
|
||||
.for_each(|node| {
|
||||
self.retro_forwards
|
||||
.execute_retro_forward(node, &mut self.backward_states)
|
||||
});
|
||||
self.topological_sort(node_id).into_iter().for_each(|node| {
|
||||
self.retro_forwards
|
||||
.execute_retro_forward(node, &mut self.backward_states)
|
||||
});
|
||||
|
||||
self.backward_states.get_state::<T>(&node_id)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
use std::{any::Any, collections::HashMap, sync::Arc};
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate::{
|
||||
graph::{ComputingProperty, NodeID, NodeRef, NodeSteps},
|
||||
graph::{ComputingProperty, NodeID, NodeSteps},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::{any::Any, collections::HashMap, sync::Arc};
|
||||
|
||||
use super::{
|
||||
base::{Checkpointer, NodeTree},
|
||||
|
@ -20,31 +18,34 @@ pub enum CheckpointingAction {
|
|||
/// The node's already computed output should be saved
|
||||
Computed {
|
||||
/// The node
|
||||
node_ref: NodeRef,
|
||||
node_id: NodeID,
|
||||
/// The node's output
|
||||
state_content: Box<dyn Any + Send + Sync>,
|
||||
state_content: Box<dyn Any + Send>,
|
||||
},
|
||||
/// The node should recompute itself when asked
|
||||
Recompute {
|
||||
/// The node
|
||||
node_ref: NodeRef,
|
||||
node_id: NodeID,
|
||||
/// How the node should recompute itself
|
||||
retro_forward: Arc<dyn RetroForward>,
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: Remove that when proper client server.
|
||||
unsafe impl Send for CheckpointingAction {}
|
||||
|
||||
impl CheckpointingAction {
|
||||
/// Utilitary function to access the id of the node of the checkpointing action
|
||||
pub fn id(&self) -> NodeID {
|
||||
match self {
|
||||
CheckpointingAction::Computed {
|
||||
node_ref,
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => node_ref.id.clone(),
|
||||
} => *node_ref,
|
||||
CheckpointingAction::Recompute {
|
||||
node_ref,
|
||||
node_id: node_ref,
|
||||
retro_forward: _,
|
||||
} => node_ref.id.clone(),
|
||||
} => *node_ref,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -83,13 +84,13 @@ impl CheckpointerBuilder {
|
|||
match &tensor.node.properties {
|
||||
ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {
|
||||
action_list.push(CheckpointingAction::Computed {
|
||||
node_ref: tensor.node.clone(),
|
||||
node_id: tensor.node.id,
|
||||
state_content: Box::new(tensor.primitive.clone()),
|
||||
})
|
||||
}
|
||||
ComputingProperty::MemoryBound { retro_forward } => {
|
||||
action_list.push(CheckpointingAction::Recompute {
|
||||
node_ref: tensor.node.clone(),
|
||||
node_id: tensor.node.id,
|
||||
retro_forward: retro_forward.clone(),
|
||||
})
|
||||
}
|
||||
|
@ -105,10 +106,6 @@ impl CheckpointerBuilder {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.explicit_actions.len() + self.backup_actions.len()
|
||||
}
|
||||
|
||||
pub(crate) fn build(self, graph: &NodeSteps) -> Checkpointer {
|
||||
let node_tree = self.make_tree(graph);
|
||||
let mut backward_states_map = HashMap::new();
|
||||
|
@ -143,11 +140,11 @@ impl CheckpointerBuilder {
|
|||
{
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_ref,
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => stop_nodes.push(node_ref.id.clone()),
|
||||
} => stop_nodes.push(*node_ref),
|
||||
CheckpointingAction::Recompute {
|
||||
node_ref: _,
|
||||
node_id: _,
|
||||
retro_forward: _,
|
||||
} => {}
|
||||
}
|
||||
|
@ -165,10 +162,10 @@ impl CheckpointerBuilder {
|
|||
for action in self.explicit_actions.iter() {
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_ref,
|
||||
node_id: node_ref,
|
||||
state_content: _,
|
||||
} => {
|
||||
let id = node_ref.id.clone();
|
||||
let id = *node_ref;
|
||||
match n_required_map.remove(&id) {
|
||||
Some(n) => {
|
||||
n_required_map.insert(id, n + 1);
|
||||
|
@ -179,10 +176,10 @@ impl CheckpointerBuilder {
|
|||
};
|
||||
}
|
||||
CheckpointingAction::Recompute {
|
||||
node_ref,
|
||||
node_id: node_ref,
|
||||
retro_forward: _,
|
||||
} => {
|
||||
let id = node_ref.id.clone();
|
||||
let id = *node_ref;
|
||||
Self::update_n_required_of_parents(
|
||||
id,
|
||||
&mut n_required_map,
|
||||
|
@ -229,13 +226,13 @@ impl CheckpointerBuilder {
|
|||
|
||||
match action {
|
||||
CheckpointingAction::Computed {
|
||||
node_ref: _,
|
||||
node_id: _,
|
||||
state_content,
|
||||
} => {
|
||||
self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)
|
||||
}
|
||||
CheckpointingAction::Recompute {
|
||||
node_ref: _,
|
||||
node_id: _,
|
||||
retro_forward,
|
||||
} => self.checkpoint_lazy(
|
||||
backward_states_map,
|
||||
|
@ -251,7 +248,7 @@ impl CheckpointerBuilder {
|
|||
fn make_tree(&self, graph: &NodeSteps) -> NodeTree {
|
||||
let mut tree = HashMap::default();
|
||||
for (id, step) in graph {
|
||||
tree.insert(id.clone(), step.node());
|
||||
tree.insert(*id, step.parents());
|
||||
}
|
||||
NodeTree::new(tree)
|
||||
}
|
||||
|
@ -267,7 +264,7 @@ impl CheckpointerBuilder {
|
|||
n_required_map.insert(id, n + 1);
|
||||
}
|
||||
None => {
|
||||
n_required_map.insert(id.clone(), 1);
|
||||
n_required_map.insert(id, 1);
|
||||
if !stop_nodes.contains(&id) {
|
||||
if let Some(parents) = node_tree.parents(&id) {
|
||||
for p in parents {
|
||||
|
@ -288,7 +285,7 @@ impl CheckpointerBuilder {
|
|||
&self,
|
||||
backward_states_map: &mut HashMap<NodeID, State>,
|
||||
node_id: NodeID,
|
||||
state_content: Box<dyn Any + Send + Sync>,
|
||||
state_content: Box<dyn Any + Send>,
|
||||
n_required: usize,
|
||||
) {
|
||||
backward_states_map.insert(
|
||||
|
@ -308,7 +305,7 @@ impl CheckpointerBuilder {
|
|||
retro_forward: Arc<dyn RetroForward>,
|
||||
n_required: usize,
|
||||
) {
|
||||
retro_forward_map.insert(node_id.clone(), retro_forward);
|
||||
backward_states_map.insert(node_id.clone(), State::Recompute { n_required });
|
||||
retro_forward_map.insert(node_id, retro_forward);
|
||||
backward_states_map.insert(node_id, State::Recompute { n_required });
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use super::state::{BackwardStates, State};
|
|||
/// Definition of the forward function of a node, called during retropropagation only.
|
||||
/// This is different from the normal forward function because it reads and writes from
|
||||
/// the [InnerStates] map instead of having a clear function signature.
|
||||
pub trait RetroForward: Debug + Send + Sync + 'static {
|
||||
pub trait RetroForward: Debug + Send + 'static {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID);
|
||||
}
|
||||
|
||||
|
@ -31,7 +31,7 @@ impl RetroForwards {
|
|||
{
|
||||
// Retro forwards are always used only once because afterwards their state is computed
|
||||
let retro_forward = self.map.remove(&node_id).unwrap();
|
||||
retro_forward.forward(backward_states, node_id.clone());
|
||||
retro_forward.forward(backward_states, node_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,7 @@ macro_rules! retro_unary_scalar {
|
|||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend, const D: usize> {
|
||||
lhs_id: NodeID,
|
||||
rhs: FloatElem<B>,
|
||||
|
@ -72,7 +72,7 @@ macro_rules! retro_unary {
|
|||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend, const D: usize> {
|
||||
input_id: NodeID,
|
||||
_backend: PhantomData<B>,
|
||||
|
@ -95,7 +95,7 @@ macro_rules! retro_binary {
|
|||
$name:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
#[derive(new, Debug)]
|
||||
#[derive(new, Debug, Clone)]
|
||||
struct $name<B: Backend, const D: usize> {
|
||||
lhs_id: NodeID,
|
||||
rhs_id: NodeID,
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::{any::Any, collections::HashMap};
|
|||
use crate::graph::NodeID;
|
||||
|
||||
/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
|
||||
pub(crate) type StateContent = Box<dyn Any + Send + Sync>;
|
||||
pub(crate) type StateContent = Box<dyn Any + Send>;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// The state contained at one node. Encapsulates the node output if precomputed,
|
||||
|
@ -71,7 +71,7 @@ impl BackwardStates {
|
|||
/// This function always gives ownership of the output, but will clone it if needed for further uses.
|
||||
pub(crate) fn get_state<T>(&mut self, node_id: &NodeID) -> T
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
// Fetch the state and decrement its number of required
|
||||
let state = self.map.remove(node_id).unwrap();
|
||||
|
@ -97,7 +97,7 @@ impl BackwardStates {
|
|||
.unwrap()
|
||||
.clone();
|
||||
|
||||
self.insert_state(node_id.clone(), new_stored_state);
|
||||
self.insert_state(*node_id, new_stored_state);
|
||||
|
||||
downcasted
|
||||
} else {
|
||||
|
@ -119,7 +119,7 @@ impl BackwardStates {
|
|||
|
||||
pub(crate) fn save<T>(&mut self, node_id: NodeID, saved_output: T)
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
T: Clone + Send + 'static,
|
||||
{
|
||||
let n_required = self.get_state_ref(&node_id).unwrap().n_required();
|
||||
self.insert_state(
|
||||
|
|
|
@ -3,6 +3,7 @@ use burn_tensor::{backend::Backend, container::TensorContainer, Tensor};
|
|||
use crate::{
|
||||
graph::{NodeRef, Requirement},
|
||||
tensor::AutodiffTensor,
|
||||
NodeID,
|
||||
};
|
||||
|
||||
/// Gradient identifier.
|
||||
|
@ -25,7 +26,7 @@ impl Gradients {
|
|||
container: TensorContainer::new(),
|
||||
};
|
||||
gradients.register::<B, D>(
|
||||
root_node,
|
||||
root_node.id,
|
||||
B::float_ones(B::float_shape(&root_tensor), &B::float_device(&root_tensor)),
|
||||
);
|
||||
gradients
|
||||
|
@ -76,15 +77,15 @@ impl Gradients {
|
|||
/// If the tensor already exists, add both tensors together before saving the result.
|
||||
pub fn register<B: Backend, const D: usize>(
|
||||
&mut self,
|
||||
node: NodeRef,
|
||||
node_id: NodeID,
|
||||
value: TensorPrimitive<B, D>,
|
||||
) {
|
||||
if let Some(tensor_old) = self.container.remove::<B, D>(&node.id.value) {
|
||||
if let Some(tensor_old) = self.container.remove::<B, D>(&node_id.value) {
|
||||
self.container
|
||||
.register(node.id.value, Tensor::from_primitive(value).add(tensor_old));
|
||||
.register(node_id.value, Tensor::from_primitive(value).add(tensor_old));
|
||||
} else {
|
||||
self.container
|
||||
.register::<B, D>(node.id.value, Tensor::from_primitive(value));
|
||||
.register::<B, D>(node_id.value, Tensor::from_primitive(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients, tensor::AutodiffTensor};
|
||||
|
||||
use super::{traversal::BreadthFirstSearch, Graph, NodeRef, StepBoxed};
|
||||
|
||||
pub fn backward<B: Backend, const D: usize>(root: AutodiffTensor<B, D>) -> Gradients {
|
||||
let grads = Gradients::new::<B, D>(root.node.clone(), root.primitive);
|
||||
let checkpointer = root.graph.build_checkpointer();
|
||||
let tape = build_tape(root.node, root.graph);
|
||||
|
||||
execute_steps(tape, grads, checkpointer)
|
||||
}
|
||||
|
||||
fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
|
||||
let mut tape = (0..root.order)
|
||||
.map(|_| Vec::with_capacity(1))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
BreadthFirstSearch.traverse(root, graph, |node, step| {
|
||||
if node.order == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(steps) = tape.get_mut(node.order - 1) {
|
||||
steps.push(step)
|
||||
};
|
||||
});
|
||||
|
||||
tape
|
||||
}
|
||||
|
||||
fn execute_steps(
|
||||
tape: Vec<Vec<StepBoxed>>,
|
||||
mut grads: Gradients,
|
||||
mut checkpointer: Checkpointer,
|
||||
) -> Gradients {
|
||||
tape.into_iter().rev().for_each(|steps| {
|
||||
steps
|
||||
.into_iter()
|
||||
.for_each(|step| step.step(&mut grads, &mut checkpointer))
|
||||
});
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
// For checkpointing tests
|
||||
assert!(checkpointer.is_empty());
|
||||
|
||||
grads
|
||||
}
|
|
@ -1,153 +1,17 @@
|
|||
use spin::Mutex;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
||||
grads::Gradients,
|
||||
};
|
||||
|
||||
use super::{NodeID, NodeRef};
|
||||
use super::NodeID;
|
||||
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Backward step for reverse mode autodiff.
|
||||
pub trait Step: Send + Sync + std::fmt::Debug {
|
||||
pub trait Step: Send + std::fmt::Debug {
|
||||
/// Executes the step and consumes it.
|
||||
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
|
||||
/// The node associated to the step.
|
||||
fn node(&self) -> NodeRef;
|
||||
fn node(&self) -> NodeID;
|
||||
/// The parents of the node associated to the step.
|
||||
fn parents(&self) -> Vec<NodeID>;
|
||||
fn order(&self) -> usize;
|
||||
}
|
||||
|
||||
pub type StepBoxed = Box<dyn Step>;
|
||||
pub type NodeSteps = HashMap<NodeID, StepBoxed>;
|
||||
|
||||
/// Graph data structure.
|
||||
///
|
||||
/// The graph contains the [node steps](Step), which can be access by [node id](NodeID).
|
||||
#[derive(Default, Clone, Debug)]
|
||||
pub struct Graph {
|
||||
steps: Arc<Mutex<NodeSteps>>,
|
||||
checkpointing_actions: Arc<Mutex<CheckpointerBuilder>>,
|
||||
}
|
||||
|
||||
impl Graph {
|
||||
/// Create a new graph.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Get all the steps for the graph.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is a owned method, so the current graph will be freed. However, the steps can
|
||||
/// be shared with other graphs, therefore they are going to be cleared.
|
||||
///
|
||||
/// This is useful, since the graph is supposed to be consumed only once for backprop, and
|
||||
/// keeping all the tensors alive for multiple backward call is a heavy waste of resources.
|
||||
pub fn steps(self) -> NodeSteps {
|
||||
let mut map_drain = HashMap::new();
|
||||
self.execute_mut_steps(|map| {
|
||||
std::mem::swap(&mut *map, &mut map_drain);
|
||||
});
|
||||
map_drain
|
||||
}
|
||||
|
||||
/// # Notes
|
||||
///
|
||||
/// This is a owned method, so the current checkpointing actions will be freed.
|
||||
pub fn take_checkpointing_actions(self) -> CheckpointerBuilder {
|
||||
let mut actions = CheckpointerBuilder::default();
|
||||
self.execute_mut_checkpointing_actions(|checkpointing_actions| {
|
||||
std::mem::swap(&mut *checkpointing_actions, &mut actions);
|
||||
});
|
||||
actions
|
||||
}
|
||||
|
||||
/// Register a new step into the graph.
|
||||
pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self {
|
||||
self.execute_mut_steps(|map| {
|
||||
map.insert(id.clone(), ops);
|
||||
})
|
||||
}
|
||||
|
||||
/// Merge two graphs.
|
||||
pub fn merge(self, other: Self) -> Self {
|
||||
if Arc::ptr_eq(&self.steps, &other.steps) {
|
||||
return self;
|
||||
}
|
||||
|
||||
self.merge_different(other)
|
||||
}
|
||||
|
||||
fn execute_mut_steps<F: FnOnce(&mut NodeSteps)>(mut self, func: F) -> Self {
|
||||
match Arc::get_mut(&mut self.steps) {
|
||||
Some(mutex) => {
|
||||
let map = mutex.get_mut();
|
||||
func(map);
|
||||
}
|
||||
None => {
|
||||
// Only lock when there are multiple references to the graph.
|
||||
let mut map = self.steps.lock();
|
||||
func(&mut map);
|
||||
}
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn execute_mut_checkpointing_actions<F: FnOnce(&mut CheckpointerBuilder)>(
|
||||
mut self,
|
||||
func: F,
|
||||
) -> Self {
|
||||
match Arc::get_mut(&mut self.checkpointing_actions) {
|
||||
Some(mutex) => {
|
||||
let map = mutex.get_mut();
|
||||
func(map);
|
||||
}
|
||||
None => {
|
||||
// Only lock when there are multiple references to the graph.
|
||||
let mut actions = self.checkpointing_actions.lock();
|
||||
func(&mut actions);
|
||||
}
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn merge_different(self, other: Self) -> Self {
|
||||
let mut map2 = other.clone().steps();
|
||||
let mut actions2 = other.take_checkpointing_actions();
|
||||
|
||||
self.execute_mut_steps(|map1| {
|
||||
if map1.len() > map2.len() {
|
||||
map1.extend(map2);
|
||||
} else {
|
||||
let mut map_drain = HashMap::new();
|
||||
std::mem::swap(map1, &mut map_drain);
|
||||
map2.extend(map_drain);
|
||||
std::mem::swap(map1, &mut map2);
|
||||
}
|
||||
})
|
||||
.execute_mut_checkpointing_actions(|actions1| {
|
||||
if actions1.len() > actions2.len() {
|
||||
actions1.extend(actions2);
|
||||
} else {
|
||||
let mut checkpointing_drain = CheckpointerBuilder::default();
|
||||
std::mem::swap(actions1, &mut checkpointing_drain);
|
||||
actions2.extend(checkpointing_drain);
|
||||
std::mem::swap(actions1, &mut actions2);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn build_checkpointer(&self) -> Checkpointer {
|
||||
let mut guard = self.checkpointing_actions.lock();
|
||||
let builder: CheckpointerBuilder = std::mem::take(&mut *guard);
|
||||
builder.build(&self.steps.lock())
|
||||
}
|
||||
|
||||
pub(crate) fn extend_checkpointer_builder(&self, checkpointing_actions: CheckpointerBuilder) {
|
||||
self.checkpointing_actions
|
||||
.lock()
|
||||
.extend(checkpointing_actions);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ mod base;
|
|||
mod node;
|
||||
mod requirement;
|
||||
|
||||
pub mod backward;
|
||||
pub mod traversal;
|
||||
|
||||
pub use base::*;
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::checkpoint::retro_forward::RetroForward;
|
||||
use crate::runtime::AutodiffClientImpl;
|
||||
|
||||
use super::Requirement;
|
||||
|
||||
|
@ -14,6 +15,15 @@ pub enum ComputingProperty {
|
|||
Ambiguous, // Maybe autotune someday
|
||||
}
|
||||
|
||||
/// This is safe only because we only call RetroForward on the autodiff server.
|
||||
/// Therefore, the trait will never be used by multiple threads at the same time.
|
||||
///
|
||||
/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the
|
||||
/// Arc, which will make (dyn RetroForward) safely implement Send.
|
||||
unsafe impl Send for ComputingProperty {}
|
||||
/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.
|
||||
unsafe impl Sync for ComputingProperty {}
|
||||
|
||||
/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
|
||||
#[derive(new, Debug)]
|
||||
pub struct Node {
|
||||
|
@ -22,6 +32,7 @@ pub struct Node {
|
|||
pub id: NodeID,
|
||||
pub requirement: Requirement,
|
||||
pub properties: ComputingProperty,
|
||||
pub client: AutodiffClientImpl,
|
||||
}
|
||||
pub type NodeRef = Arc<Node>;
|
||||
|
||||
|
@ -36,7 +47,7 @@ impl Node {
|
|||
}
|
||||
|
||||
/// Unique identifier generated for each node.
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]
|
||||
pub struct NodeID {
|
||||
/// The integer representation of the id
|
||||
pub value: u64,
|
||||
|
|
|
@ -1,29 +1,28 @@
|
|||
use std::collections::HashSet;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use super::{Graph, NodeRef, StepBoxed};
|
||||
use crate::NodeID;
|
||||
|
||||
use super::StepBoxed;
|
||||
|
||||
/// Breadth for search algorithm.
|
||||
pub struct BreadthFirstSearch;
|
||||
|
||||
impl BreadthFirstSearch {
|
||||
/// Traverse the graph of backward steps from a root node.
|
||||
pub fn traverse<F: FnMut(NodeRef, StepBoxed)>(
|
||||
pub fn traverse<F: FnMut(NodeID, StepBoxed)>(
|
||||
&self,
|
||||
root: NodeRef,
|
||||
graph: Graph,
|
||||
root_id: NodeID,
|
||||
root_step: StepBoxed,
|
||||
steps: &mut HashMap<NodeID, StepBoxed>,
|
||||
mut callback: F,
|
||||
) {
|
||||
let mut visited = HashSet::with_capacity(root.order);
|
||||
let mut parents = Vec::with_capacity(root.order);
|
||||
let mut steps = graph.steps();
|
||||
let root_step = steps.remove(&root.id).expect(
|
||||
"Root node should have a step registered, did you forget to call \
|
||||
`Tensor::register_grad` on the tensor where you need gradients?",
|
||||
);
|
||||
let root_order = root_step.order();
|
||||
let mut visited = HashSet::with_capacity(root_order);
|
||||
let mut parents = Vec::with_capacity(root_order);
|
||||
|
||||
visited.insert(root.id.clone());
|
||||
parents.append(&mut root.parents.clone());
|
||||
callback(root, root_step);
|
||||
visited.insert(root_id);
|
||||
parents.append(&mut root_step.parents());
|
||||
callback(root_id, root_step);
|
||||
|
||||
while let Some(id) = parents.pop() {
|
||||
let step = match steps.remove(&id) {
|
||||
|
@ -31,21 +30,22 @@ impl BreadthFirstSearch {
|
|||
None => continue,
|
||||
};
|
||||
|
||||
let node = step.node();
|
||||
let step_node = step.node();
|
||||
let step_parents = step.parents();
|
||||
|
||||
if visited.contains(&node.id) {
|
||||
if visited.contains(&step_node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
visited.insert(node.id.clone());
|
||||
visited.insert(step_node);
|
||||
|
||||
for id in node.parents.iter() {
|
||||
for id in step_parents.iter() {
|
||||
if !visited.contains(id) {
|
||||
parents.push(id.clone());
|
||||
parents.push(*id);
|
||||
}
|
||||
}
|
||||
|
||||
callback(node, step);
|
||||
callback(step_node, step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,8 @@ pub(crate) mod utils;
|
|||
mod backend;
|
||||
mod bridge;
|
||||
|
||||
pub(crate) mod runtime;
|
||||
|
||||
pub use backend::*;
|
||||
pub use bridge::*;
|
||||
|
||||
|
|
|
@ -40,9 +40,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
|
|||
}
|
||||
|
||||
match Gelu::<D>
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroGelu::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroGelu::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -77,9 +77,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
|
|||
}
|
||||
|
||||
match Relu
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroRelu::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroRelu::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -115,9 +115,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
|
|||
}
|
||||
|
||||
match Sigmoid
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSigmoid::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroSigmoid::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -153,9 +153,9 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
|
|||
}
|
||||
|
||||
match LogSigmoid::<D>
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroLogSigmoid::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroLogSigmoid::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::{Ops, OpsPrep};
|
|||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, Graph, NodeRef, Requirement},
|
||||
graph::{ComputingProperty, NodeRef, Requirement},
|
||||
utils::duplicate,
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
@ -14,13 +14,13 @@ use burn_tensor::backend::Backend;
|
|||
/// Concrete types implementing this trait should not have any state.
|
||||
/// If a state is necessary during the backward pass,
|
||||
/// they should be declared with the associated type 'State'.
|
||||
pub trait Backward<B, const D: usize, const N: usize>: Send + Sync + std::fmt::Debug
|
||||
pub trait Backward<B, const D: usize, const N: usize>: Send + std::fmt::Debug
|
||||
where
|
||||
Self: Sized + 'static,
|
||||
B: Backend,
|
||||
{
|
||||
/// Associated type to compute the backward pass.
|
||||
type State: Clone + Send + Sync + std::fmt::Debug + 'static;
|
||||
type State: Clone + Send + std::fmt::Debug + 'static;
|
||||
|
||||
/// The backward pass.
|
||||
fn backward(
|
||||
|
@ -34,12 +34,10 @@ where
|
|||
fn prepare<C: CheckpointStrategy>(
|
||||
self,
|
||||
nodes: [NodeRef; N],
|
||||
graphs: [Graph; N],
|
||||
) -> OpsPrep<Self, B, Self::State, C, D, N> {
|
||||
let requirement = Requirement::from_nodes(&nodes);
|
||||
OpsPrep::new(
|
||||
nodes,
|
||||
graphs,
|
||||
requirement,
|
||||
self,
|
||||
ComputingProperty::Ambiguous, // If not specified we start with ambiguous
|
||||
|
@ -65,12 +63,12 @@ pub fn binary<B, const D_OUT: usize, const D_LHS: usize, const D_RHS: usize, FLh
|
|||
|
||||
if let Some(node) = node_lhs {
|
||||
let grad = func_lhs(grad_4lhs.unwrap());
|
||||
grads.register::<B, D_LHS>(node, grad)
|
||||
grads.register::<B, D_LHS>(node.id, grad)
|
||||
}
|
||||
|
||||
if let Some(node) = node_rhs {
|
||||
let grad = func_rhs(grad_4rhs.unwrap());
|
||||
grads.register::<B, D_RHS>(node, grad)
|
||||
grads.register::<B, D_RHS>(node.id, grad)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,7 +87,7 @@ pub fn unary<B, const D_OUT: usize, const D_IN: usize, F>(
|
|||
|
||||
if let Some(node) = parent_node {
|
||||
let grad = func(grad);
|
||||
grads.register::<B, D_IN>(node, grad)
|
||||
grads.register::<B, D_IN>(node.id, grad)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,6 +108,6 @@ pub fn unary_different_backend<BIn, BOut, const D_OUT: usize, const D_IN: usize,
|
|||
|
||||
if let Some(node) = parent_node {
|
||||
let grad = func(grad);
|
||||
grads.register::<BIn, D_IN>(node, grad)
|
||||
grads.register::<BIn, D_IN>(node.id, grad)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
strategy::CheckpointStrategy,
|
||||
},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, Graph, NodeID, NodeRef, Requirement, Step},
|
||||
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Shape};
|
||||
|
@ -19,7 +19,6 @@ use std::marker::PhantomData;
|
|||
#[derive(new)]
|
||||
pub struct OpsPrep<Backward, B, S, C, const D: usize, const N: usize, Mode = Init> {
|
||||
nodes: [NodeRef; N],
|
||||
graphs: [Graph; N],
|
||||
requirement: Requirement,
|
||||
backward: Backward,
|
||||
compute_property: ComputingProperty,
|
||||
|
@ -53,7 +52,6 @@ where
|
|||
pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, D, N, ComputePropertyDone> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.graphs,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
ComputingProperty::ComputeBound,
|
||||
|
@ -66,7 +64,6 @@ where
|
|||
pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, D, N, MemoryBound> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.graphs,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
|
@ -88,7 +85,6 @@ where
|
|||
) -> OpsPrep<BO, B, S, C, D, N, MemoryBoundRetroForward> {
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.graphs,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
C::compute_property(retro_forward),
|
||||
|
@ -117,7 +113,6 @@ where
|
|||
|
||||
OpsPrep::new(
|
||||
self.nodes,
|
||||
self.graphs,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
|
@ -146,7 +141,7 @@ where
|
|||
impl<BO, B, S, C, const D: usize, const N: usize> OpsPrep<BO, B, S, C, D, N, ComputePropertyDone>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
S: Clone + Send + std::fmt::Debug + 'static,
|
||||
BO: Backward<B, D, N, State = S>,
|
||||
{
|
||||
/// Prepare an operation that requires a state during the backward pass.
|
||||
|
@ -154,7 +149,6 @@ where
|
|||
match self.requirement.is_none() {
|
||||
false => OpsKind::Tracked(OpsPrep::new(
|
||||
self.nodes,
|
||||
self.graphs,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
|
@ -162,7 +156,6 @@ where
|
|||
)),
|
||||
true => OpsKind::UnTracked(OpsPrep::new(
|
||||
self.nodes,
|
||||
self.graphs,
|
||||
self.requirement,
|
||||
self.backward,
|
||||
self.compute_property,
|
||||
|
@ -175,7 +168,7 @@ where
|
|||
impl<BO, B, S, C, const D: usize, const N: usize> OpsPrep<BO, B, S, C, D, N, UnTracked>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
S: Clone + Send + std::fmt::Debug + 'static,
|
||||
BO: Backward<B, D, N, State = S>,
|
||||
{
|
||||
/// Finish the preparation of an untracked operation and returns the output tensor.
|
||||
|
@ -183,24 +176,22 @@ where
|
|||
let output = AutodiffTensor::from_parents(
|
||||
output,
|
||||
&self.nodes,
|
||||
self.graphs.into_iter(),
|
||||
self.requirement,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
);
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, output.node.clone(), ());
|
||||
|
||||
// We register the ops in the graph even if untracked, otherwise memory bound operations
|
||||
// that have an untracked parent would not be able to retrieve it
|
||||
output.register_step(UntrackedOpsStep::new(ops))
|
||||
output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BO, B, S, C, const D: usize, const N: usize> OpsPrep<BO, B, S, C, D, N, Tracked>
|
||||
where
|
||||
B: Backend,
|
||||
S: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
S: Clone + Send + std::fmt::Debug + 'static,
|
||||
BO: Backward<B, D, N, State = S>,
|
||||
{
|
||||
/// Finish the preparation of a tracked operation and returns the output tensor.
|
||||
|
@ -212,15 +203,13 @@ where
|
|||
let output = AutodiffTensor::from_parents(
|
||||
output,
|
||||
&self.nodes,
|
||||
self.graphs.into_iter(),
|
||||
self.requirement,
|
||||
self.compute_property,
|
||||
self.checkpointer_builder,
|
||||
);
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, output.node.clone(), state);
|
||||
|
||||
output.register_step(OpsStep::new(ops, self.backward))
|
||||
output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder)
|
||||
}
|
||||
|
||||
/// Checkpoints the tensor
|
||||
|
@ -228,7 +217,7 @@ where
|
|||
self.checkpointer_builder
|
||||
.checkpoint(tensor, ActionType::Explicit);
|
||||
|
||||
tensor.node.id.clone()
|
||||
tensor.node.id
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -257,7 +246,7 @@ struct OpsStep<B, T, SB, const D: usize, const N: usize>
|
|||
where
|
||||
B: Backend,
|
||||
T: Backward<B, D, N, State = SB>,
|
||||
SB: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
SB: Clone + Send + std::fmt::Debug + 'static,
|
||||
{
|
||||
ops: Ops<SB, N>,
|
||||
backward: T,
|
||||
|
@ -268,14 +257,22 @@ impl<B, T, SB, const D: usize, const N: usize> Step for OpsStep<B, T, SB, D, N>
|
|||
where
|
||||
B: Backend,
|
||||
T: Backward<B, D, N, State = SB>,
|
||||
SB: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
SB: Clone + Send + std::fmt::Debug + 'static,
|
||||
{
|
||||
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
|
||||
self.backward.backward(self.ops, grads, checkpointer);
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeRef {
|
||||
self.ops.node.clone()
|
||||
fn node(&self) -> NodeID {
|
||||
self.ops.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> Vec<NodeID> {
|
||||
self.ops.node.parents.clone()
|
||||
}
|
||||
|
||||
fn order(&self) -> usize {
|
||||
self.ops.node.order
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -289,8 +286,15 @@ impl<const N: usize> Step for UntrackedOpsStep<N> {
|
|||
// Nothing to do
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeRef {
|
||||
self.ops.node.clone()
|
||||
fn node(&self) -> NodeID {
|
||||
self.ops.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> Vec<NodeID> {
|
||||
self.ops.node.parents.clone()
|
||||
}
|
||||
fn order(&self) -> usize {
|
||||
self.ops.node.order
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
}
|
||||
|
||||
match Embedding
|
||||
.prepare::<C>([weights.node], [weights.graph])
|
||||
.prepare::<C>([weights.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -85,13 +85,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv2d_backward(x, weight, bias, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node, backward.weights_grad)
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -115,20 +115,17 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv2d_backward(x, weight, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node, backward.weights_grad)
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match bias {
|
||||
Some(bias) => match Conv2DWithBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone(), bias.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -149,10 +146,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
)),
|
||||
},
|
||||
None => match Conv2DNoBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -203,13 +197,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv_transpose2d_backward(x, weight, bias, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node, backward.weights_grad)
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -233,20 +227,17 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv_transpose2d_backward(x, weight, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node, backward.weights_grad)
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match bias {
|
||||
Some(bias) => match ConvTranspose2DWithBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone(), bias.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -273,10 +264,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
)),
|
||||
},
|
||||
None => match ConvTranspose2DNoBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -330,13 +318,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv1d_backward(x, weight, bias, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node, backward.weights_grad)
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -360,19 +348,16 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv1d_backward(x, weight, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node, backward.weights_grad)
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
match bias {
|
||||
Some(bias) => match Conv1DWithBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone(), bias.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -393,10 +378,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
)),
|
||||
},
|
||||
None => match Conv1DNoBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -446,13 +428,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv_transpose1d_backward(x, weight, bias, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node, backward.weights_grad)
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node, backward.bias_grad.unwrap())
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -476,20 +458,17 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let backward = B::conv_transpose1d_backward(x, weight, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node, backward.weights_grad)
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match bias {
|
||||
Some(bias) => match ConvTranspose1DWithBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone(), bias.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone(), bias.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -515,10 +494,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
)),
|
||||
},
|
||||
None => match ConvTranspose1DNoBias
|
||||
.prepare::<C>(
|
||||
[x.node.clone(), weight.node.clone()],
|
||||
[x.graph.clone(), weight.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([x.node.clone(), weight.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -588,13 +564,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
padding,
|
||||
count_include_pad,
|
||||
);
|
||||
grads.register::<B, 3>(node, grad);
|
||||
grads.register::<B, 3>(node.id, grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match AvgPool1D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -654,13 +630,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
padding,
|
||||
count_include_pad,
|
||||
);
|
||||
grads.register::<B, 4>(node, grad);
|
||||
grads.register::<B, 4>(node.id, grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match AvgPool2D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -706,7 +682,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
dilation: usize,
|
||||
) -> AutodiffTensor<B, 3> {
|
||||
match MaxPool1D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -744,7 +720,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
dilation: usize,
|
||||
) -> MaxPool1dWithIndices<Self> {
|
||||
match MaxPool1D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -806,7 +782,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
dilation: [usize; 2],
|
||||
) -> AutodiffTensor<B, 4> {
|
||||
match MaxPool2D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -844,7 +820,7 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<Self> {
|
||||
match MaxPool2D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -908,13 +884,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
|
||||
if let Some(node) = node_parent {
|
||||
let grad = B::adaptive_avg_pool1d_backward(state, grad);
|
||||
grads.register::<B, 3>(node, grad);
|
||||
grads.register::<B, 3>(node.id, grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match AdaptiveAvgPool1D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -950,13 +926,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
|
||||
if let Some(node) = node_parent {
|
||||
let grad = B::adaptive_avg_pool2d_backward(state, grad);
|
||||
grads.register::<B, 4>(node, grad);
|
||||
grads.register::<B, 4>(node.id, grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match AdaptiveAvgPool2D
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1001,13 +977,13 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
|
||||
if let Some(node) = node_parent {
|
||||
let grad = B::interpolate_backward(state, grad, output_size, options);
|
||||
grads.register::<B, 4>(node, grad);
|
||||
grads.register::<B, 4>(node.id, grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match Interpolate
|
||||
.prepare::<C>([x.node.clone()], [x.graph.clone()])
|
||||
.prepare::<C>([x.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1060,7 +1036,7 @@ impl<B: Backend> Backward<B, 3, 1> for MaxPool1D {
|
|||
indices,
|
||||
);
|
||||
|
||||
grads.register::<B, 3>(node, grad.x_grad);
|
||||
grads.register::<B, 3>(node.id, grad.x_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1100,7 +1076,7 @@ impl<B: Backend> Backward<B, 4, 1> for MaxPool2D {
|
|||
indices,
|
||||
);
|
||||
|
||||
grads.register::<B, 4>(node, grad.x_grad);
|
||||
grads.register::<B, 4>(node.id, grad.x_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match ToDevice
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -137,15 +137,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Add
|
||||
.prepare::<C>(
|
||||
[lhs.node.clone(), rhs.node.clone()],
|
||||
[lhs.graph.clone(), rhs.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroAdd::<B, D>::new(
|
||||
lhs.node.id.clone(),
|
||||
rhs.node.id.clone(),
|
||||
))
|
||||
.retro_forward(RetroAdd::<B, D>::new(lhs.node.id, rhs.node.id))
|
||||
.parents([&lhs, &rhs])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -183,9 +177,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
AddScalar
|
||||
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
|
||||
.prepare::<C>([lhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroAddScalar::<B, D>::new(lhs.node.id.clone(), rhs))
|
||||
.retro_forward(RetroAddScalar::<B, D>::new(lhs.node.id, rhs))
|
||||
.parents([&lhs])
|
||||
.stateless(B::float_add_scalar(lhs.primitive, rhs))
|
||||
}
|
||||
|
@ -221,15 +215,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Sub
|
||||
.prepare::<C>(
|
||||
[lhs.node.clone(), rhs.node.clone()],
|
||||
[lhs.graph.clone(), rhs.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSub::<B, D>::new(
|
||||
lhs.node.id.clone(),
|
||||
rhs.node.id.clone(),
|
||||
))
|
||||
.retro_forward(RetroSub::<B, D>::new(lhs.node.id, rhs.node.id))
|
||||
.parents([&lhs, &rhs])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -267,9 +255,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
SubScalar
|
||||
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
|
||||
.prepare::<C>([lhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSubScalar::<B, D>::new(lhs.node.id.clone(), rhs))
|
||||
.retro_forward(RetroSubScalar::<B, D>::new(lhs.node.id, rhs))
|
||||
.parents([&lhs])
|
||||
.stateless(B::float_sub_scalar(lhs.primitive, rhs))
|
||||
}
|
||||
|
@ -317,15 +305,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
|
||||
|
||||
match Mul
|
||||
.prepare::<C>(
|
||||
[lhs.node.clone(), rhs.node.clone()],
|
||||
[lhs.graph.clone(), rhs.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroMul::<B, D>::new(
|
||||
lhs.node.id.clone(),
|
||||
rhs.node.id.clone(),
|
||||
))
|
||||
.retro_forward(RetroMul::<B, D>::new(lhs.node.id, rhs.node.id))
|
||||
.parents([&lhs, &rhs])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -367,9 +349,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match MulScalar
|
||||
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
|
||||
.prepare::<C>([lhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroMulScalar::<B, D>::new(lhs.node.id.clone(), rhs))
|
||||
.retro_forward(RetroMulScalar::<B, D>::new(lhs.node.id, rhs))
|
||||
.parents([&lhs])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -429,15 +411,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
|
||||
|
||||
match Div
|
||||
.prepare::<C>(
|
||||
[lhs.node.clone(), rhs.node.clone()],
|
||||
[lhs.graph.clone(), rhs.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroDiv::<B, D>::new(
|
||||
lhs.node.id.clone(),
|
||||
rhs.node.id.clone(),
|
||||
))
|
||||
.retro_forward(RetroDiv::<B, D>::new(lhs.node.id, rhs.node.id))
|
||||
.parents([&lhs, &rhs])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -480,9 +456,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match DivScalar
|
||||
.prepare::<C>([lhs.node.clone()], [lhs.graph.clone()])
|
||||
.prepare::<C>([lhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroDivScalar::<B, D>::new(lhs.node.id.clone(), rhs))
|
||||
.retro_forward(RetroDivScalar::<B, D>::new(lhs.node.id, rhs))
|
||||
.parents([&lhs])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -536,10 +512,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
|
||||
|
||||
match Matmul
|
||||
.prepare::<C>(
|
||||
[lhs.node.clone(), rhs.node.clone()],
|
||||
[lhs.graph.clone(), rhs.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -574,9 +547,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
Neg.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
Neg.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroNeg::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroNeg::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateless(B::float_neg(tensor.primitive))
|
||||
}
|
||||
|
@ -607,9 +580,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Recip
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroRecip::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroRecip::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -663,13 +636,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match SwapDim
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSwapDims::<B, D>::new(
|
||||
tensor.node.id.clone(),
|
||||
dim1,
|
||||
dim2,
|
||||
))
|
||||
.retro_forward(RetroSwapDims::<B, D>::new(tensor.node.id, dim1, dim2))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -728,9 +697,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match PermuteDim
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroPermuteDims::<B, D>::new(tensor.node.id.clone(), axes))
|
||||
.retro_forward(RetroPermuteDims::<B, D>::new(tensor.node.id, axes))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -779,12 +748,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match FlipDim
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroFlipDims::<B, D>::new(
|
||||
tensor.node.id.clone(),
|
||||
axes.to_vec(),
|
||||
))
|
||||
.retro_forward(RetroFlipDims::<B, D>::new(tensor.node.id, axes.to_vec()))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -844,10 +810,10 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match ReshapeDim
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroReshape::<B, D1, D2>::new(
|
||||
tensor.node.id.clone(),
|
||||
tensor.node.id,
|
||||
shape.clone(),
|
||||
))
|
||||
.parents([&tensor])
|
||||
|
@ -888,7 +854,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Gather
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -945,7 +911,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Scatter
|
||||
.prepare::<C>([tensor.node, value.node], [tensor.graph, value.graph])
|
||||
.prepare::<C>([tensor.node, value.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1010,10 +976,10 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Select
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSelect::<B, D>::new(
|
||||
tensor.node.id.clone(),
|
||||
tensor.node.id,
|
||||
dim,
|
||||
indices.clone(),
|
||||
))
|
||||
|
@ -1090,16 +1056,13 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match IndexSelectDimAssign::<D>
|
||||
.prepare::<C>(
|
||||
[tensor.node.clone(), value.node.clone()],
|
||||
[tensor.graph.clone(), value.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([tensor.node.clone(), value.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSelectAssign::<B, D>::new(
|
||||
tensor.node.id.clone(),
|
||||
tensor.node.id,
|
||||
dim,
|
||||
indices.clone(),
|
||||
value.node.id.clone(),
|
||||
value.node.id,
|
||||
))
|
||||
.parents([&tensor, &value])
|
||||
.stateful()
|
||||
|
@ -1164,12 +1127,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Index
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSlice::<B, D1, D2>::new(
|
||||
tensor.node.id.clone(),
|
||||
ranges.clone(),
|
||||
))
|
||||
.retro_forward(RetroSlice::<B, D1, D2>::new(tensor.node.id, ranges.clone()))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1236,15 +1196,12 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match SliceAssign
|
||||
.prepare::<C>(
|
||||
[tensor.node.clone(), value.node.clone()],
|
||||
[tensor.graph.clone(), value.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([tensor.node.clone(), value.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSliceAssign::<B, D1, D2>::new(
|
||||
tensor.node.id.clone(),
|
||||
tensor.node.id,
|
||||
ranges.clone(),
|
||||
value.node.id.clone(),
|
||||
value.node.id,
|
||||
))
|
||||
.parents([&tensor, &value])
|
||||
.stateful()
|
||||
|
@ -1306,7 +1263,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match MaskWhere
|
||||
.prepare::<C>([tensor.node, source.node], [tensor.graph, source.graph])
|
||||
.prepare::<C>([tensor.node, source.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1351,7 +1308,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match MaskFill
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1489,11 +1446,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
match Mean
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
match Mean.prepare::<C>([tensor.node]).compute_bound().stateful() {
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
B::float_shape(&tensor.primitive),
|
||||
B::float_mean(tensor.primitive),
|
||||
|
@ -1526,11 +1479,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
match Sum
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
match Sum.prepare::<C>([tensor.node]).compute_bound().stateful() {
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
B::float_shape(&tensor.primitive),
|
||||
B::float_sum(tensor.primitive),
|
||||
|
@ -1569,7 +1518,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match MeanDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1609,7 +1558,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match SumDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1653,9 +1602,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Exp
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroExp::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroExp::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1691,9 +1640,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Log
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroLog::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroLog::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1731,9 +1680,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Log1P
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroLog1P::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroLog1P::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1789,9 +1738,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match PowfScalar
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroPowfScalar::<B, D>::new(tensor.node.id.clone(), value))
|
||||
.retro_forward(RetroPowfScalar::<B, D>::new(tensor.node.id, value))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1828,9 +1777,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Sqrt
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSqrt::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroSqrt::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1868,9 +1817,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Abs
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroAbs::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroAbs::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1907,9 +1856,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Cos
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroCos::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroCos::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1945,9 +1894,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Sin
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSin::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroSin::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -1987,9 +1936,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Tanh
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroTanh::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroTanh::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2029,9 +1978,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match Erf
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroErf::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroErf::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2074,17 +2023,27 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
let mut ranges = ranges.clone();
|
||||
ranges[self.dim] = current_index..dim_size + current_index;
|
||||
current_index += dim_size;
|
||||
grads.register::<B, D>(node, B::float_slice(grad.clone(), ranges));
|
||||
grads.register::<B, D>(node.id, B::float_slice(grad.clone(), ranges));
|
||||
});
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeRef {
|
||||
self.output.clone()
|
||||
fn node(&self) -> NodeID {
|
||||
self.output.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> Vec<NodeID> {
|
||||
self.nodes
|
||||
.iter()
|
||||
.filter_map(|node| node.clone())
|
||||
.map(|node| node.id)
|
||||
.collect()
|
||||
}
|
||||
fn order(&self) -> usize {
|
||||
self.output.order
|
||||
}
|
||||
}
|
||||
|
||||
let mut nodes = Vec::with_capacity(tensors.len());
|
||||
let mut graphs = Vec::with_capacity(tensors.len());
|
||||
let mut primitives = Vec::with_capacity(tensors.len());
|
||||
let mut dim_sizes = Vec::with_capacity(tensors.len());
|
||||
|
||||
|
@ -2092,7 +2051,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
dim_sizes.push(B::float_shape(&tensor.primitive).dims[dim]);
|
||||
nodes.push(tensor.node);
|
||||
primitives.push(tensor.primitive);
|
||||
graphs.push(tensor.graph);
|
||||
});
|
||||
|
||||
let requirement = Requirement::from_nodes(&nodes);
|
||||
|
@ -2106,28 +2064,20 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
return AutodiffTensor::from_parents(
|
||||
output,
|
||||
&nodes,
|
||||
graphs.into_iter(),
|
||||
requirement,
|
||||
cat_computing_property,
|
||||
checkpointer_builder,
|
||||
);
|
||||
}
|
||||
|
||||
let output = AutodiffTensor::from_parents(
|
||||
output,
|
||||
&nodes,
|
||||
graphs.into_iter(),
|
||||
requirement,
|
||||
cat_computing_property,
|
||||
checkpointer_builder,
|
||||
);
|
||||
let output =
|
||||
AutodiffTensor::from_parents(output, &nodes, requirement, cat_computing_property);
|
||||
let nodes = nodes
|
||||
.into_iter()
|
||||
.map(|node| node.clone_if_require_grad())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ops = CatStep::<B, D>::new(nodes, dim_sizes, output.node.clone(), dim);
|
||||
output.register_step(ops)
|
||||
output.register_step(ops, checkpointer_builder)
|
||||
}
|
||||
|
||||
fn float_max_dim<const D: usize>(
|
||||
|
@ -2135,7 +2085,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
dim: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
match MaxMinDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2152,7 +2102,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
dim: usize,
|
||||
) -> (FloatTensor<Self, D>, IntTensor<B, D>) {
|
||||
match MaxMinDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2176,7 +2126,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
dim: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
match MaxMinDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2193,7 +2143,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
dim: usize,
|
||||
) -> (FloatTensor<Self, D>, IntTensor<B, D>) {
|
||||
match MaxMinDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2283,15 +2233,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
|
||||
|
||||
match PowF
|
||||
.prepare::<C>(
|
||||
[lhs.node.clone(), rhs.node.clone()],
|
||||
[lhs.graph.clone(), rhs.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroPowf::<B, D>::new(
|
||||
lhs.node.id.clone(),
|
||||
rhs.node.id.clone(),
|
||||
))
|
||||
.retro_forward(RetroPowf::<B, D>::new(lhs.node.id, rhs.node.id))
|
||||
.parents([&lhs, &rhs])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2329,9 +2273,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
Sign.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
Sign.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroSign::<B, D>::new(tensor.node.id.clone()))
|
||||
.retro_forward(RetroSign::<B, D>::new(tensor.node.id))
|
||||
.parents([&tensor])
|
||||
.stateless(B::float_sign(tensor.primitive))
|
||||
}
|
||||
|
@ -2394,12 +2338,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
match ExpandDim
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.prepare::<C>([tensor.node.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroExpand::<B, D1, D2>::new(
|
||||
tensor.node.id.clone(),
|
||||
shape.clone(),
|
||||
))
|
||||
.retro_forward(RetroExpand::<B, D1, D2>::new(tensor.node.id, shape.clone()))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2417,7 +2358,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
descending: bool,
|
||||
) -> FloatTensor<Self, D> {
|
||||
match SortDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
@ -2439,7 +2380,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
descending: bool,
|
||||
) -> (FloatTensor<Self, D>, IntTensor<B, D>) {
|
||||
match SortDim
|
||||
.prepare::<C>([tensor.node], [tensor.graph])
|
||||
.prepare::<C>([tensor.node])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
use crate::{
|
||||
checkpoint::builder::CheckpointerBuilder,
|
||||
grads::Gradients,
|
||||
graph::StepBoxed,
|
||||
tensor::{AutodiffTensor, NodeRefCount},
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
/// Client used to communicate with the autodiff server.
|
||||
pub trait AutodiffClient: Send + Clone {
|
||||
/// Register a new step.
|
||||
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
|
||||
/// Call backpropagation from the given tensor.
|
||||
fn backward<B: Backend, const D: usize>(&self, tensor: AutodiffTensor<B, D>) -> Gradients;
|
||||
}
|
||||
|
||||
/// Client implementation in used.
|
||||
#[cfg(feature = "std")]
|
||||
pub type AutodiffClientImpl = super::mspc::ChannelClient;
|
||||
|
||||
/// Client implementation in used.
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub type AutodiffClientImpl = super::mutex::MutexClient;
|
|
@ -0,0 +1,271 @@
|
|||
use crate::{tensor::NodeRefCount, NodeID};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
/// Keeps a version on the graphs created during autodiff with the reference count of each node.
|
||||
///
|
||||
/// When all nodes in a graph have only one reference, the graph can be freed.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct GraphMemoryManagement {
|
||||
graphs: HashMap<GraphId, GraphState>,
|
||||
owned: HashSet<GraphId>,
|
||||
}
|
||||
|
||||
#[derive(new, Hash, PartialEq, Eq, Clone, Copy, Debug)]
|
||||
pub struct GraphId {
|
||||
node: NodeID,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum GraphState {
|
||||
Merged(GraphId),
|
||||
Owned(Vec<NodeRefCount>),
|
||||
}
|
||||
|
||||
impl GraphMemoryManagement {
|
||||
/// Register a new node with its parent.
|
||||
pub fn register(&mut self, node: NodeRefCount, parents: Vec<NodeID>) {
|
||||
let node_id = *node.as_ref();
|
||||
let graph_id = GraphId::new(node_id);
|
||||
|
||||
self.insert_owned_graph(graph_id, vec![node.clone()]);
|
||||
|
||||
if !parents.is_empty() {
|
||||
let graph_ids = parents.into_iter().map(GraphId::new);
|
||||
if let Some(parent_graph_id) = self.merge_graph(graph_ids) {
|
||||
self.merge_graph([graph_id, parent_graph_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Free the given graph calling the given function for each node deleted.
|
||||
pub fn free_graph<F>(&mut self, graph_id: GraphId, mut func: F)
|
||||
where
|
||||
F: FnMut(&NodeID),
|
||||
{
|
||||
self.owned.remove(&graph_id);
|
||||
let graph = match self.graphs.remove(&graph_id) {
|
||||
Some(graph) => graph,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let graph = match graph {
|
||||
GraphState::Merged(graph) => {
|
||||
self.free_graph(graph, func);
|
||||
return;
|
||||
}
|
||||
GraphState::Owned(graph) => graph,
|
||||
};
|
||||
|
||||
for node_id in graph.into_iter() {
|
||||
func(&node_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the graphs where all nodes are orphan.
|
||||
///
|
||||
/// The returned graphs can be safely freed.
|
||||
pub fn find_orphan_graphs(&self) -> Vec<GraphId> {
|
||||
self.owned
|
||||
.iter()
|
||||
.filter(|id| self.is_orphan(id))
|
||||
.copied()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_orphan(&self, id: &GraphId) -> bool {
|
||||
let graph = match self.graphs.get(id) {
|
||||
Some(val) => val,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
let nodes = match graph {
|
||||
GraphState::Merged(_) => return false,
|
||||
GraphState::Owned(nodes) => nodes,
|
||||
};
|
||||
|
||||
for node in nodes {
|
||||
if Arc::strong_count(node) > 1 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn insert_owned_graph(&mut self, graph_id: GraphId, nodes: Vec<NodeRefCount>) {
|
||||
self.graphs.insert(graph_id, GraphState::Owned(nodes));
|
||||
self.owned.insert(graph_id);
|
||||
}
|
||||
|
||||
fn merge_graph<I: IntoIterator<Item = GraphId>>(&mut self, graph_ids: I) -> Option<GraphId> {
|
||||
let graph_ids = graph_ids.into_iter();
|
||||
let graph_ids = graph_ids.collect::<Vec<_>>();
|
||||
|
||||
let mut merged = HashSet::new();
|
||||
|
||||
let mut updated_nodes = Vec::new();
|
||||
let mut updated_graph_id = None;
|
||||
|
||||
for id in graph_ids {
|
||||
let graph_id = match self.find_owned_graph(id) {
|
||||
Some(val) => val,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if updated_graph_id.is_none() {
|
||||
updated_graph_id = Some(graph_id);
|
||||
}
|
||||
|
||||
merged.insert(graph_id);
|
||||
}
|
||||
|
||||
let updated_graph_id = match updated_graph_id {
|
||||
Some(val) => val,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
for id in merged {
|
||||
let mut updated_state = GraphState::Merged(updated_graph_id);
|
||||
let state = self.graphs.get_mut(&id).unwrap();
|
||||
self.owned.remove(&id);
|
||||
|
||||
core::mem::swap(state, &mut updated_state);
|
||||
|
||||
if let GraphState::Owned(nodes) = updated_state {
|
||||
updated_nodes.extend(nodes)
|
||||
};
|
||||
}
|
||||
|
||||
self.insert_owned_graph(updated_graph_id, updated_nodes);
|
||||
|
||||
Some(updated_graph_id)
|
||||
}
|
||||
|
||||
fn find_owned_graph(&mut self, graph_id: GraphId) -> Option<GraphId> {
|
||||
let graph = match self.graphs.get(&graph_id) {
|
||||
Some(val) => val,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let merged_graph_id = match graph {
|
||||
GraphState::Merged(graph_id) => graph_id,
|
||||
GraphState::Owned(_) => return Some(graph_id),
|
||||
};
|
||||
|
||||
self.find_owned_graph(*merged_graph_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for GraphMemoryManagement {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"Graphs Memory Management with {} owned graphs and total of {} graphs\n",
|
||||
self.owned.len(),
|
||||
self.graphs.len()
|
||||
))?;
|
||||
for (id, state) in self.graphs.iter() {
|
||||
f.write_fmt(format_args!("Graph {} => ", id.node.value))?;
|
||||
match state {
|
||||
GraphState::Merged(id) => f.write_fmt(format_args!("Merged {}", id.node.value))?,
|
||||
GraphState::Owned(nodes) => {
|
||||
f.write_str("Owned")?;
|
||||
for node in nodes {
|
||||
f.write_fmt(format_args!(" {}", node.value))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
f.write_str("\n")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_graph_memory_management_connect_graphs() {
|
||||
let mut graph_mm = GraphMemoryManagement::default();
|
||||
|
||||
let node_1 = Arc::new(NodeID::new());
|
||||
let node_2 = Arc::new(NodeID::new());
|
||||
let node_3 = Arc::new(NodeID::new());
|
||||
let node_4 = Arc::new(NodeID::new());
|
||||
let node_5 = Arc::new(NodeID::new());
|
||||
|
||||
graph_mm.register(node_1.clone(), vec![]);
|
||||
graph_mm.register(node_2.clone(), vec![*node_1]);
|
||||
assert_eq!(graph_mm.owned.len(), 1, "A single connected graph.");
|
||||
|
||||
graph_mm.register(node_3.clone(), vec![]);
|
||||
graph_mm.register(node_4.clone(), vec![*node_3]);
|
||||
assert_eq!(graph_mm.owned.len(), 2, "Two connected graphs.");
|
||||
|
||||
graph_mm.register(node_5.clone(), vec![*node_1, *node_3]);
|
||||
assert_eq!(
|
||||
graph_mm.owned.len(),
|
||||
1,
|
||||
"Two connected graphs are merged into one."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_memory_management_find_orphans() {
|
||||
let mut graph_mm = GraphMemoryManagement::default();
|
||||
|
||||
let node_1 = Arc::new(NodeID::new());
|
||||
let node_2 = Arc::new(NodeID::new());
|
||||
|
||||
graph_mm.register(node_1.clone(), vec![]);
|
||||
graph_mm.register(node_2.clone(), vec![*node_1]);
|
||||
|
||||
core::mem::drop(node_1);
|
||||
assert_eq!(
|
||||
graph_mm.find_orphan_graphs().len(),
|
||||
0,
|
||||
"Not all nodes are dropped"
|
||||
);
|
||||
|
||||
core::mem::drop(node_2);
|
||||
assert_eq!(
|
||||
graph_mm.find_orphan_graphs().len(),
|
||||
1,
|
||||
"All nodes are dropped"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_memory_management_free_graph_from_any_node() {
|
||||
let mut graph_mm = GraphMemoryManagement::default();
|
||||
|
||||
// Create a graph and free(node_1)
|
||||
let node_1 = Arc::new(NodeID::new());
|
||||
let node_2 = Arc::new(NodeID::new());
|
||||
|
||||
graph_mm.register(node_1.clone(), vec![]);
|
||||
graph_mm.register(node_2.clone(), vec![*node_1]);
|
||||
|
||||
let mut node_ids = Vec::new();
|
||||
graph_mm.free_graph(GraphId::new(*node_1.as_ref()), |id| node_ids.push(*id));
|
||||
|
||||
assert!(node_ids.contains(&node_1));
|
||||
assert!(node_ids.contains(&node_2));
|
||||
|
||||
// Same but with free(node_2);
|
||||
graph_mm.register(node_1.clone(), vec![]);
|
||||
graph_mm.register(node_2.clone(), vec![*node_1]);
|
||||
|
||||
let mut node_ids = Vec::new();
|
||||
graph_mm.free_graph(GraphId::new(*node_2.as_ref()), |id| node_ids.push(*id));
|
||||
|
||||
assert!(node_ids.contains(&node_1));
|
||||
assert!(node_ids.contains(&node_2));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
mod client;
|
||||
mod memory_management;
|
||||
mod server;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub mod mutex;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub mod mspc;
|
||||
|
||||
pub use client::*;
|
|
@ -0,0 +1,94 @@
|
|||
use super::{server::AutodiffServer, AutodiffClient};
|
||||
use crate::{
|
||||
checkpoint::builder::CheckpointerBuilder,
|
||||
grads::Gradients,
|
||||
graph::StepBoxed,
|
||||
tensor::{AutodiffTensor, NodeRefCount},
|
||||
NodeID,
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::sync::mpsc::Sender;
|
||||
|
||||
static INSTANCE: spin::Lazy<ChannelClient> = spin::Lazy::new(ChannelClient::init);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChannelClient {
|
||||
sender: Sender<Message>,
|
||||
}
|
||||
|
||||
enum Message {
|
||||
Register {
|
||||
node_id: NodeRefCount,
|
||||
step: StepBoxed,
|
||||
actions: CheckpointerBuilder,
|
||||
},
|
||||
Backward {
|
||||
node_id: NodeID,
|
||||
grads: Gradients,
|
||||
callback: Sender<Gradients>,
|
||||
},
|
||||
}
|
||||
impl ChannelClient {
|
||||
pub(crate) fn new() -> Self {
|
||||
INSTANCE.clone()
|
||||
}
|
||||
|
||||
fn init() -> Self {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let mut server = AutodiffServer::default();
|
||||
|
||||
for message in receiver.iter() {
|
||||
match message {
|
||||
Message::Register {
|
||||
node_id,
|
||||
step,
|
||||
actions,
|
||||
} => server.register(node_id, step, actions),
|
||||
Message::Backward {
|
||||
node_id,
|
||||
grads,
|
||||
callback,
|
||||
} => {
|
||||
let grads = server.backward(grads, node_id);
|
||||
callback.send(grads).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self { sender }
|
||||
}
|
||||
}
|
||||
|
||||
impl AutodiffClient for ChannelClient {
|
||||
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
|
||||
self.sender
|
||||
.send(Message::Register {
|
||||
node_id,
|
||||
step,
|
||||
actions,
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
|
||||
let node_id = root.node.id;
|
||||
let grads = Gradients::new::<B, D>(root.node, root.primitive);
|
||||
let (callback, receiver) = std::sync::mpsc::channel();
|
||||
|
||||
self.sender
|
||||
.send(Message::Backward {
|
||||
node_id,
|
||||
grads,
|
||||
callback,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv() {
|
||||
Ok(grads) => grads,
|
||||
Err(err) => panic!("Error during backward {err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
use super::{server::AutodiffServer, AutodiffClient};
|
||||
use crate::{
|
||||
checkpoint::builder::CheckpointerBuilder,
|
||||
grads::Gradients,
|
||||
graph::StepBoxed,
|
||||
tensor::{AutodiffTensor, NodeRefCount},
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
#[derive(Clone, new)]
|
||||
pub struct MutexClient;
|
||||
|
||||
impl core::fmt::Debug for MutexClient {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("MutexClient")
|
||||
}
|
||||
}
|
||||
|
||||
static SERVER: spin::Mutex<Option<AutodiffServer>> = spin::Mutex::new(None);
|
||||
|
||||
impl AutodiffClient for MutexClient {
|
||||
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
|
||||
let mut server = SERVER.lock();
|
||||
|
||||
if let Some(server) = server.as_mut() {
|
||||
server.register(node_id, step, actions);
|
||||
return;
|
||||
}
|
||||
|
||||
let mut server_new = AutodiffServer::default();
|
||||
server_new.register(node_id, step, actions);
|
||||
*server = Some(server_new);
|
||||
}
|
||||
fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
|
||||
let mut server = SERVER.lock();
|
||||
let node_id = root.node.id.clone();
|
||||
let grads = Gradients::new::<B, D>(root.node, root.primitive);
|
||||
|
||||
if let Some(server) = server.as_mut() {
|
||||
return server.backward(grads, node_id);
|
||||
}
|
||||
|
||||
let mut server_new = AutodiffServer::default();
|
||||
let gradients = server_new.backward(grads, node_id);
|
||||
*server = Some(server_new);
|
||||
|
||||
gradients
|
||||
}
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
use super::memory_management::GraphMemoryManagement;
|
||||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
||||
grads::Gradients,
|
||||
graph::{traversal::BreadthFirstSearch, StepBoxed},
|
||||
tensor::NodeRefCount,
|
||||
NodeID,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AutodiffServer {
|
||||
steps: HashMap<NodeID, StepBoxed>,
|
||||
actions_builder: HashMap<NodeID, CheckpointerBuilder>,
|
||||
memory_management: GraphMemoryManagement,
|
||||
}
|
||||
|
||||
impl AutodiffServer {
|
||||
pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
|
||||
let parents = step.parents();
|
||||
let node_id = *rc.as_ref();
|
||||
|
||||
self.memory_management.register(rc, parents);
|
||||
|
||||
self.steps.insert(node_id, step);
|
||||
self.actions_builder.insert(node_id, actions);
|
||||
}
|
||||
|
||||
pub fn backward(&mut self, grads: Gradients, node_id: NodeID) -> Gradients {
|
||||
let step = self.steps.remove(&node_id).expect(
|
||||
"Root node should have a step registered, did you forget to call \
|
||||
`Tensor::register_grad` on the tensor where you need gradients?",
|
||||
);
|
||||
let builder = self.actions_builder.remove(&node_id).unwrap();
|
||||
|
||||
let (tape, builder) = self.build_tape(node_id, step, builder);
|
||||
let checkpointer = builder.build(&self.steps);
|
||||
|
||||
let gradients = Self::execute_steps(tape, grads, checkpointer);
|
||||
|
||||
// Cleanup
|
||||
let mut on_free_graph = |node_id: &NodeID| {
|
||||
self.steps.remove(node_id);
|
||||
self.actions_builder.remove(node_id);
|
||||
};
|
||||
|
||||
for graph_id in self.memory_management.find_orphan_graphs() {
|
||||
self.memory_management
|
||||
.free_graph(graph_id, &mut on_free_graph);
|
||||
}
|
||||
|
||||
gradients
|
||||
}
|
||||
|
||||
fn build_tape(
|
||||
&mut self,
|
||||
root: NodeID,
|
||||
root_step: StepBoxed,
|
||||
mut builder: CheckpointerBuilder,
|
||||
) -> (Vec<Vec<StepBoxed>>, CheckpointerBuilder) {
|
||||
let mut tape = (0..root_step.order())
|
||||
.map(|_| Vec::with_capacity(1))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| {
|
||||
let order = step.order();
|
||||
if order == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(steps) = tape.get_mut(order - 1) {
|
||||
steps.push(step);
|
||||
}
|
||||
|
||||
if let Some(node_builder) = self.actions_builder.remove(&id) {
|
||||
builder.extend(node_builder);
|
||||
}
|
||||
});
|
||||
|
||||
(tape, builder)
|
||||
}
|
||||
|
||||
fn execute_steps(
|
||||
tape: Vec<Vec<StepBoxed>>,
|
||||
mut grads: Gradients,
|
||||
mut checkpointer: Checkpointer,
|
||||
) -> Gradients {
|
||||
tape.into_iter().rev().for_each(|steps| {
|
||||
steps
|
||||
.into_iter()
|
||||
.for_each(|step| step.step(&mut grads, &mut checkpointer))
|
||||
});
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
// For checkpointing tests
|
||||
assert!(checkpointer.is_empty());
|
||||
grads
|
||||
}
|
||||
}
|
|
@ -1,18 +1,22 @@
|
|||
use burn_tensor::backend::Backend;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
||||
grads::Gradients,
|
||||
graph::{ComputingProperty, Graph, Node, NodeID, NodeRef, Requirement, Step},
|
||||
graph::{ComputingProperty, Node, NodeID, NodeRef, Requirement, Step},
|
||||
runtime::{AutodiffClient, AutodiffClientImpl},
|
||||
};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutodiffTensor<B: Backend, const D: usize> {
|
||||
pub primitive: B::FloatTensorPrimitive<D>,
|
||||
pub node: NodeRef,
|
||||
pub graph: Graph,
|
||||
pub rc: NodeRefCount,
|
||||
}
|
||||
|
||||
pub type NodeRefCount = Arc<NodeID>;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct RootStep {
|
||||
node: NodeRef,
|
||||
|
@ -23,8 +27,16 @@ impl Step for RootStep {
|
|||
// Nothing to do
|
||||
}
|
||||
|
||||
fn node(&self) -> NodeRef {
|
||||
self.node.clone()
|
||||
fn node(&self) -> NodeID {
|
||||
self.node.id
|
||||
}
|
||||
|
||||
fn parents(&self) -> Vec<NodeID> {
|
||||
self.node.parents.clone()
|
||||
}
|
||||
|
||||
fn order(&self) -> usize {
|
||||
self.node.order
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,13 +50,14 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
|
|||
id,
|
||||
Requirement::None,
|
||||
ComputingProperty::Ambiguous,
|
||||
AutodiffClientImpl::new(),
|
||||
)
|
||||
.into();
|
||||
|
||||
Self {
|
||||
rc: Arc::new(node.id),
|
||||
primitive,
|
||||
node,
|
||||
graph: Graph::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,33 +80,26 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
|
|||
self.node = Node::new(
|
||||
vec![],
|
||||
0,
|
||||
self.node.id.clone(),
|
||||
self.node.id,
|
||||
Requirement::Grad,
|
||||
self.node.properties.clone(),
|
||||
self.node.client.clone(),
|
||||
)
|
||||
.into();
|
||||
let ops = RootStep::new(self.node.clone());
|
||||
let step = RootStep::new(self.node.clone());
|
||||
|
||||
self.register_step(ops)
|
||||
self.register_step(step, CheckpointerBuilder::default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tensor from parent infos.
|
||||
pub fn from_parents<I: Iterator<Item = Graph>>(
|
||||
pub fn from_parents(
|
||||
primitive: B::FloatTensorPrimitive<D>,
|
||||
parent_nodes: &[NodeRef],
|
||||
parent_graphs: I,
|
||||
requirement: Requirement,
|
||||
computing_properties: ComputingProperty,
|
||||
checkpointer_builder: CheckpointerBuilder,
|
||||
) -> Self {
|
||||
let graph = parent_graphs
|
||||
.reduce(|acc, graph| acc.merge(graph))
|
||||
.unwrap_or_else(Graph::new);
|
||||
|
||||
graph.extend_checkpointer_builder(checkpointer_builder);
|
||||
|
||||
let order = parent_nodes
|
||||
.iter()
|
||||
.map(|node| node.order)
|
||||
|
@ -101,25 +107,47 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
|
|||
.unwrap_or(0)
|
||||
+ 1;
|
||||
|
||||
let client = parent_nodes
|
||||
.first()
|
||||
.map(|node| node.client.clone())
|
||||
.unwrap_or_else(AutodiffClientImpl::new);
|
||||
|
||||
let node: NodeRef = Node::new(
|
||||
parent_nodes.iter().map(|node| node.id.clone()).collect(),
|
||||
parent_nodes.iter().map(|node| node.id).collect(),
|
||||
order,
|
||||
NodeID::new(),
|
||||
requirement,
|
||||
computing_properties,
|
||||
client,
|
||||
)
|
||||
.into();
|
||||
|
||||
Self {
|
||||
rc: Arc::new(node.id),
|
||||
primitive,
|
||||
node,
|
||||
graph,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a step into a graph for that tensor.
|
||||
pub fn register_step<O: Step + 'static>(mut self, ops: O) -> Self {
|
||||
self.graph = self.graph.register(&self.node.id, Box::new(ops));
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// This should be called only once per tensor.
|
||||
pub fn register_step<S: Step + 'static>(
|
||||
self,
|
||||
step_that_created_the_tensor: S,
|
||||
actions: CheckpointerBuilder,
|
||||
) -> Self {
|
||||
self.node.client.register(
|
||||
self.rc.clone(),
|
||||
Box::new(step_that_created_the_tensor),
|
||||
actions,
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn into_primitive(self) -> B::FloatTensorPrimitive<D> {
|
||||
self.primitive
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,10 +24,10 @@ where
|
|||
Server: ComputeServer,
|
||||
{
|
||||
_handle: thread::JoinHandle<()>,
|
||||
sender: mpsc::SyncSender<Message<Server>>,
|
||||
sender: mpsc::Sender<Message<Server>>,
|
||||
}
|
||||
|
||||
type Callback<Response> = mpsc::SyncSender<Response>;
|
||||
type Callback<Response> = mpsc::Sender<Response>;
|
||||
|
||||
enum Message<Server>
|
||||
where
|
||||
|
@ -45,8 +45,8 @@ where
|
|||
Server: ComputeServer + 'static,
|
||||
{
|
||||
/// Create a new mpsc compute channel.
|
||||
pub fn new(mut server: Server, bound: usize) -> Self {
|
||||
let (sender, receiver) = mpsc::sync_channel(bound);
|
||||
pub fn new(mut server: Server) -> Self {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
|
||||
let _handle = thread::spawn(move || {
|
||||
while let Ok(message) = receiver.recv() {
|
||||
|
@ -94,7 +94,7 @@ where
|
|||
Server: ComputeServer + 'static,
|
||||
{
|
||||
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
|
@ -105,7 +105,7 @@ where
|
|||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
|
@ -116,7 +116,7 @@ where
|
|||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
|
@ -140,7 +140,7 @@ where
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
self.state.sender.send(Message::Sync(callback)).unwrap();
|
||||
|
||||
|
|
|
@ -20,8 +20,10 @@ default = [
|
|||
"burn-tch?/default",
|
||||
"burn-tensor/default",
|
||||
"burn-wgpu?/default",
|
||||
"burn-autodiff?/default",
|
||||
]
|
||||
std = [
|
||||
"burn-autodiff?/std",
|
||||
"bincode/std",
|
||||
"burn-candle?/std",
|
||||
"burn-common/std",
|
||||
|
|
|
@ -18,10 +18,27 @@ pub trait DataLoaderIterator<O>: Iterator<Item = O> {
|
|||
}
|
||||
|
||||
/// A data loader that can be used to iterate over a dataset.
|
||||
pub trait DataLoader<O> {
|
||||
pub trait DataLoader<O>: Send {
|
||||
/// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
|
||||
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
|
||||
/// The number of items (not the number of batches nor the number of iterations),
|
||||
/// corresponding to the items_total of the progress returned by the iterator.
|
||||
fn num_items(&self) -> usize;
|
||||
}
|
||||
|
||||
/// A super trait for [dataloader](DataLoader) that allows it to be cloned dynamically.
|
||||
///
|
||||
/// Any dataloader that implements [Clone] should also implement this automatically.
|
||||
pub trait DynDataLoader<O>: DataLoader<O> {
|
||||
/// Clone the dataloader and returns a new one.
|
||||
fn clone_dyn(&self) -> Box<dyn DynDataLoader<O>>;
|
||||
}
|
||||
|
||||
impl<D, O> DynDataLoader<O> for D
|
||||
where
|
||||
D: DataLoader<O> + Clone + 'static,
|
||||
{
|
||||
fn clone_dyn(&self) -> Box<dyn DynDataLoader<O>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::{
|
||||
batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, MultiThreadDataLoader,
|
||||
Progress,
|
||||
batcher::DynBatcher, BatchStrategy, DataLoader, DataLoaderIterator, DynDataLoader,
|
||||
MultiThreadDataLoader, Progress,
|
||||
};
|
||||
use burn_dataset::{
|
||||
transform::{PartialDataset, ShuffledDataset},
|
||||
|
@ -13,8 +13,19 @@ use std::sync::Arc;
|
|||
pub struct BatchDataLoader<I, O> {
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
rng: Option<spin::Mutex<rand::rngs::StdRng>>,
|
||||
batcher: Box<dyn DynBatcher<I, O>>,
|
||||
rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
|
||||
}
|
||||
|
||||
impl<I, O> Clone for BatchDataLoader<I, O> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
strategy: self.strategy.clone_dyn(),
|
||||
dataset: self.dataset.clone(),
|
||||
batcher: self.batcher.clone_dyn(),
|
||||
rng: self.rng.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> BatchDataLoader<I, O> {
|
||||
|
@ -34,14 +45,14 @@ impl<I, O> BatchDataLoader<I, O> {
|
|||
pub fn new(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
batcher: Box<dyn DynBatcher<I, O>>,
|
||||
rng: Option<rand::rngs::StdRng>,
|
||||
) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
dataset,
|
||||
batcher,
|
||||
rng: rng.map(spin::Mutex::new),
|
||||
rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -51,13 +62,13 @@ struct BatchDataloaderIterator<I, O> {
|
|||
current_index: usize,
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
batcher: Box<dyn DynBatcher<I, O>>,
|
||||
}
|
||||
|
||||
impl<I, O> BatchDataLoader<I, O>
|
||||
where
|
||||
I: Send + Sync + Clone + 'static,
|
||||
O: Send + Sync + Clone + 'static,
|
||||
O: Send + Clone + 'static,
|
||||
{
|
||||
/// Creates a new multi-threaded batch data loader.
|
||||
///
|
||||
|
@ -74,14 +85,13 @@ where
|
|||
pub fn multi_thread(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
batcher: Box<dyn DynBatcher<I, O>>,
|
||||
num_threads: usize,
|
||||
mut rng: Option<rand::rngs::StdRng>,
|
||||
) -> MultiThreadDataLoader<O> {
|
||||
let datasets = PartialDataset::split(dataset, num_threads);
|
||||
|
||||
let mut dataloaders: Vec<Arc<dyn DataLoader<_> + Send + Sync>> =
|
||||
Vec::with_capacity(num_threads);
|
||||
let mut dataloaders = Vec::with_capacity(num_threads);
|
||||
|
||||
// Create more rngs from the first one, one for each new dataloader.
|
||||
let rngs = (0..num_threads).map(|_| {
|
||||
|
@ -90,17 +100,21 @@ where
|
|||
});
|
||||
|
||||
for (dataset, rng) in datasets.into_iter().zip(rngs) {
|
||||
let strategy = strategy.new_like();
|
||||
let strategy = strategy.clone_dyn();
|
||||
let dataloader =
|
||||
BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone(), rng);
|
||||
let dataloader = Arc::new(dataloader);
|
||||
BatchDataLoader::new(strategy, Arc::new(dataset), batcher.clone_dyn(), rng);
|
||||
let dataloader: Box<dyn DynDataLoader<_>> = Box::new(dataloader);
|
||||
dataloaders.push(dataloader);
|
||||
}
|
||||
MultiThreadDataLoader::new(dataloaders)
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Send + Sync + Clone + 'static, O: Send + Sync> DataLoader<O> for BatchDataLoader<I, O> {
|
||||
impl<I, O> DataLoader<O> for BatchDataLoader<I, O>
|
||||
where
|
||||
I: Send + Sync + Clone + 'static,
|
||||
O: Send + 'static,
|
||||
{
|
||||
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
|
||||
// When starting a new iteration, we first check if the dataloader was created with an rng,
|
||||
// implying that we should shuffle the dataset beforehand, while advancing the current
|
||||
|
@ -117,9 +131,9 @@ impl<I: Send + Sync + Clone + 'static, O: Send + Sync> DataLoader<O> for BatchDa
|
|||
None => self.dataset.clone(),
|
||||
};
|
||||
Box::new(BatchDataloaderIterator::new(
|
||||
self.strategy.new_like(),
|
||||
self.strategy.clone_dyn(),
|
||||
dataset,
|
||||
self.batcher.clone(),
|
||||
self.batcher.clone_dyn(),
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -143,7 +157,7 @@ impl<I, O> BatchDataloaderIterator<I, O> {
|
|||
pub fn new(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
batcher: Box<dyn DynBatcher<I, O>>,
|
||||
) -> Self {
|
||||
BatchDataloaderIterator {
|
||||
current_index: 0,
|
||||
|
@ -192,7 +206,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_batch_dataloader() {
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
let batcher = Box::new(TestBatcher::new());
|
||||
let dataset = Arc::new(FakeDataset::<String>::new(27));
|
||||
let dataloader = BatchDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
|
@ -219,12 +233,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_multi_thread_batch_dataloader() {
|
||||
let batcher = Arc::new(TestBatcher::new());
|
||||
let batcher = Box::new(TestBatcher::new());
|
||||
let dataset = Arc::new(FakeDataset::<String>::new(27));
|
||||
let dataloader_single_thread = BatchDataLoader::new(
|
||||
Box::new(FixBatchStrategy::new(5)),
|
||||
dataset.clone(),
|
||||
batcher.clone(),
|
||||
batcher.clone_dyn(),
|
||||
None,
|
||||
);
|
||||
let dataloader_multi_thread = BatchDataLoader::multi_thread(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/// A trait for batching items of type `I` into items of type `O`.
|
||||
pub trait Batcher<I, O>: Send + Sync {
|
||||
pub trait Batcher<I, O>: Send {
|
||||
/// Batches the given items.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -12,9 +12,27 @@ pub trait Batcher<I, O>: Send + Sync {
|
|||
fn batch(&self, items: Vec<I>) -> O;
|
||||
}
|
||||
|
||||
/// A super trait for [batcher](Batcher) that allows it to be cloned dynamically.
|
||||
///
|
||||
/// Any batcher that implements [Clone] should also implement this automatically.
|
||||
pub trait DynBatcher<I, O>: Send + Batcher<I, O> {
|
||||
/// Clone the batcher and returns a new one.
|
||||
fn clone_dyn(&self) -> Box<dyn DynBatcher<I, O>>;
|
||||
}
|
||||
|
||||
impl<B, I, O> DynBatcher<I, O> for B
|
||||
where
|
||||
B: Batcher<I, O> + Clone + 'static,
|
||||
{
|
||||
fn clone_dyn(&self) -> Box<dyn DynBatcher<I, O>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(new)]
|
||||
#[derive(new, Clone)]
|
||||
pub struct TestBatcher;
|
||||
|
||||
#[cfg(test)]
|
||||
impl<I> Batcher<I, Vec<I>> for TestBatcher {
|
||||
fn batch(&self, items: Vec<I>) -> Vec<I> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{batcher::Batcher, BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy};
|
||||
use super::{batcher::DynBatcher, BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy};
|
||||
use burn_dataset::Dataset;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::sync::Arc;
|
||||
|
@ -6,7 +6,7 @@ use std::sync::Arc;
|
|||
/// A builder for data loaders.
|
||||
pub struct DataLoaderBuilder<I, O> {
|
||||
strategy: Option<Box<dyn BatchStrategy<I>>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
batcher: Box<dyn DynBatcher<I, O>>,
|
||||
num_threads: Option<usize>,
|
||||
shuffle: Option<u64>,
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ pub struct DataLoaderBuilder<I, O> {
|
|||
impl<I, O> DataLoaderBuilder<I, O>
|
||||
where
|
||||
I: Send + Sync + Clone + std::fmt::Debug + 'static,
|
||||
O: Send + Sync + Clone + std::fmt::Debug + 'static,
|
||||
O: Send + Clone + std::fmt::Debug + 'static,
|
||||
{
|
||||
/// Creates a new data loader builder.
|
||||
///
|
||||
|
@ -27,10 +27,10 @@ where
|
|||
/// The data loader builder.
|
||||
pub fn new<B>(batcher: B) -> Self
|
||||
where
|
||||
B: Batcher<I, O> + 'static,
|
||||
B: DynBatcher<I, O> + 'static,
|
||||
{
|
||||
Self {
|
||||
batcher: Arc::new(batcher),
|
||||
batcher: Box::new(batcher),
|
||||
strategy: None,
|
||||
num_threads: None,
|
||||
shuffle: None,
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use super::{DataLoader, DataLoaderIterator, Progress};
|
||||
use std::sync::{mpsc, Arc};
|
||||
use super::{DataLoader, DataLoaderIterator, DynDataLoader, Progress};
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
const MAX_QUEUED_ITEMS: usize = 100;
|
||||
|
||||
/// A multi-threaded data loader that can be used to iterate over a dataset.
|
||||
pub struct MultiThreadDataLoader<O> {
|
||||
dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>,
|
||||
dataloaders: Vec<Box<dyn DynDataLoader<O>>>,
|
||||
}
|
||||
|
||||
/// A message that can be sent between threads.
|
||||
|
@ -36,7 +36,7 @@ impl<O> MultiThreadDataLoader<O> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The multi-threaded data loader.
|
||||
pub fn new(dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>) -> Self {
|
||||
pub fn new(dataloaders: Vec<Box<dyn DynDataLoader<O>>>) -> Self {
|
||||
Self { dataloaders }
|
||||
}
|
||||
}
|
||||
|
@ -52,11 +52,10 @@ where
|
|||
|
||||
let handlers: Vec<_> = self
|
||||
.dataloaders
|
||||
.clone()
|
||||
.into_iter()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, dataloader)| {
|
||||
let dataloader_cloned = dataloader;
|
||||
let dataloader_cloned = dataloader.clone_dyn();
|
||||
let sender_cloned = sender.clone();
|
||||
progresses.push(Progress::new(0, dataloader_cloned.num_items()));
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/// A strategy to batch items.
|
||||
pub trait BatchStrategy<I>: Send + Sync {
|
||||
pub trait BatchStrategy<I>: Send {
|
||||
/// Adds an item to the strategy.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -23,7 +23,7 @@ pub trait BatchStrategy<I>: Send + Sync {
|
|||
/// # Returns
|
||||
///
|
||||
/// The new strategy.
|
||||
fn new_like(&self) -> Box<dyn BatchStrategy<I>>;
|
||||
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
|
||||
}
|
||||
|
||||
/// A strategy to batch items with a fixed batch size.
|
||||
|
@ -50,7 +50,7 @@ impl<I> FixBatchStrategy<I> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
|
||||
impl<I: Send + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
|
||||
fn add(&mut self, item: I) {
|
||||
self.items.push(item);
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
|
|||
Some(items)
|
||||
}
|
||||
|
||||
fn new_like(&self) -> Box<dyn BatchStrategy<I>> {
|
||||
fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
|
||||
Box::new(Self::new(self.batch_size))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ macro_rules! module {
|
|||
/// my_other_field: usize,
|
||||
/// }
|
||||
/// ```
|
||||
pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
||||
pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
|
||||
/// Type to save and load the module.
|
||||
type Record: Record<B>;
|
||||
|
||||
|
@ -238,7 +238,7 @@ pub trait ModuleMapper<B: Backend> {
|
|||
}
|
||||
|
||||
/// Module with auto-differentiation backend.
|
||||
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + Sync + core::fmt::Debug {
|
||||
pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
|
||||
/// Inner module without auto-differentiation.
|
||||
type InnerModule: Module<B::InnerBackend>;
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use super::ParamId;
|
||||
use alloc::boxed::Box;
|
||||
use alloc::format;
|
||||
use burn_common::stub::{RwLock, SyncOnceCell};
|
||||
use burn_common::stub::RwLock;
|
||||
use core::cell::OnceCell;
|
||||
use core::ops::Deref;
|
||||
|
||||
/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
|
||||
|
@ -31,7 +32,7 @@ use core::ops::Deref;
|
|||
/// ```
|
||||
pub struct Param<T: Parameter> {
|
||||
pub(crate) id: ParamId,
|
||||
state: SyncOnceCell<T>,
|
||||
state: OnceCell<T>,
|
||||
/// The locking is only required because of `lazy_device` and `lazy_is_require_grad`.
|
||||
///
|
||||
/// Because of once cell, we have a guarantee that the initialization will only be called once,
|
||||
|
@ -54,7 +55,7 @@ impl<T: Parameter> core::fmt::Debug for Param<T> {
|
|||
}
|
||||
|
||||
/// Trait that defines what is necessary for a type to be a parameter.
|
||||
pub trait Parameter: Clone + core::fmt::Debug + Send + Sync {
|
||||
pub trait Parameter: Clone + core::fmt::Debug + Send {
|
||||
/// The device type to be used.
|
||||
type Device: Clone;
|
||||
|
||||
|
@ -70,7 +71,7 @@ pub trait Parameter: Clone + core::fmt::Debug + Send + Sync {
|
|||
|
||||
#[allow(clippy::type_complexity)]
|
||||
struct Uninitialized<P: Parameter> {
|
||||
init: Box<dyn Fn(&P::Device, bool) -> P + Send + Sync>,
|
||||
init: Box<dyn Fn(&P::Device, bool) -> P + Send>,
|
||||
device: P::Device,
|
||||
is_require_grad: bool,
|
||||
}
|
||||
|
@ -87,7 +88,7 @@ impl<T: Parameter> Param<T> {
|
|||
pub fn initialized(id: ParamId, value: T) -> Self {
|
||||
Self {
|
||||
id,
|
||||
state: SyncOnceCell::initialized(value),
|
||||
state: OnceCell::from(value),
|
||||
initialization: None,
|
||||
}
|
||||
}
|
||||
|
@ -95,11 +96,11 @@ impl<T: Parameter> Param<T> {
|
|||
/// Create a new parameter that is not already initialized.
|
||||
pub fn uninitialized<F>(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self
|
||||
where
|
||||
F: Fn(&T::Device, bool) -> T + Send + Sync + 'static,
|
||||
F: Fn(&T::Device, bool) -> T + Send + 'static,
|
||||
{
|
||||
Self {
|
||||
id,
|
||||
state: SyncOnceCell::new(),
|
||||
state: OnceCell::new(),
|
||||
initialization: Some(RwLock::new(Some(Uninitialized {
|
||||
init: Box::new(init),
|
||||
device,
|
||||
|
@ -149,7 +150,7 @@ impl<T: Parameter> Param<T> {
|
|||
|
||||
Self {
|
||||
id,
|
||||
state: SyncOnceCell::initialized(tensor),
|
||||
state: OnceCell::from(tensor),
|
||||
initialization: None,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use core::fmt::Debug;
|
|||
|
||||
impl<T, B> Module<B> for Option<T>
|
||||
where
|
||||
T: Module<B> + Debug + Send + Sync + Clone,
|
||||
T: Module<B> + Debug + Send + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
type Record = Option<T::Record>;
|
||||
|
@ -48,7 +48,7 @@ where
|
|||
|
||||
impl<T, B> AutodiffModule<B> for Option<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Sync + Clone,
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type InnerModule = Option<T::InnerModule>;
|
||||
|
@ -60,7 +60,7 @@ where
|
|||
|
||||
impl<T, B> Module<B> for Vec<T>
|
||||
where
|
||||
T: Module<B> + Debug + Send + Sync + Clone,
|
||||
T: Module<B> + Debug + Send + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
type Record = Vec<T::Record>;
|
||||
|
@ -116,7 +116,7 @@ where
|
|||
|
||||
impl<T, B> AutodiffModule<B> for Vec<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Sync + Clone,
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
type InnerModule = Vec<T::InnerModule>;
|
||||
|
@ -128,7 +128,7 @@ where
|
|||
|
||||
impl<const N: usize, T, B> Module<B> for [T; N]
|
||||
where
|
||||
T: Module<B> + Debug + Send + Sync + Clone + Copy,
|
||||
T: Module<B> + Debug + Send + Clone + Copy,
|
||||
T::Record: Debug,
|
||||
B: Backend,
|
||||
{
|
||||
|
@ -185,7 +185,7 @@ where
|
|||
|
||||
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Sync + Clone + Copy,
|
||||
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
|
||||
T::InnerModule: Copy + Debug,
|
||||
<T::InnerModule as Module<B::InnerBackend>>::Record: Debug,
|
||||
<T as Module<B>>::Record: Debug,
|
||||
|
@ -210,7 +210,7 @@ macro_rules! impl_module_tuple {
|
|||
impl<B, $($l,)*> Module<B> for ($($l,)*)
|
||||
where
|
||||
B: Backend,
|
||||
$($l: Module<B> + Debug + Send + Sync + Clone,)*
|
||||
$($l: Module<B> + Debug + Send + Clone,)*
|
||||
{
|
||||
type Record = ($($l::Record),*);
|
||||
|
||||
|
@ -247,7 +247,7 @@ macro_rules! impl_module_tuple {
|
|||
impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
$($l: AutodiffModule<B> + Debug + Send + Sync + Clone,)*
|
||||
$($l: AutodiffModule<B> + Debug + Send + Clone,)*
|
||||
{
|
||||
type InnerModule = ($($l::InnerModule,)*);
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::ParamId;
|
|||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
|
@ -10,7 +11,6 @@ use burn_tensor::{
|
|||
#[cfg(feature = "std")]
|
||||
mod threading {
|
||||
pub(super) use std::collections::HashMap;
|
||||
pub(super) use std::sync::{Mutex, RwLock};
|
||||
pub(super) use std::thread::ThreadId;
|
||||
|
||||
#[inline(always)]
|
||||
|
@ -21,7 +21,7 @@ mod threading {
|
|||
|
||||
#[cfg(not(feature = "std"))]
|
||||
mod threading {
|
||||
pub(super) use burn_common::stub::{Mutex, RwLock, ThreadId};
|
||||
pub(super) use burn_common::stub::ThreadId;
|
||||
pub(super) use hashbrown::HashMap;
|
||||
|
||||
#[inline(always)]
|
||||
|
@ -42,20 +42,20 @@ use threading::*;
|
|||
pub struct RunningState<V> {
|
||||
id: ParamId,
|
||||
values: Arc<Mutex<HashMap<ThreadId, V>>>,
|
||||
value: Arc<RwLock<V>>,
|
||||
value: Arc<Mutex<V>>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
let tensor = self.value.read().unwrap();
|
||||
let tensor = self.value.lock().unwrap();
|
||||
|
||||
visitor.visit_float(&self.id, &tensor)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let mut tensor = self.value.write().unwrap();
|
||||
let mut tensor = self.value.lock().unwrap();
|
||||
let tensor_out = mapper.map_float(&self.id, tensor.clone());
|
||||
|
||||
*tensor = tensor_out;
|
||||
|
@ -66,13 +66,13 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
|||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self.sync();
|
||||
let tensor = self.value.read().unwrap();
|
||||
let tensor = self.value.lock().unwrap();
|
||||
|
||||
Param::initialized(self.id, tensor.clone())
|
||||
}
|
||||
|
||||
fn load_record(mut self, record: Self::Record) -> Self {
|
||||
let mut tensor = self.value.write().unwrap();
|
||||
let mut tensor = self.value.lock().unwrap();
|
||||
*tensor = record.val().to_device(&tensor.device());
|
||||
self.id = record.id;
|
||||
|
||||
|
@ -82,7 +82,7 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
|||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
let mut tensor = self.value.write().unwrap();
|
||||
let mut tensor = self.value.lock().unwrap();
|
||||
let tensor_out = tensor.clone().to_device(device);
|
||||
|
||||
*tensor = tensor_out;
|
||||
|
@ -99,7 +99,7 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
|||
&self,
|
||||
mut devices: Vec<<B as Backend>::Device>,
|
||||
) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.value.read().unwrap().device();
|
||||
let device = self.value.lock().unwrap().device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
|
@ -115,7 +115,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
Self {
|
||||
id: ParamId::new(),
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(RwLock::new(value)),
|
||||
value: Arc::new(Mutex::new(value)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
Self {
|
||||
id,
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(RwLock::new(value)),
|
||||
value: Arc::new(Mutex::new(value)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -134,7 +134,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
Self {
|
||||
id: record.id,
|
||||
values: Arc::new(Mutex::new(HashMap::new())),
|
||||
value: Arc::new(RwLock::new(tensor)),
|
||||
value: Arc::new(Mutex::new(tensor)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
///
|
||||
/// The current value might be outdated by one update.
|
||||
pub fn value(&self) -> Tensor<B, D> {
|
||||
let value = self.value.read().unwrap();
|
||||
let value = self.value.lock().unwrap();
|
||||
value.clone()
|
||||
}
|
||||
|
||||
|
@ -174,7 +174,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
self.update_value(&mut map);
|
||||
}
|
||||
|
||||
let value = self.value.read().unwrap();
|
||||
let value = self.value.lock().unwrap();
|
||||
value.clone()
|
||||
}
|
||||
|
||||
|
@ -204,7 +204,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
|
||||
if let Some(value) = value_updated {
|
||||
let value = value.div_scalar(counter);
|
||||
let mut value_old = self.value.write().unwrap();
|
||||
let mut value_old = self.value.lock().unwrap();
|
||||
*value_old = value;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::tensor::backend::AutodiffBackend;
|
|||
use crate::LearningRate;
|
||||
|
||||
/// General trait to optimize [module](AutodiffModule).
|
||||
pub trait Optimizer<M, B>: Send + Sync
|
||||
pub trait Optimizer<M, B>: Send
|
||||
where
|
||||
M: AutodiffModule<B>,
|
||||
B: AutodiffBackend,
|
||||
|
|
|
@ -5,7 +5,7 @@ use super::PrecisionSettings;
|
|||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
|
||||
pub trait Record<B: Backend>: Send + Sync {
|
||||
pub trait Record<B: Backend>: Send {
|
||||
/// Type of the item that can be serialized and deserialized.
|
||||
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ pub struct Bool;
|
|||
/// A type-level representation of the kind of a tensor.
|
||||
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
|
||||
/// The primitive type of the tensor.
|
||||
type Primitive<const D: usize>: Clone + core::fmt::Debug + Sync + Send;
|
||||
type Primitive<const D: usize>: Clone + core::fmt::Debug + Send;
|
||||
|
||||
/// The name of the tensor kind.
|
||||
fn name() -> &'static str;
|
||||
|
|
|
@ -34,8 +34,8 @@ use super::BackendBridge;
|
|||
///
|
||||
/// ### Multi-Threaded
|
||||
///
|
||||
/// Backend tensor types are all `Clone` + `Sync` + `Send`, which allows them to be safely
|
||||
/// shared between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
|
||||
/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
|
||||
/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
|
||||
/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
|
||||
/// reuse tensors' buffer without locking; see the next section on the Mutable API.
|
||||
///
|
||||
|
@ -72,17 +72,17 @@ pub trait Backend:
|
|||
type FullPrecisionBridge: BackendBridge<Self> + 'static;
|
||||
|
||||
/// Tensor primitive to be used for all float operations.
|
||||
type FloatTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
type FloatTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
|
||||
/// Float element type.
|
||||
type FloatElem: Element;
|
||||
|
||||
/// Tensor primitive to be used for all int operations.
|
||||
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
type IntTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
|
||||
/// Int element type.
|
||||
type IntElem: Element;
|
||||
|
||||
/// Tensor primitive to be used for all bool operations.
|
||||
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
type BoolTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
|
||||
|
||||
/// If autodiff is enabled.
|
||||
fn ad_enabled() -> bool {
|
||||
|
@ -109,7 +109,7 @@ pub trait AutodiffBackend: Backend {
|
|||
>;
|
||||
|
||||
/// Gradients type.
|
||||
type Gradients: Send + Sync;
|
||||
type Gradients: Send;
|
||||
|
||||
/// Backward pass.
|
||||
///
|
||||
|
|
|
@ -12,7 +12,7 @@ use crate::{backend::Backend, Tensor};
|
|||
/// Contains tensor of arbitrary dimension.
|
||||
#[derive(Debug)]
|
||||
pub struct TensorContainer<ID> {
|
||||
tensors: HashMap<ID, Box<dyn Any + Send + Sync>>,
|
||||
tensors: HashMap<ID, Box<dyn Any + Send>>,
|
||||
}
|
||||
|
||||
impl<ID> Default for TensorContainer<ID>
|
||||
|
|
|
@ -24,8 +24,8 @@ use burn_core::tensor::backend::AutodiffBackend;
|
|||
/// Struct to configure and create a [learner](Learner).
|
||||
pub struct LearnerBuilder<B, T, V, M, O, S>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
T: Send + 'static,
|
||||
V: Send + 'static,
|
||||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
|
@ -58,8 +58,8 @@ where
|
|||
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
T: Send + 'static,
|
||||
V: Send + 'static,
|
||||
M: AutodiffModule<B> + core::fmt::Display + 'static,
|
||||
O: Optimizer<M, B>,
|
||||
S: LrScheduler<B>,
|
||||
|
|
|
@ -11,6 +11,7 @@ const MEAN: [f32; 3] = [0.4914, 0.48216, 0.44653];
|
|||
const STD: [f32; 3] = [0.24703, 0.24349, 0.26159];
|
||||
|
||||
/// Normalizer for the CIFAR-10 dataset.
|
||||
#[derive(Clone)]
|
||||
pub struct Normalizer<B: Backend> {
|
||||
pub mean: Tensor<B, 4>,
|
||||
pub std: Tensor<B, 4>,
|
||||
|
@ -36,6 +37,7 @@ impl<B: Backend> Normalizer<B> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClassificationBatcher<B: Backend> {
|
||||
normalizer: Normalizer<B>,
|
||||
device: B::Device,
|
||||
|
|
|
@ -87,13 +87,13 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
|||
// Register the gradient for each variable based on whether they are marked as
|
||||
// `tracked`.
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, D>(node, grad_bias);
|
||||
grads.register::<B, D>(node.id, grad_bias);
|
||||
}
|
||||
if let Some(node) = node_lhs {
|
||||
grads.register::<B, D>(node, grad_lhs);
|
||||
grads.register::<B, D>(node.id, grad_lhs);
|
||||
}
|
||||
if let Some(node) = node_rhs {
|
||||
grads.register::<B, D>(node, grad_rhs);
|
||||
grads.register::<B, D>(node.id, grad_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -102,10 +102,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
|||
//
|
||||
// Each node can be fetched with `ops.parents` in the same order as defined here.
|
||||
match FusedMatmulAddReluBackward
|
||||
.prepare::<C>(
|
||||
[lhs.node.clone(), rhs.node.clone(), bias.node.clone()],
|
||||
[lhs.graph.clone(), rhs.graph.clone(), bias.graph.clone()],
|
||||
)
|
||||
.prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
|
||||
// Marks the operation as compute bound, meaning it will save its
|
||||
// state instead of recomputing itself during checkpointing
|
||||
.compute_bound()
|
||||
|
|
|
@ -3,6 +3,7 @@ use burn::{
|
|||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ use burn::{
|
|||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MnistBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
|
|
@ -109,6 +109,7 @@ impl DiabetesDataset {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DiabetesBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_m
|
|||
use std::sync::Arc;
|
||||
|
||||
/// Struct for batching text classification items
|
||||
#[derive(new)]
|
||||
#[derive(Clone, new)]
|
||||
pub struct TextClassificationBatcher<B: Backend> {
|
||||
tokenizer: Arc<dyn Tokenizer>, // Tokenizer for converting text to token IDs
|
||||
device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device)
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::{dataset::TextGenerationItem, tokenizer::Tokenizer};
|
|||
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(new)]
|
||||
#[derive(Clone, new)]
|
||||
pub struct TextGenerationBatcher {
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
max_seq_length: usize,
|
||||
|
|
Loading…
Reference in New Issue