Added tuple modules (#1186)

This commit is contained in:
Roy Varon 2024-01-29 22:56:31 +00:00 committed by GitHub
parent 67ffa1e54b
commit f5ac5d8e9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 110 additions and 1 deletions

View File

@ -197,3 +197,73 @@ where
self.map(|module| module.valid())
}
}
/// A macro for generating implementations for tuple modules of different sizes.
/// For example: `impl_module_tuple!([L0, L1][0, 1])`.
/// Would generate an implementation for a tuple of size 2.
/// For this macro to work properly, please adhear to the convention:
/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.
macro_rules! impl_module_tuple {
// `$l` represents the generic modules.
// `$i` represents the indices of the modules in the tuple.
([$($l:ident),*][$($i:tt),*]) => {
impl<B, $($l,)*> Module<B> for ($($l,)*)
where
B: Backend,
$($l: Module<B> + Debug + Send + Sync + Clone,)*
{
type Record = ($($l::Record),*);
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
$(devices = self.$i.collect_devices(devices);)*
devices
}
fn fork(self, device: &<B as Backend>::Device) -> Self {
($(self.$i.fork(device),)*)
}
fn to_device(self, device: &<B as Backend>::Device) -> Self {
($(self.$i.to_device(device),)*)
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
$(self.$i.visit(visitor);)*
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
($(self.$i.map(mapper),)*)
}
fn load_record(self, record: Self::Record) -> Self {
($(self.$i.load_record(record.$i),)*)
}
fn into_record(self) -> Self::Record {
($(self.$i.into_record(),)*)
}
}
impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)
where
B: AutodiffBackend,
$($l: AutodiffModule<B> + Debug + Send + Sync + Clone,)*
{
type InnerModule = ($($l::InnerModule,)*);
fn valid(&self) -> Self::InnerModule {
($(self.$i.valid(),)*)
}
}
};
}
impl_module_tuple!([L0, L1][0, 1]);
impl_module_tuple!([L0, L1, L2][0, 1, 2]);
impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);

View File

@ -84,6 +84,43 @@ where
}
}
/// A macro for generating implementations for tuple records of different sizes.
/// For example: `impl_record_tuple!([R0, R1][0, 1])`.
/// Would generate an implementation for a tuple of size 2.
/// For this macro to work properly, please adhear to the convention:
/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.
macro_rules! impl_record_tuple {
// `$r` represents the generic records.
// `$i` represents the indices of the records in the tuple.
([$($r:ident),*][$($i:tt),*]) => {
impl<B, $($r,)*> Record<B> for ($($r,)*)
where
B: Backend,
$($r: Record<B>),*
{
type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
($(self.$i.into_item(),)*)
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
($(Record::from_item(item.$i, device),)*)
}
}
};
}
impl_record_tuple!([R0, R1][0, 1]);
impl_record_tuple!([R0, R1, R2][0, 1, 2]);
impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
impl<T, B> Record<B> for HashMap<ParamId, T>
where
T: Record<B>,

View File

@ -38,6 +38,7 @@ struct ModuleWithGenericModule<B: Backend, M> {
pub struct ModuleComposed<B: Backend> {
weight: Param<Tensor<B, 2>>,
basic: ModuleBasic<B>,
tuple: (ModuleBasic<B>, ModuleBasic<B>),
}
impl<B: Backend> ModuleComposed<B> {
@ -46,6 +47,7 @@ impl<B: Backend> ModuleComposed<B> {
Self {
weight: Param::from(weight),
basic: ModuleBasic::new(device),
tuple: (ModuleBasic::new(device), ModuleBasic::new(device)),
}
}
}
@ -109,7 +111,7 @@ mod num_params {
fn should_output_state_composed() {
let device = <TestBackend as Backend>::Device::default();
let module = ModuleComposed::<TestBackend>::new(&device);
assert_eq!(2 * 20 * 20, module.num_params());
assert_eq!(4 * 20 * 20, module.num_params());
}
}