mirror of https://github.com/tracel-ai/burn.git
Reduce MNIST parameters and update mnist-inference-web to match (#242)
This commit is contained in:
parent
04d72631d7
commit
8e74ce4bb9
|
@ -48,16 +48,16 @@ values.
|
|||
Layers:
|
||||
|
||||
1. Input Image (28,28, 1ch)
|
||||
2. `Conv2d`(3x3, 8ch), `GELU`
|
||||
3. `Conv2d`(3x3, 16ch), `GELU`
|
||||
4. `Conv2d`(3x3, 24ch), `GELU`
|
||||
2. `Conv2d`(3x3, 8ch), `BatchNorm2d`, `GELU`
|
||||
3. `Conv2d`(3x3, 16ch), `BatchNorm2d`, `GELU`
|
||||
4. `Conv2d`(3x3, 24ch), `BatchNorm2d`, `GELU`
|
||||
5. `Linear`(11616, 32), `GELU`
|
||||
6. `Linear`(32, 10)
|
||||
7. Softmax Output
|
||||
|
||||
The total number of parameters is 376,712.
|
||||
The total number of parameters is 376,952.
|
||||
|
||||
The model is trained with 6 epochs and the final test accuracy is 98.03%.
|
||||
The model is trained with 4 epochs and the final test accuracy is 98.67%.
|
||||
|
||||
The training and hyper parameter information in can be found in
|
||||
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
|
||||
|
@ -68,8 +68,8 @@ The main differentiating factor of this example's approach (compiling rust model
|
|||
other popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js),
|
||||
[ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and
|
||||
[TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust
|
||||
compiler optimizes and includes only used `burn` routines. 1,507,884 bytes out of Wasm's 1,831,094
|
||||
byte file is the model's parameters. The rest of 323,210 bytes contain all the code (including
|
||||
compiler optimizes and includes only used `burn` routines. 1,509,747 bytes out of Wasm's 1,866,491
|
||||
byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including
|
||||
`burn`'s `nn` components, the data deserialization library, and math operations).
|
||||
|
||||
## Future Improvements
|
||||
|
|
|
@ -5,7 +5,7 @@ use burn::module::State;
|
|||
use bincode::config;
|
||||
|
||||
const GENERATED_FILE_NAME: &str = "mnist_model_state.bincode";
|
||||
const MODEL_STATE_FILE_NAME: &str = "model-6.json.gz";
|
||||
const MODEL_STATE_FILE_NAME: &str = "model-4.json.gz";
|
||||
|
||||
/// This build step is responsible for converting JSON serialized to Bincode serilization
|
||||
/// in order to make the file small and efficient for bundling the binary into wasm code.
|
||||
|
|
|
@ -54,7 +54,7 @@
|
|||
<table>
|
||||
<tr>
|
||||
<th>Draw a digit here</th>
|
||||
<th>Auto cropped and scaled</th>
|
||||
<th>Cropped and scaled</th>
|
||||
<th>Probability result</th>
|
||||
</tr>
|
||||
<tr>
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -6,7 +6,7 @@ use alloc::{format, vec::Vec};
|
|||
|
||||
use burn::{
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
nn::{self, conv::Conv2dPaddingConfig, BatchNorm2d},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
|
@ -24,12 +24,12 @@ const NUM_CLASSES: usize = 10;
|
|||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn new() -> Self {
|
||||
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,1,26,26]
|
||||
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,1,24x24]
|
||||
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,1,22x22]
|
||||
|
||||
let fc1 = nn::Linear::new(&nn::LinearConfig::new(24 * 22 * 22, 32).with_bias(false));
|
||||
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26]
|
||||
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24]
|
||||
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22]
|
||||
|
||||
let hidden_size = 24 * 22 * 22;
|
||||
let fc1 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, 32).with_bias(false));
|
||||
let fc2 = nn::Linear::new(&nn::LinearConfig::new(32, NUM_CLASSES).with_bias(false));
|
||||
|
||||
Self {
|
||||
|
@ -50,7 +50,8 @@ impl<B: Backend> Model<B> {
|
|||
let x = self.conv2.forward(x);
|
||||
let x = self.conv3.forward(x);
|
||||
|
||||
let x = x.reshape([batch_size, 24 * 22 * 22]);
|
||||
let [batch_size, channels, heigth, width] = x.dims();
|
||||
let x = x.reshape([batch_size, channels * heigth * width]);
|
||||
|
||||
let x = self.fc1.forward(x);
|
||||
let x = self.activation.forward(x);
|
||||
|
@ -62,23 +63,29 @@ impl<B: Backend> Model<B> {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct ConvBlock<B: Backend> {
|
||||
conv: Param<nn::conv::Conv2d<B>>,
|
||||
norm: Param<BatchNorm2d<B>>,
|
||||
activation: nn::GELU,
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvBlock<B> {
|
||||
pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
|
||||
let conv = nn::conv::Conv2d::new(
|
||||
&nn::conv::Conv2dConfig::new(channels, kernel_size).with_bias(false),
|
||||
&nn::conv::Conv2dConfig::new(channels, kernel_size)
|
||||
.with_padding(Conv2dPaddingConfig::Valid),
|
||||
);
|
||||
let norm = nn::BatchNorm2d::new(&nn::BatchNorm2dConfig::new(channels[1]));
|
||||
|
||||
Self {
|
||||
conv: Param::from(conv),
|
||||
norm: Param::from(norm),
|
||||
activation: nn::GELU::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv.forward(input);
|
||||
let x = self.norm.forward(x);
|
||||
|
||||
self.activation.forward(x)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,13 +25,12 @@ const NUM_CLASSES: usize = 10;
|
|||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn new() -> Self {
|
||||
let conv1 = ConvBlock::new([1, 32], [3, 3]); // out: [Batch,32,28,28]
|
||||
let conv2 = ConvBlock::new([32, 32], [3, 3]); // out: [Batch,32,28x28]
|
||||
let conv3 = ConvBlock::new([32, 1], [3, 3]); // out: [Batch,1,28x28]
|
||||
|
||||
let hidden_size = 28 * 28;
|
||||
let fc1 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, hidden_size));
|
||||
let fc2 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, NUM_CLASSES));
|
||||
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26]
|
||||
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24]
|
||||
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22]
|
||||
let hidden_size = 24 * 22 * 22;
|
||||
let fc1 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, 32).with_bias(false));
|
||||
let fc2 = nn::Linear::new(&nn::LinearConfig::new(32, NUM_CLASSES).with_bias(false));
|
||||
|
||||
let dropout = nn::Dropout::new(&nn::DropoutConfig::new(0.3));
|
||||
|
||||
|
@ -89,7 +88,7 @@ impl<B: Backend> ConvBlock<B> {
|
|||
pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
|
||||
let conv = nn::conv::Conv2d::new(
|
||||
&nn::conv::Conv2dConfig::new(channels, kernel_size)
|
||||
.with_padding(Conv2dPaddingConfig::Same),
|
||||
.with_padding(Conv2dPaddingConfig::Valid),
|
||||
);
|
||||
let norm = nn::BatchNorm2d::new(&nn::BatchNorm2dConfig::new(channels[1]));
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist";
|
|||
|
||||
#[derive(Config)]
|
||||
pub struct MnistTrainingConfig {
|
||||
#[config(default = 10)]
|
||||
#[config(default = 4)]
|
||||
pub num_epochs: usize,
|
||||
|
||||
#[config(default = 64)]
|
||||
|
|
Loading…
Reference in New Issue