Skip to content

Commit

Permalink
Run TensorFlow in worker_threads
Browse files Browse the repository at this point in the history
Moved all TensorFlow calculations to worker_threads. This was primarily
for the JS only variant. It hogged nearly the entire event loop causing
the web interface to appear very glitchy.
  • Loading branch information
SimplyBoo committed Mar 22, 2021
1 parent 78e9390 commit 2e5a8ef
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 36 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ This will start an instance of the program listening on (http://localhost:3523)[
- The keyword search supports quotes ("magic phrase" for sentences and negation (-) on words and sentences).
- When doing the keyword search there's a limit of 1200 results.
- All code is TypeScript and the UI framework is Angular.
- TensorFlow will run on the CPU. If you have AVX support on your CPU it uses native code and depending on the model and your CPU classifies in the range of 45ms. If your CPU does not support AVX it uses the JS only library and each classification will take in the range of 5s.
- TensorFlow currently requires internet connectivity. It fetches the list of ImageNet classes and the selected pre-trained model when it starts. Media is all classified locally.
- TensorFlow will run on the CPU. If you have AVX support on your CPU it uses native code and depending on the model and your CPU classifies in the range of 45ms. If your CPU does not support AVX it uses the JS only library and each classification will take in the range of 5s. However, it does use worker threads to make use of multi-core CPUs.
- TensorFlow currently requires internet connectivity the first time a model is used. It fetches the list of ImageNet classes and the selected pre-trained model when it starts. Media is all classified locally. Further classifications with that model will use the cached download.

### Github Sponsors

Expand Down
64 changes: 34 additions & 30 deletions server/src/tasks/tensorflow/classify.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import { ClassifierWorkerWrapper } from './workers/classifier-worker-wrapper';
import { Database, RouterTask, TaskRunnerCallback } from '../../types';
import { IMAGENET_MODELS, loadClasses, loadModel } from './imagenet';
import { ScalingConnectionPool, execute } from 'proper-job';
import { Transcoder } from '../../cache/transcoder';
import { execute } from 'proper-job';
import { loadImageFileCommon } from './common';
import OS from 'os';
import TensorFlow from './tensorflow';

// Two if native because it doesn't quite achieve 100% multi-core use,
// likely due to not feeding it data enough.
// One per core if JS only since it's single threaded.
const parallel = TensorFlow.isNative() ? 2 : OS.cpus().length;

export function getTask(database: Database): RouterTask[] {
const transcoder = new Transcoder(database);

Expand All @@ -25,16 +31,29 @@ export function getTask(database: Database): RouterTask[] {
return [];
}

const model = await loadModel(modelDefinition);

updateStatus(0, hashes.length);

// This pre-downloads the classes and model so that the runners
// don't each attempt parallel downloads.
await loadClasses();
await loadModel(modelDefinition);

// Only begin creating runners when the pool completes.
const pool = new ScalingConnectionPool(
() => {
return new ClassifierWorkerWrapper();
},
{
maxInstances: parallel,
},
);

return {
iterable: hashes,
init: {
model,
classes: await loadClasses(),
current: 0,
max: hashes.length,
pool,
},
};
},
Expand All @@ -47,35 +66,20 @@ export function getTask(database: Database): RouterTask[] {
throw new Error(`Couldn't find media to generate preview: ${hash}`);
}

const tensor = await loadImageFileCommon(
transcoder.getThumbnailPath(media),
modelDefinition.width,
modelDefinition.height,
);
try {
const classified = init.model.predict(tensor) as TensorFlow.Tensor<TensorFlow.Rank>;
try {
const classificationNestedArray = (await classified.array()) as number[][];
const classificationArray = classificationNestedArray[0];

const probabilities = classificationArray
.map((probability, index) => {
return {
probability,
label: init.classes[index],
};
})
.sort((a, b) => b.probability - a.probability);
const topFive = probabilities.slice(0, 5);
await database.saveMedia(hash, { autoTags: topFive.map(output => output.label) });
} finally {
classified.dispose();
}
const results = await init.pool.run(instance =>
instance.classify({
definition: modelDefinition,
absolutePath: transcoder.getThumbnailPath(media),
}),
);
await database.saveMedia(hash, { autoTags: results.probabilities });
} finally {
tensor.dispose();
updateStatus(init.current++, init.max);
}
},
{ parallel, continueOnError: true },
init => init?.pool?.quit(),
);
},
};
Expand Down
4 changes: 1 addition & 3 deletions server/src/tasks/tensorflow/imagenet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import TensorFlow from './tensorflow';
// loaded on non-AVX CPUs.
import { NodeFileSystem } from '@tensorflow/tfjs-node/dist/io/file_system';

