Add image classification web demo with WebGPU, CPU backends (#840)

This commit is contained in:
Dilshod Tadjibaev 2023-10-05 09:29:13 -05:00 committed by GitHub
parent 28e2a99efe
commit e2a17e4295
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 2045 additions and 63 deletions

View File

@ -4,31 +4,32 @@
resolver = "2"
members = [
"burn",
"burn-autodiff",
"burn-common",
"burn-compute",
"burn-core",
"burn-dataset",
"burn-derive",
"burn-import",
"burn-import/onnx-tests",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
"burn-wgpu",
"burn-candle",
"burn-tensor-testgen",
"burn-tensor",
"burn-train",
"xtask",
"examples/*",
"backend-comparison",
"burn",
"burn-autodiff",
"burn-common",
"burn-compute",
"burn-core",
"burn-dataset",
"burn-derive",
"burn-import",
"burn-import/onnx-tests",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
"burn-wgpu",
"burn-candle",
"burn-tensor-testgen",
"burn-tensor",
"burn-train",
"xtask",
"examples/*",
"backend-comparison",
]
exclude = ["examples/notebook"]
[workspace.dependencies]
async-trait = "0.1.73"
bytemuck = "1.13"
const-random = "0.1.15"
csv = "1.2.2"
@ -37,11 +38,12 @@ dirs = "5.0.1"
fake = "2.6.1"
flate2 = "1.0.26"
float-cmp = "0.9.0"
getrandom = { version = "0.2.10", default-features = false }
gix-tempfile = { version = "8.0.0", features = ["signals"] }
hashbrown = "0.14.0"
indicatif = "0.17.5"
libm = "0.2.7"
log = "0.4.19"
log = { default-features = false, version = "0.4.19" }
pretty_assertions = "1.3"
proc-macro2 = "1.0.60"
protobuf-codegen = "3.2"
@ -55,15 +57,17 @@ rusqlite = { version = "0.29" }
sanitize-filename = "0.5.0"
serde_rusqlite = "0.33.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
strum = "0.24"
strum_macros = "0.24"
strum = "0.25.0"
strum_macros = "0.25.2"
syn = { version = "2.0", features = ["full", "extra-traits"] }
tempfile = "3.6.0"
thiserror = "1.0.40"
tracing-subscriber = "0.3.17"
tracing-core = "0.1.31"
tracing-appender = "0.2.2"
async-trait = "0.1.73"
tracing-core = "0.1.31"
tracing-subscriber = "0.3.17"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2.0"
# WGPU stuff
futures-intrusive = "0.5"
@ -75,26 +79,27 @@ wgpu = "0.17.0"
# The following packages disable the "std" feature for no_std compatibility
#
bincode = { version = "2.0.0-rc.3", features = [
"alloc",
"serde",
"alloc",
"serde",
], default-features = false }
derive-new = { version = "0.5.9", default-features = false }
half = { version = "2.3.1", features = [
"alloc",
"num-traits",
"serde",
"alloc",
"num-traits",
"serde",
], default-features = false }
ndarray = { version = "0.15.6", default-features = false }
num-traits = { version = "0.2.15", default-features = false, features = [
"libm",
"libm",
] } # libm is for no_std
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
"std_rng",
] } # std_rng is for no_std
rand_distr = { version = "0.4.3", default-features = false }
serde = { version = "1.0.164", default-features = false, features = [
"derive",
"alloc",
"derive",
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = { version = "1.0.96", default-features = false }
uuid = { version = "1.3.4", default-features = false }

View File

@ -1,9 +1,8 @@
[default]
extend-ignore-identifiers-re = [
"ratatui",
"NdArray*",
"ND"
]
extend-ignore-identifiers-re = ["ratatui", "NdArray*", "ND"]
[files]
extend-exclude = ["assets/ModuleSerialization.xml"]
extend-exclude = [
"assets/ModuleSerialization.xml",
"examples/image-classification-web/src/model/label.txt",
]

View File

@ -14,8 +14,8 @@ version = "0.10.0"
[dependencies]
derive-new = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.10.0" }
half = { workspace = true, features = ["std"] }
burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features = false }
half = { workspace = true }
# candle-core = { version = "0.1.2" }
candle-core = { git = "https://github.com/huggingface/candle", rev = "237323c" }

View File

@ -18,7 +18,7 @@ std = ["rand/std"]
[target.'cfg(target_family = "wasm")'.dependencies]
async-trait = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }
getrandom = { workspace = true, features = ["js"] }
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **

