Add zero-shot classification demo (#519)
* Create zero-shot-classification demo * Update hypothesis template
This commit is contained in:
parent
b1f96a2fc9
commit
1edf683e64
|
@ -0,0 +1,20 @@
|
|||
module.exports = {
|
||||
root: true,
|
||||
env: { browser: true, es2020: true },
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
'plugin:react/recommended',
|
||||
'plugin:react/jsx-runtime',
|
||||
'plugin:react-hooks/recommended',
|
||||
],
|
||||
ignorePatterns: ['dist', '.eslintrc.cjs'],
|
||||
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
|
||||
settings: { react: { version: '18.2' } },
|
||||
plugins: ['react-refresh'],
|
||||
rules: {
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
},
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
|
@ -0,0 +1,13 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Zero-shot classication w/ Transformers.js</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"name": "zero-shot-classification",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@xenova/transformers": "^2.14.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.2.43",
|
||||
"@types/react-dom": "^18.2.17",
|
||||
"@vitejs/plugin-react": "^4.2.1",
|
||||
"autoprefixer": "^10.4.16",
|
||||
"eslint": "^8.55.0",
|
||||
"eslint-plugin-react": "^7.33.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.5",
|
||||
"postcss": "^8.4.33",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"vite": "^5.0.8"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
#root {
|
||||
height: 100vh;
|
||||
width: 100vw;
|
||||
padding: 1rem;
|
||||
}
|
|
@ -0,0 +1,189 @@
|
|||
import { useState, useRef, useEffect, useCallback } from 'react'
|
||||
import './App.css'
|
||||
|
||||
const PLACEHOLDER_REVIEWS = [
|
||||
// battery/charging problems
|
||||
"Disappointed with the battery life! The phone barely lasts half a day with regular use. Considering how much I paid for it, I expected better performance in this department.",
|
||||
"I bought this phone a week ago, and I'm already frustrated with the battery life. It barely lasts half a day with normal usage. I expected more from a supposedly high-end device",
|
||||
"The charging port is so finicky. Sometimes it takes forever to charge, and other times it doesn't even recognize the charger. Frustrating experience!",
|
||||
|
||||
// overheating
|
||||
"This phone heats up way too quickly, especially when using demanding apps. It's uncomfortable to hold, and I'm concerned it might damage the internal components over time. Not what I expected",
|
||||
"This phone is like holding a hot potato. Video calls turn it into a scalding nightmare. Seriously, can't it keep its cool?",
|
||||
"Forget about a heatwave outside; my phone's got its own. It's like a little portable heater. Not what I signed up for.",
|
||||
|
||||
// poor build quality
|
||||
"I dropped the phone from a short distance, and the screen cracked easily. Not as durable as I expected from a flagship device.",
|
||||
"Took a slight bump in my bag, and the frame got dinged. Are we back in the flip phone era?",
|
||||
"So, my phone's been in my pocket with just keys – no ninja moves or anything. Still, it managed to get some scratches. Disappointed with the build quality.",
|
||||
|
||||
// software
|
||||
"The software updates are a nightmare. Each update seems to introduce new bugs, and it takes forever for them to be fixed.",
|
||||
"Constant crashes and freezes make me want to throw it into a black hole.",
|
||||
"Every time I open Instagram, my phone freezes and crashes. It's so frustrating!",
|
||||
|
||||
// other
|
||||
"I'm not sure what to make of this phone. It's not bad, but it's not great either. I'm on the fence about it.",
|
||||
"I hate the color of this phone. It's so ugly!",
|
||||
"This phone sucks! I'm returning it."
|
||||
].sort(() => Math.random() - 0.5)
|
||||
|
||||
const PLACEHOLDER_SECTIONS = [
|
||||
'Battery and charging problems',
|
||||
'Overheating',
|
||||
'Poor build quality',
|
||||
'Software issues',
|
||||
'Other',
|
||||
];
|
||||
|
||||
function App() {
|
||||
|
||||
const [text, setText] = useState(PLACEHOLDER_REVIEWS.join('\n'));
|
||||
|
||||
const [sections, setSections] = useState(
|
||||
PLACEHOLDER_SECTIONS.map(title => ({ title, items: [] }))
|
||||
);
|
||||
|
||||
const [status, setStatus] = useState('idle');
|
||||
|
||||
// Create a reference to the worker object.
|
||||
const worker = useRef(null);
|
||||
|
||||
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
||||
useEffect(() => {
|
||||
if (!worker.current) {
|
||||
// Create the worker if it does not yet exist.
|
||||
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
|
||||
type: 'module'
|
||||
});
|
||||
}
|
||||
|
||||
// Create a callback function for messages from the worker thread.
|
||||
const onMessageReceived = (e) => {
|
||||
const status = e.data.status;
|
||||
if (status === 'initiate') {
|
||||
setStatus('loading');
|
||||
} else if (status === 'ready') {
|
||||
setStatus('ready');
|
||||
} else if (status === 'output') {
|
||||
const { sequence, labels, scores } = e.data.output;
|
||||
|
||||
// Threshold for classification
|
||||
const label = scores[0] > 0.5 ? labels[0] : 'Other';
|
||||
|
||||
const sectionID = sections.map(x => x.title).indexOf(label) ?? sections.length - 1;
|
||||
setSections((sections) => {
|
||||
const newSections = [...sections]
|
||||
newSections[sectionID] = {
|
||||
...newSections[sectionID],
|
||||
items: [...newSections[sectionID].items, sequence]
|
||||
}
|
||||
return newSections
|
||||
})
|
||||
} else if (status === 'complete') {
|
||||
setStatus('idle');
|
||||
}
|
||||
};
|
||||
|
||||
// Attach the callback function as an event listener.
|
||||
worker.current.addEventListener('message', onMessageReceived);
|
||||
|
||||
// Define a cleanup function for when the component is unmounted.
|
||||
return () => worker.current.removeEventListener('message', onMessageReceived);
|
||||
}, [sections]);
|
||||
|
||||
const classify = useCallback(() => {
|
||||
setStatus('processing');
|
||||
worker.current.postMessage({
|
||||
text,
|
||||
labels: sections.slice(0, sections.length - 1).map(section => section.title)
|
||||
});
|
||||
}, [text, sections])
|
||||
|
||||
const busy = status !== 'idle';
|
||||
|
||||
return (
|
||||
<div className='flex flex-col h-full'>
|
||||
<textarea
|
||||
className='border w-full p-1 h-1/2'
|
||||
value={text}
|
||||
onChange={e => setText(e.target.value)}
|
||||
></textarea>
|
||||
<div className='flex flex-col justify-center items-center m-2 gap-1'>
|
||||
<button
|
||||
className='border py-1 px-2 bg-blue-400 rounded text-white text-lg font-medium disabled:opacity-50 disabled:cursor-not-allowed'
|
||||
disabled={busy}
|
||||
onClick={classify}>{
|
||||
!busy
|
||||
? 'Categorize'
|
||||
: status === 'loading'
|
||||
? 'Model loading...'
|
||||
: 'Processing'
|
||||
}</button>
|
||||
<div className='flex gap-1'>
|
||||
<button
|
||||
className='border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium'
|
||||
onClick={e => {
|
||||
setSections((sections) => {
|
||||
const newSections = [...sections];
|
||||
// add at position 2 from the end
|
||||
newSections.splice(newSections.length - 1, 0, {
|
||||
title: 'New Category',
|
||||
items: [],
|
||||
})
|
||||
return newSections;
|
||||
})
|
||||
}}>Add category</button>
|
||||
<button
|
||||
className='border py-1 px-2 bg-red-400 rounded text-white text-sm font-medium'
|
||||
disabled={sections.length <= 1}
|
||||
onClick={e => {
|
||||
setSections((sections) => {
|
||||
const newSections = [...sections];
|
||||
newSections.splice(newSections.length - 2, 1); // Remove second last element
|
||||
return newSections;
|
||||
})
|
||||
}}>Remove category</button>
|
||||
<button
|
||||
className='border py-1 px-2 bg-orange-400 rounded text-white text-sm font-medium'
|
||||
onClick={e => {
|
||||
setSections((sections) => (sections.map(section => ({
|
||||
...section,
|
||||
items: [],
|
||||
}))))
|
||||
}}>Clear</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex justify-between flex-grow overflow-x-auto max-h-[40%]'>
|
||||
{sections.map((section, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="flex flex-col w-full"
|
||||
>
|
||||
<input
|
||||
disabled={section.title === 'Other'}
|
||||
className="w-full border px-1 text-center"
|
||||
value={section.title} onChange={e => {
|
||||
setSections((sections) => {
|
||||
const newSections = [...sections];
|
||||
newSections[index].title = e.target.value;
|
||||
return newSections;
|
||||
})
|
||||
}}></input>
|
||||
<div className="overflow-y-auto h-full border">
|
||||
{section.items.map((item, index) => (
|
||||
<div
|
||||
className="m-2 border bg-red-50 rounded p-1 text-sm"
|
||||
key={index}>{item}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default App
|
|
@ -0,0 +1,3 @@
|
|||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
|
@ -0,0 +1,10 @@
|
|||
import React from 'react'
|
||||
import ReactDOM from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
import './index.css'
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')).render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
</React.StrictMode>,
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
import { env, pipeline } from '@xenova/transformers';
|
||||
|
||||
// Skip local model check since we are downloading the model from the Hugging Face Hub.
|
||||
env.allowLocalModels = false;
|
||||
|
||||
class MyZeroShotClassificationPipeline {
|
||||
static task = 'zero-shot-classification';
|
||||
static model = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33';
|
||||
static instance = null;
|
||||
|
||||
static async getInstance(progress_callback = null) {
|
||||
if (this.instance === null) {
|
||||
this.instance = pipeline(this.task, this.model, {
|
||||
quantized: true,
|
||||
progress_callback,
|
||||
});
|
||||
}
|
||||
|
||||
return this.instance;
|
||||
}
|
||||
}
|
||||
|
||||
// Listen for messages from the main thread
|
||||
self.addEventListener('message', async (event) => {
|
||||
// Retrieve the pipeline. When called for the first time,
|
||||
// this will load the pipeline and save it for future use.
|
||||
const classifier = await MyZeroShotClassificationPipeline.getInstance(x => {
|
||||
// We also add a progress callback to the pipeline so that we can
|
||||
// track model loading.
|
||||
self.postMessage(x);
|
||||
});
|
||||
|
||||
const { text, labels } = event.data;
|
||||
|
||||
const split = text.split('\n');
|
||||
for (const line of split) {
|
||||
const output = await classifier(line, labels, {
|
||||
hypothesis_template: 'This text is about {}.',
|
||||
multi_label: true,
|
||||
});
|
||||
// Send the output back to the main thread
|
||||
self.postMessage({ status: 'output', output });
|
||||
}
|
||||
// Send the output back to the main thread
|
||||
self.postMessage({ status: 'complete' });
|
||||
});
|
|
@ -0,0 +1,12 @@
|
|||
/** @type {import('tailwindcss').Config} */
|
||||
export default {
|
||||
content: [
|
||||
"./index.html",
|
||||
"./src/**/*.{js,ts,jsx,tsx}",
|
||||
],
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
})
|
Loading…
Reference in New Issue