Add zero-shot classification demo (#519)

* Create zero-shot-classification demo

* Update hypothesis template
This commit is contained in:
Joshua Lochner 2024-01-29 12:29:33 +02:00 committed by GitHub
parent b1f96a2fc9
commit 1edf683e64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 6170 additions and 0 deletions

View File

@ -0,0 +1,20 @@
module.exports = {
root: true,
env: { browser: true, es2020: true },
extends: [
'eslint:recommended',
'plugin:react/recommended',
'plugin:react/jsx-runtime',
'plugin:react-hooks/recommended',
],
ignorePatterns: ['dist', '.eslintrc.cjs'],
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
settings: { react: { version: '18.2' } },
plugins: ['react-refresh'],
rules: {
'react-refresh/only-export-components': [
'warn',
{ allowConstantExport: true },
],
},
}

View File

@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

View File

@ -0,0 +1,13 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Zero-shot classication w/ Transformers.js</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.jsx"></script>
</body>
</html>

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
{
"name": "zero-shot-classification",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
"preview": "vite preview"
},
"dependencies": {
"@xenova/transformers": "^2.14.0",
"react": "^18.2.0",
"react-dom": "^18.2.0"
},
"devDependencies": {
"@types/react": "^18.2.43",
"@types/react-dom": "^18.2.17",
"@vitejs/plugin-react": "^4.2.1",
"autoprefixer": "^10.4.16",
"eslint": "^8.55.0",
"eslint-plugin-react": "^7.33.2",
"eslint-plugin-react-hooks": "^4.6.0",
"eslint-plugin-react-refresh": "^0.4.5",
"postcss": "^8.4.33",
"tailwindcss": "^3.4.1",
"vite": "^5.0.8"
}
}

View File

@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

View File

@ -0,0 +1,5 @@
#root {
height: 100vh;
width: 100vw;
padding: 1rem;
}

View File