View File

@ -13,17 +13,16 @@ version = "0.10.0"
[features]
default = ["std", "dataset-minimal"]
std = [
"burn-common/std",
"burn-tensor/std",
"flate2",
"log",
"rand/std",
"rmp-serde",
"serde/std",
"serde_json/std",
"bincode/std",
"half/std",
"derive-new/std",
"burn-common/std",
"burn-tensor/std",
"flate2",
"log",
"rand/std",
"rmp-serde",
"serde/std",
"serde_json/std",
"bincode/std",
"half/std",
]
dataset = ["burn-dataset/default"]
dataset-minimal = ["burn-dataset"]
@ -35,10 +34,18 @@ autodiff = ["burn-autodiff"]
ndarray = ["__ndarray", "burn-ndarray/default"]
ndarray-no-std = ["__ndarray", "burn-ndarray"]
ndarray-blas-accelerate = ["__ndarray", "ndarray", "burn-ndarray/blas-accelerate"]
ndarray-blas-accelerate = [
"__ndarray",
"ndarray",
"burn-ndarray/blas-accelerate",
]
ndarray-blas-netlib = ["__ndarray", "ndarray", "burn-ndarray/blas-netlib"]
ndarray-blas-openblas = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas"]
ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas-system"]
ndarray-blas-openblas-system = [
"__ndarray",
"ndarray",
"burn-ndarray/blas-openblas-system",
]
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.
wgpu = ["burn-wgpu/default"]
@ -48,8 +55,8 @@ tch = ["burn-tch"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
[dependencies]
@ -66,7 +73,7 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
derive-new = { workspace = true, default-features = false }
derive-new = { workspace = true }
libm = { workspace = true }
log = { workspace = true, optional = true }
rand = { workspace = true, features = ["std_rng"] } # Default enables std
@ -88,7 +95,7 @@ serde_json = { workspace = true, features = ["alloc"] } #Default enables std
[dev-dependencies]
tempfile = { workspace = true }
burn-dataset = { path = "../burn-dataset", version = "0.10.0", features = [
"fake",
"fake",
] }
burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", default-features = false }

View File

@ -17,8 +17,7 @@ default = ["onnx"]
onnx = []
[dependencies]
burn = {path = "../burn", version = "0.10.0" }
burn-common = {path = "../burn-common", version = "0.10.0" }
burn = {path = "../burn", version = "0.10.0"}
burn-ndarray = {path = "../burn-ndarray", version = "0.10.0" }
bytemuck = {workspace = true}

View File

@ -0,0 +1,35 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)"]
edition = "2021"
license = "MIT OR Apache-2.0"
name = "image-classification-web"
publish = false
version = "0.10.0"
[lib]
crate-type = ["cdylib"]
[features]
default = []
[dependencies]
burn = { path = "../../burn", default-features = false, features = [
"ndarray-no-std",
"wgpu",
] }
burn-candle = { path = "../../burn-candle", version = "0.10.0", default-features = false }
js-sys = "0.3.64"
log = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde-wasm-bindgen = "0.6.0"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2.0"
wasm-timer = "0.2.5"
[build-dependencies]
# Used to generate code from ONNX model
burn-import = { path = "../../burn-import" }

View File

