From 78e9390fe220266913ca1f30977ae051d6eae674 Mon Sep 17 00:00:00 2001 From: SimplyBoo Date: Sun, 21 Mar 2021 20:13:36 +0000 Subject: [PATCH] Added local TensorFlow model caching --- server/src/tasks/tensorflow/classify.ts | 4 +- server/src/tasks/tensorflow/imagenet.ts | 91 ++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 11 deletions(-) diff --git a/server/src/tasks/tensorflow/classify.ts b/server/src/tasks/tensorflow/classify.ts index a46def8..e2a0f71 100755 --- a/server/src/tasks/tensorflow/classify.ts +++ b/server/src/tasks/tensorflow/classify.ts @@ -1,5 +1,5 @@ import { Database, RouterTask, TaskRunnerCallback } from '../../types'; -import { IMAGENET_MODELS, loadClasses } from './imagenet'; +import { IMAGENET_MODELS, loadClasses, loadModel } from './imagenet'; import { Transcoder } from '../../cache/transcoder'; import { execute } from 'proper-job'; import { loadImageFileCommon } from './common'; @@ -25,7 +25,7 @@ export function getTask(database: Database): RouterTask[] { return []; } - const model = await TensorFlow.loadGraphModel(modelDefinition.url, { fromTFHub: true }); + const model = await loadModel(modelDefinition); updateStatus(0, hashes.length); return { diff --git a/server/src/tasks/tensorflow/imagenet.ts b/server/src/tasks/tensorflow/imagenet.ts index 5adb7f9..1dd2659 100755 --- a/server/src/tasks/tensorflow/imagenet.ts +++ b/server/src/tasks/tensorflow/imagenet.ts @@ -1,22 +1,95 @@ +import { ImportUtils } from '../../cache/import-utils'; +import Config from '../../config'; +import FS from 'fs'; import Fetch from 'node-fetch'; +import Path from 'path'; +import TensorFlow from './tensorflow'; +// Explicitly import from the sub-directory so the TFJS node backend isn't +// loaded on non-AVX CPUs. +import { NodeFileSystem } from '@tensorflow/tfjs-node/dist/io/file_system'; -export async function loadClasses(): Promise { +interface TensorFlowHubModel { + id: string; + name: string; + url: string; + width: number; + height: number; +} + +const MODELS_DIR = Path.resolve(Config.get().cachePath, 'tensorflow'); + +function loadClassesLocal(): Promise { + return new Promise((resolve, reject) => { + FS.readFile(Path.resolve(MODELS_DIR, 'ImageNetLabels.txt'), (err, data) => { + if (err) { + reject(err); + } else { + resolve(data.toString()); + } + }); + }); +} + +async function loadClassesRemote(): Promise { const res = await Fetch( 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt', ); if (!res.ok) { throw new Error('Failed to fetch imagenet classes'); } - const text = await res.text(); - return text.split('\n').map(line => line.trim()); + return res.text(); } -interface TensorFlowHubModel { - id: string; - name: string; - url: string; - width: number; - height: number; +function saveClasses(classes: string): Promise { + return new Promise((resolve, reject) => { + FS.writeFile(Path.resolve(MODELS_DIR, 'ImageNetLabels.txt'), classes, err => { + if (err) { + reject(err); + } else { + resolve(); + } + }); + }); +} + +async function loadClassesRaw(): Promise { + try { + const classes = await loadClassesLocal(); + return classes; + } catch { + console.log('Downloading ImageNet classes'); + const classes = await loadClassesRemote(); + await saveClasses(classes); + console.log('ImageNet classes saved'); + return classes; + } +} + +export async function loadClasses(): Promise { + await ImportUtils.mkdir(MODELS_DIR); + const raw = await loadClassesRaw(); + return raw.split('\n').map(line => line.trim()); +} + +export async function loadModel( + modelDefinition: TensorFlowHubModel, +): Promise { + await ImportUtils.mkdir(MODELS_DIR); + try { + const io = new NodeFileSystem(Path.resolve(MODELS_DIR, modelDefinition.id, 'model.json')); + console.log(`Load model: ${modelDefinition.id}`); + const model = await TensorFlow.loadGraphModel(io); + console.log(`Model loaded: ${modelDefinition.id}`); + return model; + } catch (err) { + console.error(err); + console.log(`Downloading model: ${modelDefinition.id}`); + const model = await TensorFlow.loadGraphModel(modelDefinition.url, { fromTFHub: true }); + const io = new NodeFileSystem(Path.resolve(MODELS_DIR, modelDefinition.id)); + await model.save(io); + console.log(`Model saved: ${modelDefinition.id}`); + return model; + } } // All use the common image input.