Create cross-encoder demo (#617)

This commit is contained in:
Joshua Lochner 2024-03-06 17:31:57 +02:00 committed by GitHub
parent e072ee6b70
commit 382ff3af35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 6211 additions and 0 deletions

View File

@ -0,0 +1,20 @@
module.exports = {
root: true,
env: { browser: true, es2020: true },
extends: [
'eslint:recommended',
'plugin:react/recommended',
'plugin:react/jsx-runtime',
'plugin:react-hooks/recommended',
],
ignorePatterns: ['dist', '.eslintrc.cjs'],
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
settings: { react: { version: '18.2' } },
plugins: ['react-refresh'],
rules: {
'react-refresh/only-export-components': [
'warn',
{ allowConstantExport: true },
],
},
}

24
examples/cross-encoder/.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

View File

@ -0,0 +1,13 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Reranking w/ Transformers.js</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.jsx"></script>
</body>
</html>

5895
examples/cross-encoder/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
{
"name": "cross-encoder",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
"preview": "vite preview"
},
"dependencies": {
"@xenova/transformers": "^2.15.0",
"react": "^18.2.0",
"react-dom": "^18.2.0"
},
"devDependencies": {
"@types/react": "^18.2.43",
"@types/react-dom": "^18.2.17",
"@vitejs/plugin-react": "^4.2.1",
"autoprefixer": "^10.4.16",
"eslint": "^8.55.0",
"eslint-plugin-react": "^7.33.2",
"eslint-plugin-react-hooks": "^4.6.0",
"eslint-plugin-react-refresh": "^0.4.5",
"postcss": "^8.4.33",
"tailwindcss": "^3.4.1",
"vite": "^5.0.12"
}
}

View File

@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

View File

@ -0,0 +1,5 @@
#root {
height: 100vh;
width: 100vw;
padding: 1rem;
}

View File

@ -0,0 +1,124 @@
import { useState, useRef, useEffect, useCallback } from 'react'
import './App.css'
const PLACEHOLDER_TEXTS = [
"'To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.",
"The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.",
"Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.",
"Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.",
"The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.",
"'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."
].sort(() => Math.random() - 0.5);
function App() {
const [status, setStatus] = useState('idle');
const [query, setQuery] = useState(`Who wrote 'To Kill a Mockingbird'?`);
const [documents, setDocuments] = useState(PLACEHOLDER_TEXTS.join('\n'));
const [results, setResults] = useState([]);
// Create a reference to the worker object.
const worker = useRef(null);
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
useEffect(() => {
if (!worker.current) {
// Create the worker if it does not yet exist.
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
type: 'module'
});
}
// Create a callback function for messages from the worker thread.
const onMessageReceived = (e) => {
const status = e.data.status;
if (e.data.file?.endsWith('.onnx')) {
if (status === 'initiate') {
setStatus('loading');
} else if (status === 'done') {
setStatus('ready');
}
} else if (status === 'complete') {
setResults(e.data.output);
setStatus('idle');
}
};
// Attach the callback function as an event listener.
worker.current.addEventListener('message', onMessageReceived);
// Define a cleanup function for when the component is unmounted.
return () => worker.current.removeEventListener('message', onMessageReceived);
}, []);
const run = useCallback(() => {
setStatus('processing');
worker.current.postMessage({
query,
documents,
});
}, [query, documents])
const busy = status !== 'idle';
return (
<div className='flex flex-col h-full'>
<h1 className='text-2xl md:text-4xl font-bold text-center mb-1'>Reranking w/ The Crispy mixedbread Rerank Models</h1>
<p className='text-lg md:text-xl font-medium text-center mb-2'>Powered by <a href='https://huggingface.co/mixedbread-ai/mxbai-rerank-xsmall-v1' target='_blank' rel='noreferrer'>mxbai-rerank-xsmall-v1</a> and <a href='http://huggingface.co/docs/transformers.js' target='_blank' rel='noreferrer'>🤗 Transformers.js</a></p>
<div className='flex-grow flex flex-wrap p-4'>
<div className='flex flex-col items-center gap-y-1 w-full md:w-1/2'>
<label className='text-lg font-medium'>Query</label>
<textarea
placeholder='Enter query.'
className='border w-full p-1 resize-none overflow-hidden h-10'
value={query}
onChange={e => {
setQuery(e.target.value);
setResults([]);
}}
></textarea>
<label className='text-lg font-medium mt-1'>Documents</label>
<textarea
placeholder='Enter documents to compare with the query. One sentence per line.'
className='border w-full p-1 h-full resize-none'
value={documents}
onChange={e => {
setDocuments(e.target.value);
setResults([]);
}}
></textarea>
<button
className='border py-1 px-2 bg-green-400 rounded text-white text-lg font-medium disabled:opacity-50 disabled:cursor-not-allowed'
disabled={busy}
onClick={run}>{
!busy
? 'Rerank'
: status === 'loading'
? 'Model loading...'
: 'Processing'
}</button>
</div>
<div className='flex flex-col items-center w-full md:w-1/2 gap-y-1'>
{results.length > 0 && (<>
<div className='w-full flex flex-col gap-y-1'>
<label className='text-lg font-medium text-center'>Results</label>
<div className='flex flex-col gap-y-1'>
{results.map((result, i) => (
<div key={i} className='flex gap-x-2 border mx-2 p-1'>
<span className='font-bold'>{result.score.toFixed(3)}</span>
<span>{result.text}</span>
</div>
))}
</div>
</div>
</>)
}
</div>
</div>
</div>
)
}
export default App

View File

@ -0,0 +1,3 @@
@tailwind base;
@tailwind components;
@tailwind utilities;

View File

@ -0,0 +1,10 @@
import React from 'react'
import ReactDOM from 'react-dom/client'
import App from './App.jsx'
import './index.css'
ReactDOM.createRoot(document.getElementById('root')).render(
<React.StrictMode>
<App />
</React.StrictMode>,
)

View File

@ -0,0 +1,62 @@
import { env, AutoTokenizer, AutoModelForSequenceClassification } from '@xenova/transformers';
// Skip local model check since we are downloading the model from the Hugging Face Hub.
env.allowLocalModels = false;
class CrossEncoderSingleton {
static model_id = 'mixedbread-ai/mxbai-rerank-xsmall-v1';
static model = null;
static tokenizer = null;
static async getInstance(progress_callback) {
if (!this.tokenizer) {
this.tokenizer = AutoTokenizer.from_pretrained(this.model_id);
}
if (!this.model) {
this.model = AutoModelForSequenceClassification.from_pretrained(this.model_id, {
quantized: true,
progress_callback,
});
}
return Promise.all([this.tokenizer, this.model]);
}
}
// Listen for messages from the main thread
self.addEventListener('message', async (event) => {
// Retrieve the pipeline. When called for the first time,
// this will load the pipeline and save it for future use.
const [tokenizer, model] = await CrossEncoderSingleton.getInstance(x => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
const { query, documents } = event.data;
const docs = documents.trim().split('\n');
const inputs = tokenizer(
new Array(docs.length).fill(query),
{
text_pair: docs,
padding: true,
truncation: true,
}
)
const { logits } = await model(inputs);
const output = logits
.sigmoid()
.tolist()
.map(([score], i) => ({
corpus_id: i,
score,
text: docs[i],
}))
.sort((a, b) => b.score - a.score);
// Send the output back to the main thread
self.postMessage({ status: 'complete', output });
});

View File

@ -0,0 +1,12 @@
/** @type {import('tailwindcss').Config} */
export default {
content: [
"./index.html",
"./src/**/*.{js,ts,jsx,tsx}",
],
theme: {
extend: {},
},
plugins: [],
}

View File

@ -0,0 +1,7 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
// https://vitejs.dev/config/
export default defineConfig({
plugins: [react()],
})