Reduce MNIST parameters and update mnist-inference-web to match (#242)

This commit is contained in:
Dilshod Tadjibaev 2023-03-17 18:46:26 -05:00 committed by GitHub
parent 04d72631d7
commit 8e74ce4bb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 32 additions and 26 deletions

View File

@ -48,16 +48,16 @@ values.
Layers:
1. Input Image (28,28, 1ch)
2. `Conv2d`(3x3, 8ch), `GELU`
3. `Conv2d`(3x3, 16ch), `GELU`
4. `Conv2d`(3x3, 24ch), `GELU`
2. `Conv2d`(3x3, 8ch), `BatchNorm2d`, `GELU`
3. `Conv2d`(3x3, 16ch), `BatchNorm2d`, `GELU`
4. `Conv2d`(3x3, 24ch), `BatchNorm2d`, `GELU`
5. `Linear`(11616, 32), `GELU`
6. `Linear`(32, 10)
7. Softmax Output
The total number of parameters is 376,712.
The total number of parameters is 376,952.
The model is trained with 6 epochs and the final test accuracy is 98.03%.
The model is trained with 4 epochs and the final test accuracy is 98.67%.
The training and hyper parameter information in can be found in
[`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist).
@ -68,8 +68,8 @@ The main differentiating factor of this example's approach (compiling rust model
other popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js),
[ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and
[TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust
compiler optimizes and includes only used `burn` routines. 1,507,884 bytes out of Wasm's 1,831,094
byte file is the model's parameters. The rest of 323,210 bytes contain all the code (including
compiler optimizes and includes only used `burn` routines. 1,509,747 bytes out of Wasm's 1,866,491
byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including
`burn`'s `nn` components, the data deserialization library, and math operations).
## Future Improvements

View File

@ -5,7 +5,7 @@ use burn::module::State;
use bincode::config;
const GENERATED_FILE_NAME: &str = "mnist_model_state.bincode";
const MODEL_STATE_FILE_NAME: &str = "model-6.json.gz";
const MODEL_STATE_FILE_NAME: &str = "model-4.json.gz";
/// This build step is responsible for converting JSON serialized to Bincode serilization
/// in order to make the file small and efficient for bundling the binary into wasm code.

View File

@ -54,7 +54,7 @@
<table>
<tr>
<th>Draw a digit here</th>
<th>Auto cropped and scaled</th>
<th>Cropped and scaled</th>
<th>Probability result</th>
</tr>
<tr>

Binary file not shown.

View File

@ -6,7 +6,7 @@ use alloc::{format, vec::Vec};
use burn::{
module::{Module, Param},
nn,
nn::{self, conv::Conv2dPaddingConfig, BatchNorm2d},
tensor::{backend::Backend, Tensor},
};
@ -24,12 +24,12 @@ const NUM_CLASSES: usize = 10;
impl<B: Backend> Model<B> {
pub fn new() -> Self {
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,1,26,26]
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,1,24x24]
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,1,22x22]
let fc1 = nn::Linear::new(&nn::LinearConfig::new(24 * 22 * 22, 32).with_bias(false));
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26]
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24]
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22]
let hidden_size = 24 * 22 * 22;
let fc1 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, 32).with_bias(false));
let fc2 = nn::Linear::new(&nn::LinearConfig::new(32, NUM_CLASSES).with_bias(false));
Self {
@ -50,7 +50,8 @@ impl<B: Backend> Model<B> {
let x = self.conv2.forward(x);
let x = self.conv3.forward(x);
let x = x.reshape([batch_size, 24 * 22 * 22]);
let [batch_size, channels, heigth, width] = x.dims();
let x = x.reshape([batch_size, channels * heigth * width]);
let x = self.fc1.forward(x);
let x = self.activation.forward(x);
@ -62,23 +63,29 @@ impl<B: Backend> Model<B> {
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Param<nn::conv::Conv2d<B>>,
norm: Param<BatchNorm2d<B>>,
activation: nn::GELU,
}
impl<B: Backend> ConvBlock<B> {
pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
let conv = nn::conv::Conv2d::new(
&nn::conv::Conv2dConfig::new(channels, kernel_size).with_bias(false),
&nn::conv::Conv2dConfig::new(channels, kernel_size)
.with_padding(Conv2dPaddingConfig::Valid),
);
let norm = nn::BatchNorm2d::new(&nn::BatchNorm2dConfig::new(channels[1]));
Self {
conv: Param::from(conv),
norm: Param::from(norm),
activation: nn::GELU::new(),
}
}
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(input);
let x = self.norm.forward(x);
self.activation.forward(x)
}
}

View File

@ -25,13 +25,12 @@ const NUM_CLASSES: usize = 10;
impl<B: Backend> Model<B> {
pub fn new() -> Self {
let conv1 = ConvBlock::new([1, 32], [3, 3]); // out: [Batch,32,28,28]
let conv2 = ConvBlock::new([32, 32], [3, 3]); // out: [Batch,32,28x28]
let conv3 = ConvBlock::new([32, 1], [3, 3]); // out: [Batch,1,28x28]
let hidden_size = 28 * 28;
let fc1 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, hidden_size));
let fc2 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, NUM_CLASSES));
let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26]
let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24]
let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22]
let hidden_size = 24 * 22 * 22;
let fc1 = nn::Linear::new(&nn::LinearConfig::new(hidden_size, 32).with_bias(false));
let fc2 = nn::Linear::new(&nn::LinearConfig::new(32, NUM_CLASSES).with_bias(false));
let dropout = nn::Dropout::new(&nn::DropoutConfig::new(0.3));
@ -89,7 +88,7 @@ impl<B: Backend> ConvBlock<B> {
pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
let conv = nn::conv::Conv2d::new(
&nn::conv::Conv2dConfig::new(channels, kernel_size)
.with_padding(Conv2dPaddingConfig::Same),
.with_padding(Conv2dPaddingConfig::Valid),
);
let norm = nn::BatchNorm2d::new(&nn::BatchNorm2dConfig::new(channels[1]));

View File

@ -19,7 +19,7 @@ static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist";
#[derive(Config)]
pub struct MnistTrainingConfig {
#[config(default = 10)]
#[config(default = 4)]
pub num_epochs: usize,
#[config(default = 64)]