Fix/constant tensors (#984)

* Generalize autodiff tensor

* Can have int const module

* Update example

* Support no-std with burn-import

* Fix typos

* Fix alloc problems

* Revert burn-import changes

* Fix examples

* Support Int and Bool Params

* Fix

* Add comment
This commit is contained in:
Nathaniel Simard 2023-11-21 16:27:28 -05:00 committed by GitHub
parent 2f079e991b
commit cabbaab0c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 679 additions and 181 deletions

View File

@ -79,4 +79,28 @@ impl<B: Backend> AutodiffBackend for Autodiff<B> {
grads.remove(tensor);
grads.register::<B, D>(tensor.node.clone(), grad);
}
fn int_inner<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
) -> burn_tensor::ops::IntTensor<Self::InnerBackend, D> {
tensor
}
fn bool_inner<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D>,
) -> burn_tensor::ops::BoolTensor<Self::InnerBackend, D> {
tensor
}
fn int_from_inner<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self::InnerBackend, D>,
) -> burn_tensor::ops::IntTensor<Self, D> {
tensor
}
fn bool_from_inner<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self::InnerBackend, D>,
) -> burn_tensor::ops::BoolTensor<Self, D> {
tensor
}
}

View File

@ -1,12 +1,15 @@
use alloc::vec::Vec;
use super::ParamId;
use crate::{
record::Record,
tensor::backend::{AutodiffBackend, Backend},
};
use alloc::vec::Vec;
pub use burn_derive::Module;
use burn_tensor::Tensor;
use burn_tensor::{Bool, Int, Tensor};
/// Type alias to `Vec<B::Device>` which supports `no_std` environements, but automatically using
/// the `alloc` crate.
pub type Devices<B> = Vec<<B as Backend>::Device>;
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
@ -14,7 +17,11 @@ macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
fn map_float<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
let func = $item;
func(tensor)
}
@ -22,30 +29,13 @@ macro_rules! module {
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
struct Mapper<'a, B: Backend> {
capture: &'a $ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let func = $item;
func(tensor, self.capture)
}
}
let mut mapper = Mapper {
capture: $capture,
backend: core::marker::PhantomData,
};
$module.map(&mut mapper)
}};
(visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
let func = $item;
func(tensor, &mut self.state)
}
@ -94,20 +84,9 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record;
/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<B::Device> {
module!(
visit = self,
ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
let device = tensor.device();
if !state.contains(&device) {
state.push(device);
}
},
state = Vec<B::Device>,
init = Vec::new
)
}
/// Collects devices in the given vector and returns it with the devices found in the module
/// structure without duplicates.
fn devices(&self, devices: Devices<B>) -> Devices<B>;
/// Fork the module and all of its sub-modules to the given device.
///
@ -115,22 +94,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
///
/// This is similar to [to_device](Module::to_device), but it ensures the module will
/// have its own autodiff graph.
fn fork(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| {
let is_require_grad = tensor.is_require_grad();
let mut tensor = tensor.to_device(device).detach();
if is_require_grad {
tensor = tensor.require_grad();
}
tensor
},
capture = { device: B::Device }
)
}
fn fork(self, device: &B::Device) -> Self;
/// Move the module and all of its sub-modules to the given device.
///
@ -139,13 +103,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// The device operations will be registered in the autodiff graph. Therefore, be sure to call
/// backward only one time even if you have the same module on multiple devices. If you want to
/// call backward multiple times, look into using [fork](Module::fork) instead.
fn to_device(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
capture = { device: B::Device }
)
}
fn to_device(self, device: &B::Device) -> Self;
/// Each tensor in the module tree will not require grad.
///
@ -164,7 +122,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize {
module!(
visit = self,
visit_float = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
@ -172,10 +130,10 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
init = || 0
)
}
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
/// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
/// Map each tensor in the module with a [mapper](ModuleMapper).
/// Map each tensor parameter in the module with a [mapper](ModuleMapper).
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
/// Load the module state from a record.
@ -233,14 +191,36 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a tensor in the module.
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
/// Visit a float tensor in the module.
fn visit_float<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D>) {}
/// Visit an int tensor in the module.
fn visit_int<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Int>) {}
/// Visit a bool tensor in the module.
fn visit_bool<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Bool>) {}
}
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
/// Map a float tensor in the module.
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor
}
/// Map an int tensor in the module.
fn map_int<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Int>,
) -> Tensor<B, D, Int> {
tensor
}
/// Map a bool tensor in the module.
fn map_bool<const D: usize>(
&mut self,
_id: &ParamId,
tensor: Tensor<B, D, Bool>,
) -> Tensor<B, D, Bool> {
tensor
}
}
/// Module with auto-differentiation backend.