@ -0,0 +1,52 @@
# NOTICES AND INFORMATION
This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.
## Sample Images
Image Title: Domestic cat, a ten month old female.
Author: Von.grzanka
Source: https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/
Image Title: The George Washington Bridge over the Hudson River leading to New York City as seen from Fort Lee, New Jersey.
Author: John O'Connell
Source: https://commons.wikimedia.org/wiki/File:George_Washington_Bridge_from_New_Jersey-edit.jpg
License: https://creativecommons.org/licenses/by/2.0/deed.en
Image Title: Coyote from Yosemite National Park, California in snow
Author: Yathin S Krishnappa
Source https://commons.wikimedia.org/wiki/File:2009-Coyote-Yosemite.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/deed.en
Image Title: Table lamp with a lampshade illuminated by sunlight.
Author: LoMit
Source: https://commons.wikimedia.org/wiki/File:Lamp_with_a_lampshade_illuminated_by_sunlight.jpg
License: https://creativecommons.org/licenses/by-sa/4.0/deed.en
Image Title: White Pelican Pelecanus onocrotalus at Walvis Bay, Namibia
Author: Rui Ornelas
Source: https://commons.wikimedia.org/wiki/File:Pelikan_Walvis_Bay.jpg
License: https://creativecommons.org/licenses/by/2.0/deed.en
Image Title: Photo of a traditional torch to be posted at gates
Author: Faizul Latif Chowdhury
Source: https://commons.wikimedia.org/wiki/File:Torch_traditional.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/deed.en
Image Title: American Flamingo Phoenicopterus ruber at Gotomeer, Riscado, Bonaire
Author: Paul Asman and Jill Lenoble
Source: https://commons.wikimedia.org/wiki/File:Phoenicopterus_ruber_Bonaire_2.jpg
License: https://creativecommons.org/licenses/by/2.0/deed.en
## ONNX Model
SqueezeNet 1.1 model is licensed under Apache License 2.0. The model is downloaded from the [ONNX model zoo](https://github.com/onnx/models/tree/main).
Source: https://github.com/onnx/models/blob/main/vision/classification/squeezenet/model/squeezenet1.1-7.onnx
License: Apache License 2.0
License URL: https://github.com/onnx/models/blob/main/LICENSE
## ONNX Labels
The labels for the SqueezeNet 1.1 model are licensed under Apache License 2.0. The labels are downloaded from the [ONNX model zoo](https://github.com/onnx/models/blob/main/vision/classification/synset.txt)

View File

@ -0,0 +1,74 @@
# Image Classification Web Demo Using Burn and WebAssembly
## Overview
This demo showcases how to execute an image classification task in a web browser using a model
converted to Rust code. The project utilizes the Burn deep learning framework, WebGPU and
WebAssembly . Specifically, it demonstrates:
1. Converting an ONNX (Open Neural Networks Exchange) model into Rust code compatible with the Burn
framework.
2. Executing the model within a web browser using WebGPU via the `burn-wgpu` backend and WebAssembly
through the `burn-ndarray` and `burn-candle` backends.
## Running the Demo
### Step 1: Build the WebAssembly Binary and Other Assets
To compile the Rust code into WebAssembly and build other essential files, execute the following
script:
```bash
./build-for-web.sh
```
### Step 2: Launch the Web Server
Run the following command to initiate a web server on your local machine:
```bash
./run-server.sh
```
### Step 3: Access the Web Demo
Open your web browser and navigate to:
```plaintext
http://localhost:8000
```
## Backend Compatibility
As of now, the WebGPU backend is compatible only with Chrome browsers running on macOS and Windows.
The application will dynamically detect if WebGPU support is available and proceed accordingly.
## SIMD Support
The build targets two sets of binaries, one with SIMD support and one without. The web application
dynamically detects if SIMD support is available and downloads the appropriate binary.
## Model Information
The image classification task is achieved using the SqueezeNet model, a compact Convolutional Neural
Network (CNN). It is trained on the ImageNet dataset and can classify images into 1,000 distinct
categories. The included ONNX model is sourced from the
[ONNX Model Zoo](https://github.com/onnx/models/tree/main/vision/classification/squeezenet). For
further details about the model's architecture and performance, you can refer to the
[original paper](https://arxiv.org/abs/1602.07360).
## Credits
This demo was inspired by the ONNX Runtime web demo featuring the
[SqueezeNet model trained on ImageNet](https://microsoft.github.io/onnxruntime-web-demo/#/squeezenet).
The complete list of credits/attribution can be found in the [NOTICES](NOTICES.md) file.
## Future Enhancements
- [ ] Fall back to WebGL if WebGPU is not supported by the browser. See
[wgpu's WebGL support ](https://github.com/gfx-rs/wgpu/wiki/Running-on-the-Web-with-WebGPU-and-WebGL)
- [ ] Enable SIMD support for Safari browsers after Release 179.
- [ ] Add image paste functionality to allow users to paste an image from the clipboard.

View File

@ -0,0 +1,18 @@
# Add wasm32 target for compiler.
rustup target add wasm32-unknown-unknown
if ! command -v wasm-pack &>/dev/null; then
echo "wasm-pack could not be found. Installing ..."
cargo install wasm-pack
exit
fi
mkdir -p pkg
echo "Building SIMD version of wasm for web ..."
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 -Ctarget-feature=+simd128 --cfg web_sys_unstable_apis"
wasm-pack build --dev --out-dir pkg/simd --target web --no-typescript
echo "Building Non-SIMD version of wasm for web ..."
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis"
wasm-pack build --dev --out-dir pkg/no_simd --target web --no-typescript

View File

@ -0,0 +1,75 @@
/// This build script generates the model code from the ONNX file and the labels from the text file.
use std::env;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::path::Path;
use burn_import::burn::graph::RecordType;
use burn_import::onnx::ModelGen;
const LABEL_SOURCE_FILE: &str = "src/model/label.txt";
const LABEL_DEST_FILE: &str = "model/label.rs";
const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx";
const OUT_DIR: &str = "model/";
fn main() {
// Re-run the build script if model files change.
println!("cargo:rerun-if-changed=src/model");
// Check if half precision is enabled.
let half_precision = cfg!(feature = "half_precision");
// Generate the model code from the ONNX file.
ModelGen::new()
.input(INPUT_ONNX_FILE)
.out_dir(OUT_DIR)
.record_type(RecordType::Bincode)
.embed_states(true)
.half_precision(half_precision)
.run_from_script();
// Generate the labels from the synset.txt file.
generate_labels_from_txt_file().unwrap();
}
/// Read labels from synset.txt and store them in a vector of strings in a Rust file.
fn generate_labels_from_txt_file() -> std::io::Result<()> {
let out_dir = env::var("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE);
let mut f = File::create(&dest_path)?;
let file = File::open(LABEL_SOURCE_FILE)?;
let reader = BufReader::new(file);
writeln!(f, "pub static LABELS: &[&str] = &[")?;
for line in reader.lines() {
writeln!(
f,
" \"{}\",",
extract_simple_label(line.unwrap()).unwrap()
)?;
}
writeln!(f, "];")?;
Ok(())
}
/// Extract the simple label from the full label.
///
/// The full label is of the form: "n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea"
/// The simple label is of the form: "indigo bunting"
fn extract_simple_label(input: String) -> Option<String> {
// Split the string based on the space character.
let mut parts = input.split(' ');
// Skip the first part (the alphanumeric code).
parts.next()?;
// Get the remaining string.
let remaining = parts.collect::<Vec<&str>>().join(" ");
// Find the first comma, if it exists, and take the substring before it.
let end_index = remaining.find(',').unwrap_or(remaining.len());
Some(remaining[0..end_index].to_string())
}

View File

@ -0,0 +1,11 @@
import http.server
import socketserver
PORT = 8000
Handler = http.server.SimpleHTTPRequestHandler
with socketserver.TCPServer(("", PORT), Handler) as httpd:
print(f"Running local python HTTP server on port {PORT} ...")
print(f"Serving HTTP on http://localhost:{PORT}/ ...")
httpd.serve_forever()

View File

@ -0,0 +1,57 @@
.container {
width: 100%;
max-width: 800px;
margin: auto;
}
.selections {
display: flex;
justify-content: space-between;
}
.select-box {
margin-bottom: 20px;
}
.file-input-box {
margin-bottom: 20px;
}
.actions {
display: flex;
justify-content: space-between;
margin-top: 20px;
}
#chart {
border: 1px solid #aaa;
/* width: 600px; */
/* height: 300px; */
}
#imageCanvas {
border: 1px solid #aaa;
}
/* Wrapping container for the three boxes */
.row-container {
display: flex;
align-items: center; /* Vertically center the content */
justify-content: space-between; /* Distributes space between the items */
flex-wrap: wrap; /* Allows the flex items to wrap */
}
.canvas-box,
.chart-box {
flex: 1; /* Takes up equal width */
}
#time {
text-align: center;
white-space: nowrap;
padding: 10px;
}
#time > span {
font-weight: bold;
}

View File

@ -0,0 +1,243 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Image Classification</title>
<script
src="https://cdn.jsdelivr.net/npm/wasm-feature-detect@1.5.1/dist/umd/index.min.js"
integrity="sha256-9+AQR2dApXE+f/D998vy0RATN/o4++mqVjAZ3lo432g="
crossorigin="anonymous"
></script>
<script
src="https://cdn.jsdelivr.net/npm/chart.js@4.2.1/dist/chart.umd.min.js"
integrity="sha256-tgiW1vJqfIKxE0F2uVvsXbgUlTyrhPMY/sm30hh/Sxc="
crossorigin="anonymous"
></script>
<script
src="https://cdn.jsdelivr.net/npm/chartjs-plugin-datalabels@2.2.0/dist/chartjs-plugin-datalabels.min.js"
integrity="sha256-IMCPPZxtLvdt9tam8RJ8ABMzn+Mq3SQiInbDmMYwjDg="
crossorigin="anonymous"
></script>
<script src="./index.js"></script>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/normalize.min.css@8.0.1/normalize.min.css"
integrity="sha256-oeib74n7OcB5VoyaI+aGxJKkNEdyxYjd2m3fi/3gKls="
crossorigin="anonymous"
/>
<link rel="stylesheet" href="./index.css" />
</head>
<body>
<div class="container">
<div class="selections">
<!-- Backend Selection -->
<div class="select-box">
1.
<label for="backend">Backend:</label>
<select id="backend">
<option value="ndarray" selected>CPU - Ndarray</option>
<option value="candle">CPU - Candle</option>
<option value="webgpu">GPU - WebGPU</option>
</select>
</div>
</div>
<div class="row-container">
<!-- Image Selection -->
<div class="select-box">
2.
<select id="imageDropdown">
<option value="" selected>Select Image</option>
<option value="samples/bridge.jpg">Bridge</option>
<option value="samples/cat.jpg">Cat</option>
<option value="samples/coyote.jpg">Coyote</option>
<option value="samples/flamingo.jpg">Flamingo</option>
<option value="samples/pelican.jpg">Pelican</option>
<option value="samples/table-lamp.jpg">Table Lamp</option>
<option value="samples/torch.jpg">Torch</option>
</select>
or
<input type="file" id="fileInput" accept="image/*" />
</div>
</div>
<!-- Time Taken -->
<div id="time">&nbsp;</div>
<!-- Container for the three boxes -->
<div class="row-container">
<!-- Canvas to Display Image -->
<div class="canvas-box">
<canvas id="imageCanvas" width="224" height="224"></canvas>
</div>
<!-- Chart -->
<div class="chart-box">
<canvas id="chart" width="500" height="224"></canvas>
</div>
</div>
<!-- Clear Button -->
<div class="actions">
<button id="clearButton">Clear</button>
</div>
</div>
<!-- JavaScript Logic -->
<script type="module">
// TODO - Move this to a separate file (index.js)
// DOM Elements
const imgDropdown = $("imageDropdown");
const backendDropdown = $("backend");
const fileInput = $("fileInput");
const canvas = $("imageCanvas");
const ctx = canvas.getContext("2d", { willReadFrequently: true });
const clearButton = $("clearButton");
const time = $("time");
const chart = chartConfigBuilder($("chart"));
// Event Handlers
imgDropdown.addEventListener("change", handleImageDropdownChange);
backendDropdown.addEventListener("change", handleBackendDropdownChange);
fileInput.addEventListener("change", handleFileInputChange);
clearButton.addEventListener("click", resetCanvasAndInputs);
// Module level variables
let imageClassifier;
async function initWasm() {
let simdSupported = await wasmFeatureDetect.simd();
if (isSafari()) {
// TODO enable simd for Safari once it works
// For some reason NDarray backend is not working on Safari with SIMD enabled
// Got the following error:
// recursive use of an object detected which would lead to unsafe aliasing in rust
console.warn("Safari detected. Disabling wasm simd for now ...");
simdSupported = false;
}
if (simdSupported) {
console.debug("SIMD is supported");
} else {
console.debug("SIMD is not supported");
}
let modulePath = simdSupported
? "./pkg/simd/image_classification_web.js"
: "./pkg/no_simd/image_classification_web.js";
const { default: wasm, ImageClassifier } = await import(modulePath);
wasm().then(() => {
// Initialize the classifier and save to module level variable
imageClassifier = new ImageClassifier();
});
}
initWasm();
// Check if WebGPU is supported
if (!navigator.gpu) {
backendDropdown.options[2].disabled = true;
alert("WebGPU is not supported on this device.\n\nDisabling WebGPU backend ...");
}
// Function Definitions
async function handleImageDropdownChange() {
if (this.value) {
await loadImage(this.value);
}
// Reset file input
fileInput.value = "";
}
async function handleBackendDropdownChange() {
const backend = this.value;
if (backend === "ndarray") await imageClassifier.set_backend_ndarray();
if (backend === "candle") await imageClassifier.set_backend_candle();
if (backend === "webgpu") await imageClassifier.set_backend_wgpu();
resetCanvasAndInputs();
}
function handleFileInputChange() {
if (this.files && this.files[0]) {
const reader = new FileReader();
reader.onload = (event) => loadImage(event.target.result);
reader.readAsDataURL(this.files[0]);
// Reset image dropdown
imgDropdown.selectedIndex = 0;
}
}
function resetCanvasAndInputs() {
// Clear canvas and reset inputs
ctx.clearRect(0, 0, canvas.width, canvas.height);
// Reset dropdowns
imgDropdown.selectedIndex = 0;
// Reset file input
fileInput.value = "";
// Clear chart
chart.data.labels = ["", "", "", "", ""];
chart.data.datasets[0].data = [0.0, 0.0, 0.0, 0.0, 0.0];
chart.update();
// Clear time
time.innerHTML = " ";
console.log("Cleared canvas");
}
async function loadImage(src) {
const img = new Image();
img.src = src;
await new Promise((resolve) => {
img.onload = resolve;
});
clearAndDrawCanvas(img);
runInference();
}
async function runInference() {
const data = extractRGBValuesFromCanvas(canvas, ctx);
// Run inference
const startTime = performance.now();
const output = await imageClassifier.inference(data);
const timeTaken = performance.now() - startTime;
// Update chart
const { labels, probabilities } = extractLabelsAndProbabilities(output);
chart.data.labels = labels;
chart.data.datasets[0].data = probabilities;
chart.update();
time.innerHTML = `Inference Time: <span> ${toFixed(timeTaken)} </span> ms.`;
}
function clearAndDrawCanvas(img) {
// Clear canvas
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(img, 0, 0, 224, 224);
}
</script>
</body>
</html>

View File

@ -0,0 +1,152 @@
/**
*
* This demo is part of Burn project: https://github.com/burn-rs/burn
*
* Released under a dual license:
* https://github.com/burn-rs/burn/blob/main/LICENSE-MIT
* https://github.com/burn-rs/burn/blob/main/LICENSE-APACHE
*
*/
/**
* Looks up element by an id.
* @param {string} - Element id.
*/
function $(id) {
return document.getElementById(id);
}
/**
* Truncates number to a given decimal position
* @param {number} num - Number to truncate.
* @param {number} fixed - Decimal positions.
* src: https://stackoverflow.com/a/11818658
*/
function toFixed(num, fixed) {
const re = new RegExp('^-?\\d+(?:\.\\d{0,' + (fixed || -1) + '})?');
return num.toString().match(re)[0];
}
/**
* Helper function that builds a chart using Chart.js library.
* @param {object} chartEl - Chart canvas element.
*
* NOTE: Assumes chart.js is loaded into the global.
*/
function chartConfigBuilder(chartEl) {
Chart.register(ChartDataLabels);
return new Chart(chartEl, {
plugins: [ChartDataLabels],
type: "bar",
data: {
labels: ["", "", "", "", "",],
datasets: [
{
data: [0.0, 0.0, 0.0, 0.0, 0.0], // Added one more data point to make it 10
borderWidth: 0,
fill: true,
backgroundColor: "#247ABF",
axis: 'y',
},
],
},
options: {
responsive: false,
maintainAspectRatio: false,
animation: true,
plugins: {
legend: {
display: false,
},
tooltip: {
enabled: true,
},
datalabels: {
color: "white",
formatter: function (value, context) {
return toFixed(value, 2);
},
},
},
indexAxis: 'y',
scales: {
y: {
},
x: {
suggestedMin: 0.0,
suggestedMax: 1.0,
beginAtZero: true,
},
},
},
});
}
/** Helper function that extracts labels and probabilities from the data.
* @param {object} data - Data object.
* @returns {object} - Object with labels and probabilities.
*/
function extractLabelsAndProbabilities(data) {
const labels = [];
const probabilities = [];
for (let item of data) {
if (item.hasOwnProperty('label') && item.hasOwnProperty('probability')) {
labels.push(item.label);
probabilities.push(item.probability);
}
}
return {
labels,
probabilities
};
}
/**
* Helper function that extracts RGB values from a canvas.
* @param {object} canvas - Canvas element.
* @param {object} ctx - Canvas context.
* @returns {object} - Flattened array of RGB values.
*/
function extractRGBValuesFromCanvas(canvas, ctx) {
// Get image data from the canvas
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
// Get canvas dimensions
const height = canvas.height;
const width = canvas.width;
// Create a flattened array to hold the RGB values in channel-first order
const flattenedArray = new Float32Array(3 * height * width);
// Initialize indices for R, G, B channels in the flattened array
let kR = 0,
kG = height * width,
kB = 2 * height * width;
for (let y = 0; y < height; y++) {
for (let x = 0; x < width; x++) {
// Compute the index for the image data array
const index = (y * width + x) * 4;
// Fill in the R, G, B channels in the flattened array
flattenedArray[kR++] = imageData.data[index] / 255.0; // Red
flattenedArray[kG++] = imageData.data[index + 1] / 255.0; // Green
flattenedArray[kB++] = imageData.data[index + 2] / 255.0; // Blue
}
}
return flattenedArray;
}
/** Detect if browser is safari
* @returns {boolean} - True if browser is safari.
*/
function isSafari() {
// https://stackoverflow.com/questions/7944460/detect-safari-browser
let isSafari = /^((?!chrome|android).)*safari/i.test(navigator.userAgent);
return isSafari;
}

View File

@ -0,0 +1,9 @@
# Opening index.html file directly by a browser does not work because of
# the security restrictions by the browser.
if ! command -v python3 &>/dev/null; then
echo "python3 could not be found. Running server requires python3."
exit
fi
python3 https_server.py

Binary file not shown.

After

Width:  |  Height:  |  Size: 324 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 795 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

View File

@ -0,0 +1,6 @@
#![cfg_attr(not(test), no_std)]
pub mod model;
pub mod web;
extern crate alloc;

View File

@ -0,0 +1,2 @@
// Generated labels from labels.txt
include!(concat!(env!("OUT_DIR"), "/model/label.rs"));

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,3 @@
pub mod label;
pub mod normalizer;
pub mod squeezenet1;

View File

@ -0,0 +1,38 @@
use burn::tensor::{backend::Backend, Tensor};
// Values are taken from the [ONNX SqueezeNet]
// (https://github.com/onnx/models/tree/main/vision/classification/squeezenet#preprocessing)
const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const STD: [f32; 3] = [0.229, 0.224, 0.225];
/// Normalizer for the imagenet dataset.
pub struct Normalizer<B: Backend> {
pub mean: Tensor<B, 4>,
pub std: Tensor<B, 4>,
}
impl<B: Backend> Normalizer<B> {
/// Creates a new normalizer.
pub fn new() -> Self {
let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]);
let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]);
Self { mean, std }
}
/// Normalizes the input image according to the imagenet dataset.
///
/// The input image should be in the range [0, 1].
/// The output image will be in the range [-1, 1].
///
/// The normalization is done according to the following formula:
/// `input = (input - mean) / std`
pub fn normalize(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
(input - self.mean.clone()) / self.std.clone()
}
}
impl<B: Backend> Default for Normalizer<B> {
fn default() -> Self {
Self::new()
}
}

