Remove copy restriction for const generic modules (#2222)

This commit is contained in:
Guillaume Lagrange 2024-09-03 09:39:12 -04:00 committed by GitHub
parent cc214d366c
commit 59d41bd4b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 7 deletions

View File

@ -173,8 +173,7 @@ where
impl<const N: usize, T, B> Module<B> for [T; N]
where
T: Module<B> + Debug + Send + Clone + Copy,
T::Record: Debug,
T: Module<B> + Debug + Send + Clone,
B: Backend,
{
type Record = [T::Record; N];
@ -245,16 +244,14 @@ impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
where
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
T::InnerModule: Copy + Debug,
<T::InnerModule as Module<B::InnerBackend>>::Record: Debug,
<T as Module<B>>::Record: Debug,
T: AutodiffModule<B> + Debug + Send + Clone,
T::InnerModule: Debug,
B: AutodiffBackend,
{
type InnerModule = [T::InnerModule; N];
fn valid(&self) -> Self::InnerModule {
self.map(|module| module.valid())
self.clone().map(|module| module.valid())
}
}

View File

@ -32,6 +32,11 @@ impl<B: Backend> ModuleBasic<B> {
}
}
#[derive(Module, Debug)]
struct ModuleWithConstGeneric<B: Backend, const N: usize> {
modules: [ModuleBasic<B>; N],
}
#[derive(Module, Debug)]
struct ModuleWithGenericModule<B: Backend, M> {
module: M,
@ -151,6 +156,44 @@ mod state {
);
}
#[test]
fn should_load_from_record_const_generic() {
let device = <TestBackend as Backend>::Device::default();
let module_1 = ModuleWithConstGeneric {
modules: [
ModuleBasic::<TestBackend>::new(&device),
ModuleBasic::<TestBackend>::new(&device),
],
};
let mut module_2 = ModuleWithConstGeneric {
modules: [
ModuleBasic::<TestBackend>::new(&device),
ModuleBasic::<TestBackend>::new(&device),
],
};
let state_1 = module_1.clone().into_record();
assert_ne!(
module_1.modules[0].weight_basic.to_data(),
module_2.modules[0].weight_basic.to_data(),
);
assert_ne!(
module_1.modules[1].weight_basic.to_data(),
module_2.modules[1].weight_basic.to_data(),
);
module_2 = module_2.load_record(state_1);
assert_eq!(
module_1.modules[0].weight_basic.to_data(),
module_2.modules[0].weight_basic.to_data(),
);
assert_eq!(
module_1.modules[1].weight_basic.to_data(),
module_2.modules[1].weight_basic.to_data(),
);
}
#[test]
#[should_panic(expected = "Can't parse record from a different variant")]
fn should_panic_load_from_incorrect_enum_variant() {