mirror of https://github.com/tracel-ai/burn.git
Fix/constant tensors (#984)
* Generalize autodiff tensor * Can have int const module * Update example * Support no-std with burn-import * Fix typos * Fix alloc problems * Revert burn-import changes * Fix examples * Support Int and Bool Params * Fix * Add comment
This commit is contained in:
parent
2f079e991b
commit
cabbaab0c4
|
@ -79,4 +79,28 @@ impl<B: Backend> AutodiffBackend for Autodiff<B> {
|
|||
grads.remove(tensor);
|
||||
grads.register::<B, D>(tensor.node.clone(), grad);
|
||||
}
|
||||
|
||||
fn int_inner<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D>,
|
||||
) -> burn_tensor::ops::IntTensor<Self::InnerBackend, D> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_inner<const D: usize>(
|
||||
tensor: burn_tensor::ops::BoolTensor<Self, D>,
|
||||
) -> burn_tensor::ops::BoolTensor<Self::InnerBackend, D> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_from_inner<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self::InnerBackend, D>,
|
||||
) -> burn_tensor::ops::IntTensor<Self, D> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_from_inner<const D: usize>(
|
||||
tensor: burn_tensor::ops::BoolTensor<Self::InnerBackend, D>,
|
||||
) -> burn_tensor::ops::BoolTensor<Self, D> {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use super::ParamId;
|
||||
use crate::{
|
||||
record::Record,
|
||||
tensor::backend::{AutodiffBackend, Backend},
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
pub use burn_derive::Module;
|
||||
use burn_tensor::Tensor;
|
||||
use burn_tensor::{Bool, Int, Tensor};
|
||||
|
||||
/// Type alias to `Vec<B::Device>` which supports `no_std` environements, but automatically using
|
||||
/// the `alloc` crate.
|
||||
pub type Devices<B> = Vec<<B as Backend>::Device>;
|
||||
|
||||
// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
|
||||
// We may consider making it public in the future.
|
||||
|
@ -14,7 +17,11 @@ macro_rules! module {
|
|||
(map=$module:ident, ops=$item:expr) => {{
|
||||
struct Mapper;
|
||||
impl<B: Backend> ModuleMapper<B> for Mapper {
|
||||
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
fn map_float<const D: usize>(
|
||||
&mut self,
|
||||
_id: &ParamId,
|
||||
tensor: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let func = $item;
|
||||
func(tensor)
|
||||
}
|
||||
|
@ -22,30 +29,13 @@ macro_rules! module {
|
|||
let mut mapper = Mapper;
|
||||
$module.map(&mut mapper)
|
||||
}};
|
||||
(map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
|
||||
struct Mapper<'a, B: Backend> {
|
||||
capture: &'a $ty,
|
||||
backend: core::marker::PhantomData<B>,
|
||||
}
|
||||
impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
|
||||
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let func = $item;
|
||||
func(tensor, self.capture)
|
||||
}
|
||||
}
|
||||
let mut mapper = Mapper {
|
||||
capture: $capture,
|
||||
backend: core::marker::PhantomData,
|
||||
};
|
||||
$module.map(&mut mapper)
|
||||
}};
|
||||
(visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
|
||||
(visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
|
||||
struct Visitor<'a, B: Backend> {
|
||||
state: &'a mut $state_ty,
|
||||
backend: core::marker::PhantomData<B>,
|
||||
}
|
||||
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
|
||||
fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
|
||||
fn visit_float<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
|
||||
let func = $item;
|
||||
func(tensor, &mut self.state)
|
||||
}
|
||||
|
@ -94,20 +84,9 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
/// Type to save and load the module.
|
||||
type Record: Record;
|
||||
|
||||
/// Get the device list of the module and all of its sub-modules.
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
module!(
|
||||
visit = self,
|
||||
ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
|
||||
let device = tensor.device();
|
||||
if !state.contains(&device) {
|
||||
state.push(device);
|
||||
}
|
||||
},
|
||||
state = Vec<B::Device>,
|
||||
init = Vec::new
|
||||
)
|
||||
}
|
||||
/// Collects devices in the given vector and returns it with the devices found in the module
|
||||
/// structure without duplicates.
|
||||
fn devices(&self, devices: Devices<B>) -> Devices<B>;
|
||||
|
||||
/// Fork the module and all of its sub-modules to the given device.
|
||||
///
|
||||
|
@ -115,22 +94,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
///
|
||||
/// This is similar to [to_device](Module::to_device), but it ensures the module will
|
||||
/// have its own autodiff graph.
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
module!(
|
||||
map = self,
|
||||
ops = |tensor: Tensor<B, D>, device: &B::Device| {
|
||||
let is_require_grad = tensor.is_require_grad();
|
||||
let mut tensor = tensor.to_device(device).detach();
|
||||
|
||||
if is_require_grad {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
|
||||
tensor
|
||||
},
|
||||
capture = { device: B::Device }
|
||||
)
|
||||
}
|
||||
fn fork(self, device: &B::Device) -> Self;
|
||||
|
||||
/// Move the module and all of its sub-modules to the given device.
|
||||
///
|
||||
|
@ -139,13 +103,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
/// The device operations will be registered in the autodiff graph. Therefore, be sure to call
|
||||
/// backward only one time even if you have the same module on multiple devices. If you want to
|
||||
/// call backward multiple times, look into using [fork](Module::fork) instead.
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
module!(
|
||||
map = self,
|
||||
ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
|
||||
capture = { device: B::Device }
|
||||
)
|
||||
}
|
||||
fn to_device(self, device: &B::Device) -> Self;
|
||||
|
||||
/// Each tensor in the module tree will not require grad.
|
||||
///
|
||||
|
@ -164,7 +122,7 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
/// Get the number of parameters the module has, including all of its sub-modules.
|
||||
fn num_params(&self) -> usize {
|
||||
module!(
|
||||
visit = self,
|
||||
visit_float = self,
|
||||
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
|
||||
*state += tensor.shape().num_elements();
|
||||
},
|
||||
|
@ -172,10 +130,10 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
init = || 0
|
||||
)
|
||||
}
|
||||
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
|
||||
/// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
|
||||
|
||||
/// Map each tensor in the module with a [mapper](ModuleMapper).
|
||||
/// Map each tensor parameter in the module with a [mapper](ModuleMapper).
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
|
||||
|
||||
/// Load the module state from a record.
|
||||
|
@ -233,14 +191,36 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
|
||||
/// Module visitor trait.
|
||||
pub trait ModuleVisitor<B: Backend> {
|
||||
/// Visit a tensor in the module.
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
|
||||
/// Visit a float tensor in the module.
|
||||
fn visit_float<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D>) {}
|
||||
/// Visit an int tensor in the module.
|
||||
fn visit_int<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Int>) {}
|
||||
/// Visit a bool tensor in the module.
|
||||
fn visit_bool<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Bool>) {}
|
||||
}
|
||||
|
||||
/// Module mapper trait.
|
||||
pub trait ModuleMapper<B: Backend> {
|
||||
/// Map a tensor in the module.
|
||||
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
|
||||
/// Map a float tensor in the module.
|
||||
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor
|
||||
}
|
||||
/// Map an int tensor in the module.
|
||||
fn map_int<const D: usize>(
|
||||
&mut self,
|
||||
_id: &ParamId,
|
||||
tensor: Tensor<B, D, Int>,
|
||||
) -> Tensor<B, D, Int> {
|
||||
tensor
|
||||
}
|
||||
/// Map a bool tensor in the module.
|
||||
fn map_bool<const D: usize>(
|
||||
&mut self,
|
||||
_id: &ParamId,
|
||||
tensor: Tensor<B, D, Bool>,
|
||||
) -> Tensor<B, D, Bool> {
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
/// Module with auto-differentiation backend.
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
self as burn,
|
||||
module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor},
|
||||
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
|
||||
record::Record,
|
||||
};
|
||||
use burn::record::PrecisionSettings;
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
BasicAutodiffOps, BasicOps, Tensor,
|
||||
};
|
||||
|
||||
use super::ParamId;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Record used for constant type implementing the [module](crate::module::Module) trait.
|
||||
#[derive(Debug, Clone, Copy, new)]
|
||||
|
@ -69,6 +66,18 @@ macro_rules! constant {
|
|||
fn into_record(self) -> Self::Record {
|
||||
burn::module::ConstantRecord::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _: &B::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, _: &B::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
|
||||
devices
|
||||
}
|
||||
};
|
||||
|
||||
(ad_module, $type:ty) => {
|
||||
|
@ -113,27 +122,13 @@ constant!(i32);
|
|||
constant!(i16);
|
||||
constant!(i8);
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
|
||||
type Record = ConstantRecord;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
// Important:
|
||||
// We need to implement visit method for Tensor Module because
|
||||
// to_device will be called during the visit method of the ModuleVisitor
|
||||
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}
|
||||
|
||||
// We are using a dummy param id because the visit method requires a param id
|
||||
let dummy_param_id = ParamId::new();
|
||||
visitor.visit(&dummy_param_id, self)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
// Important:
|
||||
// We need to implement visit method for Tensor Module because
|
||||
// to_device will be called during the visit method of the ModuleVisitor
|
||||
|
||||
// We are using a dummy param id because the visit method requires a param id
|
||||
let dummy_param_id = ParamId::new();
|
||||
mapper.map(&dummy_param_id, self)
|
||||
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
|
@ -143,10 +138,30 @@ impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
|
|||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
self.to_device(device)
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Tensor<B, D> {
|
||||
type InnerModule = Tensor<B::InnerBackend, D>;
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
self.to_device(device)
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Devices<B>) -> Devices<B> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
|
||||
for Tensor<B, D, K>
|
||||
{
|
||||
type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.clone().inner()
|
||||
|
@ -171,6 +186,18 @@ impl<B: Backend> Module<B> for PhantomData<B> {
|
|||
fn into_record(self) -> Self::Record {
|
||||
ConstantRecord::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _: &<B as Backend>::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, _: &<B as Backend>::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn devices(&self, devices: Devices<B>) -> Devices<B> {
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
|
||||
|
|
|
@ -28,6 +28,22 @@ where
|
|||
fn into_record(self) -> Self::Record {
|
||||
self.map(Module::into_record)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|module| module.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|module| module.fork(device))
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
if let Some(module) = self.as_ref() {
|
||||
devices = module.devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Option<T>
|
||||
|
@ -78,6 +94,24 @@ where
|
|||
.map(|(module, record)| module.load_record(record))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.into_iter()
|
||||
.map(|module| module.to_device(device))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.into_iter().map(|module| module.fork(device)).collect()
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
for module in self.iter() {
|
||||
devices = module.devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Vec<T>
|
||||
|
@ -100,11 +134,11 @@ where
|
|||
{
|
||||
type Record = [T::Record; N];
|
||||
|
||||
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
let mut devices = Vec::new();
|
||||
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
for module in self.iter() {
|
||||
devices.append(&mut module.devices());
|
||||
devices = module.devices(devices);
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
|
||||
|
@ -139,6 +173,14 @@ where
|
|||
fn into_record(self) -> Self::Record {
|
||||
self.map(Module::into_record)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|module| module.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|module| module.fork(device))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use alloc::sync::Arc;
|
||||
|
||||
use super::ParamId;
|
||||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
|
@ -51,12 +51,12 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
|||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
let tensor = self.value.read().unwrap();
|
||||
|
||||
visitor.visit(&self.id, &tensor)
|
||||
visitor.visit_float(&self.id, &tensor)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let mut tensor = self.value.write().unwrap();
|
||||
let tensor_out = mapper.map(&self.id, tensor.clone());
|
||||
let tensor_out = mapper.map_float(&self.id, tensor.clone());
|
||||
|
||||
*tensor = tensor_out;
|
||||
core::mem::drop(tensor);
|
||||
|
@ -80,6 +80,30 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
|||
|
||||
self
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
let mut tensor = self.value.write().unwrap();
|
||||
let tensor_out = tensor.clone().to_device(device);
|
||||
|
||||
*tensor = tensor_out;
|
||||
core::mem::drop(tensor);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.to_device(device) // Same thing here since no grad.
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.value.read().unwrap().device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
||||
|
|
|
@ -4,22 +4,38 @@ use crate::tensor::{
|
|||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{Bool, Int};
|
||||
|
||||
impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
|
||||
fn from(value: Tensor<B, D>) -> Self {
|
||||
// When creating a parameter from a float tensor, we automatically mark it as requiring
|
||||
// gradients, so that it can be updated by an optimizer.
|
||||
Param::new(ParamId::new(), value.require_grad())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> From<Tensor<B, D, Int>> for Param<Tensor<B, D, Int>> {
|
||||
fn from(value: Tensor<B, D, Int>) -> Self {
|
||||
Param::new(ParamId::new(), value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> From<Tensor<B, D, Bool>> for Param<Tensor<B, D, Bool>> {
|
||||
fn from(value: Tensor<B, D, Bool>) -> Self {
|
||||
Param::new(ParamId::new(), value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit(&self.id, &self.value)
|
||||
visitor.visit_float(&self.id, &self.value)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let value = mapper.map(&self.id, self.value);
|
||||
let value = mapper.map_float(&self.id, self.value);
|
||||
Self::new(self.id, value)
|
||||
}
|
||||
|
||||
|
@ -41,6 +57,127 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
|||
|
||||
Self::new(record.id, tensor)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|tensor| tensor.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|tensor| {
|
||||
let is_require_grad = tensor.is_require_grad();
|
||||
let mut tensor = tensor.to_device(device).detach();
|
||||
|
||||
if is_require_grad {
|
||||
tensor = tensor.require_grad();
|
||||
}
|
||||
|
||||
tensor
|
||||
})
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
|
||||
type Record = Param<Tensor<B, D, Int>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit_int(&self.id, &self.value)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let value = mapper.map_int(&self.id, self.value);
|
||||
Self::new(self.id, value)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let mut tensor = record.value;
|
||||
let device = self.device();
|
||||
|
||||
// Make sure we load the record into the same module device.
|
||||
if tensor.device() != device {
|
||||
tensor = tensor.to_device(&device);
|
||||
}
|
||||
|
||||
Self::new(record.id, tensor)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|tensor| tensor.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.to_device(device) // Don't support autodiff.
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
|
||||
type Record = Param<Tensor<B, D, Bool>>;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit_bool(&self.id, &self.value)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let value = mapper.map_bool(&self.id, self.value);
|
||||
Self::new(self.id, value)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
self
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
let mut tensor = record.value;
|
||||
let device = self.device();
|
||||
|
||||
// Make sure we load the record into the same module device.
|
||||
if tensor.device() != device {
|
||||
tensor = tensor.to_device(&device);
|
||||
}
|
||||
|
||||
Self::new(record.id, tensor)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.map(|tensor| tensor.to_device(device))
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
self.to_device(device) // Don't support autodiff.
|
||||
}
|
||||
|
||||
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
|
||||
let device = self.device();
|
||||
|
||||
if !devices.contains(&device) {
|
||||
devices.push(device)
|
||||
}
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
|
||||
|
@ -54,6 +191,22 @@ impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
Param::new(self.id.clone(), self.value.clone().inner())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
Param::new(self.id.clone(), self.value.clone().inner())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::ParamId;
|
||||
use crate::module::{Module, ModuleVisitor};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{backend::Backend, Tensor};
|
||||
use burn_tensor::{backend::Backend, Bool, Int, Tensor};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
struct ParamIdCollector<'a, M> {
|
||||
|
@ -14,7 +14,13 @@ where
|
|||
B: Backend,
|
||||
M: Module<B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
fn visit_float<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
self.param_ids.push(id.clone());
|
||||
}
|
||||
fn visit_int<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D, Int>) {
|
||||
self.param_ids.push(id.clone());
|
||||
}
|
||||
fn visit_bool<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D, Bool>) {
|
||||
self.param_ids.push(id.clone());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ struct ModuleGradsAccumulator<'a, M> {
|
|||
impl<'a, B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B>
|
||||
for ModuleGradsAccumulator<'a, M>
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
fn visit_float<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(id) {
|
||||
Some(new) => match self.grads.remove::<B::InnerBackend, D>(id) {
|
||||
Some(grad) => grad.add(new),
|
||||
|
|
|
@ -115,7 +115,7 @@ where
|
|||
B: AutodiffBackend,
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
{
|
||||
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
fn map_float<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let grad = self.grads.remove(id);
|
||||
|
||||
if let Some(grad) = grad {
|
||||
|
|
|
@ -22,7 +22,7 @@ where
|
|||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
|
||||
fn visit_float<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
|
||||
if let Some(grad) = tensor.grad_remove(&mut self.grads) {
|
||||
self.grads_params
|
||||
.register::<B::InnerBackend, D>(id.clone(), grad);
|
||||
|
@ -35,7 +35,7 @@ where
|
|||
B: AutodiffBackend,
|
||||
M: AutodiffModule<B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
fn visit_float<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
if let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) {
|
||||
self.grads
|
||||
.register::<B::InnerBackend, D>(id.clone(), grad.to_device(self.device));
|
||||
|
|
|
@ -2,11 +2,15 @@ use alloc::string::String;
|
|||
use alloc::string::ToString;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::Bool;
|
||||
use burn_tensor::Int;
|
||||
use burn_tensor::Tensor;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::tensor::BoolTensorSerde;
|
||||
use super::tensor::FloatTensorSerde;
|
||||
use super::tensor::IntTensorSerde;
|
||||
use super::{PrecisionSettings, Record};
|
||||
use crate::module::{Param, ParamId};
|
||||
use burn_tensor::{DataSerialize, Element};
|
||||
|
@ -115,6 +119,30 @@ impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Int>> {
|
||||
type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
ParamSerde::new(self.id.into_string(), self.value.into_item())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
Param::new(ParamId::from(item.id), Tensor::from_item(item.param))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D, Bool>> {
|
||||
type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
ParamSerde::new(self.id.into_string(), self.value.into_item::<S>())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
Param::new(ParamId::from(item.id), Tensor::from_item::<S>(item.param))
|
||||
}
|
||||
}
|
||||
|
||||
// Type that can be serialized as is without any conversion.
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn::module::{Module, Param};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Distribution, Shape, Tensor};
|
||||
use burn::tensor::{Distribution, Int, Shape, Tensor};
|
||||
use burn_core as burn;
|
||||
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
@ -12,6 +12,11 @@ pub struct ModuleBasic<B: Backend> {
|
|||
weight_basic: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct ModuleTensorConstInt<B: Backend> {
|
||||
weight_basic: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleBasic<B> {
|
||||
fn new() -> Self {
|
||||
let weight_basic = Tensor::random(Shape::new([20, 20]), Distribution::Default);
|
||||
|
|
|
@ -29,6 +29,9 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
let num_params_fn = generator.gen_num_params();
|
||||
let visit = generator.gen_visit();
|
||||
let map_mut = generator.gen_map();
|
||||
let devices = generator.gen_devices();
|
||||
let to_device = generator.gen_to_device();
|
||||
let fork = generator.gen_fork();
|
||||
let valid_fn = generator.gen_valid();
|
||||
let into_record_fn = generator.gen_into_record();
|
||||
let load_record_fn = generator.gen_load_record();
|
||||
|
@ -50,6 +53,10 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
|
||||
#visit
|
||||
#map_mut
|
||||
|
||||
#devices
|
||||
#to_device
|
||||
#fork
|
||||
}
|
||||
|
||||
impl #generics burn::module::AutodiffModule<B> for #name #generics_ty
|
||||
|
|
|
@ -4,6 +4,9 @@ use proc_macro2::TokenStream;
|
|||
pub(crate) trait ModuleCodegen {
|
||||
fn gen_num_params(&self) -> TokenStream;
|
||||
fn gen_visit(&self) -> TokenStream;
|
||||
fn gen_devices(&self) -> TokenStream;
|
||||
fn gen_to_device(&self) -> TokenStream;
|
||||
fn gen_fork(&self) -> TokenStream;
|
||||
fn gen_map(&self) -> TokenStream;
|
||||
fn gen_valid(&self) -> TokenStream;
|
||||
fn gen_into_record(&self) -> TokenStream;
|
||||
|
|
|
@ -39,10 +39,62 @@ impl ModuleCodegen for StructModuleCodegen {
|
|||
}
|
||||
}
|
||||
|
||||
fn gen_devices(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
let devices = burn::module::Module::<B>::devices(&self.#name, devices);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
|
||||
#body
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_to_device(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::to_device(self.#name, device);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_fork(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::fork(self.#name, device);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_map(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::map(self.#name, mapper);
|
||||
let #name = burn::module::Module::<B>::map(self.#name, mapper);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
use crate::{backend::AutodiffBackend, BasicOps, Bool, Float, Int, Tensor, TensorKind};
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
|
||||
/// Backward pass of the tensor.
|
||||
pub fn backward(&self) -> B::Gradients {
|
||||
B::backward::<D>(self.primitive.clone())
|
||||
}
|
||||
|
||||
/// Get the gradients of a tensor if it exist.
|
||||
///
|
||||
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
|
||||
/// be accessed multiple times. If you only need to get the gradients one time,
|
||||
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
|
||||
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
|
||||
B::grad(&self.primitive, grads).map(Tensor::new)
|
||||
}
|
||||
|
||||
/// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
|
||||
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
|
||||
B::grad_remove(&self.primitive, grads).map(Tensor::new)
|
||||
}
|
||||
|
||||
/// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
|
||||
/// gradient.
|
||||
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
|
||||
B::grad_replace(&self.primitive, grads, grad.primitive);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
|
||||
/// Returns the inner tensor without the autodiff information.
|
||||
pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
|
||||
Tensor::new(K::inner(self.primitive))
|
||||
}
|
||||
|
||||
/// Convert a tensor to the autodiff backend.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `inner` - The tensor to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor converted to the autodiff backend.
|
||||
pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
|
||||
Self::new(K::from_inner(inner.primitive))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
|
||||
type InnerKind = Float;
|
||||
|
||||
fn inner<const D: usize>(
|
||||
tensor: <Self as TensorKind<B>>::Primitive<D>,
|
||||
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
|
||||
B::inner(tensor)
|
||||
}
|
||||
|
||||
fn from_inner<const D: usize>(
|
||||
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
|
||||
) -> <Self as TensorKind<B>>::Primitive<D> {
|
||||
B::from_inner(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
|
||||
type InnerKind = Int;
|
||||
|
||||
fn inner<const D: usize>(
|
||||
tensor: <Self as TensorKind<B>>::Primitive<D>,
|
||||
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
|
||||
B::int_inner(tensor)
|
||||
}
|
||||
|
||||
fn from_inner<const D: usize>(
|
||||
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
|
||||
) -> <Self as TensorKind<B>>::Primitive<D> {
|
||||
B::int_from_inner(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
|
||||
type InnerKind = Bool;
|
||||
|
||||
fn inner<const D: usize>(
|
||||
tensor: <Self as TensorKind<B>>::Primitive<D>,
|
||||
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
|
||||
B::bool_inner(tensor)
|
||||
}
|
||||
|
||||
fn from_inner<const D: usize>(
|
||||
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
|
||||
) -> <Self as TensorKind<B>>::Primitive<D> {
|
||||
B::bool_from_inner(inner)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that list all operations that can be applied on all tensors on an autodiff backend.
|
||||
///
|
||||
/// # Warnings
|
||||
///
|
||||
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
|
||||
pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {
|
||||
/// Inner primitive tensor.
|
||||
type InnerKind: BasicOps<B::InnerBackend>;
|
||||
|
||||
/// Returns the inner tensor without the autodiff information.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// Users should prefer the [Tensor::inner](Tensor::inner) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn inner<const D: usize>(
|
||||
tensor: <Self as TensorKind<B>>::Primitive<D>,
|
||||
) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive<D>;
|
||||
|
||||
/// Convert a tensor to the autodiff backend.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// Users should prefer the [Tensor::from_inner](Tensor::from_inner) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn from_inner<const D: usize>(
|
||||
inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive<D>,
|
||||
) -> <Self as TensorKind<B>>::Primitive<D>;
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::convert::TryInto;
|
||||
|
||||
use crate::backend::AutodiffBackend;
|
||||
use crate::check;
|
||||
use crate::check::TensorCheck;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -233,6 +232,7 @@ where
|
|||
}
|
||||
|
||||
/// Detach the current tensor from the autodiff graph.
|
||||
///
|
||||
/// This function does nothing when autodiff is not enabled.
|
||||
/// This can be used in batchers or elsewhere to ensure that previous operations are not
|
||||
/// considered in the autodiff graph.
|
||||
|
@ -241,6 +241,7 @@ where
|
|||
}
|
||||
|
||||
/// Mark the tensor to keep gradients during the backward pass.
|
||||
///
|
||||
/// This function does nothing when autodiff is not enabled.
|
||||
pub fn require_grad(self) -> Self {
|
||||
self.set_require_grad(true)
|
||||
|
@ -280,48 +281,3 @@ where
|
|||
.div_scalar(n as f32 - correction_factor as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
|
||||
/// Backward pass of the tensor.
|
||||
pub fn backward(&self) -> B::Gradients {
|
||||
B::backward::<D>(self.primitive.clone())
|
||||
}
|
||||
|
||||
/// Get the gradients of a tensor if it exist.
|
||||
///
|
||||
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
|
||||
/// be accessed multiple times. If you only need to get the gradients one time,
|
||||
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
|
||||
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
|
||||
B::grad(&self.primitive, grads).map(Tensor::new)
|
||||
}
|
||||
|
||||
/// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
|
||||
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
|
||||
B::grad_remove(&self.primitive, grads).map(Tensor::new)
|
||||
}
|
||||
|
||||
/// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
|
||||
/// gradient.
|
||||
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
|
||||
B::grad_replace(&self.primitive, grads, grad.primitive);
|
||||
}
|
||||
|
||||
/// Returns the inner tensor without the autodiff information.
|
||||
pub fn inner(self) -> Tensor<B::InnerBackend, D> {
|
||||
Tensor::new(B::inner(self.primitive))
|
||||
}
|
||||
|
||||
/// Convert a tensor to the autodiff backend.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `inner` - The tensor to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor converted to the autodiff backend.
|
||||
pub fn from_inner(inner: Tensor<B::InnerBackend, D>) -> Self {
|
||||
Self::new(B::from_inner(inner.primitive))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ pub struct Bool;
|
|||
/// A type-level representation of the kind of a tensor.
|
||||
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
|
||||
/// The primitive type of the tensor.
|
||||
type Primitive<const D: usize>: Clone + core::fmt::Debug;
|
||||
type Primitive<const D: usize>: Clone + core::fmt::Debug + Sync + Send;
|
||||
|
||||
/// The name of the tensor kind.
|
||||
fn name() -> &'static str;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
pub(crate) mod check;
|
||||
|
||||
mod autodiff;
|
||||
mod base;
|
||||
mod bool;
|
||||
mod float;
|
||||
|
@ -7,6 +8,7 @@ mod int;
|
|||
mod kind;
|
||||
mod numeric;
|
||||
|
||||
pub use autodiff::*;
|
||||
pub use base::*;
|
||||
pub use kind::*;
|
||||
pub use numeric::*;
|
||||
|
|
|
@ -178,6 +178,29 @@ pub trait AutodiffBackend: Backend {
|
|||
/// The inner backend tensor.
|
||||
fn inner<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self::InnerBackend, D>;
|
||||
|
||||
/// Returns the tensor with inner backend type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the inner backend tensor for.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The inner backend tensor.
|
||||
fn int_inner<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self::InnerBackend, D>;
|
||||
|
||||
/// Returns the tensor with inner backend type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the inner backend tensor for.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The inner backend tensor.
|
||||
fn bool_inner<const D: usize>(tensor: BoolTensor<Self, D>)
|
||||
-> BoolTensor<Self::InnerBackend, D>;
|
||||
|
||||
/// Converts the inner backend tensor to the autodiff backend tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -191,4 +214,32 @@ pub trait AutodiffBackend: Backend {
|
|||
fn from_inner<const D: usize>(
|
||||
tensor: FloatTensor<Self::InnerBackend, D>,
|
||||
) -> FloatTensor<Self, D>;
|
||||
|
||||
/// Converts the inner backend tensor to the autodiff backend tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The inner backend tensor to convert.
|
||||
///
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The autodiff backend tensor.
|
||||
fn int_from_inner<const D: usize>(
|
||||
tensor: IntTensor<Self::InnerBackend, D>,
|
||||
) -> IntTensor<Self, D>;
|
||||
|
||||
/// Converts the inner backend tensor to the autodiff backend tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The inner backend tensor to convert.
|
||||
///
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The autodiff backend tensor.
|
||||
fn bool_from_inner<const D: usize>(
|
||||
tensor: BoolTensor<Self::InnerBackend, D>,
|
||||
) -> BoolTensor<Self, D>;
|
||||
}
|
||||
|
|
|
@ -11,14 +11,14 @@ type ElemType = f32;
|
|||
#[cfg(feature = "f16")]
|
||||
type ElemType = burn::tensor::f16;
|
||||
|
||||
pub fn launch<B: AutodiffBackend>(device: B::Device) {
|
||||
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
|
||||
let config = ExperimentConfig::new(
|
||||
TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true),
|
||||
AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),
|
||||
);
|
||||
|
||||
text_classification::training::train::<B, AgNewsDataset>(
|
||||
device,
|
||||
devices,
|
||||
AgNewsDataset::train(),
|
||||
AgNewsDataset::test(),
|
||||
config,
|
||||
|
@ -39,7 +39,7 @@ mod ndarray {
|
|||
use crate::{launch, ElemType};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<NdArray<ElemType>>>(NdArrayDevice::Cpu);
|
||||
launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ mod tch_gpu {
|
|||
#[cfg(target_os = "macos")]
|
||||
let device = LibTorchDevice::Mps;
|
||||
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(device);
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(vec![device]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,7 +68,7 @@ mod tch_cpu {
|
|||
use crate::{launch, ElemType};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(LibTorchDevice::Cpu);
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(vec![LibTorchDevice::Cpu]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,9 @@ mod wgpu {
|
|||
use burn::backend::{Autodiff, Fusion};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<Fusion<Wgpu<AutoGraphicsApi, ElemType, i32>>>>(WgpuDevice::default());
|
||||
launch::<Autodiff<Fusion<Wgpu<AutoGraphicsApi, ElemType, i32>>>>(vec![
|
||||
WgpuDevice::default(),
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -11,14 +11,14 @@ type ElemType = f32;
|
|||
#[cfg(feature = "f16")]
|
||||
type ElemType = burn::tensor::f16;
|
||||
|
||||
pub fn launch<B: AutodiffBackend>(device: B::Device) {
|
||||
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
|
||||
let config = ExperimentConfig::new(
|
||||
TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true),
|
||||
AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),
|
||||
);
|
||||
|
||||
text_classification::training::train::<B, DbPediaDataset>(
|
||||
device,
|
||||
devices,
|
||||
DbPediaDataset::train(),
|
||||
DbPediaDataset::test(),
|
||||
config,
|
||||
|
@ -38,7 +38,7 @@ mod ndarray {
|
|||
use burn::backend::Autodiff;
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<NdArray<ElemType>>>(NdArrayDevice::Cpu);
|
||||
launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ mod tch_gpu {
|
|||
#[cfg(target_os = "macos")]
|
||||
let device = LibTorchDevice::Mps;
|
||||
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(device);
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(vec![device]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,7 +67,7 @@ mod tch_cpu {
|
|||
use crate::{launch, ElemType};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(LibTorchDevice::Cpu);
|
||||
launch::<Autodiff<LibTorch<ElemType>>>(vec![LibTorchDevice::Cpu]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ mod wgpu {
|
|||
use crate::{launch, ElemType};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default());
|
||||
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
pub fn forward(&self, item: TextClassificationTrainingBatch<B>) -> ClassificationOutput<B> {
|
||||
// Get batch and sequence length, and the device
|
||||
let [batch_size, seq_length] = item.tokens.dims();
|
||||
let device = &self.embedding_token.devices()[0];
|
||||
let device = &self.embedding_token.devices(Vec::new())[0];
|
||||
|
||||
// Move tensors to the correct device
|
||||
let tokens = item.tokens.to_device(device);
|
||||
|
@ -128,7 +128,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
pub fn infer(&self, item: TextClassificationInferenceBatch<B>) -> Tensor<B, 2> {
|
||||
// Get batch and sequence length, and the device
|
||||
let [batch_size, seq_length] = item.tokens.dims();
|
||||
let device = &self.embedding_token.devices()[0];
|
||||
let device = &self.embedding_token.devices(Vec::new())[0];
|
||||
|
||||
// Move tensors to the correct device
|
||||
let tokens = item.tokens.to_device(device);
|
||||
|
|
|
@ -40,7 +40,7 @@ pub struct ExperimentConfig {
|
|||
|
||||
// Define train function
|
||||
pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
||||
device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device)
|
||||
devices: Vec<B::Device>, // Device on which to perform computation (e.g., CPU or CUDA device)
|
||||
dataset_train: D, // Training dataset
|
||||
dataset_test: D, // Testing dataset
|
||||
config: ExperimentConfig, // Experiment configuration
|
||||
|
@ -52,12 +52,12 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|||
// Initialize batchers for training and testing data
|
||||
let batcher_train = TextClassificationBatcher::<B>::new(
|
||||
tokenizer.clone(),
|
||||
device.clone(),
|
||||
devices[0].clone(),
|
||||
config.max_seq_length,
|
||||
);
|
||||
let batcher_test = TextClassificationBatcher::<B::InnerBackend>::new(
|
||||
tokenizer.clone(),
|
||||
device.clone(),
|
||||
devices[0].clone(),
|
||||
config.max_seq_length,
|
||||
);
|
||||
|
||||
|
@ -93,13 +93,15 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train(CUDAMetric::new())
|
||||
.metric_valid(CUDAMetric::new())
|
||||
.metric_train(AccuracyMetric::new())
|
||||
.metric_valid(AccuracyMetric::new())
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.metric_train_numeric(LearningRateMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device])
|
||||
.devices(devices)
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(model, optim, lr_scheduler);
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ impl<B: Backend> TextGenerationModel<B> {
|
|||
item: TrainingTextGenerationBatch<B>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let [batch_size, seq_length] = item.tokens_inputs.dims();
|
||||
let device = &self.devices()[0];
|
||||
let device = &self.devices(Vec::new())[0];
|
||||
|
||||
let inputs = item.tokens_inputs.to_device(device);
|
||||
let targets = item.targets.to_device(device);
|
||||
|
|
Loading…
Reference in New Issue