Create adaptive retrieval demo (#587)

* Create adaptive retrieval demo

* Remove waiting label
This commit is contained in:
Joshua Lochner 2024-02-21 15:16:51 +02:00 committed by GitHub
parent d1eabaeb0b
commit 6d2808b571
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 6245 additions and 0 deletions

View File

@ -0,0 +1,20 @@
module.exports = {
root: true,
env: { browser: true, es2020: true },
extends: [
'eslint:recommended',
'plugin:react/recommended',
'plugin:react/jsx-runtime',
'plugin:react-hooks/recommended',
],
ignorePatterns: ['dist', '.eslintrc.cjs'],
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
settings: { react: { version: '18.2' } },
plugins: ['react-refresh'],
rules: {
'react-refresh/only-export-components': [
'warn',
{ allowConstantExport: true },
],
},
}

24
examples/adaptive-retrieval/.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

View File

@ -0,0 +1,13 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Adaptive Retrieval w/ Transformers.js</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.jsx"></script>
</body>
</html>

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
{
"name": "adaptive-retrieval",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
"preview": "vite preview"
},
"dependencies": {
"@xenova/transformers": "^2.15.0",
"react": "^18.2.0",
"react-dom": "^18.2.0"
},
"devDependencies": {
"@types/react": "^18.2.43",
"@types/react-dom": "^18.2.17",
"@vitejs/plugin-react": "^4.2.1",
"autoprefixer": "^10.4.16",
"eslint": "^8.55.0",
"eslint-plugin-react": "^7.33.2",
"eslint-plugin-react-hooks": "^4.6.0",
"eslint-plugin-react-refresh": "^0.4.5",
"postcss": "^8.4.33",
"tailwindcss": "^3.4.1",
"vite": "^5.0.12"
}
}

View File

@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

View File

@ -0,0 +1,5 @@
#root {
height: 100vh;
width: 100vw;
padding: 1rem;
}

View File

@ -0,0 +1,173 @@
import { useState, useRef, useEffect, useCallback } from 'react'
import './App.css'
const PLACEHOLDER_TEXTS = [
"A panda is a large black-and-white bear native to China.",
"The typical life span of a panda is 20 years in the wild.",
"A panda's diet consists almost entirely of bamboo.",
"Ailuropoda melanoleuca is a bear species endemic to China.",
"I love pandas so much!",
"Bamboo is a fast-growing, woody grass.",
"My favorite movie is Kung Fu Panda.",
"I love the color blue.",
"Once upon a time, in a land far, far away...",
"Hello world.",
"This is an example sentence.",
].sort(() => Math.random() - 0.5);
function normalize(embedding) {
const magnitude = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
return embedding.map((val) => val / magnitude);
}
function dot(a, b) {
return a.reduce((acc, val, i) => acc + val * b[i], 0);
}
function App() {
const [status, setStatus] = useState('idle');
const [source, setSource] = useState('What is a panda?');
const [text, setText] = useState(PLACEHOLDER_TEXTS.join('\n'));
const [dimensions, setDimensions] = useState(768);
const [embeddings, setEmbeddings] = useState([]);
const [results, setResults] = useState([]);
// Create a reference to the worker object.
const worker = useRef(null);
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
useEffect(() => {
if (!worker.current) {
// Create the worker if it does not yet exist.
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
type: 'module'
});
}
// Create a callback function for messages from the worker thread.
const onMessageReceived = (e) => {
const status = e.data.status;
if (status === 'initiate') {
setStatus('loading');
} else if (status === 'ready') {
setStatus('ready');
} else if (status === 'complete') {
const embeddings = e.data.embeddings;
setDimensions(embeddings[0].length);
setEmbeddings(embeddings);
setStatus('idle');
}
};
// Attach the callback function as an event listener.
worker.current.addEventListener('message', onMessageReceived);
// Define a cleanup function for when the component is unmounted.
return () => worker.current.removeEventListener('message', onMessageReceived);
}, []);
const run = useCallback(() => {
setStatus('processing');
worker.current.postMessage({
source,
text,
});
}, [source, text])
useEffect(() => {
if (embeddings.length === 0) return;
const slicedEmbeddings = embeddings.map(x => normalize(x.slice(0, dimensions)));
const sourceEmbedding = slicedEmbeddings[0];
const sentenceEmbeddings = slicedEmbeddings.slice(1);
// Compute the cosine similarity between the source sentence and the other sentences.
// NOTE: Since vectors are normalized, we use the dot product.
const similarities = sentenceEmbeddings.map((embedding) => dot(sourceEmbedding, embedding));
setResults(text.trim().split('\n').map((sentence, i) => ({
sentence,
similarity: similarities[i]
})).sort((a, b) => b.similarity - a.similarity));
}, [text, embeddings, dimensions])
const busy = status !== 'idle';
return (
<div className='flex flex-col h-full'>
<h1 className='text-2xl md:text-4xl font-bold text-center mb-1'>Adaptive Retrieval w/ Matryoshka Embeddings</h1>
<p className='text-lg md:text-xl font-medium text-center mb-2'>Powered by <a href='https://huggingface.co/nomic-ai/nomic-embed-text-v1.5'>Nomic Embed v1.5</a> and <a href='http://huggingface.co/docs/transformers.js'>🤗 Transformers.js</a></p>
<div className='flex-grow flex flex-wrap p-4'>
<div className='flex flex-col items-center gap-y-1 w-full md:w-1/2'>
<label className='text-lg font-medium'>Query</label>
<textarea
placeholder='Enter source sentence.'
className='border w-full p-1 resize-none overflow-hidden h-10'
value={source}
onChange={e => {
setSource(e.target.value);
setResults([]);
setEmbeddings([]);
}}
></textarea>
<label className='text-lg font-medium mt-1'>Text</label>
<textarea
placeholder='Enter sentences to compare with the source sentence. One sentence per line.'
className='border w-full p-1 h-full resize-none'
value={text}
onChange={e => {
setText(e.target.value);
setResults([]);
setEmbeddings([]);
}}
></textarea>
<button
className='border py-1 px-2 bg-blue-400 rounded text-white text-lg font-medium disabled:opacity-50 disabled:cursor-not-allowed'
disabled={busy}
onClick={run}>{
!busy
? (embeddings.length === 0 ? 'Compute Embeddings' : 'Recompute Embeddings')
: status === 'loading'
? 'Model loading...'
: 'Processing'
}</button>
</div>
<div className='flex flex-col items-center w-full md:w-1/2 gap-y-1'>
{embeddings.length > 0 && (<>
<label className='text-lg font-medium'>Dimensions</label>
<input
type="range"
min="64"
max="768"
step="1"
value={dimensions}
onChange={e => {
setDimensions(e.target.value);
}}
className="w-[98%] h-[10px]"
/>
<p className="font-bold text-sm">{dimensions}</p>
<div className='w-full flex flex-col gap-y-1'>
<label className='text-lg font-medium text-center mt-1'>Results</label>
<div className='flex flex-col gap-y-1'>
{results.map((result, i) => (
<div key={i} className='flex gap-x-2 border mx-2 p-1'>
<span className='font-bold'>{result.similarity.toFixed(3)}</span>
<span>{result.sentence}</span>
</div>
))}
</div>
</div>
</>)
}
</div>
</div>
</div>
)
}
export default App