View File

@ -1,17 +1,14 @@
use core::marker::PhantomData;
use crate::{
self as burn,
module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor},
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
record::Record,
};
use burn::record::PrecisionSettings;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
BasicAutodiffOps, BasicOps, Tensor,
};
use super::ParamId;
use core::marker::PhantomData;
/// Record used for constant type implementing the [module](crate::module::Module) trait.
#[derive(Debug, Clone, Copy, new)]
@ -69,6 +66,18 @@ macro_rules! constant {
fn into_record(self) -> Self::Record {
burn::module::ConstantRecord::new()
}
fn to_device(self, _: &B::Device) -> Self {
self
}
fn fork(self, _: &B::Device) -> Self {
self
}
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
devices
}
};
(ad_module, $type:ty) => {
@ -113,27 +122,13 @@ constant!(i32);
constant!(i16);
constant!(i8);
impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
type Record = ConstantRecord;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
// Important:
// We need to implement visit method for Tensor Module because
// to_device will be called during the visit method of the ModuleVisitor
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
// We are using a dummy param id because the visit method requires a param id
let dummy_param_id = ParamId::new();
visitor.visit(&dummy_param_id, self)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
// Important:
// We need to implement visit method for Tensor Module because
// to_device will be called during the visit method of the ModuleVisitor
// We are using a dummy param id because the visit method requires a param id
let dummy_param_id = ParamId::new();
mapper.map(&dummy_param_id, self)
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
fn into_record(self) -> Self::Record {
@ -143,10 +138,30 @@ impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn to_device(self, device: &B::Device) -> Self {
self.to_device(device)
}
fn fork(self, device: &B::Device) -> Self {
self.to_device(device)
}
fn devices(&self, mut devices: Devices<B>) -> Devices<B> {
let device = self.device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Tensor<B, D> {
type InnerModule = Tensor<B::InnerBackend, D>;
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
for Tensor<B, D, K>
{
type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
fn valid(&self) -> Self::InnerModule {
self.clone().inner()
@ -171,6 +186,18 @@ impl<B: Backend> Module<B> for PhantomData<B> {
fn into_record(self) -> Self::Record {
ConstantRecord::new()
}
fn to_device(self, _: &<B as Backend>::Device) -> Self {
self
}
fn fork(self, _: &<B as Backend>::Device) -> Self {
self
}
fn devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
}
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {

View File

@ -28,6 +28,22 @@ where
fn into_record(self) -> Self::Record {
self.map(Module::into_record)
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
self.map(|module| module.to_device(device))
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
self.map(|module| module.fork(device))
}
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
if let Some(module) = self.as_ref() {
devices = module.devices(devices);
}
devices
}
}
impl<T, B> AutodiffModule<B> for Option<T>
@ -78,6 +94,24 @@ where
.map(|(module, record)| module.load_record(record))
.collect()
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
self.into_iter()
.map(|module| module.to_device(device))
.collect()
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
self.into_iter().map(|module| module.fork(device)).collect()
}
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices = module.devices(devices);
}
devices
}
}
impl<T, B> AutodiffModule<B> for Vec<T>
@ -100,11 +134,11 @@ where
{
type Record = [T::Record; N];
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
let mut devices = Vec::new();
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices.append(&mut module.devices());
devices = module.devices(devices);
}
devices
}
@ -139,6 +173,14 @@ where
fn into_record(self) -> Self::Record {
self.map(Module::into_record)
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
self.map(|module| module.to_device(device))
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
self.map(|module| module.fork(device))
}
}
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]

View File

