Skip to content

Commit

Permalink
Implements pagination support for GetAll style endpoints (#429)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy authored Oct 3, 2024
1 parent 50bf73e commit 9cc09d9
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 27 deletions.
22 changes: 21 additions & 1 deletion clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,24 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/mode
"state": "LIVE",
"artifactType": "TYPE_ONE"
}}'
```
```

### Pagination
The following query parameters are supported by "Get All" style endpoints to control pagination.

| Parameter Name | Description |
|----------------|-----------------------------------------------------------------------------------------------------------|
| pageSize | Number of entities in each page |
| orderBy | Specifies the order by criteria for listing entities. Available values: CREATE_TIME, LAST_UPDATE_TIME, ID |
| sortOrder | Specifies the sort order for listing entities. Available values: ASC, DESC. Default: ASC |
| nextPageToken | Token to use to retrieve next page of results. |

### Sample local calls
```
# Get with a page size of 5 getting a specific page.
curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&nextPageToken=CAEQARoCCAE"
```
```
# Get with a page size of 5, order by last update time in descending order.
curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&orderBy=LAST_UPDATE_TIME&sortOrder=DESC"
```
2 changes: 1 addition & 1 deletion clients/ui/bff/internal/api/model_versions_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter,
return
}

data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId))
data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId), r.URL.Query())
if err != nil {
app.serverErrorResponse(w, r, err)
return
Expand Down
15 changes: 6 additions & 9 deletions clients/ui/bff/internal/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,25 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"github.com/kubeflow/model-registry/ui/bff/internal/validation"
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
)

type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None]
type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None]
type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None]

func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (ederign) implement pagination
func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client)
modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client, r.URL.Query())
if err != nil {
app.serverErrorResponse(w, r, err)
return
Expand All @@ -40,7 +38,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req
}
}

func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
Expand Down Expand Up @@ -173,14 +171,13 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ
}

func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (acreasy) implement pagination
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId))
versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query())

if err != nil {
app.serverErrorResponse(w, r, err)
Expand Down
38 changes: 38 additions & 0 deletions clients/ui/bff/internal/data/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package data

import (
"fmt"
"net/url"
)

func FilterPageValues(values url.Values) url.Values {
result := url.Values{}

if v := values.Get("pageSize"); v != "" {
result.Set("pageSize", v)
}
if v := values.Get("orderBy"); v != "" {
result.Set("orderBy", v)
}
if v := values.Get("sortOrder"); v != "" {
result.Set("sortOrder", v)
}
if v := values.Get("nextPageToken"); v != "" {
result.Set("nextPageToken", v)
}

return result
}

func UrlWithParams(url string, values url.Values) string {
queryString := values.Encode()
if queryString == "" {
return url
}
return fmt.Sprintf("%s?%s", url, queryString)
}

func UrlWithPageParams(url string, values url.Values) string {
pageValues := FilterPageValues(values)
return UrlWithParams(url, pageValues)
}
6 changes: 3 additions & 3 deletions clients/ui/bff/internal/data/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ModelVersionInterface interface {
GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error)
CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error)
UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error)
CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error)
}

Expand Down Expand Up @@ -79,14 +79,14 @@ func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface
return &model, nil
}

func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error) {
path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath)

if err != nil {
return nil, err
}

responseData, err := client.GET(path)
responseData, err := client.GET(UrlWithPageParams(path, pageValues))
if err != nil {
return nil, fmt.Errorf("error fetching model version artifacts: %w", err)
}
Expand Down
28 changes: 27 additions & 1 deletion clients/ui/bff/internal/data/model_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package data

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
Expand Down Expand Up @@ -106,7 +107,7 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) {
mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil)

actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1")
actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", nil)
assert.NoError(t, err)

assert.NotNil(t, actual)
Expand All @@ -116,6 +117,31 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) {
assert.Equal(t, len(expected.Items), len(actual.Items))
}

func TestGetModelArtifactsByModelVersionWithPageParams(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockModelArtifactList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

modelVersion := ModelVersion{}

path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath)
assert.NoError(t, err)
reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode())

mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodGet, reqUrl, mock.Anything).Return(mockData, nil)

actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", pageValues)
assert.NoError(t, err)

assert.NotNil(t, actual)
mockClient.AssertExpectations(t)
}

func TestCreateModelArtifactByModelVersion(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

Expand Down
12 changes: 6 additions & 6 deletions clients/ui/bff/internal/data/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ const registeredModelPath = "/registered_models"
const versionsPath = "/versions"

type RegisteredModelInterface interface {
GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error)
GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error)
CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error)
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error)
CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
}

type RegisteredModel struct {
RegisteredModelInterface
}

func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error) {
responseData, err := client.GET(UrlWithPageParams(registeredModelPath, pageValues))

responseData, err := client.GET(registeredModelPath)
if err != nil {
return nil, fmt.Errorf("error fetching registered models: %w", err)
}
Expand Down Expand Up @@ -94,14 +94,14 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt
return &model, nil
}

func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) {
func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) {
path, err := url.JoinPath(registeredModelPath, id, versionsPath)

if err != nil {
return nil, err
}

responseData, err := client.GET(path)
responseData, err := client.GET(UrlWithPageParams(path, pageValues))

if err != nil {
return nil, fmt.Errorf("error fetching model versions: %w", err)
Expand Down
52 changes: 50 additions & 2 deletions clients/ui/bff/internal/data/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package data

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
Expand All @@ -25,7 +26,7 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", registeredModelPath).Return(mockData, nil)

actual, err := registeredModel.GetAllRegisteredModels(mockClient)
actual, err := registeredModel.GetAllRegisteredModels(mockClient, nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
Expand All @@ -36,6 +37,28 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllRegisteredModelsWithPageParams(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockRegisteredModelList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

reqUrl := fmt.Sprintf("%s?%s", registeredModelPath, pageValues.Encode())

registeredModel := RegisteredModel{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllRegisteredModels(mockClient, pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)
mockClient.AssertExpectations(t)
}

func TestCreateRegisteredModel(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

Expand Down Expand Up @@ -126,7 +149,7 @@ func TestGetAllModelVersions(t *testing.T) {
assert.NoError(t, err)
mockClient.On("GET", path).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1")
actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.NoError(t, err)
Expand All @@ -139,6 +162,31 @@ func TestGetAllModelVersions(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllModelVersionsWithPageParams(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockModelVersionList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

registeredModel := RegisteredModel{}

mockClient := new(mocks.MockHTTPClient)
path, err := url.JoinPath(registeredModelPath, "1", versionsPath)
assert.NoError(t, err)
reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode())

mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)

mockClient.AssertExpectations(t)
}

func TestCreateModelVersionForRegisteredModel(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

Expand Down
10 changes: 6 additions & 4 deletions clients/ui/bff/internal/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ import (
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"github.com/stretchr/testify/mock"
"log/slog"
"net/url"
)

type ModelRegistryClientMock struct {
mock.Mock
}

func NewModelRegistryClient(logger *slog.Logger) (*ModelRegistryClientMock, error) {
func NewModelRegistryClient(_ *slog.Logger) (*ModelRegistryClientMock, error) {
return &ModelRegistryClientMock{}, nil
}

func (m *ModelRegistryClientMock) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
func (m *ModelRegistryClientMock) GetAllRegisteredModels(_ integrations.HTTPClientInterface, _ url.Values) (*openapi.RegisteredModelList, error) {
mockData := GetRegisteredModelListMock()
return &mockData, nil
}
Expand Down Expand Up @@ -50,7 +51,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(client integrations.HTTPCli
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) {
func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) {
mockData := GetModelVersionListMock()
return &mockData, nil
}
Expand All @@ -60,10 +61,11 @@ func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client in
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelArtifactList, error) {
mockData := GetModelArtifactListMock()
return &mockData, nil
}

func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) {
mockData := GetModelArtifactMocks()[0]
return &mockData, nil
Expand Down
Loading

0 comments on commit 9cc09d9

Please sign in to comment.