mirror of https://github.com/tracel-ai/burn.git
Fix autodiff memory management graph cleaning (#1602)
This commit is contained in:
parent
0cbe9a927d
commit
07a61a1cec
|
@ -365,6 +365,7 @@ dependencies = [
|
||||||
"burn-tensor",
|
"burn-tensor",
|
||||||
"burn-tensor-testgen",
|
"burn-tensor-testgen",
|
||||||
"derive-new",
|
"derive-new",
|
||||||
|
"log",
|
||||||
"spin",
|
"spin",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.13.0", opt
|
||||||
|
|
||||||
derive-new = { workspace = true }
|
derive-new = { workspace = true }
|
||||||
spin = { workspace = true }
|
spin = { workspace = true }
|
||||||
|
log = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
burn-tensor = { path = "../burn-tensor", version = "0.13.0", default-features = false, features = [
|
burn-tensor = { path = "../burn-tensor", version = "0.13.0", default-features = false, features = [
|
||||||
|
|
|
@ -61,6 +61,7 @@ impl GraphMemoryManagement {
|
||||||
|
|
||||||
for node_id in graph.into_iter() {
|
for node_id in graph.into_iter() {
|
||||||
func(&node_id);
|
func(&node_id);
|
||||||
|
self.graphs.remove(&GraphId::new(*node_id));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,6 +259,9 @@ mod tests {
|
||||||
assert!(node_ids.contains(&node_1));
|
assert!(node_ids.contains(&node_1));
|
||||||
assert!(node_ids.contains(&node_2));
|
assert!(node_ids.contains(&node_2));
|
||||||
|
|
||||||
|
assert_eq!(graph_mm.graphs.len(), 0);
|
||||||
|
assert_eq!(graph_mm.owned.len(), 0);
|
||||||
|
|
||||||
// Same but with free(node_2);
|
// Same but with free(node_2);
|
||||||
graph_mm.register(node_1.clone(), vec![]);
|
graph_mm.register(node_1.clone(), vec![]);
|
||||||
graph_mm.register(node_2.clone(), vec![*node_1]);
|
graph_mm.register(node_2.clone(), vec![*node_1]);
|
||||||
|
@ -267,5 +271,8 @@ mod tests {
|
||||||
|
|
||||||
assert!(node_ids.contains(&node_1));
|
assert!(node_ids.contains(&node_1));
|
||||||
assert!(node_ids.contains(&node_2));
|
assert!(node_ids.contains(&node_2));
|
||||||
|
|
||||||
|
assert_eq!(graph_mm.graphs.len(), 0);
|
||||||
|
assert_eq!(graph_mm.owned.len(), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ use crate::{
|
||||||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
||||||
grads::Gradients,
|
grads::Gradients,
|
||||||
graph::{traversal::BreadthFirstSearch, StepBoxed},
|
graph::{traversal::BreadthFirstSearch, StepBoxed},
|
||||||
|
runtime::memory_management::GraphId,
|
||||||
tensor::NodeRefCount,
|
tensor::NodeRefCount,
|
||||||
NodeID,
|
NodeID,
|
||||||
};
|
};
|
||||||
|
@ -63,6 +64,10 @@ impl AutodiffServer {
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| {
|
BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| {
|
||||||
|
// We consume that node for the tape, so we should remove it from the
|
||||||
|
// memory_management.
|
||||||
|
self.memory_management.free_graph(GraphId::new(id), |_| {});
|
||||||
|
|
||||||
let order = step.order();
|
let order = step.order();
|
||||||
if order == 0 {
|
if order == 0 {
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Reference in New Issue