@ -1,7 +1,7 @@
use alloc::sync::Arc;
use super::ParamId;
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
@ -51,12 +51,12 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
let tensor = self.value.read().unwrap();
visitor.visit(&self.id, &tensor)
visitor.visit_float(&self.id, &tensor)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut tensor = self.value.write().unwrap();
let tensor_out = mapper.map(&self.id, tensor.clone());
let tensor_out = mapper.map_float(&self.id, tensor.clone());
*tensor = tensor_out;
core::mem::drop(tensor);
@ -80,6 +80,30 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
self
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
let mut tensor = self.value.write().unwrap();
let tensor_out = tensor.clone().to_device(device);
*tensor = tensor_out;
core::mem::drop(tensor);
self
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
self.to_device(device) // Same thing here since no grad.
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
let device = self.value.read().unwrap().device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {

View File

@ -4,22 +4,38 @@ use crate::tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
};
use alloc::vec::Vec;
use burn_tensor::{Bool, Int};
impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
fn from(value: Tensor<B, D>) -> Self {
// When creating a parameter from a float tensor, we automatically mark it as requiring
// gradients, so that it can be updated by an optimizer.
Param::new(ParamId::new(), value.require_grad())
}
}
impl<B: Backend, const D: usize> From<Tensor<B, D, Int>> for Param<Tensor<B, D, Int>> {
fn from(value: Tensor<B, D, Int>) -> Self {
Param::new(ParamId::new(), value)
}
}
impl<B: Backend, const D: usize> From<Tensor<B, D, Bool>> for Param<Tensor<B, D, Bool>> {
fn from(value: Tensor<B, D, Bool>) -> Self {
Param::new(ParamId::new(), value)
}
}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit(&self.id, &self.value)
visitor.visit_float(&self.id, &self.value)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let value = mapper.map(&self.id, self.value);
let value = mapper.map_float(&self.id, self.value);
Self::new(self.id, value)
}
@ -41,6 +57,127 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
Self::new(record.id, tensor)
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
self.map(|tensor| tensor.to_device(device))
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
self.map(|tensor| {
let is_require_grad = tensor.is_require_grad();
let mut tensor = tensor.to_device(device).detach();
if is_require_grad {
tensor = tensor.require_grad();
}
tensor
})
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
let device = self.device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
type Record = Param<Tensor<B, D, Int>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit_int(&self.id, &self.value)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let value = mapper.map_int(&self.id, self.value);
Self::new(self.id, value)
}
fn into_record(self) -> Self::Record {
self
}
fn load_record(self, record: Self::Record) -> Self {
let mut tensor = record.value;
let device = self.device();
// Make sure we load the record into the same module device.
if tensor.device() != device {
tensor = tensor.to_device(&device);
}
Self::new(record.id, tensor)
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
self.map(|tensor| tensor.to_device(device))
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
self.to_device(device) // Don't support autodiff.
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
let device = self.device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
type Record = Param<Tensor<B, D, Bool>>;
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit_bool(&self.id, &self.value)
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let value = mapper.map_bool(&self.id, self.value);
Self::new(self.id, value)
}
fn into_record(self) -> Self::Record {
self
}
fn load_record(self, record: Self::Record) -> Self {
let mut tensor = record.value;
let device = self.device();
// Make sure we load the record into the same module device.
if tensor.device() != device {
tensor = tensor.to_device(&device);
}
Self::new(record.id, tensor)
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
self.map(|tensor| tensor.to_device(device))
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
self.to_device(device) // Don't support autodiff.
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
let device = self.device();
if !devices.contains(&device) {
devices.push(device)
}
devices
}
}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
@ -54,6 +191,22 @@ impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D
}
}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
fn valid(&self) -> Self::InnerModule {
Param::new(self.id.clone(), self.value.clone().inner())
}
}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
fn valid(&self) -> Self::InnerModule {
Param::new(self.id.clone(), self.value.clone().inner())
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;

View File

@ -1,7 +1,7 @@
use super::ParamId;
use crate::module::{Module, ModuleVisitor};
use alloc::vec::Vec;
use burn_tensor::{backend::Backend, Tensor};
use burn_tensor::{backend::Backend, Bool, Int, Tensor};
use core::marker::PhantomData;
struct ParamIdCollector<'a, M> {
@ -14,7 +14,13 @@ where
B: Backend,
M: Module<B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
self.param_ids.push(id.clone());
}
fn visit_int<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D, Int>) {
self.param_ids.push(id.clone());
}
fn visit_bool<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D, Bool>) {
self.param_ids.push(id.clone());
}
}