@ -0,0 +1,189 @@
import { useState, useRef, useEffect, useCallback } from 'react'
import './App.css'
const PLACEHOLDER_REVIEWS = [
// battery/charging problems
"Disappointed with the battery life! The phone barely lasts half a day with regular use. Considering how much I paid for it, I expected better performance in this department.",
"I bought this phone a week ago, and I'm already frustrated with the battery life. It barely lasts half a day with normal usage. I expected more from a supposedly high-end device",
"The charging port is so finicky. Sometimes it takes forever to charge, and other times it doesn't even recognize the charger. Frustrating experience!",
// overheating
"This phone heats up way too quickly, especially when using demanding apps. It's uncomfortable to hold, and I'm concerned it might damage the internal components over time. Not what I expected",
"This phone is like holding a hot potato. Video calls turn it into a scalding nightmare. Seriously, can't it keep its cool?",
"Forget about a heatwave outside; my phone's got its own. It's like a little portable heater. Not what I signed up for.",
// poor build quality
"I dropped the phone from a short distance, and the screen cracked easily. Not as durable as I expected from a flagship device.",
"Took a slight bump in my bag, and the frame got dinged. Are we back in the flip phone era?",
"So, my phone's been in my pocket with just keys no ninja moves or anything. Still, it managed to get some scratches. Disappointed with the build quality.",
// software
"The software updates are a nightmare. Each update seems to introduce new bugs, and it takes forever for them to be fixed.",
"Constant crashes and freezes make me want to throw it into a black hole.",
"Every time I open Instagram, my phone freezes and crashes. It's so frustrating!",
// other
"I'm not sure what to make of this phone. It's not bad, but it's not great either. I'm on the fence about it.",
"I hate the color of this phone. It's so ugly!",
"This phone sucks! I'm returning it."
].sort(() => Math.random() - 0.5)
const PLACEHOLDER_SECTIONS = [
'Battery and charging problems',
'Overheating',
'Poor build quality',
'Software issues',
'Other',
];
function App() {
const [text, setText] = useState(PLACEHOLDER_REVIEWS.join('\n'));
const [sections, setSections] = useState(
PLACEHOLDER_SECTIONS.map(title => ({ title, items: [] }))
);
const [status, setStatus] = useState('idle');
// Create a reference to the worker object.
const worker = useRef(null);
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
useEffect(() => {
if (!worker.current) {
// Create the worker if it does not yet exist.
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
type: 'module'
});
}
// Create a callback function for messages from the worker thread.
const onMessageReceived = (e) => {
const status = e.data.status;
if (status === 'initiate') {
setStatus('loading');
} else if (status === 'ready') {
setStatus('ready');
} else if (status === 'output') {
const { sequence, labels, scores } = e.data.output;
// Threshold for classification
const label = scores[0] > 0.5 ? labels[0] : 'Other';
const sectionID = sections.map(x => x.title).indexOf(label) ?? sections.length - 1;
setSections((sections) => {
const newSections = [...sections]
newSections[sectionID] = {
...newSections[sectionID],
items: [...newSections[sectionID].items, sequence]
}
return newSections
})
} else if (status === 'complete') {
setStatus('idle');
}
};
// Attach the callback function as an event listener.
worker.current.addEventListener('message', onMessageReceived);
// Define a cleanup function for when the component is unmounted.
return () => worker.current.removeEventListener('message', onMessageReceived);
}, [sections]);
const classify = useCallback(() => {
setStatus('processing');
worker.current.postMessage({
text,
labels: sections.slice(0, sections.length - 1).map(section => section.title)
});
}, [text, sections])
const busy = status !== 'idle';
return (
<div className='flex flex-col h-full'>
<textarea
className='border w-full p-1 h-1/2'
value={text}
onChange={e => setText(e.target.value)}
></textarea>
<div className='flex flex-col justify-center items-center m-2 gap-1'>
<button
className='border py-1 px-2 bg-blue-400 rounded text-white text-lg font-medium disabled:opacity-50 disabled:cursor-not-allowed'
disabled={busy}
onClick={classify}>{
!busy
? 'Categorize'
: status === 'loading'
? 'Model loading...'
: 'Processing'
}</button>
<div className='flex gap-1'>
<button
className='border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium'
onClick={e => {
setSections((sections) => {
const newSections = [...sections];
// add at position 2 from the end
newSections.splice(newSections.length - 1, 0, {
title: 'New Category',
items: [],
})
return newSections;
})
}}>Add category</button>
<button
className='border py-1 px-2 bg-red-400 rounded text-white text-sm font-medium'
disabled={sections.length <= 1}
onClick={e => {
setSections((sections) => {
const newSections = [...sections];
newSections.splice(newSections.length - 2, 1); // Remove second last element
return newSections;
})
}}>Remove category</button>
<button
className='border py-1 px-2 bg-orange-400 rounded text-white text-sm font-medium'
onClick={e => {
setSections((sections) => (sections.map(section => ({
...section,
items: [],
}))))
}}>Clear</button>
</div>
</div>
<div className='flex justify-between flex-grow overflow-x-auto max-h-[40%]'>
{sections.map((section, index) => (
<div
key={index}
className="flex flex-col w-full"
>
<input
disabled={section.title === 'Other'}
className="w-full border px-1 text-center"
value={section.title} onChange={e => {
setSections((sections) => {
const newSections = [...sections];
newSections[index].title = e.target.value;
return newSections;
})
}}></input>
<div className="overflow-y-auto h-full border">
{section.items.map((item, index) => (
<div
className="m-2 border bg-red-50 rounded p-1 text-sm"
key={index}>{item}
</div>
))}
</div>
</div>
))}
</div>
</div>
)
}
export default App

View File

@ -0,0 +1,3 @@
@tailwind base;
@tailwind components;
@tailwind utilities;

View File

@ -0,0 +1,10 @@
import React from 'react'
import ReactDOM from 'react-dom/client'
import App from './App.jsx'
import './index.css'
ReactDOM.createRoot(document.getElementById('root')).render(
<React.StrictMode>
<App />
</React.StrictMode>,
)

View File

@ -0,0 +1,46 @@
import { env, pipeline } from '@xenova/transformers';
// Skip local model check since we are downloading the model from the Hugging Face Hub.
env.allowLocalModels = false;
class MyZeroShotClassificationPipeline {
static task = 'zero-shot-classification';
static model = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33';
static instance = null;
static async getInstance(progress_callback = null) {
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, {
quantized: true,
progress_callback,
});
}
return this.instance;
}
}
// Listen for messages from the main thread
self.addEventListener('message', async (event) => {
// Retrieve the pipeline. When called for the first time,
// this will load the pipeline and save it for future use.
const classifier = await MyZeroShotClassificationPipeline.getInstance(x => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
const { text, labels } = event.data;
const split = text.split('\n');
for (const line of split) {
const output = await classifier(line, labels, {
hypothesis_template: 'This text is about {}.',
multi_label: true,
});
// Send the output back to the main thread
self.postMessage({ status: 'output', output });
}
// Send the output back to the main thread
self.postMessage({ status: 'complete' });
});

View File

@ -0,0 +1,12 @@
/** @type {import('tailwindcss').Config} */
export default {
content: [
"./index.html",
"./src/**/*.{js,ts,jsx,tsx}",
],
theme: {
extend: {},
},
plugins: [],
}

View File

@ -0,0 +1,7 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
// https://vitejs.dev/config/
export default defineConfig({
plugins: [react()],
})