Skip to content

Commit

Permalink
List models with criteria
Browse files Browse the repository at this point in the history
Change-Id: I9364e5d67a7a8de95f76eb7c69707b71fafef871
  • Loading branch information
zachgk committed Sep 17, 2020
1 parent a20892a commit 13e1c60
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 6 deletions.
10 changes: 10 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ public String toString() {
return path.replace('/', '.').toUpperCase();
}

/**
* Returns whether this application matches the test application set.
*
* @param test a application or application set to test against
* @return true if it fits within the application set
*/
public boolean matches(Application test) {
return path.startsWith(test.path);
}

/** {@inheritDoc} */
@Override
public boolean equals(Object o) {
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ public String toString() {

@SuppressWarnings("unchecked")
private <I, O> TranslatorFactory<I, O> getTranslatorFactory(Criteria<I, O> criteria) {
if (criteria.getInputClass() == null) {
throw new IllegalArgumentException("The criteria must set an input class");
}
if (criteria.getOutputClass() == null) {
throw new IllegalArgumentException("The criteria must set an output class");
}
return (TranslatorFactory<I, O>)
factories.get(new Pair<>(criteria.getInputClass(), criteria.getOutputClass()));
}
Expand Down
4 changes: 0 additions & 4 deletions api/src/main/java/ai/djl/repository/zoo/Criteria.java
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,6 @@ public Builder<I, O> optProgress(Progress progress) {
* @return the {@link Criteria} instance
*/
public Criteria<I, O> build() {
if (inputClass == null || outputClass == null) {
throw new IllegalArgumentException(
"Input and output type are required for a Criteria.");
}
return new Criteria<>(this);
}
}
Expand Down
43 changes: 41 additions & 2 deletions api/src/main/java/ai/djl/repository/zoo/ModelZoo.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ default ModelLoader getModelLoader(String name) {
/**
* Gets the {@link ModelLoader} based on the model name.
*
* @param criteria the name of the model
* @param criteria the requirements for the model
* @param <I> the input data type for preprocessing
* @param <O> the output data type after postprocessing
* @return the model that matches the criteria
Expand Down Expand Up @@ -153,7 +153,7 @@ static <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria)
}
if (application != Application.UNDEFINED
&& app != Application.UNDEFINED
&& !app.equals(application)) {
&& !app.matches(application)) {
// filter out ModelLoader by application
continue;
}
Expand Down Expand Up @@ -183,6 +183,25 @@ static <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria)
*/
static Map<Application, List<Artifact>> listModels()
throws IOException, ModelNotFoundException {
return listModels(Criteria.builder().build());
}

/**
* Returns the available {@link Application} and their model artifact metadata.
*
* @param criteria the requirements for the model
* @return the available {@link Application} and their model artifact metadata
* @throws IOException if failed to download to repository metadata
* @throws ModelNotFoundException if failed to parse repository metadata
*/
static Map<Application, List<Artifact>> listModels(Criteria<?, ?> criteria)
throws IOException, ModelNotFoundException {
String artifactId = criteria.getArtifactId();
ModelZoo modelZoo = criteria.getModelZoo();
String groupId = criteria.getGroupId();
String engine = criteria.getEngine();
Application application = criteria.getApplication();

@SuppressWarnings("PMD.UseConcurrentHashMap")
Map<Application, List<Artifact>> models =
new TreeMap<>(Comparator.comparing(Application::getPath));
Expand All @@ -192,9 +211,29 @@ static Map<Application, List<Artifact>> listModels()
if (zoo == null) {
continue;
}
if (modelZoo != null) {
if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
continue;
}
Set<String> supportedEngine = modelZoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
continue;
}
}
List<ModelLoader> list = zoo.getModelLoaders();
for (ModelLoader loader : list) {
Application app = loader.getApplication();
String loaderArtifactId = loader.getArtifactId();
if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
// filter out by model loader artifactId
continue;
}
if (application != Application.UNDEFINED
&& app != Application.UNDEFINED
&& !app.matches(application)) {
// filter out ModelLoader by application
continue;
}
final List<Artifact> artifacts = loader.listModels();
models.compute(
app,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
package ai.djl.examples.inference;

import ai.djl.Application;
import ai.djl.Application.CV;
import ai.djl.Application.NLP;
import ai.djl.ModelException;
import ai.djl.repository.Artifact;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import java.io.IOException;
import java.nio.file.Path;
Expand All @@ -41,4 +44,19 @@ public void testListModels() throws ModelException, IOException {
List<Artifact> artifacts = models.get(Application.UNDEFINED);
Assert.assertFalse(artifacts.isEmpty());
}

@Test
public void testListModelsWithApplication() throws ModelException, IOException {
Path path = Paths.get("../model-zoo/src/test/resources/mlrepo");
String repoUrl = path.toRealPath().toAbsolutePath().toUri().toURL().toExternalForm();
System.setProperty("ai.djl.repository.zoo.location", "src/test/resources," + repoUrl);
Criteria<?, ?> criteria = Criteria.builder().optApplication(NLP.ANY).build();
Map<Application, List<Artifact>> models = ModelZoo.listModels(criteria);

for (Application application : models.keySet()) {
Assert.assertTrue(
application.matches(NLP.ANY) || application.matches(Application.UNDEFINED));
Assert.assertFalse(application.matches(CV.ANY));
}
}
}

0 comments on commit 13e1c60

Please sign in to comment.