diff --git a/api/src/main/java/ai/djl/Application.java b/api/src/main/java/ai/djl/Application.java index 349eea205e0..61bbfb37d33 100644 --- a/api/src/main/java/ai/djl/Application.java +++ b/api/src/main/java/ai/djl/Application.java @@ -86,6 +86,16 @@ public String toString() { return path.replace('/', '.').toUpperCase(); } + /** + * Returns whether this application matches the test application set. + * + * @param test a application or application set to test against + * @return true if it fits within the application set + */ + public boolean matches(Application test) { + return path.startsWith(test.path); + } + /** {@inheritDoc} */ @Override public boolean equals(Object o) { diff --git a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java index 0d1506871c0..e350e8d11fa 100644 --- a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java +++ b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java @@ -198,6 +198,12 @@ public String toString() { @SuppressWarnings("unchecked") private TranslatorFactory getTranslatorFactory(Criteria criteria) { + if (criteria.getInputClass() == null) { + throw new IllegalArgumentException("The criteria must set an input class"); + } + if (criteria.getOutputClass() == null) { + throw new IllegalArgumentException("The criteria must set an output class"); + } return (TranslatorFactory) factories.get(new Pair<>(criteria.getInputClass(), criteria.getOutputClass())); } diff --git a/api/src/main/java/ai/djl/repository/zoo/Criteria.java b/api/src/main/java/ai/djl/repository/zoo/Criteria.java index 0c2c94b5501..0667d672a95 100644 --- a/api/src/main/java/ai/djl/repository/zoo/Criteria.java +++ b/api/src/main/java/ai/djl/repository/zoo/Criteria.java @@ -481,10 +481,6 @@ public Builder optProgress(Progress progress) { * @return the {@link Criteria} instance */ public Criteria build() { - if (inputClass == null || outputClass == null) { - throw new IllegalArgumentException( - "Input and output type are required for a Criteria."); - } return new Criteria<>(this); } } diff --git a/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java b/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java index a7afad67c8e..ab74afe80b7 100644 --- a/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java +++ b/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java @@ -85,7 +85,7 @@ default ModelLoader getModelLoader(String name) { /** * Gets the {@link ModelLoader} based on the model name. * - * @param criteria the name of the model + * @param criteria the requirements for the model * @param the input data type for preprocessing * @param the output data type after postprocessing * @return the model that matches the criteria @@ -153,7 +153,7 @@ static ZooModel loadModel(Criteria criteria) } if (application != Application.UNDEFINED && app != Application.UNDEFINED - && !app.equals(application)) { + && !app.matches(application)) { // filter out ModelLoader by application continue; } @@ -183,6 +183,25 @@ static ZooModel loadModel(Criteria criteria) */ static Map> listModels() throws IOException, ModelNotFoundException { + return listModels(Criteria.builder().build()); + } + + /** + * Returns the available {@link Application} and their model artifact metadata. + * + * @param criteria the requirements for the model + * @return the available {@link Application} and their model artifact metadata + * @throws IOException if failed to download to repository metadata + * @throws ModelNotFoundException if failed to parse repository metadata + */ + static Map> listModels(Criteria criteria) + throws IOException, ModelNotFoundException { + String artifactId = criteria.getArtifactId(); + ModelZoo modelZoo = criteria.getModelZoo(); + String groupId = criteria.getGroupId(); + String engine = criteria.getEngine(); + Application application = criteria.getApplication(); + @SuppressWarnings("PMD.UseConcurrentHashMap") Map> models = new TreeMap<>(Comparator.comparing(Application::getPath)); @@ -192,9 +211,29 @@ static Map> listModels() if (zoo == null) { continue; } + if (modelZoo != null) { + if (groupId != null && !modelZoo.getGroupId().equals(groupId)) { + continue; + } + Set supportedEngine = modelZoo.getSupportedEngines(); + if (engine != null && !supportedEngine.contains(engine)) { + continue; + } + } List list = zoo.getModelLoaders(); for (ModelLoader loader : list) { Application app = loader.getApplication(); + String loaderArtifactId = loader.getArtifactId(); + if (artifactId != null && !artifactId.equals(loaderArtifactId)) { + // filter out by model loader artifactId + continue; + } + if (application != Application.UNDEFINED + && app != Application.UNDEFINED + && !app.matches(application)) { + // filter out ModelLoader by application + continue; + } final List artifacts = loader.listModels(); models.compute( app, diff --git a/examples/src/test/java/ai/djl/examples/inference/ListModelsTest.java b/examples/src/test/java/ai/djl/examples/inference/ListModelsTest.java index 55b2a43eac8..5f26852fe7a 100644 --- a/examples/src/test/java/ai/djl/examples/inference/ListModelsTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/ListModelsTest.java @@ -13,8 +13,11 @@ package ai.djl.examples.inference; import ai.djl.Application; +import ai.djl.Application.CV; +import ai.djl.Application.NLP; import ai.djl.ModelException; import ai.djl.repository.Artifact; +import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelZoo; import java.io.IOException; import java.nio.file.Path; @@ -41,4 +44,19 @@ public void testListModels() throws ModelException, IOException { List artifacts = models.get(Application.UNDEFINED); Assert.assertFalse(artifacts.isEmpty()); } + + @Test + public void testListModelsWithApplication() throws ModelException, IOException { + Path path = Paths.get("../model-zoo/src/test/resources/mlrepo"); + String repoUrl = path.toRealPath().toAbsolutePath().toUri().toURL().toExternalForm(); + System.setProperty("ai.djl.repository.zoo.location", "src/test/resources," + repoUrl); + Criteria criteria = Criteria.builder().optApplication(NLP.ANY).build(); + Map> models = ModelZoo.listModels(criteria); + + for (Application application : models.keySet()) { + Assert.assertTrue( + application.matches(NLP.ANY) || application.matches(Application.UNDEFINED)); + Assert.assertFalse(application.matches(CV.ANY)); + } + } }