Skip to content

Commit

Permalink
Openai model version (#99)
Browse files Browse the repository at this point in the history
* intermediate

* ready-to-discuss state

* Fix merge; Format

* Add javadoc

* Fix PMD warnings

* Initial change of draft

* Initial API change

* Fix merge conflicts

* Update readme

* Fix PMD

* Formatting

* Apply suggestions from code review

Co-authored-by: Charles Dubois <[email protected]>

* Defuse functions

* Update core/src/main/java/com/sap/ai/sdk/core/AiCoreServiceWithDeployment.java

Co-authored-by: Charles Dubois <[email protected]>

* Make service binding accessor changeable, e.g. for future testing

* Rename methods and types; Fix merge

* Formatting

* Fix code, tests are working

* Fix header provisioning

* Add tests; Move stuff around

* make base class extensible

* work in progress

* refine work in progress

* Fix tests

* Update test

* Add annotations

* Format

* Format; JavaDoc

* Restructure code

* Fix PMD

* Fix PMD

* Fix test

* Fix test

* Added model version filtering

* Fix model version filtering

* Apply review comments

* Apply review comments

* Add assertion on AI-Client-Type header

* Minor JavaDoc fix

* Formatting

* Add unit test for model comparison

* Improve test coverage

* Add missing JavaDoc comments

* Add missing JavaDoc comments

* Formatting

* Add missing JavaDoc comments

* Improve coding practice

* Remove duplicate anonymous classes

* Accept review changes

* Fix merge conflicts

* Fix naming

---------

Co-authored-by: Alexander Dümont <[email protected]>
Co-authored-by: SAP Cloud SDK Bot <[email protected]>
Co-authored-by: Alexander Dümont <[email protected]>
Co-authored-by: Charles Dubois <[email protected]>
Co-authored-by: Roshin Rajan Panackal <[email protected]>
  • Loading branch information
6 people authored Oct 18, 2024
1 parent e0078c8 commit 26bc938
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 50 deletions.
11 changes: 5 additions & 6 deletions core/src/main/java/com/sap/ai/sdk/core/AiCoreService.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,18 @@ public AiCoreDeployment forDeployment(@Nonnull final String deploymentId) {
}

/**
* Set a specific deployment by model name. If there are multiple deployments of the same model,
* the first one is returned.
* Set a specific deployment by model. If there are multiple deployments of the same model, the
* first one is returned.
*
* @param modelName The model name to be used for AI Core service calls.
* @param model The model to be used for AI Core service calls.
* @return A new instance of the AI Core Deployment.
* @throws NoSuchElementException if no running deployment is found for the model.
*/
@Nonnull
public AiCoreDeployment forDeploymentByModel(@Nonnull final String modelName)
public AiCoreDeployment forDeploymentByModel(@Nonnull final AiModel model)
throws NoSuchElementException {
return new AiCoreDeployment(
this,
() -> DEPLOYMENT_CACHE.getDeploymentIdByModel(this.client(), resourceGroup, modelName));
this, () -> DEPLOYMENT_CACHE.getDeploymentIdByModel(this.client(), resourceGroup, model));
}

