mirror of https://github.com/tracel-ai/burn.git
Added tuple modules (#1186)
This commit is contained in:
parent
67ffa1e54b
commit
f5ac5d8e9f
|
@ -197,3 +197,73 @@ where
|
|||
self.map(|module| module.valid())
|
||||
}
|
||||
}
|
||||
|
||||
/// A macro for generating implementations for tuple modules of different sizes.
|
||||
/// For example: `impl_module_tuple!([L0, L1][0, 1])`.
|
||||
/// Would generate an implementation for a tuple of size 2.
|
||||
/// For this macro to work properly, please adhear to the convention:
|
||||
/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.
|
||||
macro_rules! impl_module_tuple {
|
||||
// `$l` represents the generic modules.
|
||||
// `$i` represents the indices of the modules in the tuple.
|
||||
([$($l:ident),*][$($i:tt),*]) => {
|
||||
impl<B, $($l,)*> Module<B> for ($($l,)*)
|
||||
where
|
||||
B: Backend,
|
||||
$($l: Module<B> + Debug + Send + Sync + Clone,)*
|
||||
{
|
||||
type Record = ($($l::Record),*);
|
||||
|
||||
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
|
||||
$(devices = self.$i.collect_devices(devices);)*
|
||||
devices
|
||||
}
|
||||
|
||||
fn fork(self, device: &<B as Backend>::Device) -> Self {
|
||||
($(self.$i.fork(device),)*)
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as Backend>::Device) -> Self {
|
||||
($(self.$i.to_device(device),)*)
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
$(self.$i.visit(visitor);)*
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
($(self.$i.map(mapper),)*)
|
||||
}
|
||||
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
($(self.$i.load_record(record.$i),)*)
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
($(self.$i.into_record(),)*)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
$($l: AutodiffModule<B> + Debug + Send + Sync + Clone,)*
|
||||
{
|
||||
type InnerModule = ($($l::InnerModule,)*);
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
($(self.$i.valid(),)*)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_module_tuple!([L0, L1][0, 1]);
|
||||
impl_module_tuple!([L0, L1, L2][0, 1, 2]);
|
||||
impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
|
|
|
@ -84,6 +84,43 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// A macro for generating implementations for tuple records of different sizes.
|
||||
/// For example: `impl_record_tuple!([R0, R1][0, 1])`.
|
||||
/// Would generate an implementation for a tuple of size 2.
|
||||
/// For this macro to work properly, please adhear to the convention:
|
||||
/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.
|
||||
macro_rules! impl_record_tuple {
|
||||
// `$r` represents the generic records.
|
||||
// `$i` represents the indices of the records in the tuple.
|
||||
([$($r:ident),*][$($i:tt),*]) => {
|
||||
impl<B, $($r,)*> Record<B> for ($($r,)*)
|
||||
where
|
||||
B: Backend,
|
||||
$($r: Record<B>),*
|
||||
{
|
||||
type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
($(self.$i.into_item(),)*)
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
($(Record::from_item(item.$i, device),)*)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_record_tuple!([R0, R1][0, 1]);
|
||||
impl_record_tuple!([R0, R1, R2][0, 1, 2]);
|
||||
impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
|
||||
impl<T, B> Record<B> for HashMap<ParamId, T>
|
||||
where
|
||||
T: Record<B>,
|
||||
|
|
|
@ -38,6 +38,7 @@ struct ModuleWithGenericModule<B: Backend, M> {
|
|||
pub struct ModuleComposed<B: Backend> {
|
||||
weight: Param<Tensor<B, 2>>,
|
||||
basic: ModuleBasic<B>,
|
||||
tuple: (ModuleBasic<B>, ModuleBasic<B>),
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleComposed<B> {
|
||||
|
@ -46,6 +47,7 @@ impl<B: Backend> ModuleComposed<B> {
|
|||
Self {
|
||||
weight: Param::from(weight),
|
||||
basic: ModuleBasic::new(device),
|
||||
tuple: (ModuleBasic::new(device), ModuleBasic::new(device)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -109,7 +111,7 @@ mod num_params {
|
|||
fn should_output_state_composed() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let module = ModuleComposed::<TestBackend>::new(&device);
|
||||
assert_eq!(2 * 20 * 20, module.num_params());
|
||||
assert_eq!(4 * 20 * 20, module.num_params());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue