diff --git a/examples/segment-anything-client/index.css b/examples/segment-anything-client/index.css new file mode 100644 index 0000000..a896b88 --- /dev/null +++ b/examples/segment-anything-client/index.css @@ -0,0 +1,116 @@ +* { + box-sizing: border-box; + padding: 0; + margin: 0; + font-family: sans-serif; +} + +html, +body { + height: 100%; +} + +body { + padding: 16px 32px; +} + +body, +#container, +#upload-button { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; +} + +h1 { + text-align: center; +} + +#container { + position: relative; + width: 640px; + height: 420px; + max-width: 100%; + max-height: 100%; + border: 2px dashed #D1D5DB; + border-radius: 0.75rem; + overflow: hidden; + cursor: pointer; + margin-top: 1rem; + background-size: 100% 100%; + background-position: center; + background-repeat: no-repeat; +} + +#mask-output { + position: absolute; + width: 100%; + height: 100%; + pointer-events: none; +} + +#upload-button { + gap: 0.4rem; + font-size: 18px; + cursor: pointer; +} + +#upload { + display: none; +} + +svg { + pointer-events: none; +} + +#example { + font-size: 14px; + text-decoration: underline; + cursor: pointer; +} + +#example:hover { + color: #2563EB; +} + +canvas { + position: absolute; + width: 100%; + height: 100%; + opacity: 0.6; +} + +#status { + min-height: 16px; + margin: 8px 0; +} + +.icon { + height: 16px; + width: 16px; + position: absolute; + transform: translate(-50%, -50%); +} + +#controls>button { + padding: 6px 12px; + background-color: #3498db; + color: white; + border: 1px solid #2980b9; + border-radius: 5px; + cursor: pointer; + font-size: 16px; +} + +#controls>button:disabled { + background-color: #d1d5db; + color: #6b7280; + border: 1px solid #9ca3af; + cursor: not-allowed; +} + +#information { + margin-top: 0.25rem; + font-size: 15px; +} \ No newline at end of file diff --git a/examples/segment-anything-client/index.html b/examples/segment-anything-client/index.html new file mode 100644 index 0000000..5e8a2e9 --- /dev/null +++ b/examples/segment-anything-client/index.html @@ -0,0 +1,40 @@ + + + + + + + + + Transformers.js - Segment Anything + + + +

Segment Anything w/ 🤗 Transformers.js

+
+ + +
+ +
+ + + +
+

+ Left click = positive points, right click = negative points. +

+ + + + + + \ No newline at end of file diff --git a/examples/segment-anything-client/index.js b/examples/segment-anything-client/index.js new file mode 100644 index 0000000..e01b59c --- /dev/null +++ b/examples/segment-anything-client/index.js @@ -0,0 +1,295 @@ + +// Reference the elements we will use +const statusLabel = document.getElementById('status'); +const fileUpload = document.getElementById('upload'); +const imageContainer = document.getElementById('container'); +const example = document.getElementById('example'); +const maskCanvas = document.getElementById('mask-output'); +const uploadButton = document.getElementById('upload-button'); +const resetButton = document.getElementById('reset-image'); +const clearButton = document.getElementById('clear-points'); +const cutButton = document.getElementById('cut-mask'); + +// State variables +let lastPoints = null; +let isEncoded = false; +let isDecoding = false; +let isMultiMaskMode = false; +let modelReady = false; +let imageDataURI = null; + +// Constants +const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/'; +const EXAMPLE_URL = BASE_URL + 'corgi.jpg'; + +// Create a web worker so that the main (UI) thread is not blocked during inference. +const worker = new Worker('worker.js', { + type: 'module', +}); + +// Preload star and cross images to avoid lag on first click +const star = new Image(); +star.src = BASE_URL + 'star-icon.png'; +star.className = 'icon'; + +const cross = new Image(); +cross.src = BASE_URL + 'cross-icon.png'; +cross.className = 'icon'; + +// Set up message handler +worker.addEventListener('message', (e) => { + const { type, data } = e.data; + if (type === 'ready') { + modelReady = true; + statusLabel.textContent = 'Ready'; + + } else if (type === 'decode_result') { + isDecoding = false; + + if (!isEncoded) { + return; // We are not ready to decode yet + } + + if (!isMultiMaskMode && lastPoints) { + // Perform decoding with the last point + decode(); + lastPoints = null; + } + + const { mask, scores } = data; + + // Update canvas dimensions (if different) + if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { + maskCanvas.width = mask.width; + maskCanvas.height = mask.height; + } + + // Create context and allocate buffer for pixel data + const context = maskCanvas.getContext('2d'); + const imageData = context.createImageData(maskCanvas.width, maskCanvas.height); + + // Select best mask + const numMasks = scores.length; // 3 + let bestIndex = 0; + for (let i = 1; i < numMasks; ++i) { + if (scores[i] > scores[bestIndex]) { + bestIndex = i; + } + } + statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; + + // Fill mask with colour + const pixelData = imageData.data; + for (let i = 0; i < pixelData.length; ++i) { + if (mask.data[numMasks * i + bestIndex] === 1) { + const offset = 4 * i; + pixelData[offset] = 0; // red + pixelData[offset + 1] = 114; // green + pixelData[offset + 2] = 189; // blue + pixelData[offset + 3] = 255; // alpha + } + } + + // Draw image data to context + context.putImageData(imageData, 0, 0); + + } else if (type === 'segment_result') { + if (data === 'start') { + statusLabel.textContent = 'Extracting image embedding...'; + } else { + statusLabel.textContent = 'Embedding extracted!'; + isEncoded = true; + } + } +}); + +function decode() { + isDecoding = true; + worker.postMessage({ type: 'decode', data: lastPoints }); +} + +function clearPointsAndMask() { + // Reset state + isMultiMaskMode = false; + lastPoints = null; + + // Remove points from previous mask (if any) + document.querySelectorAll('.icon').forEach(e => e.remove()); + + // Disable cut button + cutButton.disabled = true; + + // Reset mask canvas + maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height); +} +clearButton.addEventListener('click', clearPointsAndMask); + +resetButton.addEventListener('click', () => { + // Update state + isEncoded = false; + imageDataURI = null; + + // Indicate to worker that we have reset the state + worker.postMessage({ type: 'reset' }); + + // Clear points and mask (if present) + clearPointsAndMask(); + + // Update UI + cutButton.disabled = true; + imageContainer.style.backgroundImage = 'none'; + uploadButton.style.display = 'flex'; + statusLabel.textContent = 'Ready'; +}); + +function segment(data) { + // Update state + isEncoded = false; + if (!modelReady) { + statusLabel.textContent = 'Loading model...'; + } + imageDataURI = data; + + // Update UI + imageContainer.style.backgroundImage = `url(${data})`; + uploadButton.style.display = 'none'; + cutButton.disabled = true; + + // Instruct worker to segment the image + worker.postMessage({ type: 'segment', data }); +} + +// Handle file selection +fileUpload.addEventListener('change', function (e) { + const file = e.target.files[0]; + if (!file) { + return; + } + + const reader = new FileReader(); + + // Set up a callback when the file is loaded + reader.onload = e2 => segment(e2.target.result); + + reader.readAsDataURL(file); +}); + +example.addEventListener('click', (e) => { + e.preventDefault(); + segment(EXAMPLE_URL); +}); + +function addIcon({ point, label }) { + const icon = (label === 1 ? star : cross).cloneNode(); + icon.style.left = `${point[0] * 100}%`; + icon.style.top = `${point[1] * 100}%`; + imageContainer.appendChild(icon); +} + +// Attach hover event to image container +imageContainer.addEventListener('mousedown', e => { + if (e.button !== 0 && e.button !== 2) { + return; // Ignore other buttons + } + if (!isEncoded) { + return; // Ignore if not encoded yet + } + if (!isMultiMaskMode) { + lastPoints = []; + isMultiMaskMode = true; + cutButton.disabled = false; + } + + const point = getPoint(e); + lastPoints.push(point); + + // add icon + addIcon(point); + + decode(); +}); + + +// Clamp a value inside a range [min, max] +function clamp(x, min = 0, max = 1) { + return Math.max(Math.min(x, max), min) +} + +function getPoint(e) { + // Get bounding box + const bb = imageContainer.getBoundingClientRect(); + + // Get the mouse coordinates relative to the container + const mouseX = clamp((e.clientX - bb.left) / bb.width); + const mouseY = clamp((e.clientY - bb.top) / bb.height); + + return { + point: [mouseX, mouseY], + label: e.button === 2 // right click + ? 0 // negative prompt + : 1, // positive prompt + } +} + +// Do not show context menu on right click +imageContainer.addEventListener('contextmenu', e => { + e.preventDefault(); +}); + +// Attach hover event to image container +imageContainer.addEventListener('mousemove', e => { + if (!isEncoded || isMultiMaskMode) { + // Ignore mousemove events if the image is not encoded yet, + // or we are in multi-mask mode + return; + } + lastPoints = [getPoint(e)]; + + if (!isDecoding) { + decode(); // Only decode if we are not already decoding + } +}); + +// Handle cut button click +cutButton.addEventListener('click', () => { + const [w, h] = [maskCanvas.width, maskCanvas.height]; + + // Get the mask pixel data + const maskContext = maskCanvas.getContext('2d'); + const maskPixelData = maskContext.getImageData(0, 0, w, h); + + // Load the image + const image = new Image(); + image.crossOrigin = 'anonymous'; + image.onload = async () => { + // Create a new canvas to hold the image + const imageCanvas = new OffscreenCanvas(w, h); + const imageContext = imageCanvas.getContext('2d'); + imageContext.drawImage(image, 0, 0, w, h); + const imagePixelData = imageContext.getImageData(0, 0, w, h); + + // Create a new canvas to hold the cut-out + const cutCanvas = new OffscreenCanvas(w, h); + const cutContext = cutCanvas.getContext('2d'); + const cutPixelData = cutContext.getImageData(0, 0, w, h); + + // Copy the image pixel data to the cut canvas + for (let i = 3; i < maskPixelData.data.length; i += 4) { + if (maskPixelData.data[i] > 0) { + for (let j = 0; j < 4; ++j) { + const offset = i - j; + cutPixelData.data[offset] = imagePixelData.data[offset]; + } + } + } + cutContext.putImageData(cutPixelData, 0, 0); + + // Download image + const link = document.createElement('a'); + link.download = 'image.png'; + link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); + link.click(); + link.remove(); + } + image.src = imageDataURI; +}); diff --git a/examples/segment-anything-client/worker.js b/examples/segment-anything-client/worker.js new file mode 100644 index 0000000..5dd6369 --- /dev/null +++ b/examples/segment-anything-client/worker.js @@ -0,0 +1,109 @@ +import { env, SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.14.0'; + +// Since we will download the model from the Hugging Face Hub, we can skip the local model check +env.allowLocalModels = false; + +// We adopt the singleton pattern to enable lazy-loading of the model and processor. +export class SegmentAnythingSingleton { + static model_id = 'Xenova/slimsam-77-uniform'; + static model; + static processor; + static quantized = true; + + static getInstance() { + if (!this.model) { + this.model = SamModel.from_pretrained(this.model_id, { + quantized: this.quantized, + }); + } + if (!this.processor) { + this.processor = AutoProcessor.from_pretrained(this.model_id); + } + + return Promise.all([this.model, this.processor]); + } +} + + +// State variables +let image_embeddings = null; +let image_inputs = null; +let ready = false; + +self.onmessage = async (e) => { + const [model, processor] = await SegmentAnythingSingleton.getInstance(); + if (!ready) { + // Indicate that we are ready to accept requests + ready = true; + self.postMessage({ + type: 'ready', + }); + } + + const { type, data } = e.data; + if (type === 'reset') { + image_inputs = null; + image_embeddings = null; + + } else if (type === 'segment') { + // Indicate that we are starting to segment the image + self.postMessage({ + type: 'segment_result', + data: 'start', + }); + + // Read the image and recompute image embeddings + const image = await RawImage.read(e.data.data); + image_inputs = await processor(image); + image_embeddings = await model.get_image_embeddings(image_inputs) + + // Indicate that we have computed the image embeddings, and we are ready to accept decoding requests + self.postMessage({ + type: 'segment_result', + data: 'done', + }); + + } else if (type === 'decode') { + // Prepare inputs for decoding + const reshaped = image_inputs.reshaped_input_sizes[0]; + const points = data.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]]) + const labels = data.map(x => BigInt(x.label)); + + const input_points = new Tensor( + 'float32', + points.flat(Infinity), + [1, 1, points.length, 2], + ) + const input_labels = new Tensor( + 'int64', + labels.flat(Infinity), + [1, 1, labels.length], + ) + + // Generate the mask + const outputs = await model({ + ...image_embeddings, + input_points, + input_labels, + }) + + // Post-process the mask + const masks = await processor.post_process_masks( + outputs.pred_masks, + image_inputs.original_sizes, + image_inputs.reshaped_input_sizes, + ); + + // Send the result back to the main thread + self.postMessage({ + type: 'decode_result', + data: { + mask: RawImage.fromTensor(masks[0][0]), + scores: outputs.iou_scores.data, + }, + }); + + } else { + throw new Error(`Unknown message type: ${type}`); + } +}