/**
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/java/com/sap/ai/sdk/core/AiModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.sap.ai.sdk.core;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/** An interface defining essential attributes of an AI model. */
public interface AiModel {

/**
* Get the model's name.
*
* @return The name of the model.
*/
@Nonnull
String name();

/**
* Get the model's version.
*
* @return The version of the model, or null if not specified.
*/
@Nullable
String version();
}
44 changes: 25 additions & 19 deletions core/src/main/java/com/sap/ai/sdk/core/DeploymentCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.HashSet;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -21,7 +22,7 @@
class DeploymentCache {

/** Cache for deployment ids. The key is the model name and the value is the deployment id. */
private final Map<String, Set<AiDeployment>> CACHE = new ConcurrentHashMap<>();
private final Map<String, Set<AiDeployment>> cache = new ConcurrentHashMap<>();

/**
* Remove all entries from the cache then load all deployments into the cache.
Expand All @@ -32,48 +33,48 @@ class DeploymentCache {
* @param resourceGroup the resource group, usually "default".
*/
void resetCache(@Nonnull final ApiClient client, @Nonnull final String resourceGroup) {
CACHE.remove(resourceGroup);
cache.remove(resourceGroup);
try {
final var deployments =
new HashSet<>(new DeploymentApi(client).query(resourceGroup).getResources());
CACHE.put(resourceGroup, deployments);
cache.put(resourceGroup, deployments);
} catch (final OpenApiRequestException e) {
log.error("Failed to load deployments into cache", e);
}
}

/**
* Get the deployment id from the foundation model name. If there are multiple deployments of the
* same model, the first one is returned.
* Get the deployment id from the foundation model object. If there are multiple deployments of
* the same model, the first one is returned.
*
* @param client the API client to maybe reset the cache if the deployment is not found.
* @param resourceGroup the resource group, usually "default".
* @param modelName the name of the foundation model.
* @param model the foundation model.
* @return the deployment id.
* @throws NoSuchElementException if no running deployment is found for the model.
*/
@Nonnull
String getDeploymentIdByModel(
@Nonnull final ApiClient client,
@Nonnull final String resourceGroup,
@Nonnull final String modelName)
@Nonnull final AiModel model)
throws NoSuchElementException {
return getDeploymentIdByModel(resourceGroup, modelName)
return getDeploymentIdByModel(resourceGroup, model)
.orElseGet(
() -> {
resetCache(client, resourceGroup);
return getDeploymentIdByModel(resourceGroup, modelName)
return getDeploymentIdByModel(resourceGroup, model)
.orElseThrow(
() ->
new NoSuchElementException(
"No running deployment found for model: " + modelName));
"No running deployment found for model: " + model));
});
}

