fix: backward order

This commit is contained in:
nathaniel 2022-07-21 07:39:47 -04:00
parent 2f7f65cea5
commit 230cd01ea1
4 changed files with 44 additions and 19 deletions

View File

@ -1,19 +1,36 @@
use super::NodeStateRef;
use crate::ops::{RecordedOpsParent, RecordedOpsParentRef, RecordedOpsRef};
use std::{collections::HashSet, ops::Add, rc::Rc};
use std::{collections::HashMap, ops::Add, rc::Rc};
#[derive(Debug)]
pub struct Node<Out> {
pub id: String,
pub order: usize,
pub state: NodeStateRef<Out>,
pub ops: RecordedOpsRef<Out>,
}
impl<Out> Node<Out> {
pub fn new(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self {
let id = nanoid::nanoid!();
println!("Creating node {}", id);
Self { id, state, ops }
pub fn from_root(state: NodeStateRef<Out>, ops: RecordedOpsRef<Out>) -> Self {
let order = 0;
Self { order, state, ops }
}
pub fn from_unary<T>(
node: &Node<T>,
state: NodeStateRef<Out>,
ops: RecordedOpsRef<Out>,
) -> Self {
let order = node.order + 1;
Self { order, state, ops }
}
pub fn from_binary<Lhs, Rhs>(
lhs: &Node<Lhs>,
rhs: &Node<Rhs>,
state: NodeStateRef<Out>,
ops: RecordedOpsRef<Out>,
) -> Self {
let order = usize::max(lhs.order, rhs.order) + 1;
Self { order, state, ops }
}
}
@ -24,47 +41,55 @@ where
{
pub fn backward(&self) {
let grad = self.state.borrow().value().ones();
self.state.borrow_mut().update_grad(grad);
self.ops.backward_step(&self.state);
let mut nodes = HashMap::new();
let mut parents = self.ops.backward_parents();
let mut visited = HashSet::new();
loop {
match parents.pop() {
Some(node) => {
let id = node.id();
if visited.contains(&id) {
if id == 0 {
continue;
}
visited.insert(id);
node.backward_step();
if nodes.contains_key(&id) {
continue;
}
for parent in node.backward_parents() {
if !visited.contains(&parent.id()) {
if !nodes.contains_key(&parent.id()) {
parents.push(parent);
}
}
nodes.insert(id, node);
}
None => break,
}
}
for i in (0..self.order + 1).rev() {
if let Some(node) = nodes.get(&i) {
node.backward_step();
}
}
}
}
impl<T: std::fmt::Debug> RecordedOpsParent for Node<T> {
fn backward_step(&self) {
println!("backward node {}", self.id);
println!("backward node {}", self.order);
self.ops.backward_step(&self.state)
}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
self.ops.backward_parents()
}
fn id(&self) -> String {
self.id.clone()
fn id(&self) -> usize {
self.order
}
}

View File

@ -20,7 +20,7 @@ pub trait RecordedOps<T>: std::fmt::Debug {
}
pub trait RecordedOpsParent: std::fmt::Debug {
fn id(&self) -> String;
fn id(&self) -> usize;
fn backward_step(&self);
fn backward_parents(&self) -> Vec<RecordedOpsParentRef>;
}

View File

@ -121,7 +121,7 @@ macro_rules! execute_ops {
let ops = BinaryRecordedOps::new($lhs, $rhs, ops);
let ops = std::rc::Rc::new(ops);
let node = $crate::node::Node::new(state, ops);
let node = $crate::node::Node::from_binary(&$lhs, &$rhs, state, ops);
std::rc::Rc::new(node)
};
callback()
@ -139,7 +139,7 @@ macro_rules! execute_ops {
let ops = UnaryRecordedOps::new($input, ops);
let ops = std::rc::Rc::new(ops);
let node = $crate::node::Node::new(state, ops);
let node = $crate::node::Node::from_unary(&$input, state, ops);
std::rc::Rc::new(node)
};
callback()

View File

@ -41,7 +41,7 @@ where
let ops = InitRecordedOps::new();
let ops = Rc::new(ops);
let node = Rc::new(Node::new(state, ops));
let node = Rc::new(Node::from_root(state, ops));
Self { node, shape, kind }
}