From 8e74ce4bb939daf17950d56b62c80678d980a1da Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 17 Mar 2023 18:46:26 -0500 Subject: [PATCH] Reduce MNIST parameters and update mnist-inference-web to match (#242) --- examples/mnist-inference-web/README.md | 14 +++++------ examples/mnist-inference-web/build.rs | 2 +- examples/mnist-inference-web/index.html | 2 +- examples/mnist-inference-web/model-4.json.gz | Bin 0 -> 1823622 bytes examples/mnist-inference-web/model-6.json.gz | Bin 1821352 -> 0 bytes examples/mnist-inference-web/src/model.rs | 23 ++++++++++++------- examples/mnist/src/model.rs | 15 ++++++------ examples/mnist/src/training.rs | 2 +- 8 files changed, 32 insertions(+), 26 deletions(-) create mode 100644 examples/mnist-inference-web/model-4.json.gz delete mode 100644 examples/mnist-inference-web/model-6.json.gz diff --git a/examples/mnist-inference-web/README.md b/examples/mnist-inference-web/README.md index 8720b8036..d1a39a3ad 100644 --- a/examples/mnist-inference-web/README.md +++ b/examples/mnist-inference-web/README.md @@ -48,16 +48,16 @@ values. Layers: 1. Input Image (28,28, 1ch) -2. `Conv2d`(3x3, 8ch), `GELU` -3. `Conv2d`(3x3, 16ch), `GELU` -4. `Conv2d`(3x3, 24ch), `GELU` +2. `Conv2d`(3x3, 8ch), `BatchNorm2d`, `GELU` +3. `Conv2d`(3x3, 16ch), `BatchNorm2d`, `GELU` +4. `Conv2d`(3x3, 24ch), `BatchNorm2d`, `GELU` 5. `Linear`(11616, 32), `GELU` 6. `Linear`(32, 10) 7. Softmax Output -The total number of parameters is 376,712. +The total number of parameters is 376,952. -The model is trained with 6 epochs and the final test accuracy is 98.03%. +The model is trained with 4 epochs and the final test accuracy is 98.67%. The training and hyper parameter information in can be found in [`burn` MNIST example](https://github.com/burn-rs/burn/tree/main/examples/mnist). @@ -68,8 +68,8 @@ The main differentiating factor of this example's approach (compiling rust model other popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js), [ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and [TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust -compiler optimizes and includes only used `burn` routines. 1,507,884 bytes out of Wasm's 1,831,094 -byte file is the model's parameters. The rest of 323,210 bytes contain all the code (including +compiler optimizes and includes only used `burn` routines. 1,509,747 bytes out of Wasm's 1,866,491 +byte file is the model's parameters. The rest of 356,744 bytes contain all the code (including `burn`'s `nn` components, the data deserialization library, and math operations). ## Future Improvements diff --git a/examples/mnist-inference-web/build.rs b/examples/mnist-inference-web/build.rs index 8edca323f..241b14144 100644 --- a/examples/mnist-inference-web/build.rs +++ b/examples/mnist-inference-web/build.rs @@ -5,7 +5,7 @@ use burn::module::State; use bincode::config; const GENERATED_FILE_NAME: &str = "mnist_model_state.bincode"; -const MODEL_STATE_FILE_NAME: &str = "model-6.json.gz"; +const MODEL_STATE_FILE_NAME: &str = "model-4.json.gz"; /// This build step is responsible for converting JSON serialized to Bincode serilization /// in order to make the file small and efficient for bundling the binary into wasm code. diff --git a/examples/mnist-inference-web/index.html b/examples/mnist-inference-web/index.html index 4d8b8f17c..15bda279d 100644 --- a/examples/mnist-inference-web/index.html +++ b/examples/mnist-inference-web/index.html @@ -54,7 +54,7 @@