diff --git a/burn-core/src/module/state.rs b/burn-core/src/module/state.rs index c319e2587..d55cd8f93 100644 --- a/burn-core/src/module/state.rs +++ b/burn-core/src/module/state.rs @@ -227,10 +227,6 @@ mod tests { } fn create_model() -> nn::Linear { - nn::Linear::::new(&nn::LinearConfig { - d_input: 32, - d_output: 32, - bias: true, - }) + nn::Linear::::new(&nn::LinearConfig::new(32, 32).with_bias(true)) } } diff --git a/burn-core/src/nn/conv/conv1d.rs b/burn-core/src/nn/conv/conv1d.rs index 69ee7ac4d..d3bff6b06 100644 --- a/burn-core/src/nn/conv/conv1d.rs +++ b/burn-core/src/nn/conv/conv1d.rs @@ -6,9 +6,9 @@ use crate as burn; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::nn::Initializer; use crate::tensor::backend::Backend; -use crate::tensor::ElementConversion; -use crate::tensor::{Distribution, Tensor}; +use crate::tensor::Tensor; use burn_tensor::module::conv1d; use burn_tensor::ops::conv::calculate_padding; @@ -28,6 +28,9 @@ pub struct Conv1dConfig { /// If bias should be added to the output. #[config(default = true)] pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::UniformDefault")] + pub initializer: Initializer, } /// Padding configuration for 1D convolution [config](Conv1dConfig). @@ -64,19 +67,17 @@ impl Conv1d { let k = (config.channels_in * config.kernel_size) as f64; let k = sqrt(1.0 / k); - let k1: B::FloatElem = (-k).elem(); - let k2: B::FloatElem = k.elem(); + let initializer = if let Initializer::UniformDefault = config.initializer { + Initializer::Uniform(-k, k) + } else { + config.initializer.clone() + }; - let weight = Tensor::random( - [config.channels_out, config.channels_in, config.kernel_size], - Distribution::Uniform(k1, k2), - ); + let weight = + initializer.init([config.channels_out, config.channels_in, config.kernel_size]); let bias = if config.bias { - Some(Tensor::random( - [config.channels_out], - Distribution::Uniform(k1, k2), - )) + Some(initializer.init([config.channels_out])) } else { None }; @@ -119,3 +120,35 @@ impl Conv1d { ) } } + +#[cfg(test)] +mod tests { + use super::*; + pub type TB = burn_ndarray::NdArrayBackend; + + #[test] + fn initializer_default() { + TB::seed(0); + let config = Conv1dConfig::new(5, 5, 5); + let k = (config.channels_in * config.kernel_size) as f64; + let k = sqrt(1.0 / k); + assert_eq!(config.initializer, Initializer::UniformDefault); + let conv: Conv1d = Conv1d::new(&config); + for item in conv.weight.to_data().value.iter() { + if *item < -k as f32 || *item > k as f32 { + panic!("Element ({item}) is not within the range of (-{k},{k})"); + } + } + } + + #[test] + fn initializer_zeros() { + TB::seed(0); + let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros); + assert_eq!(config.initializer, Initializer::Zeros); + let conv: Conv1d = Conv1d::new(&config); + for item in conv.weight.to_data().value.iter() { + assert_eq!(*item, 0.0f32); + } + } +} diff --git a/burn-core/src/nn/conv/conv2d.rs b/burn-core/src/nn/conv/conv2d.rs index c2544bcf0..0a897313c 100644 --- a/burn-core/src/nn/conv/conv2d.rs +++ b/burn-core/src/nn/conv/conv2d.rs @@ -5,9 +5,9 @@ use crate as burn; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::nn::Initializer; use crate::tensor::backend::Backend; -use crate::tensor::ElementConversion; -use crate::tensor::{Distribution, Tensor}; +use crate::tensor::Tensor; use burn_tensor::module::conv2d; use burn_tensor::ops::conv::calculate_padding; @@ -26,6 +26,9 @@ pub struct Conv2dConfig { /// If bias should be added to the output. #[config(default = true)] pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::UniformDefault")] + pub initializer: Initializer, } /// Padding configuration for 2D convolution [config](Conv2dConfig). @@ -64,24 +67,21 @@ impl Conv2d { let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; let k = sqrt(1.0 / k); - let k1: B::FloatElem = (-k).elem(); - let k2: B::FloatElem = k.elem(); + let initializer = if let Initializer::UniformDefault = config.initializer { + Initializer::Uniform(-k, k) + } else { + config.initializer.clone() + }; - let weight = Tensor::random( - [ - config.channels[1], - config.channels[0], - config.kernel_size[0], - config.kernel_size[1], - ], - Distribution::Uniform(k1, k2), - ); + let weight = initializer.init([ + config.channels[1], + config.channels[0], + config.kernel_size[0], + config.kernel_size[1], + ]); let bias = if config.bias { - Some(Tensor::random( - [config.channels[1]], - Distribution::Uniform(k1, k2), - )) + Some(initializer.init([config.channels[1]])) } else { None }; @@ -138,3 +138,35 @@ impl Conv2dPaddingConfig { } } } + +#[cfg(test)] +mod tests { + use super::*; + pub type TB = burn_ndarray::NdArrayBackend; + + #[test] + fn initializer_default() { + TB::seed(0); + let config = Conv2dConfig::new([5, 1], [5, 5]); + let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; + let k = sqrt(1.0 / k); + assert_eq!(config.initializer, Initializer::UniformDefault); + let conv: Conv2d = Conv2d::new(&config); + for item in conv.weight.to_data().value.iter() { + if *item < -k as f32 || *item > k as f32 { + panic!("Element ({item}) is not within the range of (-{k},{k})"); + } + } + } + + #[test] + fn initializer_zeros() { + TB::seed(0); + let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); + assert_eq!(config.initializer, Initializer::Zeros); + let conv: Conv2d = Conv2d::new(&config); + for item in conv.weight.to_data().value.iter() { + assert_eq!(*item, 0.0f32); + } + } +} diff --git a/burn-core/src/nn/embedding.rs b/burn-core/src/nn/embedding.rs index a6ffeffd1..1e3548436 100644 --- a/burn-core/src/nn/embedding.rs +++ b/burn-core/src/nn/embedding.rs @@ -3,11 +3,12 @@ use burn_tensor::Int; use crate as burn; +use super::Initializer; use crate::config::Config; use crate::module::Module; use crate::module::Param; use crate::tensor::backend::Backend; -use crate::tensor::{Distribution, Tensor}; +use crate::tensor::Tensor; /// Configuration to create an [Embedding](Embedding) layer. #[derive(Config)] @@ -16,6 +17,9 @@ pub struct EmbeddingConfig { n_embedding: usize, /// The size of each vector. d_model: usize, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::Normal(0.0,1.0)")] + pub initializer: Initializer, } /// Lookup table to store a fix number of vectors. @@ -32,11 +36,10 @@ pub struct Embedding { impl Embedding { /// Create the module from the given configuration. pub fn new(config: &EmbeddingConfig) -> Self { - let weight = Tensor::random( - [config.n_embedding, config.d_model], - Distribution::Normal(0.0, 1.0), - ) - .require_grad(); + let weight = config + .initializer + .init([config.n_embedding, config.d_model]) + .require_grad(); Self { weight: Param::from(weight), @@ -53,3 +56,36 @@ impl Embedding { burn_tensor::module::embedding(self.weight.val(), input) } } + +#[cfg(test)] +mod tests { + use burn_tensor::Data; + + use super::*; + pub type TB = burn_ndarray::NdArrayBackend; + + #[test] + fn initializer_default() { + TB::seed(0); + let config = EmbeddingConfig::new(100, 10); + assert_eq!(config.initializer, Initializer::Normal(0.0, 1.0)); + let embed: Embedding = Embedding::new(&config); + let weights = embed.weight.val().reshape([1000]); + let (var_act, mean_act) = weights.var_mean(0); + var_act.to_data().assert_approx_eq(&Data::from([1.0f32]), 1); + mean_act + .to_data() + .assert_approx_eq(&Data::from([0.0f32]), 1); + } + + #[test] + fn initializer_zeros() { + TB::seed(0); + let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros); + assert_eq!(config.initializer, Initializer::Zeros); + let conv: Embedding = Embedding::new(&config); + for item in conv.weight.to_data().value.iter() { + assert_eq!(*item, 0.0f32); + } + } +} diff --git a/burn-core/src/nn/initializer.rs b/burn-core/src/nn/initializer.rs new file mode 100644 index 000000000..abaa70eb4 --- /dev/null +++ b/burn-core/src/nn/initializer.rs @@ -0,0 +1,107 @@ +use burn_tensor::Shape; + +use crate::config::Config; +use crate::tensor::backend::Backend; +use crate::tensor::{Distribution, ElementConversion, Tensor}; + +use crate as burn; + +#[derive(Config, Debug, PartialEq)] +pub enum Initializer { + Uniform(f64, f64), + UniformDefault, + Normal(f64, f64), + Constant(f64), + Ones, + Zeros, + // TODO: add Xavier initialization +} + +impl Initializer { + pub fn init>>(&self, shape: S) -> Tensor { + match self { + Self::Uniform(a, b) => Tensor::::random( + shape, + Distribution::Uniform((*a).elem::(), (*b).elem::()), + ), + Self::UniformDefault => unimplemented!("The caller should implement the default"), + Self::Normal(mean, std) => { + Tensor::::random(shape, Distribution::Normal(*mean, *std)) + } + Self::Constant(value) => Tensor::::zeros(shape) + *value, //TODO replace with fill() + Self::Ones => Tensor::::ones(shape), + Self::Zeros => Tensor::::zeros(shape), + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + use burn_tensor::Data; + + pub type TB = burn_ndarray::NdArrayBackend; + + #[test] + fn initializer_uniform_init() { + // seed random generator + TB::seed(0); + let (a, b) = (0.0, 1.0); + let uniform: Tensor = Initializer::Uniform(a, b).init([2, 2, 2, 2]); + for item in uniform.to_data().value.iter() { + if *item < a as f32 || *item > b as f32 { + panic!("Element ({item}) is not within range ({a},{b})"); + } + } + } + + #[test] + #[should_panic] + fn initializer_uniform_default_init() { + let _: Tensor = Initializer::UniformDefault.init([2, 2, 2, 2]); + } + + #[test] + fn initializer_normal_init() { + // seed random generator + TB::seed(0); + let (mean, std) = (0.0, 1.0); + let normal: Tensor = Initializer::Normal(mean, std).init([1000]); + let (var_act, mean_act) = normal.var_mean(0); + var_act + .to_data() + .assert_approx_eq(&Data::from([std as f32]), 1); + mean_act + .to_data() + .assert_approx_eq(&Data::from([mean as f32]), 1); + } + + #[test] + fn initializer_constant_init() { + let value = 5.0; + let constants: Tensor = Initializer::Constant(value).init([2, 2, 2, 2]); + constants + .sum() + .to_data() + .assert_approx_eq(&Data::from([value as f32 * 16.0]), 3); + } + + #[test] + fn initializer_zeros_init() { + let zeros: Tensor = Initializer::Zeros.init([2, 2, 2, 2]); + zeros + .sum() + .to_data() + .assert_approx_eq(&Data::from([0.0]), 3); + } + + #[test] + fn initializer_ones_init() { + let ones: Tensor = Initializer::Ones.init([2, 2, 2, 2]); + ones.sum() + .to_data() + .assert_approx_eq(&Data::from([16.0]), 3); + } +} diff --git a/burn-core/src/nn/linear.rs b/burn-core/src/nn/linear.rs index a611b21e0..cdb64a45a 100644 --- a/burn-core/src/nn/linear.rs +++ b/burn-core/src/nn/linear.rs @@ -5,11 +5,12 @@ use crate as burn; use crate::config::Config; use crate::module::Module; use crate::module::Param; -use crate::tensor::backend::Backend; -use crate::tensor::{Distribution, ElementConversion, Tensor}; +use crate::tensor::{backend::Backend, Tensor}; use libm::sqrt; +use super::Initializer; + /// Configuration to create a [Linear](Linear) layer. #[derive(Config)] pub struct LinearConfig { @@ -20,6 +21,9 @@ pub struct LinearConfig { /// If a bias should be applied during the linear transformation. #[config(default = true)] pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::UniformDefault")] + pub initializer: Initializer, } /// Applies a linear transformation to the input tensor: @@ -43,12 +47,19 @@ impl Linear { /// Create the module from the given configuration. pub fn new(config: &LinearConfig) -> Self { let k = sqrt(1.0 / config.d_input as f64); - let distribution = Distribution::Uniform((-1.0 * k).elem(), k.elem()); - let weight = Tensor::random([config.d_input, config.d_output], distribution); - let bias = match config.bias { - true => Some(Tensor::random([config.d_output], distribution)), - false => None, + let initializer = if let Initializer::UniformDefault = config.initializer { + Initializer::Uniform(-k, k) + } else { + config.initializer.clone() + }; + + let weight = initializer.init([config.d_input, config.d_output]); + + let bias = if config.bias { + Some(initializer.init([config.d_output])) + } else { + None }; Self { @@ -72,3 +83,35 @@ impl Linear { } } } + +#[cfg(test)] +mod tests { + use super::*; + pub type TB = burn_ndarray::NdArrayBackend; + + #[test] + fn initializer_default() { + TB::seed(0); + let config = LinearConfig::new(5, 5); + let k = sqrt(1.0 / config.d_input as f64); + + assert_eq!(config.initializer, Initializer::UniformDefault); + let conv: Linear = Linear::new(&config); + for item in conv.weight.to_data().value.iter() { + if *item < -k as f32 || *item > k as f32 { + panic!("Element ({item}) is not within the range of (-{k},{k})"); + } + } + } + + #[test] + fn initializer_zeros() { + TB::seed(0); + let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros); + assert_eq!(config.initializer, Initializer::Zeros); + let conv: Linear = Linear::new(&config); + for item in conv.weight.to_data().value.iter() { + assert_eq!(*item, 0.0f32); + } + } +} diff --git a/burn-core/src/nn/mod.rs b/burn-core/src/nn/mod.rs index 95bf92319..12c867eac 100644 --- a/burn-core/src/nn/mod.rs +++ b/burn-core/src/nn/mod.rs @@ -8,6 +8,7 @@ pub mod transformer; mod dropout; mod embedding; mod gelu; +mod initializer; mod linear; mod norm; mod relu; @@ -15,6 +16,7 @@ mod relu; pub use dropout::*; pub use embedding::*; pub use gelu::*; +pub use initializer::*; pub use linear::*; pub use norm::*; pub use relu::*; diff --git a/burn-core/src/optim/grad_accum.rs b/burn-core/src/optim/grad_accum.rs index c6f89df38..3751c8a26 100644 --- a/burn-core/src/optim/grad_accum.rs +++ b/burn-core/src/optim/grad_accum.rs @@ -106,11 +106,7 @@ mod tests { } fn layer() -> Linear { - Linear::::new(&LinearConfig { - d_input: 20, - d_output: 20, - bias: true, - }) + Linear::::new(&LinearConfig::new(20, 20).with_bias(true)) } fn random_tensor() -> Tensor { diff --git a/burn-core/src/optim/grads.rs b/burn-core/src/optim/grads.rs index 827f7588f..79560faaf 100644 --- a/burn-core/src/optim/grads.rs +++ b/burn-core/src/optim/grads.rs @@ -117,11 +117,7 @@ mod tests { } fn layer() -> Linear { - Linear::::new(&LinearConfig { - d_input: 20, - d_output: 20, - bias: true, - }) + Linear::::new(&LinearConfig::new(20, 20).with_bias(true)) } fn random_tensor() -> Tensor { diff --git a/burn-core/src/optim/sgd.rs b/burn-core/src/optim/sgd.rs index 7fb2fdb10..72a146bec 100644 --- a/burn-core/src/optim/sgd.rs +++ b/burn-core/src/optim/sgd.rs @@ -170,11 +170,7 @@ mod tests { } fn layer() -> Linear { - Linear::::new(&LinearConfig { - d_input: 20, - d_output: 20, - bias: true, - }) + Linear::::new(&LinearConfig::new(20, 20).with_bias(true)) } fn sgd_with_all() -> Sgd {