diff --git a/Cargo.lock b/Cargo.lock index 9cc352006..06902b64f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -365,6 +365,7 @@ dependencies = [ "burn-tensor", "burn-tensor-testgen", "derive-new", + "log", "spin", ] diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml index 095491164..5f46dbf64 100644 --- a/crates/burn-autodiff/Cargo.toml +++ b/crates/burn-autodiff/Cargo.toml @@ -22,6 +22,7 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.13.0", opt derive-new = { workspace = true } spin = { workspace = true } +log = { workspace = true } [dev-dependencies] burn-tensor = { path = "../burn-tensor", version = "0.13.0", default-features = false, features = [ diff --git a/crates/burn-autodiff/src/runtime/memory_management.rs b/crates/burn-autodiff/src/runtime/memory_management.rs index d0fb5c5c6..1485a4650 100644 --- a/crates/burn-autodiff/src/runtime/memory_management.rs +++ b/crates/burn-autodiff/src/runtime/memory_management.rs @@ -61,6 +61,7 @@ impl GraphMemoryManagement { for node_id in graph.into_iter() { func(&node_id); + self.graphs.remove(&GraphId::new(*node_id)); } } @@ -258,6 +259,9 @@ mod tests { assert!(node_ids.contains(&node_1)); assert!(node_ids.contains(&node_2)); + assert_eq!(graph_mm.graphs.len(), 0); + assert_eq!(graph_mm.owned.len(), 0); + // Same but with free(node_2); graph_mm.register(node_1.clone(), vec![]); graph_mm.register(node_2.clone(), vec![*node_1]); @@ -267,5 +271,8 @@ mod tests { assert!(node_ids.contains(&node_1)); assert!(node_ids.contains(&node_2)); + + assert_eq!(graph_mm.graphs.len(), 0); + assert_eq!(graph_mm.owned.len(), 0); } } diff --git a/crates/burn-autodiff/src/runtime/server.rs b/crates/burn-autodiff/src/runtime/server.rs index 17fc8c841..c4581ee26 100644 --- a/crates/burn-autodiff/src/runtime/server.rs +++ b/crates/burn-autodiff/src/runtime/server.rs @@ -3,6 +3,7 @@ use crate::{ checkpoint::{base::Checkpointer, builder::CheckpointerBuilder}, grads::Gradients, graph::{traversal::BreadthFirstSearch, StepBoxed}, + runtime::memory_management::GraphId, tensor::NodeRefCount, NodeID, }; @@ -63,6 +64,10 @@ impl AutodiffServer { .collect::>(); BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| { + // We consume that node for the tape, so we should remove it from the + // memory_management. + self.memory_management.free_graph(GraphId::new(id), |_| {}); + let order = step.order(); if order == 0 { return;