Skip to content

Commit

Permalink
Initial implementation
Browse files Browse the repository at this point in the history
krassowski committed Oct 2, 2023
1 parent d6d6bd2 commit 5ebfc41
Showing 3 changed files with 355 additions and 3 deletions.
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -54,7 +54,9 @@
"watch:labextension": "jupyter labextension watch ."
},
"dependencies": {
"@jupyterlab/application": "^4.0.0"
"@jupyterlab/application": "^4.0.0",
"@jupyterlab/completer": "^4.0.0",
"@xenova/transformers": "^2.6.2"
},
"devDependencies": {
"@jupyterlab/builder": "^4.0.0",
241 changes: 239 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -2,16 +2,253 @@ import {
JupyterFrontEnd,
JupyterFrontEndPlugin
} from '@jupyterlab/application';
import {
ICompletionProviderManager,
IInlineCompletionProvider,
IInlineCompletionContext,
CompletionHandler,
IInlineCompletionList,
IInlineCompletionItem
} from '@jupyterlab/completer';
import type { ISettingRegistry } from '@jupyterlab/settingregistry';
import { Notification } from '@jupyterlab/apputils';
import { JSONValue, PromiseDelegate } from '@lumino/coreutils';

interface ISettings {
temperature: number;
model: string;
maxNewTokens: number;
topK: number;
doSample: boolean;
generateN: number;
}

const DEFAULT_SETTINGS: ISettings = {
model: 'Xenova/tiny_starcoder_py',
temperature: 0.5,
doSample: false,
topK: 5,
maxNewTokens: 50,
generateN: 2
};

interface IOptions {
worker: Worker;
}

interface IStream {
done: boolean;
response: IInlineCompletionItem;
}

class TransformersInlineProvider implements IInlineCompletionProvider {
constructor(protected options: IOptions) {
options.worker.addEventListener(
'message',
this.onMessageReceived.bind(this)
);
}

readonly identifier = '@krassowski/inline-completer';
readonly name = 'Transformers powered completions';

get schema(): ISettingRegistry.IProperty {
return {
properties: {
model: {
title: 'Model',
enum: [
// https://huggingface.co/bigcode/tiny_starcoder_py
'Xenova/tiny_starcoder_py',
// https://huggingface.co/Salesforce/codegen-350M-mono
'Xenova/codegen-350M-mono'
],
type: 'string'
},
temperature: {
minimum: 0,
maximum: 1,
type: 'number',
description: 'The value used to module the next token probabilities'
},
doSample: {
type: 'boolean',
description: 'Whether to use sampling; use greedy decoding otherwise'
},
topK: {
minimum: 0,
maximum: 50,
type: 'number',
description:
'The number of highest probability vocabulary tokens to keep for top-k-filtering'
},
maxNewTokens: {
minimum: 1,
maximum: 512,
type: 'number'
},
generateN: {
minimum: 1,
maximum: 10,
type: 'number'
}
},
default: DEFAULT_SETTINGS as any
};
}

configure(settings: { [property: string]: JSONValue }): void {
this._settings = settings as any as ISettings;
console.log(this._settings);
this.options.worker.postMessage({
action: 'initializeModel',
model: this._settings.model
});
}

// TODO types
onMessageReceived(e: any) {
console.log(e);
const data = e.data;
switch (e.data.status) {
case 'initiate':
this._ready = new PromiseDelegate();
if (data.model !== this._lastModel) {
this._notificationId = Notification.info(
'Loading model' + data.model + ': fetching' + data.file
);
this._lastModel = data.model;
}
break;
case 'progress':
Notification.update({
id: this._notificationId,
message:
'Loading model ' +
data.model +
' ' +
Math.round(data.progress) +
'%',
type: 'in-progress',
progress: data.progress
});
break;

case 'done':
Notification.dismiss(this._notificationId);
break;

case 'ready':
this._ready.resolve(void 0);
break;

case 'update': {
const token = data.idToken;
const delegate = this._streamPromises.get(token);
if (!delegate) {
console.warn('Completion updated but stream absent');
} else {
delegate.resolve({
done: false,
response: {
insertText: data.output
}
});
}
break;
}
case 'complete': {
const token = data.idToken;
const delegate = this._streamPromises.get(token);
if (!delegate) {
console.warn('Completion done but stream absent');
} else {
delegate.resolve({
done: true,
response: {
insertText: data.output
}
});
}
break;
}
}
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
): Promise<IInlineCompletionList<IInlineCompletionItem>> {
// TODO:
await this._ready.promise;
this._streamPromises = new Map();
const multiLinePrefix = request.text.slice(0, request.offset);
const linePrefix = multiLinePrefix.split('\n').slice(-1)[0];
console.log(linePrefix);
const items: IInlineCompletionItem[] = [];
const idTokens = [];
for (let i = 0; i < this._settings.generateN; i++) {
const token = 'T' + ++this._tokenCounter;
idTokens.push(token);
items.push({
insertText: '',
isIncomplete: true,
token: token
});
}
this.options.worker.postMessage({
model: this._settings.model,
text: multiLinePrefix,
max_new_tokens: this._settings.maxNewTokens,
temperature: this._settings.temperature,
top_k: this._settings.topK,
do_sample: this._settings.doSample,
num_return_sequences: this._settings.generateN,
idTokens,
action: 'generate'
});
return { items };
}

async *stream(token: string) {
let done = false;
console.log('steaming', token);
while (!done) {
const delegate = new PromiseDelegate<IStream>();
this._streamPromises.set(token, delegate);
const promise = delegate.promise;
yield promise;
done = (await promise).done;
}
}

private _notificationId: string = '';
private _settings: ISettings = DEFAULT_SETTINGS;
private _streamPromises: Map<string, PromiseDelegate<IStream>> = new Map();
private _ready = new PromiseDelegate();
private _tokenCounter = 0;
private _lastModel = '';
}

const worker = new Worker(new URL('./worker.js', import.meta.url));

/**
* Initialization data for the @jupyterlab/transformers-completer extension.
*/
const plugin: JupyterFrontEndPlugin<void> = {
id: '@jupyterlab/transformers-completer:plugin',
description: 'A JupyterLab extension.',
requires: [ICompletionProviderManager],
autoStart: true,
activate: (app: JupyterFrontEnd) => {
console.log('JupyterLab extension @jupyterlab/transformers-completer is activated!');
activate: (
app: JupyterFrontEnd,
providerManager: ICompletionProviderManager
) => {
const provider = new TransformersInlineProvider({ worker });
providerManager.registerInlineProvider(provider);
console.log(
'JupyterLab extension @jupyterlab/transformers-completer is activated!'
);
}
};

113 changes: 113 additions & 0 deletions src/worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Apache-2.0 license
// Based on code by Joshua Lochner
import type { Pipeline } from '@xenova/transformers';
import type * as transformersModuleNamespace from '@xenova/transformers';
type transformersModuleType = {
env: typeof transformersModuleNamespace.env;
pipeline: typeof transformersModuleNamespace.pipeline;
};

/**
* This class uses the Singleton pattern to ensure that only one instance of the pipeline is loaded.
*/
class CodeCompletionPipeline {
static task = 'text-generation';
static model?: string;
static instance?: Promise<Pipeline>;

static async getInstance(
progress_callback?: (message: any) => void
): Promise<Pipeline> {
// note: neither importScripts nor module import worked, see:
// https://github.com/webpack/webpack/issues/16633
// https://github.com/webpack/webpack/issues/16173
const transformers = (await import(
/* webpackIgnore: true */
// @ts-ignore
'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.6.2'
)) as transformersModuleType;

// @ts-ignore
transformers.env.allowLocalModels = false;

if (!this.instance) {
this.instance = transformers.pipeline(this.task, this.model, {
progress_callback
});
}

return this.instance;
}
}

// Listen for messages from the main thread
self.addEventListener('message', async event => {
const {
model,
text,
max_new_tokens,

// Generation parameters
temperature,
top_k,
do_sample,
num_return_sequences,
idTokens,
action
} = event.data;

if (CodeCompletionPipeline.model !== model) {
// Invalidate model if different
CodeCompletionPipeline.model = model;

const instance = CodeCompletionPipeline.instance;
if (instance) {
(await instance).dispose();
CodeCompletionPipeline.instance = undefined;
}
}

// Retrieve the code-completion pipeline. When called for the first time,
// this will load the pipeline and save it for future use.
const generator = await CodeCompletionPipeline.getInstance(x => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage({ ...x, model });
});

if (action !== 'generate') {
return;
}

// Actually perform the code-completion
const output = await generator(text, {
max_new_tokens,
temperature,
top_k,
do_sample,
num_beams: num_return_sequences,
num_return_sequences,
// Allows for partial output
callback_function: (x: any) => {
for (let i = 0; i < x.length; i++) {
const output = generator.tokenizer.decode(x[i].output_token_ids, {
skip_special_tokens: true
});
self.postMessage({
status: 'update',
output: output.substring(text.length),
idToken: idTokens[i]
});
}
}
});

// Send the output back to the main thread
for (let i = 0; i < output.length; i++) {
self.postMessage({
status: 'complete',
output: output[i].generated_text.substring(text.length),
idToken: idTokens[i]
});
}
});

0 comments on commit 5ebfc41

Please sign in to comment.