From 2015c685c760fe994a3d212abdfccbf2aeaaa97a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 28 Jul 2023 13:24:32 +0200 Subject: [PATCH] Add Starcoder model support + demo (#225) * Add support for `gpt_bigcode` models * Create basic code-completion sample application * Update sidebar * Remove debug statement * Disable 1B model (for now) * Display progress bars * Reuse config if not specified * Update supported_models.py * Update comment * Add temperature/sample/topk generation params * Update sidebar * Add `gpt_bigcode` to supported models list * Add code playground example * Update title * Cleanup * Ignore `bigcode/starcoderbase-1b` from tests * Update transformers.js version for demo --- README.md | 2 + docs/snippets/3_examples.snippet | 1 + docs/snippets/6_supported-models.snippet | 1 + examples/code-completion/.eslintrc.cjs | 16 ++ examples/code-completion/.gitignore | 24 ++ examples/code-completion/index.html | 12 + examples/code-completion/package.json | 34 +++ examples/code-completion/postcss.config.js | 6 + examples/code-completion/src/App.css | 34 +++ examples/code-completion/src/App.jsx | 251 ++++++++++++++++++ examples/code-completion/src/assets/react.svg | 1 + .../src/components/Progress.jsx | 20 ++ examples/code-completion/src/index.css | 72 +++++ examples/code-completion/src/main.jsx | 10 + examples/code-completion/src/worker.js | 73 +++++ examples/code-completion/tailwind.config.js | 12 + examples/code-completion/vite.config.js | 7 + scripts/supported_models.py | 5 + src/models.js | 104 +++++++- 19 files changed, 679 insertions(+), 6 deletions(-) create mode 100644 examples/code-completion/.eslintrc.cjs create mode 100644 examples/code-completion/.gitignore create mode 100644 examples/code-completion/index.html create mode 100644 examples/code-completion/package.json create mode 100644 examples/code-completion/postcss.config.js create mode 100644 examples/code-completion/src/App.css create mode 100644 examples/code-completion/src/App.jsx create mode 100644 examples/code-completion/src/assets/react.svg create mode 100644 examples/code-completion/src/components/Progress.jsx create mode 100644 examples/code-completion/src/index.css create mode 100644 examples/code-completion/src/main.jsx create mode 100644 examples/code-completion/src/worker.js create mode 100644 examples/code-completion/tailwind.config.js create mode 100644 examples/code-completion/vite.config.js diff --git a/README.md b/README.md index db3e6f7..05176c4 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,7 @@ Want to jump straight in? Get started with one of our sample applications/templa |-------------------|----------------------------------|-------------------------------| | Whisper Web | Speech recognition w/ Whisper | [link](https://github.com/xenova/whisper-web) | | Doodle Dash | Real-time sketch-recognition game (see [blog](https://huggingface.co/blog/ml-web-games)) | [link](https://github.com/xenova/doodle-dash) | +| Code Playground | In-browser code completion website | [link](./examples/code-completion/) | | React | Multilingual translation website | [link](./examples/react-translator/) | | Browser extension | Text classification extension | [link](./examples/extension/) | | Electron | Text classification application | [link](./examples/electron/) | @@ -261,6 +262,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. +1. **[GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode)** (from BigCode) released with the paper [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988) by Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. diff --git a/docs/snippets/3_examples.snippet b/docs/snippets/3_examples.snippet index aec5d65..776e9df 100644 --- a/docs/snippets/3_examples.snippet +++ b/docs/snippets/3_examples.snippet @@ -4,6 +4,7 @@ Want to jump straight in? Get started with one of our sample applications/templa |-------------------|----------------------------------|-------------------------------| | Whisper Web | Speech recognition w/ Whisper | [link](https://github.com/xenova/whisper-web) | | Doodle Dash | Real-time sketch-recognition game (see [blog](https://huggingface.co/blog/ml-web-games)) | [link](https://github.com/xenova/doodle-dash) | +| Code Playground | In-browser code completion website | [link](./examples/code-completion/) | | React | Multilingual translation website | [link](./examples/react-translator/) | | Browser extension | Text classification extension | [link](./examples/extension/) | | Electron | Text classification application | [link](./examples/electron/) | diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 0b91355..ab2102d 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -11,6 +11,7 @@ 1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. +1. **[GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode)** (from BigCode) released with the paper [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988) by Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. diff --git a/examples/code-completion/.eslintrc.cjs b/examples/code-completion/.eslintrc.cjs new file mode 100644 index 0000000..c12e55c --- /dev/null +++ b/examples/code-completion/.eslintrc.cjs @@ -0,0 +1,16 @@ +module.exports = { + env: { browser: true, es2020: true, 'node': true }, + extends: [ + 'eslint:recommended', + 'plugin:react/recommended', + 'plugin:react/jsx-runtime', + 'plugin:react-hooks/recommended', + ], + parserOptions: { ecmaVersion: 'latest', sourceType: 'module' }, + settings: { react: { version: '18.2' } }, + plugins: ['react-refresh'], + rules: { + 'react-refresh/only-export-components': 'warn', + 'react/prop-types': 'off', + }, +} diff --git a/examples/code-completion/.gitignore b/examples/code-completion/.gitignore new file mode 100644 index 0000000..a547bf3 --- /dev/null +++ b/examples/code-completion/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/examples/code-completion/index.html b/examples/code-completion/index.html new file mode 100644 index 0000000..b23c61f --- /dev/null +++ b/examples/code-completion/index.html @@ -0,0 +1,12 @@ + + + + + + Transformers.js - Code completion playground + + +
+ + + diff --git a/examples/code-completion/package.json b/examples/code-completion/package.json new file mode 100644 index 0000000..9a49e03 --- /dev/null +++ b/examples/code-completion/package.json @@ -0,0 +1,34 @@ +{ + "name": "code-completion", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "lint": "eslint src --ext js,jsx --report-unused-disable-directives --max-warnings 0", + "preview": "vite preview" + }, + "dependencies": { + "@xenova/transformers": "^2.4.4", + "@monaco-editor/react": "^4.5.1", + "react": "^18.2.0", + "react-dom": "^18.2.0" + }, + "devDependencies": { + "@types/react": "^18.2.15", + "@types/react-dom": "^18.2.7", + "@vitejs/plugin-react": "^4.0.3", + "autoprefixer": "^10.4.14", + "eslint": "^8.45.0", + "eslint-plugin-react": "^7.32.2", + "eslint-plugin-react-hooks": "^4.6.0", + "eslint-plugin-react-refresh": "^0.4.3", + "postcss": "^8.4.27", + "tailwindcss": "^3.3.3", + "vite": "^4.4.5" + }, + "overrides": { + "protobufjs": "^7.2.4" + } +} diff --git a/examples/code-completion/postcss.config.js b/examples/code-completion/postcss.config.js new file mode 100644 index 0000000..2e7af2b --- /dev/null +++ b/examples/code-completion/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/examples/code-completion/src/App.css b/examples/code-completion/src/App.css new file mode 100644 index 0000000..749e884 --- /dev/null +++ b/examples/code-completion/src/App.css @@ -0,0 +1,34 @@ +.sidebar { + background-color: #181818; + color: #CCCCCC; +} + +body{ + background-color: #1F1F1F; + color: white; +} + +.progress-container { + position: relative; + font-size: 16px; + color: white; + /* background-color: #e9ecef; */ + border-radius: 8px; + text-align: left; + overflow: hidden; +} + +.progress-bar { + padding: 2px 4px; + z-index: 0; + top: 0; + width: 1%; + height: 100%; + overflow: hidden; + background-color: #007bff; + white-space: nowrap; +} + +.progress-text { + z-index: 2; +} diff --git a/examples/code-completion/src/App.jsx b/examples/code-completion/src/App.jsx new file mode 100644 index 0000000..a532f72 --- /dev/null +++ b/examples/code-completion/src/App.jsx @@ -0,0 +1,251 @@ +import { useState, useRef, useEffect } from "react"; + +import Editor from "@monaco-editor/react"; +import Progress from './components/Progress'; + +import './App.css' + + +const MODELS = [ + 'Xenova/tiny_starcoder_py', + 'Xenova/codegen-350M-mono', + // 'Xenova/starcoderbase-1b', +] +function App() { + // Editor setup + const monaco = useRef(null); + const [monacoReady, setMonacoReady] = useState(false); + const [language, setLanguage] = useState('python'); // Only allow python for now + + // Model loading + const [ready, setReady] = useState(null); + const [disabled, setDisabled] = useState(false); + const [progressItems, setProgressItems] = useState([]); + + // Inputs and outputs + const [model, setModel] = useState(MODELS[0]); + const [maxNewTokens, setMaxNewTokens] = useState(45); + const [code, setCode] = useState('\ndef fib(n):\n """Calculates the nth Fibonacci number"""\n'); + + // Generation parameters + const [temperature, setTemperature] = useState(0.5); + const [topK, setTopK] = useState(5); + const [doSample, setDoSample] = useState(false); + + // Create a reference to the worker object. + const worker = useRef(null); + + // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted. + useEffect(() => { + if (!worker.current) { + // Create the worker if it does not yet exist. + worker.current = new Worker(new URL('./worker.js', import.meta.url), { + type: 'module' + }); + } + + // Create a callback function for messages from the worker thread. + const onMessageReceived = (e) => { + switch (e.data.status) { + case 'initiate': + // Model file start load: add a new progress item to the list. + setReady(false); + setProgressItems(prev => [...prev, e.data]); + break; + + case 'progress': + // Model file progress: update one of the progress items. + setProgressItems( + prev => prev.map(item => { + if (item.file === e.data.file) { + return { ...item, ...e.data } + } + return item; + }) + ); + break; + + case 'done': + // Model file loaded: remove the progress item from the list. + setProgressItems( + prev => prev.filter(item => item.file !== e.data.file) + ); + break; + + case 'ready': + // Pipeline ready: the worker is ready to accept messages. + setReady(true); + break; + + case 'update': + // Generation update: update the output text. + setCode(e.data.output); + break; + + case 'complete': + // Generation complete: re-enable the "Generate" button + setDisabled(false); + break; + } + }; + + // Attach the callback function as an event listener. + worker.current.addEventListener('message', onMessageReceived); + + // Define a cleanup function for when the component is unmounted. + return () => worker.current.removeEventListener('message', onMessageReceived); + }); + + + useEffect(() => { + const m = monaco.current; + if (!m) return; + + let actionRegistration = m.addAction({ + id: "generate", + label: "Generate", + contextMenuGroupId: "0_custom", + run: () => { + const val = m.getValue(); + if (!val) return; + + worker.current.postMessage({ + model, + text: val, + max_new_tokens: maxNewTokens, + temperature, + top_k: topK, + do_sample: doSample + }); + } + }); + + // Define a cleanup function for when the component is unmounted. + return () => actionRegistration.dispose(); + }, [monacoReady, model, maxNewTokens, temperature, topK, doSample]); + + const showLoading = ready === false || progressItems.length > 0 + + return ( +
+
+ + {ready === false && ( + + )} + {progressItems.map(data => ( +
+ +
+ ))} +
+
+ { + monaco.current = m; + setMonacoReady(true); + }} + options={{ + fontSize: 24 + }} + /> +
+
+

