Add zero-shot classification demo (#519)
* Create zero-shot-classification demo * Update hypothesis template
This commit is contained in:
parent
b1f96a2fc9
commit
1edf683e64
|
@ -0,0 +1,20 @@
|
||||||
|
module.exports = {
|
||||||
|
root: true,
|
||||||
|
env: { browser: true, es2020: true },
|
||||||
|
extends: [
|
||||||
|
'eslint:recommended',
|
||||||
|
'plugin:react/recommended',
|
||||||
|
'plugin:react/jsx-runtime',
|
||||||
|
'plugin:react-hooks/recommended',
|
||||||
|
],
|
||||||
|
ignorePatterns: ['dist', '.eslintrc.cjs'],
|
||||||
|
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
|
||||||
|
settings: { react: { version: '18.2' } },
|
||||||
|
plugins: ['react-refresh'],
|
||||||
|
rules: {
|
||||||
|
'react-refresh/only-export-components': [
|
||||||
|
'warn',
|
||||||
|
{ allowConstantExport: true },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
|
@ -0,0 +1,13 @@
|
||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Zero-shot classication w/ Transformers.js</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.jsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,30 @@
|
||||||
|
{
|
||||||
|
"name": "zero-shot-classification",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "vite build",
|
||||||
|
"lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@xenova/transformers": "^2.14.0",
|
||||||
|
"react": "^18.2.0",
|
||||||
|
"react-dom": "^18.2.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/react": "^18.2.43",
|
||||||
|
"@types/react-dom": "^18.2.17",
|
||||||
|
"@vitejs/plugin-react": "^4.2.1",
|
||||||
|
"autoprefixer": "^10.4.16",
|
||||||
|
"eslint": "^8.55.0",
|
||||||
|
"eslint-plugin-react": "^7.33.2",
|
||||||
|
"eslint-plugin-react-hooks": "^4.6.0",
|
||||||
|
"eslint-plugin-react-refresh": "^0.4.5",
|
||||||
|
"postcss": "^8.4.33",
|
||||||
|
"tailwindcss": "^3.4.1",
|
||||||
|
"vite": "^5.0.8"
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
export default {
|
||||||
|
plugins: {
|
||||||
|
tailwindcss: {},
|
||||||
|
autoprefixer: {},
|
||||||
|
},
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
#root {
|
||||||
|
height: 100vh;
|
||||||
|
width: 100vw;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
|
@ -0,0 +1,189 @@
|
||||||
|
import { useState, useRef, useEffect, useCallback } from 'react'
|
||||||
|
import './App.css'
|
||||||
|
|
||||||
|
const PLACEHOLDER_REVIEWS = [
|
||||||
|
// battery/charging problems
|
||||||
|
"Disappointed with the battery life! The phone barely lasts half a day with regular use. Considering how much I paid for it, I expected better performance in this department.",
|
||||||
|
"I bought this phone a week ago, and I'm already frustrated with the battery life. It barely lasts half a day with normal usage. I expected more from a supposedly high-end device",
|
||||||
|
"The charging port is so finicky. Sometimes it takes forever to charge, and other times it doesn't even recognize the charger. Frustrating experience!",
|
||||||
|
|
||||||
|
// overheating
|
||||||
|
"This phone heats up way too quickly, especially when using demanding apps. It's uncomfortable to hold, and I'm concerned it might damage the internal components over time. Not what I expected",
|
||||||
|
"This phone is like holding a hot potato. Video calls turn it into a scalding nightmare. Seriously, can't it keep its cool?",
|
||||||
|
"Forget about a heatwave outside; my phone's got its own. It's like a little portable heater. Not what I signed up for.",
|
||||||
|
|
||||||
|
// poor build quality
|
||||||
|
"I dropped the phone from a short distance, and the screen cracked easily. Not as durable as I expected from a flagship device.",
|
||||||
|
"Took a slight bump in my bag, and the frame got dinged. Are we back in the flip phone era?",
|
||||||
|
"So, my phone's been in my pocket with just keys – no ninja moves or anything. Still, it managed to get some scratches. Disappointed with the build quality.",
|
||||||
|
|
||||||
|
// software
|
||||||
|
"The software updates are a nightmare. Each update seems to introduce new bugs, and it takes forever for them to be fixed.",
|
||||||
|
"Constant crashes and freezes make me want to throw it into a black hole.",
|
||||||
|
"Every time I open Instagram, my phone freezes and crashes. It's so frustrating!",
|
||||||
|
|
||||||
|
// other
|
||||||
|
"I'm not sure what to make of this phone. It's not bad, but it's not great either. I'm on the fence about it.",
|
||||||
|
"I hate the color of this phone. It's so ugly!",
|
||||||
|
"This phone sucks! I'm returning it."
|
||||||
|
].sort(() => Math.random() - 0.5)
|
||||||
|
|
||||||
|
const PLACEHOLDER_SECTIONS = [
|
||||||
|
'Battery and charging problems',
|
||||||
|
'Overheating',
|
||||||
|
'Poor build quality',
|
||||||
|
'Software issues',
|
||||||
|
'Other',
|
||||||
|
];
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
|
||||||
|
const [text, setText] = useState(PLACEHOLDER_REVIEWS.join('\n'));
|
||||||
|
|
||||||
|
const [sections, setSections] = useState(
|
||||||
|
PLACEHOLDER_SECTIONS.map(title => ({ title, items: [] }))
|
||||||
|
);
|
||||||
|
|
||||||
|
const [status, setStatus] = useState('idle');
|
||||||
|
|
||||||
|
// Create a reference to the worker object.
|
||||||
|
const worker = useRef(null);
|
||||||
|
|
||||||
|
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
||||||
|
useEffect(() => {
|
||||||
|
if (!worker.current) {
|
||||||
|
// Create the worker if it does not yet exist.
|
||||||
|
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
|
||||||
|
type: 'module'
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a callback function for messages from the worker thread.
|
||||||
|
const onMessageReceived = (e) => {
|
||||||
|
const status = e.data.status;
|
||||||
|
if (status === 'initiate') {
|
||||||
|
setStatus('loading');
|
||||||
|
} else if (status === 'ready') {
|
||||||
|
setStatus('ready');
|
||||||
|
} else if (status === 'output') {
|
||||||
|
const { sequence, labels, scores } = e.data.output;
|
||||||
|
|
||||||
|
// Threshold for classification
|
||||||
|
const label = scores[0] > 0.5 ? labels[0] : 'Other';
|
||||||
|
|
||||||
|
const sectionID = sections.map(x => x.title).indexOf(label) ?? sections.length - 1;
|
||||||
|
setSections((sections) => {
|
||||||
|
const newSections = [...sections]
|
||||||
|
newSections[sectionID] = {
|
||||||
|
...newSections[sectionID],
|
||||||
|
items: [...newSections[sectionID].items, sequence]
|
||||||
|
}
|
||||||
|
return newSections
|
||||||
|
})
|
||||||
|
} else if (status === 'complete') {
|
||||||
|
setStatus('idle');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Attach the callback function as an event listener.
|
||||||
|
worker.current.addEventListener('message', onMessageReceived);
|
||||||
|
|
||||||
|
// Define a cleanup function for when the component is unmounted.
|
||||||
|
return () => worker.current.removeEventListener('message', onMessageReceived);
|
||||||
|
}, [sections]);
|
||||||
|
|
||||||
|
const classify = useCallback(() => {
|
||||||
|
setStatus('processing');
|
||||||
|
worker.current.postMessage({
|
||||||
|
text,
|
||||||
|
labels: sections.slice(0, sections.length - 1).map(section => section.title)
|
||||||
|
});
|
||||||
|
}, [text, sections])
|
||||||
|
|
||||||
|
const busy = status !== 'idle';
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className='flex flex-col h-full'>
|
||||||
|
<textarea
|
||||||
|
className='border w-full p-1 h-1/2'
|
||||||
|
value={text}
|
||||||
|
onChange={e => setText(e.target.value)}
|
||||||
|
></textarea>
|
||||||
|
<div className='flex flex-col justify-center items-center m-2 gap-1'>
|
||||||
|
<button
|
||||||
|
className='border py-1 px-2 bg-blue-400 rounded text-white text-lg font-medium disabled:opacity-50 disabled:cursor-not-allowed'
|
||||||
|
disabled={busy}
|
||||||
|
onClick={classify}>{
|
||||||
|
!busy
|
||||||
|
? 'Categorize'
|
||||||
|
: status === 'loading'
|
||||||
|
? 'Model loading...'
|
||||||
|
: 'Processing'
|
||||||
|
}</button>
|
||||||
|
<div className='flex gap-1'>
|
||||||
|
<button
|
||||||
|
className='border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium'
|
||||||
|
onClick={e => {
|
||||||
|
setSections((sections) => {
|
||||||
|
const newSections = [...sections];
|
||||||
|
// add at position 2 from the end
|
||||||
|
newSections.splice(newSections.length - 1, 0, {
|
||||||
|
title: 'New Category',
|
||||||
|
items: [],
|
||||||
|
})
|
||||||
|
return newSections;
|
||||||
|
})
|
||||||
|
}}>Add category</button>
|
||||||
|
<button
|
||||||
|
className='border py-1 px-2 bg-red-400 rounded text-white text-sm font-medium'
|
||||||
|
disabled={sections.length <= 1}
|
||||||
|
onClick={e => {
|
||||||
|
setSections((sections) => {
|
||||||
|
const newSections = [...sections];
|
||||||
|
newSections.splice(newSections.length - 2, 1); // Remove second last element
|
||||||
|
return newSections;
|
||||||
|
})
|
||||||
|
}}>Remove category</button>
|
||||||
|
<button
|
||||||
|
className='border py-1 px-2 bg-orange-400 rounded text-white text-sm font-medium'
|
||||||
|
onClick={e => {
|
||||||
|
setSections((sections) => (sections.map(section => ({
|
||||||
|
...section,
|
||||||
|
items: [],
|
||||||
|
}))))
|
||||||
|
}}>Clear</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className='flex justify-between flex-grow overflow-x-auto max-h-[40%]'>
|
||||||
|
{sections.map((section, index) => (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className="flex flex-col w-full"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
disabled={section.title === 'Other'}
|
||||||
|
className="w-full border px-1 text-center"
|
||||||
|
value={section.title} onChange={e => {
|
||||||
|
setSections((sections) => {
|
||||||
|
const newSections = [...sections];
|
||||||
|
newSections[index].title = e.target.value;
|
||||||
|
return newSections;
|
||||||
|
})
|
||||||
|
}}></input>
|
||||||
|
<div className="overflow-y-auto h-full border">
|
||||||
|
{section.items.map((item, index) => (
|
||||||
|
<div
|
||||||
|
className="m-2 border bg-red-50 rounded p-1 text-sm"
|
||||||
|
key={index}>{item}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App
|
|
@ -0,0 +1,3 @@
|
||||||
|
@tailwind base;
|
||||||
|
@tailwind components;
|
||||||
|
@tailwind utilities;
|
|
@ -0,0 +1,10 @@
|
||||||
|
import React from 'react'
|
||||||
|
import ReactDOM from 'react-dom/client'
|
||||||
|
import App from './App.jsx'
|
||||||
|
import './index.css'
|
||||||
|
|
||||||
|
ReactDOM.createRoot(document.getElementById('root')).render(
|
||||||
|
<React.StrictMode>
|
||||||
|
<App />
|
||||||
|
</React.StrictMode>,
|
||||||
|
)
|
|
@ -0,0 +1,46 @@
|
||||||
|
import { env, pipeline } from '@xenova/transformers';
|
||||||
|
|
||||||
|
// Skip local model check since we are downloading the model from the Hugging Face Hub.
|
||||||
|
env.allowLocalModels = false;
|
||||||
|
|
||||||
|
class MyZeroShotClassificationPipeline {
|
||||||
|
static task = 'zero-shot-classification';
|
||||||
|
static model = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33';
|
||||||
|
static instance = null;
|
||||||
|
|
||||||
|
static async getInstance(progress_callback = null) {
|
||||||
|
if (this.instance === null) {
|
||||||
|
this.instance = pipeline(this.task, this.model, {
|
||||||
|
quantized: true,
|
||||||
|
progress_callback,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.instance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen for messages from the main thread
|
||||||
|
self.addEventListener('message', async (event) => {
|
||||||
|
// Retrieve the pipeline. When called for the first time,
|
||||||
|
// this will load the pipeline and save it for future use.
|
||||||
|
const classifier = await MyZeroShotClassificationPipeline.getInstance(x => {
|
||||||
|
// We also add a progress callback to the pipeline so that we can
|
||||||
|
// track model loading.
|
||||||
|
self.postMessage(x);
|
||||||
|
});
|
||||||
|
|
||||||
|
const { text, labels } = event.data;
|
||||||
|
|
||||||
|
const split = text.split('\n');
|
||||||
|
for (const line of split) {
|
||||||
|
const output = await classifier(line, labels, {
|
||||||
|
hypothesis_template: 'This text is about {}.',
|
||||||
|
multi_label: true,
|
||||||
|
});
|
||||||
|
// Send the output back to the main thread
|
||||||
|
self.postMessage({ status: 'output', output });
|
||||||
|
}
|
||||||
|
// Send the output back to the main thread
|
||||||
|
self.postMessage({ status: 'complete' });
|
||||||
|
});
|
|
@ -0,0 +1,12 @@
|
||||||
|
/** @type {import('tailwindcss').Config} */
|
||||||
|
export default {
|
||||||
|
content: [
|
||||||
|
"./index.html",
|
||||||
|
"./src/**/*.{js,ts,jsx,tsx}",
|
||||||
|
],
|
||||||
|
theme: {
|
||||||
|
extend: {},
|
||||||
|
},
|
||||||
|
plugins: [],
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
import { defineConfig } from 'vite'
|
||||||
|
import react from '@vitejs/plugin-react'
|
||||||
|
|
||||||
|
// https://vitejs.dev/config/
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
})
|
Loading…
Reference in New Issue