Fix autodiff memory management graph cleaning (#1602)

This commit is contained in:
Nathaniel Simard 2024-04-11 16:21:00 -04:00 committed by GitHub
parent 0cbe9a927d
commit 07a61a1cec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 0 deletions

1
Cargo.lock generated
View File

@ -365,6 +365,7 @@ dependencies = [
"burn-tensor",
"burn-tensor-testgen",
"derive-new",
"log",
"spin",
]

View File

@ -22,6 +22,7 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.13.0", opt
derive-new = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.13.0", default-features = false, features = [

View File

@ -61,6 +61,7 @@ impl GraphMemoryManagement {
for node_id in graph.into_iter() {
func(&node_id);
self.graphs.remove(&GraphId::new(*node_id));
}
}
@ -258,6 +259,9 @@ mod tests {
assert!(node_ids.contains(&node_1));
assert!(node_ids.contains(&node_2));
assert_eq!(graph_mm.graphs.len(), 0);
assert_eq!(graph_mm.owned.len(), 0);
// Same but with free(node_2);
graph_mm.register(node_1.clone(), vec![]);
graph_mm.register(node_2.clone(), vec![*node_1]);
@ -267,5 +271,8 @@ mod tests {
assert!(node_ids.contains(&node_1));
assert!(node_ids.contains(&node_2));
assert_eq!(graph_mm.graphs.len(), 0);
assert_eq!(graph_mm.owned.len(), 0);
}
}

View File

@ -3,6 +3,7 @@ use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
grads::Gradients,
graph::{traversal::BreadthFirstSearch, StepBoxed},
runtime::memory_management::GraphId,
tensor::NodeRefCount,
NodeID,
};
@ -63,6 +64,10 @@ impl AutodiffServer {
.collect::<Vec<_>>();
BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| {
// We consume that node for the tape, so we should remove it from the
// memory_management.
self.memory_management.free_graph(GraphId::new(id), |_| {});
let order = step.order();
if order == 0 {
return;