In-browser code completion

+
+ Made with 🤗 Transformers.js +
+ + + + +
+ + +
+ { + const newValue = parseInt(event.target.value); + setMaxNewTokens(newValue); + }} + /> + +
+ + +
+ { + const newValue = Number(event.target.value); + setTemperature(newValue); + }} + /> +
+ { + setDoSample(event.target.checked); + }} + className="w-4 h-4 text-blue-600 rounded focus:ring-blue-600 ring-offset-gray-800 focus:ring-2 bg-gray-700 border-gray-600" + /> + +
+ +
+ + +
+ { + const newValue = parseInt(event.target.value); + setTopK(newValue); + }} + /> + +
+ +
+ + + + Source code +
+
+
+ ); +} + +export default App; diff --git a/examples/code-completion/src/assets/react.svg b/examples/code-completion/src/assets/react.svg new file mode 100644 index 0000000..6c87de9 --- /dev/null +++ b/examples/code-completion/src/assets/react.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/examples/code-completion/src/components/Progress.jsx b/examples/code-completion/src/components/Progress.jsx new file mode 100644 index 0000000..28491df --- /dev/null +++ b/examples/code-completion/src/components/Progress.jsx @@ -0,0 +1,20 @@ +function formatBytes(bytes, decimals = 0) { + const sizes = ["Bytes", "KB", "MB", "GB", "TB"]; + if (bytes === 0) return "0 Bytes"; + const i = parseInt(Math.floor(Math.log(bytes) / Math.log(1000)), 10); + const rounded = (bytes / Math.pow(1000, i)).toFixed(decimals); + return rounded + " " + sizes[i]; +} + +export default function Progress({ data }) { + const progress = data.progress ?? 0; + const text = data.file; + + const a = formatBytes(data.loaded); + const b = formatBytes(data.total); + return ( +
+
{text} ({`${a} / ${b}`})
+
+ ); +} \ No newline at end of file diff --git a/examples/code-completion/src/index.css b/examples/code-completion/src/index.css new file mode 100644 index 0000000..b90bf28 --- /dev/null +++ b/examples/code-completion/src/index.css @@ -0,0 +1,72 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +:root { + font-family: Inter, system-ui, Avenir, Helvetica, Arial, sans-serif; + line-height: 1.5; + font-weight: 400; + + color-scheme: light dark; + color: rgba(255, 255, 255, 0.87); + background-color: #242424; + + font-synthesis: none; + text-rendering: optimizeLegibility; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + -webkit-text-size-adjust: 100%; +} + + +a { + font-weight: 500; + color: #646cff; + text-decoration: inherit; +} +a:hover { + color: #535bf2; +} + +body { + margin: 0; + display: flex; + place-items: center; +} + +h1 { + font-size: 3.2em; + line-height: 1.1; +} + +button { + border-radius: 8px; + border: 1px solid transparent; + padding: 0.6em 1.2em; + font-size: 1em; + font-weight: 500; + font-family: inherit; + background-color: #1a1a1a; + cursor: pointer; + transition: border-color 0.25s; +} +button:hover { + border-color: #646cff; +} +button:focus, +button:focus-visible { + outline: 4px auto -webkit-focus-ring-color; +} + +@media (prefers-color-scheme: light) { + :root { + color: #213547; + background-color: #ffffff; + } + a:hover { + color: #747bff; + } + button { + background-color: #f9f9f9; + } +} diff --git a/examples/code-completion/src/main.jsx b/examples/code-completion/src/main.jsx new file mode 100644 index 0000000..54b39dd --- /dev/null +++ b/examples/code-completion/src/main.jsx @@ -0,0 +1,10 @@ +import React from 'react' +import ReactDOM from 'react-dom/client' +import App from './App.jsx' +import './index.css' + +ReactDOM.createRoot(document.getElementById('root')).render( + + + , +) diff --git a/examples/code-completion/src/worker.js b/examples/code-completion/src/worker.js new file mode 100644 index 0000000..31ceed6 --- /dev/null +++ b/examples/code-completion/src/worker.js @@ -0,0 +1,73 @@ + +import { pipeline, env } from '@xenova/transformers'; + +env.allowLocalModels = false; + +/** + * This class uses the Singleton pattern to ensure that only one instance of the pipeline is loaded. + */ +class CodeCompletionPipeline { + static task = 'text-generation'; + static model = null; + static instance = null; + + static async getInstance(progress_callback = null) { + if (this.instance === null) { + this.instance = pipeline(this.task, this.model, { progress_callback }); + } + + return this.instance; + } +} + +// Listen for messages from the main thread +self.addEventListener('message', async (event) => { + const { + model, text, max_new_tokens, + + // Generation parameters + temperature, + top_k, + do_sample, + } = event.data; + + if (CodeCompletionPipeline.model !== model) { + // Invalidate model if different + CodeCompletionPipeline.model = model; + + if (CodeCompletionPipeline.instance !== null) { + (await CodeCompletionPipeline.getInstance()).dispose(); + CodeCompletionPipeline.instance = null; + } + } + + // Retrieve the code-completion pipeline. When called for the first time, + // this will load the pipeline and save it for future use. + let generator = await CodeCompletionPipeline.getInstance(x => { + // We also add a progress callback to the pipeline so that we can + // track model loading. + self.postMessage(x); + }); + + // Actually perform the code-completion + let output = await generator(text, { + max_new_tokens, + temperature, + top_k, + do_sample, + + // Allows for partial output + callback_function: x => { + self.postMessage({ + status: 'update', + output: generator.tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true }) + }); + } + }); + + // Send the output back to the main thread + self.postMessage({ + status: 'complete', + output: output, + }); +}); diff --git a/examples/code-completion/tailwind.config.js b/examples/code-completion/tailwind.config.js new file mode 100644 index 0000000..d37737f --- /dev/null +++ b/examples/code-completion/tailwind.config.js @@ -0,0 +1,12 @@ +/** @type {import('tailwindcss').Config} */ +export default { + content: [ + "./index.html", + "./src/**/*.{js,ts,jsx,tsx}", + ], + theme: { + extend: {}, + }, + plugins: [], +} + diff --git a/examples/code-completion/vite.config.js b/examples/code-completion/vite.config.js new file mode 100644 index 0000000..5a33944 --- /dev/null +++ b/examples/code-completion/vite.config.js @@ -0,0 +1,7 @@ +import { defineConfig } from 'vite' +import react from '@vitejs/plugin-react' + +// https://vitejs.dev/config/ +export default defineConfig({ + plugins: [react()], +}) diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 673af7c..e604555 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -109,6 +109,11 @@ SUPPORTED_MODELS = { 'MBZUAI/LaMini-Cerebras-590M', 'MBZUAI/LaMini-GPT-124M', ], + 'gpt_bigcode': [ + # See branch: https://github.com/huggingface/optimum/tree/xenova-bigcode-testing + 'bigcode/tiny_starcoder_py', + # 'bigcode/starcoderbase-1b', # NOTE: This model is gated, so we ignore it when testing + ], 'm2m_100': [ 'facebook/nllb-200-distilled-600M', ], diff --git a/src/models.js b/src/models.js index 13739d2..4ce1acf 100644 --- a/src/models.js +++ b/src/models.js @@ -1168,12 +1168,21 @@ export class PreTrainedModel extends Callable { } } else { - // @ts-ignore - let dims = [1, this.num_heads, 0, this.dim_kv] - // @ts-ignore - for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) + if (this.config.multi_query) { + // @ts-ignore + let dims = [1, 0, 2 * this.dim_kv] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims) + } + } else { + // @ts-ignore + let dims = [1, this.num_heads, 0, this.dim_kv] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) + } } } } @@ -2450,6 +2459,83 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { return await decoderForward(this, model_inputs); } } +////////////////////////////////////////////////// + +////////////////////////////////////////////////// +// GPTBigCode models +export class GPTBigCodePreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `GPTBigCodePreTrainedModel` class. + * @param {Object} config The configuration of the model. + * @param {any} session The ONNX session containing the model weights. + */ + constructor(config, session) { + super(config, session); + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id + + this.num_heads = this.config.n_head + this.num_layers = this.config.n_layer + this.dim_kv = this.config.n_embd / this.num_heads; + } +} + +export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { + /** + * + * @param {...any} args + * @throws {Error} + * @returns {Promise} + */ + async generate(...args) { + throw Error( + "The current model class (GPTBigCodeModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'GPTBigCodeForCausalLM'}" + ) + } +} + +export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { + + /** + * Initializes and returns the beam for text generation task + * @param {Tensor} inputTokenIds The input token ids. + * @param {number} numOutputTokens The number of tokens to be generated. + * @param {Tensor} inputs_attention_mask Optional input attention mask. + * @returns {any} A Beam object representing the initialized beam. + */ + getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { + return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) + } + + /** + * Runs a single step of the beam search generation algorithm. + * @param {any} beam The current beam being generated. + * @returns {Promise} The updated beam after a single generation step. + */ + async runBeam(beam) { + return await decoderRunBeam(this, beam); + } + + /** + * Updates the given beam with the new generated token id. + * @param {any} beam The Beam object representing the beam. + * @param {number} newTokenId The new generated token id to be added to the beam. + */ + updateBeam(beam, newTokenId) { + return decoderUpdatebeam(beam, newTokenId); + } + + /** + * Forward pass for the model. + * @param {Object} model_inputs The inputs for the model. + * @returns {Promise} The output tensor of the model. + */ + async forward(model_inputs) { + return await decoderForward(this, model_inputs); + } +} +////////////////////////////////////////////////// ////////////////////////////////////////////////// // CodeGen models @@ -2854,6 +2940,10 @@ export class PretrainedMixin { revision, } config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); + if (!options.config) { + // If no config was passed, reuse this config for future processing + options.config = config; + } if (!this.MODEL_CLASS_MAPPINGS) { throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name); @@ -2904,6 +2994,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ ['gpt2', GPT2Model], + ['gpt_bigcode', GPTBigCodeModel], ['gpt_neo', GPTNeoModel], ['codegen', CodeGenModel], ]); @@ -2939,6 +3030,7 @@ const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ ['gpt2', GPT2LMHeadModel], + ['gpt_bigcode', GPTBigCodeForCausalLM], ['gpt_neo', GPTNeoForCausalLM], ['codegen', CodeGenForCausalLM], ]);