Fix: serde dependency (#1091)

* Re-export serde

* Fix

* USe another strategy

* Fix

* Fix

* Update de book
This commit is contained in:
Nathaniel Simard 2023-12-22 16:53:34 -05:00 committed by GitHub
parent b0a2b30ed1
commit fceb036c6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 18 deletions

View File

@ -1,8 +1,7 @@
# Model
The first step is to create a project and add the different Burn dependencies. In the `Cargo.toml`
file, add the `burn` dependency with `train` and `wgpu` features. Note that the `serde` dependency
is also mandatory for the time being, as it is needed for serialization.
file, add the `burn` dependency with `train` and `wgpu` features.
```toml
[package]
@ -12,9 +11,6 @@ edition = "2021"
[dependencies]
burn = { version = "0.12.0", features=["train", "wgpu"]}
# Serialization
serde = "1"
```
Our goal will be to create a basic convolutional neural network used for image classification. We

View File

@ -6,6 +6,9 @@
#[macro_use]
extern crate derive_new;
/// Re-export serde for proc macros.
pub use serde;
/// The configuration module.
pub mod config;

View File

@ -22,7 +22,8 @@ impl ConfigEnumAnalyzer {
let data = &self.data.variants;
quote! {
#[derive(serde::Serialize, serde::Deserialize)]
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
enum #enum_name {
#data
}
@ -80,10 +81,10 @@ impl ConfigEnumAnalyzer {
let name = &self.name;
quote! {
impl serde::Serialize for #name {
impl burn::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer {
S: burn::serde::Serializer {
let serde_state = match self {
#(#variants),*
};
@ -105,10 +106,10 @@ impl ConfigEnumAnalyzer {
let name = &self.name;
quote! {
impl<'de> serde::Deserialize<'de> for #name {
impl<'de> burn::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de> {
D: burn::serde::Deserializer<'de> {
let serde_state = #enum_name::deserialize(deserializer)?;
Ok(match serde_state {
#(#variants),*

View File

@ -85,12 +85,13 @@ impl ConfigStructAnalyzer {
});
quote! {
impl serde::Serialize for #name {
impl burn::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer {
#[derive(serde::Serialize)]
S: burn::serde::Serializer {
#[derive(burn::serde::Serialize)]
#[serde(crate = "burn::serde")]
#struct_gen
let serde_state = #struct_name {
@ -116,11 +117,12 @@ impl ConfigStructAnalyzer {
});
quote! {
impl<'de> serde::Deserialize<'de> for #name {
impl<'de> burn::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de> {
#[derive(serde::Deserialize)]
D: burn::serde::Deserializer<'de> {
#[derive(burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#struct_gen
let serde_state = #struct_name::deserialize(deserializer)?;

View File

@ -24,7 +24,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
pub #name: <#ty as burn::record::Record>::Item<S>,
});
bounds.extend(quote! {
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
<#ty as burn::record::Record>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
});
}
let bound = bounds.to_string();
@ -32,7 +32,8 @@ impl RecordItemCodegen for StructRecordItemCodegen {
quote! {
/// The record item type for the module.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#[serde(bound = #bound)]
pub struct #item_name #generics {
#fields