View File

@ -57,7 +57,7 @@ struct ModuleGradsAccumulator<'a, M> {
impl<'a, B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B>
for ModuleGradsAccumulator<'a, M>
{
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(id) {
Some(new) => match self.grads.remove::<B::InnerBackend, D>(id) {
Some(grad) => grad.add(new),

View File

@ -115,7 +115,7 @@ where
B: AutodiffBackend,
O: SimpleOptimizer<B::InnerBackend>,
{
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
fn map_float<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let grad = self.grads.remove(id);
if let Some(grad) = grad {

View File

@ -22,7 +22,7 @@ where
B: AutodiffBackend,
M: AutodiffModule<B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
if let Some(grad) = tensor.grad_remove(&mut self.grads) {
self.grads_params
.register::<B::InnerBackend, D>(id.clone(), grad);
@ -35,7 +35,7 @@ where
B: AutodiffBackend,
M: AutodiffModule<B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
if let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) {
self.grads
.register::<B::InnerBackend, D>(id.clone(), grad.to_device(self.device));

View File

@ -2,11 +2,15 @@ use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use burn_tensor::backend::Backend;
use burn_tensor::Bool;
use burn_tensor::Int;
use burn_tensor::Tensor;
use serde::Deserialize;
use serde::Serialize;
use super::tensor::BoolTensorSerde;
use super::tensor::FloatTensorSerde;
use super::tensor::IntTensorSerde;
use super::{PrecisionSettings, Record};
use crate::module::{Param, ParamId};
use burn_tensor::{DataSerialize, Element};
@ -115,6 +119,30 @@ impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
}
}
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Int>> {
type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
ParamSerde::new(self.id.into_string(), self.value.into_item())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Param::new(ParamId::from(item.id), Tensor::from_item(item.param))
}
}
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Bool>> {
type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
ParamSerde::new(self.id.into_string(), self.value.into_item::<S>())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Param::new(ParamId::from(item.id), Tensor::from_item::<S>(item.param))
}
}
// Type that can be serialized as is without any conversion.
macro_rules! primitive {
($type:ty) => {

View File

@ -1,6 +1,6 @@
use burn::module::{Module, Param};
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution, Shape, Tensor};
use burn::tensor::{Distribution, Int, Shape, Tensor};
use burn_core as burn;
pub type TestBackend = burn_ndarray::NdArray<f32>;
@ -12,6 +12,11 @@ pub struct ModuleBasic<B: Backend> {
weight_basic: Param<Tensor<B, 2>>,
}
#[derive(Module, Debug)]
struct ModuleTensorConstInt<B: Backend> {
weight_basic: Tensor<B, 2, Int>,
}
impl<B: Backend> ModuleBasic<B> {
fn new() -> Self {
let weight_basic = Tensor::random(Shape::new([20, 20]), Distribution::Default);

View File

@ -29,6 +29,9 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let num_params_fn = generator.gen_num_params();
let visit = generator.gen_visit();
let map_mut = generator.gen_map();
let devices = generator.gen_devices();
let to_device = generator.gen_to_device();
let fork = generator.gen_fork();
let valid_fn = generator.gen_valid();
let into_record_fn = generator.gen_into_record();
let load_record_fn = generator.gen_load_record();
@ -50,6 +53,10 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
#visit
#map_mut
#devices
#to_device
#fork
}
impl #generics burn::module::AutodiffModule<B> for #name #generics_ty

View File

@ -4,6 +4,9 @@ use proc_macro2::TokenStream;
pub(crate) trait ModuleCodegen {
fn gen_num_params(&self) -> TokenStream;
fn gen_visit(&self) -> TokenStream;
fn gen_devices(&self) -> TokenStream;
fn gen_to_device(&self) -> TokenStream;
fn gen_fork(&self) -> TokenStream;
fn gen_map(&self) -> TokenStream;
fn gen_valid(&self) -> TokenStream;
fn gen_into_record(&self) -> TokenStream;

View File

@ -39,10 +39,62 @@ impl ModuleCodegen for StructModuleCodegen {
}
}
fn gen_devices(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
let devices = burn::module::Module::<B>::devices(&self.#name, devices);
}
});
quote! {
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
#body
devices
}
}
}
fn gen_to_device(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::to_device(self.#name, device);
}
});
quote! {
fn to_device(self, device: &B::Device) -> Self {
#body
Self {
#(#names),*
}
}
}
}
fn gen_fork(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::fork(self.#name, device);
}
});
quote! {
fn fork(self, device: &B::Device) -> Self {
#body
Self {
#(#names),*
}
}
}
}
fn gen_map(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::map(self.#name, mapper);
let #name = burn::module::Module::<B>::map(self.#name, mapper);
}
});

