doc: add a simple example

This commit is contained in:
nathaniel 2022-07-26 10:04:13 -04:00
parent b306156cc2
commit 61b67b44ff
6 changed files with 36 additions and 5 deletions

View File

@ -72,7 +72,7 @@ where
T: std::fmt::Debug + 'static,
{
fn backward_step(&self) {
println!("backward node id={} order={}", self.id, self.order);
// println!("backward node id={} order={}", self.id, self.order);
self.ops.backward_step(&self.state)
}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {

View File

@ -37,7 +37,7 @@ impl<Out> ForwardNode<Out> {
fn new(order: usize, state: ForwardNodeState<Out>, ops: ForwardRecordedOpsRef<Out>) -> Self {
let id = nanoid::nanoid!();
println!("Creating a new node with id={} and order={}", id, order);
// println!("Creating a new node with id={} and order={}", id, order);
Self {
id,
order,

View File

@ -110,8 +110,8 @@ macro_rules! random {
let data = $crate::Data::sample_(shape, $distribution, $kind::default());
match $backend {
Backend::Tch(device) => {
$crate::tensor::backend::tch::TchTensor::from_data(data, device)
$crate::backend::Backend::Tch(device) => {
$crate::backend::tch::TchTensor::from_data(data, device)
}
}
}};
@ -152,7 +152,7 @@ macro_rules! random {
$crate::random!(
kind: $kind,
shape: $shape,
backend: $crate::tensor::backend::Backend::default()
backend: $crate::backend::Backend::default()
)
}};

View File

@ -1,6 +1,7 @@
pub mod backend;
mod data;
mod print;
mod shape;
mod tensor;

View File

@ -0,0 +1,7 @@
use crate::Data;
impl<P: std::fmt::Debug, const D: usize> std::fmt::Display for Data<P, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(format!("{:?}", &self.value).as_str())
}
}

View File

@ -0,0 +1,23 @@
use burn::tensor::*;
use burn_tensor::backend::autodiff::ADTensor;
fn main() {
let x = random!([2, 3]);
let y = random!([3, 2]);
println!("x: {}", x.to_data());
println!("y: {}", y.to_data());
let x = ADTensor::from_tensor(x);
let y = ADTensor::from_tensor(y);
let z = x.matmul(&y);
let grads = z.backward();
let x_grad = grads.wrt(&x).expect("x gradient defined");
let y_grad = grads.wrt(&y).expect("y gradient defined");
println!("x_grad: {}", x_grad.to_data());
println!("y_grad: {}", y_grad.to_data());
}