Create adaptive retrieval demo (#587)
* Create adaptive retrieval demo * Remove waiting label
This commit is contained in:
parent
d1eabaeb0b
commit
6d2808b571
|
@ -0,0 +1,20 @@
|
|||
module.exports = {
|
||||
root: true,
|
||||
env: { browser: true, es2020: true },
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
'plugin:react/recommended',
|
||||
'plugin:react/jsx-runtime',
|
||||
'plugin:react-hooks/recommended',
|
||||
],
|
||||
ignorePatterns: ['dist', '.eslintrc.cjs'],
|
||||
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
|
||||
settings: { react: { version: '18.2' } },
|
||||
plugins: ['react-refresh'],
|
||||
rules: {
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
},
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
|
@ -0,0 +1,13 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Adaptive Retrieval w/ Transformers.js</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"name": "adaptive-retrieval",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@xenova/transformers": "^2.15.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.2.43",
|
||||
"@types/react-dom": "^18.2.17",
|
||||
"@vitejs/plugin-react": "^4.2.1",
|
||||
"autoprefixer": "^10.4.16",
|
||||
"eslint": "^8.55.0",
|
||||
"eslint-plugin-react": "^7.33.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.5",
|
||||
"postcss": "^8.4.33",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"vite": "^5.0.12"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
#root {
|
||||
height: 100vh;
|
||||
width: 100vw;
|
||||
padding: 1rem;
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
import { useState, useRef, useEffect, useCallback } from 'react'
|
||||
import './App.css'
|
||||
|
||||
const PLACEHOLDER_TEXTS = [
|
||||
"A panda is a large black-and-white bear native to China.",
|
||||
"The typical life span of a panda is 20 years in the wild.",
|
||||
"A panda's diet consists almost entirely of bamboo.",
|
||||
"Ailuropoda melanoleuca is a bear species endemic to China.",
|
||||
"I love pandas so much!",
|
||||
"Bamboo is a fast-growing, woody grass.",
|
||||
"My favorite movie is Kung Fu Panda.",
|
||||
"I love the color blue.",
|
||||
"Once upon a time, in a land far, far away...",
|
||||
"Hello world.",
|
||||
"This is an example sentence.",
|
||||
].sort(() => Math.random() - 0.5);
|
||||
|
||||
function normalize(embedding) {
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
|
||||
return embedding.map((val) => val / magnitude);
|
||||
}
|
||||
|
||||
function dot(a, b) {
|
||||
return a.reduce((acc, val, i) => acc + val * b[i], 0);
|
||||
}
|
||||
|
||||
function App() {
|
||||
const [status, setStatus] = useState('idle');
|
||||
|
||||
const [source, setSource] = useState('What is a panda?');
|
||||
const [text, setText] = useState(PLACEHOLDER_TEXTS.join('\n'));
|
||||
|
||||
const [dimensions, setDimensions] = useState(768);
|
||||
|
||||
const [embeddings, setEmbeddings] = useState([]);
|
||||
const [results, setResults] = useState([]);
|
||||
|
||||
// Create a reference to the worker object.
|
||||
const worker = useRef(null);
|
||||
|
||||
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
||||
useEffect(() => {
|
||||
if (!worker.current) {
|
||||
// Create the worker if it does not yet exist.
|
||||
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
|
||||
type: 'module'
|
||||
});
|
||||
}
|
||||
|
||||
// Create a callback function for messages from the worker thread.
|
||||
const onMessageReceived = (e) => {
|
||||
const status = e.data.status;
|
||||
if (status === 'initiate') {
|
||||
setStatus('loading');
|
||||
} else if (status === 'ready') {
|
||||
setStatus('ready');
|
||||
} else if (status === 'complete') {
|
||||
const embeddings = e.data.embeddings;
|
||||
setDimensions(embeddings[0].length);
|
||||
setEmbeddings(embeddings);
|
||||
setStatus('idle');
|
||||
}
|
||||
};
|
||||
|
||||
// Attach the callback function as an event listener.
|
||||
worker.current.addEventListener('message', onMessageReceived);
|
||||
|
||||
// Define a cleanup function for when the component is unmounted.
|
||||
return () => worker.current.removeEventListener('message', onMessageReceived);
|
||||
}, []);
|
||||
|
||||
const run = useCallback(() => {
|
||||
setStatus('processing');
|
||||
worker.current.postMessage({
|
||||
source,
|
||||
text,
|
||||
});
|
||||
}, [source, text])
|
||||
|
||||
useEffect(() => {
|
||||
if (embeddings.length === 0) return;
|
||||
const slicedEmbeddings = embeddings.map(x => normalize(x.slice(0, dimensions)));
|
||||
|
||||
const sourceEmbedding = slicedEmbeddings[0];
|
||||
const sentenceEmbeddings = slicedEmbeddings.slice(1);
|
||||
|
||||
// Compute the cosine similarity between the source sentence and the other sentences.
|
||||
// NOTE: Since vectors are normalized, we use the dot product.
|
||||
const similarities = sentenceEmbeddings.map((embedding) => dot(sourceEmbedding, embedding));
|
||||
|
||||
setResults(text.trim().split('\n').map((sentence, i) => ({
|
||||
sentence,
|
||||
similarity: similarities[i]
|
||||
})).sort((a, b) => b.similarity - a.similarity));
|
||||
}, [text, embeddings, dimensions])
|
||||
|
||||
const busy = status !== 'idle';
|
||||
|
||||
return (
|
||||
<div className='flex flex-col h-full'>
|
||||
<h1 className='text-2xl md:text-4xl font-bold text-center mb-1'>Adaptive Retrieval w/ Matryoshka Embeddings</h1>
|
||||
<p className='text-lg md:text-xl font-medium text-center mb-2'>Powered by <a href='https://huggingface.co/nomic-ai/nomic-embed-text-v1.5'>Nomic Embed v1.5</a> and <a href='http://huggingface.co/docs/transformers.js'>🤗 Transformers.js</a></p>
|
||||
<div className='flex-grow flex flex-wrap p-4'>
|
||||
<div className='flex flex-col items-center gap-y-1 w-full md:w-1/2'>
|
||||
<label className='text-lg font-medium'>Query</label>
|
||||
<textarea
|
||||
placeholder='Enter source sentence.'
|
||||
className='border w-full p-1 resize-none overflow-hidden h-10'
|
||||
value={source}
|
||||
onChange={e => {
|
||||
setSource(e.target.value);
|
||||
setResults([]);
|
||||
setEmbeddings([]);
|
||||
}}
|
||||
></textarea>
|
||||
<label className='text-lg font-medium mt-1'>Text</label>
|
||||
<textarea
|
||||
placeholder='Enter sentences to compare with the source sentence. One sentence per line.'
|
||||
className='border w-full p-1 h-full resize-none'
|
||||
value={text}
|
||||
onChange={e => {
|
||||
setText(e.target.value);
|
||||
setResults([]);
|
||||
setEmbeddings([]);
|
||||
}}
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
className='border py-1 px-2 bg-blue-400 rounded text-white text-lg font-medium disabled:opacity-50 disabled:cursor-not-allowed'
|
||||
disabled={busy}
|
||||
onClick={run}>{
|
||||
!busy
|
||||
? (embeddings.length === 0 ? 'Compute Embeddings' : 'Recompute Embeddings')
|
||||
: status === 'loading'
|
||||
? 'Model loading...'
|
||||
: 'Processing'
|
||||
}</button>
|
||||
</div>
|
||||
<div className='flex flex-col items-center w-full md:w-1/2 gap-y-1'>
|
||||
{embeddings.length > 0 && (<>
|
||||
<label className='text-lg font-medium'>Dimensions</label>
|
||||
<input
|
||||
type="range"
|
||||
min="64"
|
||||
max="768"
|
||||
step="1"
|
||||
value={dimensions}
|
||||
onChange={e => {
|
||||
setDimensions(e.target.value);
|
||||
}}
|
||||
className="w-[98%] h-[10px]"
|
||||
/>
|
||||
<p className="font-bold text-sm">{dimensions}</p>
|
||||
<div className='w-full flex flex-col gap-y-1'>
|
||||
<label className='text-lg font-medium text-center mt-1'>Results</label>
|
||||
<div className='flex flex-col gap-y-1'>
|
||||
{results.map((result, i) => (
|
||||
<div key={i} className='flex gap-x-2 border mx-2 p-1'>
|
||||
<span className='font-bold'>{result.similarity.toFixed(3)}</span>
|
||||
<span>{result.sentence}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</>)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default App
|
|
@ -0,0 +1,3 @@
|
|||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
|
@ -0,0 +1,10 @@
|
|||
import React from 'react'
|
||||
import ReactDOM from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
import './index.css'
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')).render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
</React.StrictMode>,
|
||||
)
|
|
@ -0,0 +1,47 @@
|
|||
import { env, pipeline } from '@xenova/transformers';
|
||||
|
||||
// Skip local model check since we are downloading the model from the Hugging Face Hub.
|
||||
env.allowLocalModels = false;
|
||||
|
||||
class MyFeatureExtractionPipeline {
|
||||
static task = 'feature-extraction';
|
||||
static model = 'nomic-ai/nomic-embed-text-v1.5';
|
||||
static instance = null;
|
||||
|
||||
static async getInstance(progress_callback = null) {
|
||||
if (this.instance === null) {
|
||||
this.instance = pipeline(this.task, this.model, {
|
||||
quantized: true,
|
||||
progress_callback,
|
||||
});
|
||||
}
|
||||
|
||||
return this.instance;
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#usage
|
||||
const SEARCH_PREFIX = 'search_query: ';
|
||||
const DOCUMENT_PREFIX = 'search_document: ';
|
||||
|
||||
// Listen for messages from the main thread
|
||||
self.addEventListener('message', async (event) => {
|
||||
// Retrieve the pipeline. When called for the first time,
|
||||
// this will load the pipeline and save it for future use.
|
||||
const extractor = await MyFeatureExtractionPipeline.getInstance(x => {
|
||||
// We also add a progress callback to the pipeline so that we can
|
||||
// track model loading.
|
||||
self.postMessage(x);
|
||||
});
|
||||
|
||||
const { source, text } = event.data;
|
||||
|
||||
const split = [
|
||||
SEARCH_PREFIX + source,
|
||||
...text.trim().split('\n').map(x => DOCUMENT_PREFIX + x),
|
||||
];
|
||||
const embeddings = await extractor(split, { pooling: 'mean', normalize: true });
|
||||
|
||||
// Send the output back to the main thread
|
||||
self.postMessage({ status: 'complete', embeddings: embeddings.tolist() });
|
||||
});
|
|
@ -0,0 +1,12 @@
|
|||
/** @type {import('tailwindcss').Config} */
|
||||
export default {
|
||||
content: [
|
||||
"./index.html",
|
||||
"./src/**/*.{js,ts,jsx,tsx}",
|
||||
],
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
})
|
Loading…
Reference in New Issue