View File

@ -0,0 +1,134 @@
use crate::{backend::AutodiffBackend, BasicOps, Bool, Float, Int, Tensor, TensorKind};
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
/// Backward pass of the tensor.
pub fn backward(&self) -> B::Gradients {
B::backward::<D>(self.primitive.clone())
}
/// Get the gradients of a tensor if it exist.
///
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
/// be accessed multiple times. If you only need to get the gradients one time,
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
B::grad(&self.primitive, grads).map(Tensor::new)
}
/// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
B::grad_remove(&self.primitive, grads).map(Tensor::new)
}
/// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
/// gradient.
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
B::grad_replace(&self.primitive, grads, grad.primitive);
}
}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
/// Returns the inner tensor without the autodiff information.
pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
Tensor::new(K::inner(self.primitive))
}
/// Convert a tensor to the autodiff backend.
///
/// # Arguments
///
/// * `inner` - The tensor to convert.
///
/// # Returns
///
/// The tensor converted to the autodiff backend.
pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
Self::new(K::from_inner(inner.primitive))
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
type InnerKind = Float;
fn inner<const D: usize>(
tensor: <Self as TensorKind<B>>::Primitive<D>,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
B::inner(tensor)
}
fn from_inner<const D: usize>(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
) -> <Self as TensorKind<B>>::Primitive<D> {
B::from_inner(inner)
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
type InnerKind = Int;
fn inner<const D: usize>(
tensor: <Self as TensorKind<B>>::Primitive<D>,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
B::int_inner(tensor)
}
fn from_inner<const D: usize>(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
) -> <Self as TensorKind<B>>::Primitive<D> {
B::int_from_inner(inner)
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
type InnerKind = Bool;
fn inner<const D: usize>(
tensor: <Self as TensorKind<B>>::Primitive<D>,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
B::bool_inner(tensor)
}
fn from_inner<const D: usize>(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
) -> <Self as TensorKind<B>>::Primitive<D> {
B::bool_from_inner(inner)
}
}
/// Trait that list all operations that can be applied on all tensors on an autodiff backend.
///
/// # Warnings
///
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {
/// Inner primitive tensor.
type InnerKind: BasicOps<B::InnerBackend>;
/// Returns the inner tensor without the autodiff information.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// Users should prefer the [Tensor::inner](Tensor::inner) function,
/// which is more high-level and designed for public use.
fn inner<const D: usize>(
tensor: <Self as TensorKind<B>>::Primitive<D>,
) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive<D>;
/// Convert a tensor to the autodiff backend.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// Users should prefer the [Tensor::from_inner](Tensor::from_inner) function,
/// which is more high-level and designed for public use.
fn from_inner<const D: usize>(
inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive<D>,
) -> <Self as TensorKind<B>>::Primitive<D>;
}

View File

@ -1,7 +1,6 @@
use alloc::vec::Vec;
use core::convert::TryInto;
use crate::backend::AutodiffBackend;
use crate::check;
use crate::check::TensorCheck;
use crate::tensor::backend::Backend;
@ -233,6 +232,7 @@ where
}
/// Detach the current tensor from the autodiff graph.
///
/// This function does nothing when autodiff is not enabled.
/// This can be used in batchers or elsewhere to ensure that previous operations are not
/// considered in the autodiff graph.
@ -241,6 +241,7 @@ where
}
/// Mark the tensor to keep gradients during the backward pass.
///
/// This function does nothing when autodiff is not enabled.
pub fn require_grad(self) -> Self {
self.set_require_grad(true)
@ -280,48 +281,3 @@ where
.div_scalar(n as f32 - correction_factor as f32)
}
}
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
/// Backward pass of the tensor.
pub fn backward(&self) -> B::Gradients {
B::backward::<D>(self.primitive.clone())
}
/// Get the gradients of a tensor if it exist.
///
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
/// be accessed multiple times. If you only need to get the gradients one time,
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
B::grad(&self.primitive, grads).map(Tensor::new)
}
/// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
B::grad_remove(&self.primitive, grads).map(Tensor::new)
}
/// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
/// gradient.
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
B::grad_replace(&self.primitive, grads, grad.primitive);
}
/// Returns the inner tensor without the autodiff information.
pub fn inner(self) -> Tensor<B::InnerBackend, D> {
Tensor::new(B::inner(self.primitive))
}
/// Convert a tensor to the autodiff backend.
///
/// # Arguments
///
/// * `inner` - The tensor to convert.
///
/// # Returns
///
/// The tensor converted to the autodiff backend.
pub fn from_inner(inner: Tensor<B::InnerBackend, D>) -> Self {
Self::new(B::from_inner(inner.primitive))
}
}