View File

@ -0,0 +1,6 @@
// Generated model from squeezenet1.onnx
mod internal_model {
include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs"));
}
pub use internal_model::*;

View File

@ -0,0 +1,191 @@
#![allow(clippy::new_without_default)]
use alloc::{
string::{String, ToString},
vec::Vec,
};
use core::convert::Into;
use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet1::Model as SqueezenetModel};
use burn::{
backend::{
wgpu::{compute::init_async, AutoGraphicsApi, WgpuBackend, WgpuDevice},
NdArrayBackend,
},
tensor::{activation::softmax, backend::Backend, Tensor},
};
use burn_candle::CandleBackend;
use serde::Serialize;
use wasm_bindgen::prelude::*;
use wasm_timer::Instant;
#[allow(clippy::large_enum_variant)]
/// The model is loaded to a specific backend
pub enum ModelType {
/// The model is loaded to the Candle backend
WithCandleBackend(Model<CandleBackend<f32, i64>>),
/// The model is loaded to the NdArray backend
WithNdarrayBackend(Model<NdArrayBackend<f32>>),
/// The model is loaded to the Wgpu backend
WithWgpuBackend(Model<WgpuBackend<AutoGraphicsApi, f32, i32>>),
}
/// The image is 224x224 pixels with 3 channels (RGB)
const HEIGHT: usize = 224;
const WIDTH: usize = 224;
const CHANNELS: usize = 3;
/// The image classifier
#[wasm_bindgen]
pub struct ImageClassifier {
model: ModelType,
}
#[wasm_bindgen]
impl ImageClassifier {
/// Constructor called by JavaScripts with the new keyword.
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
// Initialize the logger so that the logs are printed to the console
wasm_logger::init(wasm_logger::Config::default());
log::info!("Initializing the image classifier");
Self {
model: ModelType::WithNdarrayBackend(Model::new()),
}
}
/// Runs inference on the image
pub async fn inference(&self, input: &[f32]) -> Result<JsValue, JsValue> {
log::info!("Running inference on the image");
let start = Instant::now();
let result = match self.model {
ModelType::WithCandleBackend(ref model) => model.forward(input).await,
ModelType::WithNdarrayBackend(ref model) => model.forward(input).await,
ModelType::WithWgpuBackend(ref model) => model.forward(input).await,
};
let duration = start.elapsed();
log::debug!("Inference is completed in {:?}", duration);
top_5_classes(result)
}
/// Sets the backend to Candle
pub async fn set_backend_candle(&mut self) -> Result<(), JsValue> {
log::info!("Loading the model to the Candle backend");
let start = Instant::now();
self.model = ModelType::WithCandleBackend(Model::new());
let duration = start.elapsed();
log::debug!("Model is loaded to the Candle backend in {:?}", duration);
Ok(())
}
/// Sets the backend to NdArray
pub async fn set_backend_ndarray(&mut self) -> Result<(), JsValue> {
log::info!("Loading the model to the NdArray backend");
let start = Instant::now();
self.model = ModelType::WithNdarrayBackend(Model::new());
let duration = start.elapsed();
log::debug!("Model is loaded to the NdArray backend in {:?}", duration);
Ok(())
}
/// Sets the backend to Wgpu
pub async fn set_backend_wgpu(&mut self) -> Result<(), JsValue> {
log::info!("Loading the model to the Wgpu backend");
let start = Instant::now();
init_async::<AutoGraphicsApi>(&WgpuDevice::default()).await;
self.model = ModelType::WithWgpuBackend(Model::new());
let duration = start.elapsed();
log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);
log::debug!("Warming up the model");
let start = Instant::now();
let _ = self.inference(&[0.0; HEIGHT * WIDTH * CHANNELS]).await;
let duration = start.elapsed();
log::debug!("Warming up is completed in {:?}", duration);
Ok(())
}
}
/// The image classifier model
pub struct Model<B: Backend> {
model: SqueezenetModel<B>,
normalizer: Normalizer<B>,
}
impl<B: Backend> Model<B> {
/// Constructor
pub fn new() -> Self {
Self {
model: SqueezenetModel::from_embedded(),
normalizer: Normalizer::new(),
}
}
/// Normalizes input and runs inference on the image
pub async fn forward(&self, input: &[f32]) -> Vec<f32> {
// Reshape from the 1D array to 3d tensor [ width, height, channels]
let input: Tensor<B, 4> = Tensor::from_floats(input).reshape([1, CHANNELS, HEIGHT, WIDTH]);
// Normalize input: make between [-1,1] and make the mean=0 and std=1
let input = self.normalizer.normalize(input);
// Run the tensor input through the model
let output = self.model.forward(input);
// Convert the model output into probability distribution using softmax formula
let probabilies = softmax(output, 1);
#[cfg(not(target_family = "wasm"))]
let result = probabilies.into_data().convert::<f32>().value;
// Forces the result to be computed
#[cfg(target_family = "wasm")]
let result = probabilies.into_data().await.convert::<f32>().value;
result
}
}
#[wasm_bindgen]
#[derive(Serialize)]
pub struct InferenceResult {
index: usize,
probability: f32,
label: String,
}
/// Returns the top 5 classes and convert them into a JsValue
fn top_5_classes(probabilies: Vec<f32>) -> Result<JsValue, JsValue> {
// Convert the probabilities into a vector of (index, probability)
let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect();
// Sort the probabilities in descending order
probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
// Take the top 5 probabilities
probabilies.truncate(5);
// Convert the probabilities into InferenceResult
let result: Vec<InferenceResult> = probabilies
.into_iter()
.map(|(index, probability)| InferenceResult {
index,
probability: *probability,
label: LABELS[index].to_string(),
})
.collect();
// Convert the InferenceResult into a JsValue
Ok(serde_wasm_bindgen::to_value(&result)?)
}