private Optional<String> getDeploymentIdByModel(
@Nonnull final String resourceGroup, @Nonnull final String modelName) {
return CACHE.getOrDefault(resourceGroup, new HashSet<>()).stream()
.filter(deployment -> isDeploymentOfModel(modelName, deployment))
@Nonnull final String resourceGroup, @Nonnull final AiModel model) {
return cache.getOrDefault(resourceGroup, new HashSet<>()).stream()
.filter(deployment -> isDeploymentOfModel(model, deployment))
.findFirst()
.map(AiDeployment::getId);
}
Expand Down Expand Up @@ -108,7 +109,7 @@ String getDeploymentIdByScenario(

private Optional<String> getDeploymentIdByScenario(
@Nonnull final String resourceGroup, @Nonnull final String scenarioId) {
return CACHE.getOrDefault(resourceGroup, new HashSet<>()).stream()
return cache.getOrDefault(resourceGroup, new HashSet<>()).stream()
.filter(deployment -> scenarioId.equals(deployment.getScenarioId()))
.findFirst()
.map(AiDeployment::getId);
Expand All @@ -117,12 +118,12 @@ private Optional<String> getDeploymentIdByScenario(
/**
* This exists because getBackendDetails() is broken
*
* @param modelName The model name.
* @param targetModel The target model object.
* @param deployment The deployment.
* @return true if the deployment is of the model.
*/
private static boolean isDeploymentOfModel(
@Nonnull final String modelName, @Nonnull final AiDeployment deployment) {
protected static boolean isDeploymentOfModel(
@Nonnull final AiModel targetModel, @Nonnull final AiDeployment deployment) {
final var deploymentDetails = deployment.getDetails();
// The AI Core specification doesn't mention that this is nullable, but it can be.
// Remove this check when the specification is fixed.
Expand All @@ -144,9 +145,14 @@ private static boolean isDeploymentOfModel(

if (detailsObject instanceof Map<?, ?> details
&& details.get("model") instanceof Map<?, ?> model
&& model.get("name") instanceof String name) {
return modelName.equals(name);
&& model.get("name") instanceof String name
&& targetModel.name().equals(name)) {
// if target version is not specified (null), any version is accepted, otherwise they must
// match
return targetModel.version() == null
|| Objects.equals(targetModel.version(), model.get("version"));
}

return false;
}
}
73 changes: 66 additions & 7 deletions core/src/test/java/com/sap/ai/sdk/core/CacheTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.sap.ai.sdk.core.client.WireMockTestServer;
import com.sap.ai.sdk.core.client.model.AiDeployment;
import com.sap.ai.sdk.core.client.model.AiDeploymentDetails;
import com.sap.ai.sdk.core.client.model.AiDeploymentStatus;
import com.sap.ai.sdk.core.client.model.AiResourcesDetails;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import java.time.OffsetDateTime;
import java.util.Map;
import java.util.NoSuchElementException;
import javax.annotation.Nonnull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -56,10 +64,12 @@ void newDeployment() {
String resourceGroup = "default";
stubGPT4(resourceGroup);

cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k");
final AiModel gpt4 = createAiModel("gpt-4-32k", null);

cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4);
wireMockServer.verify(1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments")));

cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k");
cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4);
wireMockServer.verify(1, getRequestedFor(urlPathEqualTo("/v2/lm/deployments")));
}

Expand All @@ -79,11 +89,13 @@ void newDeploymentAfterReset() {
cacheUnderTest.resetCache(client, resourceGroup);
stubGPT4(resourceGroup);

cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k");
final AiModel gpt4 = createAiModel("gpt-4-32k", null);

cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4);
// 1 reset empty and 1 cache miss
wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments")));

cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k");
cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4);
wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments")));
}

Expand All @@ -94,7 +106,9 @@ void resourceGroupIsolation() {
stubGPT4(resourceGroupA);
stubGPT4(resourceGroupB);

cacheUnderTest.getDeploymentIdByModel(client, resourceGroupA, "gpt-4-32k");
final AiModel gpt4 = createAiModel("gpt-4-32k", null);

cacheUnderTest.getDeploymentIdByModel(client, resourceGroupA, gpt4);
wireMockServer.verify(
1,
getRequestedFor(urlPathEqualTo("/v2/lm/deployments"))
Expand All @@ -110,8 +124,9 @@ void exceptionDeploymentNotFound() {
String resourceGroup = "default";
stubEmpty(resourceGroup);

assertThatThrownBy(
() -> cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, "gpt-4-32k"))
final AiModel gpt4 = createAiModel("gpt-4-32k", null);

assertThatThrownBy(() -> cacheUnderTest.getDeploymentIdByModel(client, resourceGroup, gpt4))
.isExactlyInstanceOf(NoSuchElementException.class)
.hasMessageContaining("No running deployment found for model: gpt-4-32k");
}
Expand All @@ -127,4 +142,48 @@ void resetCache() {
new AiCoreService().withDestination(destination).reloadCachedDeployments(resourceGroup);
wireMockServer.verify(2, getRequestedFor(urlPathEqualTo("/v2/lm/deployments")));
}

@Test
public void isDeploymentOfModel() {
// Create a target model
final AiModel gpt4AnyVersion = createAiModel("gpt-4-32k", null);
final AiModel gpt4Version1 = createAiModel("gpt-4-32k", "1.0");
final AiModel gpt4VersionLatest = createAiModel("gpt-4-32k", "latest");

// Create a deployment with a different model by version
final var model = Map.of("model", Map.of("name", "gpt-4-32k", "version", "latest"));
final var deployment =
AiDeployment.create()
.id("test-deployment")
.configurationId("test-configuration")
.status(AiDeploymentStatus.RUNNING)
.createdAt(OffsetDateTime.parse("2024-01-22T17:57:23+00:00"))
.modifiedAt(OffsetDateTime.parse("2024-02-08T08:41:23+00:00"));
deployment.setDetails(AiDeploymentDetails.create().resources(AiResourcesDetails.create()));
deployment.getDetails().getResources().setCustomField("backend_details", model);

// Check if the deployment is of the target model
assertThat(DeploymentCache.isDeploymentOfModel(gpt4AnyVersion, deployment)).isTrue();
assertThat(DeploymentCache.isDeploymentOfModel(gpt4Version1, deployment)).isFalse();
assertThat(DeploymentCache.isDeploymentOfModel(gpt4VersionLatest, deployment)).isTrue();
}

static AiModel createAiModel(String name, String version) {
return new AiModel() {
@Nonnull
@Override
public String name() {
return name;
}

@Override
public String version() {
return version;
}

public String toString() {
return name;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public final class OpenAiClient {
public static OpenAiClient forModel(@Nonnull final OpenAiModel foundationModel) {
final var destination =
new AiCoreService()
.forDeploymentByModel(foundationModel.model())
.forDeploymentByModel(foundationModel)
.withResourceGroup("default")
.destination();
final var client = new OpenAiClient(destination);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,47 +1,61 @@
package com.sap.ai.sdk.foundationmodels.openai;

import com.sap.ai.sdk.core.AiModel;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/**
* Available OpenAI models.
* OpenAI models that are available in AI Core.
*
* @param model a deployed OpenAI model
* @param name The name of the model.
* @param version The version of the model (optional).
*/
public record OpenAiModel(@Nonnull String model) {
public record OpenAiModel(@Nonnull String name, @Nullable String version) implements AiModel {

/** Azure OpenAI dall-e-3 image generate model */
public static final OpenAiModel DALL_E_3 = new OpenAiModel("dall-e-3");
public static final OpenAiModel DALL_E_3 = new OpenAiModel("dall-e-3", null);

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OpenAiModel GPT_35_TURBO = new OpenAiModel("gpt-35-turbo");
public static final OpenAiModel GPT_35_TURBO = new OpenAiModel("gpt-35-turbo", null);

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OpenAiModel GPT_35_TURBO_1025 = new OpenAiModel("gpt-35-turbo-0125");
public static final OpenAiModel GPT_35_TURBO_1025 = new OpenAiModel("gpt-35-turbo-0125", null);

/** Azure OpenAI GPT-3.5 Turbo chat completions model */
public static final OpenAiModel GPT_35_TURBO_16K = new OpenAiModel("gpt-35-turbo-16k");
public static final OpenAiModel GPT_35_TURBO_16K = new OpenAiModel("gpt-35-turbo-16k", null);

/** Azure OpenAI GPT-4 chat completions model */
public static final OpenAiModel GPT_4 = new OpenAiModel("gpt-4");
public static final OpenAiModel GPT_4 = new OpenAiModel("gpt-4", null);

/** Azure OpenAI GPT-4-32k chat completions model */
public static final OpenAiModel GPT_4_32K = new OpenAiModel("gpt-4-32k");
public static final OpenAiModel GPT_4_32K = new OpenAiModel("gpt-4-32k", null);

/** Azure OpenAI GPT-4o chat completions model */
public static final OpenAiModel GPT_4O = new OpenAiModel("gpt-4o");
public static final OpenAiModel GPT_4O = new OpenAiModel("gpt-4o", null);

/** Azure OpenAI GPT-4o Mini chat completions model */
public static final OpenAiModel GPT_4O_MINI = new OpenAiModel("gpt-4o-mini");
public static final OpenAiModel GPT_4O_MINI = new OpenAiModel("gpt-4o-mini", null);

/** Azure OpenAI Text Embedding 3 Large model */
public static final OpenAiModel TEXT_EMBEDDING_3_LARGE =
new OpenAiModel("text-embedding-3-large");
new OpenAiModel("text-embedding-3-large", null);

/** Azure OpenAI Text Embedding 3 Small model */
public static final OpenAiModel TEXT_EMBEDDING_3_SMALL =
new OpenAiModel("text-embedding-3-small");
new OpenAiModel("text-embedding-3-small", null);

/** Azure OpenAI Text Embedding ADA 002 model */
public static final OpenAiModel TEXT_EMBEDDING_ADA_002 =
new OpenAiModel("text-embedding-ada-002");
new OpenAiModel("text-embedding-ada-002", null);

/**
* Create a new instance of OpenAiModel with the provided version.
*
* @param version The version of the model.
* @return The new instance of OpenAiModel.
*/
@Nonnull
public OpenAiModel withVersion(@Nonnull final String version) {
return new OpenAiModel(name, version);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ public AiDeploymentCreationResponse createConfigAndDeploy(final OpenAiModel mode

// Create a configuration
final var modelNameParameter =
AiParameterArgumentBinding.create().key("model").value(model.model());
AiParameterArgumentBinding.create().key("model").value(model.name());
final var modelVersion =
AiParameterArgumentBinding.create().key("modelVersion").value("latest");
final var configurationBaseData =
AiConfigurationBaseData.create()
.name(model.model())
.name(model.name())
.executableId("azure-openai")
.scenarioId("foundation-models")
.addParameterBindingsItem(modelNameParameter)
Expand Down
Loading

0 comments on commit 26bc938

Please sign in to comment.