View File

@ -15,7 +15,7 @@ pub struct Bool;
/// A type-level representation of the kind of a tensor.
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
/// The primitive type of the tensor.
type Primitive<const D: usize>: Clone + core::fmt::Debug;
type Primitive<const D: usize>: Clone + core::fmt::Debug + Sync + Send;
/// The name of the tensor kind.
fn name() -> &'static str;

View File

@ -1,5 +1,6 @@
pub(crate) mod check;
mod autodiff;
mod base;
mod bool;
mod float;
@ -7,6 +8,7 @@ mod int;
mod kind;
mod numeric;
pub use autodiff::*;
pub use base::*;
pub use kind::*;
pub use numeric::*;

View File

@ -178,6 +178,29 @@ pub trait AutodiffBackend: Backend {
/// The inner backend tensor.
fn inner<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self::InnerBackend, D>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn int_inner<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self::InnerBackend, D>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn bool_inner<const D: usize>(tensor: BoolTensor<Self, D>)
-> BoolTensor<Self::InnerBackend, D>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
@ -191,4 +214,32 @@ pub trait AutodiffBackend: Backend {
fn from_inner<const D: usize>(
tensor: FloatTensor<Self::InnerBackend, D>,
) -> FloatTensor<Self, D>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn int_from_inner<const D: usize>(
tensor: IntTensor<Self::InnerBackend, D>,
) -> IntTensor<Self, D>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn bool_from_inner<const D: usize>(
tensor: BoolTensor<Self::InnerBackend, D>,
) -> BoolTensor<Self, D>;
}

View File

@ -11,14 +11,14 @@ type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;
pub fn launch<B: AutodiffBackend>(device: B::Device) {
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
let config = ExperimentConfig::new(
TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true),
AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),
);
text_classification::training::train::<B, AgNewsDataset>(
device,
devices,
AgNewsDataset::train(),
AgNewsDataset::test(),
config,
@ -39,7 +39,7 @@ mod ndarray {
use crate::{launch, ElemType};
pub fn run() {
launch::<Autodiff<NdArray<ElemType>>>(NdArrayDevice::Cpu);
launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);
}
}
@ -56,7 +56,7 @@ mod tch_gpu {
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;
launch::<Autodiff<LibTorch<ElemType>>>(device);
launch::<Autodiff<LibTorch<ElemType>>>(vec![device]);
}
}
@ -68,7 +68,7 @@ mod tch_cpu {
use crate::{launch, ElemType};
pub fn run() {
launch::<Autodiff<LibTorch<ElemType>>>(LibTorchDevice::Cpu);
launch::<Autodiff<LibTorch<ElemType>>>(vec![LibTorchDevice::Cpu]);
}
}
@ -79,7 +79,9 @@ mod wgpu {
use burn::backend::{Autodiff, Fusion};
pub fn run() {
launch::<Autodiff<Fusion<Wgpu<AutoGraphicsApi, ElemType, i32>>>>(WgpuDevice::default());
launch::<Autodiff<Fusion<Wgpu<AutoGraphicsApi, ElemType, i32>>>>(vec![
WgpuDevice::default(),
]);
}
}

