diff --git a/README.md b/README.md index 3f26531..5ac6545 100755 --- a/README.md +++ b/README.md @@ -41,8 +41,8 @@ This will start an instance of the program listening on (http://localhost:3523)[ - The keyword search supports quotes ("magic phrase" for sentences and negation (-) on words and sentences). - When doing the keyword search there's a limit of 1200 results. - All code is TypeScript and the UI framework is Angular. -- TensorFlow will run on the CPU. If you have AVX support on your CPU it uses native code and depending on the model and your CPU classifies in the range of 45ms. If your CPU does not support AVX it uses the JS only library and each classification will take in the range of 5s. -- TensorFlow currently requires internet connectivity. It fetches the list of ImageNet classes and the selected pre-trained model when it starts. Media is all classified locally. +- TensorFlow will run on the CPU. If you have AVX support on your CPU it uses native code and depending on the model and your CPU classifies in the range of 45ms. If your CPU does not support AVX it uses the JS only library and each classification will take in the range of 5s. However, it does use worker threads to make use of multi-core CPUs. +- TensorFlow currently requires internet connectivity the first time a model is used. It fetches the list of ImageNet classes and the selected pre-trained model when it starts. Media is all classified locally. Further classifications with that model will use the cached download. ### Github Sponsors diff --git a/server/src/tasks/tensorflow/classify.ts b/server/src/tasks/tensorflow/classify.ts index e2a0f71..ab6ba8d 100755 --- a/server/src/tasks/tensorflow/classify.ts +++ b/server/src/tasks/tensorflow/classify.ts @@ -1,10 +1,16 @@ +import { ClassifierWorkerWrapper } from './workers/classifier-worker-wrapper'; import { Database, RouterTask, TaskRunnerCallback } from '../../types'; import { IMAGENET_MODELS, loadClasses, loadModel } from './imagenet'; +import { ScalingConnectionPool, execute } from 'proper-job'; import { Transcoder } from '../../cache/transcoder'; -import { execute } from 'proper-job'; -import { loadImageFileCommon } from './common'; +import OS from 'os'; import TensorFlow from './tensorflow'; +// Two if native because it doesn't quite achieve 100% multi-core use, +// likely due to not feeding it data enough. +// One per core if JS only since it's single threaded. +const parallel = TensorFlow.isNative() ? 2 : OS.cpus().length; + export function getTask(database: Database): RouterTask[] { const transcoder = new Transcoder(database); @@ -25,16 +31,29 @@ export function getTask(database: Database): RouterTask[] { return []; } - const model = await loadModel(modelDefinition); - updateStatus(0, hashes.length); + + // This pre-downloads the classes and model so that the runners + // don't each attempt parallel downloads. + await loadClasses(); + await loadModel(modelDefinition); + + // Only begin creating runners when the pool completes. + const pool = new ScalingConnectionPool( + () => { + return new ClassifierWorkerWrapper(); + }, + { + maxInstances: parallel, + }, + ); + return { iterable: hashes, init: { - model, - classes: await loadClasses(), current: 0, max: hashes.length, + pool, }, }; }, @@ -47,35 +66,20 @@ export function getTask(database: Database): RouterTask[] { throw new Error(`Couldn't find media to generate preview: ${hash}`); } - const tensor = await loadImageFileCommon( - transcoder.getThumbnailPath(media), - modelDefinition.width, - modelDefinition.height, - ); try { - const classified = init.model.predict(tensor) as TensorFlow.Tensor; - try { - const classificationNestedArray = (await classified.array()) as number[][]; - const classificationArray = classificationNestedArray[0]; - - const probabilities = classificationArray - .map((probability, index) => { - return { - probability, - label: init.classes[index], - }; - }) - .sort((a, b) => b.probability - a.probability); - const topFive = probabilities.slice(0, 5); - await database.saveMedia(hash, { autoTags: topFive.map(output => output.label) }); - } finally { - classified.dispose(); - } + const results = await init.pool.run(instance => + instance.classify({ + definition: modelDefinition, + absolutePath: transcoder.getThumbnailPath(media), + }), + ); + await database.saveMedia(hash, { autoTags: results.probabilities }); } finally { - tensor.dispose(); updateStatus(init.current++, init.max); } }, + { parallel, continueOnError: true }, + init => init?.pool?.quit(), ); }, }; diff --git a/server/src/tasks/tensorflow/imagenet.ts b/server/src/tasks/tensorflow/imagenet.ts index 1dd2659..cf59d80 100755 --- a/server/src/tasks/tensorflow/imagenet.ts +++ b/server/src/tasks/tensorflow/imagenet.ts @@ -8,7 +8,7 @@ import TensorFlow from './tensorflow'; // loaded on non-AVX CPUs. import { NodeFileSystem } from '@tensorflow/tfjs-node/dist/io/file_system'; -interface TensorFlowHubModel { +export interface TensorFlowHubModel { id: string; name: string; url: string; @@ -77,9 +77,7 @@ export async function loadModel( await ImportUtils.mkdir(MODELS_DIR); try { const io = new NodeFileSystem(Path.resolve(MODELS_DIR, modelDefinition.id, 'model.json')); - console.log(`Load model: ${modelDefinition.id}`); const model = await TensorFlow.loadGraphModel(io); - console.log(`Model loaded: ${modelDefinition.id}`); return model; } catch (err) { console.error(err); diff --git a/server/src/tasks/tensorflow/tensorflow.ts b/server/src/tasks/tensorflow/tensorflow.ts index 4fab5c6..d8d0ee8 100755 --- a/server/src/tasks/tensorflow/tensorflow.ts +++ b/server/src/tasks/tensorflow/tensorflow.ts @@ -32,6 +32,7 @@ function isAvxSupported(): boolean { declare module '@tensorflow/tfjs' { export function decodeImage(buffer: Buffer): TFJS.Tensor3D; + export function isNative(): boolean; } export default TFJS; @@ -49,6 +50,7 @@ if (isAvxSupported()) { module.exports = { ...tfjs, decodeImage, + isNative: () => true, }; } else { console.warn( @@ -74,5 +76,6 @@ if (isAvxSupported()) { module.exports = { ...tfjs, decodeImage, + isNative: () => false, }; } diff --git a/server/src/tasks/tensorflow/workers/classifier-worker-wrapper.ts b/server/src/tasks/tensorflow/workers/classifier-worker-wrapper.ts new file mode 100755 index 0000000..827390c --- /dev/null +++ b/server/src/tasks/tensorflow/workers/classifier-worker-wrapper.ts @@ -0,0 +1,44 @@ +import { SHARE_ENV, Worker } from 'worker_threads'; +import { TensorFlowHubModel } from '../imagenet'; +import Path from 'path'; + +export interface ClassifierWorkerRequest { + definition: TensorFlowHubModel; + absolutePath: string; +} + +export interface ClassifierWorkerError { + err: string; +} + +export interface ClassifierWorkerSuccess { + probabilities: string[]; +} + +export type ClassifierWorkerResponse = ClassifierWorkerError | ClassifierWorkerSuccess; + +export class ClassifierWorkerWrapper { + // Silence stdout and stderr. + private worker = new Worker(Path.resolve(__dirname, 'classifier-worker.js'), { + env: SHARE_ENV, + stdout: true, + stderr: true, + }); + + public async quit(): Promise { + await this.worker.terminate(); + } + + public classify(request: ClassifierWorkerRequest): Promise { + return new Promise((resolve, reject) => { + this.worker.once('message', (response: ClassifierWorkerResponse) => { + if ((response as ClassifierWorkerError).err) { + reject(new Error((response as ClassifierWorkerError).err)); + } else { + resolve(response as ClassifierWorkerSuccess); + } + }); + this.worker.postMessage(request); + }); + } +} diff --git a/server/src/tasks/tensorflow/workers/classifier-worker.ts b/server/src/tasks/tensorflow/workers/classifier-worker.ts new file mode 100755 index 0000000..5901275 --- /dev/null +++ b/server/src/tasks/tensorflow/workers/classifier-worker.ts @@ -0,0 +1,85 @@ +import { + ClassifierWorkerRequest, + ClassifierWorkerResponse, + ClassifierWorkerSuccess, +} from './classifier-worker-wrapper'; +import { TensorFlowHubModel, loadClasses, loadModel } from '../imagenet'; +import { loadImageFileCommon } from '../common'; +import { parentPort } from 'worker_threads'; +import TensorFlow from '../tensorflow'; + +if (!parentPort) { + throw new Error('Worker missing fields'); +} + +interface ModelInfo { + model: Promise; + classes: Promise; + definition: TensorFlowHubModel; +} + +let modelInfo: ModelInfo | undefined = undefined; + +async function onMessage(message: ClassifierWorkerRequest): Promise { + if (modelInfo && modelInfo.definition.id !== message.definition.id) { + modelInfo = undefined; + } + if (!modelInfo) { + modelInfo = { + definition: message.definition, + model: loadModel(message.definition), + classes: loadClasses(), + }; + } + + const model = await modelInfo.model; + const classes = await modelInfo.classes; + + const tensor = await loadImageFileCommon( + message.absolutePath, + message.definition.width, + message.definition.height, + ); + + try { + const classified = model.predict(tensor) as TensorFlow.Tensor; + try { + const classificationNestedArray = (await classified.array()) as number[][]; + const classificationArray = classificationNestedArray[0]; + + const probabilities = classificationArray + .map((probability, index) => { + return { + probability, + label: classes[index], + }; + }) + .sort((a, b) => b.probability - a.probability); + const topFive = probabilities.slice(0, 5); + return { probabilities: topFive.map(output => output.label) }; + } finally { + classified.dispose(); + } + } finally { + tensor.dispose(); + } +} + +// Wrapper to enforce typing of response. +function postResult(result: ClassifierWorkerResponse): void { + parentPort!.postMessage(result); +} + +parentPort.on('message', (message: ClassifierWorkerRequest) => { + try { + onMessage(message) + .then(result => { + postResult(result); + }) + .catch(err => { + postResult({ err: err.message }); + }); + } catch (err) { + postResult({ err: err.message }); + } +}); diff --git a/server/src/utils/validator.ts b/server/src/utils/validator.ts index a5a7088..7f5f628 100755 --- a/server/src/utils/validator.ts +++ b/server/src/utils/validator.ts @@ -11,7 +11,7 @@ if (!require.main) { throw new Error('require.main missing'); } -const ROOT_DIR = Path.dirname(require.main.filename); +const ROOT_DIR = Path.resolve(__dirname, '..'); const SCHEMA_PATH = process.env['SCHEMA_PATH'] || `${ROOT_DIR}/schemas`; export class Validator {