diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel index 67e5ca86908f..8463cd28bb32 100644 --- a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel +++ b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel @@ -1,3 +1,4 @@ +load("//dev:go_defs.bzl", "go_test") load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( @@ -24,3 +25,13 @@ go_library( "@org_golang_x_exp//slices", ], ) + +go_test( + name = "embeddings_test", + srcs = ["openai_test.go"], + embed = [":embeddings"], + deps = [ + "//internal/codygateway", + "@com_github_stretchr_testify//require", + ], +) diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/openai.go b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/openai.go index 02321cef344a..57a94b93d483 100644 --- a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/openai.go +++ b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/openai.go @@ -27,6 +27,14 @@ type openaiClient struct { const apiURL = "https://api.openai.com/v1/embeddings" func (c *openaiClient) GenerateEmbeddings(ctx context.Context, input codygateway.EmbeddingsRequest) (*codygateway.EmbeddingsResponse, int, error) { + for _, s := range input.Input { + if s == "" { + // The OpenAI API will return an error if any of the strings in texts is an empty string, + // so fail fast to avoid making tons of retryable requests. + return nil, 0, response.NewHTTPStatusCodeError(http.StatusBadRequest, errors.New("cannot generate embeddings for an empty string")) + } + } + openAIModel, ok := openAIModelMappings[input.Model] if !ok { return nil, 0, response.NewHTTPStatusCodeError(http.StatusBadRequest, errors.Newf("no OpenAI model found for %q", input.Model)) diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/openai_test.go b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/openai_test.go new file mode 100644 index 000000000000..beaefddf849c --- /dev/null +++ b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/openai_test.go @@ -0,0 +1,19 @@ +package embeddings + +import ( + "context" + "testing" + + "github.com/sourcegraph/sourcegraph/internal/codygateway" + "github.com/stretchr/testify/require" +) + +func TestOpenAI(t *testing.T) { + t.Run("errors on empty embedding string", func(t *testing.T) { + client := NewOpenAIClient("") + _, _, err := client.GenerateEmbeddings(context.Background(), codygateway.EmbeddingsRequest{ + Input: []string{"a", ""}, // empty string is invalid + }) + require.ErrorContains(t, err, "empty string") + }) +} diff --git a/enterprise/internal/embeddings/embed/client/openai/BUILD.bazel b/enterprise/internal/embeddings/embed/client/openai/BUILD.bazel index cd641477607f..2c8b196285e8 100644 --- a/enterprise/internal/embeddings/embed/client/openai/BUILD.bazel +++ b/enterprise/internal/embeddings/embed/client/openai/BUILD.bazel @@ -1,3 +1,4 @@ +load("//dev:go_defs.bzl", "go_test") load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( @@ -11,3 +12,13 @@ go_library( "//lib/errors", ], ) + +go_test( + name = "openai_test", + srcs = ["client_test.go"], + embed = [":openai"], + deps = [ + "//internal/conf/conftypes", + "@com_github_stretchr_testify//require", + ], +) diff --git a/enterprise/internal/embeddings/embed/client/openai/client.go b/enterprise/internal/embeddings/embed/client/openai/client.go index 7512bb110b48..8dc3965a8438 100644 --- a/enterprise/internal/embeddings/embed/client/openai/client.go +++ b/enterprise/internal/embeddings/embed/client/openai/client.go @@ -48,6 +48,14 @@ func (c *openaiEmbeddingsClient) GetModelIdentifier() string { // In case of failure, it retries the embedding procedure up to maxRetries. This due to the OpenAI API which // often hangs up when downloading large embedding responses. func (c *openaiEmbeddingsClient) GetEmbeddingsWithRetries(ctx context.Context, texts []string, maxRetries int) ([]float32, error) { + for _, text := range texts { + if text == "" { + // The OpenAI API will return an error if any of the strings in texts is an empty string, + // so fail fast to avoid making tons of retryable requests. + return nil, errors.New("cannot generate embeddings for an empty string") + } + } + embeddings, err := c.getEmbeddings(ctx, texts) if err == nil { return embeddings, nil diff --git a/enterprise/internal/embeddings/embed/client/openai/client_test.go b/enterprise/internal/embeddings/embed/client/openai/client_test.go new file mode 100644 index 000000000000..95c7d8b0ef0e --- /dev/null +++ b/enterprise/internal/embeddings/embed/client/openai/client_test.go @@ -0,0 +1,18 @@ +package openai + +import ( + "context" + "testing" + + "github.com/sourcegraph/sourcegraph/internal/conf/conftypes" + "github.com/stretchr/testify/require" +) + +func TestOpenAI(t *testing.T) { + t.Run("errors on empty embedding string", func(t *testing.T) { + client := NewClient(&conftypes.EmbeddingsConfig{}) + invalidTexts := []string{"a", ""} // empty string is invalid + _, err := client.GetEmbeddingsWithRetries(context.Background(), invalidTexts, 10) + require.ErrorContains(t, err, "empty string") + }) +}