From 59d41bd4b2c2883f14d5aa763e5b40d4b411c9cf Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 3 Sep 2024 09:39:12 -0400 Subject: [PATCH] Remove copy restriction for const generic modules (#2222) --- .../burn-core/src/module/param/primitive.rs | 11 ++--- crates/burn-core/tests/test_derive_module.rs | 43 +++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/crates/burn-core/src/module/param/primitive.rs b/crates/burn-core/src/module/param/primitive.rs index 719b61d5c..6bde30fd7 100644 --- a/crates/burn-core/src/module/param/primitive.rs +++ b/crates/burn-core/src/module/param/primitive.rs @@ -173,8 +173,7 @@ where impl Module for [T; N] where - T: Module + Debug + Send + Clone + Copy, - T::Record: Debug, + T: Module + Debug + Send + Clone, B: Backend, { type Record = [T::Record; N]; @@ -245,16 +244,14 @@ impl ModuleDisplay for [T; N] {} impl AutodiffModule for [T; N] where - T: AutodiffModule + Debug + Send + Clone + Copy, - T::InnerModule: Copy + Debug, - >::Record: Debug, - >::Record: Debug, + T: AutodiffModule + Debug + Send + Clone, + T::InnerModule: Debug, B: AutodiffBackend, { type InnerModule = [T::InnerModule; N]; fn valid(&self) -> Self::InnerModule { - self.map(|module| module.valid()) + self.clone().map(|module| module.valid()) } } diff --git a/crates/burn-core/tests/test_derive_module.rs b/crates/burn-core/tests/test_derive_module.rs index 6c75297fc..45270638a 100644 --- a/crates/burn-core/tests/test_derive_module.rs +++ b/crates/burn-core/tests/test_derive_module.rs @@ -32,6 +32,11 @@ impl ModuleBasic { } } +#[derive(Module, Debug)] +struct ModuleWithConstGeneric { + modules: [ModuleBasic; N], +} + #[derive(Module, Debug)] struct ModuleWithGenericModule { module: M, @@ -151,6 +156,44 @@ mod state { ); } + #[test] + fn should_load_from_record_const_generic() { + let device = ::Device::default(); + let module_1 = ModuleWithConstGeneric { + modules: [ + ModuleBasic::::new(&device), + ModuleBasic::::new(&device), + ], + }; + let mut module_2 = ModuleWithConstGeneric { + modules: [ + ModuleBasic::::new(&device), + ModuleBasic::::new(&device), + ], + }; + let state_1 = module_1.clone().into_record(); + + assert_ne!( + module_1.modules[0].weight_basic.to_data(), + module_2.modules[0].weight_basic.to_data(), + ); + assert_ne!( + module_1.modules[1].weight_basic.to_data(), + module_2.modules[1].weight_basic.to_data(), + ); + + module_2 = module_2.load_record(state_1); + + assert_eq!( + module_1.modules[0].weight_basic.to_data(), + module_2.modules[0].weight_basic.to_data(), + ); + assert_eq!( + module_1.modules[1].weight_basic.to_data(), + module_2.modules[1].weight_basic.to_data(), + ); + } + #[test] #[should_panic(expected = "Can't parse record from a different variant")] fn should_panic_load_from_incorrect_enum_variant() {