Skip to content

Commit

Permalink
Added local TensorFlow model caching
Browse files Browse the repository at this point in the history
  • Loading branch information
SimplyBoo committed Mar 21, 2021
1 parent 40fa5e6 commit 78e9390
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 11 deletions.
4 changes: 2 additions & 2 deletions server/src/tasks/tensorflow/classify.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Database, RouterTask, TaskRunnerCallback } from '../../types';
import { IMAGENET_MODELS, loadClasses } from './imagenet';
import { IMAGENET_MODELS, loadClasses, loadModel } from './imagenet';
import { Transcoder } from '../../cache/transcoder';
import { execute } from 'proper-job';
import { loadImageFileCommon } from './common';
Expand All @@ -25,7 +25,7 @@ export function getTask(database: Database): RouterTask[] {
return [];
}

const model = await TensorFlow.loadGraphModel(modelDefinition.url, { fromTFHub: true });
const model = await loadModel(modelDefinition);

updateStatus(0, hashes.length);
return {
Expand Down
91 changes: 82 additions & 9 deletions server/src/tasks/tensorflow/imagenet.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,95 @@
import { ImportUtils } from '../../cache/import-utils';
import Config from '../../config';
import FS from 'fs';
import Fetch from 'node-fetch';
import Path from 'path';
import TensorFlow from './tensorflow';
// Explicitly import from the sub-directory so the TFJS node backend isn't
// loaded on non-AVX CPUs.
import { NodeFileSystem } from '@tensorflow/tfjs-node/dist/io/file_system';

export async function loadClasses(): Promise<string[]> {
interface TensorFlowHubModel {
id: string;
name: string;
url: string;
width: number;
height: number;
}

const MODELS_DIR = Path.resolve(Config.get().cachePath, 'tensorflow');

function loadClassesLocal(): Promise<string> {
return new Promise<string>((resolve, reject) => {
FS.readFile(Path.resolve(MODELS_DIR, 'ImageNetLabels.txt'), (err, data) => {
if (err) {
reject(err);
} else {
resolve(data.toString());
}
});
});
}

async function loadClassesRemote(): Promise<string> {
const res = await Fetch(
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt',
);
if (!res.ok) {
throw new Error('Failed to fetch imagenet classes');
}
const text = await res.text();
return text.split('\n').map(line => line.trim());
return res.text();
}

interface TensorFlowHubModel {
id: string;
name: string;
url: string;
width: number;
height: number;
function saveClasses(classes: string): Promise<void> {
return new Promise((resolve, reject) => {
FS.writeFile(Path.resolve(MODELS_DIR, 'ImageNetLabels.txt'), classes, err => {
if (err) {
reject(err);
} else {
resolve();
}
});
});
}

async function loadClassesRaw(): Promise<string> {
try {
const classes = await loadClassesLocal();
return classes;
} catch {
console.log('Downloading ImageNet classes');
const classes = await loadClassesRemote();
await saveClasses(classes);
console.log('ImageNet classes saved');
return classes;
}
}

export async function loadClasses(): Promise<string[]> {
await ImportUtils.mkdir(MODELS_DIR);
const raw = await loadClassesRaw();
return raw.split('\n').map(line => line.trim());
}

export async function loadModel(
modelDefinition: TensorFlowHubModel,
): Promise<TensorFlow.GraphModel> {
await ImportUtils.mkdir(MODELS_DIR);
try {
const io = new NodeFileSystem(Path.resolve(MODELS_DIR, modelDefinition.id, 'model.json'));
console.log(`Load model: ${modelDefinition.id}`);
const model = await TensorFlow.loadGraphModel(io);
console.log(`Model loaded: ${modelDefinition.id}`);
return model;
} catch (err) {
console.error(err);
console.log(`Downloading model: ${modelDefinition.id}`);
const model = await TensorFlow.loadGraphModel(modelDefinition.url, { fromTFHub: true });
const io = new NodeFileSystem(Path.resolve(MODELS_DIR, modelDefinition.id));
await model.save(io);
console.log(`Model saved: ${modelDefinition.id}`);
return model;
}
}

// All use the common image input.
Expand Down

0 comments on commit 78e9390

Please sign in to comment.