From 5ebfc411148f7cf408f963f19bf975d7ed4e0da3 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Mon, 2 Oct 2023 06:19:56 +0100 Subject: [PATCH] Initial implementation --- package.json | 4 +- src/index.ts | 241 +++++++++++++++++++++++++++++++++++++++++++++++++- src/worker.ts | 113 +++++++++++++++++++++++ 3 files changed, 355 insertions(+), 3 deletions(-) create mode 100644 src/worker.ts diff --git a/package.json b/package.json index 2e84691..1218c58 100644 --- a/package.json +++ b/package.json @@ -54,7 +54,9 @@ "watch:labextension": "jupyter labextension watch ." }, "dependencies": { - "@jupyterlab/application": "^4.0.0" + "@jupyterlab/application": "^4.0.0", + "@jupyterlab/completer": "^4.0.0", + "@xenova/transformers": "^2.6.2" }, "devDependencies": { "@jupyterlab/builder": "^4.0.0", diff --git a/src/index.ts b/src/index.ts index 3c65157..a09dd84 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,6 +2,235 @@ import { JupyterFrontEnd, JupyterFrontEndPlugin } from '@jupyterlab/application'; +import { + ICompletionProviderManager, + IInlineCompletionProvider, + IInlineCompletionContext, + CompletionHandler, + IInlineCompletionList, + IInlineCompletionItem +} from '@jupyterlab/completer'; +import type { ISettingRegistry } from '@jupyterlab/settingregistry'; +import { Notification } from '@jupyterlab/apputils'; +import { JSONValue, PromiseDelegate } from '@lumino/coreutils'; + +interface ISettings { + temperature: number; + model: string; + maxNewTokens: number; + topK: number; + doSample: boolean; + generateN: number; +} + +const DEFAULT_SETTINGS: ISettings = { + model: 'Xenova/tiny_starcoder_py', + temperature: 0.5, + doSample: false, + topK: 5, + maxNewTokens: 50, + generateN: 2 +}; + +interface IOptions { + worker: Worker; +} + +interface IStream { + done: boolean; + response: IInlineCompletionItem; +} + +class TransformersInlineProvider implements IInlineCompletionProvider { + constructor(protected options: IOptions) { + options.worker.addEventListener( + 'message', + this.onMessageReceived.bind(this) + ); + } + + readonly identifier = '@krassowski/inline-completer'; + readonly name = 'Transformers powered completions'; + + get schema(): ISettingRegistry.IProperty { + return { + properties: { + model: { + title: 'Model', + enum: [ + // https://huggingface.co/bigcode/tiny_starcoder_py + 'Xenova/tiny_starcoder_py', + // https://huggingface.co/Salesforce/codegen-350M-mono + 'Xenova/codegen-350M-mono' + ], + type: 'string' + }, + temperature: { + minimum: 0, + maximum: 1, + type: 'number', + description: 'The value used to module the next token probabilities' + }, + doSample: { + type: 'boolean', + description: 'Whether to use sampling; use greedy decoding otherwise' + }, + topK: { + minimum: 0, + maximum: 50, + type: 'number', + description: + 'The number of highest probability vocabulary tokens to keep for top-k-filtering' + }, + maxNewTokens: { + minimum: 1, + maximum: 512, + type: 'number' + }, + generateN: { + minimum: 1, + maximum: 10, + type: 'number' + } + }, + default: DEFAULT_SETTINGS as any + }; + } + + configure(settings: { [property: string]: JSONValue }): void { + this._settings = settings as any as ISettings; + console.log(this._settings); + this.options.worker.postMessage({ + action: 'initializeModel', + model: this._settings.model + }); + } + + // TODO types + onMessageReceived(e: any) { + console.log(e); + const data = e.data; + switch (e.data.status) { + case 'initiate': + this._ready = new PromiseDelegate(); + if (data.model !== this._lastModel) { + this._notificationId = Notification.info( + 'Loading model' + data.model + ': fetching' + data.file + ); + this._lastModel = data.model; + } + break; + case 'progress': + Notification.update({ + id: this._notificationId, + message: + 'Loading model ' + + data.model + + ' ' + + Math.round(data.progress) + + '%', + type: 'in-progress', + progress: data.progress + }); + break; + + case 'done': + Notification.dismiss(this._notificationId); + break; + + case 'ready': + this._ready.resolve(void 0); + break; + + case 'update': { + const token = data.idToken; + const delegate = this._streamPromises.get(token); + if (!delegate) { + console.warn('Completion updated but stream absent'); + } else { + delegate.resolve({ + done: false, + response: { + insertText: data.output + } + }); + } + break; + } + case 'complete': { + const token = data.idToken; + const delegate = this._streamPromises.get(token); + if (!delegate) { + console.warn('Completion done but stream absent'); + } else { + delegate.resolve({ + done: true, + response: { + insertText: data.output + } + }); + } + break; + } + } + } + + async fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ): Promise> { + // TODO: + await this._ready.promise; + this._streamPromises = new Map(); + const multiLinePrefix = request.text.slice(0, request.offset); + const linePrefix = multiLinePrefix.split('\n').slice(-1)[0]; + console.log(linePrefix); + const items: IInlineCompletionItem[] = []; + const idTokens = []; + for (let i = 0; i < this._settings.generateN; i++) { + const token = 'T' + ++this._tokenCounter; + idTokens.push(token); + items.push({ + insertText: '', + isIncomplete: true, + token: token + }); + } + this.options.worker.postMessage({ + model: this._settings.model, + text: multiLinePrefix, + max_new_tokens: this._settings.maxNewTokens, + temperature: this._settings.temperature, + top_k: this._settings.topK, + do_sample: this._settings.doSample, + num_return_sequences: this._settings.generateN, + idTokens, + action: 'generate' + }); + return { items }; + } + + async *stream(token: string) { + let done = false; + console.log('steaming', token); + while (!done) { + const delegate = new PromiseDelegate(); + this._streamPromises.set(token, delegate); + const promise = delegate.promise; + yield promise; + done = (await promise).done; + } + } + + private _notificationId: string = ''; + private _settings: ISettings = DEFAULT_SETTINGS; + private _streamPromises: Map> = new Map(); + private _ready = new PromiseDelegate(); + private _tokenCounter = 0; + private _lastModel = ''; +} + +const worker = new Worker(new URL('./worker.js', import.meta.url)); /** * Initialization data for the @jupyterlab/transformers-completer extension. @@ -9,9 +238,17 @@ import { const plugin: JupyterFrontEndPlugin = { id: '@jupyterlab/transformers-completer:plugin', description: 'A JupyterLab extension.', + requires: [ICompletionProviderManager], autoStart: true, - activate: (app: JupyterFrontEnd) => { - console.log('JupyterLab extension @jupyterlab/transformers-completer is activated!'); + activate: ( + app: JupyterFrontEnd, + providerManager: ICompletionProviderManager + ) => { + const provider = new TransformersInlineProvider({ worker }); + providerManager.registerInlineProvider(provider); + console.log( + 'JupyterLab extension @jupyterlab/transformers-completer is activated!' + ); } }; diff --git a/src/worker.ts b/src/worker.ts new file mode 100644 index 0000000..33c2faf --- /dev/null +++ b/src/worker.ts @@ -0,0 +1,113 @@ +// Apache-2.0 license +// Based on code by Joshua Lochner +import type { Pipeline } from '@xenova/transformers'; +import type * as transformersModuleNamespace from '@xenova/transformers'; +type transformersModuleType = { + env: typeof transformersModuleNamespace.env; + pipeline: typeof transformersModuleNamespace.pipeline; +}; + +/** + * This class uses the Singleton pattern to ensure that only one instance of the pipeline is loaded. + */ +class CodeCompletionPipeline { + static task = 'text-generation'; + static model?: string; + static instance?: Promise; + + static async getInstance( + progress_callback?: (message: any) => void + ): Promise { + // note: neither importScripts nor module import worked, see: + // https://github.com/webpack/webpack/issues/16633 + // https://github.com/webpack/webpack/issues/16173 + const transformers = (await import( + /* webpackIgnore: true */ + // @ts-ignore + 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.6.2' + )) as transformersModuleType; + + // @ts-ignore + transformers.env.allowLocalModels = false; + + if (!this.instance) { + this.instance = transformers.pipeline(this.task, this.model, { + progress_callback + }); + } + + return this.instance; + } +} + +// Listen for messages from the main thread +self.addEventListener('message', async event => { + const { + model, + text, + max_new_tokens, + + // Generation parameters + temperature, + top_k, + do_sample, + num_return_sequences, + idTokens, + action + } = event.data; + + if (CodeCompletionPipeline.model !== model) { + // Invalidate model if different + CodeCompletionPipeline.model = model; + + const instance = CodeCompletionPipeline.instance; + if (instance) { + (await instance).dispose(); + CodeCompletionPipeline.instance = undefined; + } + } + + // Retrieve the code-completion pipeline. When called for the first time, + // this will load the pipeline and save it for future use. + const generator = await CodeCompletionPipeline.getInstance(x => { + // We also add a progress callback to the pipeline so that we can + // track model loading. + self.postMessage({ ...x, model }); + }); + + if (action !== 'generate') { + return; + } + + // Actually perform the code-completion + const output = await generator(text, { + max_new_tokens, + temperature, + top_k, + do_sample, + num_beams: num_return_sequences, + num_return_sequences, + // Allows for partial output + callback_function: (x: any) => { + for (let i = 0; i < x.length; i++) { + const output = generator.tokenizer.decode(x[i].output_token_ids, { + skip_special_tokens: true + }); + self.postMessage({ + status: 'update', + output: output.substring(text.length), + idToken: idTokens[i] + }); + } + } + }); + + // Send the output back to the main thread + for (let i = 0; i < output.length; i++) { + self.postMessage({ + status: 'complete', + output: output[i].generated_text.substring(text.length), + idToken: idTokens[i] + }); + } +});