From 9cc09d960528a496e129186b1715cfd6f923b6ac Mon Sep 17 00:00:00 2001 From: Alex Creasy Date: Thu, 3 Oct 2024 21:02:03 +0100 Subject: [PATCH] Implements pagination support for GetAll style endpoints (#429) Signed-off-by: Alex Creasy --- clients/ui/bff/README.md | 22 +++++++- .../internal/api/model_versions_handler.go | 2 +- .../internal/api/registered_models_handler.go | 15 +++--- clients/ui/bff/internal/data/helpers.go | 38 ++++++++++++++ clients/ui/bff/internal/data/model_version.go | 6 +-- .../bff/internal/data/model_version_test.go | 28 +++++++++- .../ui/bff/internal/data/registered_model.go | 12 ++--- .../internal/data/registered_model_test.go | 52 ++++++++++++++++++- .../mocks/model_registry_client_mock.go | 10 ++-- clients/ui/bff/internal/mocks/types_mock.go | 13 +++++ 10 files changed, 171 insertions(+), 27 deletions(-) create mode 100644 clients/ui/bff/internal/data/helpers.go diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 37028090..4c707bbb 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -188,4 +188,24 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/mode "state": "LIVE", "artifactType": "TYPE_ONE" }}' -``` \ No newline at end of file +``` + +### Pagination +The following query parameters are supported by "Get All" style endpoints to control pagination. + +| Parameter Name | Description | +|----------------|-----------------------------------------------------------------------------------------------------------| +| pageSize | Number of entities in each page | +| orderBy | Specifies the order by criteria for listing entities. Available values: CREATE_TIME, LAST_UPDATE_TIME, ID | +| sortOrder | Specifies the sort order for listing entities. Available values: ASC, DESC. Default: ASC | +| nextPageToken | Token to use to retrieve next page of results. | + +### Sample local calls +``` +# Get with a page size of 5 getting a specific page. +curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&nextPageToken=CAEQARoCCAE" +``` +``` +# Get with a page size of 5, order by last update time in descending order. +curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&orderBy=LAST_UPDATE_TIME&sortOrder=DESC" +``` diff --git a/clients/ui/bff/internal/api/model_versions_handler.go b/clients/ui/bff/internal/api/model_versions_handler.go index 7b979267..88b29f8a 100644 --- a/clients/ui/bff/internal/api/model_versions_handler.go +++ b/clients/ui/bff/internal/api/model_versions_handler.go @@ -157,7 +157,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, return } - data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId)) + data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId), r.URL.Query()) if err != nil { app.serverErrorResponse(w, r, err) return diff --git a/clients/ui/bff/internal/api/registered_models_handler.go b/clients/ui/bff/internal/api/registered_models_handler.go index ffff7cdc..e75d121f 100644 --- a/clients/ui/bff/internal/api/registered_models_handler.go +++ b/clients/ui/bff/internal/api/registered_models_handler.go @@ -4,27 +4,25 @@ import ( "encoding/json" "errors" "fmt" + "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/pkg/openapi" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/validation" "net/http" - - "github.com/julienschmidt/httprouter" - "github.com/kubeflow/model-registry/pkg/openapi" ) type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None] type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None] type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None] -func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - //TODO (ederign) implement pagination +func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return } - modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client) + modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client, r.URL.Query()) if err != nil { app.serverErrorResponse(w, r, err) return @@ -40,7 +38,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req } } -func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) @@ -173,14 +171,13 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - //TODO (acreasy) implement pagination client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return } - versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId)) + versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query()) if err != nil { app.serverErrorResponse(w, r, err) diff --git a/clients/ui/bff/internal/data/helpers.go b/clients/ui/bff/internal/data/helpers.go new file mode 100644 index 00000000..8c19f2c2 --- /dev/null +++ b/clients/ui/bff/internal/data/helpers.go @@ -0,0 +1,38 @@ +package data + +import ( + "fmt" + "net/url" +) + +func FilterPageValues(values url.Values) url.Values { + result := url.Values{} + + if v := values.Get("pageSize"); v != "" { + result.Set("pageSize", v) + } + if v := values.Get("orderBy"); v != "" { + result.Set("orderBy", v) + } + if v := values.Get("sortOrder"); v != "" { + result.Set("sortOrder", v) + } + if v := values.Get("nextPageToken"); v != "" { + result.Set("nextPageToken", v) + } + + return result +} + +func UrlWithParams(url string, values url.Values) string { + queryString := values.Encode() + if queryString == "" { + return url + } + return fmt.Sprintf("%s?%s", url, queryString) +} + +func UrlWithPageParams(url string, values url.Values) string { + pageValues := FilterPageValues(values) + return UrlWithParams(url, pageValues) +} diff --git a/clients/ui/bff/internal/data/model_version.go b/clients/ui/bff/internal/data/model_version.go index c84c08ef..c149a177 100644 --- a/clients/ui/bff/internal/data/model_version.go +++ b/clients/ui/bff/internal/data/model_version.go @@ -16,7 +16,7 @@ type ModelVersionInterface interface { GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) - GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) + GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) } @@ -79,14 +79,14 @@ func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface return &model, nil } -func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { +func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error) { path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath) if err != nil { return nil, err } - responseData, err := client.GET(path) + responseData, err := client.GET(UrlWithPageParams(path, pageValues)) if err != nil { return nil, fmt.Errorf("error fetching model version artifacts: %w", err) } diff --git a/clients/ui/bff/internal/data/model_version_test.go b/clients/ui/bff/internal/data/model_version_test.go index 1f842042..b2ad30d9 100644 --- a/clients/ui/bff/internal/data/model_version_test.go +++ b/clients/ui/bff/internal/data/model_version_test.go @@ -2,6 +2,7 @@ package data import ( "encoding/json" + "fmt" "net/http" "net/url" "testing" @@ -106,7 +107,7 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) { mockClient := new(mocks.MockHTTPClient) mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil) - actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1") + actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", nil) assert.NoError(t, err) assert.NotNil(t, actual) @@ -116,6 +117,31 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) { assert.Equal(t, len(expected.Items), len(actual.Items)) } +func TestGetModelArtifactsByModelVersionWithPageParams(t *testing.T) { + gofakeit.Seed(0) //nolint:errcheck + + pageValues := mocks.GenerateMockPageValues() + expected := mocks.GenerateMockModelArtifactList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath) + assert.NoError(t, err) + reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode()) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodGet, reqUrl, mock.Anything).Return(mockData, nil) + + actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", pageValues) + assert.NoError(t, err) + + assert.NotNil(t, actual) + mockClient.AssertExpectations(t) +} + func TestCreateModelArtifactByModelVersion(t *testing.T) { gofakeit.Seed(0) //nolint:errcheck diff --git a/clients/ui/bff/internal/data/registered_model.go b/clients/ui/bff/internal/data/registered_model.go index 10acd55c..bc082feb 100644 --- a/clients/ui/bff/internal/data/registered_model.go +++ b/clients/ui/bff/internal/data/registered_model.go @@ -13,11 +13,11 @@ const registeredModelPath = "/registered_models" const versionsPath = "/versions" type RegisteredModelInterface interface { - GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) + GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error) CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error) - GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) + GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) } @@ -25,9 +25,9 @@ type RegisteredModel struct { RegisteredModelInterface } -func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) { +func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error) { + responseData, err := client.GET(UrlWithPageParams(registeredModelPath, pageValues)) - responseData, err := client.GET(registeredModelPath) if err != nil { return nil, fmt.Errorf("error fetching registered models: %w", err) } @@ -94,14 +94,14 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt return &model, nil } -func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) { +func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) { path, err := url.JoinPath(registeredModelPath, id, versionsPath) if err != nil { return nil, err } - responseData, err := client.GET(path) + responseData, err := client.GET(UrlWithPageParams(path, pageValues)) if err != nil { return nil, fmt.Errorf("error fetching model versions: %w", err) diff --git a/clients/ui/bff/internal/data/registered_model_test.go b/clients/ui/bff/internal/data/registered_model_test.go index f340527d..5fdd1db8 100644 --- a/clients/ui/bff/internal/data/registered_model_test.go +++ b/clients/ui/bff/internal/data/registered_model_test.go @@ -2,6 +2,7 @@ package data import ( "encoding/json" + "fmt" "net/http" "net/url" "testing" @@ -25,7 +26,7 @@ func TestGetAllRegisteredModels(t *testing.T) { mockClient := new(mocks.MockHTTPClient) mockClient.On("GET", registeredModelPath).Return(mockData, nil) - actual, err := registeredModel.GetAllRegisteredModels(mockClient) + actual, err := registeredModel.GetAllRegisteredModels(mockClient, nil) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.NextPageToken, actual.NextPageToken) @@ -36,6 +37,28 @@ func TestGetAllRegisteredModels(t *testing.T) { mockClient.AssertExpectations(t) } +func TestGetAllRegisteredModelsWithPageParams(t *testing.T) { + gofakeit.Seed(0) //nolint:errcheck + + pageValues := mocks.GenerateMockPageValues() + expected := mocks.GenerateMockRegisteredModelList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + reqUrl := fmt.Sprintf("%s?%s", registeredModelPath, pageValues.Encode()) + + registeredModel := RegisteredModel{} + + mockClient := new(mocks.MockHTTPClient) + mockClient.On("GET", reqUrl).Return(mockData, nil) + + actual, err := registeredModel.GetAllRegisteredModels(mockClient, pageValues) + assert.NoError(t, err) + assert.NotNil(t, actual) + mockClient.AssertExpectations(t) +} + func TestCreateRegisteredModel(t *testing.T) { gofakeit.Seed(0) //nolint:errcheck @@ -126,7 +149,7 @@ func TestGetAllModelVersions(t *testing.T) { assert.NoError(t, err) mockClient.On("GET", path).Return(mockData, nil) - actual, err := registeredModel.GetAllModelVersions(mockClient, "1") + actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil) assert.NoError(t, err) assert.NotNil(t, actual) assert.NoError(t, err) @@ -139,6 +162,31 @@ func TestGetAllModelVersions(t *testing.T) { mockClient.AssertExpectations(t) } +func TestGetAllModelVersionsWithPageParams(t *testing.T) { + gofakeit.Seed(0) //nolint:errcheck + + pageValues := mocks.GenerateMockPageValues() + expected := mocks.GenerateMockModelVersionList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + registeredModel := RegisteredModel{} + + mockClient := new(mocks.MockHTTPClient) + path, err := url.JoinPath(registeredModelPath, "1", versionsPath) + assert.NoError(t, err) + reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode()) + + mockClient.On("GET", reqUrl).Return(mockData, nil) + + actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues) + assert.NoError(t, err) + assert.NotNil(t, actual) + + mockClient.AssertExpectations(t) +} + func TestCreateModelVersionForRegisteredModel(t *testing.T) { gofakeit.Seed(0) //nolint:errcheck diff --git a/clients/ui/bff/internal/mocks/model_registry_client_mock.go b/clients/ui/bff/internal/mocks/model_registry_client_mock.go index b65aca0a..8b397939 100644 --- a/clients/ui/bff/internal/mocks/model_registry_client_mock.go +++ b/clients/ui/bff/internal/mocks/model_registry_client_mock.go @@ -5,17 +5,18 @@ import ( "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/stretchr/testify/mock" "log/slog" + "net/url" ) type ModelRegistryClientMock struct { mock.Mock } -func NewModelRegistryClient(logger *slog.Logger) (*ModelRegistryClientMock, error) { +func NewModelRegistryClient(_ *slog.Logger) (*ModelRegistryClientMock, error) { return &ModelRegistryClientMock{}, nil } -func (m *ModelRegistryClientMock) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) { +func (m *ModelRegistryClientMock) GetAllRegisteredModels(_ integrations.HTTPClientInterface, _ url.Values) (*openapi.RegisteredModelList, error) { mockData := GetRegisteredModelListMock() return &mockData, nil } @@ -50,7 +51,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(client integrations.HTTPCli return &mockData, nil } -func (m *ModelRegistryClientMock) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) { +func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) { mockData := GetModelVersionListMock() return &mockData, nil } @@ -60,10 +61,11 @@ func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client in return &mockData, nil } -func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { +func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelArtifactList, error) { mockData := GetModelArtifactListMock() return &mockData, nil } + func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) { mockData := GetModelArtifactMocks()[0] return &mockData, nil diff --git a/clients/ui/bff/internal/mocks/types_mock.go b/clients/ui/bff/internal/mocks/types_mock.go index f1d81a9a..096d3dfe 100644 --- a/clients/ui/bff/internal/mocks/types_mock.go +++ b/clients/ui/bff/internal/mocks/types_mock.go @@ -4,6 +4,8 @@ import ( "fmt" "github.com/brianvoe/gofakeit/v7" "github.com/kubeflow/model-registry/pkg/openapi" + "net/url" + "strconv" ) func GenerateMockRegisteredModelList() openapi.RegisteredModelList { @@ -125,6 +127,17 @@ func GenerateMockModelArtifactList() openapi.ModelArtifactList { } } +func GenerateMockPageValues() url.Values { + pageValues := url.Values{} + + pageValues.Add("pageSize", strconv.Itoa(gofakeit.Number(1, 100))) + pageValues.Add("orderBy", gofakeit.RandomString([]string{"CREATE_TIME", "LAST_UPDATE_TIME", "ID"})) + pageValues.Add("sortOrder", gofakeit.RandomString([]string{"ASC", "DESC"})) + pageValues.Add("nextPageToken", gofakeit.UUID()) + + return pageValues +} + func randomEpochTime() *string { return stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())) }