View File

@ -11,14 +11,14 @@ type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;
pub fn launch<B: AutodiffBackend>(device: B::Device) {
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
let config = ExperimentConfig::new(
TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true),
AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),
);
text_classification::training::train::<B, DbPediaDataset>(
device,
devices,
DbPediaDataset::train(),
DbPediaDataset::test(),
config,
@ -38,7 +38,7 @@ mod ndarray {
use burn::backend::Autodiff;
pub fn run() {
launch::<Autodiff<NdArray<ElemType>>>(NdArrayDevice::Cpu);
launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);
}
}
@ -55,7 +55,7 @@ mod tch_gpu {
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;
launch::<Autodiff<LibTorch<ElemType>>>(device);
launch::<Autodiff<LibTorch<ElemType>>>(vec![device]);
}
}
@ -67,7 +67,7 @@ mod tch_cpu {
use crate::{launch, ElemType};
pub fn run() {
launch::<Autodiff<LibTorch<ElemType>>>(LibTorchDevice::Cpu);
launch::<Autodiff<LibTorch<ElemType>>>(vec![LibTorchDevice::Cpu]);
}
}
@ -79,7 +79,7 @@ mod wgpu {
use crate::{launch, ElemType};
pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default());
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
}
}

View File

@ -88,7 +88,7 @@ impl<B: Backend> TextClassificationModel<B> {
pub fn forward(&self, item: TextClassificationTrainingBatch<B>) -> ClassificationOutput<B> {
// Get batch and sequence length, and the device
let [batch_size, seq_length] = item.tokens.dims();
let device = &self.embedding_token.devices()[0];
let device = &self.embedding_token.devices(Vec::new())[0];
// Move tensors to the correct device
let tokens = item.tokens.to_device(device);
@ -128,7 +128,7 @@ impl<B: Backend> TextClassificationModel<B> {
pub fn infer(&self, item: TextClassificationInferenceBatch<B>) -> Tensor<B, 2> {
// Get batch and sequence length, and the device
let [batch_size, seq_length] = item.tokens.dims();
let device = &self.embedding_token.devices()[0];
let device = &self.embedding_token.devices(Vec::new())[0];
// Move tensors to the correct device
let tokens = item.tokens.to_device(device);

View File

@ -40,11 +40,11 @@ pub struct ExperimentConfig {
// Define train function
pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device)
dataset_train: D, // Training dataset
dataset_test: D, // Testing dataset
devices: Vec<B::Device>, // Device on which to perform computation (e.g., CPU or CUDA device)
dataset_train: D, // Training dataset
dataset_test: D, // Testing dataset
config: ExperimentConfig, // Experiment configuration
artifact_dir: &str, // Directory to save model and config files
artifact_dir: &str, // Directory to save model and config files
) {
// Initialize tokenizer
let tokenizer = Arc::new(BertCasedTokenizer::default());
@ -52,12 +52,12 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
// Initialize batchers for training and testing data
let batcher_train = TextClassificationBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
devices[0].clone(),
config.max_seq_length,
);
let batcher_test = TextClassificationBatcher::<B::InnerBackend>::new(
tokenizer.clone(),
device.clone(),
devices[0].clone(),
config.max_seq_length,
);
@ -93,13 +93,15 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
let learner = LearnerBuilder::new(artifact_dir)
.metric_train(CUDAMetric::new())
.metric_valid(CUDAMetric::new())
.metric_train(AccuracyMetric::new())
.metric_valid(AccuracyMetric::new())
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.metric_train_numeric(LearningRateMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device])
.devices(devices)
.num_epochs(config.num_epochs)
.build(model, optim, lr_scheduler);

View File

@ -58,7 +58,7 @@ impl<B: Backend> TextGenerationModel<B> {
item: TrainingTextGenerationBatch<B>,
) -> ClassificationOutput<B> {
let [batch_size, seq_length] = item.tokens_inputs.dims();
let device = &self.devices()[0];
let device = &self.devices(Vec::new())[0];
let inputs = item.tokens_inputs.to_device(device);
let targets = item.targets.to_device(device);