From 26bc938d1569e110650d22b84f1fc760cb7ea1a4 Mon Sep 17 00:00:00 2001 From: Roshin Rajan Panackal <36329474+rpanackal@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:01:46 +0200 Subject: [PATCH] Openai model version (#99) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * intermediate * ready-to-discuss state * Fix merge; Format * Add javadoc * Fix PMD warnings * Initial change of draft * Initial API change * Fix merge conflicts * Update readme * Fix PMD * Formatting * Apply suggestions from code review Co-authored-by: Charles Dubois <103174266+CharlesDuboisSAP@users.noreply.github.com> * Defuse functions * Update core/src/main/java/com/sap/ai/sdk/core/AiCoreServiceWithDeployment.java Co-authored-by: Charles Dubois <103174266+CharlesDuboisSAP@users.noreply.github.com> * Make service binding accessor changeable, e.g. for future testing * Rename methods and types; Fix merge * Formatting * Fix code, tests are working * Fix header provisioning * Add tests; Move stuff around * make base class extensible * work in progress * refine work in progress * Fix tests * Update test * Add annotations * Format * Format; JavaDoc * Restructure code * Fix PMD * Fix PMD * Fix test * Fix test * Added model version filtering * Fix model version filtering * Apply review comments * Apply review comments * Add assertion on AI-Client-Type header * Minor JavaDoc fix * Formatting * Add unit test for model comparison * Improve test coverage * Add missing JavaDoc comments * Add missing JavaDoc comments * Formatting * Add missing JavaDoc comments * Improve coding practice * Remove duplicate anonymous classes * Accept review changes * Fix merge conflicts * Fix naming --------- Co-authored-by: Alexander Dümont Co-authored-by: SAP Cloud SDK Bot Co-authored-by: Alexander Dümont <22489773+newtork@users.noreply.github.com> Co-authored-by: Charles Dubois <103174266+CharlesDuboisSAP@users.noreply.github.com> Co-authored-by: Roshin Rajan Panackal --- .../com/sap/ai/sdk/core/AiCoreService.java | 11 ++- .../java/com/sap/ai/sdk/core/AiModel.java | 24 ++++++ .../com/sap/ai/sdk/core/DeploymentCache.java | 44 ++++++----- .../java/com/sap/ai/sdk/core/CacheTest.java | 73 +++++++++++++++++-- .../foundationmodels/openai/OpenAiClient.java | 2 +- .../foundationmodels/openai/OpenAiModel.java | 42 +++++++---- .../app/controllers/DeploymentController.java | 4 +- .../ai/sdk/app/controllers/ScenarioTest.java | 2 +- 8 files changed, 152 insertions(+), 50 deletions(-) create mode 100644 core/src/main/java/com/sap/ai/sdk/core/AiModel.java diff --git a/core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java b/core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java index 49ae1691..bb7b345d 100644 --- a/core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java +++ b/core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java @@ -119,19 +119,18 @@ public AiCoreDeployment forDeployment(@Nonnull final String deploymentId) { } /** - * Set a specific deployment by model name. If there are multiple deployments of the same model, - * the first one is returned. + * Set a specific deployment by model. If there are multiple deployments of the same model, the + * first one is returned. * - * @param modelName The model name to be used for AI Core service calls. + * @param model The model to be used for AI Core service calls. * @return A new instance of the AI Core Deployment. * @throws NoSuchElementException if no running deployment is found for the model. */ @Nonnull - public AiCoreDeployment forDeploymentByModel(@Nonnull final String modelName) + public AiCoreDeployment forDeploymentByModel(@Nonnull final AiModel model) throws NoSuchElementException { return new AiCoreDeployment( - this, - () -> DEPLOYMENT_CACHE.getDeploymentIdByModel(this.client(), resourceGroup, modelName)); + this, () -> DEPLOYMENT_CACHE.getDeploymentIdByModel(this.client(), resourceGroup, model)); } /** diff --git a/core/src/main/java/com/sap/ai/sdk/core/AiModel.java b/core/src/main/java/com/sap/ai/sdk/core/AiModel.java new file mode 100644 index 00000000..f0d63a55 --- /dev/null +++ b/core/src/main/java/com/sap/ai/sdk/core/AiModel.java @@ -0,0 +1,24 @@ +package com.sap.ai.sdk.core; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** An interface defining essential attributes of an AI model. */ +public interface AiModel { + + /** + * Get the model's name. + * + * @return The name of the model. + */ + @Nonnull + String name(); + + /** + * Get the model's version. + * + * @return The version of the model, or null if not specified. + */ + @Nullable + String version(); +} diff --git a/core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java b/core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java index d14028c6..f39e93ba 100644 --- a/core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java +++ b/core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java @@ -7,6 +7,7 @@ import java.util.HashSet; import java.util.Map; import java.util.NoSuchElementException; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -21,7 +22,7 @@ class DeploymentCache { /** Cache for deployment ids. The key is the model name and the value is the deployment id. */ - private final Map> CACHE = new ConcurrentHashMap<>(); + private final Map> cache = new ConcurrentHashMap<>(); /** * Remove all entries from the cache then load all deployments into the cache. @@ -32,23 +33,23 @@ class DeploymentCache { * @param resourceGroup the resource group, usually "default". */ void resetCache(@Nonnull final ApiClient client, @Nonnull final String resourceGroup) { - CACHE.remove(resourceGroup); + cache.remove(resourceGroup); try { final var deployments = new HashSet<>(new DeploymentApi(client).query(resourceGroup).getResources()); - CACHE.put(resourceGroup, deployments); + cache.put(resourceGroup, deployments); } catch (final OpenApiRequestException e) { log.error("Failed to load deployments into cache", e); } } /** - * Get the deployment id from the foundation model name. If there are multiple deployments of the - * same model, the first one is returned. + * Get the deployment id from the foundation model object. If there are multiple deployments of + * the same model, the first one is returned. * * @param client the API client to maybe reset the cache if the deployment is not found. * @param resourceGroup the resource group, usually "default". - * @param modelName the name of the foundation model. + * @param model the foundation model. * @return the deployment id. * @throws NoSuchElementException if no running deployment is found for the model. */ @@ -56,24 +57,24 @@ void resetCache(@Nonnull final ApiClient client, @Nonnull final String resourceG String getDeploymentIdByModel( @Nonnull final ApiClient client, @Nonnull final String resourceGroup, - @Nonnull final String modelName) + @Nonnull final AiModel model) throws NoSuchElementException { - return getDeploymentIdByModel(resourceGroup, modelName) + return getDeploymentIdByModel(resourceGroup, model) .orElseGet( () -> { resetCache(client, resourceGroup); - return getDeploymentIdByModel(resourceGroup, modelName) + return getDeploymentIdByModel(resourceGroup, model) .orElseThrow( () -> new NoSuchElementException( - "No running deployment found for model: " + modelName)); + "No running deployment found for model: " + model)); }); } private Optional getDeploymentIdByModel( - @Nonnull final String resourceGroup, @Nonnull final String modelName) { - return CACHE.getOrDefault(resourceGroup, new HashSet<>()).stream() - .filter(deployment -> isDeploymentOfModel(modelName, deployment)) + @Nonnull final String resourceGroup, @Nonnull final AiModel model) { + return cache.getOrDefault(resourceGroup, new HashSet<>()).stream() + .filter(deployment -> isDeploymentOfModel(model, deployment)) .findFirst() .map(AiDeployment::getId); } @@ -108,7 +109,7 @@ String getDeploymentIdByScenario( private Optional getDeploymentIdByScenario( @Nonnull final String resourceGroup, @Nonnull final String scenarioId) { - return CACHE.getOrDefault(resourceGroup, new HashSet<>()).stream() + return cache.getOrDefault(resourceGroup, new HashSet<>()).stream() .filter(deployment -> scenarioId.equals(deployment.getScenarioId())) .findFirst() .map(AiDeployment::getId); @@ -117,12 +118,12 @@ private Optional getDeploymentIdByScenario( /** * This exists because getBackendDetails() is broken * - * @param modelName The model name. + * @param targetModel The target model object. * @param deployment The deployment. * @return true if the deployment is of the model. */ - private static boolean isDeploymentOfModel( - @Nonnull final String modelName, @Nonnull final AiDeployment deployment) { + protected static boolean isDeploymentOfModel( + @Nonnull final AiModel targetModel, @Nonnull final AiDeployment deployment) { final var deploymentDetails = deployment.getDetails(); // The AI Core specification doesn't mention that this is nullable, but it can be. // Remove this check when the specification is fixed. @@ -144,9 +145,14 @@ private static boolean isDeploymentOfModel( if (detailsObject instanceof Map details && details.get("model") instanceof Map model - && model.get("name") instanceof String name) { - return modelName.equals(name); + && model.get("name") instanceof String name + && targetModel.name().equals(name)) { + // if target version is not specified (null), any version is accepted, otherwise they must + // match + return targetModel.version() == null + || Objects.equals(targetModel.version(), model.get("version")); } + return false; } } diff --git a/core/src/test/java/com/sap/ai/sdk/core/CacheTest.java b/core/src/test/java/com/sap/ai/sdk/core/CacheTest.java index 950b27af..bd907ed0 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/CacheTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/CacheTest.java @@ -5,11 +5,19 @@ import static com.github.tomakehurst.wiremock.client.WireMock.get; import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.sap.ai.sdk.core.client.WireMockTestServer; +import com.sap.ai.sdk.core.client.model.AiDeployment; +import com.sap.ai.sdk.core.client.model.AiDeploymentDetails; +import com.sap.ai.sdk.core.client.model.AiDeploymentStatus; +import com.sap.ai.sdk.core.client.model.AiResourcesDetails; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import java.time.OffsetDateTime; +import java.util.Map; import java.util.NoSuchElementException; +import javax.annotation.Nonnull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -56,10 +64,12 @@ void newDeployment() { String resourceGroup = "default"; stubGPT4(resourceGroup); - cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k"); + final AiModel gpt4 = createAiModel("gpt-4-32k", null); + + cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4); wireMockServer.verify(1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); - cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k"); + cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4); wireMockServer.verify(1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); } @@ -79,11 +89,13 @@ void newDeploymentAfterReset() { cacheUnderTest.resetCache(client, resourceGroup); stubGPT4(resourceGroup); - cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k"); + final AiModel gpt4 = createAiModel("gpt-4-32k", null); + + cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4); // 1 reset empty and 1 cache miss wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); - cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k"); + cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4); wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); } @@ -94,7 +106,9 @@ void resourceGroupIsolation() { stubGPT4(resourceGroupA); stubGPT4(resourceGroupB); - cacheUnderTest.getDeploymentIdByModel(client, resourceGroupA, "gpt-4-32k"); + final AiModel gpt4 = createAiModel("gpt-4-32k", null); + + cacheUnderTest.getDeploymentIdByModel(client, resourceGroupA, gpt4); wireMockServer.verify( 1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments")) @@ -110,8 +124,9 @@ void exceptionDeploymentNotFound() { String resourceGroup = "default"; stubEmpty(resourceGroup); - assertThatThrownBy( - () -> cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k")) + final AiModel gpt4 = createAiModel("gpt-4-32k", null); + + assertThatThrownBy(() -> cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4)) .isExactlyInstanceOf(NoSuchElementException.class) .hasMessageContaining("No running deployment found for model: gpt-4-32k"); } @@ -127,4 +142,48 @@ void resetCache() { new AiCoreService().withDestination(destination).reloadCachedDeployments(resourceGroup); wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))); } + + @Test + public void isDeploymentOfModel() { + // Create a target model + final AiModel gpt4AnyVersion = createAiModel("gpt-4-32k", null); + final AiModel gpt4Version1 = createAiModel("gpt-4-32k", "1.0"); + final AiModel gpt4VersionLatest = createAiModel("gpt-4-32k", "latest"); + + // Create a deployment with a different model by version + final var model = Map.of("model", Map.of("name", "gpt-4-32k", "version", "latest")); + final var deployment = + AiDeployment.create() + .id("test-deployment") + .configurationId("test-configuration") + .status(AiDeploymentStatus.RUNNING) + .createdAt(OffsetDateTime.parse("2024-01-22T17:57:23+00:00")) + .modifiedAt(OffsetDateTime.parse("2024-02-08T08:41:23+00:00")); + deployment.setDetails(AiDeploymentDetails.create().resources(AiResourcesDetails.create())); + deployment.getDetails().getResources().setCustomField("backend_details", model); + + // Check if the deployment is of the target model + assertThat(DeploymentCache.isDeploymentOfModel(gpt4AnyVersion, deployment)).isTrue(); + assertThat(DeploymentCache.isDeploymentOfModel(gpt4Version1, deployment)).isFalse(); + assertThat(DeploymentCache.isDeploymentOfModel(gpt4VersionLatest, deployment)).isTrue(); + } + + static AiModel createAiModel(String name, String version) { + return new AiModel() { + @Nonnull + @Override + public String name() { + return name; + } + + @Override + public String version() { + return version; + } + + public String toString() { + return name; + } + }; + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java index 25a5d922..35f98610 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java @@ -60,7 +60,7 @@ public final class OpenAiClient { public static OpenAiClient forModel(@Nonnull final OpenAiModel foundationModel) { final var destination = new AiCoreService() - .forDeploymentByModel(foundationModel.model()) + .forDeploymentByModel(foundationModel) .withResourceGroup("default") .destination(); final var client = new OpenAiClient(destination); diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiModel.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiModel.java index c14ac402..ec50a205 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiModel.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiModel.java @@ -1,47 +1,61 @@ package com.sap.ai.sdk.foundationmodels.openai; +import com.sap.ai.sdk.core.AiModel; import javax.annotation.Nonnull; +import javax.annotation.Nullable; /** - * Available OpenAI models. + * OpenAI models that are available in AI Core. * - * @param model a deployed OpenAI model + * @param name The name of the model. + * @param version The version of the model (optional). */ -public record OpenAiModel(@Nonnull String model) { +public record OpenAiModel(@Nonnull String name, @Nullable String version) implements AiModel { /** Azure OpenAI dall-e-3 image generate model */ - public static final OpenAiModel DALL_E_3 = new OpenAiModel("dall-e-3"); + public static final OpenAiModel DALL_E_3 = new OpenAiModel("dall-e-3", null); /** Azure OpenAI GPT-3.5 Turbo chat completions model */ - public static final OpenAiModel GPT_35_TURBO = new OpenAiModel("gpt-35-turbo"); + public static final OpenAiModel GPT_35_TURBO = new OpenAiModel("gpt-35-turbo", null); /** Azure OpenAI GPT-3.5 Turbo chat completions model */ - public static final OpenAiModel GPT_35_TURBO_1025 = new OpenAiModel("gpt-35-turbo-0125"); + public static final OpenAiModel GPT_35_TURBO_1025 = new OpenAiModel("gpt-35-turbo-0125", null); /** Azure OpenAI GPT-3.5 Turbo chat completions model */ - public static final OpenAiModel GPT_35_TURBO_16K = new OpenAiModel("gpt-35-turbo-16k"); + public static final OpenAiModel GPT_35_TURBO_16K = new OpenAiModel("gpt-35-turbo-16k", null); /** Azure OpenAI GPT-4 chat completions model */ - public static final OpenAiModel GPT_4 = new OpenAiModel("gpt-4"); + public static final OpenAiModel GPT_4 = new OpenAiModel("gpt-4", null); /** Azure OpenAI GPT-4-32k chat completions model */ - public static final OpenAiModel GPT_4_32K = new OpenAiModel("gpt-4-32k"); + public static final OpenAiModel GPT_4_32K = new OpenAiModel("gpt-4-32k", null); /** Azure OpenAI GPT-4o chat completions model */ - public static final OpenAiModel GPT_4O = new OpenAiModel("gpt-4o"); + public static final OpenAiModel GPT_4O = new OpenAiModel("gpt-4o", null); /** Azure OpenAI GPT-4o Mini chat completions model */ - public static final OpenAiModel GPT_4O_MINI = new OpenAiModel("gpt-4o-mini"); + public static final OpenAiModel GPT_4O_MINI = new OpenAiModel("gpt-4o-mini", null); /** Azure OpenAI Text Embedding 3 Large model */ public static final OpenAiModel TEXT_EMBEDDING_3_LARGE = - new OpenAiModel("text-embedding-3-large"); + new OpenAiModel("text-embedding-3-large", null); /** Azure OpenAI Text Embedding 3 Small model */ public static final OpenAiModel TEXT_EMBEDDING_3_SMALL = - new OpenAiModel("text-embedding-3-small"); + new OpenAiModel("text-embedding-3-small", null); /** Azure OpenAI Text Embedding ADA 002 model */ public static final OpenAiModel TEXT_EMBEDDING_ADA_002 = - new OpenAiModel("text-embedding-ada-002"); + new OpenAiModel("text-embedding-ada-002", null); + + /** + * Create a new instance of OpenAiModel with the provided version. + * + * @param version The version of the model. + * @return The new instance of OpenAiModel. + */ + @Nonnull + public OpenAiModel withVersion(@Nonnull final String version) { + return new OpenAiModel(name, version); + } } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java index 99abb794..b944309e 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/DeploymentController.java @@ -142,12 +142,12 @@ public AiDeploymentCreationResponse createConfigAndDeploy(final OpenAiModel mode // Create a configuration final var modelNameParameter = - AiParameterArgumentBinding.create().key("model").value(model.model()); + AiParameterArgumentBinding.create().key("model").value(model.name()); final var modelVersion = AiParameterArgumentBinding.create().key("modelVersion").value("latest"); final var configurationBaseData = AiConfigurationBaseData.create() - .name(model.model()) + .name(model.name()) .executableId("azure-openai") .scenarioId("foundation-models") .addParameterBindingsItem(modelNameParameter) diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java index bc154d32..a13bc6cb 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java @@ -34,7 +34,7 @@ public void openAiModelAvailability() { List declaredOpenAiModelList = new ArrayList<>(); for (Field field : declaredFields) { if (field.getType().equals(OpenAiModel.class)) { - declaredOpenAiModelList.add(((OpenAiModel) field.get(null)).model()); + declaredOpenAiModelList.add(((OpenAiModel) field.get(null)).name()); } }