mirror of https://github.com/tracel-ai/burn.git
doc: add a simple example
This commit is contained in:
parent
b306156cc2
commit
61b67b44ff
|
@ -72,7 +72,7 @@ where
|
|||
T: std::fmt::Debug + 'static,
|
||||
{
|
||||
fn backward_step(&self) {
|
||||
println!("backward node id={} order={}", self.id, self.order);
|
||||
// println!("backward node id={} order={}", self.id, self.order);
|
||||
self.ops.backward_step(&self.state)
|
||||
}
|
||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
||||
|
|
|
@ -37,7 +37,7 @@ impl<Out> ForwardNode<Out> {
|
|||
|
||||
fn new(order: usize, state: ForwardNodeState<Out>, ops: ForwardRecordedOpsRef<Out>) -> Self {
|
||||
let id = nanoid::nanoid!();
|
||||
println!("Creating a new node with id={} and order={}", id, order);
|
||||
// println!("Creating a new node with id={} and order={}", id, order);
|
||||
Self {
|
||||
id,
|
||||
order,
|
||||
|
|
|
@ -110,8 +110,8 @@ macro_rules! random {
|
|||
let data = $crate::Data::sample_(shape, $distribution, $kind::default());
|
||||
|
||||
match $backend {
|
||||
Backend::Tch(device) => {
|
||||
$crate::tensor::backend::tch::TchTensor::from_data(data, device)
|
||||
$crate::backend::Backend::Tch(device) => {
|
||||
$crate::backend::tch::TchTensor::from_data(data, device)
|
||||
}
|
||||
}
|
||||
}};
|
||||
|
@ -152,7 +152,7 @@ macro_rules! random {
|
|||
$crate::random!(
|
||||
kind: $kind,
|
||||
shape: $shape,
|
||||
backend: $crate::tensor::backend::Backend::default()
|
||||
backend: $crate::backend::Backend::default()
|
||||
)
|
||||
}};
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
pub mod backend;
|
||||
|
||||
mod data;
|
||||
mod print;
|
||||
mod shape;
|
||||
mod tensor;
|
||||
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
use crate::Data;
|
||||
|
||||
impl<P: std::fmt::Debug, const D: usize> std::fmt::Display for Data<P, D> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(format!("{:?}", &self.value).as_str())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
use burn::tensor::*;
|
||||
use burn_tensor::backend::autodiff::ADTensor;
|
||||
|
||||
fn main() {
|
||||
let x = random!([2, 3]);
|
||||
let y = random!([3, 2]);
|
||||
|
||||
println!("x: {}", x.to_data());
|
||||
println!("y: {}", y.to_data());
|
||||
|
||||
let x = ADTensor::from_tensor(x);
|
||||
let y = ADTensor::from_tensor(y);
|
||||
|
||||
let z = x.matmul(&y);
|
||||
|
||||
let grads = z.backward();
|
||||
|
||||
let x_grad = grads.wrt(&x).expect("x gradient defined");
|
||||
let y_grad = grads.wrt(&y).expect("y gradient defined");
|
||||
|
||||
println!("x_grad: {}", x_grad.to_data());
|
||||
println!("y_grad: {}", y_grad.to_data());
|
||||
}
|
Loading…
Reference in New Issue