Create cross-encoder demo (#617)
This commit is contained in:
parent
e072ee6b70
commit
382ff3af35
|
@ -0,0 +1,20 @@
|
|||
module.exports = {
|
||||
root: true,
|
||||
env: { browser: true, es2020: true },
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
'plugin:react/recommended',
|
||||
'plugin:react/jsx-runtime',
|
||||
'plugin:react-hooks/recommended',
|
||||
],
|
||||
ignorePatterns: ['dist', '.eslintrc.cjs'],
|
||||
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
|
||||
settings: { react: { version: '18.2' } },
|
||||
plugins: ['react-refresh'],
|
||||
rules: {
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
},
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
|
@ -0,0 +1,13 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Reranking w/ Transformers.js</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"name": "cross-encoder",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@xenova/transformers": "^2.15.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.2.43",
|
||||
"@types/react-dom": "^18.2.17",
|
||||
"@vitejs/plugin-react": "^4.2.1",
|
||||
"autoprefixer": "^10.4.16",
|
||||
"eslint": "^8.55.0",
|
||||
"eslint-plugin-react": "^7.33.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.5",
|
||||
"postcss": "^8.4.33",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"vite": "^5.0.12"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
#root {
|
||||
height: 100vh;
|
||||
width: 100vw;
|
||||
padding: 1rem;
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
import { useState, useRef, useEffect, useCallback } from 'react'
|
||||
import './App.css'
|
||||
|
||||
const PLACEHOLDER_TEXTS = [
|
||||
"'To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.",
|
||||
"The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.",
|
||||
"Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.",
|
||||
"Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.",
|
||||
"The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.",
|
||||
"'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."
|
||||
].sort(() => Math.random() - 0.5);
|
||||
|
||||
function App() {
|
||||
const [status, setStatus] = useState('idle');
|
||||
|
||||
const [query, setQuery] = useState(`Who wrote 'To Kill a Mockingbird'?`);
|
||||
const [documents, setDocuments] = useState(PLACEHOLDER_TEXTS.join('\n'));
|
||||
|
||||
const [results, setResults] = useState([]);
|
||||
|
||||
// Create a reference to the worker object.
|
||||
const worker = useRef(null);
|
||||
|
||||
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
||||
useEffect(() => {
|
||||
if (!worker.current) {
|
||||
// Create the worker if it does not yet exist.
|
||||
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
|
||||
type: 'module'
|
||||
});
|
||||
}
|
||||
|
||||
// Create a callback function for messages from the worker thread.
|
||||
const onMessageReceived = (e) => {
|
||||
const status = e.data.status;
|
||||
if (e.data.file?.endsWith('.onnx')) {
|
||||
if (status === 'initiate') {
|
||||
setStatus('loading');
|
||||
} else if (status === 'done') {
|
||||
setStatus('ready');
|
||||
}
|
||||
} else if (status === 'complete') {
|
||||
setResults(e.data.output);
|
||||
setStatus('idle');
|
||||
}
|
||||
};
|
||||
|
||||
// Attach the callback function as an event listener.
|
||||
worker.current.addEventListener('message', onMessageReceived);
|
||||
|
||||
// Define a cleanup function for when the component is unmounted.
|
||||
return () => worker.current.removeEventListener('message', onMessageReceived);
|
||||
}, []);
|
||||
|
||||
const run = useCallback(() => {
|
||||
setStatus('processing');
|
||||
worker.current.postMessage({
|
||||
query,
|
||||
documents,
|
||||
});
|
||||
}, [query, documents])
|
||||
|
||||
const busy = status !== 'idle';
|
||||
|
||||
return (
|
||||
<div className='flex flex-col h-full'>
|
||||
<h1 className='text-2xl md:text-4xl font-bold text-center mb-1'>Reranking w/ The Crispy mixedbread Rerank Models</h1>
|
||||
<p className='text-lg md:text-xl font-medium text-center mb-2'>Powered by <a href='https://huggingface.co/mixedbread-ai/mxbai-rerank-xsmall-v1' target='_blank' rel='noreferrer'>mxbai-rerank-xsmall-v1</a> and <a href='http://huggingface.co/docs/transformers.js' target='_blank' rel='noreferrer'>🤗 Transformers.js</a></p>
|
||||
<div className='flex-grow flex flex-wrap p-4'>
|
||||
<div className='flex flex-col items-center gap-y-1 w-full md:w-1/2'>
|
||||
<label className='text-lg font-medium'>Query</label>
|
||||
<textarea
|
||||
placeholder='Enter query.'
|
||||
className='border w-full p-1 resize-none overflow-hidden h-10'
|
||||
value={query}
|
||||
onChange={e => {
|
||||
setQuery(e.target.value);
|
||||
setResults([]);
|
||||
}}
|
||||
></textarea>
|
||||
<label className='text-lg font-medium mt-1'>Documents</label>
|
||||
<textarea
|
||||
placeholder='Enter documents to compare with the query. One sentence per line.'
|
||||
className='border w-full p-1 h-full resize-none'
|
||||
value={documents}
|
||||
onChange={e => {
|
||||
setDocuments(e.target.value);
|
||||
setResults([]);
|
||||
}}
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
className='border py-1 px-2 bg-green-400 rounded text-white text-lg font-medium disabled:opacity-50 disabled:cursor-not-allowed'
|
||||
disabled={busy}
|
||||
onClick={run}>{
|
||||
!busy
|
||||
? 'Rerank'
|
||||
: status === 'loading'
|
||||
? 'Model loading...'
|
||||
: 'Processing'
|
||||
}</button>
|
||||
</div>
|
||||
<div className='flex flex-col items-center w-full md:w-1/2 gap-y-1'>
|
||||
{results.length > 0 && (<>
|
||||
<div className='w-full flex flex-col gap-y-1'>
|
||||
<label className='text-lg font-medium text-center'>Results</label>
|
||||
<div className='flex flex-col gap-y-1'>
|
||||
{results.map((result, i) => (
|
||||
<div key={i} className='flex gap-x-2 border mx-2 p-1'>
|
||||
<span className='font-bold'>{result.score.toFixed(3)}</span>
|
||||
<span>{result.text}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</>)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default App
|
|
@ -0,0 +1,3 @@
|
|||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
|
@ -0,0 +1,10 @@
|
|||
import React from 'react'
|
||||
import ReactDOM from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
import './index.css'
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')).render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
</React.StrictMode>,
|
||||
)
|
|
@ -0,0 +1,62 @@
|
|||
import { env, AutoTokenizer, AutoModelForSequenceClassification } from '@xenova/transformers';
|
||||
|
||||
// Skip local model check since we are downloading the model from the Hugging Face Hub.
|
||||
env.allowLocalModels = false;
|
||||
|
||||
class CrossEncoderSingleton {
|
||||
static model_id = 'mixedbread-ai/mxbai-rerank-xsmall-v1';
|
||||
static model = null;
|
||||
static tokenizer = null;
|
||||
|
||||
static async getInstance(progress_callback) {
|
||||
if (!this.tokenizer) {
|
||||
this.tokenizer = AutoTokenizer.from_pretrained(this.model_id);
|
||||
}
|
||||
|
||||
if (!this.model) {
|
||||
this.model = AutoModelForSequenceClassification.from_pretrained(this.model_id, {
|
||||
quantized: true,
|
||||
progress_callback,
|
||||
});
|
||||
}
|
||||
|
||||
return Promise.all([this.tokenizer, this.model]);
|
||||
}
|
||||
}
|
||||
|
||||
// Listen for messages from the main thread
|
||||
self.addEventListener('message', async (event) => {
|
||||
// Retrieve the pipeline. When called for the first time,
|
||||
// this will load the pipeline and save it for future use.
|
||||
const [tokenizer, model] = await CrossEncoderSingleton.getInstance(x => {
|
||||
// We also add a progress callback to the pipeline so that we can
|
||||
// track model loading.
|
||||
self.postMessage(x);
|
||||
});
|
||||
|
||||
const { query, documents } = event.data;
|
||||
|
||||
const docs = documents.trim().split('\n');
|
||||
|
||||
const inputs = tokenizer(
|
||||
new Array(docs.length).fill(query),
|
||||
{
|
||||
text_pair: docs,
|
||||
padding: true,
|
||||
truncation: true,
|
||||
}
|
||||
)
|
||||
const { logits } = await model(inputs);
|
||||
const output = logits
|
||||
.sigmoid()
|
||||
.tolist()
|
||||
.map(([score], i) => ({
|
||||
corpus_id: i,
|
||||
score,
|
||||
text: docs[i],
|
||||
}))
|
||||
.sort((a, b) => b.score - a.score);
|
||||
|
||||
// Send the output back to the main thread
|
||||
self.postMessage({ status: 'complete', output });
|
||||
});
|
|
@ -0,0 +1,12 @@
|
|||
/** @type {import('tailwindcss').Config} */
|
||||
export default {
|
||||
content: [
|
||||
"./index.html",
|
||||
"./src/**/*.{js,ts,jsx,tsx}",
|
||||
],
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
})
|
Loading…
Reference in New Issue