View File

@ -0,0 +1,3 @@
@tailwind base;
@tailwind components;
@tailwind utilities;

View File

@ -0,0 +1,10 @@
import React from 'react'
import ReactDOM from 'react-dom/client'
import App from './App.jsx'
import './index.css'
ReactDOM.createRoot(document.getElementById('root')).render(
<React.StrictMode>
<App />
</React.StrictMode>,
)

View File

@ -0,0 +1,47 @@
import { env, pipeline } from '@xenova/transformers';
// Skip local model check since we are downloading the model from the Hugging Face Hub.
env.allowLocalModels = false;
class MyFeatureExtractionPipeline {
static task = 'feature-extraction';
static model = 'nomic-ai/nomic-embed-text-v1.5';
static instance = null;
static async getInstance(progress_callback = null) {
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, {
quantized: true,
progress_callback,
});
}
return this.instance;
}
}
// https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#usage
const SEARCH_PREFIX = 'search_query: ';
const DOCUMENT_PREFIX = 'search_document: ';
// Listen for messages from the main thread
self.addEventListener('message', async (event) => {
// Retrieve the pipeline. When called for the first time,
// this will load the pipeline and save it for future use.
const extractor = await MyFeatureExtractionPipeline.getInstance(x => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
const { source, text } = event.data;
const split = [
SEARCH_PREFIX + source,
...text.trim().split('\n').map(x => DOCUMENT_PREFIX + x),
];
const embeddings = await extractor(split, { pooling: 'mean', normalize: true });
// Send the output back to the main thread
self.postMessage({ status: 'complete', embeddings: embeddings.tolist() });
});

View File

@ -0,0 +1,12 @@
/** @type {import('tailwindcss').Config} */
export default {
content: [
"./index.html",
"./src/**/*.{js,ts,jsx,tsx}",
],
theme: {
extend: {},
},
plugins: [],
}

View File

@ -0,0 +1,7 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
// https://vitejs.dev/config/
export default defineConfig({
plugins: [react()],
})