From fceb036c6f85c77189ce460c342ac010a5b7240b Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Fri, 22 Dec 2023 16:53:34 -0500 Subject: [PATCH] Fix: serde dependency (#1091) * Re-export serde * Fix * USe another strategy * Fix * Fix * Update de book --- burn-book/src/basic-workflow/model.md | 6 +----- burn-core/src/lib.rs | 3 +++ burn-derive/src/config/analyzer_enum.rs | 11 ++++++----- burn-derive/src/config/analyzer_struct.rs | 14 ++++++++------ burn-derive/src/record/codegen_struct.rs | 5 +++-- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md index ed4285d99..f10d62d2f 100644 --- a/burn-book/src/basic-workflow/model.md +++ b/burn-book/src/basic-workflow/model.md @@ -1,8 +1,7 @@ # Model The first step is to create a project and add the different Burn dependencies. In the `Cargo.toml` -file, add the `burn` dependency with `train` and `wgpu` features. Note that the `serde` dependency -is also mandatory for the time being, as it is needed for serialization. +file, add the `burn` dependency with `train` and `wgpu` features. ```toml [package] @@ -12,9 +11,6 @@ edition = "2021" [dependencies] burn = { version = "0.12.0", features=["train", "wgpu"]} - -# Serialization -serde = "1" ``` Our goal will be to create a basic convolutional neural network used for image classification. We diff --git a/burn-core/src/lib.rs b/burn-core/src/lib.rs index 7d729948e..d7effc133 100644 --- a/burn-core/src/lib.rs +++ b/burn-core/src/lib.rs @@ -6,6 +6,9 @@ #[macro_use] extern crate derive_new; +/// Re-export serde for proc macros. +pub use serde; + /// The configuration module. pub mod config; diff --git a/burn-derive/src/config/analyzer_enum.rs b/burn-derive/src/config/analyzer_enum.rs index 2f7e2347b..0f00c5ea2 100644 --- a/burn-derive/src/config/analyzer_enum.rs +++ b/burn-derive/src/config/analyzer_enum.rs @@ -22,7 +22,8 @@ impl ConfigEnumAnalyzer { let data = &self.data.variants; quote! { - #[derive(serde::Serialize, serde::Deserialize)] + #[derive(burn::serde::Serialize, burn::serde::Deserialize)] + #[serde(crate = "burn::serde")] enum #enum_name { #data } @@ -80,10 +81,10 @@ impl ConfigEnumAnalyzer { let name = &self.name; quote! { - impl serde::Serialize for #name { + impl burn::serde::Serialize for #name { fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer { + S: burn::serde::Serializer { let serde_state = match self { #(#variants),* }; @@ -105,10 +106,10 @@ impl ConfigEnumAnalyzer { let name = &self.name; quote! { - impl<'de> serde::Deserialize<'de> for #name { + impl<'de> burn::serde::Deserialize<'de> for #name { fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> { + D: burn::serde::Deserializer<'de> { let serde_state = #enum_name::deserialize(deserializer)?; Ok(match serde_state { #(#variants),* diff --git a/burn-derive/src/config/analyzer_struct.rs b/burn-derive/src/config/analyzer_struct.rs index 18ec62c16..42bfd98b0 100644 --- a/burn-derive/src/config/analyzer_struct.rs +++ b/burn-derive/src/config/analyzer_struct.rs @@ -85,12 +85,13 @@ impl ConfigStructAnalyzer { }); quote! { - impl serde::Serialize for #name { + impl burn::serde::Serialize for #name { fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer { - #[derive(serde::Serialize)] + S: burn::serde::Serializer { + #[derive(burn::serde::Serialize)] + #[serde(crate = "burn::serde")] #struct_gen let serde_state = #struct_name { @@ -116,11 +117,12 @@ impl ConfigStructAnalyzer { }); quote! { - impl<'de> serde::Deserialize<'de> for #name { + impl<'de> burn::serde::Deserialize<'de> for #name { fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> { - #[derive(serde::Deserialize)] + D: burn::serde::Deserializer<'de> { + #[derive(burn::serde::Deserialize)] + #[serde(crate = "burn::serde")] #struct_gen let serde_state = #struct_name::deserialize(deserializer)?; diff --git a/burn-derive/src/record/codegen_struct.rs b/burn-derive/src/record/codegen_struct.rs index 331b38b30..9a5dca26e 100644 --- a/burn-derive/src/record/codegen_struct.rs +++ b/burn-derive/src/record/codegen_struct.rs @@ -24,7 +24,7 @@ impl RecordItemCodegen for StructRecordItemCodegen { pub #name: <#ty as burn::record::Record>::Item, }); bounds.extend(quote! { - <#ty as burn::record::Record>::Item: serde::Serialize + serde::de::DeserializeOwned, + <#ty as burn::record::Record>::Item: burn::serde::Serialize + burn::serde::de::DeserializeOwned, }); } let bound = bounds.to_string(); @@ -32,7 +32,8 @@ impl RecordItemCodegen for StructRecordItemCodegen { quote! { /// The record item type for the module. - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + #[derive(Debug, Clone, burn::serde::Serialize, burn::serde::Deserialize)] + #[serde(crate = "burn::serde")] #[serde(bound = #bound)] pub struct #item_name #generics { #fields