-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moved all TensorFlow calculations to worker_threads. This was primarily for the JS only variant. It hogged nearly the entire event loop causing the web interface to appear very glitchy.
- Loading branch information
SimplyBoo
committed
Mar 22, 2021
1 parent
78e9390
commit 2e5a8ef
Showing
7 changed files
with
170 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
server/src/tasks/tensorflow/workers/classifier-worker-wrapper.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import { SHARE_ENV, Worker } from 'worker_threads'; | ||
import { TensorFlowHubModel } from '../imagenet'; | ||
import Path from 'path'; | ||
|
||
export interface ClassifierWorkerRequest { | ||
definition: TensorFlowHubModel; | ||
absolutePath: string; | ||
} | ||
|
||
export interface ClassifierWorkerError { | ||
err: string; | ||
} | ||
|
||
export interface ClassifierWorkerSuccess { | ||
probabilities: string[]; | ||
} | ||
|
||
export type ClassifierWorkerResponse = ClassifierWorkerError | ClassifierWorkerSuccess; | ||
|
||
export class ClassifierWorkerWrapper { | ||
// Silence stdout and stderr. | ||
private worker = new Worker(Path.resolve(__dirname, 'classifier-worker.js'), { | ||
env: SHARE_ENV, | ||
stdout: true, | ||
stderr: true, | ||
}); | ||
|
||
public async quit(): Promise<void> { | ||
await this.worker.terminate(); | ||
} | ||
|
||
public classify(request: ClassifierWorkerRequest): Promise<ClassifierWorkerSuccess> { | ||
return new Promise<ClassifierWorkerSuccess>((resolve, reject) => { | ||
this.worker.once('message', (response: ClassifierWorkerResponse) => { | ||
if ((response as ClassifierWorkerError).err) { | ||
reject(new Error((response as ClassifierWorkerError).err)); | ||
} else { | ||
resolve(response as ClassifierWorkerSuccess); | ||
} | ||
}); | ||
this.worker.postMessage(request); | ||
}); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import { | ||
ClassifierWorkerRequest, | ||
ClassifierWorkerResponse, | ||
ClassifierWorkerSuccess, | ||
} from './classifier-worker-wrapper'; | ||
import { TensorFlowHubModel, loadClasses, loadModel } from '../imagenet'; | ||
import { loadImageFileCommon } from '../common'; | ||
import { parentPort } from 'worker_threads'; | ||
import TensorFlow from '../tensorflow'; | ||
|
||
if (!parentPort) { | ||
throw new Error('Worker missing fields'); | ||
} | ||
|
||
interface ModelInfo { | ||
model: Promise<TensorFlow.GraphModel>; | ||
classes: Promise<string[]>; | ||
definition: TensorFlowHubModel; | ||
} | ||
|
||
let modelInfo: ModelInfo | undefined = undefined; | ||
|
||
async function onMessage(message: ClassifierWorkerRequest): Promise<ClassifierWorkerSuccess> { | ||
if (modelInfo && modelInfo.definition.id !== message.definition.id) { | ||
modelInfo = undefined; | ||
} | ||
if (!modelInfo) { | ||
modelInfo = { | ||
definition: message.definition, | ||
model: loadModel(message.definition), | ||
classes: loadClasses(), | ||
}; | ||
} | ||
|
||
const model = await modelInfo.model; | ||
const classes = await modelInfo.classes; | ||
|
||
const tensor = await loadImageFileCommon( | ||
message.absolutePath, | ||
message.definition.width, | ||
message.definition.height, | ||
); | ||
|
||
try { | ||
const classified = model.predict(tensor) as TensorFlow.Tensor<TensorFlow.Rank>; | ||
try { | ||
const classificationNestedArray = (await classified.array()) as number[][]; | ||
const classificationArray = classificationNestedArray[0]; | ||
|
||
const probabilities = classificationArray | ||
.map((probability, index) => { | ||
return { | ||
probability, | ||
label: classes[index], | ||
}; | ||
}) | ||
.sort((a, b) => b.probability - a.probability); | ||
const topFive = probabilities.slice(0, 5); | ||
return { probabilities: topFive.map(output => output.label) }; | ||
} finally { | ||
classified.dispose(); | ||
} | ||
} finally { | ||
tensor.dispose(); | ||
} | ||
} | ||
|
||
// Wrapper to enforce typing of response. | ||
function postResult(result: ClassifierWorkerResponse): void { | ||
parentPort!.postMessage(result); | ||
} | ||
|
||
parentPort.on('message', (message: ClassifierWorkerRequest) => { | ||
try { | ||
onMessage(message) | ||
.then(result => { | ||
postResult(result); | ||
}) | ||
.catch(err => { | ||
postResult({ err: err.message }); | ||
}); | ||
} catch (err) { | ||
postResult({ err: err.message }); | ||
} | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters