From cabbaab0c4a28f0a0e4c83a57e51e82d8c116aa5 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 21 Nov 2023 16:27:28 -0500 Subject: [PATCH] Fix/constant tensors (#984) * Generalize autodiff tensor * Can have int const module * Update example * Support no-std with burn-import * Fix typos * Fix alloc problems * Revert burn-import changes * Fix examples * Support Int and Bool Params * Fix * Add comment --- burn-autodiff/src/backend.rs | 24 +++ burn-core/src/module/base.rs | 114 ++++++------- burn-core/src/module/param/constant.rs | 79 ++++++--- burn-core/src/module/param/primitive.rs | 48 +++++- burn-core/src/module/param/running.rs | 32 +++- burn-core/src/module/param/tensor.rs | 157 +++++++++++++++++- burn-core/src/module/param/visitor.rs | 10 +- burn-core/src/optim/grad_accum.rs | 2 +- burn-core/src/optim/simple/adaptor.rs | 2 +- burn-core/src/optim/visitor.rs | 4 +- burn-core/src/record/primitive.rs | 28 ++++ burn-core/tests/derive_module.rs | 7 +- burn-derive/src/module/base.rs | 7 + burn-derive/src/module/codegen.rs | 3 + burn-derive/src/module/codegen_struct.rs | 54 +++++- burn-tensor/src/tensor/api/autodiff.rs | 134 +++++++++++++++ burn-tensor/src/tensor/api/float.rs | 48 +----- burn-tensor/src/tensor/api/kind.rs | 2 +- burn-tensor/src/tensor/api/mod.rs | 2 + burn-tensor/src/tensor/backend/base.rs | 51 ++++++ .../examples/ag-news-train.rs | 14 +- .../examples/db-pedia-train.rs | 12 +- examples/text-classification/src/model.rs | 4 +- examples/text-classification/src/training.rs | 20 ++- examples/text-generation/src/model.rs | 2 +- 25 files changed, 679 insertions(+), 181 deletions(-) create mode 100644 burn-tensor/src/tensor/api/autodiff.rs diff --git a/burn-autodiff/src/backend.rs b/burn-autodiff/src/backend.rs index e0039ae29..794c6a796 100644 --- a/burn-autodiff/src/backend.rs +++ b/burn-autodiff/src/backend.rs @@ -79,4 +79,28 @@ impl AutodiffBackend for Autodiff { grads.remove(tensor); grads.register::(tensor.node.clone(), grad); } + + fn int_inner( + tensor: burn_tensor::ops::IntTensor, + ) -> burn_tensor::ops::IntTensor { + tensor + } + + fn bool_inner( + tensor: burn_tensor::ops::BoolTensor, + ) -> burn_tensor::ops::BoolTensor { + tensor + } + + fn int_from_inner( + tensor: burn_tensor::ops::IntTensor, + ) -> burn_tensor::ops::IntTensor { + tensor + } + + fn bool_from_inner( + tensor: burn_tensor::ops::BoolTensor, + ) -> burn_tensor::ops::BoolTensor { + tensor + } } diff --git a/burn-core/src/module/base.rs b/burn-core/src/module/base.rs index a54f00624..c01b9b194 100644 --- a/burn-core/src/module/base.rs +++ b/burn-core/src/module/base.rs @@ -1,12 +1,15 @@ -use alloc::vec::Vec; - use super::ParamId; use crate::{ record::Record, tensor::backend::{AutodiffBackend, Backend}, }; +use alloc::vec::Vec; pub use burn_derive::Module; -use burn_tensor::Tensor; +use burn_tensor::{Bool, Int, Tensor}; + +/// Type alias to `Vec` which supports `no_std` environements, but automatically using +/// the `alloc` crate. +pub type Devices = Vec<::Device>; // At the moment, our plan is to continue experimenting with the macro internally and monitor its development. // We may consider making it public in the future. @@ -14,7 +17,11 @@ macro_rules! module { (map=$module:ident, ops=$item:expr) => {{ struct Mapper; impl ModuleMapper for Mapper { - fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + fn map_float( + &mut self, + _id: &ParamId, + tensor: Tensor, + ) -> Tensor { let func = $item; func(tensor) } @@ -22,30 +29,13 @@ macro_rules! module { let mut mapper = Mapper; $module.map(&mut mapper) }}; - (map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{ - struct Mapper<'a, B: Backend> { - capture: &'a $ty, - backend: core::marker::PhantomData, - } - impl<'a, B: Backend> ModuleMapper for Mapper<'a, B> { - fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { - let func = $item; - func(tensor, self.capture) - } - } - let mut mapper = Mapper { - capture: $capture, - backend: core::marker::PhantomData, - }; - $module.map(&mut mapper) - }}; - (visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ + (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ struct Visitor<'a, B: Backend> { state: &'a mut $state_ty, backend: core::marker::PhantomData, } impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { - fn visit(&mut self, _id: &ParamId, tensor: &Tensor) { + fn visit_float(&mut self, _id: &ParamId, tensor: &Tensor) { let func = $item; func(tensor, &mut self.state) } @@ -94,20 +84,9 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { /// Type to save and load the module. type Record: Record; - /// Get the device list of the module and all of its sub-modules. - fn devices(&self) -> Vec { - module!( - visit = self, - ops = |tensor: &Tensor, state: &mut Vec| { - let device = tensor.device(); - if !state.contains(&device) { - state.push(device); - } - }, - state = Vec, - init = Vec::new - ) - } + /// Collects devices in the given vector and returns it with the devices found in the module + /// structure without duplicates. + fn devices(&self, devices: Devices) -> Devices; /// Fork the module and all of its sub-modules to the given device. /// @@ -115,22 +94,7 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { /// /// This is similar to [to_device](Module::to_device), but it ensures the module will /// have its own autodiff graph. - fn fork(self, device: &B::Device) -> Self { - module!( - map = self, - ops = |tensor: Tensor, device: &B::Device| { - let is_require_grad = tensor.is_require_grad(); - let mut tensor = tensor.to_device(device).detach(); - - if is_require_grad { - tensor = tensor.require_grad(); - } - - tensor - }, - capture = { device: B::Device } - ) - } + fn fork(self, device: &B::Device) -> Self; /// Move the module and all of its sub-modules to the given device. /// @@ -139,13 +103,7 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { /// The device operations will be registered in the autodiff graph. Therefore, be sure to call /// backward only one time even if you have the same module on multiple devices. If you want to /// call backward multiple times, look into using [fork](Module::fork) instead. - fn to_device(self, device: &B::Device) -> Self { - module!( - map = self, - ops = |tensor: Tensor, device: &B::Device| tensor.to_device(device), - capture = { device: B::Device } - ) - } + fn to_device(self, device: &B::Device) -> Self; /// Each tensor in the module tree will not require grad. /// @@ -164,7 +122,7 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { /// Get the number of parameters the module has, including all of its sub-modules. fn num_params(&self) -> usize { module!( - visit = self, + visit_float = self, ops = |tensor: &Tensor, state: &mut usize| { *state += tensor.shape().num_elements(); }, @@ -172,10 +130,10 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { init = || 0 ) } - /// Visit each tensor in the module with a [visitor](ModuleVisitor). + /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor). fn visit>(&self, visitor: &mut V); - /// Map each tensor in the module with a [mapper](ModuleMapper). + /// Map each tensor parameter in the module with a [mapper](ModuleMapper). fn map>(self, mapper: &mut M) -> Self; /// Load the module state from a record. @@ -233,14 +191,36 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { /// Module visitor trait. pub trait ModuleVisitor { - /// Visit a tensor in the module. - fn visit(&mut self, id: &ParamId, tensor: &Tensor); + /// Visit a float tensor in the module. + fn visit_float(&mut self, _id: &ParamId, _tensor: &Tensor) {} + /// Visit an int tensor in the module. + fn visit_int(&mut self, _id: &ParamId, _tensor: &Tensor) {} + /// Visit a bool tensor in the module. + fn visit_bool(&mut self, _id: &ParamId, _tensor: &Tensor) {} } /// Module mapper trait. pub trait ModuleMapper { - /// Map a tensor in the module. - fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor; + /// Map a float tensor in the module. + fn map_float(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + tensor + } + /// Map an int tensor in the module. + fn map_int( + &mut self, + _id: &ParamId, + tensor: Tensor, + ) -> Tensor { + tensor + } + /// Map a bool tensor in the module. + fn map_bool( + &mut self, + _id: &ParamId, + tensor: Tensor, + ) -> Tensor { + tensor + } } /// Module with auto-differentiation backend. diff --git a/burn-core/src/module/param/constant.rs b/burn-core/src/module/param/constant.rs index 9c33d1409..9010fe3b6 100644 --- a/burn-core/src/module/param/constant.rs +++ b/burn-core/src/module/param/constant.rs @@ -1,17 +1,14 @@ -use core::marker::PhantomData; - use crate::{ self as burn, - module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}, + module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor}, record::Record, }; use burn::record::PrecisionSettings; use burn_tensor::{ backend::{AutodiffBackend, Backend}, - Tensor, + BasicAutodiffOps, BasicOps, Tensor, }; - -use super::ParamId; +use core::marker::PhantomData; /// Record used for constant type implementing the [module](crate::module::Module) trait. #[derive(Debug, Clone, Copy, new)] @@ -69,6 +66,18 @@ macro_rules! constant { fn into_record(self) -> Self::Record { burn::module::ConstantRecord::new() } + + fn to_device(self, _: &B::Device) -> Self { + self + } + + fn fork(self, _: &B::Device) -> Self { + self + } + + fn devices(&self, devices: burn::module::Devices) -> burn::module::Devices { + devices + } }; (ad_module, $type:ty) => { @@ -113,27 +122,13 @@ constant!(i32); constant!(i16); constant!(i8); -impl Module for Tensor { +impl> Module for Tensor { type Record = ConstantRecord; - fn visit>(&self, visitor: &mut V) { - // Important: - // We need to implement visit method for Tensor Module because - // to_device will be called during the visit method of the ModuleVisitor + fn visit>(&self, _visitor: &mut V) {} - // We are using a dummy param id because the visit method requires a param id - let dummy_param_id = ParamId::new(); - visitor.visit(&dummy_param_id, self) - } - - fn map>(self, mapper: &mut M) -> Self { - // Important: - // We need to implement visit method for Tensor Module because - // to_device will be called during the visit method of the ModuleVisitor - - // We are using a dummy param id because the visit method requires a param id - let dummy_param_id = ParamId::new(); - mapper.map(&dummy_param_id, self) + fn map>(self, _mapper: &mut M) -> Self { + self } fn into_record(self) -> Self::Record { @@ -143,10 +138,30 @@ impl Module for Tensor { fn load_record(self, _record: Self::Record) -> Self { self } + + fn to_device(self, device: &B::Device) -> Self { + self.to_device(device) + } + + fn fork(self, device: &B::Device) -> Self { + self.to_device(device) + } + + fn devices(&self, mut devices: Devices) -> Devices { + let device = self.device(); + + if !devices.contains(&device) { + devices.push(device) + } + + devices + } } -impl AutodiffModule for Tensor { - type InnerModule = Tensor; +impl> AutodiffModule + for Tensor +{ + type InnerModule = Tensor; fn valid(&self) -> Self::InnerModule { self.clone().inner() @@ -171,6 +186,18 @@ impl Module for PhantomData { fn into_record(self) -> Self::Record { ConstantRecord::new() } + + fn to_device(self, _: &::Device) -> Self { + self + } + + fn fork(self, _: &::Device) -> Self { + self + } + + fn devices(&self, devices: Devices) -> Devices { + devices + } } impl AutodiffModule for PhantomData { diff --git a/burn-core/src/module/param/primitive.rs b/burn-core/src/module/param/primitive.rs index dfa23e2d8..e23beac9d 100644 --- a/burn-core/src/module/param/primitive.rs +++ b/burn-core/src/module/param/primitive.rs @@ -28,6 +28,22 @@ where fn into_record(self) -> Self::Record { self.map(Module::into_record) } + + fn to_device(self, device: &::Device) -> Self { + self.map(|module| module.to_device(device)) + } + + fn fork(self, device: &::Device) -> Self { + self.map(|module| module.fork(device)) + } + + fn devices(&self, mut devices: Vec) -> Vec { + if let Some(module) = self.as_ref() { + devices = module.devices(devices); + } + + devices + } } impl AutodiffModule for Option @@ -78,6 +94,24 @@ where .map(|(module, record)| module.load_record(record)) .collect() } + + fn to_device(self, device: &::Device) -> Self { + self.into_iter() + .map(|module| module.to_device(device)) + .collect() + } + + fn fork(self, device: &::Device) -> Self { + self.into_iter().map(|module| module.fork(device)).collect() + } + + fn devices(&self, mut devices: Vec) -> Vec { + for module in self.iter() { + devices = module.devices(devices); + } + + devices + } } impl AutodiffModule for Vec @@ -100,11 +134,11 @@ where { type Record = [T::Record; N]; - fn devices(&self) -> Vec<::Device> { - let mut devices = Vec::new(); + fn devices(&self, mut devices: Vec) -> Vec { for module in self.iter() { - devices.append(&mut module.devices()); + devices = module.devices(devices); } + devices } @@ -139,6 +173,14 @@ where fn into_record(self) -> Self::Record { self.map(Module::into_record) } + + fn to_device(self, device: &::Device) -> Self { + self.map(|module| module.to_device(device)) + } + + fn fork(self, device: &::Device) -> Self { + self.map(|module| module.fork(device)) + } } impl AutodiffModule for [T; N] diff --git a/burn-core/src/module/param/running.rs b/burn-core/src/module/param/running.rs index ab17f181f..1e72284ff 100644 --- a/burn-core/src/module/param/running.rs +++ b/burn-core/src/module/param/running.rs @@ -1,7 +1,7 @@ -use alloc::sync::Arc; - use super::ParamId; use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param}; +use alloc::sync::Arc; +use alloc::vec::Vec; use burn_tensor::{ backend::{AutodiffBackend, Backend}, Tensor, @@ -51,12 +51,12 @@ impl Module for RunningState> { fn visit>(&self, visitor: &mut V) { let tensor = self.value.read().unwrap(); - visitor.visit(&self.id, &tensor) + visitor.visit_float(&self.id, &tensor) } fn map>(self, mapper: &mut M) -> Self { let mut tensor = self.value.write().unwrap(); - let tensor_out = mapper.map(&self.id, tensor.clone()); + let tensor_out = mapper.map_float(&self.id, tensor.clone()); *tensor = tensor_out; core::mem::drop(tensor); @@ -80,6 +80,30 @@ impl Module for RunningState> { self } + + fn to_device(self, device: &::Device) -> Self { + let mut tensor = self.value.write().unwrap(); + let tensor_out = tensor.clone().to_device(device); + + *tensor = tensor_out; + core::mem::drop(tensor); + + self + } + + fn fork(self, device: &::Device) -> Self { + self.to_device(device) // Same thing here since no grad. + } + + fn devices(&self, mut devices: Vec<::Device>) -> Vec<::Device> { + let device = self.value.read().unwrap().device(); + + if !devices.contains(&device) { + devices.push(device) + } + + devices + } } impl RunningState> { diff --git a/burn-core/src/module/param/tensor.rs b/burn-core/src/module/param/tensor.rs index 3bbe519dc..e52a21542 100644 --- a/burn-core/src/module/param/tensor.rs +++ b/burn-core/src/module/param/tensor.rs @@ -4,22 +4,38 @@ use crate::tensor::{ backend::{AutodiffBackend, Backend}, Tensor, }; +use alloc::vec::Vec; +use burn_tensor::{Bool, Int}; impl From> for Param> { fn from(value: Tensor) -> Self { + // When creating a parameter from a float tensor, we automatically mark it as requiring + // gradients, so that it can be updated by an optimizer. Param::new(ParamId::new(), value.require_grad()) } } +impl From> for Param> { + fn from(value: Tensor) -> Self { + Param::new(ParamId::new(), value) + } +} + +impl From> for Param> { + fn from(value: Tensor) -> Self { + Param::new(ParamId::new(), value) + } +} + impl Module for Param> { type Record = Param>; fn visit>(&self, visitor: &mut V) { - visitor.visit(&self.id, &self.value) + visitor.visit_float(&self.id, &self.value) } fn map>(self, mapper: &mut M) -> Self { - let value = mapper.map(&self.id, self.value); + let value = mapper.map_float(&self.id, self.value); Self::new(self.id, value) } @@ -41,6 +57,127 @@ impl Module for Param> { Self::new(record.id, tensor) } + + fn to_device(self, device: &::Device) -> Self { + self.map(|tensor| tensor.to_device(device)) + } + + fn fork(self, device: &::Device) -> Self { + self.map(|tensor| { + let is_require_grad = tensor.is_require_grad(); + let mut tensor = tensor.to_device(device).detach(); + + if is_require_grad { + tensor = tensor.require_grad(); + } + + tensor + }) + } + + fn devices(&self, mut devices: Vec<::Device>) -> Vec<::Device> { + let device = self.device(); + + if !devices.contains(&device) { + devices.push(device) + } + + devices + } +} + +impl Module for Param> { + type Record = Param>; + + fn visit>(&self, visitor: &mut V) { + visitor.visit_int(&self.id, &self.value) + } + + fn map>(self, mapper: &mut M) -> Self { + let value = mapper.map_int(&self.id, self.value); + Self::new(self.id, value) + } + + fn into_record(self) -> Self::Record { + self + } + + fn load_record(self, record: Self::Record) -> Self { + let mut tensor = record.value; + let device = self.device(); + + // Make sure we load the record into the same module device. + if tensor.device() != device { + tensor = tensor.to_device(&device); + } + + Self::new(record.id, tensor) + } + + fn to_device(self, device: &::Device) -> Self { + self.map(|tensor| tensor.to_device(device)) + } + + fn fork(self, device: &::Device) -> Self { + self.to_device(device) // Don't support autodiff. + } + + fn devices(&self, mut devices: Vec<::Device>) -> Vec<::Device> { + let device = self.device(); + + if !devices.contains(&device) { + devices.push(device) + } + + devices + } +} + +impl Module for Param> { + type Record = Param>; + + fn visit>(&self, visitor: &mut V) { + visitor.visit_bool(&self.id, &self.value) + } + + fn map>(self, mapper: &mut M) -> Self { + let value = mapper.map_bool(&self.id, self.value); + Self::new(self.id, value) + } + + fn into_record(self) -> Self::Record { + self + } + + fn load_record(self, record: Self::Record) -> Self { + let mut tensor = record.value; + let device = self.device(); + + // Make sure we load the record into the same module device. + if tensor.device() != device { + tensor = tensor.to_device(&device); + } + + Self::new(record.id, tensor) + } + + fn to_device(self, device: &::Device) -> Self { + self.map(|tensor| tensor.to_device(device)) + } + + fn fork(self, device: &::Device) -> Self { + self.to_device(device) // Don't support autodiff. + } + + fn devices(&self, mut devices: Vec<::Device>) -> Vec<::Device> { + let device = self.device(); + + if !devices.contains(&device) { + devices.push(device) + } + + devices + } } impl AutodiffModule for Param> { @@ -54,6 +191,22 @@ impl AutodiffModule for Param AutodiffModule for Param> { + type InnerModule = Param>; + + fn valid(&self) -> Self::InnerModule { + Param::new(self.id.clone(), self.value.clone().inner()) + } +} + +impl AutodiffModule for Param> { + type InnerModule = Param>; + + fn valid(&self) -> Self::InnerModule { + Param::new(self.id.clone(), self.value.clone().inner()) + } +} + #[cfg(all(test, feature = "std"))] mod tests { use super::*; diff --git a/burn-core/src/module/param/visitor.rs b/burn-core/src/module/param/visitor.rs index 9e27e3b6d..c83c410b1 100644 --- a/burn-core/src/module/param/visitor.rs +++ b/burn-core/src/module/param/visitor.rs @@ -1,7 +1,7 @@ use super::ParamId; use crate::module::{Module, ModuleVisitor}; use alloc::vec::Vec; -use burn_tensor::{backend::Backend, Tensor}; +use burn_tensor::{backend::Backend, Bool, Int, Tensor}; use core::marker::PhantomData; struct ParamIdCollector<'a, M> { @@ -14,7 +14,13 @@ where B: Backend, M: Module, { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + fn visit_float(&mut self, id: &ParamId, _tensor: &Tensor) { + self.param_ids.push(id.clone()); + } + fn visit_int(&mut self, id: &ParamId, _tensor: &Tensor) { + self.param_ids.push(id.clone()); + } + fn visit_bool(&mut self, id: &ParamId, _tensor: &Tensor) { self.param_ids.push(id.clone()); } } diff --git a/burn-core/src/optim/grad_accum.rs b/burn-core/src/optim/grad_accum.rs index e0e455bdf..e6d059e2e 100644 --- a/burn-core/src/optim/grad_accum.rs +++ b/burn-core/src/optim/grad_accum.rs @@ -57,7 +57,7 @@ struct ModuleGradsAccumulator<'a, M> { impl<'a, B: AutodiffBackend, M: AutodiffModule> ModuleVisitor for ModuleGradsAccumulator<'a, M> { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + fn visit_float(&mut self, id: &ParamId, _tensor: &Tensor) { let grad_updated = match self.grads_new.remove::(id) { Some(new) => match self.grads.remove::(id) { Some(grad) => grad.add(new), diff --git a/burn-core/src/optim/simple/adaptor.rs b/burn-core/src/optim/simple/adaptor.rs index 0b44c8418..37648ce56 100644 --- a/burn-core/src/optim/simple/adaptor.rs +++ b/burn-core/src/optim/simple/adaptor.rs @@ -115,7 +115,7 @@ where B: AutodiffBackend, O: SimpleOptimizer, { - fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor { + fn map_float(&mut self, id: &ParamId, tensor: Tensor) -> Tensor { let grad = self.grads.remove(id); if let Some(grad) = grad { diff --git a/burn-core/src/optim/visitor.rs b/burn-core/src/optim/visitor.rs index 1631fcf74..3d9f3af52 100644 --- a/burn-core/src/optim/visitor.rs +++ b/burn-core/src/optim/visitor.rs @@ -22,7 +22,7 @@ where B: AutodiffBackend, M: AutodiffModule, { - fn visit(&mut self, id: &ParamId, tensor: &Tensor) { + fn visit_float(&mut self, id: &ParamId, tensor: &Tensor) { if let Some(grad) = tensor.grad_remove(&mut self.grads) { self.grads_params .register::(id.clone(), grad); @@ -35,7 +35,7 @@ where B: AutodiffBackend, M: AutodiffModule, { - fn visit(&mut self, id: &ParamId, _tensor: &Tensor) { + fn visit_float(&mut self, id: &ParamId, _tensor: &Tensor) { if let Some(grad) = self.grads.remove::(id) { self.grads .register::(id.clone(), grad.to_device(self.device)); diff --git a/burn-core/src/record/primitive.rs b/burn-core/src/record/primitive.rs index 507635b20..1192acc13 100644 --- a/burn-core/src/record/primitive.rs +++ b/burn-core/src/record/primitive.rs @@ -2,11 +2,15 @@ use alloc::string::String; use alloc::string::ToString; use alloc::vec::Vec; use burn_tensor::backend::Backend; +use burn_tensor::Bool; +use burn_tensor::Int; use burn_tensor::Tensor; use serde::Deserialize; use serde::Serialize; +use super::tensor::BoolTensorSerde; use super::tensor::FloatTensorSerde; +use super::tensor::IntTensorSerde; use super::{PrecisionSettings, Record}; use crate::module::{Param, ParamId}; use burn_tensor::{DataSerialize, Element}; @@ -115,6 +119,30 @@ impl Record for Param> { } } +impl Record for Param> { + type Item = ParamSerde>; + + fn into_item(self) -> Self::Item { + ParamSerde::new(self.id.into_string(), self.value.into_item()) + } + + fn from_item(item: Self::Item) -> Self { + Param::new(ParamId::from(item.id), Tensor::from_item(item.param)) + } +} + +impl Record for Param> { + type Item = ParamSerde; + + fn into_item(self) -> Self::Item { + ParamSerde::new(self.id.into_string(), self.value.into_item::()) + } + + fn from_item(item: Self::Item) -> Self { + Param::new(ParamId::from(item.id), Tensor::from_item::(item.param)) + } +} + // Type that can be serialized as is without any conversion. macro_rules! primitive { ($type:ty) => { diff --git a/burn-core/tests/derive_module.rs b/burn-core/tests/derive_module.rs index 87beafc42..6303077af 100644 --- a/burn-core/tests/derive_module.rs +++ b/burn-core/tests/derive_module.rs @@ -1,6 +1,6 @@ use burn::module::{Module, Param}; use burn::tensor::backend::Backend; -use burn::tensor::{Distribution, Shape, Tensor}; +use burn::tensor::{Distribution, Int, Shape, Tensor}; use burn_core as burn; pub type TestBackend = burn_ndarray::NdArray; @@ -12,6 +12,11 @@ pub struct ModuleBasic { weight_basic: Param>, } +#[derive(Module, Debug)] +struct ModuleTensorConstInt { + weight_basic: Tensor, +} + impl ModuleBasic { fn new() -> Self { let weight_basic = Tensor::random(Shape::new([20, 20]), Distribution::Default); diff --git a/burn-derive/src/module/base.rs b/burn-derive/src/module/base.rs index 1cf41bb5e..0ac6d7ceb 100644 --- a/burn-derive/src/module/base.rs +++ b/burn-derive/src/module/base.rs @@ -29,6 +29,9 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream { let num_params_fn = generator.gen_num_params(); let visit = generator.gen_visit(); let map_mut = generator.gen_map(); + let devices = generator.gen_devices(); + let to_device = generator.gen_to_device(); + let fork = generator.gen_fork(); let valid_fn = generator.gen_valid(); let into_record_fn = generator.gen_into_record(); let load_record_fn = generator.gen_load_record(); @@ -50,6 +53,10 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream { #visit #map_mut + + #devices + #to_device + #fork } impl #generics burn::module::AutodiffModule for #name #generics_ty diff --git a/burn-derive/src/module/codegen.rs b/burn-derive/src/module/codegen.rs index 852138018..bf1bf6bc0 100644 --- a/burn-derive/src/module/codegen.rs +++ b/burn-derive/src/module/codegen.rs @@ -4,6 +4,9 @@ use proc_macro2::TokenStream; pub(crate) trait ModuleCodegen { fn gen_num_params(&self) -> TokenStream; fn gen_visit(&self) -> TokenStream; + fn gen_devices(&self) -> TokenStream; + fn gen_to_device(&self) -> TokenStream; + fn gen_fork(&self) -> TokenStream; fn gen_map(&self) -> TokenStream; fn gen_valid(&self) -> TokenStream; fn gen_into_record(&self) -> TokenStream; diff --git a/burn-derive/src/module/codegen_struct.rs b/burn-derive/src/module/codegen_struct.rs index a6988b2bd..806aebca0 100644 --- a/burn-derive/src/module/codegen_struct.rs +++ b/burn-derive/src/module/codegen_struct.rs @@ -39,10 +39,62 @@ impl ModuleCodegen for StructModuleCodegen { } } + fn gen_devices(&self) -> TokenStream { + let body = self.gen_fields_fn(|name| { + quote! { + let devices = burn::module::Module::::devices(&self.#name, devices); + } + }); + + quote! { + fn devices(&self, devices: burn::module::Devices) -> burn::module::Devices { + #body + + devices + } + } + } + + fn gen_to_device(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = burn::module::Module::::to_device(self.#name, device); + } + }); + + quote! { + fn to_device(self, device: &B::Device) -> Self { + #body + + Self { + #(#names),* + } + } + } + } + + fn gen_fork(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = burn::module::Module::::fork(self.#name, device); + } + }); + + quote! { + fn fork(self, device: &B::Device) -> Self { + #body + + Self { + #(#names),* + } + } + } + } + fn gen_map(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name| { quote! { - let #name = burn::module::Module::map(self.#name, mapper); + let #name = burn::module::Module::::map(self.#name, mapper); } }); diff --git a/burn-tensor/src/tensor/api/autodiff.rs b/burn-tensor/src/tensor/api/autodiff.rs new file mode 100644 index 000000000..b7bb1ff89 --- /dev/null +++ b/burn-tensor/src/tensor/api/autodiff.rs @@ -0,0 +1,134 @@ +use crate::{backend::AutodiffBackend, BasicOps, Bool, Float, Int, Tensor, TensorKind}; + +impl Tensor { + /// Backward pass of the tensor. + pub fn backward(&self) -> B::Gradients { + B::backward::(self.primitive.clone()) + } + + /// Get the gradients of a tensor if it exist. + /// + /// Returns a new reference to the same tensor. Therefore the same grad tensor can + /// be accessed multiple times. If you only need to get the gradients one time, + /// consider using [grad_remove](Tensor::grad_remove) for better performance. + pub fn grad(&self, grads: &B::Gradients) -> Option> { + B::grad(&self.primitive, grads).map(Tensor::new) + } + + /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result. + pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option> { + B::grad_remove(&self.primitive, grads).map(Tensor::new) + } + + /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided + /// gradient. + pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor) { + B::grad_replace(&self.primitive, grads, grad.primitive); + } +} + +impl> Tensor { + /// Returns the inner tensor without the autodiff information. + pub fn inner(self) -> Tensor { + Tensor::new(K::inner(self.primitive)) + } + + /// Convert a tensor to the autodiff backend. + /// + /// # Arguments + /// + /// * `inner` - The tensor to convert. + /// + /// # Returns + /// + /// The tensor converted to the autodiff backend. + pub fn from_inner(inner: Tensor) -> Self { + Self::new(K::from_inner(inner.primitive)) + } +} + +impl BasicAutodiffOps for Float { + type InnerKind = Float; + + fn inner( + tensor: >::Primitive, + ) -> ::InnerBackend>>::Primitive { + B::inner(tensor) + } + + fn from_inner( + inner: ::InnerBackend>>::Primitive, + ) -> >::Primitive { + B::from_inner(inner) + } +} + +impl BasicAutodiffOps for Int { + type InnerKind = Int; + + fn inner( + tensor: >::Primitive, + ) -> ::InnerBackend>>::Primitive { + B::int_inner(tensor) + } + + fn from_inner( + inner: ::InnerBackend>>::Primitive, + ) -> >::Primitive { + B::int_from_inner(inner) + } +} + +impl BasicAutodiffOps for Bool { + type InnerKind = Bool; + + fn inner( + tensor: >::Primitive, + ) -> ::InnerBackend>>::Primitive { + B::bool_inner(tensor) + } + + fn from_inner( + inner: ::InnerBackend>>::Primitive, + ) -> >::Primitive { + B::bool_from_inner(inner) + } +} + +/// Trait that list all operations that can be applied on all tensors on an autodiff backend. +/// +/// # Warnings +/// +/// This is an internal trait, use the public API provided by [tensor struct](Tensor). +pub trait BasicAutodiffOps: BasicOps + BasicOps { + /// Inner primitive tensor. + type InnerKind: BasicOps; + + /// Returns the inner tensor without the autodiff information. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// Users should prefer the [Tensor::inner](Tensor::inner) function, + /// which is more high-level and designed for public use. + fn inner( + tensor: >::Primitive, + ) -> >::Primitive; + + /// Convert a tensor to the autodiff backend. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// Users should prefer the [Tensor::from_inner](Tensor::from_inner) function, + /// which is more high-level and designed for public use. + fn from_inner( + inner: >::Primitive, + ) -> >::Primitive; +} diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index d2cededba..846ad3787 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -1,7 +1,6 @@ use alloc::vec::Vec; use core::convert::TryInto; -use crate::backend::AutodiffBackend; use crate::check; use crate::check::TensorCheck; use crate::tensor::backend::Backend; @@ -233,6 +232,7 @@ where } /// Detach the current tensor from the autodiff graph. + /// /// This function does nothing when autodiff is not enabled. /// This can be used in batchers or elsewhere to ensure that previous operations are not /// considered in the autodiff graph. @@ -241,6 +241,7 @@ where } /// Mark the tensor to keep gradients during the backward pass. + /// /// This function does nothing when autodiff is not enabled. pub fn require_grad(self) -> Self { self.set_require_grad(true) @@ -280,48 +281,3 @@ where .div_scalar(n as f32 - correction_factor as f32) } } - -impl Tensor { - /// Backward pass of the tensor. - pub fn backward(&self) -> B::Gradients { - B::backward::(self.primitive.clone()) - } - - /// Get the gradients of a tensor if it exist. - /// - /// Returns a new reference to the same tensor. Therefore the same grad tensor can - /// be accessed multiple times. If you only need to get the gradients one time, - /// consider using [grad_remove](Tensor::grad_remove) for better performance. - pub fn grad(&self, grads: &B::Gradients) -> Option> { - B::grad(&self.primitive, grads).map(Tensor::new) - } - - /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result. - pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option> { - B::grad_remove(&self.primitive, grads).map(Tensor::new) - } - - /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided - /// gradient. - pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor) { - B::grad_replace(&self.primitive, grads, grad.primitive); - } - - /// Returns the inner tensor without the autodiff information. - pub fn inner(self) -> Tensor { - Tensor::new(B::inner(self.primitive)) - } - - /// Convert a tensor to the autodiff backend. - /// - /// # Arguments - /// - /// * `inner` - The tensor to convert. - /// - /// # Returns - /// - /// The tensor converted to the autodiff backend. - pub fn from_inner(inner: Tensor) -> Self { - Self::new(B::from_inner(inner.primitive)) - } -} diff --git a/burn-tensor/src/tensor/api/kind.rs b/burn-tensor/src/tensor/api/kind.rs index 208aa26ef..0aa836064 100644 --- a/burn-tensor/src/tensor/api/kind.rs +++ b/burn-tensor/src/tensor/api/kind.rs @@ -15,7 +15,7 @@ pub struct Bool; /// A type-level representation of the kind of a tensor. pub trait TensorKind: Clone + core::fmt::Debug { /// The primitive type of the tensor. - type Primitive: Clone + core::fmt::Debug; + type Primitive: Clone + core::fmt::Debug + Sync + Send; /// The name of the tensor kind. fn name() -> &'static str; diff --git a/burn-tensor/src/tensor/api/mod.rs b/burn-tensor/src/tensor/api/mod.rs index e601eb3e3..77fe61ef2 100644 --- a/burn-tensor/src/tensor/api/mod.rs +++ b/burn-tensor/src/tensor/api/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod check; +mod autodiff; mod base; mod bool; mod float; @@ -7,6 +8,7 @@ mod int; mod kind; mod numeric; +pub use autodiff::*; pub use base::*; pub use kind::*; pub use numeric::*; diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 131855cce..5193fd67c 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -178,6 +178,29 @@ pub trait AutodiffBackend: Backend { /// The inner backend tensor. fn inner(tensor: FloatTensor) -> FloatTensor; + /// Returns the tensor with inner backend type. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the inner backend tensor for. + /// + /// # Returns + /// + /// The inner backend tensor. + fn int_inner(tensor: IntTensor) -> IntTensor; + + /// Returns the tensor with inner backend type. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the inner backend tensor for. + /// + /// # Returns + /// + /// The inner backend tensor. + fn bool_inner(tensor: BoolTensor) + -> BoolTensor; + /// Converts the inner backend tensor to the autodiff backend tensor. /// /// # Arguments @@ -191,4 +214,32 @@ pub trait AutodiffBackend: Backend { fn from_inner( tensor: FloatTensor, ) -> FloatTensor; + + /// Converts the inner backend tensor to the autodiff backend tensor. + /// + /// # Arguments + /// + /// * `tensor` - The inner backend tensor to convert. + /// + /// + /// # Returns + /// + /// The autodiff backend tensor. + fn int_from_inner( + tensor: IntTensor, + ) -> IntTensor; + + /// Converts the inner backend tensor to the autodiff backend tensor. + /// + /// # Arguments + /// + /// * `tensor` - The inner backend tensor to convert. + /// + /// + /// # Returns + /// + /// The autodiff backend tensor. + fn bool_from_inner( + tensor: BoolTensor, + ) -> BoolTensor; } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 4b336c270..07e3d11e7 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -11,14 +11,14 @@ type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; -pub fn launch(device: B::Device) { +pub fn launch(devices: Vec) { let config = ExperimentConfig::new( TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), ); text_classification::training::train::( - device, + devices, AgNewsDataset::train(), AgNewsDataset::test(), config, @@ -39,7 +39,7 @@ mod ndarray { use crate::{launch, ElemType}; pub fn run() { - launch::>>(NdArrayDevice::Cpu); + launch::>>(vec![NdArrayDevice::Cpu]); } } @@ -56,7 +56,7 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>>(device); + launch::>>(vec![device]); } } @@ -68,7 +68,7 @@ mod tch_cpu { use crate::{launch, ElemType}; pub fn run() { - launch::>>(LibTorchDevice::Cpu); + launch::>>(vec![LibTorchDevice::Cpu]); } } @@ -79,7 +79,9 @@ mod wgpu { use burn::backend::{Autodiff, Fusion}; pub fn run() { - launch::>>>(WgpuDevice::default()); + launch::>>>(vec![ + WgpuDevice::default(), + ]); } } diff --git a/examples/text-classification/examples/db-pedia-train.rs b/examples/text-classification/examples/db-pedia-train.rs index 81319c32c..dba57cc27 100644 --- a/examples/text-classification/examples/db-pedia-train.rs +++ b/examples/text-classification/examples/db-pedia-train.rs @@ -11,14 +11,14 @@ type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; -pub fn launch(device: B::Device) { +pub fn launch(devices: Vec) { let config = ExperimentConfig::new( TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), ); text_classification::training::train::( - device, + devices, DbPediaDataset::train(), DbPediaDataset::test(), config, @@ -38,7 +38,7 @@ mod ndarray { use burn::backend::Autodiff; pub fn run() { - launch::>>(NdArrayDevice::Cpu); + launch::>>(vec![NdArrayDevice::Cpu]); } } @@ -55,7 +55,7 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>>(device); + launch::>>(vec![device]); } } @@ -67,7 +67,7 @@ mod tch_cpu { use crate::{launch, ElemType}; pub fn run() { - launch::>>(LibTorchDevice::Cpu); + launch::>>(vec![LibTorchDevice::Cpu]); } } @@ -79,7 +79,7 @@ mod wgpu { use crate::{launch, ElemType}; pub fn run() { - launch::>>(WgpuDevice::default()); + launch::>>(vec![WgpuDevice::default()]); } } diff --git a/examples/text-classification/src/model.rs b/examples/text-classification/src/model.rs index 914b14576..d5f5db07e 100644 --- a/examples/text-classification/src/model.rs +++ b/examples/text-classification/src/model.rs @@ -88,7 +88,7 @@ impl TextClassificationModel { pub fn forward(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { // Get batch and sequence length, and the device let [batch_size, seq_length] = item.tokens.dims(); - let device = &self.embedding_token.devices()[0]; + let device = &self.embedding_token.devices(Vec::new())[0]; // Move tensors to the correct device let tokens = item.tokens.to_device(device); @@ -128,7 +128,7 @@ impl TextClassificationModel { pub fn infer(&self, item: TextClassificationInferenceBatch) -> Tensor { // Get batch and sequence length, and the device let [batch_size, seq_length] = item.tokens.dims(); - let device = &self.embedding_token.devices()[0]; + let device = &self.embedding_token.devices(Vec::new())[0]; // Move tensors to the correct device let tokens = item.tokens.to_device(device); diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index eed1ddc4d..5bb6de0d4 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -40,11 +40,11 @@ pub struct ExperimentConfig { // Define train function pub fn train( - device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) - dataset_train: D, // Training dataset - dataset_test: D, // Testing dataset + devices: Vec, // Device on which to perform computation (e.g., CPU or CUDA device) + dataset_train: D, // Training dataset + dataset_test: D, // Testing dataset config: ExperimentConfig, // Experiment configuration - artifact_dir: &str, // Directory to save model and config files + artifact_dir: &str, // Directory to save model and config files ) { // Initialize tokenizer let tokenizer = Arc::new(BertCasedTokenizer::default()); @@ -52,12 +52,12 @@ pub fn train( // Initialize batchers for training and testing data let batcher_train = TextClassificationBatcher::::new( tokenizer.clone(), - device.clone(), + devices[0].clone(), config.max_seq_length, ); let batcher_test = TextClassificationBatcher::::new( tokenizer.clone(), - device.clone(), + devices[0].clone(), config.max_seq_length, ); @@ -93,13 +93,15 @@ pub fn train( let learner = LearnerBuilder::new(artifact_dir) .metric_train(CUDAMetric::new()) .metric_valid(CUDAMetric::new()) - .metric_train(AccuracyMetric::new()) - .metric_valid(AccuracyMetric::new()) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) - .devices(vec![device]) + .devices(devices) .num_epochs(config.num_epochs) .build(model, optim, lr_scheduler); diff --git a/examples/text-generation/src/model.rs b/examples/text-generation/src/model.rs index 6e2312142..60bbf152f 100644 --- a/examples/text-generation/src/model.rs +++ b/examples/text-generation/src/model.rs @@ -58,7 +58,7 @@ impl TextGenerationModel { item: TrainingTextGenerationBatch, ) -> ClassificationOutput { let [batch_size, seq_length] = item.tokens_inputs.dims(); - let device = &self.devices()[0]; + let device = &self.devices(Vec::new())[0]; let inputs = item.tokens_inputs.to_device(device); let targets = item.targets.to_device(device);