interface TensorFlowHubModel {
export interface TensorFlowHubModel {
id: string;
name: string;
url: string;
Expand Down Expand Up @@ -77,9 +77,7 @@ export async function loadModel(
await ImportUtils.mkdir(MODELS_DIR);
try {
const io = new NodeFileSystem(Path.resolve(MODELS_DIR, modelDefinition.id, 'model.json'));
console.log(`Load model: ${modelDefinition.id}`);
const model = await TensorFlow.loadGraphModel(io);
console.log(`Model loaded: ${modelDefinition.id}`);
return model;
} catch (err) {
console.error(err);
Expand Down
3 changes: 3 additions & 0 deletions server/src/tasks/tensorflow/tensorflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function isAvxSupported(): boolean {

declare module '@tensorflow/tfjs' {
export function decodeImage(buffer: Buffer): TFJS.Tensor3D;
export function isNative(): boolean;
}

export default TFJS;
Expand All @@ -49,6 +50,7 @@ if (isAvxSupported()) {
module.exports = {
...tfjs,
decodeImage,
isNative: () => true,
};
} else {
console.warn(
Expand All @@ -74,5 +76,6 @@ if (isAvxSupported()) {
module.exports = {
...tfjs,
decodeImage,
isNative: () => false,
};
}
44 changes: 44 additions & 0 deletions server/src/tasks/tensorflow/workers/classifier-worker-wrapper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { SHARE_ENV, Worker } from 'worker_threads';
import { TensorFlowHubModel } from '../imagenet';
import Path from 'path';

export interface ClassifierWorkerRequest {
definition: TensorFlowHubModel;
absolutePath: string;
}

export interface ClassifierWorkerError {
err: string;
}

export interface ClassifierWorkerSuccess {
probabilities: string[];
}

export type ClassifierWorkerResponse = ClassifierWorkerError | ClassifierWorkerSuccess;

export class ClassifierWorkerWrapper {
// Silence stdout and stderr.
private worker = new Worker(Path.resolve(__dirname, 'classifier-worker.js'), {
env: SHARE_ENV,
stdout: true,
stderr: true,
});

public async quit(): Promise<void> {
await this.worker.terminate();
}

public classify(request: ClassifierWorkerRequest): Promise<ClassifierWorkerSuccess> {
return new Promise<ClassifierWorkerSuccess>((resolve, reject) => {
this.worker.once('message', (response: ClassifierWorkerResponse) => {
if ((response as ClassifierWorkerError).err) {
reject(new Error((response as ClassifierWorkerError).err));
} else {
resolve(response as ClassifierWorkerSuccess);
}
});
this.worker.postMessage(request);
});
}
}
85 changes: 85 additions & 0 deletions server/src/tasks/tensorflow/workers/classifier-worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import {
ClassifierWorkerRequest,
ClassifierWorkerResponse,
ClassifierWorkerSuccess,
} from './classifier-worker-wrapper';
import { TensorFlowHubModel, loadClasses, loadModel } from '../imagenet';
import { loadImageFileCommon } from '../common';
import { parentPort } from 'worker_threads';
import TensorFlow from '../tensorflow';

if (!parentPort) {
throw new Error('Worker missing fields');
}

interface ModelInfo {
model: Promise<TensorFlow.GraphModel>;
classes: Promise<string[]>;
definition: TensorFlowHubModel;
}

let modelInfo: ModelInfo | undefined = undefined;

async function onMessage(message: ClassifierWorkerRequest): Promise<ClassifierWorkerSuccess> {
if (modelInfo && modelInfo.definition.id !== message.definition.id) {
modelInfo = undefined;
}
if (!modelInfo) {
modelInfo = {
definition: message.definition,
model: loadModel(message.definition),
classes: loadClasses(),
};
}

const model = await modelInfo.model;
const classes = await modelInfo.classes;

const tensor = await loadImageFileCommon(
message.absolutePath,
message.definition.width,
message.definition.height,
);

try {
const classified = model.predict(tensor) as TensorFlow.Tensor<TensorFlow.Rank>;
try {
const classificationNestedArray = (await classified.array()) as number[][];
const classificationArray = classificationNestedArray[0];

const probabilities = classificationArray
.map((probability, index) => {
return {
probability,
label: classes[index],
};
})
.sort((a, b) => b.probability - a.probability);
const topFive = probabilities.slice(0, 5);
return { probabilities: topFive.map(output => output.label) };
} finally {
classified.dispose();
}
} finally {
tensor.dispose();
}
}

// Wrapper to enforce typing of response.
function postResult(result: ClassifierWorkerResponse): void {
parentPort!.postMessage(result);
}

parentPort.on('message', (message: ClassifierWorkerRequest) => {
try {
onMessage(message)
.then(result => {
postResult(result);
})
.catch(err => {
postResult({ err: err.message });
});
} catch (err) {
postResult({ err: err.message });
}
});
2 changes: 1 addition & 1 deletion server/src/utils/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ if (!require.main) {
throw new Error('require.main missing');
}

const ROOT_DIR = Path.dirname(require.main.filename);
const ROOT_DIR = Path.resolve(__dirname, '..');
const SCHEMA_PATH = process.env['SCHEMA_PATH'] || `${ROOT_DIR}/schemas`;

export class Validator {
Expand Down

0 comments on commit 2e5a8ef

Please sign in to comment.