From 41c7754e058c4993ac7664bd15e619d8c640baa0 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Fri, 4 Oct 2024 02:56:49 +0000 Subject: [PATCH 01/24] Add SetTag endpoint Signed-off-by: Juan Escalada --- .devcontainer/Dockerfile | 40 +- .devcontainer/devcontainer.json | 130 +- .devcontainer/docker-compose.yml | 76 +- .devcontainer/postCreate.sh | 14 +- .github/dependabot.yml | 32 +- .github/workflows/ci.yml | 36 +- .github/workflows/lint.yml | 44 +- .github/workflows/test.yml | 118 +- .gitignore | 30 +- .golangci.yml | 110 +- .mockery.yaml | 16 +- .pre-commit-config.yaml | 68 +- LICENSE | 402 ++-- README.md | 630 +++--- conftest.py | 128 +- go.mod | 124 +- go.sum | 472 ++--- magefiles/dev.go | 56 +- magefiles/endpoints.go | 112 +- magefiles/generate.go | 90 +- magefiles/generate/ast_creation.go | 204 +- magefiles/generate/discovery/discovery.go | 194 +- .../generate/discovery/discovery_test.go | 110 +- magefiles/generate/endpoints.go | 174 +- magefiles/generate/protos.go | 116 +- magefiles/generate/query_annotations.go | 232 +-- magefiles/generate/source_code.go | 936 ++++----- magefiles/generate/validations.go | 60 +- magefiles/repo.go | 440 ++-- magefiles/temp.go | 148 +- magefiles/tests.go | 177 +- mlflow_go/__init__.py | 40 +- mlflow_go/cli.py | 224 +- mlflow_go/lib.py | 248 +-- mlflow_go/server.py | 62 +- mlflow_go/store/_service_proxy.py | 86 +- mlflow_go/store/model_registry.py | 110 +- mlflow_go/store/tracking.py | 375 ++-- pkg/artifacts/service/service.go | 34 +- pkg/cmd/server/main.go | 42 +- pkg/config/config.go | 212 +- pkg/config/config_test.go | 106 +- pkg/contract/error.go | 164 +- pkg/contract/http_request_parser.go | 16 +- pkg/contract/service/tracking.g.go | 2 + pkg/entities/dataset.go | 70 +- pkg/entities/dataset_input.go | 40 +- pkg/entities/experiment.go | 72 +- pkg/entities/experiment_tag.go | 44 +- pkg/entities/input_tag.go | 30 +- pkg/entities/metric.go | 104 +- pkg/entities/param.go | 44 +- pkg/entities/run.go | 150 +- pkg/entities/run_data.go | 14 +- pkg/entities/run_info.go | 68 +- pkg/entities/run_inputs.go | 10 +- pkg/entities/run_tag.go | 44 +- pkg/lib/artifacts.go | 44 +- pkg/lib/ffi.go | 182 +- pkg/lib/instance_map.go | 156 +- pkg/lib/main.go | 6 +- pkg/lib/model_registry.go | 44 +- pkg/lib/server.go | 166 +- pkg/lib/tracking.g.go | 16 + pkg/lib/tracking.go | 44 +- pkg/lib/validation.go | 46 +- pkg/model_registry/service/model_versions.go | 42 +- pkg/model_registry/service/service.go | 54 +- .../store/sql/model_versions.go | 188 +- .../store/sql/models/model_version_stage.go | 66 +- .../store/sql/models/model_version_tags.go | 22 +- .../store/sql/models/model_versions.go | 98 +- .../sql/models/registered_model_aliases.go | 16 +- .../store/sql/models/registered_model_tags.go | 16 +- .../store/sql/models/registered_models.go | 18 +- pkg/model_registry/store/sql/store.go | 56 +- pkg/model_registry/store/store.go | 24 +- pkg/protos/artifacts/mlflow_artifacts.pb.go | 2 +- pkg/protos/databricks.pb.go | 2 +- pkg/protos/databricks_artifacts.pb.go | 2 +- pkg/protos/internal.pb.go | 2 +- pkg/protos/model_registry.pb.go | 2 +- pkg/protos/scalapb/scalapb.pb.go | 2 +- pkg/protos/service.pb.go | 10 +- pkg/server/command/command.go | 84 +- pkg/server/command/command_posix.go | 60 +- pkg/server/command/command_windows.go | 154 +- pkg/server/launch.go | 172 +- pkg/server/parser/http_request_parser.go | 144 +- pkg/server/routes/tracking.g.go | 22 + pkg/server/server.go | 446 ++-- pkg/sql/logger.go | 278 +-- pkg/sql/sql.go | 180 +- pkg/tracking/service/experiments.go | 268 +-- pkg/tracking/service/experiments_test.go | 122 +- pkg/tracking/service/metrics.go | 40 +- pkg/tracking/service/query/README.md | 16 +- pkg/tracking/service/query/lexer/token.go | 222 +- pkg/tracking/service/query/lexer/tokenizer.go | 290 +-- .../service/query/lexer/tokenizer_test.go | 228 +-- pkg/tracking/service/query/parser/ast.go | 274 +-- pkg/tracking/service/query/parser/parser.go | 530 ++--- .../service/query/parser/parser_test.go | 364 ++-- pkg/tracking/service/query/parser/validate.go | 658 +++--- pkg/tracking/service/query/query.go | 74 +- pkg/tracking/service/query/query_test.go | 224 +- pkg/tracking/service/runs.go | 336 +-- pkg/tracking/service/service.go | 54 +- pkg/tracking/service/tags.go | 27 + pkg/tracking/store/mock_tracking_store.go | 165 +- pkg/tracking/store/sql/experiments.go | 508 ++--- pkg/tracking/store/sql/metrics.go | 386 ++-- .../store/sql/models/alembic_version.go | 22 +- pkg/tracking/store/sql/models/datasets.go | 56 +- .../store/sql/models/experiment_tags.go | 20 +- pkg/tracking/store/sql/models/experiments.go | 80 +- pkg/tracking/store/sql/models/input_tags.go | 38 +- pkg/tracking/store/sql/models/inputs.go | 56 +- .../store/sql/models/latest_metrics.go | 50 +- pkg/tracking/store/sql/models/lifecycle.go | 24 +- pkg/tracking/store/sql/models/metrics.go | 114 +- pkg/tracking/store/sql/models/params.go | 54 +- pkg/tracking/store/sql/models/runs.go | 212 +- pkg/tracking/store/sql/models/tags.go | 62 +- pkg/tracking/store/sql/params.go | 238 +-- pkg/tracking/store/sql/runs.go | 1800 ++++++++--------- pkg/tracking/store/sql/runs_internal_test.go | 1036 +++++----- pkg/tracking/store/sql/store.go | 56 +- pkg/tracking/store/sql/tags.go | 344 +++- pkg/tracking/store/store.go | 160 +- pkg/utils/logger.go | 98 +- pkg/utils/naming.go | 142 +- pkg/utils/path.go | 180 +- pkg/utils/pointers.go | 82 +- pkg/utils/strings.go | 48 +- pkg/utils/tags.go | 12 +- pkg/validation/validation.go | 598 +++--- pkg/validation/validation_test.go | 488 ++--- pyproject.toml | 346 ++-- setup.py | 132 +- tests/override_model_registry_store.py | 10 +- tests/override_server.py | 154 +- tests/override_test_sqlalchemy_store.py | 34 +- tests/override_tracking_store.py | 10 +- 144 files changed, 11338 insertions(+), 11000 deletions(-) create mode 100644 pkg/tracking/service/tags.go diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 2b4ad23..3c75a68 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,20 +1,20 @@ -FROM mcr.microsoft.com/devcontainers/go:1-1.22-bookworm - -# [Optional] Uncomment this section to install additional OS packages. -RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ - && apt-get -y install --no-install-recommends \ - postgresql-client \ - sqlite3 \ - && rm -rf /var/lib/apt/lists/* - -# [Optional] Uncomment the next lines to use go get to install anything else you need -USER vscode -RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.33.0 \ - && go install github.com/vektra/mockery/v2@v2.43.2 \ - && go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.59.1 \ - && go install github.com/magefile/mage@v1.15.0 \ - && go clean -cache -modcache -USER root - -# [Optional] Uncomment this line to install global node packages. -# RUN su vscode -c "source /usr/local/share/nvm/nvm.sh && npm install -g " 2>&1 +FROM mcr.microsoft.com/devcontainers/go:1-1.22-bookworm + +# [Optional] Uncomment this section to install additional OS packages. +RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ + && apt-get -y install --no-install-recommends \ + postgresql-client \ + sqlite3 \ + && rm -rf /var/lib/apt/lists/* + +# [Optional] Uncomment the next lines to use go get to install anything else you need +USER vscode +RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.33.0 \ + && go install github.com/vektra/mockery/v2@v2.43.2 \ + && go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.59.1 \ + && go install github.com/magefile/mage@v1.15.0 \ + && go clean -cache -modcache +USER root + +# [Optional] Uncomment this line to install global node packages. +# RUN su vscode -c "source /usr/local/share/nvm/nvm.sh && npm install -g " 2>&1 diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 2a3f0fb..13d7b1f 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,66 +1,66 @@ -// For format details, see https://aka.ms/devcontainer.json. -{ - "name": "MLflow Go", - "dockerComposeFile": "docker-compose.yml", - "service": "app", - "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", - - // Features to add to the dev container. More info: https://containers.dev/features. - "features": { - "ghcr.io/devcontainers/features/github-cli:1": {}, - "ghcr.io/devcontainers/features/python:1": { - "version": "3.8" - }, - "ghcr.io/devcontainers/features/docker-in-docker:2": {}, - "ghcr.io/devcontainers-contrib/features/k6:1": {}, - "ghcr.io/devcontainers-contrib/features/pre-commit:2": {}, - "ghcr.io/devcontainers-contrib/features/protoc-asdf:1": { - "version": "26.0" - }, - "ghcr.io/devcontainers-contrib/features/ruff:1": {} - }, - - // Configure tool-specific properties. - "customizations": { - "vscode": { - "settings": { - "terminal.integrated.defaultProfile.linux": "zsh", - "editor.rulers": [ - 80, - 100 - ], - "editor.formatOnSave": true, - "git.alwaysSignOff": true, - "go.lintTool": "golangci-lint", - "gopls": { - "formatting.local": "github.com/mlflow/mlflow-go", - "formatting.gofumpt": true, - "build.buildFlags": ["-tags=mage"] - }, - "[python]": { - "editor.codeActionsOnSave": { - "source.fixAll": "explicit", - "source.organizeImports": "explicit" - }, - "editor.defaultFormatter": "charliermarsh.ruff" - } - }, - "extensions": [ - "charliermarsh.ruff", - "golang.Go", - "humao.rest-client", - "pbkit.vscode-pbkit", - "tamasfe.even-better-toml" - ] - } - }, - - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [5432], - - // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": ".devcontainer/postCreate.sh" - - // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. - // "remoteUser": "root" +// For format details, see https://aka.ms/devcontainer.json. +{ + "name": "MLflow Go", + "dockerComposeFile": "docker-compose.yml", + "service": "app", + "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", + + // Features to add to the dev container. More info: https://containers.dev/features. + "features": { + "ghcr.io/devcontainers/features/github-cli:1": {}, + "ghcr.io/devcontainers/features/python:1": { + "version": "3.8" + }, + "ghcr.io/devcontainers/features/docker-in-docker:2": {}, + "ghcr.io/devcontainers-contrib/features/k6:1": {}, + "ghcr.io/devcontainers-contrib/features/pre-commit:2": {}, + "ghcr.io/devcontainers-contrib/features/protoc-asdf:1": { + "version": "26.0" + }, + "ghcr.io/devcontainers-contrib/features/ruff:1": {} + }, + + // Configure tool-specific properties. + "customizations": { + "vscode": { + "settings": { + "terminal.integrated.defaultProfile.linux": "zsh", + "editor.rulers": [ + 80, + 100 + ], + "editor.formatOnSave": true, + "git.alwaysSignOff": true, + "go.lintTool": "golangci-lint", + "gopls": { + "formatting.local": "github.com/mlflow/mlflow-go", + "formatting.gofumpt": true, + "build.buildFlags": ["-tags=mage"] + }, + "[python]": { + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } + }, + "extensions": [ + "charliermarsh.ruff", + "golang.Go", + "humao.rest-client", + "pbkit.vscode-pbkit", + "tamasfe.even-better-toml" + ] + } + }, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [5432], + + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": ".devcontainer/postCreate.sh", + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + "remoteUser": "root" } \ No newline at end of file diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 9ebb9ec..11e19cb 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -1,38 +1,38 @@ -volumes: - go-cache: - postgres-data: - -services: - app: - build: - context: . - dockerfile: Dockerfile - volumes: - - ../..:/workspaces:cached - - go-cache:/var/cache/go - environment: - - GOCACHE=/var/cache/go/build - - GOMODCACHE=/var/cache/go/mod - - # Overrides default command so things don't shut down after the process ends. - command: sleep infinity - - # Runs app on the same network as the database container, allows "forwardPorts" in devcontainer.json function. - network_mode: service:db - - # Use "forwardPorts" in **devcontainer.json** to forward an app port locally. - # (Adding the "ports" property to this file will not forward from a Codespace.) - - db: - image: postgres:latest - restart: unless-stopped - volumes: - - postgres-data:/var/lib/postgresql/data - environment: - - POSTGRES_USER=postgres - - POSTGRES_PASSWORD=postgres - - POSTGRES_DB=postgres - - POSTGRES_HOSTNAME=localhost=value - - # Add "forwardPorts": ["5432"] to **devcontainer.json** to forward PostgreSQL locally. - # (Adding the "ports" property to this file will not forward from a Codespace.) +volumes: + go-cache: + postgres-data: + +services: + app: + build: + context: . + dockerfile: Dockerfile + volumes: + - ../..:/workspaces:cached + - go-cache:/var/cache/go + environment: + - GOCACHE=/var/cache/go/build + - GOMODCACHE=/var/cache/go/mod + + # Overrides default command so things don't shut down after the process ends. + command: sleep infinity + + # Runs app on the same network as the database container, allows "forwardPorts" in devcontainer.json function. + network_mode: service:db + + # Use "forwardPorts" in **devcontainer.json** to forward an app port locally. + # (Adding the "ports" property to this file will not forward from a Codespace.) + + db: + image: postgres:latest + restart: unless-stopped + volumes: + - postgres-data:/var/lib/postgresql/data + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=postgres + - POSTGRES_HOSTNAME=localhost=value + + # Add "forwardPorts": ["5432"] to **devcontainer.json** to forward PostgreSQL locally. + # (Adding the "ports" property to this file will not forward from a Codespace.) diff --git a/.devcontainer/postCreate.sh b/.devcontainer/postCreate.sh index d71c8a4..68ef182 100755 --- a/.devcontainer/postCreate.sh +++ b/.devcontainer/postCreate.sh @@ -1,7 +1,7 @@ -#!/bin/sh - -# Fix permissions for the Go cache -sudo chown -R $(id -u):$(id -g) /var/cache/go - -# Install precommit (https://pre-commit.com/) -pre-commit install -t pre-commit -t prepare-commit-msg +#!/bin/sh + +# Fix permissions for the Go cache +sudo chown -R $(id -u):$(id -g) /var/cache/go + +# Install precommit (https://pre-commit.com/) +pre-commit install -t pre-commit -t prepare-commit-msg diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 97720e7..505633d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,16 +1,16 @@ -# To get started with Dependabot version updates, you'll need to specify which -# package ecosystems to update and where the package manifests are located. -# Please see the documentation for more information: -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -# https://containers.dev/guide/dependabot - -version: 2 -updates: - - package-ecosystem: "devcontainers" - directory: "/" - schedule: - interval: weekly - - package-ecosystem: "github-actions" - directory: "/" - schedule: - interval: weekly +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: weekly diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ae98d9a..850711b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,18 +1,18 @@ -name: CI - -on: - push: - branches: - - main - pull_request: - schedule: - # Run daily at 01:34 so we get notified if CI is broken before a pull request - # is submitted. - - cron: "34 1 * * *" - -jobs: - lint: - uses: ./.github/workflows/lint.yml - - test: - uses: ./.github/workflows/test.yml +name: CI + +on: + push: + branches: + - main + pull_request: + schedule: + # Run daily at 01:34 so we get notified if CI is broken before a pull request + # is submitted. + - cron: "34 1 * * *" + +jobs: + lint: + uses: ./.github/workflows/lint.yml + + test: + uses: ./.github/workflows/test.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 83e5398..d4b0f19 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,22 +1,22 @@ -name: Lint - -on: - workflow_call: - -permissions: - contents: read - -jobs: - lint: - name: Lint - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: "1.22" - check-latest: true - cache: false - - name: Run pre-commit hooks - run: pipx run pre-commit run --all-files +name: Lint + +on: + workflow_call: + +permissions: + contents: read + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + check-latest: true + cache: false + - name: Run pre-commit hooks + run: pipx run pre-commit run --all-files diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0becdec..24b20a5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,59 +1,59 @@ -name: Test - -on: - workflow_call: - -permissions: - contents: read - -jobs: - go: - name: Test Go - strategy: - matrix: - runner: [macos-latest, ubuntu-latest, windows-latest] - runs-on: ${{ matrix.runner }} - steps: - - uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: "1.22" - check-latest: true - cache: false - - name: Install mage - run: go install github.com/magefile/mage@v1.15.0 - - name: Run unit tests - run: mage test:unit - - python: - name: Test Python - strategy: - matrix: - runner: [macos-latest, ubuntu-latest, windows-latest] - python: ["3.8", "3.9", "3.10", "3.11", "3.12"] - runs-on: ${{ matrix.runner }} - steps: - - uses: actions/checkout@v4 - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: "1.22" - check-latest: true - cache: false - - name: Install mage - run: go install github.com/magefile/mage@v1.15.0 - - name: Install our package in editable mode - run: pip install -e . - - name: Initialize MLflow repo - run: mage repo:init - - name: Install dependencies - run: pip install pytest==8.1.1 psycopg2-binary -e .mlflow.repo - - name: Run integration tests - run: mage test:python - # Temporary workaround for failing tests - continue-on-error: ${{ matrix.runner != 'ubuntu-latest' }} +name: Test + +on: + workflow_call: + +permissions: + contents: read + +jobs: + go: + name: Test Go + strategy: + matrix: + runner: [macos-latest, ubuntu-latest, windows-latest] + runs-on: ${{ matrix.runner }} + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + check-latest: true + cache: false + - name: Install mage + run: go install github.com/magefile/mage@v1.15.0 + - name: Run unit tests + run: mage test:unit + + python: + name: Test Python + strategy: + matrix: + runner: [macos-latest, ubuntu-latest, windows-latest] + python: ["3.8", "3.9", "3.10", "3.11", "3.12"] + runs-on: ${{ matrix.runner }} + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + check-latest: true + cache: false + - name: Install mage + run: go install github.com/magefile/mage@v1.15.0 + - name: Install our package in editable mode + run: pip install -e . + - name: Initialize MLflow repo + run: mage repo:init + - name: Install dependencies + run: pip install pytest==8.1.1 psycopg2-binary -e .mlflow.repo + - name: Run integration tests + run: mage test:python + # Temporary workaround for failing tests + continue-on-error: ${{ matrix.runner != 'ubuntu-latest' }} diff --git a/.gitignore b/.gitignore index af82b33..cf9ad19 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ -# Artifacts -dist/ -*.egg-info/ -*.so - -# Runs -mlruns/ - -# Cache -__pycache__/ - -# MLflow repo -.mlflow.repo/ - -# JetBrains +# Artifacts +dist/ +*.egg-info/ +*.so + +# Runs +mlruns/ + +# Cache +__pycache__/ + +# MLflow repo +.mlflow.repo/ + +# JetBrains .idea \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index 383e76b..6252ee3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,55 +1,55 @@ -run: - build-tags: - - mage - timeout: 5m - -linters: - enable: - - errcheck - - gosimple - - lll - disable: - - depguard - - gochecknoglobals # Immutable globals are fine. - - exhaustruct # Often the case for protobuf generated code or gorm structs. - - protogetter # We do want to use pointers for memory optimization. - presets: - - bugs - - comment - - complexity - - error - - format - - import - - metalinter - - module - - performance - - sql - - style - - test - - unused - -linters-settings: - gci: - custom-order: true - - sections: - - standard # Standard section: captures all standard packages. - - default # Default section: contains all imports that could not be matched to another section type. - - alias # Alias section: contains all alias imports. This section is not present unless explicitly enabled. - - prefix(github.com/mlflow/mlflow-go) # Custom section: groups all imports with the specified Prefix. - - blank # Blank section: contains all blank imports. This section is not present unless explicitly enabled. - - dot # Dot section: contains all dot imports. This section is not present unless explicitly enabled. - - gofumpt: - module-path: github.com/mlflow/mlflow-go - extra-rules: true - - tagliatelle: - case: - rules: - json: snake - -issues: - exclude-files: - - ".*\\.g\\.go$" - - ".*\\.pb\\.go$" +run: + build-tags: + - mage + timeout: 5m + +linters: + enable: + - errcheck + - gosimple + - lll + disable: + - depguard + - gochecknoglobals # Immutable globals are fine. + - exhaustruct # Often the case for protobuf generated code or gorm structs. + - protogetter # We do want to use pointers for memory optimization. + presets: + - bugs + - comment + - complexity + - error + - format + - import + - metalinter + - module + - performance + - sql + - style + - test + - unused + +linters-settings: + gci: + custom-order: true + + sections: + - standard # Standard section: captures all standard packages. + - default # Default section: contains all imports that could not be matched to another section type. + - alias # Alias section: contains all alias imports. This section is not present unless explicitly enabled. + - prefix(github.com/mlflow/mlflow-go) # Custom section: groups all imports with the specified Prefix. + - blank # Blank section: contains all blank imports. This section is not present unless explicitly enabled. + - dot # Dot section: contains all dot imports. This section is not present unless explicitly enabled. + + gofumpt: + module-path: github.com/mlflow/mlflow-go + extra-rules: true + + tagliatelle: + case: + rules: + json: snake + +issues: + exclude-files: + - ".*\\.g\\.go$" + - ".*\\.pb\\.go$" diff --git a/.mockery.yaml b/.mockery.yaml index b3021f7..33ec946 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -1,8 +1,8 @@ -dir: "{{ .InterfaceDir }}" -filename: "mock_{{ .InterfaceNameSnake }}.go" -with-expecter: true -inpackage: true -packages: - github.com/mlflow/mlflow-go/pkg/tracking/store: - interfaces: - TrackingStore: +dir: "{{ .InterfaceDir }}" +filename: "mock_{{ .InterfaceNameSnake }}.go" +with-expecter: true +inpackage: true +packages: + github.com/mlflow/mlflow-go/pkg/tracking/store: + interfaces: + TrackingStore: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9bfe726..6dc1a46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,34 +1,34 @@ -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 - hooks: - - id: end-of-file-fixer - files: \.(proto|txt|sh|rst)$ - - repo: https://github.com/golangci/golangci-lint - rev: "v1.59.1" - hooks: - - id: golangci-lint-full - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.7 - hooks: - - id: ruff - types_or: [python, pyi, jupyter] - args: [--fix] - - id: ruff-format - types_or: [python, pyi, jupyter] - - repo: local - hooks: - # - id: rstcheck - # name: rstcheck - # entry: rstcheck - # language: system - # files: README.rst - # stages: [commit] - # require_serial: true - - - id: must-have-signoff - name: must-have-signoff - entry: 'grep "Signed-off-by:"' - language: system - stages: [prepare-commit-msg] - require_serial: true +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: end-of-file-fixer + files: \.(proto|txt|sh|rst)$ + - repo: https://github.com/golangci/golangci-lint + rev: "v1.59.1" + hooks: + - id: golangci-lint-full + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.7 + hooks: + - id: ruff + types_or: [python, pyi, jupyter] + args: [--fix] + - id: ruff-format + types_or: [python, pyi, jupyter] + - repo: local + hooks: + # - id: rstcheck + # name: rstcheck + # entry: rstcheck + # language: system + # files: README.rst + # stages: [commit] + # require_serial: true + + - id: must-have-signoff + name: must-have-signoff + entry: 'grep "Signed-off-by:"' + language: system + stages: [prepare-commit-msg] + require_serial: true diff --git a/LICENSE b/LICENSE index 261eeb9..29f81d8 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,201 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index f4d3991..e184a13 100644 --- a/README.md +++ b/README.md @@ -1,316 +1,316 @@ -# Go backend for MLflow - -In order to increase the performance of the tracking server and the various stores, we propose to rewrite the server and store implementation in Go. - -## Usage - -### Installation - -This package is not yet available on PyPI and currently requires the [Go SDK](https://go.dev) to be installed. - -You can then install the package via pip: -```bash -pip install git+https://github.com/jgiannuzzi/mlflow-go.git -``` - -### Using the Go server - -```bash -# Start the Go server with a database URI -# Other databases are supported as well: postgresql, mysql and mssql -mlflow-go server --backend-store-uri sqlite:///mlflow.db -``` - -```python -import mlflow - -# Use the Go server -mlflow.set_tracking_uri("http://localhost:5000") - -# Use MLflow as usual -mlflow.set_experiment("my-experiment") - -with mlflow.start_run(): - mlflow.log_param("param", 1) - mlflow.log_metric("metric", 2) -``` - -### Using the client-side Go implementation - -```python -import mlflow -import mlflow_go - -# Enable the Go client implementation (disabled by default) -mlflow_go.enable_go() - -# Set the tracking URI (you can also set it via the environment variable MLFLOW_TRACKING_URI) -# Currently only database URIs are supported -mlflow.set_tracking_uri("sqlite:///mlflow.db") - -# Use MLflow as usual -mlflow.set_experiment("my-experiment") - -with mlflow.start_run(): - mlflow.log_param("param", 1) - mlflow.log_metric("metric", 2) -``` - -## Temp stuff - -### Dev setup - -```bash -# Install our Python package and its dependencies -pip install -e . - -# Install the dreaded psycho -pip install psycopg2-binary - -# Archive the MLFlow pre-built UI -tar -C /usr/local/python/current/lib/python3.8/site-packages/mlflow -czvf ./ui.tgz ./server/js/build - -# Clone the MLflow repo -git clone https://github.com/jgiannuzzi/mlflow.git -b master .mlflow.repo - -# Add the UI back to it -tar -C .mlflow.repo/mlflow -xzvf ./ui.tgz - -# Install it in editable mode -pip install -e .mlflow.repo -``` - -or run `mage temp`. - -### Run the tests manually - -```bash -# Build the Go binary in a temporary directory -libpath=$(mktemp -d) -python -m mlflow_go.lib . $libpath - -# Run the tests (currently just the server ones) -MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ - .mlflow.repo/tests/tracking/test_rest_tracking.py \ - .mlflow.repo/tests/tracking/test_model_registry.py \ - .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py \ - .mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py \ - -k 'not [file' - -# Remove the Go binary -rm -rf $libpath - -# If you want to run a specific test with more verbosity -# -s for live output -# --log-level=debug for more verbosity (passed down to the Go server/stores) -MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ - .mlflow.repo/tests/tracking/test_rest_tracking.py::test_create_experiment_validation \ - -k 'not [file' \ - -s --log-level=debug -``` - -Or run the `mage test:python` target. - -### Use the Go store directly in Python - -```python -import logging -import mlflow -import mlflow_go - -# Enable debug logging -logging.basicConfig() -logging.getLogger('mlflow_go').setLevel(logging.DEBUG) - -# Enable the Go client implementation (disabled by default) -mlflow_go.enable_go() - -# Instantiate the tracking store with a database URI -tracking_store = mlflow.tracking._tracking_service.utils._get_store('sqlite:///mlflow.db') - -# Call any tracking store method -tracking_store.get_experiment(0) - -# Instantiate the model registry store with a database URI -model_registry_store = mlflow.tracking._model_registry.utils._get_store('sqlite:///mlflow.db') - -# Call any model registry store method -model_registry_store.get_latest_versions("model") -``` - -## General setup - -### Mage - -This repository uses [mage](https://magefile.org/) to streamline some utilily functions. - -```bash -# Install mage (already done in the dev container) -go install github.com/magefile/mage@v1.15.0 - -# See all targets -mage - -# Execute single target -mage dev -``` - -The beauty of Mage is that we can use regular Go code for our scripting. -That being said, we are not married to this tool. - -### mlflow source code - -To integrate with MLflow, you need to include the source code. The [mlflow/mlflow](https://github.com/mlflow/mlflow/) repository contains proto files that define the tracking API. It also includes Python tests that we use to verify our Go implementation produces identical behaviour. - -We use a `.mlflow.ref` file to specify the exact location from which to pull our sources. The format should be `remote#reference`, where `remote` is a git remote and `reference` is a branch, tag, or commit SHA. - -If the `.mlflow.ref` file is modified and becomes out of sync with the current source files, the mage target will automatically detect this. To manually force a sync, you can run `mage repo:update`. - -### Protos - -To ensure we stay compatible with the Python implementation, we aim to generate as much as possible based on the `.proto` files. - -By running - -```bash -mage generate -``` - -Go code will be generated. Use the protos files from `.mlflow.repo` repository. - -This incudes the generation of: - -- Structs for each endpoint. ([pkg/protos](./protos/service.pb.go)) -- Go interfaces for each service. ([pkg/contract/service/*.g.go](./contract/service/tracking.g.go)) -- [fiber](https://gofiber.io/) routes for each endpoint. ([pkg/server/routes/*.g.go](./server/routes/tracking.g.go)) - -If there is any change in the proto files, this should ripple into the Go code. - -## Launching the Go server - -To enable use of the Go server, users can run the `mlflow-go server` command. - -```bash -mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres -``` - -This will launch the python process as usual. Within Python, a random port is chosen to start the existing server and a Go child process is spawned. The Go server will use the user specified port (5000 by default) and spawn the actual Python server as its own child process (`gunicorn` or `waitress`). -Any incoming requests the Go server cannot process will be proxied to the existing Python server. - -Any Go-specific options can be passed with `--go-opts`, which takes a comma-separated list of key-value pairs. - -```bash -mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres --go-opts log_level=debug,shutdown_timeout=5s -``` - -## Building the Go binary - -To ensure everything still compiles: - -```bash -go build -o /dev/null ./pkg/cmd/server -``` - -or - -```bash -python -m mlflow_go.lib . /tmp -``` - -## Request validation - -We use [Go validator](https://github.com/go-playground/validator) to validate all incoming request structs. -As the proto files don't specify any validation rules, we map them manually in [pkg/cmd/generate/validations.go](./cmd/generate/validations.go). - -Once the mapping has been done, validation will be invoked automatically in the generated fiber code. - -When the need arises, we can write custom validation function in [pkg/validation/validation.go](./validation/validation.go). - -## Data access - -Initially, we want to focus on supporting Postgres SQL. We chose [Gorm](https://gorm.io/) as ORM to interact with the database. - -We do not generate any Go code based on the database schema. Gorm has generation capabilities but they didn't fit our needs. The plan would be to eventually assert the current code stil matches the database schema via an intergration test. - -All the models use pointers for their fields. We do this for performance reasons and to distinguish between zero values and null values. - -## Testing strategy - -> [!WARNING] -> TODO rewrite this whole section - -The Python integration tests have been adapted to also run against the Go implementation. Just run them as usual, e.g. - -```bash -pytest tests/tracking/test_rest_tracking.py -``` - -To run only the tests targetting the Go implementation, you can use the `-k` flag: - -```bash -pytest tests/tracking/test_rest_tracking.py -k '[go-' -``` - -If you'd like to run a specific test and see its output 'live', you can use the `-s` flag: - -```bash -pytest -s "tests/tracking/test_rest_tracking.py::test_create_experiment_validation[go-postgresql]" -``` - -See the [pytest documentation](https://docs.pytest.org/en/8.2.x/how-to/usage.html#specifying-which-tests-to-run) for more details. - -## Supported endpoints - -The currently supported endpoints can be found by running - -```bash -mage endpoints -``` - -## Linters - -We have enabled various linters from [golangci-lint](https://golangci-lint.run/), you can run these via: - -```bash -pre-commit run golangci-lint --all-files -``` - -Sometimes `golangci-lint` can complain about unrelated files, run `golangci-lint cache clean` to clear the cache. - -## Failing tests - -The following Python tests are currently failing: - -``` -===================================================================================================================== short test summary info ====================================================================================================================== -FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_inputs_with_large_inputs_limit_check - AssertionError: assert {'digest': 'd...ema': '', ...} == {'digest': 'd...a': None, ...} -======================================================================================== 1 failed, 358 passed, 9 skipped, 128 deselected, 10 warnings in 227.64s (0:03:47) ========================================================================================= -``` - -## Debug Failing Tests - -Sometimes, it can be very useful to modify failing tests and use `print` statements to display the current state or differences between objects from Python or Go services. - -Adding `"-vv"` to the `pytest` command in `magefiles/tests.go` can also provide more information when assertions are not met. - -### Targeting Local Postgres in Integration Tests - -At times, you might want to apply store calls to your local database to investigate certain read operations via the local tracking server. - -You can achieve this by changing: - -```python -def test_search_runs_datasets(store: SqlAlchemyStore): -``` - -to: - -```python -def test_search_runs_datasets(): - db_uri = "postgresql://postgres:postgres@localhost:5432/postgres" - artifact_uri = Path("/tmp/artifacts") - artifact_uri.mkdir(exist_ok=True) - store = SqlAlchemyStore(db_uri, artifact_uri.as_uri()) -``` - +# Go backend for MLflow + +In order to increase the performance of the tracking server and the various stores, we propose to rewrite the server and store implementation in Go. + +## Usage + +### Installation + +This package is not yet available on PyPI and currently requires the [Go SDK](https://go.dev) to be installed. + +You can then install the package via pip: +```bash +pip install git+https://github.com/jgiannuzzi/mlflow-go.git +``` + +### Using the Go server + +```bash +# Start the Go server with a database URI +# Other databases are supported as well: postgresql, mysql and mssql +mlflow-go server --backend-store-uri sqlite:///mlflow.db +``` + +```python +import mlflow + +# Use the Go server +mlflow.set_tracking_uri("http://localhost:5000") + +# Use MLflow as usual +mlflow.set_experiment("my-experiment") + +with mlflow.start_run(): + mlflow.log_param("param", 1) + mlflow.log_metric("metric", 2) +``` + +### Using the client-side Go implementation + +```python +import mlflow +import mlflow_go + +# Enable the Go client implementation (disabled by default) +mlflow_go.enable_go() + +# Set the tracking URI (you can also set it via the environment variable MLFLOW_TRACKING_URI) +# Currently only database URIs are supported +mlflow.set_tracking_uri("sqlite:///mlflow.db") + +# Use MLflow as usual +mlflow.set_experiment("my-experiment") + +with mlflow.start_run(): + mlflow.log_param("param", 1) + mlflow.log_metric("metric", 2) +``` + +## Temp stuff + +### Dev setup + +```bash +# Install our Python package and its dependencies +pip install -e . + +# Install the dreaded psycho +pip install psycopg2-binary + +# Archive the MLFlow pre-built UI +tar -C /usr/local/python/current/lib/python3.8/site-packages/mlflow -czvf ./ui.tgz ./server/js/build + +# Clone the MLflow repo +git clone https://github.com/jgiannuzzi/mlflow.git -b master .mlflow.repo + +# Add the UI back to it +tar -C .mlflow.repo/mlflow -xzvf ./ui.tgz + +# Install it in editable mode +pip install -e .mlflow.repo +``` + +or run `mage temp`. + +### Run the tests manually + +```bash +# Build the Go binary in a temporary directory +libpath=$(mktemp -d) +python -m mlflow_go.lib . $libpath + +# Run the tests (currently just the server ones) +MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ + .mlflow.repo/tests/tracking/test_rest_tracking.py \ + .mlflow.repo/tests/tracking/test_model_registry.py \ + .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py \ + .mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py \ + -k 'not [file' + +# Remove the Go binary +rm -rf $libpath + +# If you want to run a specific test with more verbosity +# -s for live output +# --log-level=debug for more verbosity (passed down to the Go server/stores) +MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ + .mlflow.repo/tests/tracking/test_rest_tracking.py::test_create_experiment_validation \ + -k 'not [file' \ + -s --log-level=debug +``` + +Or run the `mage test:python` target. + +### Use the Go store directly in Python + +```python +import logging +import mlflow +import mlflow_go + +# Enable debug logging +logging.basicConfig() +logging.getLogger('mlflow_go').setLevel(logging.DEBUG) + +# Enable the Go client implementation (disabled by default) +mlflow_go.enable_go() + +# Instantiate the tracking store with a database URI +tracking_store = mlflow.tracking._tracking_service.utils._get_store('sqlite:///mlflow.db') + +# Call any tracking store method +tracking_store.get_experiment(0) + +# Instantiate the model registry store with a database URI +model_registry_store = mlflow.tracking._model_registry.utils._get_store('sqlite:///mlflow.db') + +# Call any model registry store method +model_registry_store.get_latest_versions("model") +``` + +## General setup + +### Mage + +This repository uses [mage](https://magefile.org/) to streamline some utilily functions. + +```bash +# Install mage (already done in the dev container) +go install github.com/magefile/mage@v1.15.0 + +# See all targets +mage + +# Execute single target +mage dev +``` + +The beauty of Mage is that we can use regular Go code for our scripting. +That being said, we are not married to this tool. + +### mlflow source code + +To integrate with MLflow, you need to include the source code. The [mlflow/mlflow](https://github.com/mlflow/mlflow/) repository contains proto files that define the tracking API. It also includes Python tests that we use to verify our Go implementation produces identical behaviour. + +We use a `.mlflow.ref` file to specify the exact location from which to pull our sources. The format should be `remote#reference`, where `remote` is a git remote and `reference` is a branch, tag, or commit SHA. + +If the `.mlflow.ref` file is modified and becomes out of sync with the current source files, the mage target will automatically detect this. To manually force a sync, you can run `mage repo:update`. + +### Protos + +To ensure we stay compatible with the Python implementation, we aim to generate as much as possible based on the `.proto` files. + +By running + +```bash +mage generate +``` + +Go code will be generated. Use the protos files from `.mlflow.repo` repository. + +This incudes the generation of: + +- Structs for each endpoint. ([pkg/protos](./protos/service.pb.go)) +- Go interfaces for each service. ([pkg/contract/service/*.g.go](./contract/service/tracking.g.go)) +- [fiber](https://gofiber.io/) routes for each endpoint. ([pkg/server/routes/*.g.go](./server/routes/tracking.g.go)) + +If there is any change in the proto files, this should ripple into the Go code. + +## Launching the Go server + +To enable use of the Go server, users can run the `mlflow-go server` command. + +```bash +mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres +``` + +This will launch the python process as usual. Within Python, a random port is chosen to start the existing server and a Go child process is spawned. The Go server will use the user specified port (5000 by default) and spawn the actual Python server as its own child process (`gunicorn` or `waitress`). +Any incoming requests the Go server cannot process will be proxied to the existing Python server. + +Any Go-specific options can be passed with `--go-opts`, which takes a comma-separated list of key-value pairs. + +```bash +mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres --go-opts log_level=debug,shutdown_timeout=5s +``` + +## Building the Go binary + +To ensure everything still compiles: + +```bash +go build -o /dev/null ./pkg/cmd/server +``` + +or + +```bash +python -m mlflow_go.lib . /tmp +``` + +## Request validation + +We use [Go validator](https://github.com/go-playground/validator) to validate all incoming request structs. +As the proto files don't specify any validation rules, we map them manually in [pkg/cmd/generate/validations.go](./cmd/generate/validations.go). + +Once the mapping has been done, validation will be invoked automatically in the generated fiber code. + +When the need arises, we can write custom validation function in [pkg/validation/validation.go](./validation/validation.go). + +## Data access + +Initially, we want to focus on supporting Postgres SQL. We chose [Gorm](https://gorm.io/) as ORM to interact with the database. + +We do not generate any Go code based on the database schema. Gorm has generation capabilities but they didn't fit our needs. The plan would be to eventually assert the current code stil matches the database schema via an intergration test. + +All the models use pointers for their fields. We do this for performance reasons and to distinguish between zero values and null values. + +## Testing strategy + +> [!WARNING] +> TODO rewrite this whole section + +The Python integration tests have been adapted to also run against the Go implementation. Just run them as usual, e.g. + +```bash +pytest tests/tracking/test_rest_tracking.py +``` + +To run only the tests targetting the Go implementation, you can use the `-k` flag: + +```bash +pytest tests/tracking/test_rest_tracking.py -k '[go-' +``` + +If you'd like to run a specific test and see its output 'live', you can use the `-s` flag: + +```bash +pytest -s "tests/tracking/test_rest_tracking.py::test_create_experiment_validation[go-postgresql]" +``` + +See the [pytest documentation](https://docs.pytest.org/en/8.2.x/how-to/usage.html#specifying-which-tests-to-run) for more details. + +## Supported endpoints + +The currently supported endpoints can be found by running + +```bash +mage endpoints +``` + +## Linters + +We have enabled various linters from [golangci-lint](https://golangci-lint.run/), you can run these via: + +```bash +pre-commit run golangci-lint --all-files +``` + +Sometimes `golangci-lint` can complain about unrelated files, run `golangci-lint cache clean` to clear the cache. + +## Failing tests + +The following Python tests are currently failing: + +``` +===================================================================================================================== short test summary info ====================================================================================================================== +FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_inputs_with_large_inputs_limit_check - AssertionError: assert {'digest': 'd...ema': '', ...} == {'digest': 'd...a': None, ...} +======================================================================================== 1 failed, 358 passed, 9 skipped, 128 deselected, 10 warnings in 227.64s (0:03:47) ========================================================================================= +``` + +## Debug Failing Tests + +Sometimes, it can be very useful to modify failing tests and use `print` statements to display the current state or differences between objects from Python or Go services. + +Adding `"-vv"` to the `pytest` command in `magefiles/tests.go` can also provide more information when assertions are not met. + +### Targeting Local Postgres in Integration Tests + +At times, you might want to apply store calls to your local database to investigate certain read operations via the local tracking server. + +You can achieve this by changing: + +```python +def test_search_runs_datasets(store: SqlAlchemyStore): +``` + +to: + +```python +def test_search_runs_datasets(): + db_uri = "postgresql://postgres:postgres@localhost:5432/postgres" + artifact_uri = Path("/tmp/artifacts") + artifact_uri.mkdir(exist_ok=True) + store = SqlAlchemyStore(db_uri, artifact_uri.as_uri()) +``` + in the test file located in `.mlflow.repo`. \ No newline at end of file diff --git a/conftest.py b/conftest.py index 4ea53bd..999a2a7 100644 --- a/conftest.py +++ b/conftest.py @@ -1,64 +1,64 @@ -import logging -import pathlib -from unittest.mock import patch - -_logger = logging.getLogger(__name__) - - -def load_new_function(file_path, func_name): - with open(file_path) as f: - new_func_code = f.read() - - local_dict = {} - exec(new_func_code, local_dict) - return local_dict[func_name] - - -def pytest_configure(config): - for func_to_patch, new_func_file_relative in ( - ( - "tests.tracking.integration_test_utils._init_server", - "tests/override_server.py", - ), - ( - "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore", - "tests/override_tracking_store.py", - ), - ( - "mlflow.store.model_registry.sqlalchemy_store.SqlAlchemyStore", - "tests/override_model_registry_store.py", - ), - # This test will patch some Python internals to invoke an internal exception. - # We cannot do this in Go. - ( - "tests.store.tracking.test_sqlalchemy_store.test_log_batch_internal_error", - "tests/override_test_sqlalchemy_store.py", - ), - # This test uses monkeypatch.setenv which does not flow through to the - ( - "tests.store.tracking.test_sqlalchemy_store.test_log_batch_params_max_length_value", - "tests/override_test_sqlalchemy_store.py", - ), - # This tests calls the store using invalid metric entity that cannot be converted - # to its proto counterpart. - # Example: entities.Metric("invalid_metric", None, (int(time.time() * 1000)), 0).to_proto() - ( - "tests.store.tracking.test_sqlalchemy_store.test_log_batch_null_metrics", - "tests/override_test_sqlalchemy_store.py", - ), - # We do not support applying the SQL schema to sqlite like Python does. - # So we do not support sqlite:////:memory: database. - ( - "tests.store.tracking.test_sqlalchemy_store.test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db", - "tests/override_test_sqlalchemy_store.py", - ), - ): - func_name = func_to_patch.rsplit(".", 1)[1] - new_func_file = ( - pathlib.Path(__file__).parent.joinpath(new_func_file_relative).resolve().as_posix() - ) - - new_func = load_new_function(new_func_file, func_name) - - _logger.info(f"Patching function: {func_to_patch}") - patch(func_to_patch, new_func).start() +import logging +import pathlib +from unittest.mock import patch + +_logger = logging.getLogger(__name__) + + +def load_new_function(file_path, func_name): + with open(file_path) as f: + new_func_code = f.read() + + local_dict = {} + exec(new_func_code, local_dict) + return local_dict[func_name] + + +def pytest_configure(config): + for func_to_patch, new_func_file_relative in ( + ( + "tests.tracking.integration_test_utils._init_server", + "tests/override_server.py", + ), + ( + "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore", + "tests/override_tracking_store.py", + ), + ( + "mlflow.store.model_registry.sqlalchemy_store.SqlAlchemyStore", + "tests/override_model_registry_store.py", + ), + # This test will patch some Python internals to invoke an internal exception. + # We cannot do this in Go. + ( + "tests.store.tracking.test_sqlalchemy_store.test_log_batch_internal_error", + "tests/override_test_sqlalchemy_store.py", + ), + # This test uses monkeypatch.setenv which does not flow through to the + ( + "tests.store.tracking.test_sqlalchemy_store.test_log_batch_params_max_length_value", + "tests/override_test_sqlalchemy_store.py", + ), + # This tests calls the store using invalid metric entity that cannot be converted + # to its proto counterpart. + # Example: entities.Metric("invalid_metric", None, (int(time.time() * 1000)), 0).to_proto() + ( + "tests.store.tracking.test_sqlalchemy_store.test_log_batch_null_metrics", + "tests/override_test_sqlalchemy_store.py", + ), + # We do not support applying the SQL schema to sqlite like Python does. + # So we do not support sqlite:////:memory: database. + ( + "tests.store.tracking.test_sqlalchemy_store.test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db", + "tests/override_test_sqlalchemy_store.py", + ), + ): + func_name = func_to_patch.rsplit(".", 1)[1] + new_func_file = ( + pathlib.Path(__file__).parent.joinpath(new_func_file_relative).resolve().as_posix() + ) + + new_func = load_new_function(new_func_file, func_name) + + _logger.info(f"Patching function: {func_to_patch}") + patch(func_to_patch, new_func).start() diff --git a/go.mod b/go.mod index 44ef743..25ad617 100644 --- a/go.mod +++ b/go.mod @@ -1,62 +1,62 @@ -module github.com/mlflow/mlflow-go - -go 1.22 - -require ( - github.com/DATA-DOG/go-sqlmock v1.5.2 - github.com/go-playground/validator/v10 v10.20.0 - github.com/gofiber/fiber/v2 v2.52.4 - github.com/google/uuid v1.6.0 - github.com/iancoleman/strcase v0.3.0 - github.com/magefile/mage v1.15.0 - github.com/sirupsen/logrus v1.9.3 - github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.17.1 - golang.org/x/sys v0.20.0 - google.golang.org/protobuf v1.34.1 - gorm.io/driver/mysql v1.5.6 - gorm.io/driver/postgres v1.5.7 - gorm.io/driver/sqlite v1.5.6 - gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.10 -) - -require ( - github.com/andybalholm/brotli v1.1.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-sql-driver/mysql v1.7.0 // indirect - github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect - github.com/golang-sql/sqlexp v0.1.0 // indirect - github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect - github.com/klauspost/compress v1.17.8 // indirect - github.com/kr/text v0.2.0 // indirect - github.com/leodido/go-urn v1.4.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/microsoft/go-mssqldb v1.6.0 // indirect - github.com/olekukonko/tablewriter v0.0.5 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rivo/uniseg v0.4.7 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect - github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.53.0 // indirect - github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/net v0.25.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/text v0.15.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +module github.com/mlflow/mlflow-go + +go 1.22 + +require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/go-playground/validator/v10 v10.20.0 + github.com/gofiber/fiber/v2 v2.52.4 + github.com/google/uuid v1.6.0 + github.com/iancoleman/strcase v0.3.0 + github.com/magefile/mage v1.15.0 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 + github.com/tidwall/gjson v1.17.1 + golang.org/x/sys v0.20.0 + google.golang.org/protobuf v1.34.1 + gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 + gorm.io/driver/sqlite v1.5.6 + gorm.io/driver/sqlserver v1.5.3 + gorm.io/gorm v1.25.10 +) + +require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/klauspost/compress v1.17.8 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/microsoft/go-mssqldb v1.6.0 // indirect + github.com/olekukonko/tablewriter v0.0.5 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.53.0 // indirect + github.com/valyala/tcplisten v1.0.0 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.15.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index f58db98..16c2faf 100644 --- a/go.sum +++ b/go.sum @@ -1,236 +1,236 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 h1:/iHxaJhsFr0+xVFfbMr5vxz848jyiWuIEDhYq3y5odY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0 h1:yfJe15aSwEQ6Oo6J+gdfdulPNoZ3TEhmbhLIoxZcA+U= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0/go.mod h1:Q28U+75mpCaSCDowNEmhIo/rmgdkqmkmzI7N6TGR4UY= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 h1:T028gtTPiYt/RMUfs8nVsAL7FDQrfLlrm/NnRG/zcC4= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0/go.mod h1:cw4zVQgBby0Z5f2v0itn6se2dDP17nTjbZFXW5uPyHA= -github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpCKs6LGGLAhwBARt9632unrVcI6i8s/8os= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= -github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= -github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= -github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= -github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= -github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= -github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= -github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/gofiber/fiber/v2 v2.52.4 h1:P+T+4iK7VaqUsq2PALYEfBBo6bJZ4q3FP8cZ84EggTM= -github.com/gofiber/fiber/v2 v2.52.4/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= -github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= -github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= -github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= -github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= -github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= -github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= -github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= -github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= -github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= -github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= -github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= -github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= -github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= -github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= -github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/microsoft/go-mssqldb v1.6.0 h1:mM3gYdVwEPFrlg/Dvr2DNVEgYFG7L42l+dGc67NNNpc= -github.com/microsoft/go-mssqldb v1.6.0/go.mod h1:00mDtPbeQCRGC1HwOOR5K/gr30P1NcEG0vx6Kbv2aJU= -github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= -github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= -github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.53.0 h1:lW/+SUkOxCx2vlIu0iaImv4JLrVRnbbkpCoaawvA4zc= -github.com/valyala/fasthttp v1.53.0/go.mod h1:6dt4/8olwq9QARP/TDuPmWyWcl4byhpvTJ4AAtcz+QM= -github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= -github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= -gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= -gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= -gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= -gorm.io/driver/sqlserver v1.5.3 h1:rjupPS4PVw+rjJkfvr8jn2lJ8BMhT4UW5FwuJY0P3Z0= -gorm.io/driver/sqlserver v1.5.3/go.mod h1:B+CZ0/7oFJ6tAlefsKoyxdgDCXJKSgwS2bMOQZT0I00= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= -gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 h1:/iHxaJhsFr0+xVFfbMr5vxz848jyiWuIEDhYq3y5odY= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0 h1:yfJe15aSwEQ6Oo6J+gdfdulPNoZ3TEhmbhLIoxZcA+U= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0/go.mod h1:Q28U+75mpCaSCDowNEmhIo/rmgdkqmkmzI7N6TGR4UY= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 h1:T028gtTPiYt/RMUfs8nVsAL7FDQrfLlrm/NnRG/zcC4= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0/go.mod h1:cw4zVQgBby0Z5f2v0itn6se2dDP17nTjbZFXW5uPyHA= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= +github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpCKs6LGGLAhwBARt9632unrVcI6i8s/8os= +github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/gofiber/fiber/v2 v2.52.4 h1:P+T+4iK7VaqUsq2PALYEfBBo6bJZ4q3FP8cZ84EggTM= +github.com/gofiber/fiber/v2 v2.52.4/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= +github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= +github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/microsoft/go-mssqldb v1.6.0 h1:mM3gYdVwEPFrlg/Dvr2DNVEgYFG7L42l+dGc67NNNpc= +github.com/microsoft/go-mssqldb v1.6.0/go.mod h1:00mDtPbeQCRGC1HwOOR5K/gr30P1NcEG0vx6Kbv2aJU= +github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.53.0 h1:lW/+SUkOxCx2vlIu0iaImv4JLrVRnbbkpCoaawvA4zc= +github.com/valyala/fasthttp v1.53.0/go.mod h1:6dt4/8olwq9QARP/TDuPmWyWcl4byhpvTJ4AAtcz+QM= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= +gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/driver/sqlserver v1.5.3 h1:rjupPS4PVw+rjJkfvr8jn2lJ8BMhT4UW5FwuJY0P3Z0= +gorm.io/driver/sqlserver v1.5.3/go.mod h1:B+CZ0/7oFJ6tAlefsKoyxdgDCXJKSgwS2bMOQZT0I00= +gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= +gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/magefiles/dev.go b/magefiles/dev.go index 9b6d739..1012c84 100644 --- a/magefiles/dev.go +++ b/magefiles/dev.go @@ -1,28 +1,28 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -// Start the mlflow-go dev server connecting to postgres. -func Dev() error { - mg.Deps(Generate) - - envs := make(map[string]string) - envs["MLFLOW_TRUNCATE_LONG_VALUES"] = "false" - envs["MLFLOW_SQLALCHEMYSTORE_ECHO"] = "true" - - return sh.RunWithV( - envs, - "mlflow-go", - "server", - "--backend-store-uri", - "postgresql://postgres:postgres@localhost:5432/postgres", - "--go-opts", - "log_level=debug,shutdown_timeout=5s", - ) -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +// Start the mlflow-go dev server connecting to postgres. +func Dev() error { + mg.Deps(Generate) + + envs := make(map[string]string) + envs["MLFLOW_TRUNCATE_LONG_VALUES"] = "false" + envs["MLFLOW_SQLALCHEMYSTORE_ECHO"] = "true" + + return sh.RunWithV( + envs, + "mlflow-go", + "server", + "--backend-store-uri", + "postgresql://postgres:postgres@localhost:5432/postgres", + "--go-opts", + "log_level=debug,shutdown_timeout=5s", + ) +} diff --git a/magefiles/endpoints.go b/magefiles/endpoints.go index 5239cac..a7f3163 100644 --- a/magefiles/endpoints.go +++ b/magefiles/endpoints.go @@ -1,56 +1,56 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "os" - - "github.com/olekukonko/tablewriter" - - "github.com/mlflow/mlflow-go/magefiles/generate" - "github.com/mlflow/mlflow-go/magefiles/generate/discovery" -) - -func contains(slice []string, value string) bool { - for _, v := range slice { - if v == value { - return true - } - } - - return false -} - -// Print an overview of implementated API endpoints. -func Endpoints() error { - services, err := discovery.GetServiceInfos() - if err != nil { - return err - } - - table := tablewriter.NewWriter(os.Stdout) - table.SetHeader([]string{"Service", "Endpoint", "Implemented"}) - table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_CENTER}) - table.SetRowLine(true) - - for _, service := range services { - servinceInfo, ok := generate.ServiceInfoMap[service.Name] - if !ok { - continue - } - - for _, method := range service.Methods { - implemented := "No" - if contains(servinceInfo.ImplementedEndpoints, method.Name) { - implemented = "Yes" - } - - table.Append([]string{service.Name, method.Name, implemented}) - } - } - - table.Render() - - return nil -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "os" + + "github.com/olekukonko/tablewriter" + + "github.com/mlflow/mlflow-go/magefiles/generate" + "github.com/mlflow/mlflow-go/magefiles/generate/discovery" +) + +func contains(slice []string, value string) bool { + for _, v := range slice { + if v == value { + return true + } + } + + return false +} + +// Print an overview of implementated API endpoints. +func Endpoints() error { + services, err := discovery.GetServiceInfos() + if err != nil { + return err + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"Service", "Endpoint", "Implemented"}) + table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_CENTER}) + table.SetRowLine(true) + + for _, service := range services { + servinceInfo, ok := generate.ServiceInfoMap[service.Name] + if !ok { + continue + } + + for _, method := range service.Methods { + implemented := "No" + if contains(servinceInfo.ImplementedEndpoints, method.Name) { + implemented = "Yes" + } + + table.Append([]string{service.Name, method.Name, implemented}) + } + } + + table.Render() + + return nil +} diff --git a/magefiles/generate.go b/magefiles/generate.go index 7b22a58..2ed8537 100644 --- a/magefiles/generate.go +++ b/magefiles/generate.go @@ -1,45 +1,45 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "path" - "path/filepath" - - "github.com/gofiber/fiber/v2/log" - "github.com/magefile/mage/mg" - - "github.com/mlflow/mlflow-go/magefiles/generate" -) - -// Generate Go files based on proto files and other configuration. -func Generate() error { - mg.Deps(Repo.Init) - - protoFolder, err := filepath.Abs(path.Join(MLFlowRepoFolderName, "mlflow", "protos")) - if err != nil { - return err - } - - if err := generate.RunProtoc(protoFolder); err != nil { - return err - } - - pkgFolder, err := filepath.Abs("pkg") - if err != nil { - return err - } - - if err := generate.AddQueryAnnotations(pkgFolder); err != nil { - return err - } - - if err := generate.SourceCode(pkgFolder); err != nil { - return err - } - - log.Info("Successfully added query annotations and generated services!") - - return nil -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "path" + "path/filepath" + + "github.com/gofiber/fiber/v2/log" + "github.com/magefile/mage/mg" + + "github.com/mlflow/mlflow-go/magefiles/generate" +) + +// Generate Go files based on proto files and other configuration. +func Generate() error { + mg.Deps(Repo.Init) + + protoFolder, err := filepath.Abs(path.Join(MLFlowRepoFolderName, "mlflow", "protos")) + if err != nil { + return err + } + + if err := generate.RunProtoc(protoFolder); err != nil { + return err + } + + pkgFolder, err := filepath.Abs("pkg") + if err != nil { + return err + } + + if err := generate.AddQueryAnnotations(pkgFolder); err != nil { + return err + } + + if err := generate.SourceCode(pkgFolder); err != nil { + return err + } + + log.Info("Successfully added query annotations and generated services!") + + return nil +} diff --git a/magefiles/generate/ast_creation.go b/magefiles/generate/ast_creation.go index 1c7d048..d40875b 100644 --- a/magefiles/generate/ast_creation.go +++ b/magefiles/generate/ast_creation.go @@ -1,102 +1,102 @@ -package generate - -import ( - "go/ast" - "go/token" -) - -func mkImportSpec(value string) *ast.ImportSpec { - return &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: value}} -} - -func mkImportStatements(importStatements ...string) ast.Decl { - specs := make([]ast.Spec, 0, len(importStatements)) - - for _, importStatement := range importStatements { - specs = append(specs, mkImportSpec(importStatement)) - } - - return &ast.GenDecl{ - Tok: token.IMPORT, - Specs: specs, - } -} - -func mkStarExpr(e ast.Expr) *ast.StarExpr { - return &ast.StarExpr{ - X: e, - } -} - -func mkSelectorExpr(x, sel string) *ast.SelectorExpr { - return &ast.SelectorExpr{X: ast.NewIdent(x), Sel: ast.NewIdent(sel)} -} - -func mkNamedField(name string, typ ast.Expr) *ast.Field { - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(name)}, - Type: typ, - } -} - -func mkField(typ ast.Expr) *ast.Field { - return &ast.Field{ - Type: typ, - } -} - -// fun(arg1, arg2, ...) -func mkCallExpr(fun ast.Expr, args ...ast.Expr) *ast.CallExpr { - return &ast.CallExpr{ - Fun: fun, - Args: args, - } -} - -// Shorthand for creating &expr. -func mkAmpExpr(expr ast.Expr) *ast.UnaryExpr { - return &ast.UnaryExpr{ - Op: token.AND, - X: expr, - } -} - -// err != nil. -var errNotEqualNil = &ast.BinaryExpr{ - X: ast.NewIdent("err"), - Op: token.NEQ, - Y: ast.NewIdent("nil"), -} - -// return err. -var returnErr = &ast.ReturnStmt{ - Results: []ast.Expr{ast.NewIdent("err")}, -} - -func mkBlockStmt(stmts ...ast.Stmt) *ast.BlockStmt { - return &ast.BlockStmt{ - List: stmts, - } -} - -func mkIfStmt(init ast.Stmt, cond ast.Expr, body *ast.BlockStmt) *ast.IfStmt { - return &ast.IfStmt{ - Init: init, - Cond: cond, - Body: body, - } -} - -func mkAssignStmt(lhs, rhs []ast.Expr) *ast.AssignStmt { - return &ast.AssignStmt{ - Lhs: lhs, - Tok: token.DEFINE, - Rhs: rhs, - } -} - -func mkReturnStmt(results ...ast.Expr) *ast.ReturnStmt { - return &ast.ReturnStmt{ - Results: results, - } -} +package generate + +import ( + "go/ast" + "go/token" +) + +func mkImportSpec(value string) *ast.ImportSpec { + return &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: value}} +} + +func mkImportStatements(importStatements ...string) ast.Decl { + specs := make([]ast.Spec, 0, len(importStatements)) + + for _, importStatement := range importStatements { + specs = append(specs, mkImportSpec(importStatement)) + } + + return &ast.GenDecl{ + Tok: token.IMPORT, + Specs: specs, + } +} + +func mkStarExpr(e ast.Expr) *ast.StarExpr { + return &ast.StarExpr{ + X: e, + } +} + +func mkSelectorExpr(x, sel string) *ast.SelectorExpr { + return &ast.SelectorExpr{X: ast.NewIdent(x), Sel: ast.NewIdent(sel)} +} + +func mkNamedField(name string, typ ast.Expr) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(name)}, + Type: typ, + } +} + +func mkField(typ ast.Expr) *ast.Field { + return &ast.Field{ + Type: typ, + } +} + +// fun(arg1, arg2, ...) +func mkCallExpr(fun ast.Expr, args ...ast.Expr) *ast.CallExpr { + return &ast.CallExpr{ + Fun: fun, + Args: args, + } +} + +// Shorthand for creating &expr. +func mkAmpExpr(expr ast.Expr) *ast.UnaryExpr { + return &ast.UnaryExpr{ + Op: token.AND, + X: expr, + } +} + +// err != nil. +var errNotEqualNil = &ast.BinaryExpr{ + X: ast.NewIdent("err"), + Op: token.NEQ, + Y: ast.NewIdent("nil"), +} + +// return err. +var returnErr = &ast.ReturnStmt{ + Results: []ast.Expr{ast.NewIdent("err")}, +} + +func mkBlockStmt(stmts ...ast.Stmt) *ast.BlockStmt { + return &ast.BlockStmt{ + List: stmts, + } +} + +func mkIfStmt(init ast.Stmt, cond ast.Expr, body *ast.BlockStmt) *ast.IfStmt { + return &ast.IfStmt{ + Init: init, + Cond: cond, + Body: body, + } +} + +func mkAssignStmt(lhs, rhs []ast.Expr) *ast.AssignStmt { + return &ast.AssignStmt{ + Lhs: lhs, + Tok: token.DEFINE, + Rhs: rhs, + } +} + +func mkReturnStmt(results ...ast.Expr) *ast.ReturnStmt { + return &ast.ReturnStmt{ + Results: results, + } +} diff --git a/magefiles/generate/discovery/discovery.go b/magefiles/generate/discovery/discovery.go index 18a10f2..237f272 100644 --- a/magefiles/generate/discovery/discovery.go +++ b/magefiles/generate/discovery/discovery.go @@ -1,97 +1,97 @@ -package discovery - -import ( - "fmt" - "regexp" - "strings" - - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/protos/artifacts" -) - -type ServiceInfo struct { - Name string - Methods []MethodInfo -} - -type MethodInfo struct { - Name string - PackageName string - Input string - Output string - Endpoints []Endpoint -} - -type Endpoint struct { - Method string - Path string -} - -var routeParameterRegex = regexp.MustCompile(`<[^>]+:([^>]+)>`) - -// Get the safe path to use in Fiber registration. -func (e Endpoint) GetFiberPath() string { - // e.Path cannot be trusted, it could be something like /mlflow-artifacts/artifacts/ - // Which would need to converted to /mlflow-artifacts/artifacts/:path - path := routeParameterRegex.ReplaceAllStringFunc(e.Path, func(s string) string { - parts := strings.Split(s, ":") - - return ":" + strings.Trim(parts[0], "< ") - }) - - return path -} - -func GetServiceInfos() ([]ServiceInfo, error) { - serviceInfos := make([]ServiceInfo, 0) - - services := []struct { - Name string - PackageName string - Descriptor protoreflect.FileDescriptor - }{ - {"MlflowService", "protos", protos.File_service_proto}, - {"ModelRegistryService", "protos", protos.File_model_registry_proto}, - {"MlflowArtifactsService", "artifacts", artifacts.File_mlflow_artifacts_proto}, - } - - for _, service := range services { - serviceDescriptor := service.Descriptor.Services().ByName(protoreflect.Name(service.Name)) - - if serviceDescriptor == nil { - //nolint:goerr113 - return nil, fmt.Errorf("service %s not found", service.Name) - } - - serviceInfo := ServiceInfo{Name: service.Name, Methods: make([]MethodInfo, 0)} - - methods := serviceDescriptor.Methods() - for mIdx := range methods.Len() { - method := methods.Get(mIdx) - options := method.Options() - extension := proto.GetExtension(options, protos.E_Rpc) - - endpoints := make([]Endpoint, 0) - rpcOptions, ok := extension.(*protos.DatabricksRpcOptions) - - if ok { - for _, endpoint := range rpcOptions.GetEndpoints() { - endpoints = append(endpoints, Endpoint{Method: endpoint.GetMethod(), Path: endpoint.GetPath()}) - } - } - - output := fmt.Sprintf("%s_%s", string(method.Output().Parent().Name()), string(method.Output().Name())) - methodInfo := MethodInfo{ - string(method.Name()), service.PackageName, string(method.Input().Name()), output, endpoints, - } - serviceInfo.Methods = append(serviceInfo.Methods, methodInfo) - } - - serviceInfos = append(serviceInfos, serviceInfo) - } - - return serviceInfos, nil -} +package discovery + +import ( + "fmt" + "regexp" + "strings" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/protos/artifacts" +) + +type ServiceInfo struct { + Name string + Methods []MethodInfo +} + +type MethodInfo struct { + Name string + PackageName string + Input string + Output string + Endpoints []Endpoint +} + +type Endpoint struct { + Method string + Path string +} + +var routeParameterRegex = regexp.MustCompile(`<[^>]+:([^>]+)>`) + +// Get the safe path to use in Fiber registration. +func (e Endpoint) GetFiberPath() string { + // e.Path cannot be trusted, it could be something like /mlflow-artifacts/artifacts/ + // Which would need to converted to /mlflow-artifacts/artifacts/:path + path := routeParameterRegex.ReplaceAllStringFunc(e.Path, func(s string) string { + parts := strings.Split(s, ":") + + return ":" + strings.Trim(parts[0], "< ") + }) + + return path +} + +func GetServiceInfos() ([]ServiceInfo, error) { + serviceInfos := make([]ServiceInfo, 0) + + services := []struct { + Name string + PackageName string + Descriptor protoreflect.FileDescriptor + }{ + {"MlflowService", "protos", protos.File_service_proto}, + {"ModelRegistryService", "protos", protos.File_model_registry_proto}, + {"MlflowArtifactsService", "artifacts", artifacts.File_mlflow_artifacts_proto}, + } + + for _, service := range services { + serviceDescriptor := service.Descriptor.Services().ByName(protoreflect.Name(service.Name)) + + if serviceDescriptor == nil { + //nolint:goerr113 + return nil, fmt.Errorf("service %s not found", service.Name) + } + + serviceInfo := ServiceInfo{Name: service.Name, Methods: make([]MethodInfo, 0)} + + methods := serviceDescriptor.Methods() + for mIdx := range methods.Len() { + method := methods.Get(mIdx) + options := method.Options() + extension := proto.GetExtension(options, protos.E_Rpc) + + endpoints := make([]Endpoint, 0) + rpcOptions, ok := extension.(*protos.DatabricksRpcOptions) + + if ok { + for _, endpoint := range rpcOptions.GetEndpoints() { + endpoints = append(endpoints, Endpoint{Method: endpoint.GetMethod(), Path: endpoint.GetPath()}) + } + } + + output := fmt.Sprintf("%s_%s", string(method.Output().Parent().Name()), string(method.Output().Name())) + methodInfo := MethodInfo{ + string(method.Name()), service.PackageName, string(method.Input().Name()), output, endpoints, + } + serviceInfo.Methods = append(serviceInfo.Methods, methodInfo) + } + + serviceInfos = append(serviceInfos, serviceInfo) + } + + return serviceInfos, nil +} diff --git a/magefiles/generate/discovery/discovery_test.go b/magefiles/generate/discovery/discovery_test.go index 447f7b7..f4bdeaf 100644 --- a/magefiles/generate/discovery/discovery_test.go +++ b/magefiles/generate/discovery/discovery_test.go @@ -1,55 +1,55 @@ -package discovery_test - -import ( - "testing" - - "github.com/mlflow/mlflow-go/magefiles/generate/discovery" -) - -func TestPattern(t *testing.T) { - t.Parallel() - - scenarios := []struct { - name string - endpoint discovery.Endpoint - expected string - }{ - { - name: "simple GET", - endpoint: discovery.Endpoint{ - Method: "GET", - Path: "/mlflow/experiments/get-by-name", - }, - expected: "/mlflow/experiments/get-by-name", - }, - { - name: "simple POST", - endpoint: discovery.Endpoint{ - Method: "POST", - Path: "/mlflow/experiments/create", - }, - expected: "/mlflow/experiments/create", - }, - { - name: "PUT with route parameter", - endpoint: discovery.Endpoint{ - Method: "PUT", - Path: "/mlflow-artifacts/artifacts/", - }, - expected: "/mlflow-artifacts/artifacts/:path", - }, - } - - for _, scenario := range scenarios { - currentScenario := scenario - t.Run(currentScenario.name, func(t *testing.T) { - t.Parallel() - - actual := currentScenario.endpoint.GetFiberPath() - - if actual != currentScenario.expected { - t.Errorf("Expected %s, got %s", currentScenario.expected, actual) - } - }) - } -} +package discovery_test + +import ( + "testing" + + "github.com/mlflow/mlflow-go/magefiles/generate/discovery" +) + +func TestPattern(t *testing.T) { + t.Parallel() + + scenarios := []struct { + name string + endpoint discovery.Endpoint + expected string + }{ + { + name: "simple GET", + endpoint: discovery.Endpoint{ + Method: "GET", + Path: "/mlflow/experiments/get-by-name", + }, + expected: "/mlflow/experiments/get-by-name", + }, + { + name: "simple POST", + endpoint: discovery.Endpoint{ + Method: "POST", + Path: "/mlflow/experiments/create", + }, + expected: "/mlflow/experiments/create", + }, + { + name: "PUT with route parameter", + endpoint: discovery.Endpoint{ + Method: "PUT", + Path: "/mlflow-artifacts/artifacts/", + }, + expected: "/mlflow-artifacts/artifacts/:path", + }, + } + + for _, scenario := range scenarios { + currentScenario := scenario + t.Run(currentScenario.name, func(t *testing.T) { + t.Parallel() + + actual := currentScenario.endpoint.GetFiberPath() + + if actual != currentScenario.expected { + t.Errorf("Expected %s, got %s", currentScenario.expected, actual) + } + }) + } +} diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index 43fb65f..63b22dc 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -1,87 +1,87 @@ -package generate - -type ServiceGenerationInfo struct { - FileNameWithoutExtension string - ServiceName string - ImplementedEndpoints []string -} - -var ServiceInfoMap = map[string]ServiceGenerationInfo{ - "MlflowService": { - FileNameWithoutExtension: "tracking", - ServiceName: "TrackingService", - ImplementedEndpoints: []string{ - "getExperimentByName", - "createExperiment", - // "searchExperiments", - "getExperiment", - "deleteExperiment", - "restoreExperiment", - "updateExperiment", - "getRun", - "createRun", - "updateRun", - "deleteRun", - "restoreRun", - "logMetric", - // "logParam", - // "setExperimentTag", - // "setTag", - // "setTraceTag", - // "deleteTraceTag", - // "deleteTag", - "searchRuns", - // "listArtifacts", - // "getMetricHistory", - // "getMetricHistoryBulkInterval", - "logBatch", - // "logModel", - // "logInputs", - // "startTrace", - // "endTrace", - // "getTraceInfo", - // "searchTraces", - // "deleteTraces", - }, - }, - "ModelRegistryService": { - FileNameWithoutExtension: "model_registry", - ServiceName: "ModelRegistryService", - ImplementedEndpoints: []string{ - // "createRegisteredModel", - // "renameRegisteredModel", - // "updateRegisteredModel", - // "deleteRegisteredModel", - // "getRegisteredModel", - // "searchRegisteredModels", - "getLatestVersions", - // "createModelVersion", - // "updateModelVersion", - // "transitionModelVersionStage", - // "deleteModelVersion", - // "getModelVersion", - // "searchModelVersions", - // "getModelVersionDownloadUri", - // "setRegisteredModelTag", - // "setModelVersionTag", - // "deleteRegisteredModelTag", - // "deleteModelVersionTag", - // "setRegisteredModelAlias", - // "deleteRegisteredModelAlias", - // "getModelVersionByAlias", - }, - }, - "MlflowArtifactsService": { - FileNameWithoutExtension: "artifacts", - ServiceName: "ArtifactsService", - ImplementedEndpoints: []string{ - // "downloadArtifact", - // "uploadArtifact", - // "listArtifacts", - // "deleteArtifact", - // "createMultipartUpload", - // "completeMultipartUpload", - // "abortMultipartUpload", - }, - }, -} +package generate + +type ServiceGenerationInfo struct { + FileNameWithoutExtension string + ServiceName string + ImplementedEndpoints []string +} + +var ServiceInfoMap = map[string]ServiceGenerationInfo{ + "MlflowService": { + FileNameWithoutExtension: "tracking", + ServiceName: "TrackingService", + ImplementedEndpoints: []string{ + "getExperimentByName", + "createExperiment", + // "searchExperiments", + "getExperiment", + "deleteExperiment", + "restoreExperiment", + "updateExperiment", + "getRun", + "createRun", + "updateRun", + "deleteRun", + "restoreRun", + "logMetric", + // "logParam", + // "setExperimentTag", + "setTag", + // "setTraceTag", + // "deleteTraceTag", + "deleteTag", + "searchRuns", + // "listArtifacts", + // "getMetricHistory", + // "getMetricHistoryBulkInterval", + "logBatch", + // "logModel", + // "logInputs", + // "startTrace", + // "endTrace", + // "getTraceInfo", + // "searchTraces", + // "deleteTraces", + }, + }, + "ModelRegistryService": { + FileNameWithoutExtension: "model_registry", + ServiceName: "ModelRegistryService", + ImplementedEndpoints: []string{ + // "createRegisteredModel", + // "renameRegisteredModel", + // "updateRegisteredModel", + // "deleteRegisteredModel", + // "getRegisteredModel", + // "searchRegisteredModels", + "getLatestVersions", + // "createModelVersion", + // "updateModelVersion", + // "transitionModelVersionStage", + // "deleteModelVersion", + // "getModelVersion", + // "searchModelVersions", + // "getModelVersionDownloadUri", + // "setRegisteredModelTag", + // "setModelVersionTag", + // "deleteRegisteredModelTag", + // "deleteModelVersionTag", + // "setRegisteredModelAlias", + // "deleteRegisteredModelAlias", + // "getModelVersionByAlias", + }, + }, + "MlflowArtifactsService": { + FileNameWithoutExtension: "artifacts", + ServiceName: "ArtifactsService", + ImplementedEndpoints: []string{ + // "downloadArtifact", + // "uploadArtifact", + // "listArtifacts", + // "deleteArtifact", + // "createMultipartUpload", + // "completeMultipartUpload", + // "abortMultipartUpload", + }, + }, +} diff --git a/magefiles/generate/protos.go b/magefiles/generate/protos.go index 5586e36..f26d4a3 100644 --- a/magefiles/generate/protos.go +++ b/magefiles/generate/protos.go @@ -1,58 +1,58 @@ -package generate - -import ( - "fmt" - "os/exec" - "path" - "strings" -) - -const MLFlowCommit = "3effa7380c86946f4557f03aa81119a097d8b433" - -var protoFiles = map[string]string{ - "databricks.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "service.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "model_registry.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "databricks_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "mlflow_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos/artifacts", - "internal.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "scalapb/scalapb.proto": "github.com/mlflow/mlflow-go/pkg/protos/scalapb", -} - -const fixedArguments = 3 - -func RunProtoc(protoDir string) error { - arguments := make([]string, 0, len(protoFiles)*2+fixedArguments) - - arguments = append( - arguments, - "-I="+protoDir, - `--go_out=.`, - `--go_opt=module=github.com/mlflow/mlflow-go`, - ) - - for fileName, goPackage := range protoFiles { - arguments = append( - arguments, - fmt.Sprintf("--go_opt=M%s=%s", fileName, goPackage), - ) - } - - for fileName := range protoFiles { - arguments = append(arguments, path.Join(protoDir, fileName)) - } - - cmd := exec.Command("protoc", arguments...) - - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf( - "failed to run protoc %s process, err: %s: %w", - strings.Join(arguments, " "), - output, - err, - ) - } - - return nil -} +package generate + +import ( + "fmt" + "os/exec" + "path" + "strings" +) + +const MLFlowCommit = "3effa7380c86946f4557f03aa81119a097d8b433" + +var protoFiles = map[string]string{ + "databricks.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "service.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "model_registry.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "databricks_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "mlflow_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos/artifacts", + "internal.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "scalapb/scalapb.proto": "github.com/mlflow/mlflow-go/pkg/protos/scalapb", +} + +const fixedArguments = 3 + +func RunProtoc(protoDir string) error { + arguments := make([]string, 0, len(protoFiles)*2+fixedArguments) + + arguments = append( + arguments, + "-I="+protoDir, + `--go_out=.`, + `--go_opt=module=github.com/mlflow/mlflow-go`, + ) + + for fileName, goPackage := range protoFiles { + arguments = append( + arguments, + fmt.Sprintf("--go_opt=M%s=%s", fileName, goPackage), + ) + } + + for fileName := range protoFiles { + arguments = append(arguments, path.Join(protoDir, fileName)) + } + + cmd := exec.Command("protoc", arguments...) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf( + "failed to run protoc %s process, err: %s: %w", + strings.Join(arguments, " "), + output, + err, + ) + } + + return nil +} diff --git a/magefiles/generate/query_annotations.go b/magefiles/generate/query_annotations.go index c661e9b..a7c0b41 100644 --- a/magefiles/generate/query_annotations.go +++ b/magefiles/generate/query_annotations.go @@ -1,116 +1,116 @@ -package generate - -import ( - "fmt" - "go/ast" - "go/parser" - "go/token" - "io/fs" - "os" - "path/filepath" - "regexp" - "strings" -) - -// Inspect the AST of the incoming file and add a query annotation to the struct tags which have a json tag. -// -//nolint:funlen,cyclop -func addQueryAnnotation(generatedGoFile string) error { - // Parse the file into an AST - fset := token.NewFileSet() - - node, err := parser.ParseFile(fset, generatedGoFile, nil, parser.ParseComments) - if err != nil { - return fmt.Errorf("add query annotation failed: %w", err) - } - - // Create an AST inspector to modify specific struct tags - ast.Inspect(node, func(n ast.Node) bool { - // Look for struct type declarations - typeSpec, isTypeSpec := n.(*ast.TypeSpec) - if !isTypeSpec { - return true - } - - structType, isStructType := typeSpec.Type.(*ast.StructType) - - if !isStructType { - return true - } - - // Iterate over fields in the struct - for _, field := range structType.Fields.List { - if field.Tag == nil { - continue - } - - tagValue := field.Tag.Value - - hasQuery := strings.Contains(tagValue, "query:") - hasValidate := strings.Contains(tagValue, "validate:") - validationKey := fmt.Sprintf("%s_%s", typeSpec.Name, field.Names[0]) - validationRule, needsValidation := validations[validationKey] - - if hasQuery && (!needsValidation || needsValidation && hasValidate) { - continue - } - - // With opening ` tick - newTag := tagValue[0 : len(tagValue)-1] - - matches := jsonFieldTagRegexp.FindStringSubmatch(tagValue) - if len(matches) > 0 && !hasQuery { - // Modify the tag here - // The json annotation could be something like `json:"key,omitempty"` - // We only want the key part, the `omitempty` is not relevant for the query annotation - key := matches[1] - if strings.Contains(key, ",") { - key = strings.Split(key, ",")[0] - } - // Add query annotation - newTag += fmt.Sprintf(" query:\"%s\"", key) - } - - if needsValidation { - // Add validation annotation - newTag += fmt.Sprintf(" validate:\"%s\"", validationRule) - } - - // Closing ` tick - newTag += "`" - field.Tag.Value = newTag - } - - return false - }) - - return saveASTToFile(fset, node, false, generatedGoFile) -} - -var jsonFieldTagRegexp = regexp.MustCompile(`json:"([^"]+)"`) - -//nolint:err113 -func AddQueryAnnotations(pkgFolder string) error { - protoFolder := filepath.Join(pkgFolder, "protos") - - if _, pathError := os.Stat(protoFolder); os.IsNotExist(pathError) { - return fmt.Errorf("the %s folder does not exist. Are the Go protobuf files generated?", protoFolder) - } - - err := filepath.WalkDir(protoFolder, func(path string, _ fs.DirEntry, err error) error { - if err != nil { - return err - } - - if filepath.Ext(path) == ".go" { - err = addQueryAnnotation(path) - } - - return err - }) - if err != nil { - return fmt.Errorf("failed to add query annotation: %w", err) - } - - return nil -} +package generate + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" +) + +// Inspect the AST of the incoming file and add a query annotation to the struct tags which have a json tag. +// +//nolint:funlen,cyclop +func addQueryAnnotation(generatedGoFile string) error { + // Parse the file into an AST + fset := token.NewFileSet() + + node, err := parser.ParseFile(fset, generatedGoFile, nil, parser.ParseComments) + if err != nil { + return fmt.Errorf("add query annotation failed: %w", err) + } + + // Create an AST inspector to modify specific struct tags + ast.Inspect(node, func(n ast.Node) bool { + // Look for struct type declarations + typeSpec, isTypeSpec := n.(*ast.TypeSpec) + if !isTypeSpec { + return true + } + + structType, isStructType := typeSpec.Type.(*ast.StructType) + + if !isStructType { + return true + } + + // Iterate over fields in the struct + for _, field := range structType.Fields.List { + if field.Tag == nil { + continue + } + + tagValue := field.Tag.Value + + hasQuery := strings.Contains(tagValue, "query:") + hasValidate := strings.Contains(tagValue, "validate:") + validationKey := fmt.Sprintf("%s_%s", typeSpec.Name, field.Names[0]) + validationRule, needsValidation := validations[validationKey] + + if hasQuery && (!needsValidation || needsValidation && hasValidate) { + continue + } + + // With opening ` tick + newTag := tagValue[0 : len(tagValue)-1] + + matches := jsonFieldTagRegexp.FindStringSubmatch(tagValue) + if len(matches) > 0 && !hasQuery { + // Modify the tag here + // The json annotation could be something like `json:"key,omitempty"` + // We only want the key part, the `omitempty` is not relevant for the query annotation + key := matches[1] + if strings.Contains(key, ",") { + key = strings.Split(key, ",")[0] + } + // Add query annotation + newTag += fmt.Sprintf(" query:\"%s\"", key) + } + + if needsValidation { + // Add validation annotation + newTag += fmt.Sprintf(" validate:\"%s\"", validationRule) + } + + // Closing ` tick + newTag += "`" + field.Tag.Value = newTag + } + + return false + }) + + return saveASTToFile(fset, node, false, generatedGoFile) +} + +var jsonFieldTagRegexp = regexp.MustCompile(`json:"([^"]+)"`) + +//nolint:err113 +func AddQueryAnnotations(pkgFolder string) error { + protoFolder := filepath.Join(pkgFolder, "protos") + + if _, pathError := os.Stat(protoFolder); os.IsNotExist(pathError) { + return fmt.Errorf("the %s folder does not exist. Are the Go protobuf files generated?", protoFolder) + } + + err := filepath.WalkDir(protoFolder, func(path string, _ fs.DirEntry, err error) error { + if err != nil { + return err + } + + if filepath.Ext(path) == ".go" { + err = addQueryAnnotation(path) + } + + return err + }) + if err != nil { + return fmt.Errorf("failed to add query annotation: %w", err) + } + + return nil +} diff --git a/magefiles/generate/source_code.go b/magefiles/generate/source_code.go index 7146d2c..0f15325 100644 --- a/magefiles/generate/source_code.go +++ b/magefiles/generate/source_code.go @@ -1,468 +1,468 @@ -package generate - -import ( - "bufio" - "fmt" - "go/ast" - "go/format" - "go/token" - "net/http" - "os" - "path/filepath" - - "github.com/iancoleman/strcase" - - "github.com/mlflow/mlflow-go/magefiles/generate/discovery" -) - -func mkMethodInfoInputPointerType(methodInfo discovery.MethodInfo) *ast.StarExpr { - return mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Input)) -} - -// Generate a method declaration on an service interface. -func mkServiceInterfaceMethod(methodInfo discovery.MethodInfo) *ast.Field { - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(strcase.ToCamel(methodInfo.Name))}, - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("ctx", mkSelectorExpr("context", "Context")), - mkNamedField("input", mkMethodInfoInputPointerType(methodInfo)), - }, - }, - Results: &ast.FieldList{ - List: []*ast.Field{ - mkField(mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Output))), - mkField(mkStarExpr(mkSelectorExpr("contract", "Error"))), - }, - }, - }, - } -} - -// Generate a service interface declaration. -func mkServiceInterfaceNode( - endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, -) *ast.GenDecl { - // We add one method to validate any of the input structs - methods := make([]*ast.Field, 0, len(serviceInfo.Methods)) - - for _, method := range serviceInfo.Methods { - if _, ok := endpoints[method.Name]; ok { - methods = append(methods, mkServiceInterfaceMethod(method)) - } - } - - // Create an interface declaration - return &ast.GenDecl{ - Tok: token.TYPE, // Specifies a type declaration - Specs: []ast.Spec{ - &ast.TypeSpec{ - Name: ast.NewIdent(interfaceName), - Type: &ast.InterfaceType{ - Methods: &ast.FieldList{ - List: methods, - }, - }, - }, - }, - } -} - -func saveASTToFile(fset *token.FileSet, file *ast.File, addComment bool, outputPath string) error { - // Create or truncate the output file - outputFile, err := os.Create(outputPath) - if err != nil { - return fmt.Errorf("failed to create output file: %w", err) - } - defer outputFile.Close() - - // Use a bufio.Writer for buffered writing (optional) - writer := bufio.NewWriter(outputFile) - defer writer.Flush() - - if addComment { - _, err := writer.WriteString("// Code generated by mlflow/go/cmd/generate/main.go. DO NOT EDIT.\n\n") - if err != nil { - return fmt.Errorf("failed to add comment to generated file: %w", err) - } - } - - // Write the generated code to the file - err = format.Node(writer, fset, file) - if err != nil { - return fmt.Errorf("failed to write generated AST to file: %w", err) - } - - return nil -} - -//nolint:funlen -func mkAppRoute(method discovery.MethodInfo, endpoint discovery.Endpoint) ast.Stmt { - urlExpr := &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s"`, endpoint.GetFiberPath())} - - // input := &protos.SearchExperiments - inputExpr := mkAssignStmt( - []ast.Expr{ast.NewIdent("input")}, - []ast.Expr{ - mkAmpExpr(&ast.CompositeLit{ - Type: mkSelectorExpr(method.PackageName, method.Input), - }), - }) - - // if err := parser.ParseQuery(ctx, input); err != nil { return err } - // if err := parser.ParseBody(ctx, input); err != nil { return err } - var extractModel ast.Expr - if endpoint.Method == http.MethodGet { - extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseQuery"), ast.NewIdent("ctx"), ast.NewIdent("input")) - } else { - extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseBody"), ast.NewIdent("ctx"), ast.NewIdent("input")) - } - - inputErrorCheck := mkIfStmt( - mkAssignStmt([]ast.Expr{ast.NewIdent("err")}, []ast.Expr{extractModel}), - errNotEqualNil, - mkBlockStmt(returnErr), - ) - - // output, err := service.Method(input) - outputExpr := mkAssignStmt([]ast.Expr{ - ast.NewIdent("output"), - ast.NewIdent("err"), - }, []ast.Expr{ - mkCallExpr( - mkSelectorExpr( - "service", - strcase.ToCamel(method.Name), - ), - mkCallExpr( - mkSelectorExpr("utils", "NewContextWithLoggerFromFiberContext"), - ast.NewIdent("ctx"), - ), - ast.NewIdent("input"), - ), - }) - - // if err != nil { - // return err - // } - errorCheck := mkIfStmt( - nil, - errNotEqualNil, - mkBlockStmt( - mkReturnStmt(ast.NewIdent("err")), - ), - ) - - // return ctx.JSON(output) - returnExpr := mkReturnStmt(mkCallExpr(mkSelectorExpr("ctx", "JSON"), ast.NewIdent("output"))) - - // func(ctx *fiber.Ctx) error { .. } - funcExpr := &ast.FuncLit{ - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("ctx", mkStarExpr(mkSelectorExpr("fiber", "Ctx"))), - }, - }, - Results: &ast.FieldList{ - List: []*ast.Field{ - mkField(ast.NewIdent("error")), - }, - }, - }, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - inputExpr, - inputErrorCheck, - outputExpr, - errorCheck, - returnExpr, - }, - }, - } - - return &ast.ExprStmt{ - // app.Get("/mlflow/experiments/search", func(ctx *fiber.Ctx) error { .. }) - X: mkCallExpr( - mkSelectorExpr("app", strcase.ToCamel(endpoint.Method)), urlExpr, funcExpr, - ), - } -} - -func mkRouteRegistrationFunction( - endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, -) *ast.FuncDecl { - routes := make([]ast.Stmt, 0, len(serviceInfo.Methods)) - - for _, method := range serviceInfo.Methods { - for _, endpoint := range method.Endpoints { - if _, ok := endpoints[method.Name]; ok { - routes = append(routes, mkAppRoute(method, endpoint)) - } - } - } - - return &ast.FuncDecl{ - Name: ast.NewIdent(fmt.Sprintf("Register%sRoutes", interfaceName)), - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("service", mkSelectorExpr("service", interfaceName)), - mkNamedField("parser", mkStarExpr(mkSelectorExpr("parser", "HTTPRequestParser"))), - mkNamedField("app", mkStarExpr(ast.NewIdent("fiber.App"))), - }, - }, - }, - Body: &ast.BlockStmt{ - List: routes, - }, - } -} - -func mkGeneratedFile(pkg, outputPath string, decls []ast.Decl) error { - // Set up the FileSet and the AST File - fset := token.NewFileSet() - - file := &ast.File{ - Name: ast.NewIdent(pkg), - Decls: decls, - } - - err := saveASTToFile(fset, file, true, outputPath) - if err != nil { - return fmt.Errorf("failed to save AST to file: %w", err) - } - - return nil -} - -const expectedImportStatements = 2 - -// Generate the service interface. -func generateServices( - pkgFolder string, - serviceInfo discovery.ServiceInfo, - generationInfo ServiceGenerationInfo, - endpoints map[string]any, -) error { - decls := make([]ast.Decl, 0, len(endpoints)+expectedImportStatements) - - if len(endpoints) > 0 { - decls = append(decls, - mkImportStatements( - `"context"`, - `"github.com/mlflow/mlflow-go/pkg/protos"`, - `"github.com/mlflow/mlflow-go/pkg/contract"`, - )) - } - - decls = append(decls, mkServiceInterfaceNode( - endpoints, - generationInfo.ServiceName, - serviceInfo, - )) - - fileName := generationInfo.FileNameWithoutExtension + ".g.go" - pkg := "service" - outputPath := filepath.Join(pkgFolder, "contract", pkg, fileName) - - return mkGeneratedFile(pkg, outputPath, decls) -} - -func generateRouteRegistrations( - pkgFolder string, - serviceInfo discovery.ServiceInfo, - generationInfo ServiceGenerationInfo, - endpoints map[string]any, -) error { - importStatements := []string{ - `"github.com/gofiber/fiber/v2"`, - `"github.com/mlflow/mlflow-go/pkg/server/parser"`, - `"github.com/mlflow/mlflow-go/pkg/contract/service"`, - } - - if len(endpoints) > 0 { - importStatements = append( - importStatements, - `"github.com/mlflow/mlflow-go/pkg/utils"`, - `"github.com/mlflow/mlflow-go/pkg/protos"`, - ) - } - - decls := []ast.Decl{ - mkImportStatements(importStatements...), - mkRouteRegistrationFunction(endpoints, generationInfo.ServiceName, serviceInfo), - } - - fileName := generationInfo.FileNameWithoutExtension + ".g.go" - pkg := "routes" - outputPath := filepath.Join(pkgFolder, "server", pkg, fileName) - - return mkGeneratedFile(pkg, outputPath, decls) -} - -func mkCEndpointBody(serviceName string, method discovery.MethodInfo) *ast.BlockStmt { - mapName := strcase.ToLowerCamel(serviceName) + "s" - - return &ast.BlockStmt{ - List: []ast.Stmt{ - // service, err := trackingServices.Get(serviceID) - mkAssignStmt( - []ast.Expr{ - ast.NewIdent("service"), - ast.NewIdent("err"), - }, - []ast.Expr{ - mkCallExpr(mkSelectorExpr(mapName, "Get"), ast.NewIdent("serviceID")), - }, - ), - // if err != nil { - // return makePointerFromError(err, responseSize) - // } - mkIfStmt( - nil, - errNotEqualNil, - mkBlockStmt( - mkReturnStmt( - mkCallExpr( - ast.NewIdent("makePointerFromError"), - ast.NewIdent("err"), - ast.NewIdent("responseSize"), - ), - ), - ), - ), - // return invokeServiceMethod( - // service.GetExperiment, - // new(protos.GetExperiment), - // requestData, - // requestSize, - // responseSize, - // ) - mkReturnStmt( - mkCallExpr( - ast.NewIdent("invokeServiceMethod"), - mkSelectorExpr("service", strcase.ToCamel(method.Name)), - mkCallExpr(ast.NewIdent("new"), mkSelectorExpr("protos", method.Input)), - ast.NewIdent("requestData"), - ast.NewIdent("requestSize"), - ast.NewIdent("responseSize"), - ), - ), - }, - } -} - -func mkCEndpoint(serviceName string, method discovery.MethodInfo) *ast.FuncDecl { - functionName := fmt.Sprintf("%s%s", serviceName, strcase.ToCamel(method.Name)) - - return &ast.FuncDecl{ - Doc: &ast.CommentGroup{ - List: []*ast.Comment{ - { - Text: "//export " + functionName, - }, - }, - }, - Name: ast.NewIdent(functionName), - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("serviceID", ast.NewIdent("int64")), - mkNamedField("requestData", mkSelectorExpr("unsafe", "Pointer")), - mkNamedField("requestSize", mkSelectorExpr("C", "int")), - mkNamedField("responseSize", mkStarExpr(mkSelectorExpr("C", "int"))), - }, - }, - Results: &ast.FieldList{ - List: []*ast.Field{ - mkField(mkSelectorExpr("unsafe", "Pointer")), - }, - }, - }, - Body: mkCEndpointBody(serviceName, method), - } -} - -func mkCEndpoints( - endpoints map[string]any, serviceName string, serviceInfo discovery.ServiceInfo, -) []*ast.FuncDecl { - funcs := make([]*ast.FuncDecl, 0, len(endpoints)) - - for _, method := range serviceInfo.Methods { - if _, ok := endpoints[method.Name]; ok { - funcs = append(funcs, mkCEndpoint(serviceName, method)) - } - } - - return funcs -} - -func generateEndpoints( - pkgFolder string, - serviceInfo discovery.ServiceInfo, - generationInfo ServiceGenerationInfo, - endpoints map[string]any, -) error { - decls := []ast.Decl{ - mkImportStatements(`"C"`), - } - - if len(endpoints) > 0 { - decls = append( - decls, - mkImportStatements( - `"unsafe"`, - `"github.com/mlflow/mlflow-go/pkg/protos"`, - ), - ) - - endpoints := mkCEndpoints(endpoints, generationInfo.ServiceName, serviceInfo) - for _, endpoint := range endpoints { - decls = append(decls, endpoint) - } - } - - fileName := generationInfo.FileNameWithoutExtension + ".g.go" - outputPath := filepath.Join(pkgFolder, "lib", fileName) - - return mkGeneratedFile("main", outputPath, decls) -} - -func SourceCode(pkgFolder string) error { - services, err := discovery.GetServiceInfos() - if err != nil { - return fmt.Errorf("failed to get service info: %w", err) - } - - for _, serviceInfo := range services { - generationInfo, ok := ServiceInfoMap[serviceInfo.Name] - if !ok { - continue - } - - endpoints := make(map[string]any, len(generationInfo.ImplementedEndpoints)) - - for _, endpoint := range generationInfo.ImplementedEndpoints { - endpoints[endpoint] = nil - } - - err = generateServices(pkgFolder, serviceInfo, generationInfo, endpoints) - if err != nil { - return err - } - - err = generateRouteRegistrations(pkgFolder, serviceInfo, generationInfo, endpoints) - if err != nil { - return err - } - - err = generateEndpoints(pkgFolder, serviceInfo, generationInfo, endpoints) - if err != nil { - return err - } - } - - return nil -} +package generate + +import ( + "bufio" + "fmt" + "go/ast" + "go/format" + "go/token" + "net/http" + "os" + "path/filepath" + + "github.com/iancoleman/strcase" + + "github.com/mlflow/mlflow-go/magefiles/generate/discovery" +) + +func mkMethodInfoInputPointerType(methodInfo discovery.MethodInfo) *ast.StarExpr { + return mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Input)) +} + +// Generate a method declaration on an service interface. +func mkServiceInterfaceMethod(methodInfo discovery.MethodInfo) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(strcase.ToCamel(methodInfo.Name))}, + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("ctx", mkSelectorExpr("context", "Context")), + mkNamedField("input", mkMethodInfoInputPointerType(methodInfo)), + }, + }, + Results: &ast.FieldList{ + List: []*ast.Field{ + mkField(mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Output))), + mkField(mkStarExpr(mkSelectorExpr("contract", "Error"))), + }, + }, + }, + } +} + +// Generate a service interface declaration. +func mkServiceInterfaceNode( + endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, +) *ast.GenDecl { + // We add one method to validate any of the input structs + methods := make([]*ast.Field, 0, len(serviceInfo.Methods)) + + for _, method := range serviceInfo.Methods { + if _, ok := endpoints[method.Name]; ok { + methods = append(methods, mkServiceInterfaceMethod(method)) + } + } + + // Create an interface declaration + return &ast.GenDecl{ + Tok: token.TYPE, // Specifies a type declaration + Specs: []ast.Spec{ + &ast.TypeSpec{ + Name: ast.NewIdent(interfaceName), + Type: &ast.InterfaceType{ + Methods: &ast.FieldList{ + List: methods, + }, + }, + }, + }, + } +} + +func saveASTToFile(fset *token.FileSet, file *ast.File, addComment bool, outputPath string) error { + // Create or truncate the output file + outputFile, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer outputFile.Close() + + // Use a bufio.Writer for buffered writing (optional) + writer := bufio.NewWriter(outputFile) + defer writer.Flush() + + if addComment { + _, err := writer.WriteString("// Code generated by mlflow/go/cmd/generate/main.go. DO NOT EDIT.\n\n") + if err != nil { + return fmt.Errorf("failed to add comment to generated file: %w", err) + } + } + + // Write the generated code to the file + err = format.Node(writer, fset, file) + if err != nil { + return fmt.Errorf("failed to write generated AST to file: %w", err) + } + + return nil +} + +//nolint:funlen +func mkAppRoute(method discovery.MethodInfo, endpoint discovery.Endpoint) ast.Stmt { + urlExpr := &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s"`, endpoint.GetFiberPath())} + + // input := &protos.SearchExperiments + inputExpr := mkAssignStmt( + []ast.Expr{ast.NewIdent("input")}, + []ast.Expr{ + mkAmpExpr(&ast.CompositeLit{ + Type: mkSelectorExpr(method.PackageName, method.Input), + }), + }) + + // if err := parser.ParseQuery(ctx, input); err != nil { return err } + // if err := parser.ParseBody(ctx, input); err != nil { return err } + var extractModel ast.Expr + if endpoint.Method == http.MethodGet { + extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseQuery"), ast.NewIdent("ctx"), ast.NewIdent("input")) + } else { + extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseBody"), ast.NewIdent("ctx"), ast.NewIdent("input")) + } + + inputErrorCheck := mkIfStmt( + mkAssignStmt([]ast.Expr{ast.NewIdent("err")}, []ast.Expr{extractModel}), + errNotEqualNil, + mkBlockStmt(returnErr), + ) + + // output, err := service.Method(input) + outputExpr := mkAssignStmt([]ast.Expr{ + ast.NewIdent("output"), + ast.NewIdent("err"), + }, []ast.Expr{ + mkCallExpr( + mkSelectorExpr( + "service", + strcase.ToCamel(method.Name), + ), + mkCallExpr( + mkSelectorExpr("utils", "NewContextWithLoggerFromFiberContext"), + ast.NewIdent("ctx"), + ), + ast.NewIdent("input"), + ), + }) + + // if err != nil { + // return err + // } + errorCheck := mkIfStmt( + nil, + errNotEqualNil, + mkBlockStmt( + mkReturnStmt(ast.NewIdent("err")), + ), + ) + + // return ctx.JSON(output) + returnExpr := mkReturnStmt(mkCallExpr(mkSelectorExpr("ctx", "JSON"), ast.NewIdent("output"))) + + // func(ctx *fiber.Ctx) error { .. } + funcExpr := &ast.FuncLit{ + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("ctx", mkStarExpr(mkSelectorExpr("fiber", "Ctx"))), + }, + }, + Results: &ast.FieldList{ + List: []*ast.Field{ + mkField(ast.NewIdent("error")), + }, + }, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + inputExpr, + inputErrorCheck, + outputExpr, + errorCheck, + returnExpr, + }, + }, + } + + return &ast.ExprStmt{ + // app.Get("/mlflow/experiments/search", func(ctx *fiber.Ctx) error { .. }) + X: mkCallExpr( + mkSelectorExpr("app", strcase.ToCamel(endpoint.Method)), urlExpr, funcExpr, + ), + } +} + +func mkRouteRegistrationFunction( + endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, +) *ast.FuncDecl { + routes := make([]ast.Stmt, 0, len(serviceInfo.Methods)) + + for _, method := range serviceInfo.Methods { + for _, endpoint := range method.Endpoints { + if _, ok := endpoints[method.Name]; ok { + routes = append(routes, mkAppRoute(method, endpoint)) + } + } + } + + return &ast.FuncDecl{ + Name: ast.NewIdent(fmt.Sprintf("Register%sRoutes", interfaceName)), + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("service", mkSelectorExpr("service", interfaceName)), + mkNamedField("parser", mkStarExpr(mkSelectorExpr("parser", "HTTPRequestParser"))), + mkNamedField("app", mkStarExpr(ast.NewIdent("fiber.App"))), + }, + }, + }, + Body: &ast.BlockStmt{ + List: routes, + }, + } +} + +func mkGeneratedFile(pkg, outputPath string, decls []ast.Decl) error { + // Set up the FileSet and the AST File + fset := token.NewFileSet() + + file := &ast.File{ + Name: ast.NewIdent(pkg), + Decls: decls, + } + + err := saveASTToFile(fset, file, true, outputPath) + if err != nil { + return fmt.Errorf("failed to save AST to file: %w", err) + } + + return nil +} + +const expectedImportStatements = 2 + +// Generate the service interface. +func generateServices( + pkgFolder string, + serviceInfo discovery.ServiceInfo, + generationInfo ServiceGenerationInfo, + endpoints map[string]any, +) error { + decls := make([]ast.Decl, 0, len(endpoints)+expectedImportStatements) + + if len(endpoints) > 0 { + decls = append(decls, + mkImportStatements( + `"context"`, + `"github.com/mlflow/mlflow-go/pkg/protos"`, + `"github.com/mlflow/mlflow-go/pkg/contract"`, + )) + } + + decls = append(decls, mkServiceInterfaceNode( + endpoints, + generationInfo.ServiceName, + serviceInfo, + )) + + fileName := generationInfo.FileNameWithoutExtension + ".g.go" + pkg := "service" + outputPath := filepath.Join(pkgFolder, "contract", pkg, fileName) + + return mkGeneratedFile(pkg, outputPath, decls) +} + +func generateRouteRegistrations( + pkgFolder string, + serviceInfo discovery.ServiceInfo, + generationInfo ServiceGenerationInfo, + endpoints map[string]any, +) error { + importStatements := []string{ + `"github.com/gofiber/fiber/v2"`, + `"github.com/mlflow/mlflow-go/pkg/server/parser"`, + `"github.com/mlflow/mlflow-go/pkg/contract/service"`, + } + + if len(endpoints) > 0 { + importStatements = append( + importStatements, + `"github.com/mlflow/mlflow-go/pkg/utils"`, + `"github.com/mlflow/mlflow-go/pkg/protos"`, + ) + } + + decls := []ast.Decl{ + mkImportStatements(importStatements...), + mkRouteRegistrationFunction(endpoints, generationInfo.ServiceName, serviceInfo), + } + + fileName := generationInfo.FileNameWithoutExtension + ".g.go" + pkg := "routes" + outputPath := filepath.Join(pkgFolder, "server", pkg, fileName) + + return mkGeneratedFile(pkg, outputPath, decls) +} + +func mkCEndpointBody(serviceName string, method discovery.MethodInfo) *ast.BlockStmt { + mapName := strcase.ToLowerCamel(serviceName) + "s" + + return &ast.BlockStmt{ + List: []ast.Stmt{ + // service, err := trackingServices.Get(serviceID) + mkAssignStmt( + []ast.Expr{ + ast.NewIdent("service"), + ast.NewIdent("err"), + }, + []ast.Expr{ + mkCallExpr(mkSelectorExpr(mapName, "Get"), ast.NewIdent("serviceID")), + }, + ), + // if err != nil { + // return makePointerFromError(err, responseSize) + // } + mkIfStmt( + nil, + errNotEqualNil, + mkBlockStmt( + mkReturnStmt( + mkCallExpr( + ast.NewIdent("makePointerFromError"), + ast.NewIdent("err"), + ast.NewIdent("responseSize"), + ), + ), + ), + ), + // return invokeServiceMethod( + // service.GetExperiment, + // new(protos.GetExperiment), + // requestData, + // requestSize, + // responseSize, + // ) + mkReturnStmt( + mkCallExpr( + ast.NewIdent("invokeServiceMethod"), + mkSelectorExpr("service", strcase.ToCamel(method.Name)), + mkCallExpr(ast.NewIdent("new"), mkSelectorExpr("protos", method.Input)), + ast.NewIdent("requestData"), + ast.NewIdent("requestSize"), + ast.NewIdent("responseSize"), + ), + ), + }, + } +} + +func mkCEndpoint(serviceName string, method discovery.MethodInfo) *ast.FuncDecl { + functionName := fmt.Sprintf("%s%s", serviceName, strcase.ToCamel(method.Name)) + + return &ast.FuncDecl{ + Doc: &ast.CommentGroup{ + List: []*ast.Comment{ + { + Text: "//export " + functionName, + }, + }, + }, + Name: ast.NewIdent(functionName), + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("serviceID", ast.NewIdent("int64")), + mkNamedField("requestData", mkSelectorExpr("unsafe", "Pointer")), + mkNamedField("requestSize", mkSelectorExpr("C", "int")), + mkNamedField("responseSize", mkStarExpr(mkSelectorExpr("C", "int"))), + }, + }, + Results: &ast.FieldList{ + List: []*ast.Field{ + mkField(mkSelectorExpr("unsafe", "Pointer")), + }, + }, + }, + Body: mkCEndpointBody(serviceName, method), + } +} + +func mkCEndpoints( + endpoints map[string]any, serviceName string, serviceInfo discovery.ServiceInfo, +) []*ast.FuncDecl { + funcs := make([]*ast.FuncDecl, 0, len(endpoints)) + + for _, method := range serviceInfo.Methods { + if _, ok := endpoints[method.Name]; ok { + funcs = append(funcs, mkCEndpoint(serviceName, method)) + } + } + + return funcs +} + +func generateEndpoints( + pkgFolder string, + serviceInfo discovery.ServiceInfo, + generationInfo ServiceGenerationInfo, + endpoints map[string]any, +) error { + decls := []ast.Decl{ + mkImportStatements(`"C"`), + } + + if len(endpoints) > 0 { + decls = append( + decls, + mkImportStatements( + `"unsafe"`, + `"github.com/mlflow/mlflow-go/pkg/protos"`, + ), + ) + + endpoints := mkCEndpoints(endpoints, generationInfo.ServiceName, serviceInfo) + for _, endpoint := range endpoints { + decls = append(decls, endpoint) + } + } + + fileName := generationInfo.FileNameWithoutExtension + ".g.go" + outputPath := filepath.Join(pkgFolder, "lib", fileName) + + return mkGeneratedFile("main", outputPath, decls) +} + +func SourceCode(pkgFolder string) error { + services, err := discovery.GetServiceInfos() + if err != nil { + return fmt.Errorf("failed to get service info: %w", err) + } + + for _, serviceInfo := range services { + generationInfo, ok := ServiceInfoMap[serviceInfo.Name] + if !ok { + continue + } + + endpoints := make(map[string]any, len(generationInfo.ImplementedEndpoints)) + + for _, endpoint := range generationInfo.ImplementedEndpoints { + endpoints[endpoint] = nil + } + + err = generateServices(pkgFolder, serviceInfo, generationInfo, endpoints) + if err != nil { + return err + } + + err = generateRouteRegistrations(pkgFolder, serviceInfo, generationInfo, endpoints) + if err != nil { + return err + } + + err = generateEndpoints(pkgFolder, serviceInfo, generationInfo, endpoints) + if err != nil { + return err + } + } + + return nil +} diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index 2f6a59e..c625861 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -1,28 +1,32 @@ -package generate - -var validations = map[string]string{ - "GetExperiment_ExperimentId": "required,stringAsPositiveInteger", - "CreateExperiment_Name": "required,max=500", - "CreateExperiment_ArtifactLocation": "omitempty,uriWithoutFragmentsOrParamsOrDotDotInQuery", - "SearchRuns_RunViewType": "omitempty", - "SearchRuns_MaxResults": "gt=0,max=50000", - "DeleteExperiment_ExperimentId": "required,stringAsPositiveInteger", - "LogBatch_RunId": "required,runId", - "LogBatch_Params": "omitempty,uniqueParams,max=100,dive", - "LogBatch_Metrics": "max=1000,dive", - "LogBatch_Tags": "max=100", - "RunTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", - "RunTag_Value": "omitempty,max=5000", - "Param_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", - "Param_Value": "omitempty,truncate=6000", - "Metric_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", - "Metric_Timestamp": "required", - "Metric_Value": "required", - "CreateRun_ExperimentId": "required,stringAsPositiveInteger", - "GetExperimentByName_ExperimentName": "required", - "GetLatestVersions_Name": "required", - "LogMetric_RunId": "required", - "LogMetric_Key": "required", - "LogMetric_Value": "required", - "LogMetric_Timestamp": "required", -} +package generate + +var validations = map[string]string{ + "GetExperiment_ExperimentId": "required,stringAsPositiveInteger", + "CreateExperiment_Name": "required,max=500", + "CreateExperiment_ArtifactLocation": "omitempty,uriWithoutFragmentsOrParamsOrDotDotInQuery", + "SearchRuns_RunViewType": "omitempty", + "SearchRuns_MaxResults": "gt=0,max=50000", + "DeleteExperiment_ExperimentId": "required,stringAsPositiveInteger", + "LogBatch_RunId": "required,runId", + "LogBatch_Params": "omitempty,uniqueParams,max=100,dive", + "LogBatch_Metrics": "max=1000,dive", + "LogBatch_Tags": "max=100", + "RunTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "RunTag_Value": "omitempty,max=5000", + "Param_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "Param_Value": "omitempty,truncate=6000", + "Metric_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "Metric_Timestamp": "required", + "Metric_Value": "required", + "CreateRun_ExperimentId": "required,stringAsPositiveInteger", + "GetExperimentByName_ExperimentName": "required", + "GetLatestVersions_Name": "required", + "LogMetric_RunId": "required", + "LogMetric_Key": "required", + "LogMetric_Value": "required", + "LogMetric_Timestamp": "required", + "SetTag_RunId": "required", + "SetTag_Key": "required", + "DeleteTag_RunId": "required", + "DeleteTag_Key": "required", +} diff --git a/magefiles/repo.go b/magefiles/repo.go index 25647e7..e5da12d 100644 --- a/magefiles/repo.go +++ b/magefiles/repo.go @@ -1,220 +1,220 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "errors" - "fmt" - "log" - "os" - "path/filepath" - "strings" - - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -const ( - MLFlowRepoFolderName = ".mlflow.repo" -) - -type Repo mg.Namespace - -func folderExists(path string) bool { - info, err := os.Stat(path) - if os.IsNotExist(err) { - return false - } - - return info.IsDir() -} - -func git(args ...string) error { - return sh.RunV("git", args...) -} - -func gitMlflowRepo(args ...string) error { - allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) - - return sh.RunV("git", allArgs...) -} - -func gitMlflowRepoOutput(args ...string) (string, error) { - allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) - - return sh.Output("git", allArgs...) -} - -type gitReference struct { - remote string - reference string -} - -const refFileName = ".mlflow.ref" - -func readFile(filename string) (string, error) { - content, err := os.ReadFile(filename) - if err != nil { - return "", err - } - - return string(content), nil -} - -var ErrInvalidGitRefFormat = errors.New("invalid format in .mlflow.ref file: expected 'remote#reference'") - -func readGitReference() (gitReference, error) { - refFilePath, err := filepath.Abs(refFileName) - if err != nil { - return gitReference{}, fmt.Errorf("failed to get .mlflow.ref: %w", err) - } - - content, err := readFile(refFilePath) - if err != nil { - return gitReference{}, err - } - - parts := strings.Split(content, "#") - - if len(parts) != 2 || parts[0] == "" || parts[1] == "" { - return gitReference{}, ErrInvalidGitRefFormat - } - - remote := strings.TrimSpace(parts[0]) - reference := strings.TrimSpace(parts[1]) - - return gitReference{remote: remote, reference: reference}, nil -} - -func freshCheckout(gitReference gitReference) error { - if err := git("clone", "--no-checkout", gitReference.remote, MLFlowRepoFolderName); err != nil { - return err - } - - return gitMlflowRepo("checkout", gitReference.reference) -} - -func checkRemote(gitReference gitReference) bool { - // git -C .mlflow.repo remote get-url origin - output, err := gitMlflowRepoOutput("remote", "get-url", "origin") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.remote -} - -func checkBranch(gitReference gitReference) bool { - // git -C .mlflow.repo rev-parse --abbrev-ref HEAD - output, err := gitMlflowRepoOutput("rev-parse", "--abbrev-ref", "HEAD") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.reference -} - -func checkTag(gitReference gitReference) bool { - // git -C .mlflow.repo describe --tags HEAD - output, err := gitMlflowRepoOutput("describe", "--tags", "HEAD") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.reference -} - -func checkCommit(gitReference gitReference) bool { - // git -C .mlflow.repo rev-parse HEAD - output, err := gitMlflowRepoOutput("rev-parse", "HEAD") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.reference -} - -func checkReference(gitReference gitReference) bool { - switch { - case checkBranch(gitReference): - log.Printf("Already on branch %q", gitReference.reference) - - return true - case checkTag(gitReference): - log.Printf("Already on tag %q", gitReference.reference) - - return true - case checkCommit(gitReference): - log.Printf("Already on commit %q", gitReference.reference) - - return true - } - - return false -} - -func syncRepo(gitReference gitReference) error { - log.Printf("syncing mlflow repo to %s#%s", gitReference.remote, gitReference.reference) - - if err := gitMlflowRepo("remote", "set-url", "origin", gitReference.remote); err != nil { - return err - } - - if err := gitMlflowRepo("fetch", "origin"); err != nil { - return err - } - - if err := gitMlflowRepo("checkout", gitReference.reference); err != nil { - return err - } - - if checkBranch(gitReference) { - return gitMlflowRepo("pull") - } - - return nil -} - -// Clone or reset the .mlflow.repo fork. -func (Repo) Init() error { - gitReference, err := readGitReference() - if err != nil { - return err - } - - repoPath, err := filepath.Abs(MLFlowRepoFolderName) - if err != nil { - return err - } - - if !folderExists(repoPath) { - return freshCheckout(gitReference) - } - - // Verify remote - if !checkRemote(gitReference) { - log.Printf("Remote %s no longer matches", gitReference.remote) - - return syncRepo(gitReference) - } - - // Verify reference - if !checkReference(gitReference) { - log.Printf("The current reference %q no longer matches", gitReference.reference) - - return syncRepo(gitReference) - } - - return nil -} - -// Forcefully update the .mlflow.repo according to the .mlflow.ref. -func (Repo) Update() error { - gitReference, err := readGitReference() - if err != nil { - return err - } - - return syncRepo(gitReference) -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "errors" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +const ( + MLFlowRepoFolderName = ".mlflow.repo" +) + +type Repo mg.Namespace + +func folderExists(path string) bool { + info, err := os.Stat(path) + if os.IsNotExist(err) { + return false + } + + return info.IsDir() +} + +func git(args ...string) error { + return sh.RunV("git", args...) +} + +func gitMlflowRepo(args ...string) error { + allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) + + return sh.RunV("git", allArgs...) +} + +func gitMlflowRepoOutput(args ...string) (string, error) { + allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) + + return sh.Output("git", allArgs...) +} + +type gitReference struct { + remote string + reference string +} + +const refFileName = ".mlflow.ref" + +func readFile(filename string) (string, error) { + content, err := os.ReadFile(filename) + if err != nil { + return "", err + } + + return string(content), nil +} + +var ErrInvalidGitRefFormat = errors.New("invalid format in .mlflow.ref file: expected 'remote#reference'") + +func readGitReference() (gitReference, error) { + refFilePath, err := filepath.Abs(refFileName) + if err != nil { + return gitReference{}, fmt.Errorf("failed to get .mlflow.ref: %w", err) + } + + content, err := readFile(refFilePath) + if err != nil { + return gitReference{}, err + } + + parts := strings.Split(content, "#") + + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return gitReference{}, ErrInvalidGitRefFormat + } + + remote := strings.TrimSpace(parts[0]) + reference := strings.TrimSpace(parts[1]) + + return gitReference{remote: remote, reference: reference}, nil +} + +func freshCheckout(gitReference gitReference) error { + if err := git("clone", "--no-checkout", gitReference.remote, MLFlowRepoFolderName); err != nil { + return err + } + + return gitMlflowRepo("checkout", gitReference.reference) +} + +func checkRemote(gitReference gitReference) bool { + // git -C .mlflow.repo remote get-url origin + output, err := gitMlflowRepoOutput("remote", "get-url", "origin") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.remote +} + +func checkBranch(gitReference gitReference) bool { + // git -C .mlflow.repo rev-parse --abbrev-ref HEAD + output, err := gitMlflowRepoOutput("rev-parse", "--abbrev-ref", "HEAD") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.reference +} + +func checkTag(gitReference gitReference) bool { + // git -C .mlflow.repo describe --tags HEAD + output, err := gitMlflowRepoOutput("describe", "--tags", "HEAD") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.reference +} + +func checkCommit(gitReference gitReference) bool { + // git -C .mlflow.repo rev-parse HEAD + output, err := gitMlflowRepoOutput("rev-parse", "HEAD") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.reference +} + +func checkReference(gitReference gitReference) bool { + switch { + case checkBranch(gitReference): + log.Printf("Already on branch %q", gitReference.reference) + + return true + case checkTag(gitReference): + log.Printf("Already on tag %q", gitReference.reference) + + return true + case checkCommit(gitReference): + log.Printf("Already on commit %q", gitReference.reference) + + return true + } + + return false +} + +func syncRepo(gitReference gitReference) error { + log.Printf("syncing mlflow repo to %s#%s", gitReference.remote, gitReference.reference) + + if err := gitMlflowRepo("remote", "set-url", "origin", gitReference.remote); err != nil { + return err + } + + if err := gitMlflowRepo("fetch", "origin"); err != nil { + return err + } + + if err := gitMlflowRepo("checkout", gitReference.reference); err != nil { + return err + } + + if checkBranch(gitReference) { + return gitMlflowRepo("pull") + } + + return nil +} + +// Clone or reset the .mlflow.repo fork. +func (Repo) Init() error { + gitReference, err := readGitReference() + if err != nil { + return err + } + + repoPath, err := filepath.Abs(MLFlowRepoFolderName) + if err != nil { + return err + } + + if !folderExists(repoPath) { + return freshCheckout(gitReference) + } + + // Verify remote + if !checkRemote(gitReference) { + log.Printf("Remote %s no longer matches", gitReference.remote) + + return syncRepo(gitReference) + } + + // Verify reference + if !checkReference(gitReference) { + log.Printf("The current reference %q no longer matches", gitReference.reference) + + return syncRepo(gitReference) + } + + return nil +} + +// Forcefully update the .mlflow.repo according to the .mlflow.ref. +func (Repo) Update() error { + gitReference, err := readGitReference() + if err != nil { + return err + } + + return syncRepo(gitReference) +} diff --git a/magefiles/temp.go b/magefiles/temp.go index b835db7..484d553 100644 --- a/magefiles/temp.go +++ b/magefiles/temp.go @@ -1,74 +1,74 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "os" - "path/filepath" - - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -func pipInstall(args ...string) error { - allArgs := append([]string{"install"}, args...) - - return sh.RunV("pip", allArgs...) -} - -func tar(args ...string) error { - return sh.RunV("tar", args...) -} - -func Temp() error { - mg.Deps(Repo.Init) - - // Install our Python package and its dependencies - if err := pipInstall("-e", "."); err != nil { - return err - } - - // Install the dreaded psycho - if err := pipInstall("psycopg2-binary"); err != nil { - return err - } - - // Archive the MLFlow pre-built UI - if err := tar( - "-C", "/usr/local/python/current/lib/python3.8/site-packages/mlflow", - "-czvf", - "./ui.tgz", - "./server/js/build", - ); err != nil { - return err - } - - mlflowRepoPath, err := filepath.Abs(MLFlowRepoFolderName) - if err != nil { - return err - } - - // Add the UI back to it - if err := tar( - "-C", mlflowRepoPath, - "-xzvf", "./ui.tgz", - ); err != nil { - return err - } - - // Remove tar file - tarPath, err := filepath.Abs("ui.tgz") - if err != nil { - return err - } - - defer os.Remove(tarPath) - - // Install it in editable mode - if err := pipInstall("-e", mlflowRepoPath); err != nil { - return err - } - - return nil -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "os" + "path/filepath" + + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +func pipInstall(args ...string) error { + allArgs := append([]string{"install"}, args...) + + return sh.RunV("pip", allArgs...) +} + +func tar(args ...string) error { + return sh.RunV("tar", args...) +} + +func Temp() error { + mg.Deps(Repo.Init) + + // Install our Python package and its dependencies + if err := pipInstall("-e", "."); err != nil { + return err + } + + // Install the dreaded psycho + if err := pipInstall("psycopg2-binary"); err != nil { + return err + } + + // Archive the MLFlow pre-built UI + if err := tar( + "-C", "/usr/local/python/current/lib/python3.8/site-packages/mlflow", + "-czvf", + "./ui.tgz", + "./server/js/build", + ); err != nil { + return err + } + + mlflowRepoPath, err := filepath.Abs(MLFlowRepoFolderName) + if err != nil { + return err + } + + // Add the UI back to it + if err := tar( + "-C", mlflowRepoPath, + "-xzvf", "./ui.tgz", + ); err != nil { + return err + } + + // Remove tar file + tarPath, err := filepath.Abs("ui.tgz") + if err != nil { + return err + } + + defer os.Remove(tarPath) + + // Install it in editable mode + if err := pipInstall("-e", mlflowRepoPath); err != nil { + return err + } + + return nil +} diff --git a/magefiles/tests.go b/magefiles/tests.go index 0139af6..e7a1cf3 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -1,75 +1,102 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "os" - - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -type Test mg.Namespace - -func cleanUpMemoryFile() error { - // Clean up :memory: file - filename := ":memory:" - _, err := os.Stat(filename) - - if err == nil { - // File exists, delete it - err = os.Remove(filename) - if err != nil { - return err - } - } - - return nil -} - -// Run mlflow Python tests against the Go backend. -func (Test) Python() error { - libpath, err := os.MkdirTemp("", "") - if err != nil { - return err - } - - // Remove the Go binary - defer os.RemoveAll(libpath) - //nolint:errcheck - defer cleanUpMemoryFile() - - // Build the Go binary in a temporary directory - if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { - return nil - } - - // Run the tests (currently just the server ones) - if err := sh.RunWithV(map[string]string{ - "MLFLOW_GO_LIBRARY_PATH": libpath, - }, "pytest", - "--confcutdir=.", - ".mlflow.repo/tests/tracking/test_rest_tracking.py", - ".mlflow.repo/tests/tracking/test_model_registry.py", - ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", - ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", - "-k", - "not [file", - // "-vv", - ); err != nil { - return err - } - - return nil -} - -// Run the Go unit tests. -func (Test) Unit() error { - return sh.RunV("go", "test", "./pkg/...") -} - -// Run all tests. -func (Test) All() { - mg.Deps(Test.Unit, Test.Python) -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "os" + + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +type Test mg.Namespace + +func cleanUpMemoryFile() error { + // Clean up :memory: file + filename := ":memory:" + _, err := os.Stat(filename) + + if err == nil { + // File exists, delete it + err = os.Remove(filename) + if err != nil { + return err + } + } + + return nil +} + +// Run mlflow Python tests against the Go backend. +func (Test) Python() error { + libpath, err := os.MkdirTemp("", "") + if err != nil { + return err + } + + // Remove the Go binary + defer os.RemoveAll(libpath) + //nolint:errcheck + defer cleanUpMemoryFile() + + // Build the Go binary in a temporary directory + if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { + return nil + } + + // Run the tests (currently just the server ones) + if err := sh.RunWithV(map[string]string{ + "MLFLOW_GO_LIBRARY_PATH": libpath, + }, "pytest", + "--confcutdir=.", + ".mlflow.repo/tests/tracking/test_rest_tracking.py", + ".mlflow.repo/tests/tracking/test_model_registry.py", + ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", + ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", + "-k", + "not [file", + // "-vv", + ); err != nil { + return err + } + + return nil +} + +// Run specific Python test against the Go backend. +func (Test) PythonSpecific(testName string) error { + libpath, err := os.MkdirTemp("", "") + if err != nil { + return err + } + + defer os.RemoveAll(libpath) + defer cleanUpMemoryFile() + + if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { + return nil + } + + if err := sh.RunWithV(map[string]string{ + "MLFLOW_GO_LIBRARY_PATH": libpath, + }, "pytest", + "--confcutdir=.", + ".mlflow.repo/tests/tracking/test_rest_tracking.py", + "-k", testName, + ); err != nil { + return err + } + + return nil +} + +// Run the Go unit tests. +func (Test) Unit() error { + return sh.RunV("go", "test", "./pkg/...") +} + +// Run all tests. +func (Test) All() { + mg.Deps(Test.Unit, Test.Python) +} diff --git a/mlflow_go/__init__.py b/mlflow_go/__init__.py index d7ac0f1..a9275a8 100644 --- a/mlflow_go/__init__.py +++ b/mlflow_go/__init__.py @@ -1,20 +1,20 @@ -import os - -_go_enabled = "MLFLOW_GO_ENABLED" in os.environ - - -def _set_go_enabled(enabled: bool): - global _go_enabled - _go_enabled = enabled - - -def is_go_enabled(): - return _go_enabled - - -def disable_go(): - _set_go_enabled(False) - - -def enable_go(): - _set_go_enabled(True) +import os + +_go_enabled = "MLFLOW_GO_ENABLED" in os.environ + + +def _set_go_enabled(enabled: bool): + global _go_enabled + _go_enabled = enabled + + +def is_go_enabled(): + return _go_enabled + + +def disable_go(): + _set_go_enabled(False) + + +def enable_go(): + _set_go_enabled(True) diff --git a/mlflow_go/cli.py b/mlflow_go/cli.py index a6ed160..a9d0530 100644 --- a/mlflow_go/cli.py +++ b/mlflow_go/cli.py @@ -1,112 +1,112 @@ -import json -import pathlib -import shlex - -import click -import mlflow.cli -import mlflow.version -from mlflow.utils import find_free_port - -from mlflow_go.lib import get_lib - - -def _get_commands(): - """Returns the MLflow CLI commands with the `server` command replaced with a Go server.""" - commands = mlflow.cli.cli.commands.copy() - - def server( - go_opts, - **kwargs, - ): - # convert the Go options to a dictionary - opts = {} - if go_opts: - for opt in go_opts.split(","): - key, value = opt.split("=", 1) - opts[key] = value - - # validate the Python server configuration if set - if ("python_address" in opts) ^ ("python_command" in opts): - raise click.ClickException("python_address and python_command have to be set together") - - if "python_address" and "python_command" in opts: - # use the provided Python server configuration - python_address = opts["python_address"] - python_command = shlex.split(opts["python_command"]) - else: - # assign a random port for the Python server - python_host = "127.0.0.1" - python_port = find_free_port() - python_address = f"{python_host}:{python_port}" - python_args = kwargs.copy() - python_args.update( - { - "host": python_host, - "port": python_port, - } - ) - - # construct the Python server command - python_command = [ - "mlflow", - "server", - ] - for key, value in python_args.items(): - if isinstance(value, bool): - if value: - python_command.append(f"--{key.replace('_', '-')}") - elif value is not None: - python_command.append(f"--{key.replace('_', '-')}") - python_command.append(str(value)) - - # initialize the Go server configuration - tracking_store_uri = kwargs["backend_store_uri"] - config = { - "address": f'{kwargs["host"]}:{kwargs["port"]}', - "default_artifact_root": mlflow.cli.resolve_default_artifact_root( - kwargs["serve_artifacts"], kwargs["default_artifact_root"], tracking_store_uri - ), - "log_level": opts.get("log_level", "DEBUG" if kwargs["dev"] else "INFO"), - "python_address": python_address, - "python_command": python_command, - "shutdown_timeout": opts.get("shutdown_timeout", "1m"), - "static_folder": pathlib.Path(mlflow.server.__file__) - .parent.joinpath(mlflow.server.REL_STATIC_DIR) - .resolve() - .as_posix(), - "tracking_store_uri": tracking_store_uri, - "model_registry_store_uri": kwargs["registry_store_uri"] or tracking_store_uri, - "version": mlflow.version.VERSION, - } - config_bytes = json.dumps(config).encode("utf-8") - - # start the Go server and check for errors - ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) - if ret != 0: - raise click.ClickException(f"Non-zero exit code: {ret}") - - server.__doc__ = mlflow.cli.server.callback.__doc__ - - server_params = mlflow.cli.server.params.copy() - idx = next((i for i, x in enumerate(mlflow.cli.server.params) if x.name == "gunicorn_opts"), -1) - server_params.insert( - idx, - click.Option( - ["--go-opts"], - default=None, - help="Additional options forwarded to the Go server", - ), - ) - - commands["server"] = click.command(params=server_params)(server) - - return commands - - -@click.group(commands=_get_commands()) -def cli(): - pass - - -if __name__ == "__main__": - cli() +import json +import pathlib +import shlex + +import click +import mlflow.cli +import mlflow.version +from mlflow.utils import find_free_port + +from mlflow_go.lib import get_lib + + +def _get_commands(): + """Returns the MLflow CLI commands with the `server` command replaced with a Go server.""" + commands = mlflow.cli.cli.commands.copy() + + def server( + go_opts, + **kwargs, + ): + # convert the Go options to a dictionary + opts = {} + if go_opts: + for opt in go_opts.split(","): + key, value = opt.split("=", 1) + opts[key] = value + + # validate the Python server configuration if set + if ("python_address" in opts) ^ ("python_command" in opts): + raise click.ClickException("python_address and python_command have to be set together") + + if "python_address" and "python_command" in opts: + # use the provided Python server configuration + python_address = opts["python_address"] + python_command = shlex.split(opts["python_command"]) + else: + # assign a random port for the Python server + python_host = "127.0.0.1" + python_port = find_free_port() + python_address = f"{python_host}:{python_port}" + python_args = kwargs.copy() + python_args.update( + { + "host": python_host, + "port": python_port, + } + ) + + # construct the Python server command + python_command = [ + "mlflow", + "server", + ] + for key, value in python_args.items(): + if isinstance(value, bool): + if value: + python_command.append(f"--{key.replace('_', '-')}") + elif value is not None: + python_command.append(f"--{key.replace('_', '-')}") + python_command.append(str(value)) + + # initialize the Go server configuration + tracking_store_uri = kwargs["backend_store_uri"] + config = { + "address": f'{kwargs["host"]}:{kwargs["port"]}', + "default_artifact_root": mlflow.cli.resolve_default_artifact_root( + kwargs["serve_artifacts"], kwargs["default_artifact_root"], tracking_store_uri + ), + "log_level": opts.get("log_level", "DEBUG" if kwargs["dev"] else "INFO"), + "python_address": python_address, + "python_command": python_command, + "shutdown_timeout": opts.get("shutdown_timeout", "1m"), + "static_folder": pathlib.Path(mlflow.server.__file__) + .parent.joinpath(mlflow.server.REL_STATIC_DIR) + .resolve() + .as_posix(), + "tracking_store_uri": tracking_store_uri, + "model_registry_store_uri": kwargs["registry_store_uri"] or tracking_store_uri, + "version": mlflow.version.VERSION, + } + config_bytes = json.dumps(config).encode("utf-8") + + # start the Go server and check for errors + ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) + if ret != 0: + raise click.ClickException(f"Non-zero exit code: {ret}") + + server.__doc__ = mlflow.cli.server.callback.__doc__ + + server_params = mlflow.cli.server.params.copy() + idx = next((i for i, x in enumerate(mlflow.cli.server.params) if x.name == "gunicorn_opts"), -1) + server_params.insert( + idx, + click.Option( + ["--go-opts"], + default=None, + help="Additional options forwarded to the Go server", + ), + ) + + commands["server"] = click.command(params=server_params)(server) + + return commands + + +@click.group(commands=_get_commands()) +def cli(): + pass + + +if __name__ == "__main__": + cli() diff --git a/mlflow_go/lib.py b/mlflow_go/lib.py index ad5c482..d4cbb32 100644 --- a/mlflow_go/lib.py +++ b/mlflow_go/lib.py @@ -1,124 +1,124 @@ -import logging -import os -import pathlib -import re -import subprocess -import sys -import tempfile - - -def _get_lib_name() -> str: - ext = ".so" - if sys.platform == "win32": - ext = ".dll" - elif sys.platform == "darwin": - ext = ".dylib" - return "libmlflow-go" + ext - - -def build_lib(src_dir: pathlib.Path, out_dir: pathlib.Path) -> pathlib.Path: - out_path = out_dir.joinpath(_get_lib_name()) - env = os.environ.copy() - env.update( - { - "CGO_ENABLED": "1", - } - ) - subprocess.check_call( - [ - "go", - "build", - "-trimpath", - "-ldflags", - "-w -s", - "-o", - out_path.resolve().as_posix(), - "-buildmode", - "c-shared", - src_dir.joinpath("pkg", "lib").resolve().as_posix(), - ], - cwd=src_dir.resolve().as_posix(), - env=env, - ) - return out_path - - -def _get_lib(): - # check if the library exists and load it - path = pathlib.Path( - os.environ.get("MLFLOW_GO_LIBRARY_PATH", pathlib.Path(__file__).parent.as_posix()) - ).joinpath(_get_lib_name()) - if path.is_file(): - return _load_lib(path) - - logging.getLogger(__name__).warn("Go library not found, building it now") - - # build the library in a temporary directory and load it - with tempfile.TemporaryDirectory() as tmpdir: - return _load_lib( - build_lib( - pathlib.Path(__file__).parent.parent, - pathlib.Path(tmpdir), - ) - ) - - -def _load_lib(path: pathlib.Path): - ffi = get_ffi() - - # load from header file - ffi.cdef(_parse_header(path.with_suffix(".h"))) - - # load the library - return ffi.dlopen(path.as_posix()) - - -def _parse_header(path: pathlib.Path): - with open(path) as file: - content = file.read() - - # Find all matches in the header - functions = re.findall(r"extern\s+\w+\s*\*?\s+\w+\s*\([^)]*\);", content, re.MULTILINE) - - # Replace GoInt64 with int64_t in each function - transformed_functions = [func.replace("GoInt64", "int64_t") for func in functions] - - return "\n".join(transformed_functions) - - -def _get_ffi(): - import cffi - - return cffi.FFI() - - -_ffi = None - - -def get_ffi(): - global _ffi - if _ffi is None: - _ffi = _get_ffi() - _ffi.cdef("void free(void*);") - return _ffi - - -_lib = None - - -def get_lib(): - global _lib - if _lib is None: - _lib = _get_lib() - return _lib - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser("build_lib", description="Build Go library") - parser.add_argument("src", help="the Go source directory") - parser.add_argument("out", help="the output directory") - args = parser.parse_args() - - build_lib(pathlib.Path(args.src), pathlib.Path(args.out)) +import logging +import os +import pathlib +import re +import subprocess +import sys +import tempfile + + +def _get_lib_name() -> str: + ext = ".so" + if sys.platform == "win32": + ext = ".dll" + elif sys.platform == "darwin": + ext = ".dylib" + return "libmlflow-go" + ext + + +def build_lib(src_dir: pathlib.Path, out_dir: pathlib.Path) -> pathlib.Path: + out_path = out_dir.joinpath(_get_lib_name()) + env = os.environ.copy() + env.update( + { + "CGO_ENABLED": "1", + } + ) + subprocess.check_call( + [ + "go", + "build", + "-trimpath", + "-ldflags", + "-w -s", + "-o", + out_path.resolve().as_posix(), + "-buildmode", + "c-shared", + src_dir.joinpath("pkg", "lib").resolve().as_posix(), + ], + cwd=src_dir.resolve().as_posix(), + env=env, + ) + return out_path + + +def _get_lib(): + # check if the library exists and load it + path = pathlib.Path( + os.environ.get("MLFLOW_GO_LIBRARY_PATH", pathlib.Path(__file__).parent.as_posix()) + ).joinpath(_get_lib_name()) + if path.is_file(): + return _load_lib(path) + + logging.getLogger(__name__).warn("Go library not found, building it now") + + # build the library in a temporary directory and load it + with tempfile.TemporaryDirectory() as tmpdir: + return _load_lib( + build_lib( + pathlib.Path(__file__).parent.parent, + pathlib.Path(tmpdir), + ) + ) + + +def _load_lib(path: pathlib.Path): + ffi = get_ffi() + + # load from header file + ffi.cdef(_parse_header(path.with_suffix(".h"))) + + # load the library + return ffi.dlopen(path.as_posix()) + + +def _parse_header(path: pathlib.Path): + with open(path) as file: + content = file.read() + + # Find all matches in the header + functions = re.findall(r"extern\s+\w+\s*\*?\s+\w+\s*\([^)]*\);", content, re.MULTILINE) + + # Replace GoInt64 with int64_t in each function + transformed_functions = [func.replace("GoInt64", "int64_t") for func in functions] + + return "\n".join(transformed_functions) + + +def _get_ffi(): + import cffi + + return cffi.FFI() + + +_ffi = None + + +def get_ffi(): + global _ffi + if _ffi is None: + _ffi = _get_ffi() + _ffi.cdef("void free(void*);") + return _ffi + + +_lib = None + + +def get_lib(): + global _lib + if _lib is None: + _lib = _get_lib() + return _lib + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser("build_lib", description="Build Go library") + parser.add_argument("src", help="the Go source directory") + parser.add_argument("out", help="the output directory") + args = parser.parse_args() + + build_lib(pathlib.Path(args.src), pathlib.Path(args.out)) diff --git a/mlflow_go/server.py b/mlflow_go/server.py index 102eb23..bc82c44 100644 --- a/mlflow_go/server.py +++ b/mlflow_go/server.py @@ -1,31 +1,31 @@ -import json -from contextlib import contextmanager - -from mlflow_go.lib import get_lib - - -def launch_server(**config): - config_bytes = json.dumps(config).encode("utf-8") - - # start the Go server and check for errors - ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) - if ret != 0: - raise Exception(f"Non-zero exit code: {ret}") - - -@contextmanager -def server(**config): - config_bytes = json.dumps(config).encode("utf-8") - - # start the Go server and check for errors - id = get_lib().LaunchServerAsync(config_bytes, len(config_bytes)) - if id < 0: - raise Exception(f"Non-zero exit code: {id}") - - try: - yield - finally: - # stop the Go server and check for errors - ret = get_lib().StopServer(id) - if ret != 0: - raise Exception(f"Non-zero exit code: {ret}") +import json +from contextlib import contextmanager + +from mlflow_go.lib import get_lib + + +def launch_server(**config): + config_bytes = json.dumps(config).encode("utf-8") + + # start the Go server and check for errors + ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) + if ret != 0: + raise Exception(f"Non-zero exit code: {ret}") + + +@contextmanager +def server(**config): + config_bytes = json.dumps(config).encode("utf-8") + + # start the Go server and check for errors + id = get_lib().LaunchServerAsync(config_bytes, len(config_bytes)) + if id < 0: + raise Exception(f"Non-zero exit code: {id}") + + try: + yield + finally: + # stop the Go server and check for errors + ret = get_lib().StopServer(id) + if ret != 0: + raise Exception(f"Non-zero exit code: {ret}") diff --git a/mlflow_go/store/_service_proxy.py b/mlflow_go/store/_service_proxy.py index be854ac..4241b1a 100644 --- a/mlflow_go/store/_service_proxy.py +++ b/mlflow_go/store/_service_proxy.py @@ -1,43 +1,43 @@ -import json - -from google.protobuf.message import DecodeError -from mlflow.exceptions import MlflowException -from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode - -from mlflow_go.lib import get_ffi, get_lib - - -class _ServiceProxy: - def __init__(self, id): - self.id = id - - def call_endpoint(self, endpoint, request): - request_data = request.SerializeToString() - response_size = get_ffi().new("int*") - - response_data = endpoint( - self.id, - request_data, - len(request_data), - response_size, - ) - - response_bytes = get_ffi().buffer(response_data, response_size[0])[:] - get_lib().free(response_data) - - try: - response = type(request).Response() - response.ParseFromString(response_bytes) - return response - except DecodeError: - try: - e = json.loads(response_bytes) - error_code = e.get("error_code", ErrorCode.Name(INTERNAL_ERROR)) - raise MlflowException( - message=e["message"], - error_code=ErrorCode.Value(error_code), - ) from None - except json.JSONDecodeError as e: - raise MlflowException( - message=f"Failed to parse response: {e}", - ) +import json + +from google.protobuf.message import DecodeError +from mlflow.exceptions import MlflowException +from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode + +from mlflow_go.lib import get_ffi, get_lib + + +class _ServiceProxy: + def __init__(self, id): + self.id = id + + def call_endpoint(self, endpoint, request): + request_data = request.SerializeToString() + response_size = get_ffi().new("int*") + + response_data = endpoint( + self.id, + request_data, + len(request_data), + response_size, + ) + + response_bytes = get_ffi().buffer(response_data, response_size[0])[:] + get_lib().free(response_data) + + try: + response = type(request).Response() + response.ParseFromString(response_bytes) + return response + except DecodeError: + try: + e = json.loads(response_bytes) + error_code = e.get("error_code", ErrorCode.Name(INTERNAL_ERROR)) + raise MlflowException( + message=e["message"], + error_code=ErrorCode.Value(error_code), + ) from None + except json.JSONDecodeError as e: + raise MlflowException( + message=f"Failed to parse response: {e}", + ) diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index bc5ee11..ba80a17 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -1,55 +1,55 @@ -import json -import logging - -from mlflow.entities.model_registry import ( - ModelVersion, -) -from mlflow.protos.model_registry_pb2 import ( - GetLatestVersions, -) - -from mlflow_go import is_go_enabled -from mlflow_go.lib import get_lib -from mlflow_go.store._service_proxy import _ServiceProxy - -_logger = logging.getLogger(__name__) - - -class _ModelRegistryStore: - def __init__(self, *args, **kwargs): - store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) - config = json.dumps( - { - "model_registry_store_uri": store_uri, - "log_level": logging.getLevelName(_logger.getEffectiveLevel()), - } - ).encode("utf-8") - self.service = _ServiceProxy(get_lib().CreateModelRegistryService(config, len(config))) - super().__init__(store_uri) - - def __del__(self): - if hasattr(self, "service"): - get_lib().DestroyModelRegistryService(self.service.id) - - def get_latest_versions(self, name, stages=None): - request = GetLatestVersions( - name=name, - stages=stages, - ) - response = self.service.call_endpoint( - get_lib().ModelRegistryServiceGetLatestVersions, request - ) - return [ModelVersion.from_proto(mv) for mv in response.model_versions] - - -def ModelRegistryStore(cls): - return type(cls.__name__, (_ModelRegistryStore, cls), {}) - - -def _get_sqlalchemy_store(store_uri): - from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore - - if is_go_enabled(): - SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) - - return SqlAlchemyStore(store_uri) +import json +import logging + +from mlflow.entities.model_registry import ( + ModelVersion, +) +from mlflow.protos.model_registry_pb2 import ( + GetLatestVersions, +) + +from mlflow_go import is_go_enabled +from mlflow_go.lib import get_lib +from mlflow_go.store._service_proxy import _ServiceProxy + +_logger = logging.getLogger(__name__) + + +class _ModelRegistryStore: + def __init__(self, *args, **kwargs): + store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) + config = json.dumps( + { + "model_registry_store_uri": store_uri, + "log_level": logging.getLevelName(_logger.getEffectiveLevel()), + } + ).encode("utf-8") + self.service = _ServiceProxy(get_lib().CreateModelRegistryService(config, len(config))) + super().__init__(store_uri) + + def __del__(self): + if hasattr(self, "service"): + get_lib().DestroyModelRegistryService(self.service.id) + + def get_latest_versions(self, name, stages=None): + request = GetLatestVersions( + name=name, + stages=stages, + ) + response = self.service.call_endpoint( + get_lib().ModelRegistryServiceGetLatestVersions, request + ) + return [ModelVersion.from_proto(mv) for mv in response.model_versions] + + +def ModelRegistryStore(cls): + return type(cls.__name__, (_ModelRegistryStore, cls), {}) + + +def _get_sqlalchemy_store(store_uri): + from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore + + if is_go_enabled(): + SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) + + return SqlAlchemyStore(store_uri) diff --git a/mlflow_go/store/tracking.py b/mlflow_go/store/tracking.py index 80767ef..97c4e9c 100644 --- a/mlflow_go/store/tracking.py +++ b/mlflow_go/store/tracking.py @@ -1,183 +1,192 @@ -import json -import logging - -from mlflow.entities import ( - Experiment, - Run, - RunInfo, - ViewType, -) -from mlflow.exceptions import MlflowException -from mlflow.protos import databricks_pb2 -from mlflow.protos.service_pb2 import ( - CreateExperiment, - CreateRun, - DeleteExperiment, - DeleteRun, - GetExperiment, - GetExperimentByName, - GetRun, - LogBatch, - LogMetric, - RestoreExperiment, - RestoreRun, - SearchRuns, - UpdateExperiment, - UpdateRun, -) -from mlflow.utils.uri import resolve_uri_if_local - -from mlflow_go import is_go_enabled -from mlflow_go.lib import get_lib -from mlflow_go.store._service_proxy import _ServiceProxy - -_logger = logging.getLogger(__name__) - - -class _TrackingStore: - def __init__(self, *args, **kwargs): - store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) - default_artifact_root = ( - args[1] - if len(args) > 1 - else kwargs.get("default_artifact_root", kwargs.get("artifact_root_uri")) - ) - config = json.dumps( - { - "default_artifact_root": resolve_uri_if_local(default_artifact_root), - "tracking_store_uri": store_uri, - "log_level": logging.getLevelName(_logger.getEffectiveLevel()), - } - ).encode("utf-8") - self.service = _ServiceProxy(get_lib().CreateTrackingService(config, len(config))) - super().__init__(store_uri, default_artifact_root) - - def __del__(self): - if hasattr(self, "service"): - get_lib().DestroyTrackingService(self.service.id) - - def get_experiment(self, experiment_id): - request = GetExperiment(experiment_id=str(experiment_id)) - response = self.service.call_endpoint(get_lib().TrackingServiceGetExperiment, request) - return Experiment.from_proto(response.experiment) - - def get_experiment_by_name(self, experiment_name): - request = GetExperimentByName(experiment_name=experiment_name) - try: - response = self.service.call_endpoint( - get_lib().TrackingServiceGetExperimentByName, request - ) - return Experiment.from_proto(response.experiment) - except MlflowException as e: - if e.error_code == databricks_pb2.ErrorCode.Name( - databricks_pb2.RESOURCE_DOES_NOT_EXIST - ): - return None - raise - - def create_experiment(self, name, artifact_location=None, tags=None): - request = CreateExperiment( - name=name, - artifact_location=artifact_location, - tags=[tag.to_proto() for tag in tags] if tags else [], - ) - response = self.service.call_endpoint(get_lib().TrackingServiceCreateExperiment, request) - return response.experiment_id - - def delete_experiment(self, experiment_id): - request = DeleteExperiment(experiment_id=str(experiment_id)) - self.service.call_endpoint(get_lib().TrackingServiceDeleteExperiment, request) - - def restore_experiment(self, experiment_id): - request = RestoreExperiment(experiment_id=str(experiment_id)) - self.service.call_endpoint(get_lib().TrackingServiceRestoreExperiment, request) - - def rename_experiment(self, experiment_id, new_name): - request = UpdateExperiment(experiment_id=str(experiment_id), new_name=new_name) - self.service.call_endpoint(get_lib().TrackingServiceUpdateExperiment, request) - - def get_run(self, run_id): - request = GetRun(run_uuid=run_id, run_id=run_id) - response = self.service.call_endpoint(get_lib().TrackingServiceGetRun, request) - return Run.from_proto(response.run) - - def create_run(self, experiment_id, user_id, start_time, tags, run_name): - request = CreateRun( - experiment_id=str(experiment_id), - user_id=user_id, - start_time=start_time, - tags=[tag.to_proto() for tag in tags] if tags else [], - run_name=run_name, - ) - response = self.service.call_endpoint(get_lib().TrackingServiceCreateRun, request) - return Run.from_proto(response.run) - - def delete_run(self, run_id): - request = DeleteRun(run_id=run_id) - self.service.call_endpoint(get_lib().TrackingServiceDeleteRun, request) - - def restore_run(self, run_id): - request = RestoreRun(run_id=run_id) - self.service.call_endpoint(get_lib().TrackingServiceRestoreRun, request) - - def update_run(self, run_id, run_status, end_time, run_name): - request = UpdateRun( - run_uuid=run_id, - run_id=run_id, - status=run_status, - end_time=end_time, - run_name=run_name, - ) - response = self.service.call_endpoint(get_lib().TrackingServiceUpdateRun, request) - return RunInfo.from_proto(response.run_info) - - def _search_runs( - self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token - ): - request = SearchRuns( - experiment_ids=[str(experiment_id) for experiment_id in experiment_ids], - filter=filter_string, - run_view_type=ViewType.to_proto(run_view_type), - max_results=max_results, - order_by=order_by, - page_token=page_token, - ) - response = self.service.call_endpoint(get_lib().TrackingServiceSearchRuns, request) - runs = [Run.from_proto(proto_run) for proto_run in response.runs] - return runs, (response.next_page_token or None) - - def log_batch(self, run_id, metrics, params, tags): - request = LogBatch( - run_id=run_id, - metrics=[metric.to_proto() for metric in metrics], - params=[param.to_proto() for param in params], - tags=[tag.to_proto() for tag in tags], - ) - self.service.call_endpoint(get_lib().TrackingServiceLogBatch, request) - - def log_metric(self, run_id, metric): - request = LogMetric( - run_id=run_id, - key=metric.key, - value=metric.value, - timestamp=metric.timestamp, - step=metric.step, - ) - self.service.call_endpoint(get_lib().TrackingServiceLogMetric, request) - - -def TrackingStore(cls): - return type(cls.__name__, (_TrackingStore, cls), {}) - - -def _get_sqlalchemy_store(store_uri, artifact_uri): - from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH - from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore - - if is_go_enabled(): - SqlAlchemyStore = TrackingStore(SqlAlchemyStore) - - if artifact_uri is None: - artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH - - return SqlAlchemyStore(store_uri, artifact_uri) +import json +import logging + +from mlflow.entities import ( + Experiment, + Run, + RunInfo, + ViewType, +) +from mlflow.exceptions import MlflowException +from mlflow.protos import databricks_pb2 +from mlflow.protos.service_pb2 import ( + CreateExperiment, + CreateRun, + DeleteExperiment, + DeleteRun, + DeleteTag, + GetExperiment, + GetExperimentByName, + GetRun, + LogBatch, + LogMetric, + RestoreExperiment, + RestoreRun, + SearchRuns, + SetTag, + UpdateExperiment, + UpdateRun, +) +from mlflow.utils.uri import resolve_uri_if_local + +from mlflow_go import is_go_enabled +from mlflow_go.lib import get_lib +from mlflow_go.store._service_proxy import _ServiceProxy + +_logger = logging.getLogger(__name__) + + +class _TrackingStore: + def __init__(self, *args, **kwargs): + store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) + default_artifact_root = ( + args[1] + if len(args) > 1 + else kwargs.get("default_artifact_root", kwargs.get("artifact_root_uri")) + ) + config = json.dumps( + { + "default_artifact_root": resolve_uri_if_local(default_artifact_root), + "tracking_store_uri": store_uri, + "log_level": logging.getLevelName(_logger.getEffectiveLevel()), + } + ).encode("utf-8") + self.service = _ServiceProxy(get_lib().CreateTrackingService(config, len(config))) + super().__init__(store_uri, default_artifact_root) + + def __del__(self): + if hasattr(self, "service"): + get_lib().DestroyTrackingService(self.service.id) + + def get_experiment(self, experiment_id): + request = GetExperiment(experiment_id=str(experiment_id)) + response = self.service.call_endpoint(get_lib().TrackingServiceGetExperiment, request) + return Experiment.from_proto(response.experiment) + + def get_experiment_by_name(self, experiment_name): + request = GetExperimentByName(experiment_name=experiment_name) + try: + response = self.service.call_endpoint( + get_lib().TrackingServiceGetExperimentByName, request + ) + return Experiment.from_proto(response.experiment) + except MlflowException as e: + if e.error_code == databricks_pb2.ErrorCode.Name( + databricks_pb2.RESOURCE_DOES_NOT_EXIST + ): + return None + raise + + def create_experiment(self, name, artifact_location=None, tags=None): + request = CreateExperiment( + name=name, + artifact_location=artifact_location, + tags=[tag.to_proto() for tag in tags] if tags else [], + ) + response = self.service.call_endpoint(get_lib().TrackingServiceCreateExperiment, request) + return response.experiment_id + + def delete_experiment(self, experiment_id): + request = DeleteExperiment(experiment_id=str(experiment_id)) + self.service.call_endpoint(get_lib().TrackingServiceDeleteExperiment, request) + + def restore_experiment(self, experiment_id): + request = RestoreExperiment(experiment_id=str(experiment_id)) + self.service.call_endpoint(get_lib().TrackingServiceRestoreExperiment, request) + + def rename_experiment(self, experiment_id, new_name): + request = UpdateExperiment(experiment_id=str(experiment_id), new_name=new_name) + self.service.call_endpoint(get_lib().TrackingServiceUpdateExperiment, request) + + def get_run(self, run_id): + request = GetRun(run_uuid=run_id, run_id=run_id) + response = self.service.call_endpoint(get_lib().TrackingServiceGetRun, request) + return Run.from_proto(response.run) + + def create_run(self, experiment_id, user_id, start_time, tags, run_name): + request = CreateRun( + experiment_id=str(experiment_id), + user_id=user_id, + start_time=start_time, + tags=[tag.to_proto() for tag in tags] if tags else [], + run_name=run_name, + ) + response = self.service.call_endpoint(get_lib().TrackingServiceCreateRun, request) + return Run.from_proto(response.run) + + def delete_run(self, run_id): + request = DeleteRun(run_id=run_id) + self.service.call_endpoint(get_lib().TrackingServiceDeleteRun, request) + + def restore_run(self, run_id): + request = RestoreRun(run_id=run_id) + self.service.call_endpoint(get_lib().TrackingServiceRestoreRun, request) + + def update_run(self, run_id, run_status, end_time, run_name): + request = UpdateRun( + run_uuid=run_id, + run_id=run_id, + status=run_status, + end_time=end_time, + run_name=run_name, + ) + response = self.service.call_endpoint(get_lib().TrackingServiceUpdateRun, request) + return RunInfo.from_proto(response.run_info) + + def _search_runs( + self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token + ): + request = SearchRuns( + experiment_ids=[str(experiment_id) for experiment_id in experiment_ids], + filter=filter_string, + run_view_type=ViewType.to_proto(run_view_type), + max_results=max_results, + order_by=order_by, + page_token=page_token, + ) + response = self.service.call_endpoint(get_lib().TrackingServiceSearchRuns, request) + runs = [Run.from_proto(proto_run) for proto_run in response.runs] + return runs, (response.next_page_token or None) + + def log_batch(self, run_id, metrics, params, tags): + request = LogBatch( + run_id=run_id, + metrics=[metric.to_proto() for metric in metrics], + params=[param.to_proto() for param in params], + tags=[tag.to_proto() for tag in tags], + ) + self.service.call_endpoint(get_lib().TrackingServiceLogBatch, request) + + def log_metric(self, run_id, metric): + request = LogMetric( + run_id=run_id, + key=metric.key, + value=metric.value, + timestamp=metric.timestamp, + step=metric.step, + ) + self.service.call_endpoint(get_lib().TrackingServiceLogMetric, request) + + def set_tag(self, run_id, tag): + request = SetTag(run_id=run_id, key=tag.key, value=tag.value) + self.service.call_endpoint(get_lib().TrackingServiceSetTag, request) + + def delete_tag(self, run_id, key): + request = DeleteTag(run_id=run_id, key=key) + self.service.call_endpoint(get_lib().TrackingServiceDeleteTag, request) + +def TrackingStore(cls): + return type(cls.__name__, (_TrackingStore, cls), {}) + + +def _get_sqlalchemy_store(store_uri, artifact_uri): + from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH + from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + + if is_go_enabled(): + SqlAlchemyStore = TrackingStore(SqlAlchemyStore) + + if artifact_uri is None: + artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH + + return SqlAlchemyStore(store_uri, artifact_uri) diff --git a/pkg/artifacts/service/service.go b/pkg/artifacts/service/service.go index 0f21d38..e4b6971 100644 --- a/pkg/artifacts/service/service.go +++ b/pkg/artifacts/service/service.go @@ -1,17 +1,17 @@ -package service - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/config" -) - -type ArtifactsService struct { - config *config.Config -} - -func NewArtifactsService(_ context.Context, config *config.Config) (*ArtifactsService, error) { - return &ArtifactsService{ - config: config, - }, nil -} +package service + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/config" +) + +type ArtifactsService struct { + config *config.Config +} + +func NewArtifactsService(_ context.Context, config *config.Config) (*ArtifactsService, error) { + return &ArtifactsService{ + config: config, + }, nil +} diff --git a/pkg/cmd/server/main.go b/pkg/cmd/server/main.go index 5b71e49..de96c60 100644 --- a/pkg/cmd/server/main.go +++ b/pkg/cmd/server/main.go @@ -1,21 +1,21 @@ -package main - -import ( - "os" - - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/server" -) - -func main() { - cfg, err := config.NewConfigFromString(os.Getenv("MLFLOW_GO_CONFIG")) - if err != nil { - logrus.Fatal("Failed to read config from MLFLOW_GO_CONFIG environment variable: ", err) - } - - if err := server.LaunchWithSignalHandler(cfg); err != nil { - logrus.Fatal("Failed to launch server: ", err) - } -} +package main + +import ( + "os" + + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/server" +) + +func main() { + cfg, err := config.NewConfigFromString(os.Getenv("MLFLOW_GO_CONFIG")) + if err != nil { + logrus.Fatal("Failed to read config from MLFLOW_GO_CONFIG environment variable: ", err) + } + + if err := server.LaunchWithSignalHandler(cfg); err != nil { + logrus.Fatal("Failed to launch server: ", err) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 2116518..59fe948 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,106 +1,106 @@ -package config - -import ( - "encoding/json" - "errors" - "fmt" - "time" -) - -type Duration struct { - time.Duration -} - -var ErrDuration = errors.New("invalid duration") - -func (d *Duration) UnmarshalJSON(b []byte) error { - var v interface{} - if err := json.Unmarshal(b, &v); err != nil { - return fmt.Errorf("failed to unmarshal duration: %w", err) - } - - switch value := v.(type) { - case float64: - d.Duration = time.Duration(value) - - return nil - case string: - var err error - - d.Duration, err = time.ParseDuration(value) - if err != nil { - return fmt.Errorf("failed to parse duration \"%s\": %w", value, err) - } - - return nil - default: - return ErrDuration - } -} - -type Config struct { - Address string `json:"address"` - DefaultArtifactRoot string `json:"default_artifact_root"` - LogLevel string `json:"log_level"` - ModelRegistryStoreURI string `json:"model_registry_store_uri"` - PythonAddress string `json:"python_address"` - PythonCommand []string `json:"python_command"` - PythonEnv []string `json:"python_env"` - ShutdownTimeout Duration `json:"shutdown_timeout"` - StaticFolder string `json:"static_folder"` - TrackingStoreURI string `json:"tracking_store_uri"` - Version string `json:"version"` -} - -func NewConfigFromBytes(cfgBytes []byte) (*Config, error) { - if len(cfgBytes) == 0 { - cfgBytes = []byte("{}") - } - - var cfg Config - if err := json.Unmarshal(cfgBytes, &cfg); err != nil { - return nil, fmt.Errorf("failed to parse JSON config: %w", err) - } - - cfg.applyDefaults() - - return &cfg, nil -} - -func NewConfigFromString(s string) (*Config, error) { - return NewConfigFromBytes([]byte(s)) -} - -func (c *Config) applyDefaults() { - if c.Address == "" { - c.Address = "localhost:5000" - } - - if c.DefaultArtifactRoot == "" { - c.DefaultArtifactRoot = "mlflow-artifacts:/" - } - - if c.LogLevel == "" { - c.LogLevel = "INFO" - } - - if c.ShutdownTimeout.Duration == 0 { - c.ShutdownTimeout.Duration = time.Minute - } - - if c.TrackingStoreURI == "" { - if c.ModelRegistryStoreURI != "" { - c.TrackingStoreURI = c.ModelRegistryStoreURI - } else { - c.TrackingStoreURI = "sqlite:///mlflow.db" - } - } - - if c.ModelRegistryStoreURI == "" { - c.ModelRegistryStoreURI = c.TrackingStoreURI - } - - if c.Version == "" { - c.Version = "dev" - } -} +package config + +import ( + "encoding/json" + "errors" + "fmt" + "time" +) + +type Duration struct { + time.Duration +} + +var ErrDuration = errors.New("invalid duration") + +func (d *Duration) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return fmt.Errorf("failed to unmarshal duration: %w", err) + } + + switch value := v.(type) { + case float64: + d.Duration = time.Duration(value) + + return nil + case string: + var err error + + d.Duration, err = time.ParseDuration(value) + if err != nil { + return fmt.Errorf("failed to parse duration \"%s\": %w", value, err) + } + + return nil + default: + return ErrDuration + } +} + +type Config struct { + Address string `json:"address"` + DefaultArtifactRoot string `json:"default_artifact_root"` + LogLevel string `json:"log_level"` + ModelRegistryStoreURI string `json:"model_registry_store_uri"` + PythonAddress string `json:"python_address"` + PythonCommand []string `json:"python_command"` + PythonEnv []string `json:"python_env"` + ShutdownTimeout Duration `json:"shutdown_timeout"` + StaticFolder string `json:"static_folder"` + TrackingStoreURI string `json:"tracking_store_uri"` + Version string `json:"version"` +} + +func NewConfigFromBytes(cfgBytes []byte) (*Config, error) { + if len(cfgBytes) == 0 { + cfgBytes = []byte("{}") + } + + var cfg Config + if err := json.Unmarshal(cfgBytes, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse JSON config: %w", err) + } + + cfg.applyDefaults() + + return &cfg, nil +} + +func NewConfigFromString(s string) (*Config, error) { + return NewConfigFromBytes([]byte(s)) +} + +func (c *Config) applyDefaults() { + if c.Address == "" { + c.Address = "localhost:5000" + } + + if c.DefaultArtifactRoot == "" { + c.DefaultArtifactRoot = "mlflow-artifacts:/" + } + + if c.LogLevel == "" { + c.LogLevel = "INFO" + } + + if c.ShutdownTimeout.Duration == 0 { + c.ShutdownTimeout.Duration = time.Minute + } + + if c.TrackingStoreURI == "" { + if c.ModelRegistryStoreURI != "" { + c.TrackingStoreURI = c.ModelRegistryStoreURI + } else { + c.TrackingStoreURI = "sqlite:///mlflow.db" + } + } + + if c.ModelRegistryStoreURI == "" { + c.ModelRegistryStoreURI = c.TrackingStoreURI + } + + if c.Version == "" { + c.Version = "dev" + } +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 765305a..1a98b86 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1,53 +1,53 @@ -package config_test - -import ( - "encoding/json" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/mlflow/mlflow-go/pkg/config" -) - -type validSample struct { - input string - duration config.Duration -} - -func TestValidDuration(t *testing.T) { - t.Parallel() - - samples := []validSample{ - {input: "1000", duration: config.Duration{Duration: 1000 * time.Nanosecond}}, - {input: `"1s"`, duration: config.Duration{Duration: 1 * time.Second}}, - {input: `"2h45m"`, duration: config.Duration{Duration: 2*time.Hour + 45*time.Minute}}, - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - jsonConfig := fmt.Sprintf(`{ "shutdown_timeout": %s }`, currentSample.input) - - var cfg config.Config - - err := json.Unmarshal([]byte(jsonConfig), &cfg) - require.NoError(t, err) - - require.Equal(t, currentSample.duration, cfg.ShutdownTimeout) - }) - } -} - -func TestInvalidDuration(t *testing.T) { - t.Parallel() - - var cfg config.Config - - if err := json.Unmarshal([]byte(`{ "shutdown_timeout": "two seconds" }`), &cfg); err == nil { - t.Error("expected error") - } -} +package config_test + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/mlflow/mlflow-go/pkg/config" +) + +type validSample struct { + input string + duration config.Duration +} + +func TestValidDuration(t *testing.T) { + t.Parallel() + + samples := []validSample{ + {input: "1000", duration: config.Duration{Duration: 1000 * time.Nanosecond}}, + {input: `"1s"`, duration: config.Duration{Duration: 1 * time.Second}}, + {input: `"2h45m"`, duration: config.Duration{Duration: 2*time.Hour + 45*time.Minute}}, + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + jsonConfig := fmt.Sprintf(`{ "shutdown_timeout": %s }`, currentSample.input) + + var cfg config.Config + + err := json.Unmarshal([]byte(jsonConfig), &cfg) + require.NoError(t, err) + + require.Equal(t, currentSample.duration, cfg.ShutdownTimeout) + }) + } +} + +func TestInvalidDuration(t *testing.T) { + t.Parallel() + + var cfg config.Config + + if err := json.Unmarshal([]byte(`{ "shutdown_timeout": "two seconds" }`), &cfg); err == nil { + t.Error("expected error") + } +} diff --git a/pkg/contract/error.go b/pkg/contract/error.go index 04b3f2b..8c95ba0 100644 --- a/pkg/contract/error.go +++ b/pkg/contract/error.go @@ -1,82 +1,82 @@ -package contract - -import ( - "encoding/json" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/protos" -) - -type ErrorCode protos.ErrorCode - -func (e ErrorCode) String() string { - return protos.ErrorCode(e).String() -} - -// Custom json marshalling for ErrorCode. -func (e ErrorCode) MarshalJSON() ([]byte, error) { - //nolint:wrapcheck - return json.Marshal(e.String()) -} - -type Error struct { - Code ErrorCode `json:"error_code"` - Message string `json:"message"` - Inner error `json:"-"` -} - -func NewError(code protos.ErrorCode, message string) *Error { - return NewErrorWith(code, message, nil) -} - -func NewErrorWith(code protos.ErrorCode, message string, err error) *Error { - return &Error{ - Code: ErrorCode(code), - Message: message, - Inner: err, - } -} - -func (e *Error) Error() string { - msg := fmt.Sprintf("[%s] %s", e.Code.String(), e.Message) - if e.Inner != nil { - return fmt.Sprintf("%s: %s", msg, e.Inner) - } - - return msg -} - -func (e *Error) Unwrap() error { - return e.Inner -} - -//nolint:cyclop -func (e *Error) StatusCode() int { - //nolint:exhaustive,mnd - switch protos.ErrorCode(e.Code) { - case protos.ErrorCode_BAD_REQUEST, protos.ErrorCode_INVALID_PARAMETER_VALUE, protos.ErrorCode_RESOURCE_ALREADY_EXISTS: - return 400 - case protos.ErrorCode_CUSTOMER_UNAUTHORIZED, protos.ErrorCode_UNAUTHENTICATED: - return 401 - case protos.ErrorCode_PERMISSION_DENIED: - return 403 - case protos.ErrorCode_ENDPOINT_NOT_FOUND, protos.ErrorCode_NOT_FOUND, protos.ErrorCode_RESOURCE_DOES_NOT_EXIST: - return 404 - case protos.ErrorCode_ABORTED, protos.ErrorCode_ALREADY_EXISTS, protos.ErrorCode_RESOURCE_CONFLICT: - return 409 - case protos.ErrorCode_RESOURCE_EXHAUSTED, protos.ErrorCode_RESOURCE_LIMIT_EXCEEDED: - return 429 - case protos.ErrorCode_CANCELLED: - return 499 - case protos.ErrorCode_DATA_LOSS, protos.ErrorCode_INTERNAL_ERROR, protos.ErrorCode_INVALID_STATE: - return 500 - case protos.ErrorCode_NOT_IMPLEMENTED: - return 501 - case protos.ErrorCode_TEMPORARILY_UNAVAILABLE: - return 503 - case protos.ErrorCode_DEADLINE_EXCEEDED: - return 504 - default: - return 500 - } -} +package contract + +import ( + "encoding/json" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/protos" +) + +type ErrorCode protos.ErrorCode + +func (e ErrorCode) String() string { + return protos.ErrorCode(e).String() +} + +// Custom json marshalling for ErrorCode. +func (e ErrorCode) MarshalJSON() ([]byte, error) { + //nolint:wrapcheck + return json.Marshal(e.String()) +} + +type Error struct { + Code ErrorCode `json:"error_code"` + Message string `json:"message"` + Inner error `json:"-"` +} + +func NewError(code protos.ErrorCode, message string) *Error { + return NewErrorWith(code, message, nil) +} + +func NewErrorWith(code protos.ErrorCode, message string, err error) *Error { + return &Error{ + Code: ErrorCode(code), + Message: message, + Inner: err, + } +} + +func (e *Error) Error() string { + msg := fmt.Sprintf("[%s] %s", e.Code.String(), e.Message) + if e.Inner != nil { + return fmt.Sprintf("%s: %s", msg, e.Inner) + } + + return msg +} + +func (e *Error) Unwrap() error { + return e.Inner +} + +//nolint:cyclop +func (e *Error) StatusCode() int { + //nolint:exhaustive,mnd + switch protos.ErrorCode(e.Code) { + case protos.ErrorCode_BAD_REQUEST, protos.ErrorCode_INVALID_PARAMETER_VALUE, protos.ErrorCode_RESOURCE_ALREADY_EXISTS: + return 400 + case protos.ErrorCode_CUSTOMER_UNAUTHORIZED, protos.ErrorCode_UNAUTHENTICATED: + return 401 + case protos.ErrorCode_PERMISSION_DENIED: + return 403 + case protos.ErrorCode_ENDPOINT_NOT_FOUND, protos.ErrorCode_NOT_FOUND, protos.ErrorCode_RESOURCE_DOES_NOT_EXIST: + return 404 + case protos.ErrorCode_ABORTED, protos.ErrorCode_ALREADY_EXISTS, protos.ErrorCode_RESOURCE_CONFLICT: + return 409 + case protos.ErrorCode_RESOURCE_EXHAUSTED, protos.ErrorCode_RESOURCE_LIMIT_EXCEEDED: + return 429 + case protos.ErrorCode_CANCELLED: + return 499 + case protos.ErrorCode_DATA_LOSS, protos.ErrorCode_INTERNAL_ERROR, protos.ErrorCode_INVALID_STATE: + return 500 + case protos.ErrorCode_NOT_IMPLEMENTED: + return 501 + case protos.ErrorCode_TEMPORARILY_UNAVAILABLE: + return 503 + case protos.ErrorCode_DEADLINE_EXCEEDED: + return 504 + default: + return 500 + } +} diff --git a/pkg/contract/http_request_parser.go b/pkg/contract/http_request_parser.go index d17201a..f2fca20 100644 --- a/pkg/contract/http_request_parser.go +++ b/pkg/contract/http_request_parser.go @@ -1,8 +1,8 @@ -package contract - -import "github.com/gofiber/fiber/v2" - -type HTTPRequestParser interface { - ParseBody(ctx *fiber.Ctx, out interface{}) *Error - ParseQuery(ctx *fiber.Ctx, out interface{}) *Error -} +package contract + +import "github.com/gofiber/fiber/v2" + +type HTTPRequestParser interface { + ParseBody(ctx *fiber.Ctx, out interface{}) *Error + ParseQuery(ctx *fiber.Ctx, out interface{}) *Error +} diff --git a/pkg/contract/service/tracking.g.go b/pkg/contract/service/tracking.g.go index 3854199..8704537 100644 --- a/pkg/contract/service/tracking.g.go +++ b/pkg/contract/service/tracking.g.go @@ -20,6 +20,8 @@ type TrackingService interface { DeleteRun(ctx context.Context, input *protos.DeleteRun) (*protos.DeleteRun_Response, *contract.Error) RestoreRun(ctx context.Context, input *protos.RestoreRun) (*protos.RestoreRun_Response, *contract.Error) LogMetric(ctx context.Context, input *protos.LogMetric) (*protos.LogMetric_Response, *contract.Error) + SetTag(ctx context.Context, input *protos.SetTag) (*protos.SetTag_Response, *contract.Error) + DeleteTag(ctx context.Context, input *protos.DeleteTag) (*protos.DeleteTag_Response, *contract.Error) GetRun(ctx context.Context, input *protos.GetRun) (*protos.GetRun_Response, *contract.Error) SearchRuns(ctx context.Context, input *protos.SearchRuns) (*protos.SearchRuns_Response, *contract.Error) LogBatch(ctx context.Context, input *protos.LogBatch) (*protos.LogBatch_Response, *contract.Error) diff --git a/pkg/entities/dataset.go b/pkg/entities/dataset.go index c0fe7e2..157c696 100644 --- a/pkg/entities/dataset.go +++ b/pkg/entities/dataset.go @@ -1,35 +1,35 @@ -package entities - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" -) - -type Dataset struct { - Name string - Digest string - SourceType string - Source string - Schema string - Profile string -} - -func (d *Dataset) ToProto() *protos.Dataset { - var schema *string - if d.Schema != "" { - schema = &d.Schema - } - - var profile *string - if d.Profile != "" { - profile = &d.Profile - } - - return &protos.Dataset{ - Name: &d.Name, - Digest: &d.Digest, - SourceType: &d.SourceType, - Source: &d.Source, - Schema: schema, - Profile: profile, - } -} +package entities + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" +) + +type Dataset struct { + Name string + Digest string + SourceType string + Source string + Schema string + Profile string +} + +func (d *Dataset) ToProto() *protos.Dataset { + var schema *string + if d.Schema != "" { + schema = &d.Schema + } + + var profile *string + if d.Profile != "" { + profile = &d.Profile + } + + return &protos.Dataset{ + Name: &d.Name, + Digest: &d.Digest, + SourceType: &d.SourceType, + Source: &d.Source, + Schema: schema, + Profile: profile, + } +} diff --git a/pkg/entities/dataset_input.go b/pkg/entities/dataset_input.go index 9284671..7f48dc4 100644 --- a/pkg/entities/dataset_input.go +++ b/pkg/entities/dataset_input.go @@ -1,20 +1,20 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type DatasetInput struct { - Tags []*InputTag - Dataset *Dataset -} - -func (ds DatasetInput) ToProto() *protos.DatasetInput { - tags := make([]*protos.InputTag, 0, len(ds.Tags)) - for _, tag := range ds.Tags { - tags = append(tags, tag.ToProto()) - } - - return &protos.DatasetInput{ - Tags: tags, - Dataset: ds.Dataset.ToProto(), - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type DatasetInput struct { + Tags []*InputTag + Dataset *Dataset +} + +func (ds DatasetInput) ToProto() *protos.DatasetInput { + tags := make([]*protos.InputTag, 0, len(ds.Tags)) + for _, tag := range ds.Tags { + tags = append(tags, tag.ToProto()) + } + + return &protos.DatasetInput{ + Tags: tags, + Dataset: ds.Dataset.ToProto(), + } +} diff --git a/pkg/entities/experiment.go b/pkg/entities/experiment.go index 0bb58b4..3c081bc 100644 --- a/pkg/entities/experiment.go +++ b/pkg/entities/experiment.go @@ -1,36 +1,36 @@ -package entities - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type Experiment struct { - Name string - ExperimentID string - ArtifactLocation string - LifecycleStage string - LastUpdateTime int64 - CreationTime int64 - Tags []*ExperimentTag -} - -func (e Experiment) ToProto() *protos.Experiment { - tags := make([]*protos.ExperimentTag, len(e.Tags)) - - for i, tag := range e.Tags { - tags[i] = tag.ToProto() - } - - experiment := protos.Experiment{ - ExperimentId: &e.ExperimentID, - Name: &e.Name, - ArtifactLocation: &e.ArtifactLocation, - LifecycleStage: utils.PtrTo(e.LifecycleStage), - CreationTime: &e.CreationTime, - LastUpdateTime: &e.LastUpdateTime, - Tags: tags, - } - - return &experiment -} +package entities + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type Experiment struct { + Name string + ExperimentID string + ArtifactLocation string + LifecycleStage string + LastUpdateTime int64 + CreationTime int64 + Tags []*ExperimentTag +} + +func (e Experiment) ToProto() *protos.Experiment { + tags := make([]*protos.ExperimentTag, len(e.Tags)) + + for i, tag := range e.Tags { + tags[i] = tag.ToProto() + } + + experiment := protos.Experiment{ + ExperimentId: &e.ExperimentID, + Name: &e.Name, + ArtifactLocation: &e.ArtifactLocation, + LifecycleStage: utils.PtrTo(e.LifecycleStage), + CreationTime: &e.CreationTime, + LastUpdateTime: &e.LastUpdateTime, + Tags: tags, + } + + return &experiment +} diff --git a/pkg/entities/experiment_tag.go b/pkg/entities/experiment_tag.go index ced86c1..43ef33c 100644 --- a/pkg/entities/experiment_tag.go +++ b/pkg/entities/experiment_tag.go @@ -1,22 +1,22 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type ExperimentTag struct { - Key string - Value string -} - -func (et *ExperimentTag) ToProto() *protos.ExperimentTag { - return &protos.ExperimentTag{ - Key: &et.Key, - Value: &et.Value, - } -} - -func NewExperimentTagFromProto(proto *protos.ExperimentTag) *ExperimentTag { - return &ExperimentTag{ - Key: proto.GetKey(), - Value: proto.GetValue(), - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type ExperimentTag struct { + Key string + Value string +} + +func (et *ExperimentTag) ToProto() *protos.ExperimentTag { + return &protos.ExperimentTag{ + Key: &et.Key, + Value: &et.Value, + } +} + +func NewExperimentTagFromProto(proto *protos.ExperimentTag) *ExperimentTag { + return &ExperimentTag{ + Key: proto.GetKey(), + Value: proto.GetValue(), + } +} diff --git a/pkg/entities/input_tag.go b/pkg/entities/input_tag.go index 67dda7c..06f01e2 100644 --- a/pkg/entities/input_tag.go +++ b/pkg/entities/input_tag.go @@ -1,15 +1,15 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type InputTag struct { - Key string - Value string -} - -func (i InputTag) ToProto() *protos.InputTag { - return &protos.InputTag{ - Key: &i.Key, - Value: &i.Value, - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type InputTag struct { + Key string + Value string +} + +func (i InputTag) ToProto() *protos.InputTag { + return &protos.InputTag{ + Key: &i.Key, + Value: &i.Value, + } +} diff --git a/pkg/entities/metric.go b/pkg/entities/metric.go index 9a502f8..75c0600 100644 --- a/pkg/entities/metric.go +++ b/pkg/entities/metric.go @@ -1,52 +1,52 @@ -package entities - -import ( - "math" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type Metric struct { - Key string - Value float64 - Timestamp int64 - Step int64 - IsNaN bool -} - -func (m Metric) ToProto() *protos.Metric { - metric := protos.Metric{ - Key: &m.Key, - Value: &m.Value, - Timestamp: &m.Timestamp, - Step: &m.Step, - } - - switch { - case m.IsNaN: - metric.Value = utils.PtrTo(math.NaN()) - default: - metric.Value = &m.Value - } - - return &metric -} - -func MetricFromProto(proto *protos.Metric) *Metric { - return &Metric{ - Key: proto.GetKey(), - Value: proto.GetValue(), - Timestamp: proto.GetTimestamp(), - Step: proto.GetStep(), - } -} - -func MetricFromLogMetricProtoInput(input *protos.LogMetric) *Metric { - return &Metric{ - Key: input.GetKey(), - Value: input.GetValue(), - Timestamp: input.GetTimestamp(), - Step: input.GetStep(), - } -} +package entities + +import ( + "math" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type Metric struct { + Key string + Value float64 + Timestamp int64 + Step int64 + IsNaN bool +} + +func (m Metric) ToProto() *protos.Metric { + metric := protos.Metric{ + Key: &m.Key, + Value: &m.Value, + Timestamp: &m.Timestamp, + Step: &m.Step, + } + + switch { + case m.IsNaN: + metric.Value = utils.PtrTo(math.NaN()) + default: + metric.Value = &m.Value + } + + return &metric +} + +func MetricFromProto(proto *protos.Metric) *Metric { + return &Metric{ + Key: proto.GetKey(), + Value: proto.GetValue(), + Timestamp: proto.GetTimestamp(), + Step: proto.GetStep(), + } +} + +func MetricFromLogMetricProtoInput(input *protos.LogMetric) *Metric { + return &Metric{ + Key: input.GetKey(), + Value: input.GetValue(), + Timestamp: input.GetTimestamp(), + Step: input.GetStep(), + } +} diff --git a/pkg/entities/param.go b/pkg/entities/param.go index 2d3ec01..9e1693d 100644 --- a/pkg/entities/param.go +++ b/pkg/entities/param.go @@ -1,22 +1,22 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type Param struct { - Key string - Value string -} - -func (p Param) ToProto() *protos.Param { - return &protos.Param{ - Key: &p.Key, - Value: &p.Value, - } -} - -func ParamFromProto(proto *protos.Param) *Param { - return &Param{ - Key: *proto.Key, - Value: *proto.Value, - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type Param struct { + Key string + Value string +} + +func (p Param) ToProto() *protos.Param { + return &protos.Param{ + Key: &p.Key, + Value: &p.Value, + } +} + +func ParamFromProto(proto *protos.Param) *Param { + return &Param{ + Key: *proto.Key, + Value: *proto.Value, + } +} diff --git a/pkg/entities/run.go b/pkg/entities/run.go index c0f231f..3698200 100644 --- a/pkg/entities/run.go +++ b/pkg/entities/run.go @@ -1,75 +1,75 @@ -package entities - -import ( - "strings" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func RunStatusToProto(status string) *protos.RunStatus { - if status == "" { - return nil - } - - if protoStatus, ok := protos.RunStatus_value[strings.ToUpper(status)]; ok { - return (*protos.RunStatus)(&protoStatus) - } - - return nil -} - -type Run struct { - Info *RunInfo - Data *RunData - Inputs *RunInputs -} - -func (r Run) ToProto() *protos.Run { - metrics := make([]*protos.Metric, 0, len(r.Data.Metrics)) - for _, metric := range r.Data.Metrics { - metrics = append(metrics, metric.ToProto()) - } - - params := make([]*protos.Param, 0, len(r.Data.Params)) - for _, param := range r.Data.Params { - params = append(params, param.ToProto()) - } - - tags := make([]*protos.RunTag, 0, len(r.Data.Tags)) - for _, tag := range r.Data.Tags { - tags = append(tags, tag.ToProto()) - } - - data := &protos.RunData{ - Metrics: metrics, - Params: params, - Tags: tags, - } - - datasetInputs := make([]*protos.DatasetInput, 0, len(r.Inputs.DatasetInputs)) - for _, input := range r.Inputs.DatasetInputs { - datasetInputs = append(datasetInputs, input.ToProto()) - } - - inputs := &protos.RunInputs{ - DatasetInputs: datasetInputs, - } - - return &protos.Run{ - Info: &protos.RunInfo{ - RunId: &r.Info.RunID, - RunUuid: &r.Info.RunID, - RunName: &r.Info.RunName, - ExperimentId: utils.ConvertInt32PointerToStringPointer(&r.Info.ExperimentID), - UserId: &r.Info.UserID, - Status: RunStatusToProto(r.Info.Status), - StartTime: &r.Info.StartTime, - EndTime: r.Info.EndTime, - ArtifactUri: &r.Info.ArtifactURI, - LifecycleStage: utils.PtrTo(r.Info.LifecycleStage), - }, - Data: data, - Inputs: inputs, - } -} +package entities + +import ( + "strings" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func RunStatusToProto(status string) *protos.RunStatus { + if status == "" { + return nil + } + + if protoStatus, ok := protos.RunStatus_value[strings.ToUpper(status)]; ok { + return (*protos.RunStatus)(&protoStatus) + } + + return nil +} + +type Run struct { + Info *RunInfo + Data *RunData + Inputs *RunInputs +} + +func (r Run) ToProto() *protos.Run { + metrics := make([]*protos.Metric, 0, len(r.Data.Metrics)) + for _, metric := range r.Data.Metrics { + metrics = append(metrics, metric.ToProto()) + } + + params := make([]*protos.Param, 0, len(r.Data.Params)) + for _, param := range r.Data.Params { + params = append(params, param.ToProto()) + } + + tags := make([]*protos.RunTag, 0, len(r.Data.Tags)) + for _, tag := range r.Data.Tags { + tags = append(tags, tag.ToProto()) + } + + data := &protos.RunData{ + Metrics: metrics, + Params: params, + Tags: tags, + } + + datasetInputs := make([]*protos.DatasetInput, 0, len(r.Inputs.DatasetInputs)) + for _, input := range r.Inputs.DatasetInputs { + datasetInputs = append(datasetInputs, input.ToProto()) + } + + inputs := &protos.RunInputs{ + DatasetInputs: datasetInputs, + } + + return &protos.Run{ + Info: &protos.RunInfo{ + RunId: &r.Info.RunID, + RunUuid: &r.Info.RunID, + RunName: &r.Info.RunName, + ExperimentId: utils.ConvertInt32PointerToStringPointer(&r.Info.ExperimentID), + UserId: &r.Info.UserID, + Status: RunStatusToProto(r.Info.Status), + StartTime: &r.Info.StartTime, + EndTime: r.Info.EndTime, + ArtifactUri: &r.Info.ArtifactURI, + LifecycleStage: utils.PtrTo(r.Info.LifecycleStage), + }, + Data: data, + Inputs: inputs, + } +} diff --git a/pkg/entities/run_data.go b/pkg/entities/run_data.go index 92679f2..38eeb21 100644 --- a/pkg/entities/run_data.go +++ b/pkg/entities/run_data.go @@ -1,7 +1,7 @@ -package entities - -type RunData struct { - Tags []*RunTag - Params []*Param - Metrics []*Metric -} +package entities + +type RunData struct { + Tags []*RunTag + Params []*Param + Metrics []*Metric +} diff --git a/pkg/entities/run_info.go b/pkg/entities/run_info.go index 51debc1..15eea71 100644 --- a/pkg/entities/run_info.go +++ b/pkg/entities/run_info.go @@ -1,34 +1,34 @@ -package entities - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type RunInfo struct { - RunID string - RunUUID string - RunName string - ExperimentID int32 - UserID string - Status string - StartTime int64 - EndTime *int64 - ArtifactURI string - LifecycleStage string -} - -func (ri RunInfo) ToProto() *protos.RunInfo { - return &protos.RunInfo{ - RunId: &ri.RunID, - RunUuid: &ri.RunID, - RunName: &ri.RunName, - ExperimentId: utils.ConvertInt32PointerToStringPointer(&ri.ExperimentID), - UserId: &ri.UserID, - Status: RunStatusToProto(ri.Status), - StartTime: &ri.StartTime, - EndTime: ri.EndTime, - ArtifactUri: &ri.ArtifactURI, - LifecycleStage: utils.PtrTo(ri.LifecycleStage), - } -} +package entities + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type RunInfo struct { + RunID string + RunUUID string + RunName string + ExperimentID int32 + UserID string + Status string + StartTime int64 + EndTime *int64 + ArtifactURI string + LifecycleStage string +} + +func (ri RunInfo) ToProto() *protos.RunInfo { + return &protos.RunInfo{ + RunId: &ri.RunID, + RunUuid: &ri.RunID, + RunName: &ri.RunName, + ExperimentId: utils.ConvertInt32PointerToStringPointer(&ri.ExperimentID), + UserId: &ri.UserID, + Status: RunStatusToProto(ri.Status), + StartTime: &ri.StartTime, + EndTime: ri.EndTime, + ArtifactUri: &ri.ArtifactURI, + LifecycleStage: utils.PtrTo(ri.LifecycleStage), + } +} diff --git a/pkg/entities/run_inputs.go b/pkg/entities/run_inputs.go index 20bc444..c66ec9a 100644 --- a/pkg/entities/run_inputs.go +++ b/pkg/entities/run_inputs.go @@ -1,5 +1,5 @@ -package entities - -type RunInputs struct { - DatasetInputs []*DatasetInput -} +package entities + +type RunInputs struct { + DatasetInputs []*DatasetInput +} diff --git a/pkg/entities/run_tag.go b/pkg/entities/run_tag.go index 1b31f2e..2321507 100644 --- a/pkg/entities/run_tag.go +++ b/pkg/entities/run_tag.go @@ -1,22 +1,22 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type RunTag struct { - Key string - Value string -} - -func (t RunTag) ToProto() *protos.RunTag { - return &protos.RunTag{ - Key: &t.Key, - Value: &t.Value, - } -} - -func NewTagFromProto(proto *protos.RunTag) *RunTag { - return &RunTag{ - Key: proto.GetKey(), - Value: proto.GetValue(), - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type RunTag struct { + Key string + Value string +} + +func (t RunTag) ToProto() *protos.RunTag { + return &protos.RunTag{ + Key: &t.Key, + Value: &t.Value, + } +} + +func NewTagFromProto(proto *protos.RunTag) *RunTag { + return &RunTag{ + Key: proto.GetKey(), + Value: proto.GetValue(), + } +} diff --git a/pkg/lib/artifacts.go b/pkg/lib/artifacts.go index 4aaca17..198dc5e 100644 --- a/pkg/lib/artifacts.go +++ b/pkg/lib/artifacts.go @@ -1,22 +1,22 @@ -package main - -import "C" - -import ( - "unsafe" - - "github.com/mlflow/mlflow-go/pkg/artifacts/service" -) - -var artifactsServices = newInstanceMap[*service.ArtifactsService]() - -//export CreateArtifactsService -func CreateArtifactsService(configData unsafe.Pointer, configSize C.int) int64 { - //nolint:nlreturn - return artifactsServices.Create(service.NewArtifactsService, C.GoBytes(configData, configSize)) -} - -//export DestroyArtifactsService -func DestroyArtifactsService(id int64) { - artifactsServices.Destroy(id) -} +package main + +import "C" + +import ( + "unsafe" + + "github.com/mlflow/mlflow-go/pkg/artifacts/service" +) + +var artifactsServices = newInstanceMap[*service.ArtifactsService]() + +//export CreateArtifactsService +func CreateArtifactsService(configData unsafe.Pointer, configSize C.int) int64 { + //nolint:nlreturn + return artifactsServices.Create(service.NewArtifactsService, C.GoBytes(configData, configSize)) +} + +//export DestroyArtifactsService +func DestroyArtifactsService(id int64) { + artifactsServices.Destroy(id) +} diff --git a/pkg/lib/ffi.go b/pkg/lib/ffi.go index 335a828..60751d2 100644 --- a/pkg/lib/ffi.go +++ b/pkg/lib/ffi.go @@ -1,91 +1,91 @@ -package main - -import "C" - -import ( - "context" - "encoding/json" - "unsafe" - - "google.golang.org/protobuf/proto" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -func unmarshalAndValidateProto( - data []byte, - msg proto.Message, -) *contract.Error { - if err := proto.Unmarshal(data, msg); err != nil { - return contract.NewError( - protos.ErrorCode_BAD_REQUEST, - err.Error(), - ) - } - - validate, cErr := getValidator() - if cErr != nil { - return cErr - } - - if err := validate.Struct(msg); err != nil { - return validation.NewErrorFromValidationError(err) - } - - return nil -} - -func marshalProto(msg proto.Message) ([]byte, *contract.Error) { - res, err := proto.Marshal(msg) - if err != nil { - return nil, contract.NewError( - protos.ErrorCode_INTERNAL_ERROR, - err.Error(), - ) - } - - return res, nil -} - -func makePointerFromBytes(data []byte, size *C.int) unsafe.Pointer { - *size = C.int(len(data)) - - return C.CBytes(data) //nolint:nlreturn -} - -func makePointerFromError(err *contract.Error, size *C.int) unsafe.Pointer { - data, _ := json.Marshal(err) //nolint:errchkjson - - return makePointerFromBytes(data, size) -} - -// invokeServiceMethod is a helper function that invokes a service method and handles -// marshalling/unmarshalling of request/response data through the FFI boundary. -func invokeServiceMethod[I, O proto.Message]( - serviceMethod func(context.Context, I) (O, *contract.Error), - request I, - requestData unsafe.Pointer, - requestSize C.int, - responseSize *C.int, -) unsafe.Pointer { - requestBytes := C.GoBytes(requestData, requestSize) //nolint:nlreturn - - err := unmarshalAndValidateProto(requestBytes, request) - if err != nil { - return makePointerFromError(err, responseSize) - } - - response, err := serviceMethod(context.Background(), request) - if err != nil { - return makePointerFromError(err, responseSize) - } - - responseBytes, err := marshalProto(response) - if err != nil { - return makePointerFromError(err, responseSize) - } - - return makePointerFromBytes(responseBytes, responseSize) -} +package main + +import "C" + +import ( + "context" + "encoding/json" + "unsafe" + + "google.golang.org/protobuf/proto" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +func unmarshalAndValidateProto( + data []byte, + msg proto.Message, +) *contract.Error { + if err := proto.Unmarshal(data, msg); err != nil { + return contract.NewError( + protos.ErrorCode_BAD_REQUEST, + err.Error(), + ) + } + + validate, cErr := getValidator() + if cErr != nil { + return cErr + } + + if err := validate.Struct(msg); err != nil { + return validation.NewErrorFromValidationError(err) + } + + return nil +} + +func marshalProto(msg proto.Message) ([]byte, *contract.Error) { + res, err := proto.Marshal(msg) + if err != nil { + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + err.Error(), + ) + } + + return res, nil +} + +func makePointerFromBytes(data []byte, size *C.int) unsafe.Pointer { + *size = C.int(len(data)) + + return C.CBytes(data) //nolint:nlreturn +} + +func makePointerFromError(err *contract.Error, size *C.int) unsafe.Pointer { + data, _ := json.Marshal(err) //nolint:errchkjson + + return makePointerFromBytes(data, size) +} + +// invokeServiceMethod is a helper function that invokes a service method and handles +// marshalling/unmarshalling of request/response data through the FFI boundary. +func invokeServiceMethod[I, O proto.Message]( + serviceMethod func(context.Context, I) (O, *contract.Error), + request I, + requestData unsafe.Pointer, + requestSize C.int, + responseSize *C.int, +) unsafe.Pointer { + requestBytes := C.GoBytes(requestData, requestSize) //nolint:nlreturn + + err := unmarshalAndValidateProto(requestBytes, request) + if err != nil { + return makePointerFromError(err, responseSize) + } + + response, err := serviceMethod(context.Background(), request) + if err != nil { + return makePointerFromError(err, responseSize) + } + + responseBytes, err := marshalProto(response) + if err != nil { + return makePointerFromError(err, responseSize) + } + + return makePointerFromBytes(responseBytes, responseSize) +} diff --git a/pkg/lib/instance_map.go b/pkg/lib/instance_map.go index 70c778d..dabf1a9 100644 --- a/pkg/lib/instance_map.go +++ b/pkg/lib/instance_map.go @@ -1,78 +1,78 @@ -package main - -import ( - "context" - "sync" - - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type instanceMap[T any] struct { - counter int64 - mutex sync.Mutex - instances map[int64]T -} - -func newInstanceMap[T any]() *instanceMap[T] { - return &instanceMap[T]{ - instances: make(map[int64]T), - } -} - -//nolint:ireturn -func (s *instanceMap[T]) Get(id int64) (T, *contract.Error) { - instance, ok := s.instances[id] - if !ok { - return instance, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - "Instance not found", - ) - } - - return instance, nil -} - -func (s *instanceMap[T]) Create( - creator func(ctx context.Context, cfg *config.Config) (T, error), - configBytes []byte, -) int64 { - cfg, err := config.NewConfigFromBytes(configBytes) - if err != nil { - logrus.Error("Failed to read config: ", err) - - return -1 - } - - logger := utils.NewLoggerFromConfig(cfg) - - logger.Debugf("Loaded config: %#v", cfg) - - instance, err := creator( - utils.NewContextWithLogger(context.Background(), logger), - cfg, - ) - if err != nil { - logger.Error("Failed to create instance: ", err) - - return -1 - } - - s.mutex.Lock() - defer s.mutex.Unlock() - - s.counter++ - s.instances[s.counter] = instance - - return s.counter -} - -func (s *instanceMap[T]) Destroy(id int64) { - s.mutex.Lock() - defer s.mutex.Unlock() - delete(s.instances, id) -} +package main + +import ( + "context" + "sync" + + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type instanceMap[T any] struct { + counter int64 + mutex sync.Mutex + instances map[int64]T +} + +func newInstanceMap[T any]() *instanceMap[T] { + return &instanceMap[T]{ + instances: make(map[int64]T), + } +} + +//nolint:ireturn +func (s *instanceMap[T]) Get(id int64) (T, *contract.Error) { + instance, ok := s.instances[id] + if !ok { + return instance, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + "Instance not found", + ) + } + + return instance, nil +} + +func (s *instanceMap[T]) Create( + creator func(ctx context.Context, cfg *config.Config) (T, error), + configBytes []byte, +) int64 { + cfg, err := config.NewConfigFromBytes(configBytes) + if err != nil { + logrus.Error("Failed to read config: ", err) + + return -1 + } + + logger := utils.NewLoggerFromConfig(cfg) + + logger.Debugf("Loaded config: %#v", cfg) + + instance, err := creator( + utils.NewContextWithLogger(context.Background(), logger), + cfg, + ) + if err != nil { + logger.Error("Failed to create instance: ", err) + + return -1 + } + + s.mutex.Lock() + defer s.mutex.Unlock() + + s.counter++ + s.instances[s.counter] = instance + + return s.counter +} + +func (s *instanceMap[T]) Destroy(id int64) { + s.mutex.Lock() + defer s.mutex.Unlock() + delete(s.instances, id) +} diff --git a/pkg/lib/main.go b/pkg/lib/main.go index 38dd16d..c8a27b4 100644 --- a/pkg/lib/main.go +++ b/pkg/lib/main.go @@ -1,3 +1,3 @@ -package main - -func main() {} +package main + +func main() {} diff --git a/pkg/lib/model_registry.go b/pkg/lib/model_registry.go index f7efb6c..13133f6 100644 --- a/pkg/lib/model_registry.go +++ b/pkg/lib/model_registry.go @@ -1,22 +1,22 @@ -package main - -import "C" - -import ( - "unsafe" - - "github.com/mlflow/mlflow-go/pkg/model_registry/service" -) - -var modelRegistryServices = newInstanceMap[*service.ModelRegistryService]() - -//export CreateModelRegistryService -func CreateModelRegistryService(configData unsafe.Pointer, configSize C.int) int64 { - //nolint:nlreturn - return modelRegistryServices.Create(service.NewModelRegistryService, C.GoBytes(configData, configSize)) -} - -//export DestroyModelRegistryService -func DestroyModelRegistryService(id int64) { - modelRegistryServices.Destroy(id) -} +package main + +import "C" + +import ( + "unsafe" + + "github.com/mlflow/mlflow-go/pkg/model_registry/service" +) + +var modelRegistryServices = newInstanceMap[*service.ModelRegistryService]() + +//export CreateModelRegistryService +func CreateModelRegistryService(configData unsafe.Pointer, configSize C.int) int64 { + //nolint:nlreturn + return modelRegistryServices.Create(service.NewModelRegistryService, C.GoBytes(configData, configSize)) +} + +//export DestroyModelRegistryService +func DestroyModelRegistryService(id int64) { + modelRegistryServices.Destroy(id) +} diff --git a/pkg/lib/server.go b/pkg/lib/server.go index a2d02d2..912041a 100644 --- a/pkg/lib/server.go +++ b/pkg/lib/server.go @@ -1,83 +1,83 @@ -package main - -import "C" - -import ( - "context" - "unsafe" - - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/server" -) - -type serverInstance struct { - cancel context.CancelFunc - errChan <-chan error -} - -var serverInstances = newInstanceMap[serverInstance]() - -//export LaunchServer -func LaunchServer(configData unsafe.Pointer, configSize C.int) int64 { - cfg, err := config.NewConfigFromBytes(C.GoBytes(configData, configSize)) //nolint:nlreturn - if err != nil { - logrus.Error("Failed to read config: ", err) - - return -1 - } - - if err := server.LaunchWithSignalHandler(cfg); err != nil { - logrus.Error("Failed to launch server: ", err) - - return -1 - } - - return 0 -} - -//export LaunchServerAsync -func LaunchServerAsync(configData unsafe.Pointer, configSize C.int) int64 { - serverID := serverInstances.Create( - func(ctx context.Context, cfg *config.Config) (serverInstance, error) { - errChan := make(chan error, 1) - - ctx, cancel := context.WithCancel(ctx) - - go func() { - errChan <- server.Launch(ctx, cfg) - }() - - return serverInstance{ - cancel: cancel, - errChan: errChan, - }, nil - }, - C.GoBytes(configData, configSize), //nolint:nlreturn - ) - - return serverID -} - -//export StopServer -func StopServer(serverID int64) int64 { - instance, cErr := serverInstances.Get(serverID) - if cErr != nil { - logrus.Error("Failed to get instance: ", cErr) - - return -1 - } - defer serverInstances.Destroy(serverID) - - instance.cancel() - - err := <-instance.errChan - if err != nil { - logrus.Error("Server has exited with error: ", err) - - return -1 - } - - return 0 -} +package main + +import "C" + +import ( + "context" + "unsafe" + + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/server" +) + +type serverInstance struct { + cancel context.CancelFunc + errChan <-chan error +} + +var serverInstances = newInstanceMap[serverInstance]() + +//export LaunchServer +func LaunchServer(configData unsafe.Pointer, configSize C.int) int64 { + cfg, err := config.NewConfigFromBytes(C.GoBytes(configData, configSize)) //nolint:nlreturn + if err != nil { + logrus.Error("Failed to read config: ", err) + + return -1 + } + + if err := server.LaunchWithSignalHandler(cfg); err != nil { + logrus.Error("Failed to launch server: ", err) + + return -1 + } + + return 0 +} + +//export LaunchServerAsync +func LaunchServerAsync(configData unsafe.Pointer, configSize C.int) int64 { + serverID := serverInstances.Create( + func(ctx context.Context, cfg *config.Config) (serverInstance, error) { + errChan := make(chan error, 1) + + ctx, cancel := context.WithCancel(ctx) + + go func() { + errChan <- server.Launch(ctx, cfg) + }() + + return serverInstance{ + cancel: cancel, + errChan: errChan, + }, nil + }, + C.GoBytes(configData, configSize), //nolint:nlreturn + ) + + return serverID +} + +//export StopServer +func StopServer(serverID int64) int64 { + instance, cErr := serverInstances.Get(serverID) + if cErr != nil { + logrus.Error("Failed to get instance: ", cErr) + + return -1 + } + defer serverInstances.Destroy(serverID) + + instance.cancel() + + err := <-instance.errChan + if err != nil { + logrus.Error("Server has exited with error: ", err) + + return -1 + } + + return 0 +} diff --git a/pkg/lib/tracking.g.go b/pkg/lib/tracking.g.go index 6a3e751..7bf304b 100644 --- a/pkg/lib/tracking.g.go +++ b/pkg/lib/tracking.g.go @@ -95,6 +95,22 @@ func TrackingServiceLogMetric(serviceID int64, requestData unsafe.Pointer, reque } return invokeServiceMethod(service.LogMetric, new(protos.LogMetric), requestData, requestSize, responseSize) } +//export TrackingServiceSetTag +func TrackingServiceSetTag(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := trackingServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.SetTag, new(protos.SetTag), requestData, requestSize, responseSize) +} +//export TrackingServiceDeleteTag +func TrackingServiceDeleteTag(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := trackingServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.DeleteTag, new(protos.DeleteTag), requestData, requestSize, responseSize) +} //export TrackingServiceGetRun func TrackingServiceGetRun(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := trackingServices.Get(serviceID) diff --git a/pkg/lib/tracking.go b/pkg/lib/tracking.go index 8cd6257..800b5ac 100644 --- a/pkg/lib/tracking.go +++ b/pkg/lib/tracking.go @@ -1,22 +1,22 @@ -package main - -import "C" - -import ( - "unsafe" - - "github.com/mlflow/mlflow-go/pkg/tracking/service" -) - -var trackingServices = newInstanceMap[*service.TrackingService]() - -//export CreateTrackingService -func CreateTrackingService(configData unsafe.Pointer, configSize C.int) int64 { - //nolint:nlreturn - return trackingServices.Create(service.NewTrackingService, C.GoBytes(configData, configSize)) -} - -//export DestroyTrackingService -func DestroyTrackingService(id int64) { - trackingServices.Destroy(id) -} +package main + +import "C" + +import ( + "unsafe" + + "github.com/mlflow/mlflow-go/pkg/tracking/service" +) + +var trackingServices = newInstanceMap[*service.TrackingService]() + +//export CreateTrackingService +func CreateTrackingService(configData unsafe.Pointer, configSize C.int) int64 { + //nolint:nlreturn + return trackingServices.Create(service.NewTrackingService, C.GoBytes(configData, configSize)) +} + +//export DestroyTrackingService +func DestroyTrackingService(id int64) { + trackingServices.Destroy(id) +} diff --git a/pkg/lib/validation.go b/pkg/lib/validation.go index f46526c..b7bf64f 100644 --- a/pkg/lib/validation.go +++ b/pkg/lib/validation.go @@ -1,23 +1,23 @@ -package main - -import ( - "sync" - - "github.com/go-playground/validator/v10" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -var getValidator = sync.OnceValues(func() (*validator.Validate, *contract.Error) { - validate, err := validation.NewValidator() - if err != nil { - return nil, contract.NewError( - protos.ErrorCode_INTERNAL_ERROR, - err.Error(), - ) - } - - return validate, nil -}) +package main + +import ( + "sync" + + "github.com/go-playground/validator/v10" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +var getValidator = sync.OnceValues(func() (*validator.Validate, *contract.Error) { + validate, err := validation.NewValidator() + if err != nil { + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + err.Error(), + ) + } + + return validate, nil +}) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index f4b3d3f..1c18d8f 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -1,21 +1,21 @@ -package service - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -func (m *ModelRegistryService) GetLatestVersions( - ctx context.Context, input *protos.GetLatestVersions, -) (*protos.GetLatestVersions_Response, *contract.Error) { - latestVersions, err := m.store.GetLatestVersions(ctx, input.GetName(), input.GetStages()) - if err != nil { - return nil, err - } - - return &protos.GetLatestVersions_Response{ - ModelVersions: latestVersions, - }, nil -} +package service + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +func (m *ModelRegistryService) GetLatestVersions( + ctx context.Context, input *protos.GetLatestVersions, +) (*protos.GetLatestVersions_Response, *contract.Error) { + latestVersions, err := m.store.GetLatestVersions(ctx, input.GetName(), input.GetStages()) + if err != nil { + return nil, err + } + + return &protos.GetLatestVersions_Response{ + ModelVersions: latestVersions, + }, nil +} diff --git a/pkg/model_registry/service/service.go b/pkg/model_registry/service/service.go index 8be5f38..ae59d3e 100644 --- a/pkg/model_registry/service/service.go +++ b/pkg/model_registry/service/service.go @@ -1,27 +1,27 @@ -package service - -import ( - "context" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/model_registry/store" - "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql" -) - -type ModelRegistryService struct { - store store.ModelRegistryStore - config *config.Config -} - -func NewModelRegistryService(ctx context.Context, config *config.Config) (*ModelRegistryService, error) { - store, err := sql.NewModelRegistrySQLStore(ctx, config) - if err != nil { - return nil, fmt.Errorf("failed to create new sql store: %w", err) - } - - return &ModelRegistryService{ - store: store, - config: config, - }, nil -} +package service + +import ( + "context" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/model_registry/store" + "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql" +) + +type ModelRegistryService struct { + store store.ModelRegistryStore + config *config.Config +} + +func NewModelRegistryService(ctx context.Context, config *config.Config) (*ModelRegistryService, error) { + store, err := sql.NewModelRegistrySQLStore(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to create new sql store: %w", err) + } + + return &ModelRegistryService{ + store: store, + config: config, + }, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index cbe3b45..ccf5be0 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -1,94 +1,94 @@ -package sql - -import ( - "context" - "errors" - "fmt" - "strings" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -// Validate whether there is a registered model with the given name. -func assertModelExists(db *gorm.DB, name string) *contract.Error { - if err := db.Select("name").Where("name = ?", name).First(&models.RegisteredModel{}).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("registered model with name=%q not found", name), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to query registered model with name=%q", name), - err, - ) - } - - return nil -} - -func (m *ModelRegistrySQLStore) GetLatestVersions( - ctx context.Context, name string, stages []string, -) ([]*protos.ModelVersion, *contract.Error) { - if err := assertModelExists(m.db.WithContext(ctx), name); err != nil { - return nil, err - } - - var modelVersions []*models.ModelVersion - - subQuery := m.db. - WithContext(ctx). - Model(&models.ModelVersion{}). - Select("name, MAX(version) AS max_version"). - Where("name = ?", name). - Where("current_stage <> ?", models.StageDeletedInternal). - Group("name, current_stage") - - if len(stages) > 0 { - for idx, stage := range stages { - stages[idx] = strings.ToLower(stage) - if canonicalStage, ok := models.CanonicalMapping[stages[idx]]; ok { - stages[idx] = canonicalStage - - continue - } - - return nil, contract.NewError( - protos.ErrorCode_BAD_REQUEST, - fmt.Sprintf( - "Invalid Model Version stage: %s. Value must be one of %s.", - stage, - models.AllModelVersionStages(), - ), - ) - } - - subQuery = subQuery.Where("current_stage IN (?)", stages) - } - - err := m.db. - WithContext(ctx). - Model(&models.ModelVersion{}). - Joins("JOIN (?) AS sub ON model_versions.name = sub.name AND model_versions.version = sub.max_version", subQuery). - Find(&modelVersions).Error - if err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to query latest model version for %q", name), - err, - ) - } - - results := make([]*protos.ModelVersion, 0, len(modelVersions)) - for _, modelVersion := range modelVersions { - results = append(results, modelVersion.ToProto()) - } - - return results, nil -} +package sql + +import ( + "context" + "errors" + "fmt" + "strings" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +// Validate whether there is a registered model with the given name. +func assertModelExists(db *gorm.DB, name string) *contract.Error { + if err := db.Select("name").Where("name = ?", name).First(&models.RegisteredModel{}).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("registered model with name=%q not found", name), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to query registered model with name=%q", name), + err, + ) + } + + return nil +} + +func (m *ModelRegistrySQLStore) GetLatestVersions( + ctx context.Context, name string, stages []string, +) ([]*protos.ModelVersion, *contract.Error) { + if err := assertModelExists(m.db.WithContext(ctx), name); err != nil { + return nil, err + } + + var modelVersions []*models.ModelVersion + + subQuery := m.db. + WithContext(ctx). + Model(&models.ModelVersion{}). + Select("name, MAX(version) AS max_version"). + Where("name = ?", name). + Where("current_stage <> ?", models.StageDeletedInternal). + Group("name, current_stage") + + if len(stages) > 0 { + for idx, stage := range stages { + stages[idx] = strings.ToLower(stage) + if canonicalStage, ok := models.CanonicalMapping[stages[idx]]; ok { + stages[idx] = canonicalStage + + continue + } + + return nil, contract.NewError( + protos.ErrorCode_BAD_REQUEST, + fmt.Sprintf( + "Invalid Model Version stage: %s. Value must be one of %s.", + stage, + models.AllModelVersionStages(), + ), + ) + } + + subQuery = subQuery.Where("current_stage IN (?)", stages) + } + + err := m.db. + WithContext(ctx). + Model(&models.ModelVersion{}). + Joins("JOIN (?) AS sub ON model_versions.name = sub.name AND model_versions.version = sub.max_version", subQuery). + Find(&modelVersions).Error + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to query latest model version for %q", name), + err, + ) + } + + results := make([]*protos.ModelVersion, 0, len(modelVersions)) + for _, modelVersion := range modelVersions { + results = append(results, modelVersion.ToProto()) + } + + return results, nil +} diff --git a/pkg/model_registry/store/sql/models/model_version_stage.go b/pkg/model_registry/store/sql/models/model_version_stage.go index 4020f36..56cc070 100644 --- a/pkg/model_registry/store/sql/models/model_version_stage.go +++ b/pkg/model_registry/store/sql/models/model_version_stage.go @@ -1,33 +1,33 @@ -package models - -import "strings" - -type ModelVersionStage string - -func (s ModelVersionStage) String() string { - return string(s) -} - -const ( - ModelVersionStageNone = "None" - ModelVersionStageStaging = "Staging" - ModelVersionStageProduction = "Production" - ModelVersionStageArchived = "Archived" -) - -var CanonicalMapping = map[string]string{ - strings.ToLower(ModelVersionStageNone): ModelVersionStageNone, - strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging, - strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction, - strings.ToLower(ModelVersionStageArchived): ModelVersionStageArchived, -} - -func AllModelVersionStages() string { - pairs := make([]string, 0, len(CanonicalMapping)) - - for _, v := range CanonicalMapping { - pairs = append(pairs, v) - } - - return strings.Join(pairs, ",") -} +package models + +import "strings" + +type ModelVersionStage string + +func (s ModelVersionStage) String() string { + return string(s) +} + +const ( + ModelVersionStageNone = "None" + ModelVersionStageStaging = "Staging" + ModelVersionStageProduction = "Production" + ModelVersionStageArchived = "Archived" +) + +var CanonicalMapping = map[string]string{ + strings.ToLower(ModelVersionStageNone): ModelVersionStageNone, + strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging, + strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction, + strings.ToLower(ModelVersionStageArchived): ModelVersionStageArchived, +} + +func AllModelVersionStages() string { + pairs := make([]string, 0, len(CanonicalMapping)) + + for _, v := range CanonicalMapping { + pairs = append(pairs, v) + } + + return strings.Join(pairs, ",") +} diff --git a/pkg/model_registry/store/sql/models/model_version_tags.go b/pkg/model_registry/store/sql/models/model_version_tags.go index 7bde392..16c4ab0 100644 --- a/pkg/model_registry/store/sql/models/model_version_tags.go +++ b/pkg/model_registry/store/sql/models/model_version_tags.go @@ -1,11 +1,11 @@ -package models - -// ModelVersionTag mapped from table . -// -//revive:disable:exported -type ModelVersionTag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - Name string `db:"name" gorm:"column:name;primaryKey"` - Version int32 `db:"version" gorm:"column:version;primaryKey"` -} +package models + +// ModelVersionTag mapped from table . +// +//revive:disable:exported +type ModelVersionTag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + Name string `db:"name" gorm:"column:name;primaryKey"` + Version int32 `db:"version" gorm:"column:version;primaryKey"` +} diff --git a/pkg/model_registry/store/sql/models/model_versions.go b/pkg/model_registry/store/sql/models/model_versions.go index d2b373c..c6eb4bc 100644 --- a/pkg/model_registry/store/sql/models/model_versions.go +++ b/pkg/model_registry/store/sql/models/model_versions.go @@ -1,49 +1,49 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -// ModelVersion mapped from table . -// -//revive:disable:exported -type ModelVersion struct { - Name string `db:"name" gorm:"column:name;primaryKey"` - Version int32 `db:"version" gorm:"column:version;primaryKey"` - CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` - LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` - Description string `db:"description" gorm:"column:description"` - UserID string `db:"user_id" gorm:"column:user_id"` - CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"` - Source string `db:"source" gorm:"column:source"` - RunID string `db:"run_id" gorm:"column:run_id"` - Status string `db:"status" gorm:"column:status"` - StatusMessage string `db:"status_message" gorm:"column:status_message"` - RunLink string `db:"run_link" gorm:"column:run_link"` - StorageLocation string `db:"storage_location" gorm:"column:storage_location"` -} - -const StageDeletedInternal = "Deleted_Internal" - -func (mv ModelVersion) ToProto() *protos.ModelVersion { - var status *protos.ModelVersionStatus - if s, ok := protos.ModelVersionStatus_value[mv.Status]; ok { - status = utils.PtrTo(protos.ModelVersionStatus(s)) - } - - return &protos.ModelVersion{ - Name: &mv.Name, - Version: utils.ConvertInt32PointerToStringPointer(&mv.Version), - CreationTimestamp: &mv.CreationTime, - LastUpdatedTimestamp: &mv.LastUpdatedTime, - UserId: &mv.UserID, - CurrentStage: utils.PtrTo(mv.CurrentStage.String()), - Description: &mv.Description, - Source: &mv.Source, - RunId: &mv.RunID, - Status: status, - StatusMessage: &mv.StatusMessage, - RunLink: &mv.RunLink, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +// ModelVersion mapped from table . +// +//revive:disable:exported +type ModelVersion struct { + Name string `db:"name" gorm:"column:name;primaryKey"` + Version int32 `db:"version" gorm:"column:version;primaryKey"` + CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` + LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` + Description string `db:"description" gorm:"column:description"` + UserID string `db:"user_id" gorm:"column:user_id"` + CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"` + Source string `db:"source" gorm:"column:source"` + RunID string `db:"run_id" gorm:"column:run_id"` + Status string `db:"status" gorm:"column:status"` + StatusMessage string `db:"status_message" gorm:"column:status_message"` + RunLink string `db:"run_link" gorm:"column:run_link"` + StorageLocation string `db:"storage_location" gorm:"column:storage_location"` +} + +const StageDeletedInternal = "Deleted_Internal" + +func (mv ModelVersion) ToProto() *protos.ModelVersion { + var status *protos.ModelVersionStatus + if s, ok := protos.ModelVersionStatus_value[mv.Status]; ok { + status = utils.PtrTo(protos.ModelVersionStatus(s)) + } + + return &protos.ModelVersion{ + Name: &mv.Name, + Version: utils.ConvertInt32PointerToStringPointer(&mv.Version), + CreationTimestamp: &mv.CreationTime, + LastUpdatedTimestamp: &mv.LastUpdatedTime, + UserId: &mv.UserID, + CurrentStage: utils.PtrTo(mv.CurrentStage.String()), + Description: &mv.Description, + Source: &mv.Source, + RunId: &mv.RunID, + Status: status, + StatusMessage: &mv.StatusMessage, + RunLink: &mv.RunLink, + } +} diff --git a/pkg/model_registry/store/sql/models/registered_model_aliases.go b/pkg/model_registry/store/sql/models/registered_model_aliases.go index 2cdf25a..d720770 100644 --- a/pkg/model_registry/store/sql/models/registered_model_aliases.go +++ b/pkg/model_registry/store/sql/models/registered_model_aliases.go @@ -1,8 +1,8 @@ -package models - -// RegisteredModelAlias mapped from table . -type RegisteredModelAlias struct { - Alias string `db:"alias" gorm:"column:alias;primaryKey"` - Version int32 `db:"version" gorm:"column:version;not null"` - Name string `db:"name" gorm:"column:name;primaryKey"` -} +package models + +// RegisteredModelAlias mapped from table . +type RegisteredModelAlias struct { + Alias string `db:"alias" gorm:"column:alias;primaryKey"` + Version int32 `db:"version" gorm:"column:version;not null"` + Name string `db:"name" gorm:"column:name;primaryKey"` +} diff --git a/pkg/model_registry/store/sql/models/registered_model_tags.go b/pkg/model_registry/store/sql/models/registered_model_tags.go index 6935047..99d4896 100644 --- a/pkg/model_registry/store/sql/models/registered_model_tags.go +++ b/pkg/model_registry/store/sql/models/registered_model_tags.go @@ -1,8 +1,8 @@ -package models - -// RegisteredModelTag mapped from table . -type RegisteredModelTag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - Name string `db:"name" gorm:"column:name;primaryKey"` -} +package models + +// RegisteredModelTag mapped from table . +type RegisteredModelTag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + Name string `db:"name" gorm:"column:name;primaryKey"` +} diff --git a/pkg/model_registry/store/sql/models/registered_models.go b/pkg/model_registry/store/sql/models/registered_models.go index 0a99a30..d13b6d9 100644 --- a/pkg/model_registry/store/sql/models/registered_models.go +++ b/pkg/model_registry/store/sql/models/registered_models.go @@ -1,9 +1,9 @@ -package models - -// RegisteredModel mapped from table . -type RegisteredModel struct { - Name string `db:"name" gorm:"column:name;primaryKey"` - CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` - LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` - Description string `db:"description" gorm:"column:description"` -} +package models + +// RegisteredModel mapped from table . +type RegisteredModel struct { + Name string `db:"name" gorm:"column:name;primaryKey"` + CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` + LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` + Description string `db:"description" gorm:"column:description"` +} diff --git a/pkg/model_registry/store/sql/store.go b/pkg/model_registry/store/sql/store.go index 1bfe8b2..a8a0f46 100644 --- a/pkg/model_registry/store/sql/store.go +++ b/pkg/model_registry/store/sql/store.go @@ -1,28 +1,28 @@ -package sql - -import ( - "context" - "fmt" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/sql" -) - -type ModelRegistrySQLStore struct { - config *config.Config - db *gorm.DB -} - -func NewModelRegistrySQLStore(ctx context.Context, config *config.Config) (*ModelRegistrySQLStore, error) { - database, err := sql.NewDatabase(ctx, config.ModelRegistryStoreURI) - if err != nil { - return nil, fmt.Errorf("failed to connect to database %q: %w", config.ModelRegistryStoreURI, err) - } - - return &ModelRegistrySQLStore{ - config: config, - db: database, - }, nil -} +package sql + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/sql" +) + +type ModelRegistrySQLStore struct { + config *config.Config + db *gorm.DB +} + +func NewModelRegistrySQLStore(ctx context.Context, config *config.Config) (*ModelRegistrySQLStore, error) { + database, err := sql.NewDatabase(ctx, config.ModelRegistryStoreURI) + if err != nil { + return nil, fmt.Errorf("failed to connect to database %q: %w", config.ModelRegistryStoreURI, err) + } + + return &ModelRegistrySQLStore{ + config: config, + db: database, + }, nil +} diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index d9a5c21..f14ba8c 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -1,12 +1,12 @@ -package store - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -type ModelRegistryStore interface { - GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) -} +package store + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +type ModelRegistryStore interface { + GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) +} diff --git a/pkg/protos/artifacts/mlflow_artifacts.pb.go b/pkg/protos/artifacts/mlflow_artifacts.pb.go index 9633522..ca00c6b 100644 --- a/pkg/protos/artifacts/mlflow_artifacts.pb.go +++ b/pkg/protos/artifacts/mlflow_artifacts.pb.go @@ -7,7 +7,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v5.26.0 +// protoc v3.21.12 // source: mlflow_artifacts.proto package artifacts diff --git a/pkg/protos/databricks.pb.go b/pkg/protos/databricks.pb.go index 112dda9..cb6a6b9 100644 --- a/pkg/protos/databricks.pb.go +++ b/pkg/protos/databricks.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v5.26.0 +// protoc v3.21.12 // source: databricks.proto package protos diff --git a/pkg/protos/databricks_artifacts.pb.go b/pkg/protos/databricks_artifacts.pb.go index e8ec09b..edcc57d 100644 --- a/pkg/protos/databricks_artifacts.pb.go +++ b/pkg/protos/databricks_artifacts.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v5.26.0 +// protoc v3.21.12 // source: databricks_artifacts.proto package protos diff --git a/pkg/protos/internal.pb.go b/pkg/protos/internal.pb.go index 2a9d430..38d9bc8 100644 --- a/pkg/protos/internal.pb.go +++ b/pkg/protos/internal.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v5.26.0 +// protoc v3.21.12 // source: internal.proto package protos diff --git a/pkg/protos/model_registry.pb.go b/pkg/protos/model_registry.pb.go index a6aec3f..dd6eb30 100644 --- a/pkg/protos/model_registry.pb.go +++ b/pkg/protos/model_registry.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v5.26.0 +// protoc v3.21.12 // source: model_registry.proto package protos diff --git a/pkg/protos/scalapb/scalapb.pb.go b/pkg/protos/scalapb/scalapb.pb.go index 3b4f090..95aaa19 100644 --- a/pkg/protos/scalapb/scalapb.pb.go +++ b/pkg/protos/scalapb/scalapb.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v5.26.0 +// protoc v3.21.12 // source: scalapb/scalapb.proto package scalapb diff --git a/pkg/protos/service.pb.go b/pkg/protos/service.pb.go index 5d45861..8a388b5 100644 --- a/pkg/protos/service.pb.go +++ b/pkg/protos/service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v5.26.0 +// protoc v3.21.12 // source: service.proto package protos @@ -2074,13 +2074,13 @@ type SetTag struct { unknownFields protoimpl.UnknownFields // ID of the run under which to log the tag. Must be provided. - RunId *string `protobuf:"bytes,4,opt,name=run_id,json=runId" json:"run_id,omitempty" query:"run_id"` + RunId *string `protobuf:"bytes,4,opt,name=run_id,json=runId" json:"run_id,omitempty" query:"run_id" validate:"required"` // [Deprecated, use run_id instead] ID of the run under which to log the tag. This field will // be removed in a future MLflow version. RunUuid *string `protobuf:"bytes,1,opt,name=run_uuid,json=runUuid" json:"run_uuid,omitempty" query:"run_uuid"` // Name of the tag. Maximum size depends on storage backend. // All storage backends are guaranteed to support key values up to 250 bytes in size. - Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key"` + Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key" validate:"required"` // String value of the tag being logged. Maximum size depends on storage backend. // All storage backends are guaranteed to support key values up to 5000 bytes in size. Value *string `protobuf:"bytes,3,opt,name=value" json:"value,omitempty" query:"value"` @@ -2152,9 +2152,9 @@ type DeleteTag struct { unknownFields protoimpl.UnknownFields // ID of the run that the tag was logged under. Must be provided. - RunId *string `protobuf:"bytes,1,opt,name=run_id,json=runId" json:"run_id,omitempty" query:"run_id"` + RunId *string `protobuf:"bytes,1,opt,name=run_id,json=runId" json:"run_id,omitempty" query:"run_id" validate:"required"` // Name of the tag. Maximum size is 255 bytes. Must be provided. - Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key"` + Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key" validate:"required"` } func (x *DeleteTag) Reset() { diff --git a/pkg/server/command/command.go b/pkg/server/command/command.go index ac218a1..515264a 100644 --- a/pkg/server/command/command.go +++ b/pkg/server/command/command.go @@ -1,42 +1,42 @@ -package command - -import ( - "context" - "fmt" - "os" - "os/exec" - "time" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func LaunchCommand(ctx context.Context, cfg *config.Config) error { - logger := utils.GetLoggerFromContext(ctx) - - //nolint:gosec - cmd, err := newProcessGroupCommand( - ctx, - exec.CommandContext(ctx, cfg.PythonCommand[0], cfg.PythonCommand[1:]...), - ) - if err != nil { - return fmt.Errorf("failed to create process group command: %w", err) - } - - cmd.Env = append(os.Environ(), cfg.PythonEnv...) - cmd.Stdout = logger.Writer() - cmd.Stderr = logger.Writer() - cmd.WaitDelay = 5 * time.Second //nolint:mnd - - logger.Debugf("Launching command: %v", cmd) - - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to launch command: %w", err) - } - - if err := cmd.Wait(); err != nil { - return fmt.Errorf("command exited with error: %w", err) - } - - return nil -} +package command + +import ( + "context" + "fmt" + "os" + "os/exec" + "time" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func LaunchCommand(ctx context.Context, cfg *config.Config) error { + logger := utils.GetLoggerFromContext(ctx) + + //nolint:gosec + cmd, err := newProcessGroupCommand( + ctx, + exec.CommandContext(ctx, cfg.PythonCommand[0], cfg.PythonCommand[1:]...), + ) + if err != nil { + return fmt.Errorf("failed to create process group command: %w", err) + } + + cmd.Env = append(os.Environ(), cfg.PythonEnv...) + cmd.Stdout = logger.Writer() + cmd.Stderr = logger.Writer() + cmd.WaitDelay = 5 * time.Second //nolint:mnd + + logger.Debugf("Launching command: %v", cmd) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to launch command: %w", err) + } + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("command exited with error: %w", err) + } + + return nil +} diff --git a/pkg/server/command/command_posix.go b/pkg/server/command/command_posix.go index 3e71750..7982a02 100644 --- a/pkg/server/command/command_posix.go +++ b/pkg/server/command/command_posix.go @@ -1,30 +1,30 @@ -//go:build !windows - -package command - -import ( - "context" - "os/exec" - "syscall" - - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*exec.Cmd, error) { - logger := utils.GetLoggerFromContext(ctx) - - // Create the process in a new process group - cmd.SysProcAttr = &syscall.SysProcAttr{ - Setpgid: true, - Pgid: 0, - } - - // Terminate the process group - cmd.Cancel = func() error { - logger.Debug("Sending interrupt signal to command process group") - - return syscall.Kill(-cmd.Process.Pid, syscall.SIGINT) - } - - return cmd, nil -} +//go:build !windows + +package command + +import ( + "context" + "os/exec" + "syscall" + + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*exec.Cmd, error) { + logger := utils.GetLoggerFromContext(ctx) + + // Create the process in a new process group + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Pgid: 0, + } + + // Terminate the process group + cmd.Cancel = func() error { + logger.Debug("Sending interrupt signal to command process group") + + return syscall.Kill(-cmd.Process.Pid, syscall.SIGINT) + } + + return cmd, nil +} diff --git a/pkg/server/command/command_windows.go b/pkg/server/command/command_windows.go index d1401e3..c94c20e 100644 --- a/pkg/server/command/command_windows.go +++ b/pkg/server/command/command_windows.go @@ -1,77 +1,77 @@ -package command - -import ( - "context" - "fmt" - "os/exec" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" - - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type processGroupCmd struct { - *exec.Cmd - job windows.Handle -} - -const PROCESS_ALL_ACCESS = 2097151 - -func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*processGroupCmd, error) { - logger := utils.GetLoggerFromContext(ctx) - - // Get the job object handle - jobHandle, err := windows.CreateJobObject(nil, nil) - if err != nil { - return nil, fmt.Errorf("failed to create job object: %w", err) - } - - // Set the job object to kill processes when the job is closed - info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ - BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ - LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, - }, - } - if _, err = windows.SetInformationJobObject( - jobHandle, - windows.JobObjectExtendedLimitInformation, - uintptr(unsafe.Pointer(&info)), - uint32(unsafe.Sizeof(info))); err != nil { - return nil, fmt.Errorf("failed to set job object information: %w", err) - } - - // Terminate the job object (which will terminate all processes in the job) - cmd.Cancel = func() error { - logger.Debug("Closing job object to terminate command process group") - - return windows.CloseHandle(jobHandle) - } - - // Create the process in a new process group - cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP} - - return &processGroupCmd{Cmd: cmd, job: jobHandle}, nil -} - -func (pgc *processGroupCmd) Start() error { - // Start the command - if err := pgc.Cmd.Start(); err != nil { - return fmt.Errorf("failed to start command: %w", err) - } - - // Get the process handle - hProc, err := windows.OpenProcess(PROCESS_ALL_ACCESS, true, uint32(pgc.Process.Pid)) - if err != nil { - return fmt.Errorf("failed to open process: %w", err) - } - defer windows.CloseHandle(hProc) - - // Assign the process to the job object - if err := windows.AssignProcessToJobObject(pgc.job, hProc); err != nil { - return fmt.Errorf("failed to assign process to job object: %w", err) - } - - return nil -} +package command + +import ( + "context" + "fmt" + "os/exec" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type processGroupCmd struct { + *exec.Cmd + job windows.Handle +} + +const PROCESS_ALL_ACCESS = 2097151 + +func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*processGroupCmd, error) { + logger := utils.GetLoggerFromContext(ctx) + + // Get the job object handle + jobHandle, err := windows.CreateJobObject(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to create job object: %w", err) + } + + // Set the job object to kill processes when the job is closed + info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ + BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ + LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + } + if _, err = windows.SetInformationJobObject( + jobHandle, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info))); err != nil { + return nil, fmt.Errorf("failed to set job object information: %w", err) + } + + // Terminate the job object (which will terminate all processes in the job) + cmd.Cancel = func() error { + logger.Debug("Closing job object to terminate command process group") + + return windows.CloseHandle(jobHandle) + } + + // Create the process in a new process group + cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP} + + return &processGroupCmd{Cmd: cmd, job: jobHandle}, nil +} + +func (pgc *processGroupCmd) Start() error { + // Start the command + if err := pgc.Cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + // Get the process handle + hProc, err := windows.OpenProcess(PROCESS_ALL_ACCESS, true, uint32(pgc.Process.Pid)) + if err != nil { + return fmt.Errorf("failed to open process: %w", err) + } + defer windows.CloseHandle(hProc) + + // Assign the process to the job object + if err := windows.AssignProcessToJobObject(pgc.job, hProc); err != nil { + return fmt.Errorf("failed to assign process to job object: %w", err) + } + + return nil +} diff --git a/pkg/server/launch.go b/pkg/server/launch.go index 8c50f32..5f70369 100644 --- a/pkg/server/launch.go +++ b/pkg/server/launch.go @@ -1,86 +1,86 @@ -package server - -import ( - "context" - "errors" - "os" - "os/signal" - "sync" - "syscall" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/server/command" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func Launch(ctx context.Context, cfg *config.Config) error { - if len(cfg.PythonCommand) > 0 { - return launchCommandAndServer(ctx, cfg) - } - - return launchServer(ctx, cfg) -} - -func launchCommandAndServer(ctx context.Context, cfg *config.Config) error { - var errs []error - - logger := utils.GetLoggerFromContext(ctx) - - cmdCtx, cmdCancel := context.WithCancel(ctx) - srvCtx, srvCancel := context.WithCancel(ctx) - - waitGroup := sync.WaitGroup{} - waitGroup.Add(1) - - go func() { - defer waitGroup.Done() - - if err := command.LaunchCommand(cmdCtx, cfg); err != nil && cmdCtx.Err() == nil { - errs = append(errs, err) - } - - logger.Debug("Python server has exited") - - srvCancel() - }() - - waitGroup.Add(1) - - go func() { - defer waitGroup.Done() - - if err := launchServer(srvCtx, cfg); err != nil && srvCtx.Err() == nil { - errs = append(errs, err) - } - - logger.Debug("Go server has exited") - - cmdCancel() - }() - - waitGroup.Wait() - - return errors.Join(errs...) -} - -func LaunchWithSignalHandler(cfg *config.Config) error { - logger := utils.NewLoggerFromConfig(cfg) - - logger.Debugf("Loaded config: %#v", cfg) - - sigint := make(chan os.Signal, 1) - signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) - defer signal.Stop(sigint) - - ctx, cancel := context.WithCancel( - utils.NewContextWithLogger(context.Background(), logger)) - - go func() { - sig := <-sigint - logger.Debugf("Received signal: %v", sig) - - cancel() - }() - - return Launch(ctx, cfg) -} +package server + +import ( + "context" + "errors" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/server/command" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func Launch(ctx context.Context, cfg *config.Config) error { + if len(cfg.PythonCommand) > 0 { + return launchCommandAndServer(ctx, cfg) + } + + return launchServer(ctx, cfg) +} + +func launchCommandAndServer(ctx context.Context, cfg *config.Config) error { + var errs []error + + logger := utils.GetLoggerFromContext(ctx) + + cmdCtx, cmdCancel := context.WithCancel(ctx) + srvCtx, srvCancel := context.WithCancel(ctx) + + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + + go func() { + defer waitGroup.Done() + + if err := command.LaunchCommand(cmdCtx, cfg); err != nil && cmdCtx.Err() == nil { + errs = append(errs, err) + } + + logger.Debug("Python server has exited") + + srvCancel() + }() + + waitGroup.Add(1) + + go func() { + defer waitGroup.Done() + + if err := launchServer(srvCtx, cfg); err != nil && srvCtx.Err() == nil { + errs = append(errs, err) + } + + logger.Debug("Go server has exited") + + cmdCancel() + }() + + waitGroup.Wait() + + return errors.Join(errs...) +} + +func LaunchWithSignalHandler(cfg *config.Config) error { + logger := utils.NewLoggerFromConfig(cfg) + + logger.Debugf("Loaded config: %#v", cfg) + + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigint) + + ctx, cancel := context.WithCancel( + utils.NewContextWithLogger(context.Background(), logger)) + + go func() { + sig := <-sigint + logger.Debugf("Received signal: %v", sig) + + cancel() + }() + + return Launch(ctx, cfg) +} diff --git a/pkg/server/parser/http_request_parser.go b/pkg/server/parser/http_request_parser.go index 9b9db53..5e17242 100644 --- a/pkg/server/parser/http_request_parser.go +++ b/pkg/server/parser/http_request_parser.go @@ -1,72 +1,72 @@ -package parser - -import ( - "encoding/json" - "errors" - "fmt" - - "github.com/go-playground/validator/v10" - "github.com/gofiber/fiber/v2" - "github.com/tidwall/gjson" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -type HTTPRequestParser struct { - validator *validator.Validate -} - -func NewHTTPRequestParser() (*HTTPRequestParser, error) { - validator, err := validation.NewValidator() - if err != nil { - return nil, fmt.Errorf("failed to create validator: %w", err) - } - - return &HTTPRequestParser{ - validator: validator, - }, nil -} - -func (p *HTTPRequestParser) ParseBody(ctx *fiber.Ctx, input proto.Message) *contract.Error { - if protojsonErr := protojson.Unmarshal(ctx.Body(), input); protojsonErr != nil { - // falling back to JSON, because `protojson` doesn't provide any information - // about `field` name for which ut fails. MLFlow tests expect to know the exact - // `field` name where validation failed. This approach has no effect on MLFlow - // tests, so let's keep it for now. - if jsonErr := json.Unmarshal(ctx.Body(), input); jsonErr != nil { - var unmarshalTypeError *json.UnmarshalTypeError - if errors.As(jsonErr, &unmarshalTypeError) { - result := gjson.GetBytes(ctx.Body(), unmarshalTypeError.Field) - - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Invalid value %s for parameter '%s' supplied", result.Raw, unmarshalTypeError.Field), - ) - } - } - - return contract.NewError(protos.ErrorCode_BAD_REQUEST, protojsonErr.Error()) - } - - if err := p.validator.Struct(input); err != nil { - return validation.NewErrorFromValidationError(err) - } - - return nil -} - -func (p *HTTPRequestParser) ParseQuery(ctx *fiber.Ctx, input interface{}) *contract.Error { - if err := ctx.QueryParser(input); err != nil { - return contract.NewError(protos.ErrorCode_BAD_REQUEST, err.Error()) - } - - if err := p.validator.Struct(input); err != nil { - return validation.NewErrorFromValidationError(err) - } - - return nil -} +package parser + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" + "github.com/tidwall/gjson" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +type HTTPRequestParser struct { + validator *validator.Validate +} + +func NewHTTPRequestParser() (*HTTPRequestParser, error) { + validator, err := validation.NewValidator() + if err != nil { + return nil, fmt.Errorf("failed to create validator: %w", err) + } + + return &HTTPRequestParser{ + validator: validator, + }, nil +} + +func (p *HTTPRequestParser) ParseBody(ctx *fiber.Ctx, input proto.Message) *contract.Error { + if protojsonErr := protojson.Unmarshal(ctx.Body(), input); protojsonErr != nil { + // falling back to JSON, because `protojson` doesn't provide any information + // about `field` name for which ut fails. MLFlow tests expect to know the exact + // `field` name where validation failed. This approach has no effect on MLFlow + // tests, so let's keep it for now. + if jsonErr := json.Unmarshal(ctx.Body(), input); jsonErr != nil { + var unmarshalTypeError *json.UnmarshalTypeError + if errors.As(jsonErr, &unmarshalTypeError) { + result := gjson.GetBytes(ctx.Body(), unmarshalTypeError.Field) + + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Invalid value %s for parameter '%s' supplied", result.Raw, unmarshalTypeError.Field), + ) + } + } + + return contract.NewError(protos.ErrorCode_BAD_REQUEST, protojsonErr.Error()) + } + + if err := p.validator.Struct(input); err != nil { + return validation.NewErrorFromValidationError(err) + } + + return nil +} + +func (p *HTTPRequestParser) ParseQuery(ctx *fiber.Ctx, input interface{}) *contract.Error { + if err := ctx.QueryParser(input); err != nil { + return contract.NewError(protos.ErrorCode_BAD_REQUEST, err.Error()) + } + + if err := p.validator.Struct(input); err != nil { + return validation.NewErrorFromValidationError(err) + } + + return nil +} diff --git a/pkg/server/routes/tracking.g.go b/pkg/server/routes/tracking.g.go index 265dd0b..63a9c4c 100644 --- a/pkg/server/routes/tracking.g.go +++ b/pkg/server/routes/tracking.g.go @@ -132,6 +132,28 @@ func RegisterTrackingServiceRoutes(service service.TrackingService, parser *pars } return ctx.JSON(output) }) + app.Post("/mlflow/runs/set-tag", func(ctx *fiber.Ctx) error { + input := &protos.SetTag{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.SetTag(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) + app.Post("/mlflow/runs/delete-tag", func(ctx *fiber.Ctx) error { + input := &protos.DeleteTag{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.DeleteTag(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Get("/mlflow/runs/get", func(ctx *fiber.Ctx) error { input := &protos.GetRun{} if err := parser.ParseQuery(ctx, input); err != nil { diff --git a/pkg/server/server.go b/pkg/server/server.go index 59022a1..3260613 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,223 +1,223 @@ -package server - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net" - "path/filepath" - "time" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/compress" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/fiber/v2/middleware/proxy" - "github.com/gofiber/fiber/v2/middleware/recover" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - - as "github.com/mlflow/mlflow-go/pkg/artifacts/service" - mr "github.com/mlflow/mlflow-go/pkg/model_registry/service" - ts "github.com/mlflow/mlflow-go/pkg/tracking/service" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/server/parser" - "github.com/mlflow/mlflow-go/pkg/server/routes" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -//nolint:funlen -func configureApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { - //nolint:mnd - app := fiber.New(fiber.Config{ - BodyLimit: 16 * 1024 * 1024, - ReadBufferSize: 16384, - ReadTimeout: 5 * time.Second, - WriteTimeout: 600 * time.Second, - IdleTimeout: 120 * time.Second, - ServerHeader: "mlflow/" + cfg.Version, - JSONEncoder: func(v interface{}) ([]byte, error) { - if protoMessage, ok := v.(proto.Message); ok { - return protojson.Marshal(protoMessage) - } - - return json.Marshal(v) - }, - JSONDecoder: func(data []byte, v interface{}) error { - if protoMessage, ok := v.(proto.Message); ok { - return protojson.Unmarshal(data, protoMessage) - } - - return json.Unmarshal(data, v) - }, - DisableStartupMessage: true, - }) - - app.Use(compress.New()) - app.Use(recover.New(recover.Config{EnableStackTrace: true})) - app.Use(logger.New(logger.Config{ - Format: "${status} - ${latency} ${method} ${path}\n", - Output: utils.GetLoggerFromContext(ctx).Writer(), - })) - app.Use(func(c *fiber.Ctx) error { - c.SetUserContext(ctx) - - return c.Next() - }) - - apiApp, err := newAPIApp(ctx, cfg) - if err != nil { - return nil, err - } - - app.Mount("/api/2.0", apiApp) - app.Mount("/ajax-api/2.0", apiApp) - - if cfg.StaticFolder != "" { - app.Static("/static-files", cfg.StaticFolder) - app.Get("/", func(c *fiber.Ctx) error { - return c.SendFile(filepath.Join(cfg.StaticFolder, "index.html")) - }) - } - - app.Get("/health", func(c *fiber.Ctx) error { - return c.SendString("OK") - }) - app.Get("/version", func(c *fiber.Ctx) error { - return c.SendString(cfg.Version) - }) - - if cfg.PythonAddress != "" { - app.Use(proxy.BalancerForward([]string{cfg.PythonAddress})) - } - - return app, nil -} - -func launchServer(ctx context.Context, cfg *config.Config) error { - logger := utils.GetLoggerFromContext(ctx) - - app, err := configureApp(ctx, cfg) - if err != nil { - return err - } - - go func() { - <-ctx.Done() - - logger.Info("Shutting down MLflow Go server") - - if err := app.ShutdownWithTimeout(cfg.ShutdownTimeout.Duration); err != nil { - logger.Errorf("Failed to gracefully shutdown MLflow Go server: %v", err) - } - }() - - if cfg.PythonAddress != "" { - logger.Debugf("Waiting for Python server to be ready on http://%s", cfg.PythonAddress) - - for { - dialer := &net.Dialer{} - conn, err := dialer.DialContext(ctx, "tcp", cfg.PythonAddress) - - if err == nil { - conn.Close() - - break - } - - if errors.Is(err, context.Canceled) { - return fmt.Errorf("failed to connect to Python server: %w", err) - } - - time.Sleep(50 * time.Millisecond) //nolint:mnd - } - logger.Debugf("Python server is ready on http://%s", cfg.PythonAddress) - } - - logger.Infof("Launching MLflow Go server on http://%s", cfg.Address) - - err = app.Listen(cfg.Address) - if err != nil { - return fmt.Errorf("failed to start MLflow Go server: %w", err) - } - - return nil -} - -func newFiberConfig() fiber.Config { - return fiber.Config{ - ErrorHandler: func(context *fiber.Ctx, err error) error { - var contractError *contract.Error - if !errors.As(err, &contractError) { - code := protos.ErrorCode_INTERNAL_ERROR - - var f *fiber.Error - if errors.As(err, &f) { - switch f.Code { - case fiber.StatusBadRequest: - code = protos.ErrorCode_BAD_REQUEST - case fiber.StatusServiceUnavailable: - code = protos.ErrorCode_SERVICE_UNDER_MAINTENANCE - case fiber.StatusNotFound: - code = protos.ErrorCode_ENDPOINT_NOT_FOUND - } - } - - contractError = contract.NewError(code, err.Error()) - } - - var logFn func(format string, args ...any) - - logger := utils.GetLoggerFromContext(context.Context()) - switch contractError.StatusCode() { - case fiber.StatusBadRequest: - logFn = logger.Infof - case fiber.StatusServiceUnavailable: - logFn = logger.Warnf - case fiber.StatusNotFound: - logFn = logger.Debugf - default: - logFn = logger.Errorf - } - - logFn("Error encountered in %s %s: %s", context.Method(), context.Path(), err) - - return context.Status(contractError.StatusCode()).JSON(contractError) - }, - } -} - -func newAPIApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { - app := fiber.New(newFiberConfig()) - - parser, err := parser.NewHTTPRequestParser() - if err != nil { - return nil, fmt.Errorf("failed to create new HTTP request parser: %w", err) - } - - trackingService, err := ts.NewTrackingService(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to create new tracking service: %w", err) - } - - routes.RegisterTrackingServiceRoutes(trackingService, parser, app) - - modelRegistryService, err := mr.NewModelRegistryService(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to create new model registry service: %w", err) - } - - routes.RegisterModelRegistryServiceRoutes(modelRegistryService, parser, app) - - artifactService, err := as.NewArtifactsService(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to create new artifacts service: %w", err) - } - - routes.RegisterArtifactsServiceRoutes(artifactService, parser, app) - - return app, nil -} +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "path/filepath" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/compress" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/proxy" + "github.com/gofiber/fiber/v2/middleware/recover" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + as "github.com/mlflow/mlflow-go/pkg/artifacts/service" + mr "github.com/mlflow/mlflow-go/pkg/model_registry/service" + ts "github.com/mlflow/mlflow-go/pkg/tracking/service" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/server/parser" + "github.com/mlflow/mlflow-go/pkg/server/routes" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +//nolint:funlen +func configureApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { + //nolint:mnd + app := fiber.New(fiber.Config{ + BodyLimit: 16 * 1024 * 1024, + ReadBufferSize: 16384, + ReadTimeout: 5 * time.Second, + WriteTimeout: 600 * time.Second, + IdleTimeout: 120 * time.Second, + ServerHeader: "mlflow/" + cfg.Version, + JSONEncoder: func(v interface{}) ([]byte, error) { + if protoMessage, ok := v.(proto.Message); ok { + return protojson.Marshal(protoMessage) + } + + return json.Marshal(v) + }, + JSONDecoder: func(data []byte, v interface{}) error { + if protoMessage, ok := v.(proto.Message); ok { + return protojson.Unmarshal(data, protoMessage) + } + + return json.Unmarshal(data, v) + }, + DisableStartupMessage: true, + }) + + app.Use(compress.New()) + app.Use(recover.New(recover.Config{EnableStackTrace: true})) + app.Use(logger.New(logger.Config{ + Format: "${status} - ${latency} ${method} ${path}\n", + Output: utils.GetLoggerFromContext(ctx).Writer(), + })) + app.Use(func(c *fiber.Ctx) error { + c.SetUserContext(ctx) + + return c.Next() + }) + + apiApp, err := newAPIApp(ctx, cfg) + if err != nil { + return nil, err + } + + app.Mount("/api/2.0", apiApp) + app.Mount("/ajax-api/2.0", apiApp) + + if cfg.StaticFolder != "" { + app.Static("/static-files", cfg.StaticFolder) + app.Get("/", func(c *fiber.Ctx) error { + return c.SendFile(filepath.Join(cfg.StaticFolder, "index.html")) + }) + } + + app.Get("/health", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) + app.Get("/version", func(c *fiber.Ctx) error { + return c.SendString(cfg.Version) + }) + + if cfg.PythonAddress != "" { + app.Use(proxy.BalancerForward([]string{cfg.PythonAddress})) + } + + return app, nil +} + +func launchServer(ctx context.Context, cfg *config.Config) error { + logger := utils.GetLoggerFromContext(ctx) + + app, err := configureApp(ctx, cfg) + if err != nil { + return err + } + + go func() { + <-ctx.Done() + + logger.Info("Shutting down MLflow Go server") + + if err := app.ShutdownWithTimeout(cfg.ShutdownTimeout.Duration); err != nil { + logger.Errorf("Failed to gracefully shutdown MLflow Go server: %v", err) + } + }() + + if cfg.PythonAddress != "" { + logger.Debugf("Waiting for Python server to be ready on http://%s", cfg.PythonAddress) + + for { + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", cfg.PythonAddress) + + if err == nil { + conn.Close() + + break + } + + if errors.Is(err, context.Canceled) { + return fmt.Errorf("failed to connect to Python server: %w", err) + } + + time.Sleep(50 * time.Millisecond) //nolint:mnd + } + logger.Debugf("Python server is ready on http://%s", cfg.PythonAddress) + } + + logger.Infof("Launching MLflow Go server on http://%s", cfg.Address) + + err = app.Listen(cfg.Address) + if err != nil { + return fmt.Errorf("failed to start MLflow Go server: %w", err) + } + + return nil +} + +func newFiberConfig() fiber.Config { + return fiber.Config{ + ErrorHandler: func(context *fiber.Ctx, err error) error { + var contractError *contract.Error + if !errors.As(err, &contractError) { + code := protos.ErrorCode_INTERNAL_ERROR + + var f *fiber.Error + if errors.As(err, &f) { + switch f.Code { + case fiber.StatusBadRequest: + code = protos.ErrorCode_BAD_REQUEST + case fiber.StatusServiceUnavailable: + code = protos.ErrorCode_SERVICE_UNDER_MAINTENANCE + case fiber.StatusNotFound: + code = protos.ErrorCode_ENDPOINT_NOT_FOUND + } + } + + contractError = contract.NewError(code, err.Error()) + } + + var logFn func(format string, args ...any) + + logger := utils.GetLoggerFromContext(context.Context()) + switch contractError.StatusCode() { + case fiber.StatusBadRequest: + logFn = logger.Infof + case fiber.StatusServiceUnavailable: + logFn = logger.Warnf + case fiber.StatusNotFound: + logFn = logger.Debugf + default: + logFn = logger.Errorf + } + + logFn("Error encountered in %s %s: %s", context.Method(), context.Path(), err) + + return context.Status(contractError.StatusCode()).JSON(contractError) + }, + } +} + +func newAPIApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { + app := fiber.New(newFiberConfig()) + + parser, err := parser.NewHTTPRequestParser() + if err != nil { + return nil, fmt.Errorf("failed to create new HTTP request parser: %w", err) + } + + trackingService, err := ts.NewTrackingService(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create new tracking service: %w", err) + } + + routes.RegisterTrackingServiceRoutes(trackingService, parser, app) + + modelRegistryService, err := mr.NewModelRegistryService(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create new model registry service: %w", err) + } + + routes.RegisterModelRegistryServiceRoutes(modelRegistryService, parser, app) + + artifactService, err := as.NewArtifactsService(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create new artifacts service: %w", err) + } + + routes.RegisterArtifactsServiceRoutes(artifactService, parser, app) + + return app, nil +} diff --git a/pkg/sql/logger.go b/pkg/sql/logger.go index 2699168..3939229 100644 --- a/pkg/sql/logger.go +++ b/pkg/sql/logger.go @@ -1,139 +1,139 @@ -//nolint:goprintffuncname -package sql - -import ( - "context" - "errors" - "fmt" - "runtime" - "strings" - "time" - - "github.com/sirupsen/logrus" - "gorm.io/gorm" - "gorm.io/gorm/logger" -) - -type loggerAdaptor struct { - Logger *logrus.Logger - Config LoggerAdaptorConfig -} - -type LoggerAdaptorConfig struct { - SlowThreshold time.Duration - IgnoreRecordNotFoundError bool - ParameterizedQueries bool -} - -// NewLoggerAdaptor creates a new logger adaptor. -// -//nolint:ireturn -func NewLoggerAdaptor(l *logrus.Logger, cfg LoggerAdaptorConfig) logger.Interface { - return &loggerAdaptor{l, cfg} -} - -// LogMode implements the gorm.io/gorm/logger.Interface interface and is a no-op. -// -//nolint:ireturn -func (l *loggerAdaptor) LogMode(_ logger.LogLevel) logger.Interface { - return l -} - -const ( - maximumCallerDepth int = 15 - minimumCallerDepth int = 4 -) - -// getLoggerEntry gets a logger entry with context and caller information added. -func (l *loggerAdaptor) getLoggerEntry(ctx context.Context) *logrus.Entry { - entry := l.Logger.WithContext(ctx) - // We want to report the caller of the function that called gorm's logger, - // not the caller of the loggerAdaptor, so we skip the first few frames and - // then look for the first frame that is not in the gorm package. - pcs := make([]uintptr, maximumCallerDepth) - depth := runtime.Callers(minimumCallerDepth, pcs) - frames := runtime.CallersFrames(pcs[:depth]) - - for f, again := frames.Next(); again; f, again = frames.Next() { - if !strings.HasPrefix(f.Function, "gorm.io/gorm") { - entry = entry.WithFields(logrus.Fields{ - "app_file": fmt.Sprintf("%s:%d", f.File, f.Line), - "app_func": f.Function + "()", - }) - - break - } - } - - return entry -} - -// Info logs message at info level and implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Info(ctx context.Context, format string, args ...interface{}) { - l.getLoggerEntry(ctx).Infof(format, args...) -} - -// Warn logs message at warn level and implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Warn(ctx context.Context, format string, args ...interface{}) { - l.getLoggerEntry(ctx).Warnf(format, args...) -} - -// Error logs message at error level and implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Error(ctx context.Context, format string, args ...interface{}) { - l.getLoggerEntry(ctx).Errorf(format, args...) -} - -const NanosecondsPerMillisecond = 1e6 - -// getLoggerEntryWithSQL gets a logger entry with context, caller information and SQL information added. -func (l *loggerAdaptor) getLoggerEntryWithSQL( - ctx context.Context, - elapsed time.Duration, - fc func() (sql string, rowsAffected int64), -) *logrus.Entry { - entry := l.getLoggerEntry(ctx) - - if fc != nil { - sql, rows := fc() - entry = entry.WithFields(logrus.Fields{ - "elapsed": fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/NanosecondsPerMillisecond), - "rows": rows, - "sql": sql, - }) - - if rows == -1 { - entry = entry.WithField("rows", "-") - } - } - - return entry -} - -// Trace logs SQL statement, amount of affected rows, and elapsed time. -// It implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Trace( - ctx context.Context, - begin time.Time, - function func() (sql string, rowsAffected int64), - err error, -) { - if l.Logger.GetLevel() <= logrus.FatalLevel { - return - } - - // This logic is similar to the default logger in gorm.io/gorm/logger. - elapsed := time.Since(begin) - - switch { - case err != nil && - l.Logger.IsLevelEnabled(logrus.ErrorLevel) && - (!errors.Is(err, gorm.ErrRecordNotFound) || !l.Config.IgnoreRecordNotFoundError): - l.getLoggerEntryWithSQL(ctx, elapsed, function).WithError(err).Error("SQL error") - case elapsed > l.Config.SlowThreshold && - l.Config.SlowThreshold != 0 && - l.Logger.IsLevelEnabled(logrus.WarnLevel): - l.getLoggerEntryWithSQL(ctx, elapsed, function).Warnf("SLOW SQL >= %v", l.Config.SlowThreshold) - case l.Logger.IsLevelEnabled(logrus.DebugLevel): - l.getLoggerEntryWithSQL(ctx, elapsed, function).Debug("SQL trace") - } -} +//nolint:goprintffuncname +package sql + +import ( + "context" + "errors" + "fmt" + "runtime" + "strings" + "time" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +type loggerAdaptor struct { + Logger *logrus.Logger + Config LoggerAdaptorConfig +} + +type LoggerAdaptorConfig struct { + SlowThreshold time.Duration + IgnoreRecordNotFoundError bool + ParameterizedQueries bool +} + +// NewLoggerAdaptor creates a new logger adaptor. +// +//nolint:ireturn +func NewLoggerAdaptor(l *logrus.Logger, cfg LoggerAdaptorConfig) logger.Interface { + return &loggerAdaptor{l, cfg} +} + +// LogMode implements the gorm.io/gorm/logger.Interface interface and is a no-op. +// +//nolint:ireturn +func (l *loggerAdaptor) LogMode(_ logger.LogLevel) logger.Interface { + return l +} + +const ( + maximumCallerDepth int = 15 + minimumCallerDepth int = 4 +) + +// getLoggerEntry gets a logger entry with context and caller information added. +func (l *loggerAdaptor) getLoggerEntry(ctx context.Context) *logrus.Entry { + entry := l.Logger.WithContext(ctx) + // We want to report the caller of the function that called gorm's logger, + // not the caller of the loggerAdaptor, so we skip the first few frames and + // then look for the first frame that is not in the gorm package. + pcs := make([]uintptr, maximumCallerDepth) + depth := runtime.Callers(minimumCallerDepth, pcs) + frames := runtime.CallersFrames(pcs[:depth]) + + for f, again := frames.Next(); again; f, again = frames.Next() { + if !strings.HasPrefix(f.Function, "gorm.io/gorm") { + entry = entry.WithFields(logrus.Fields{ + "app_file": fmt.Sprintf("%s:%d", f.File, f.Line), + "app_func": f.Function + "()", + }) + + break + } + } + + return entry +} + +// Info logs message at info level and implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Info(ctx context.Context, format string, args ...interface{}) { + l.getLoggerEntry(ctx).Infof(format, args...) +} + +// Warn logs message at warn level and implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Warn(ctx context.Context, format string, args ...interface{}) { + l.getLoggerEntry(ctx).Warnf(format, args...) +} + +// Error logs message at error level and implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Error(ctx context.Context, format string, args ...interface{}) { + l.getLoggerEntry(ctx).Errorf(format, args...) +} + +const NanosecondsPerMillisecond = 1e6 + +// getLoggerEntryWithSQL gets a logger entry with context, caller information and SQL information added. +func (l *loggerAdaptor) getLoggerEntryWithSQL( + ctx context.Context, + elapsed time.Duration, + fc func() (sql string, rowsAffected int64), +) *logrus.Entry { + entry := l.getLoggerEntry(ctx) + + if fc != nil { + sql, rows := fc() + entry = entry.WithFields(logrus.Fields{ + "elapsed": fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/NanosecondsPerMillisecond), + "rows": rows, + "sql": sql, + }) + + if rows == -1 { + entry = entry.WithField("rows", "-") + } + } + + return entry +} + +// Trace logs SQL statement, amount of affected rows, and elapsed time. +// It implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Trace( + ctx context.Context, + begin time.Time, + function func() (sql string, rowsAffected int64), + err error, +) { + if l.Logger.GetLevel() <= logrus.FatalLevel { + return + } + + // This logic is similar to the default logger in gorm.io/gorm/logger. + elapsed := time.Since(begin) + + switch { + case err != nil && + l.Logger.IsLevelEnabled(logrus.ErrorLevel) && + (!errors.Is(err, gorm.ErrRecordNotFound) || !l.Config.IgnoreRecordNotFoundError): + l.getLoggerEntryWithSQL(ctx, elapsed, function).WithError(err).Error("SQL error") + case elapsed > l.Config.SlowThreshold && + l.Config.SlowThreshold != 0 && + l.Logger.IsLevelEnabled(logrus.WarnLevel): + l.getLoggerEntryWithSQL(ctx, elapsed, function).Warnf("SLOW SQL >= %v", l.Config.SlowThreshold) + case l.Logger.IsLevelEnabled(logrus.DebugLevel): + l.getLoggerEntryWithSQL(ctx, elapsed, function).Debug("SQL trace") + } +} diff --git a/pkg/sql/sql.go b/pkg/sql/sql.go index a7e88d4..bc6b8e5 100644 --- a/pkg/sql/sql.go +++ b/pkg/sql/sql.go @@ -1,90 +1,90 @@ -package sql - -import ( - "context" - "errors" - "fmt" - "net/url" - "strings" - - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/driver/sqlserver" - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/utils" -) - -var errSqliteMemory = errors.New("go implementation does not support :memory: for sqlite") - -//nolint:ireturn -func getDialector(uri *url.URL) (gorm.Dialector, error) { - uri.Scheme, _, _ = strings.Cut(uri.Scheme, "+") - - switch uri.Scheme { - case "mssql": - uri.Scheme = "sqlserver" - - return sqlserver.Open(uri.String()), nil - case "mysql": - return mysql.Open(fmt.Sprintf("%s@tcp(%s)%s?%s", uri.User, uri.Host, uri.Path, uri.RawQuery)), nil - case "postgres", "postgresql": - return postgres.Open(uri.String()), nil - case "sqlite": - uri.Scheme = "" - uri.Path = uri.Path[1:] - - if uri.Path == ":memory:" { - return nil, errSqliteMemory - } - - return sqlite.Open(uri.String()), nil - default: - return nil, fmt.Errorf("unsupported store URL scheme %q", uri.Scheme) //nolint:err113 - } -} - -func initSqlite(database *gorm.DB) error { - database.Exec("PRAGMA case_sensitive_like = true;") - - sqlDB, err := database.DB() - if err != nil { - return fmt.Errorf("failed to get database instance: %w", err) - } - // set SetMaxOpenConns to be 1 only in case of SQLite to avoid `database is locked` - // in case of parallel calls to some endpoints that use `transactions`. - sqlDB.SetMaxOpenConns(1) - - return nil -} - -func NewDatabase(ctx context.Context, storeURL string) (*gorm.DB, error) { - logger := utils.GetLoggerFromContext(ctx) - - uri, err := url.Parse(storeURL) - if err != nil { - return nil, fmt.Errorf("failed to parse store URL %q: %w", storeURL, err) - } - - dialector, err := getDialector(uri) - if err != nil { - return nil, err - } - - database, err := gorm.Open(dialector, &gorm.Config{ - TranslateError: true, - Logger: NewLoggerAdaptor(logger, LoggerAdaptorConfig{}), - }) - if err != nil { - return nil, fmt.Errorf("failed to connect to database %q: %w", uri.String(), err) - } - - if dialector.Name() == "sqlite" { - if err := initSqlite(database); err != nil { - return nil, err - } - } - - return database, nil -} +package sql + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/utils" +) + +var errSqliteMemory = errors.New("go implementation does not support :memory: for sqlite") + +//nolint:ireturn +func getDialector(uri *url.URL) (gorm.Dialector, error) { + uri.Scheme, _, _ = strings.Cut(uri.Scheme, "+") + + switch uri.Scheme { + case "mssql": + uri.Scheme = "sqlserver" + + return sqlserver.Open(uri.String()), nil + case "mysql": + return mysql.Open(fmt.Sprintf("%s@tcp(%s)%s?%s", uri.User, uri.Host, uri.Path, uri.RawQuery)), nil + case "postgres", "postgresql": + return postgres.Open(uri.String()), nil + case "sqlite": + uri.Scheme = "" + uri.Path = uri.Path[1:] + + if uri.Path == ":memory:" { + return nil, errSqliteMemory + } + + return sqlite.Open(uri.String()), nil + default: + return nil, fmt.Errorf("unsupported store URL scheme %q", uri.Scheme) //nolint:err113 + } +} + +func initSqlite(database *gorm.DB) error { + database.Exec("PRAGMA case_sensitive_like = true;") + + sqlDB, err := database.DB() + if err != nil { + return fmt.Errorf("failed to get database instance: %w", err) + } + // set SetMaxOpenConns to be 1 only in case of SQLite to avoid `database is locked` + // in case of parallel calls to some endpoints that use `transactions`. + sqlDB.SetMaxOpenConns(1) + + return nil +} + +func NewDatabase(ctx context.Context, storeURL string) (*gorm.DB, error) { + logger := utils.GetLoggerFromContext(ctx) + + uri, err := url.Parse(storeURL) + if err != nil { + return nil, fmt.Errorf("failed to parse store URL %q: %w", storeURL, err) + } + + dialector, err := getDialector(uri) + if err != nil { + return nil, err + } + + database, err := gorm.Open(dialector, &gorm.Config{ + TranslateError: true, + Logger: NewLoggerAdaptor(logger, LoggerAdaptorConfig{}), + }) + if err != nil { + return nil, fmt.Errorf("failed to connect to database %q: %w", uri.String(), err) + } + + if dialector.Name() == "sqlite" { + if err := initSqlite(database); err != nil { + return nil, err + } + } + + return database, nil +} diff --git a/pkg/tracking/service/experiments.go b/pkg/tracking/service/experiments.go index 0eeefa2..6e4291f 100644 --- a/pkg/tracking/service/experiments.go +++ b/pkg/tracking/service/experiments.go @@ -1,134 +1,134 @@ -package service - -import ( - "context" - "fmt" - "net/url" - "path/filepath" - "runtime" - "strings" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -// CreateExperiment implements TrackingService. -func (ts TrackingService) CreateExperiment(ctx context.Context, input *protos.CreateExperiment) ( - *protos.CreateExperiment_Response, *contract.Error, -) { - if input.GetArtifactLocation() != "" { - artifactLocation := strings.TrimRight(input.GetArtifactLocation(), "/") - - // We don't check the validation here as this was already covered in the validator. - url, _ := url.Parse(artifactLocation) - switch url.Scheme { - case "file", "": - path, err := filepath.Abs(url.Path) - if err != nil { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("error getting absolute path: %v", err), - ) - } - - if runtime.GOOS == "windows" { - url.Scheme = "file" - path = "/" + strings.ReplaceAll(path, "\\", "/") - } - - url.Path = path - artifactLocation = url.String() - } - - input.ArtifactLocation = &artifactLocation - } - - tags := make([]*entities.ExperimentTag, len(input.GetTags())) - for i, tag := range input.GetTags() { - tags[i] = entities.NewExperimentTagFromProto(tag) - } - - experimentID, err := ts.Store.CreateExperiment(ctx, input.GetName(), input.GetArtifactLocation(), tags) - if err != nil { - return nil, err - } - - return &protos.CreateExperiment_Response{ - ExperimentId: &experimentID, - }, nil -} - -// GetExperiment implements TrackingService. -func (ts TrackingService) GetExperiment( - ctx context.Context, input *protos.GetExperiment, -) (*protos.GetExperiment_Response, *contract.Error) { - experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - return &protos.GetExperiment_Response{ - Experiment: experiment.ToProto(), - }, nil -} - -func (ts TrackingService) DeleteExperiment( - ctx context.Context, input *protos.DeleteExperiment, -) (*protos.DeleteExperiment_Response, *contract.Error) { - err := ts.Store.DeleteExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - return &protos.DeleteExperiment_Response{}, nil -} - -func (ts TrackingService) RestoreExperiment( - ctx context.Context, input *protos.RestoreExperiment, -) (*protos.RestoreExperiment_Response, *contract.Error) { - err := ts.Store.RestoreExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - return &protos.RestoreExperiment_Response{}, nil -} - -func (ts TrackingService) UpdateExperiment( - ctx context.Context, input *protos.UpdateExperiment, -) (*protos.UpdateExperiment_Response, *contract.Error) { - experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - if experiment.LifecycleStage != string(models.LifecycleStageActive) { - return nil, contract.NewError( - protos.ErrorCode_INVALID_STATE, - "Cannot rename a non-active experiment.", - ) - } - - if name := input.GetNewName(); name != "" { - if err := ts.Store.RenameExperiment(ctx, input.GetExperimentId(), input.GetNewName()); err != nil { - return nil, err - } - } - - return &protos.UpdateExperiment_Response{}, nil -} - -func (ts TrackingService) GetExperimentByName( - ctx context.Context, input *protos.GetExperimentByName, -) (*protos.GetExperimentByName_Response, *contract.Error) { - experiment, err := ts.Store.GetExperimentByName(ctx, input.GetExperimentName()) - if err != nil { - return nil, err - } - - return &protos.GetExperimentByName_Response{ - Experiment: experiment.ToProto(), - }, nil -} +package service + +import ( + "context" + "fmt" + "net/url" + "path/filepath" + "runtime" + "strings" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +// CreateExperiment implements TrackingService. +func (ts TrackingService) CreateExperiment(ctx context.Context, input *protos.CreateExperiment) ( + *protos.CreateExperiment_Response, *contract.Error, +) { + if input.GetArtifactLocation() != "" { + artifactLocation := strings.TrimRight(input.GetArtifactLocation(), "/") + + // We don't check the validation here as this was already covered in the validator. + url, _ := url.Parse(artifactLocation) + switch url.Scheme { + case "file", "": + path, err := filepath.Abs(url.Path) + if err != nil { + return nil, contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("error getting absolute path: %v", err), + ) + } + + if runtime.GOOS == "windows" { + url.Scheme = "file" + path = "/" + strings.ReplaceAll(path, "\\", "/") + } + + url.Path = path + artifactLocation = url.String() + } + + input.ArtifactLocation = &artifactLocation + } + + tags := make([]*entities.ExperimentTag, len(input.GetTags())) + for i, tag := range input.GetTags() { + tags[i] = entities.NewExperimentTagFromProto(tag) + } + + experimentID, err := ts.Store.CreateExperiment(ctx, input.GetName(), input.GetArtifactLocation(), tags) + if err != nil { + return nil, err + } + + return &protos.CreateExperiment_Response{ + ExperimentId: &experimentID, + }, nil +} + +// GetExperiment implements TrackingService. +func (ts TrackingService) GetExperiment( + ctx context.Context, input *protos.GetExperiment, +) (*protos.GetExperiment_Response, *contract.Error) { + experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + return &protos.GetExperiment_Response{ + Experiment: experiment.ToProto(), + }, nil +} + +func (ts TrackingService) DeleteExperiment( + ctx context.Context, input *protos.DeleteExperiment, +) (*protos.DeleteExperiment_Response, *contract.Error) { + err := ts.Store.DeleteExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + return &protos.DeleteExperiment_Response{}, nil +} + +func (ts TrackingService) RestoreExperiment( + ctx context.Context, input *protos.RestoreExperiment, +) (*protos.RestoreExperiment_Response, *contract.Error) { + err := ts.Store.RestoreExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + return &protos.RestoreExperiment_Response{}, nil +} + +func (ts TrackingService) UpdateExperiment( + ctx context.Context, input *protos.UpdateExperiment, +) (*protos.UpdateExperiment_Response, *contract.Error) { + experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + if experiment.LifecycleStage != string(models.LifecycleStageActive) { + return nil, contract.NewError( + protos.ErrorCode_INVALID_STATE, + "Cannot rename a non-active experiment.", + ) + } + + if name := input.GetNewName(); name != "" { + if err := ts.Store.RenameExperiment(ctx, input.GetExperimentId(), input.GetNewName()); err != nil { + return nil, err + } + } + + return &protos.UpdateExperiment_Response{}, nil +} + +func (ts TrackingService) GetExperimentByName( + ctx context.Context, input *protos.GetExperimentByName, +) (*protos.GetExperimentByName_Response, *contract.Error) { + experiment, err := ts.Store.GetExperimentByName(ctx, input.GetExperimentName()) + if err != nil { + return nil, err + } + + return &protos.GetExperimentByName_Response{ + Experiment: experiment.ToProto(), + }, nil +} diff --git a/pkg/tracking/service/experiments_test.go b/pkg/tracking/service/experiments_test.go index cbee4e9..52ec5fd 100644 --- a/pkg/tracking/service/experiments_test.go +++ b/pkg/tracking/service/experiments_test.go @@ -1,61 +1,61 @@ -package service //nolint:testpackage - -import ( - "context" - "testing" - - "github.com/stretchr/testify/mock" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type testRelativeArtifactLocationScenario struct { - name string - input string -} - -func TestRelativeArtifactLocation(t *testing.T) { - t.Parallel() - - scenarios := []testRelativeArtifactLocationScenario{ - {name: "without scheme", input: "../yow"}, - {name: "with file scheme", input: "file:///../yow"}, - } - - for _, scenario := range scenarios { - t.Run(scenario.name, func(t *testing.T) { - t.Parallel() - - store := store.NewMockTrackingStore(t) - store.EXPECT().CreateExperiment( - context.Background(), - mock.Anything, - mock.Anything, - mock.Anything, - ).Return(mock.Anything, nil) - - service := TrackingService{ - Store: store, - } - - input := protos.CreateExperiment{ - ArtifactLocation: utils.PtrTo(scenario.input), - } - - response, err := service.CreateExperiment(context.Background(), &input) - if err != nil { - t.Error("expected create experiment to succeed") - } - - if response == nil { - t.Error("expected response to be non-nil") - } - - if input.GetArtifactLocation() == scenario.input { - t.Errorf("expected artifact location to be absolute, got %s", input.GetArtifactLocation()) - } - }) - } -} +package service //nolint:testpackage + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type testRelativeArtifactLocationScenario struct { + name string + input string +} + +func TestRelativeArtifactLocation(t *testing.T) { + t.Parallel() + + scenarios := []testRelativeArtifactLocationScenario{ + {name: "without scheme", input: "../yow"}, + {name: "with file scheme", input: "file:///../yow"}, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + t.Parallel() + + store := store.NewMockTrackingStore(t) + store.EXPECT().CreateExperiment( + context.Background(), + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(mock.Anything, nil) + + service := TrackingService{ + Store: store, + } + + input := protos.CreateExperiment{ + ArtifactLocation: utils.PtrTo(scenario.input), + } + + response, err := service.CreateExperiment(context.Background(), &input) + if err != nil { + t.Error("expected create experiment to succeed") + } + + if response == nil { + t.Error("expected response to be non-nil") + } + + if input.GetArtifactLocation() == scenario.input { + t.Errorf("expected artifact location to be absolute, got %s", input.GetArtifactLocation()) + } + }) + } +} diff --git a/pkg/tracking/service/metrics.go b/pkg/tracking/service/metrics.go index e62bf0b..edb2005 100644 --- a/pkg/tracking/service/metrics.go +++ b/pkg/tracking/service/metrics.go @@ -1,20 +1,20 @@ -package service - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -func (ts TrackingService) LogMetric( - ctx context.Context, - input *protos.LogMetric, -) (*protos.LogMetric_Response, *contract.Error) { - if err := ts.Store.LogMetric(ctx, input.GetRunId(), entities.MetricFromLogMetricProtoInput(input)); err != nil { - return nil, err - } - - return &protos.LogMetric_Response{}, nil -} +package service + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +func (ts TrackingService) LogMetric( + ctx context.Context, + input *protos.LogMetric, +) (*protos.LogMetric_Response, *contract.Error) { + if err := ts.Store.LogMetric(ctx, input.GetRunId(), entities.MetricFromLogMetricProtoInput(input)); err != nil { + return nil, err + } + + return &protos.LogMetric_Response{}, nil +} diff --git a/pkg/tracking/service/query/README.md b/pkg/tracking/service/query/README.md index 019bb82..a5dba0e 100644 --- a/pkg/tracking/service/query/README.md +++ b/pkg/tracking/service/query/README.md @@ -1,8 +1,8 @@ -# Search Query Syntax - -Mlflow has a [query syntax](https://mlflow.org/docs/latest/search-runs.html#search-query-syntax-deep-dive). - -This package is meant to lex and parse this query dialect. - -The code is slightly based on the https://github.com/tlaceby/parser-series. -I did not implement a proper Pratt parser because of how limited the query language is. +# Search Query Syntax + +Mlflow has a [query syntax](https://mlflow.org/docs/latest/search-runs.html#search-query-syntax-deep-dive). + +This package is meant to lex and parse this query dialect. + +The code is slightly based on the https://github.com/tlaceby/parser-series. +I did not implement a proper Pratt parser because of how limited the query language is. diff --git a/pkg/tracking/service/query/lexer/token.go b/pkg/tracking/service/query/lexer/token.go index 97fddf3..064246a 100644 --- a/pkg/tracking/service/query/lexer/token.go +++ b/pkg/tracking/service/query/lexer/token.go @@ -1,111 +1,111 @@ -package lexer - -import "fmt" - -type TokenKind int - -const ( - EOF TokenKind = iota - Number - String - Identifier - - // Grouping & Braces. - OpenParen - CloseParen - - // Equivilance. - Equals - NotEquals - - // Conditional. - Less - LessEquals - Greater - GreaterEquals - - // Symbols. - Dot - Comma - - // Reserved Keywords. - In //nolint:varnamelen - Not - Like - ILike - And -) - -//nolint:gochecknoglobals -var reservedLu = map[string]TokenKind{ - "AND": And, - "NOT": Not, - "IN": In, - "LIKE": Like, - "ILIKE": ILike, -} - -type Token struct { - Kind TokenKind - Value string -} - -func (token Token) Debug() string { - if token.Kind == Identifier || token.Kind == Number || token.Kind == String { - return fmt.Sprintf("%s(%s)", TokenKindString(token.Kind), token.Value) - } - - return TokenKindString(token.Kind) -} - -//nolint:funlen,cyclop -func TokenKindString(kind TokenKind) string { - switch kind { - case EOF: - return "eof" - case Number: - return "number" - case String: - return "string" - case Identifier: - return "identifier" - case OpenParen: - return "open_paren" - case CloseParen: - return "close_paren" - case Equals: - return "equals" - case NotEquals: - return "not_equals" - case Less: - return "less" - case LessEquals: - return "less_equals" - case Greater: - return "greater" - case GreaterEquals: - return "greater_equals" - case And: - return "and" - case Dot: - return "dot" - case Comma: - return "comma" - case In: - return "in" - case Not: - return "not" - case Like: - return "like" - case ILike: - return "ilike" - default: - return fmt.Sprintf("unknown(%d)", kind) - } -} - -func newUniqueToken(kind TokenKind, value string) Token { - return Token{ - kind, value, - } -} +package lexer + +import "fmt" + +type TokenKind int + +const ( + EOF TokenKind = iota + Number + String + Identifier + + // Grouping & Braces. + OpenParen + CloseParen + + // Equivilance. + Equals + NotEquals + + // Conditional. + Less + LessEquals + Greater + GreaterEquals + + // Symbols. + Dot + Comma + + // Reserved Keywords. + In //nolint:varnamelen + Not + Like + ILike + And +) + +//nolint:gochecknoglobals +var reservedLu = map[string]TokenKind{ + "AND": And, + "NOT": Not, + "IN": In, + "LIKE": Like, + "ILIKE": ILike, +} + +type Token struct { + Kind TokenKind + Value string +} + +func (token Token) Debug() string { + if token.Kind == Identifier || token.Kind == Number || token.Kind == String { + return fmt.Sprintf("%s(%s)", TokenKindString(token.Kind), token.Value) + } + + return TokenKindString(token.Kind) +} + +//nolint:funlen,cyclop +func TokenKindString(kind TokenKind) string { + switch kind { + case EOF: + return "eof" + case Number: + return "number" + case String: + return "string" + case Identifier: + return "identifier" + case OpenParen: + return "open_paren" + case CloseParen: + return "close_paren" + case Equals: + return "equals" + case NotEquals: + return "not_equals" + case Less: + return "less" + case LessEquals: + return "less_equals" + case Greater: + return "greater" + case GreaterEquals: + return "greater_equals" + case And: + return "and" + case Dot: + return "dot" + case Comma: + return "comma" + case In: + return "in" + case Not: + return "not" + case Like: + return "like" + case ILike: + return "ilike" + default: + return fmt.Sprintf("unknown(%d)", kind) + } +} + +func newUniqueToken(kind TokenKind, value string) Token { + return Token{ + kind, value, + } +} diff --git a/pkg/tracking/service/query/lexer/tokenizer.go b/pkg/tracking/service/query/lexer/tokenizer.go index 9d4fa61..6fd3c64 100644 --- a/pkg/tracking/service/query/lexer/tokenizer.go +++ b/pkg/tracking/service/query/lexer/tokenizer.go @@ -1,145 +1,145 @@ -package lexer - -import ( - "fmt" - "regexp" - "strings" -) - -type regexPattern struct { - regex *regexp.Regexp - handler regexHandler -} - -type lexer struct { - patterns []regexPattern - Tokens []Token - source *string - pos int - line int -} - -type Error struct { - message string -} - -func NewLexerError(format string, a ...any) *Error { - return &Error{message: fmt.Sprintf(format, a...)} -} - -func (e *Error) Error() string { - return e.message -} - -func Tokenize(source *string) ([]Token, error) { - lex := createLexer(source) - - for !lex.atEOF() { - matched := false - - for _, pattern := range lex.patterns { - loc := pattern.regex.FindStringIndex(lex.remainder()) - if loc != nil && loc[0] == 0 { - pattern.handler(lex, pattern.regex) - - matched = true - - break // Exit the loop after the first match - } - } - - if !matched { - return lex.Tokens, NewLexerError("unrecognized token near '%v'", lex.remainder()) - } - } - - lex.push(newUniqueToken(EOF, "EOF")) - - return lex.Tokens, nil -} - -func (lex *lexer) advanceN(n int) { - lex.pos += n -} - -func (lex *lexer) remainder() string { - return (*lex.source)[lex.pos:] -} - -func (lex *lexer) push(token Token) { - lex.Tokens = append(lex.Tokens, token) -} - -func (lex *lexer) atEOF() bool { - return lex.pos >= len(*lex.source) -} - -func createLexer(source *string) *lexer { - return &lexer{ - pos: 0, - line: 1, - source: source, - Tokens: make([]Token, 0), - patterns: []regexPattern{ - {regexp.MustCompile(`\s+`), skipHandler}, - {regexp.MustCompile(`"[^"]*"`), stringHandler}, - {regexp.MustCompile(`'[^\']*\'`), stringHandler}, - {regexp.MustCompile("`[^`]*`"), stringHandler}, - {regexp.MustCompile(`\-?[0-9]+(\.[0-9]+)?`), numberHandler}, - {regexp.MustCompile(`[a-zA-Z_][a-zA-Z0-9_]*`), symbolHandler}, - {regexp.MustCompile(`\(`), defaultHandler(OpenParen, "(")}, - {regexp.MustCompile(`\)`), defaultHandler(CloseParen, ")")}, - {regexp.MustCompile(`!=`), defaultHandler(NotEquals, "!=")}, - {regexp.MustCompile(`=`), defaultHandler(Equals, "=")}, - {regexp.MustCompile(`<=`), defaultHandler(LessEquals, "<=")}, - {regexp.MustCompile(`<`), defaultHandler(Less, "<")}, - {regexp.MustCompile(`>=`), defaultHandler(GreaterEquals, ">=")}, - {regexp.MustCompile(`>`), defaultHandler(Greater, ">")}, - {regexp.MustCompile(`\.`), defaultHandler(Dot, ".")}, - {regexp.MustCompile(`,`), defaultHandler(Comma, ",")}, - }, - } -} - -type regexHandler func(lex *lexer, regex *regexp.Regexp) - -// Created a default handler which will simply create a token with the matched contents. -// This handler is used with most simple tokens. -func defaultHandler(kind TokenKind, value string) regexHandler { - return func(lex *lexer, _ *regexp.Regexp) { - lex.advanceN(len(value)) - lex.push(newUniqueToken(kind, value)) - } -} - -func stringHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindStringIndex(lex.remainder()) - stringLiteral := lex.remainder()[match[0]:match[1]] - - lex.push(newUniqueToken(String, stringLiteral)) - lex.advanceN(len(stringLiteral)) -} - -func numberHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindString(lex.remainder()) - lex.push(newUniqueToken(Number, match)) - lex.advanceN(len(match)) -} - -func symbolHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindString(lex.remainder()) - keyword := strings.ToUpper(match) - - if kind, found := reservedLu[keyword]; found { - lex.push(newUniqueToken(kind, match)) - } else { - lex.push(newUniqueToken(Identifier, match)) - } - - lex.advanceN(len(match)) -} - -func skipHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindStringIndex(lex.remainder()) - lex.advanceN(match[1]) -} +package lexer + +import ( + "fmt" + "regexp" + "strings" +) + +type regexPattern struct { + regex *regexp.Regexp + handler regexHandler +} + +type lexer struct { + patterns []regexPattern + Tokens []Token + source *string + pos int + line int +} + +type Error struct { + message string +} + +func NewLexerError(format string, a ...any) *Error { + return &Error{message: fmt.Sprintf(format, a...)} +} + +func (e *Error) Error() string { + return e.message +} + +func Tokenize(source *string) ([]Token, error) { + lex := createLexer(source) + + for !lex.atEOF() { + matched := false + + for _, pattern := range lex.patterns { + loc := pattern.regex.FindStringIndex(lex.remainder()) + if loc != nil && loc[0] == 0 { + pattern.handler(lex, pattern.regex) + + matched = true + + break // Exit the loop after the first match + } + } + + if !matched { + return lex.Tokens, NewLexerError("unrecognized token near '%v'", lex.remainder()) + } + } + + lex.push(newUniqueToken(EOF, "EOF")) + + return lex.Tokens, nil +} + +func (lex *lexer) advanceN(n int) { + lex.pos += n +} + +func (lex *lexer) remainder() string { + return (*lex.source)[lex.pos:] +} + +func (lex *lexer) push(token Token) { + lex.Tokens = append(lex.Tokens, token) +} + +func (lex *lexer) atEOF() bool { + return lex.pos >= len(*lex.source) +} + +func createLexer(source *string) *lexer { + return &lexer{ + pos: 0, + line: 1, + source: source, + Tokens: make([]Token, 0), + patterns: []regexPattern{ + {regexp.MustCompile(`\s+`), skipHandler}, + {regexp.MustCompile(`"[^"]*"`), stringHandler}, + {regexp.MustCompile(`'[^\']*\'`), stringHandler}, + {regexp.MustCompile("`[^`]*`"), stringHandler}, + {regexp.MustCompile(`\-?[0-9]+(\.[0-9]+)?`), numberHandler}, + {regexp.MustCompile(`[a-zA-Z_][a-zA-Z0-9_]*`), symbolHandler}, + {regexp.MustCompile(`\(`), defaultHandler(OpenParen, "(")}, + {regexp.MustCompile(`\)`), defaultHandler(CloseParen, ")")}, + {regexp.MustCompile(`!=`), defaultHandler(NotEquals, "!=")}, + {regexp.MustCompile(`=`), defaultHandler(Equals, "=")}, + {regexp.MustCompile(`<=`), defaultHandler(LessEquals, "<=")}, + {regexp.MustCompile(`<`), defaultHandler(Less, "<")}, + {regexp.MustCompile(`>=`), defaultHandler(GreaterEquals, ">=")}, + {regexp.MustCompile(`>`), defaultHandler(Greater, ">")}, + {regexp.MustCompile(`\.`), defaultHandler(Dot, ".")}, + {regexp.MustCompile(`,`), defaultHandler(Comma, ",")}, + }, + } +} + +type regexHandler func(lex *lexer, regex *regexp.Regexp) + +// Created a default handler which will simply create a token with the matched contents. +// This handler is used with most simple tokens. +func defaultHandler(kind TokenKind, value string) regexHandler { + return func(lex *lexer, _ *regexp.Regexp) { + lex.advanceN(len(value)) + lex.push(newUniqueToken(kind, value)) + } +} + +func stringHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindStringIndex(lex.remainder()) + stringLiteral := lex.remainder()[match[0]:match[1]] + + lex.push(newUniqueToken(String, stringLiteral)) + lex.advanceN(len(stringLiteral)) +} + +func numberHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindString(lex.remainder()) + lex.push(newUniqueToken(Number, match)) + lex.advanceN(len(match)) +} + +func symbolHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindString(lex.remainder()) + keyword := strings.ToUpper(match) + + if kind, found := reservedLu[keyword]; found { + lex.push(newUniqueToken(kind, match)) + } else { + lex.push(newUniqueToken(Identifier, match)) + } + + lex.advanceN(len(match)) +} + +func skipHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindStringIndex(lex.remainder()) + lex.advanceN(match[1]) +} diff --git a/pkg/tracking/service/query/lexer/tokenizer_test.go b/pkg/tracking/service/query/lexer/tokenizer_test.go index fa5fdf1..d4c9cea 100644 --- a/pkg/tracking/service/query/lexer/tokenizer_test.go +++ b/pkg/tracking/service/query/lexer/tokenizer_test.go @@ -1,114 +1,114 @@ -package lexer_test - -import ( - "strings" - "testing" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" -) - -type Sample struct { - input string - expected string -} - -//nolint:lll,funlen -func TestQueries(t *testing.T) { - t.Parallel() - - samples := []Sample{ - { - input: "metrics.accuracy > 0.72", - expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) eof", - }, - { - input: "metrics.\"accuracy\" > 0.72", - expected: "identifier(metrics) dot string(\"accuracy\") greater number(0.72) eof", - }, - { - input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", - expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) and identifier(metrics) dot identifier(loss) less_equals number(0.15) eof", - }, - { - input: "params.batch_size = \"2\"", - expected: "identifier(params) dot identifier(batch_size) equals string(\"2\") eof", - }, - { - input: "tags.task ILIKE \"classif%\"", - expected: "identifier(tags) dot identifier(task) ilike string(\"classif%\") eof", - }, - { - input: "datasets.digest IN ('s8ds293b', 'jks834s2')", - expected: "identifier(datasets) dot identifier(digest) in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", - }, - { - input: "attributes.created > 1664067852747", - expected: "identifier(attributes) dot identifier(created) greater number(1664067852747) eof", - }, - { - input: "params.batch_size != \"None\"", - expected: "identifier(params) dot identifier(batch_size) not_equals string(\"None\") eof", - }, - { - input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", - expected: "identifier(datasets) dot identifier(digest) not in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", - }, - { - input: "params.`random_state` = \"8888\"", - expected: "identifier(params) dot string(`random_state`) equals string(\"8888\") eof", - }, - { - input: "metrics.measure_a != -12.0", - expected: "identifier(metrics) dot identifier(measure_a) not_equals number(-12.0) eof", - }, - } - - for _, sample := range samples { - currentSample := sample - - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - tokens, err := lexer.Tokenize(¤tSample.input) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - output := "" - - for _, token := range tokens { - output += " " + token.Debug() - } - - output = strings.TrimLeft(output, " ") - - if output != currentSample.expected { - t.Errorf("expected %s, got %s", currentSample.expected, output) - } - }) - } -} - -func TestInvalidInput(t *testing.T) { - t.Parallel() - - samples := []string{ - "params.'acc = LR", - "params.acc = 'LR", - "params.acc = LR'", - "params.acc = \"LR'", - "tags.acc = \"LR'", - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample, func(t *testing.T) { - t.Parallel() - - _, err := lexer.Tokenize(¤tSample) - if err == nil { - t.Errorf("expected error, got nil") - } - }) - } -} +package lexer_test + +import ( + "strings" + "testing" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" +) + +type Sample struct { + input string + expected string +} + +//nolint:lll,funlen +func TestQueries(t *testing.T) { + t.Parallel() + + samples := []Sample{ + { + input: "metrics.accuracy > 0.72", + expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) eof", + }, + { + input: "metrics.\"accuracy\" > 0.72", + expected: "identifier(metrics) dot string(\"accuracy\") greater number(0.72) eof", + }, + { + input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", + expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) and identifier(metrics) dot identifier(loss) less_equals number(0.15) eof", + }, + { + input: "params.batch_size = \"2\"", + expected: "identifier(params) dot identifier(batch_size) equals string(\"2\") eof", + }, + { + input: "tags.task ILIKE \"classif%\"", + expected: "identifier(tags) dot identifier(task) ilike string(\"classif%\") eof", + }, + { + input: "datasets.digest IN ('s8ds293b', 'jks834s2')", + expected: "identifier(datasets) dot identifier(digest) in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", + }, + { + input: "attributes.created > 1664067852747", + expected: "identifier(attributes) dot identifier(created) greater number(1664067852747) eof", + }, + { + input: "params.batch_size != \"None\"", + expected: "identifier(params) dot identifier(batch_size) not_equals string(\"None\") eof", + }, + { + input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", + expected: "identifier(datasets) dot identifier(digest) not in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", + }, + { + input: "params.`random_state` = \"8888\"", + expected: "identifier(params) dot string(`random_state`) equals string(\"8888\") eof", + }, + { + input: "metrics.measure_a != -12.0", + expected: "identifier(metrics) dot identifier(measure_a) not_equals number(-12.0) eof", + }, + } + + for _, sample := range samples { + currentSample := sample + + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + tokens, err := lexer.Tokenize(¤tSample.input) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + output := "" + + for _, token := range tokens { + output += " " + token.Debug() + } + + output = strings.TrimLeft(output, " ") + + if output != currentSample.expected { + t.Errorf("expected %s, got %s", currentSample.expected, output) + } + }) + } +} + +func TestInvalidInput(t *testing.T) { + t.Parallel() + + samples := []string{ + "params.'acc = LR", + "params.acc = 'LR", + "params.acc = LR'", + "params.acc = \"LR'", + "tags.acc = \"LR'", + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample, func(t *testing.T) { + t.Parallel() + + _, err := lexer.Tokenize(¤tSample) + if err == nil { + t.Errorf("expected error, got nil") + } + }) + } +} diff --git a/pkg/tracking/service/query/parser/ast.go b/pkg/tracking/service/query/parser/ast.go index 9f2673c..6a92b31 100644 --- a/pkg/tracking/service/query/parser/ast.go +++ b/pkg/tracking/service/query/parser/ast.go @@ -1,137 +1,137 @@ -package parser - -import ( - "fmt" - "strings" -) - -// -------------------- -// Literal Expressions -// -------------------- - -type Value interface { - value() interface{} - fmt.Stringer -} - -type NumberExpr struct { - Value float64 -} - -func (n NumberExpr) value() interface{} { - return n.Value -} - -func (n NumberExpr) String() string { - return fmt.Sprintf("%f", n.Value) -} - -type StringExpr struct { - Value string -} - -func (n StringExpr) value() interface{} { - return n.Value -} - -func (n StringExpr) String() string { - return fmt.Sprintf("\"%s\"", n.Value) -} - -type StringListExpr struct { - Values []string -} - -func (n StringListExpr) value() interface{} { - return n.Values -} - -func (n StringListExpr) String() string { - items := make([]string, 0, len(n.Values)) - for _, v := range n.Values { - items = append(items, fmt.Sprintf("\"%s\"", v)) - } - - return strings.Join(items, ", ") -} - -//----------------------- -// Identifier Expressions -// ---------------------- - -// identifier.key expression, like metric.foo. -type Identifier struct { - Identifier string - Key string -} - -func (i Identifier) String() string { - if i.Key == "" { - return i.Identifier - } - - return fmt.Sprintf("%s.%s", i.Identifier, i.Key) -} - -// -------------------- -// Comparison Expression -// -------------------- - -type OperatorKind int - -const ( - Equals OperatorKind = iota - NotEquals - Less - LessEquals - Greater - GreaterEquals - Like - ILike - In //nolint:varnamelen - NotIn -) - -//nolint:cyclop -func (op OperatorKind) String() string { - switch op { - case Equals: - return "=" - case NotEquals: - return "!=" - case Less: - return "<" - case LessEquals: - return "<=" - case Greater: - return ">" - case GreaterEquals: - return ">=" - case Like: - return "LIKE" - case ILike: - return "ILIKE" - case In: - return "IN" - case NotIn: - return "NOT IN" - default: - return "UNKNOWN" - } -} - -// a operator b. -type CompareExpr struct { - Left Identifier - Operator OperatorKind - Right Value -} - -func (expr *CompareExpr) String() string { - return fmt.Sprintf("%s %s %s", expr.Left, expr.Operator, expr.Right) -} - -// AND. -type AndExpr struct { - Exprs []*CompareExpr -} +package parser + +import ( + "fmt" + "strings" +) + +// -------------------- +// Literal Expressions +// -------------------- + +type Value interface { + value() interface{} + fmt.Stringer +} + +type NumberExpr struct { + Value float64 +} + +func (n NumberExpr) value() interface{} { + return n.Value +} + +func (n NumberExpr) String() string { + return fmt.Sprintf("%f", n.Value) +} + +type StringExpr struct { + Value string +} + +func (n StringExpr) value() interface{} { + return n.Value +} + +func (n StringExpr) String() string { + return fmt.Sprintf("\"%s\"", n.Value) +} + +type StringListExpr struct { + Values []string +} + +func (n StringListExpr) value() interface{} { + return n.Values +} + +func (n StringListExpr) String() string { + items := make([]string, 0, len(n.Values)) + for _, v := range n.Values { + items = append(items, fmt.Sprintf("\"%s\"", v)) + } + + return strings.Join(items, ", ") +} + +//----------------------- +// Identifier Expressions +// ---------------------- + +// identifier.key expression, like metric.foo. +type Identifier struct { + Identifier string + Key string +} + +func (i Identifier) String() string { + if i.Key == "" { + return i.Identifier + } + + return fmt.Sprintf("%s.%s", i.Identifier, i.Key) +} + +// -------------------- +// Comparison Expression +// -------------------- + +type OperatorKind int + +const ( + Equals OperatorKind = iota + NotEquals + Less + LessEquals + Greater + GreaterEquals + Like + ILike + In //nolint:varnamelen + NotIn +) + +//nolint:cyclop +func (op OperatorKind) String() string { + switch op { + case Equals: + return "=" + case NotEquals: + return "!=" + case Less: + return "<" + case LessEquals: + return "<=" + case Greater: + return ">" + case GreaterEquals: + return ">=" + case Like: + return "LIKE" + case ILike: + return "ILIKE" + case In: + return "IN" + case NotIn: + return "NOT IN" + default: + return "UNKNOWN" + } +} + +// a operator b. +type CompareExpr struct { + Left Identifier + Operator OperatorKind + Right Value +} + +func (expr *CompareExpr) String() string { + return fmt.Sprintf("%s %s %s", expr.Left, expr.Operator, expr.Right) +} + +// AND. +type AndExpr struct { + Exprs []*CompareExpr +} diff --git a/pkg/tracking/service/query/parser/parser.go b/pkg/tracking/service/query/parser/parser.go index 6bc39de..978aa29 100644 --- a/pkg/tracking/service/query/parser/parser.go +++ b/pkg/tracking/service/query/parser/parser.go @@ -1,265 +1,265 @@ -package parser - -import ( - "fmt" - "strconv" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" -) - -type parser struct { - tokens []lexer.Token - pos int -} - -func newParser(tokens []lexer.Token) *parser { - return &parser{ - tokens: tokens, - pos: 0, - } -} - -func (p *parser) currentTokenKind() lexer.TokenKind { - return p.tokens[p.pos].Kind -} - -func (p *parser) hasTokens() bool { - return p.pos < len(p.tokens) && p.currentTokenKind() != lexer.EOF -} - -func (p *parser) printCurrentToken() string { - return p.tokens[p.pos].Debug() -} - -func (p *parser) currentToken() lexer.Token { - return p.tokens[p.pos] -} - -func (p *parser) advance() lexer.Token { - tk := p.currentToken() - p.pos++ - - return tk -} - -type Error struct { - message string -} - -func NewParserError(format string, a ...any) *Error { - return &Error{message: fmt.Sprintf(format, a...)} -} - -func (e *Error) Error() string { - return e.message -} - -func (p *parser) parseIdentifier() (Identifier, error) { - emptyIdentifier := Identifier{Identifier: "", Key: ""} - if p.hasTokens() && p.currentTokenKind() != lexer.Identifier { - return emptyIdentifier, NewParserError( - "expected identifier, got %s", - p.printCurrentToken(), - ) - } - - identToken := p.advance() - - if p.currentTokenKind() == lexer.Dot { - p.advance() // Consume the DOT - //nolint:exhaustive - switch p.currentTokenKind() { - case lexer.Identifier: - column := p.advance().Value - - return Identifier{Identifier: identToken.Value, Key: column}, nil - case lexer.String: - column := p.advance().Value - column = column[1 : len(column)-1] // Remove quotes - - return Identifier{Identifier: identToken.Value, Key: column}, nil - default: - return emptyIdentifier, NewParserError( - "expected IDENTIFIER or STRING, got %s", - p.printCurrentToken(), - ) - } - } else { - return Identifier{Identifier: "", Key: identToken.Value}, nil - } -} - -func (p *parser) parseOperator() (OperatorKind, error) { - //nolint:exhaustive - switch p.advance().Kind { - case lexer.Equals: - return Equals, nil - case lexer.NotEquals: - return NotEquals, nil - case lexer.Less: - return Less, nil - case lexer.LessEquals: - return LessEquals, nil - case lexer.Greater: - return Greater, nil - case lexer.GreaterEquals: - return GreaterEquals, nil - case lexer.Like: - return Like, nil - case lexer.ILike: - return ILike, nil - default: - return -1, NewParserError("expected operator, got %s", p.printCurrentToken()) - } -} - -//nolint:ireturn -func (p *parser) parseValue() (Value, error) { - //nolint:exhaustive - switch p.currentTokenKind() { - case lexer.Number: - n, err := strconv.ParseFloat(p.advance().Value, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse number token to float: %w", err) - } - - return NumberExpr{Value: n}, nil - case lexer.String: - value := p.advance().Value - value = value[1 : len(value)-1] // Remove quotes - - return StringExpr{Value: value}, nil - default: - return nil, NewParserError( - "Expected NUMBER or STRING, got %s", - p.printCurrentToken(), - ) - } -} - -func (p *parser) parseInSetExpr(ident Identifier) (*CompareExpr, error) { - if p.currentTokenKind() != lexer.OpenParen { - return nil, NewParserError( - "expected '(', got %s", - p.printCurrentToken(), - ) - } - - p.advance() // Consume the OPEN_PAREN - - set := make([]string, 0) - - for p.hasTokens() && p.currentTokenKind() != lexer.CloseParen { - if p.currentTokenKind() != lexer.String { - return nil, NewParserError( - "expected STRING, got %s", - p.printCurrentToken(), - ) - } - - value := p.advance().Value - value = value[1 : len(value)-1] // Remove quotes - - set = append(set, value) - - if p.currentTokenKind() == lexer.Comma { - p.advance() // Consume the COMMA - } - } - - if p.currentTokenKind() != lexer.CloseParen { - return nil, NewParserError( - "expected ')', got %s", - p.printCurrentToken(), - ) - } - - p.advance() // Consume the CLOSE_PAREN - - return &CompareExpr{Left: ident, Operator: In, Right: StringListExpr{Values: set}}, nil -} - -func (p *parser) parseExpression() (*CompareExpr, error) { - ident, err := p.parseIdentifier() - if err != nil { - return nil, err - } - - //nolint:exhaustive - switch p.currentTokenKind() { - case lexer.In: - p.advance() // Consume the IN - - return p.parseInSetExpr(ident) - case lexer.Not: - p.advance() // Consume the NOT - - if p.currentTokenKind() != lexer.In { - return nil, NewParserError( - "expected IN after NOT, got %s", - p.printCurrentToken(), - ) - } - - p.advance() // Consume the IN - - expr, err := p.parseInSetExpr(ident) - if err != nil { - return nil, err - } - - expr.Operator = NotIn - - return expr, nil - default: - operator, err := p.parseOperator() - if err != nil { - return nil, err - } - - value, err := p.parseValue() - if err != nil { - return nil, err - } - - return &CompareExpr{Left: ident, Operator: operator, Right: value}, nil - } -} - -func (p *parser) parse() (*AndExpr, error) { - exprs := make([]*CompareExpr, 0) - - leftExpr, err := p.parseExpression() - if err != nil { - return nil, fmt.Errorf("error while parsing initial expression: %w", err) - } - - exprs = append(exprs, leftExpr) - - // While there are tokens and the next token is AND - for p.currentTokenKind() == lexer.And { - p.advance() // Consume the AND - - rightExpr, err := p.parseExpression() - if err != nil { - return nil, err - } - - exprs = append(exprs, rightExpr) - } - - if p.hasTokens() { - return nil, NewParserError( - "unexpected leftover token(s) after parsing: %s", - p.printCurrentToken(), - ) - } - - return &AndExpr{Exprs: exprs}, nil -} - -func Parse(tokens []lexer.Token) (*AndExpr, error) { - parser := newParser(tokens) - - return parser.parse() -} +package parser + +import ( + "fmt" + "strconv" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" +) + +type parser struct { + tokens []lexer.Token + pos int +} + +func newParser(tokens []lexer.Token) *parser { + return &parser{ + tokens: tokens, + pos: 0, + } +} + +func (p *parser) currentTokenKind() lexer.TokenKind { + return p.tokens[p.pos].Kind +} + +func (p *parser) hasTokens() bool { + return p.pos < len(p.tokens) && p.currentTokenKind() != lexer.EOF +} + +func (p *parser) printCurrentToken() string { + return p.tokens[p.pos].Debug() +} + +func (p *parser) currentToken() lexer.Token { + return p.tokens[p.pos] +} + +func (p *parser) advance() lexer.Token { + tk := p.currentToken() + p.pos++ + + return tk +} + +type Error struct { + message string +} + +func NewParserError(format string, a ...any) *Error { + return &Error{message: fmt.Sprintf(format, a...)} +} + +func (e *Error) Error() string { + return e.message +} + +func (p *parser) parseIdentifier() (Identifier, error) { + emptyIdentifier := Identifier{Identifier: "", Key: ""} + if p.hasTokens() && p.currentTokenKind() != lexer.Identifier { + return emptyIdentifier, NewParserError( + "expected identifier, got %s", + p.printCurrentToken(), + ) + } + + identToken := p.advance() + + if p.currentTokenKind() == lexer.Dot { + p.advance() // Consume the DOT + //nolint:exhaustive + switch p.currentTokenKind() { + case lexer.Identifier: + column := p.advance().Value + + return Identifier{Identifier: identToken.Value, Key: column}, nil + case lexer.String: + column := p.advance().Value + column = column[1 : len(column)-1] // Remove quotes + + return Identifier{Identifier: identToken.Value, Key: column}, nil + default: + return emptyIdentifier, NewParserError( + "expected IDENTIFIER or STRING, got %s", + p.printCurrentToken(), + ) + } + } else { + return Identifier{Identifier: "", Key: identToken.Value}, nil + } +} + +func (p *parser) parseOperator() (OperatorKind, error) { + //nolint:exhaustive + switch p.advance().Kind { + case lexer.Equals: + return Equals, nil + case lexer.NotEquals: + return NotEquals, nil + case lexer.Less: + return Less, nil + case lexer.LessEquals: + return LessEquals, nil + case lexer.Greater: + return Greater, nil + case lexer.GreaterEquals: + return GreaterEquals, nil + case lexer.Like: + return Like, nil + case lexer.ILike: + return ILike, nil + default: + return -1, NewParserError("expected operator, got %s", p.printCurrentToken()) + } +} + +//nolint:ireturn +func (p *parser) parseValue() (Value, error) { + //nolint:exhaustive + switch p.currentTokenKind() { + case lexer.Number: + n, err := strconv.ParseFloat(p.advance().Value, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse number token to float: %w", err) + } + + return NumberExpr{Value: n}, nil + case lexer.String: + value := p.advance().Value + value = value[1 : len(value)-1] // Remove quotes + + return StringExpr{Value: value}, nil + default: + return nil, NewParserError( + "Expected NUMBER or STRING, got %s", + p.printCurrentToken(), + ) + } +} + +func (p *parser) parseInSetExpr(ident Identifier) (*CompareExpr, error) { + if p.currentTokenKind() != lexer.OpenParen { + return nil, NewParserError( + "expected '(', got %s", + p.printCurrentToken(), + ) + } + + p.advance() // Consume the OPEN_PAREN + + set := make([]string, 0) + + for p.hasTokens() && p.currentTokenKind() != lexer.CloseParen { + if p.currentTokenKind() != lexer.String { + return nil, NewParserError( + "expected STRING, got %s", + p.printCurrentToken(), + ) + } + + value := p.advance().Value + value = value[1 : len(value)-1] // Remove quotes + + set = append(set, value) + + if p.currentTokenKind() == lexer.Comma { + p.advance() // Consume the COMMA + } + } + + if p.currentTokenKind() != lexer.CloseParen { + return nil, NewParserError( + "expected ')', got %s", + p.printCurrentToken(), + ) + } + + p.advance() // Consume the CLOSE_PAREN + + return &CompareExpr{Left: ident, Operator: In, Right: StringListExpr{Values: set}}, nil +} + +func (p *parser) parseExpression() (*CompareExpr, error) { + ident, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + //nolint:exhaustive + switch p.currentTokenKind() { + case lexer.In: + p.advance() // Consume the IN + + return p.parseInSetExpr(ident) + case lexer.Not: + p.advance() // Consume the NOT + + if p.currentTokenKind() != lexer.In { + return nil, NewParserError( + "expected IN after NOT, got %s", + p.printCurrentToken(), + ) + } + + p.advance() // Consume the IN + + expr, err := p.parseInSetExpr(ident) + if err != nil { + return nil, err + } + + expr.Operator = NotIn + + return expr, nil + default: + operator, err := p.parseOperator() + if err != nil { + return nil, err + } + + value, err := p.parseValue() + if err != nil { + return nil, err + } + + return &CompareExpr{Left: ident, Operator: operator, Right: value}, nil + } +} + +func (p *parser) parse() (*AndExpr, error) { + exprs := make([]*CompareExpr, 0) + + leftExpr, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("error while parsing initial expression: %w", err) + } + + exprs = append(exprs, leftExpr) + + // While there are tokens and the next token is AND + for p.currentTokenKind() == lexer.And { + p.advance() // Consume the AND + + rightExpr, err := p.parseExpression() + if err != nil { + return nil, err + } + + exprs = append(exprs, rightExpr) + } + + if p.hasTokens() { + return nil, NewParserError( + "unexpected leftover token(s) after parsing: %s", + p.printCurrentToken(), + ) + } + + return &AndExpr{Exprs: exprs}, nil +} + +func Parse(tokens []lexer.Token) (*AndExpr, error) { + parser := newParser(tokens) + + return parser.parse() +} diff --git a/pkg/tracking/service/query/parser/parser_test.go b/pkg/tracking/service/query/parser/parser_test.go index 64129c3..9b99094 100644 --- a/pkg/tracking/service/query/parser/parser_test.go +++ b/pkg/tracking/service/query/parser/parser_test.go @@ -1,182 +1,182 @@ -package parser_test - -import ( - "reflect" - "testing" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" -) - -type Sample struct { - input string - expected *parser.AndExpr -} - -//nolint:funlen -func TestQueries(t *testing.T) { - t.Parallel() - - samples := []Sample{ - { - input: "metrics.accuracy > 0.72", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"metrics", "accuracy"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 0.72}, - }, - }, - }, - }, - { - input: "metrics.\"accuracy\" > 0.72", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"metrics", "accuracy"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 0.72}, - }, - }, - }, - }, - { - input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"metrics", "accuracy"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 0.72}, - }, - { - Left: parser.Identifier{"metrics", "loss"}, - Operator: parser.LessEquals, - Right: parser.NumberExpr{Value: 0.15}, - }, - }, - }, - }, - { - input: "params.batch_size = \"2\"", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"params", "batch_size"}, - Operator: parser.Equals, - Right: parser.StringExpr{Value: "2"}, - }, - }, - }, - }, - { - input: "tags.task ILIKE \"classif%\"", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"tags", "task"}, - Operator: parser.ILike, - Right: parser.StringExpr{Value: "classif%"}, - }, - }, - }, - }, - { - input: "datasets.digest IN ('s8ds293b', 'jks834s2')", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"datasets", "digest"}, - Operator: parser.In, - Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, - }, - }, - }, - }, - { - input: "attributes.created > 1664067852747", - expected: &parser.AndExpr{ - []*parser.CompareExpr{ - { - Left: parser.Identifier{"attributes", "created"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 1664067852747}, - }, - }, - }, - }, - { - input: "params.batch_size != \"None\"", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"params", "batch_size"}, - Operator: parser.NotEquals, - Right: parser.StringExpr{Value: "None"}, - }, - }, - }, - }, - { - input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"datasets", "digest"}, - Operator: parser.NotIn, - Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, - }, - }, - }, - }, - } - - for _, sample := range samples { - currentSample := sample - - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - tokens, err := lexer.Tokenize(¤tSample.input) - if err != nil { - t.Errorf("unexpected lex error: %v", err) - } - - ast, err := parser.Parse(tokens) - if err != nil { - t.Errorf("error parsing: %s", err) - } - - if !reflect.DeepEqual(ast, currentSample.expected) { - t.Errorf("expected %#v, got %#v", currentSample.expected, ast) - } - }) - } -} - -func TestInvalidSyntax(t *testing.T) { - t.Parallel() - - samples := []string{ - "attribute.status IS 'RUNNING'", - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample, func(t *testing.T) { - t.Parallel() - - tokens, err := lexer.Tokenize(¤tSample) - if err != nil { - t.Errorf("unexpected lex error: %v", err) - } - - _, err = parser.Parse(tokens) - if err == nil { - t.Errorf("expected parse error, got nil") - } - }) - } -} +package parser_test + +import ( + "reflect" + "testing" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" +) + +type Sample struct { + input string + expected *parser.AndExpr +} + +//nolint:funlen +func TestQueries(t *testing.T) { + t.Parallel() + + samples := []Sample{ + { + input: "metrics.accuracy > 0.72", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"metrics", "accuracy"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 0.72}, + }, + }, + }, + }, + { + input: "metrics.\"accuracy\" > 0.72", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"metrics", "accuracy"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 0.72}, + }, + }, + }, + }, + { + input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"metrics", "accuracy"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 0.72}, + }, + { + Left: parser.Identifier{"metrics", "loss"}, + Operator: parser.LessEquals, + Right: parser.NumberExpr{Value: 0.15}, + }, + }, + }, + }, + { + input: "params.batch_size = \"2\"", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"params", "batch_size"}, + Operator: parser.Equals, + Right: parser.StringExpr{Value: "2"}, + }, + }, + }, + }, + { + input: "tags.task ILIKE \"classif%\"", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"tags", "task"}, + Operator: parser.ILike, + Right: parser.StringExpr{Value: "classif%"}, + }, + }, + }, + }, + { + input: "datasets.digest IN ('s8ds293b', 'jks834s2')", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"datasets", "digest"}, + Operator: parser.In, + Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, + }, + }, + }, + }, + { + input: "attributes.created > 1664067852747", + expected: &parser.AndExpr{ + []*parser.CompareExpr{ + { + Left: parser.Identifier{"attributes", "created"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 1664067852747}, + }, + }, + }, + }, + { + input: "params.batch_size != \"None\"", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"params", "batch_size"}, + Operator: parser.NotEquals, + Right: parser.StringExpr{Value: "None"}, + }, + }, + }, + }, + { + input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"datasets", "digest"}, + Operator: parser.NotIn, + Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, + }, + }, + }, + }, + } + + for _, sample := range samples { + currentSample := sample + + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + tokens, err := lexer.Tokenize(¤tSample.input) + if err != nil { + t.Errorf("unexpected lex error: %v", err) + } + + ast, err := parser.Parse(tokens) + if err != nil { + t.Errorf("error parsing: %s", err) + } + + if !reflect.DeepEqual(ast, currentSample.expected) { + t.Errorf("expected %#v, got %#v", currentSample.expected, ast) + } + }) + } +} + +func TestInvalidSyntax(t *testing.T) { + t.Parallel() + + samples := []string{ + "attribute.status IS 'RUNNING'", + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample, func(t *testing.T) { + t.Parallel() + + tokens, err := lexer.Tokenize(¤tSample) + if err != nil { + t.Errorf("unexpected lex error: %v", err) + } + + _, err = parser.Parse(tokens) + if err == nil { + t.Errorf("expected parse error, got nil") + } + }) + } +} diff --git a/pkg/tracking/service/query/parser/validate.go b/pkg/tracking/service/query/parser/validate.go index d2292c0..011589a 100644 --- a/pkg/tracking/service/query/parser/validate.go +++ b/pkg/tracking/service/query/parser/validate.go @@ -1,329 +1,329 @@ -package parser - -import ( - "errors" - "fmt" - "strings" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -/* - -This is the equivalent of type-checking the untyped tree. -Not every parsed tree is a valid one. - -Grammar rule: identifier.key operator value - -The rules are: - -For identifiers: - -identifier.key - -Or if only key is passed, the identifier is "attribute" - -Identifiers can have aliases. - -if the identifier is dataset, the allowed keys are: name, digest and context. - -*/ - -type ValidIdentifier int - -const ( - Metric ValidIdentifier = iota - Parameter - Tag - Attribute - Dataset -) - -func (v ValidIdentifier) String() string { - switch v { - case Metric: - return "metric" - case Parameter: - return "parameter" - case Tag: - return "tag" - case Attribute: - return "attribute" - case Dataset: - return "dataset" - default: - return "unknown" - } -} - -type ValidCompareExpr struct { - Identifier ValidIdentifier - Key string - Operator OperatorKind - Value interface{} -} - -func (v ValidCompareExpr) String() string { - return fmt.Sprintf("%s.%s %s %v", v.Identifier, v.Key, v.Operator, v.Value) -} - -type ValidationError struct { - message string -} - -func (e *ValidationError) Error() string { - return e.message -} - -func NewValidationError(format string, a ...interface{}) *ValidationError { - return &ValidationError{message: fmt.Sprintf(format, a...)} -} - -const ( - metricIdentifier = "metric" - parameterIdentifier = "parameter" - tagIdentifier = "tag" - attributeIdentifier = "attribute" - datasetIdentifier = "dataset" -) - -var identifiers = []string{ - metricIdentifier, - parameterIdentifier, - tagIdentifier, - attributeIdentifier, - datasetIdentifier, -} - -func parseValidIdentifier(identifier string) (ValidIdentifier, error) { - switch identifier { - case metricIdentifier, "metrics": - return Metric, nil - case parameterIdentifier, "parameters", "param", "params": - return Parameter, nil - case tagIdentifier, "tags": - return Tag, nil - case "", attributeIdentifier, "attr", "attributes", "run": - return Attribute, nil - case datasetIdentifier, "datasets": - return Dataset, nil - default: - return -1, NewValidationError("invalid identifier %q", identifier) - } -} - -const ( - RunID = "run_id" - RunName = "run_name" - Created = "created" - StartTime = "start_time" -) - -// This should be configurable and only applies to the runs table. -var searchableRunAttributes = []string{ - RunID, - RunName, - "user_id", - "status", - StartTime, - "end_time", - "artifact_uri", -} - -var datasetAttributes = []string{"name", "digest", "context"} - -func parseAttributeKey(key string) (string, error) { - switch key { - case "run_id": - // We return run_uuid before that is the SQL column name. - return "run_uuid", nil - case - "user_id", - "status", - StartTime, - "end_time", - "artifact_uri": - return key, nil - case Created, "Created": - return StartTime, nil - case RunName, "run name", "Run name", "Run Name": - return RunName, nil - default: - return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, - fmt.Sprintf( - "Invalid attribute key '{%s}' specified. Valid keys are '%v'", - key, - searchableRunAttributes, - ), - ) - } -} - -func parseKey(identifier ValidIdentifier, key string) (string, error) { - if key == "" { - return attributeIdentifier, nil - } - - //nolint:exhaustive - switch identifier { - case Attribute: - return parseAttributeKey(key) - case Dataset: - switch key { - case "name", "digest", "context": - return key, nil - default: - return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, - fmt.Sprintf( - "Invalid dataset key '{%s}' specified. Valid keys are '%v'", - key, - searchableRunAttributes, - ), - ) - } - default: - return key, nil - } -} - -// Returns a standardized LongIdentifierExpr. -func validatedIdentifier(identifier *Identifier) (ValidIdentifier, string, error) { - validIdentifier, err := parseValidIdentifier(identifier.Identifier) - if err != nil { - return -1, "", err - } - - validKey, err := parseKey(validIdentifier, identifier.Key) - if err != nil { - return -1, "", err - } - - identifier.Key = validKey - - return validIdentifier, validKey, nil -} - -/* - -The value part is determined by the identifier - -"metric" takes numbers -"parameter" and "tag" takes strings - -"attribute" could be either string or number, -number when StartTime, "end_time" or "created", "Created" -otherwise string - -"dataset" takes strings for "name", "digest" and "context" - -*/ - -func validateDatasetValue(key string, value Value) (interface{}, error) { - switch key { - case "name", "digest", "context": - if _, ok := value.(NumberExpr); ok { - return nil, NewValidationError( - "expected datasets.%s to be either a string or list of strings. Found %s", - key, - value, - ) - } - - return value.value(), nil - default: - return nil, NewValidationError( - "expected dataset attribute key to be one of %s. Found %s", - strings.Join(datasetAttributes, ", "), - key, - ) - } -} - -// Port of _get_value in search_utils.py. -func validateValue(identifier ValidIdentifier, key string, value Value) (interface{}, error) { - switch identifier { - case Metric: - if _, ok := value.(NumberExpr); !ok { - return nil, NewValidationError( - "expected numeric value type for metric. Found %s", - value, - ) - } - - return value.value(), nil - case Parameter, Tag: - if _, ok := value.(StringExpr); !ok { - return nil, NewValidationError( - "expected a quoted string value for %s. Found %s", - identifier, value, - ) - } - - return value.value(), nil - case Attribute: - value, err := validateAttributeValue(key, value) - - return value, err - case Dataset: - return validateDatasetValue(key, value) - default: - return nil, NewValidationError( - "Invalid identifier type %s. Expected one of %s", - identifier, - strings.Join(identifiers, ", "), - ) - } -} - -func validateAttributeValue(key string, value Value) (interface{}, error) { - switch key { - case StartTime, "end_time", Created: - if _, ok := value.(NumberExpr); !ok { - return nil, NewValidationError( - "expected numeric value type for numeric attribute: %s. Found %s", - key, - value, - ) - } - - return value.value(), nil - default: - // run_id was earlier converted to run_uuid - if _, ok := value.(StringListExpr); key != "run_uuid" && ok { - return nil, NewValidationError( - "only the 'run_id' attribute supports comparison with a list of quoted string values", - ) - } - - return value.value(), nil - } -} - -// Validate an expression according to the mlflow domain. -// This represent is a simple type-checker for the expression. -// Not every identifier is valid according to the mlflow domain. -// The same for the value part. -func ValidateExpression(expression *CompareExpr) (*ValidCompareExpr, error) { - validIdentifier, validKey, err := validatedIdentifier(&expression.Left) - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return nil, contractError - } - - return nil, fmt.Errorf("Error on parsing filter expression: %w", err) - } - - value, err := validateValue(validIdentifier, validKey, expression.Right) - if err != nil { - return nil, fmt.Errorf("Error on parsing filter expression: %w", err) - } - - return &ValidCompareExpr{ - Identifier: validIdentifier, - Key: validKey, - Operator: expression.Operator, - Value: value, - }, nil -} +package parser + +import ( + "errors" + "fmt" + "strings" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +/* + +This is the equivalent of type-checking the untyped tree. +Not every parsed tree is a valid one. + +Grammar rule: identifier.key operator value + +The rules are: + +For identifiers: + +identifier.key + +Or if only key is passed, the identifier is "attribute" + +Identifiers can have aliases. + +if the identifier is dataset, the allowed keys are: name, digest and context. + +*/ + +type ValidIdentifier int + +const ( + Metric ValidIdentifier = iota + Parameter + Tag + Attribute + Dataset +) + +func (v ValidIdentifier) String() string { + switch v { + case Metric: + return "metric" + case Parameter: + return "parameter" + case Tag: + return "tag" + case Attribute: + return "attribute" + case Dataset: + return "dataset" + default: + return "unknown" + } +} + +type ValidCompareExpr struct { + Identifier ValidIdentifier + Key string + Operator OperatorKind + Value interface{} +} + +func (v ValidCompareExpr) String() string { + return fmt.Sprintf("%s.%s %s %v", v.Identifier, v.Key, v.Operator, v.Value) +} + +type ValidationError struct { + message string +} + +func (e *ValidationError) Error() string { + return e.message +} + +func NewValidationError(format string, a ...interface{}) *ValidationError { + return &ValidationError{message: fmt.Sprintf(format, a...)} +} + +const ( + metricIdentifier = "metric" + parameterIdentifier = "parameter" + tagIdentifier = "tag" + attributeIdentifier = "attribute" + datasetIdentifier = "dataset" +) + +var identifiers = []string{ + metricIdentifier, + parameterIdentifier, + tagIdentifier, + attributeIdentifier, + datasetIdentifier, +} + +func parseValidIdentifier(identifier string) (ValidIdentifier, error) { + switch identifier { + case metricIdentifier, "metrics": + return Metric, nil + case parameterIdentifier, "parameters", "param", "params": + return Parameter, nil + case tagIdentifier, "tags": + return Tag, nil + case "", attributeIdentifier, "attr", "attributes", "run": + return Attribute, nil + case datasetIdentifier, "datasets": + return Dataset, nil + default: + return -1, NewValidationError("invalid identifier %q", identifier) + } +} + +const ( + RunID = "run_id" + RunName = "run_name" + Created = "created" + StartTime = "start_time" +) + +// This should be configurable and only applies to the runs table. +var searchableRunAttributes = []string{ + RunID, + RunName, + "user_id", + "status", + StartTime, + "end_time", + "artifact_uri", +} + +var datasetAttributes = []string{"name", "digest", "context"} + +func parseAttributeKey(key string) (string, error) { + switch key { + case "run_id": + // We return run_uuid before that is the SQL column name. + return "run_uuid", nil + case + "user_id", + "status", + StartTime, + "end_time", + "artifact_uri": + return key, nil + case Created, "Created": + return StartTime, nil + case RunName, "run name", "Run name", "Run Name": + return RunName, nil + default: + return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, + fmt.Sprintf( + "Invalid attribute key '{%s}' specified. Valid keys are '%v'", + key, + searchableRunAttributes, + ), + ) + } +} + +func parseKey(identifier ValidIdentifier, key string) (string, error) { + if key == "" { + return attributeIdentifier, nil + } + + //nolint:exhaustive + switch identifier { + case Attribute: + return parseAttributeKey(key) + case Dataset: + switch key { + case "name", "digest", "context": + return key, nil + default: + return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, + fmt.Sprintf( + "Invalid dataset key '{%s}' specified. Valid keys are '%v'", + key, + searchableRunAttributes, + ), + ) + } + default: + return key, nil + } +} + +// Returns a standardized LongIdentifierExpr. +func validatedIdentifier(identifier *Identifier) (ValidIdentifier, string, error) { + validIdentifier, err := parseValidIdentifier(identifier.Identifier) + if err != nil { + return -1, "", err + } + + validKey, err := parseKey(validIdentifier, identifier.Key) + if err != nil { + return -1, "", err + } + + identifier.Key = validKey + + return validIdentifier, validKey, nil +} + +/* + +The value part is determined by the identifier + +"metric" takes numbers +"parameter" and "tag" takes strings + +"attribute" could be either string or number, +number when StartTime, "end_time" or "created", "Created" +otherwise string + +"dataset" takes strings for "name", "digest" and "context" + +*/ + +func validateDatasetValue(key string, value Value) (interface{}, error) { + switch key { + case "name", "digest", "context": + if _, ok := value.(NumberExpr); ok { + return nil, NewValidationError( + "expected datasets.%s to be either a string or list of strings. Found %s", + key, + value, + ) + } + + return value.value(), nil + default: + return nil, NewValidationError( + "expected dataset attribute key to be one of %s. Found %s", + strings.Join(datasetAttributes, ", "), + key, + ) + } +} + +// Port of _get_value in search_utils.py. +func validateValue(identifier ValidIdentifier, key string, value Value) (interface{}, error) { + switch identifier { + case Metric: + if _, ok := value.(NumberExpr); !ok { + return nil, NewValidationError( + "expected numeric value type for metric. Found %s", + value, + ) + } + + return value.value(), nil + case Parameter, Tag: + if _, ok := value.(StringExpr); !ok { + return nil, NewValidationError( + "expected a quoted string value for %s. Found %s", + identifier, value, + ) + } + + return value.value(), nil + case Attribute: + value, err := validateAttributeValue(key, value) + + return value, err + case Dataset: + return validateDatasetValue(key, value) + default: + return nil, NewValidationError( + "Invalid identifier type %s. Expected one of %s", + identifier, + strings.Join(identifiers, ", "), + ) + } +} + +func validateAttributeValue(key string, value Value) (interface{}, error) { + switch key { + case StartTime, "end_time", Created: + if _, ok := value.(NumberExpr); !ok { + return nil, NewValidationError( + "expected numeric value type for numeric attribute: %s. Found %s", + key, + value, + ) + } + + return value.value(), nil + default: + // run_id was earlier converted to run_uuid + if _, ok := value.(StringListExpr); key != "run_uuid" && ok { + return nil, NewValidationError( + "only the 'run_id' attribute supports comparison with a list of quoted string values", + ) + } + + return value.value(), nil + } +} + +// Validate an expression according to the mlflow domain. +// This represent is a simple type-checker for the expression. +// Not every identifier is valid according to the mlflow domain. +// The same for the value part. +func ValidateExpression(expression *CompareExpr) (*ValidCompareExpr, error) { + validIdentifier, validKey, err := validatedIdentifier(&expression.Left) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return nil, contractError + } + + return nil, fmt.Errorf("Error on parsing filter expression: %w", err) + } + + value, err := validateValue(validIdentifier, validKey, expression.Right) + if err != nil { + return nil, fmt.Errorf("Error on parsing filter expression: %w", err) + } + + return &ValidCompareExpr{ + Identifier: validIdentifier, + Key: validKey, + Operator: expression.Operator, + Value: value, + }, nil +} diff --git a/pkg/tracking/service/query/query.go b/pkg/tracking/service/query/query.go index 9f6cb30..eac82cc 100644 --- a/pkg/tracking/service/query/query.go +++ b/pkg/tracking/service/query/query.go @@ -1,37 +1,37 @@ -package query - -import ( - "fmt" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" -) - -func ParseFilter(input string) ([]*parser.ValidCompareExpr, error) { - if input == "" { - return make([]*parser.ValidCompareExpr, 0), nil - } - - tokens, err := lexer.Tokenize(&input) - if err != nil { - return nil, fmt.Errorf("error while lexing %s: %w", input, err) - } - - ast, err := parser.Parse(tokens) - if err != nil { - return nil, fmt.Errorf("error while parsing %s: %w", input, err) - } - - validExpressions := make([]*parser.ValidCompareExpr, 0, len(ast.Exprs)) - - for _, expr := range ast.Exprs { - ve, err := parser.ValidateExpression(expr) - if err != nil { - return nil, fmt.Errorf("error while validating %s: %w", input, err) - } - - validExpressions = append(validExpressions, ve) - } - - return validExpressions, nil -} +package query + +import ( + "fmt" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" +) + +func ParseFilter(input string) ([]*parser.ValidCompareExpr, error) { + if input == "" { + return make([]*parser.ValidCompareExpr, 0), nil + } + + tokens, err := lexer.Tokenize(&input) + if err != nil { + return nil, fmt.Errorf("error while lexing %s: %w", input, err) + } + + ast, err := parser.Parse(tokens) + if err != nil { + return nil, fmt.Errorf("error while parsing %s: %w", input, err) + } + + validExpressions := make([]*parser.ValidCompareExpr, 0, len(ast.Exprs)) + + for _, expr := range ast.Exprs { + ve, err := parser.ValidateExpression(expr) + if err != nil { + return nil, fmt.Errorf("error while validating %s: %w", input, err) + } + + validExpressions = append(validExpressions, ve) + } + + return validExpressions, nil +} diff --git a/pkg/tracking/service/query/query_test.go b/pkg/tracking/service/query/query_test.go index 50192cb..95b2277 100644 --- a/pkg/tracking/service/query/query_test.go +++ b/pkg/tracking/service/query/query_test.go @@ -1,112 +1,112 @@ -package query_test - -import ( - "strings" - "testing" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query" -) - -func TestValidQueries(t *testing.T) { - t.Parallel() - - samples := []string{ - "metrics.foobar = 40", - "metrics.foobar = 40 AND run_name = \"bouncy-boar-498\"", - "tags.\"mlflow.source.name\" = \"scratch.py\"", - "metrics.accuracy > 0.9", - "params.\"random_state\" = \"8888\"", - "params.`random_state` = \"8888\"", - "params.solver ILIKE \"L%\"", - "params.solver LIKE \"l%\"", - "datasets.digest IN ('77a19fc0')", - "attributes.run_id IN ('meh')", - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample, func(t *testing.T) { - t.Parallel() - - _, err := query.ParseFilter(currentSample) - if err != nil { - t.Errorf("unexpected parse error: %v", err) - } - }) - } -} - -type invalidSample struct { - input string - expectedError string -} - -//nolint:funlen -func TestInvalidQueries(t *testing.T) { - t.Parallel() - - samples := []invalidSample{ - { - input: "yow.foobar = 40", - expectedError: "invalid identifier", - }, - { - input: "attributes.foobar = 40", - expectedError: "Invalid attribute key '{foobar}' specified. " + - "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", - }, - { - input: "datasets.foobar = 40", - expectedError: "Invalid dataset key '{foobar}' specified. " + - "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", - }, - { - input: "metric.yow = 'z'", - expectedError: "expected numeric value type for metric.", - }, - { - input: "parameter.tag = 2", - expectedError: "expected a quoted string value", - }, - { - input: "attributes.start_time = 'now'", - expectedError: "expected numeric value type for numeric attribute", - }, - { - input: "attributes.run_name IN ('foo','bar')", - expectedError: "only the 'run_id' attribute supports comparison with a list", - }, - { - input: "datasets.name = 40", - expectedError: "expected datasets.name to be either a string or list of strings", - }, - { - input: "datasets.digest = 50", - expectedError: "expected datasets.digest to be either a string or list of strings", - }, - { - input: "datasets.context = 60", - expectedError: "expected datasets.context to be either a string or list of strings", - }, - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - _, err := query.ParseFilter(currentSample.input) - if err == nil { - t.Errorf("expected parse error but got nil") - } - - if !strings.Contains(err.Error(), currentSample.expectedError) { - t.Errorf( - "expected error to contain %q, got %q", - currentSample.expectedError, - err.Error(), - ) - } - }) - } -} +package query_test + +import ( + "strings" + "testing" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query" +) + +func TestValidQueries(t *testing.T) { + t.Parallel() + + samples := []string{ + "metrics.foobar = 40", + "metrics.foobar = 40 AND run_name = \"bouncy-boar-498\"", + "tags.\"mlflow.source.name\" = \"scratch.py\"", + "metrics.accuracy > 0.9", + "params.\"random_state\" = \"8888\"", + "params.`random_state` = \"8888\"", + "params.solver ILIKE \"L%\"", + "params.solver LIKE \"l%\"", + "datasets.digest IN ('77a19fc0')", + "attributes.run_id IN ('meh')", + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample, func(t *testing.T) { + t.Parallel() + + _, err := query.ParseFilter(currentSample) + if err != nil { + t.Errorf("unexpected parse error: %v", err) + } + }) + } +} + +type invalidSample struct { + input string + expectedError string +} + +//nolint:funlen +func TestInvalidQueries(t *testing.T) { + t.Parallel() + + samples := []invalidSample{ + { + input: "yow.foobar = 40", + expectedError: "invalid identifier", + }, + { + input: "attributes.foobar = 40", + expectedError: "Invalid attribute key '{foobar}' specified. " + + "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", + }, + { + input: "datasets.foobar = 40", + expectedError: "Invalid dataset key '{foobar}' specified. " + + "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", + }, + { + input: "metric.yow = 'z'", + expectedError: "expected numeric value type for metric.", + }, + { + input: "parameter.tag = 2", + expectedError: "expected a quoted string value", + }, + { + input: "attributes.start_time = 'now'", + expectedError: "expected numeric value type for numeric attribute", + }, + { + input: "attributes.run_name IN ('foo','bar')", + expectedError: "only the 'run_id' attribute supports comparison with a list", + }, + { + input: "datasets.name = 40", + expectedError: "expected datasets.name to be either a string or list of strings", + }, + { + input: "datasets.digest = 50", + expectedError: "expected datasets.digest to be either a string or list of strings", + }, + { + input: "datasets.context = 60", + expectedError: "expected datasets.context to be either a string or list of strings", + }, + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + _, err := query.ParseFilter(currentSample.input) + if err == nil { + t.Errorf("expected parse error but got nil") + } + + if !strings.Contains(err.Error(), currentSample.expectedError) { + t.Errorf( + "expected error to contain %q, got %q", + currentSample.expectedError, + err.Error(), + ) + } + }) + } +} diff --git a/pkg/tracking/service/runs.go b/pkg/tracking/service/runs.go index 15835c5..cf20e2b 100644 --- a/pkg/tracking/service/runs.go +++ b/pkg/tracking/service/runs.go @@ -1,168 +1,168 @@ -package service - -import ( - "context" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -func (ts TrackingService) SearchRuns( - ctx context.Context, input *protos.SearchRuns, -) (*protos.SearchRuns_Response, *contract.Error) { - var runViewType protos.ViewType - if input.RunViewType == nil { - runViewType = protos.ViewType_ALL - } else { - runViewType = input.GetRunViewType() - } - - maxResults := int(input.GetMaxResults()) - - runs, nextPageToken, err := ts.Store.SearchRuns( - ctx, - input.GetExperimentIds(), - input.GetFilter(), - runViewType, - maxResults, - input.GetOrderBy(), - input.GetPageToken(), - ) - if err != nil { - return nil, contract.NewError(protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error getting runs: %v", err)) - } - - response := protos.SearchRuns_Response{ - Runs: make([]*protos.Run, len(runs)), - NextPageToken: &nextPageToken, - } - - for i, run := range runs { - response.Runs[i] = run.ToProto() - } - - return &response, nil -} - -func (ts TrackingService) LogBatch( - ctx context.Context, input *protos.LogBatch, -) (*protos.LogBatch_Response, *contract.Error) { - metrics := make([]*entities.Metric, len(input.GetMetrics())) - for i, metric := range input.GetMetrics() { - metrics[i] = entities.MetricFromProto(metric) - } - - params := make([]*entities.Param, len(input.GetParams())) - for i, param := range input.GetParams() { - params[i] = entities.ParamFromProto(param) - } - - tags := make([]*entities.RunTag, len(input.GetTags())) - for i, tag := range input.GetTags() { - tags[i] = entities.NewTagFromProto(tag) - } - - err := ts.Store.LogBatch(ctx, input.GetRunId(), metrics, params, tags) - if err != nil { - return nil, err - } - - return &protos.LogBatch_Response{}, nil -} - -func (ts TrackingService) GetRun( - ctx context.Context, input *protos.GetRun, -) (*protos.GetRun_Response, *contract.Error) { - run, err := ts.Store.GetRun(ctx, input.GetRunId()) - if err != nil { - return nil, err - } - - return &protos.GetRun_Response{Run: run.ToProto()}, nil -} - -func (ts TrackingService) CreateRun( - ctx context.Context, input *protos.CreateRun, -) (*protos.CreateRun_Response, *contract.Error) { - tags := make([]*entities.RunTag, 0, len(input.GetTags())) - for _, tag := range input.GetTags() { - tags = append(tags, entities.NewTagFromProto(tag)) - } - - run, err := ts.Store.CreateRun( - ctx, - input.GetExperimentId(), - input.GetUserId(), - input.GetStartTime(), - tags, - input.GetRunName(), - ) - if err != nil { - return nil, err - } - - return &protos.CreateRun_Response{Run: run.ToProto()}, nil -} - -func (ts TrackingService) UpdateRun( - ctx context.Context, input *protos.UpdateRun, -) (*protos.UpdateRun_Response, *contract.Error) { - run, err := ts.Store.GetRun(ctx, input.GetRunId()) - if err != nil { - return nil, err - } - - if run.Info.LifecycleStage != string(models.LifecycleStageActive) { - return nil, contract.NewError( - protos.ErrorCode_INVALID_STATE, - fmt.Sprintf( - "The run %s must be in the 'active' state. Current state is %s.", - input.GetRunUuid(), - run.Info.LifecycleStage, - ), - ) - } - - if status := input.GetStatus(); status != 0 { - run.Info.Status = status.String() - } - - if runName := input.GetRunName(); runName != "" { - run.Info.RunName = runName - } - - if err := ts.Store.UpdateRun( - ctx, - run.Info.RunID, - run.Info.Status, - input.EndTime, - run.Info.RunName, - ); err != nil { - return nil, err - } - - return &protos.UpdateRun_Response{RunInfo: run.Info.ToProto()}, nil -} - -func (ts TrackingService) DeleteRun( - ctx context.Context, input *protos.DeleteRun, -) (*protos.DeleteRun_Response, *contract.Error) { - if err := ts.Store.DeleteRun(ctx, input.GetRunId()); err != nil { - return nil, err - } - - return &protos.DeleteRun_Response{}, nil -} - -func (ts TrackingService) RestoreRun( - ctx context.Context, input *protos.RestoreRun, -) (*protos.RestoreRun_Response, *contract.Error) { - if err := ts.Store.RestoreRun(ctx, input.GetRunId()); err != nil { - return nil, err - } - - return &protos.RestoreRun_Response{}, nil -} +package service + +import ( + "context" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +func (ts TrackingService) SearchRuns( + ctx context.Context, input *protos.SearchRuns, +) (*protos.SearchRuns_Response, *contract.Error) { + var runViewType protos.ViewType + if input.RunViewType == nil { + runViewType = protos.ViewType_ALL + } else { + runViewType = input.GetRunViewType() + } + + maxResults := int(input.GetMaxResults()) + + runs, nextPageToken, err := ts.Store.SearchRuns( + ctx, + input.GetExperimentIds(), + input.GetFilter(), + runViewType, + maxResults, + input.GetOrderBy(), + input.GetPageToken(), + ) + if err != nil { + return nil, contract.NewError(protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error getting runs: %v", err)) + } + + response := protos.SearchRuns_Response{ + Runs: make([]*protos.Run, len(runs)), + NextPageToken: &nextPageToken, + } + + for i, run := range runs { + response.Runs[i] = run.ToProto() + } + + return &response, nil +} + +func (ts TrackingService) LogBatch( + ctx context.Context, input *protos.LogBatch, +) (*protos.LogBatch_Response, *contract.Error) { + metrics := make([]*entities.Metric, len(input.GetMetrics())) + for i, metric := range input.GetMetrics() { + metrics[i] = entities.MetricFromProto(metric) + } + + params := make([]*entities.Param, len(input.GetParams())) + for i, param := range input.GetParams() { + params[i] = entities.ParamFromProto(param) + } + + tags := make([]*entities.RunTag, len(input.GetTags())) + for i, tag := range input.GetTags() { + tags[i] = entities.NewTagFromProto(tag) + } + + err := ts.Store.LogBatch(ctx, input.GetRunId(), metrics, params, tags) + if err != nil { + return nil, err + } + + return &protos.LogBatch_Response{}, nil +} + +func (ts TrackingService) GetRun( + ctx context.Context, input *protos.GetRun, +) (*protos.GetRun_Response, *contract.Error) { + run, err := ts.Store.GetRun(ctx, input.GetRunId()) + if err != nil { + return nil, err + } + + return &protos.GetRun_Response{Run: run.ToProto()}, nil +} + +func (ts TrackingService) CreateRun( + ctx context.Context, input *protos.CreateRun, +) (*protos.CreateRun_Response, *contract.Error) { + tags := make([]*entities.RunTag, 0, len(input.GetTags())) + for _, tag := range input.GetTags() { + tags = append(tags, entities.NewTagFromProto(tag)) + } + + run, err := ts.Store.CreateRun( + ctx, + input.GetExperimentId(), + input.GetUserId(), + input.GetStartTime(), + tags, + input.GetRunName(), + ) + if err != nil { + return nil, err + } + + return &protos.CreateRun_Response{Run: run.ToProto()}, nil +} + +func (ts TrackingService) UpdateRun( + ctx context.Context, input *protos.UpdateRun, +) (*protos.UpdateRun_Response, *contract.Error) { + run, err := ts.Store.GetRun(ctx, input.GetRunId()) + if err != nil { + return nil, err + } + + if run.Info.LifecycleStage != string(models.LifecycleStageActive) { + return nil, contract.NewError( + protos.ErrorCode_INVALID_STATE, + fmt.Sprintf( + "The run %s must be in the 'active' state. Current state is %s.", + input.GetRunUuid(), + run.Info.LifecycleStage, + ), + ) + } + + if status := input.GetStatus(); status != 0 { + run.Info.Status = status.String() + } + + if runName := input.GetRunName(); runName != "" { + run.Info.RunName = runName + } + + if err := ts.Store.UpdateRun( + ctx, + run.Info.RunID, + run.Info.Status, + input.EndTime, + run.Info.RunName, + ); err != nil { + return nil, err + } + + return &protos.UpdateRun_Response{RunInfo: run.Info.ToProto()}, nil +} + +func (ts TrackingService) DeleteRun( + ctx context.Context, input *protos.DeleteRun, +) (*protos.DeleteRun_Response, *contract.Error) { + if err := ts.Store.DeleteRun(ctx, input.GetRunId()); err != nil { + return nil, err + } + + return &protos.DeleteRun_Response{}, nil +} + +func (ts TrackingService) RestoreRun( + ctx context.Context, input *protos.RestoreRun, +) (*protos.RestoreRun_Response, *contract.Error) { + if err := ts.Store.RestoreRun(ctx, input.GetRunId()); err != nil { + return nil, err + } + + return &protos.RestoreRun_Response{}, nil +} diff --git a/pkg/tracking/service/service.go b/pkg/tracking/service/service.go index 074b854..8218627 100644 --- a/pkg/tracking/service/service.go +++ b/pkg/tracking/service/service.go @@ -1,27 +1,27 @@ -package service - -import ( - "context" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/tracking/store" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql" -) - -type TrackingService struct { - config *config.Config - Store store.TrackingStore -} - -func NewTrackingService(ctx context.Context, config *config.Config) (*TrackingService, error) { - store, err := sql.NewTrackingSQLStore(ctx, config) - if err != nil { - return nil, fmt.Errorf("failed to create new sql store: %w", err) - } - - return &TrackingService{ - config: config, - Store: store, - }, nil -} +package service + +import ( + "context" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/tracking/store" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql" +) + +type TrackingService struct { + config *config.Config + Store store.TrackingStore +} + +func NewTrackingService(ctx context.Context, config *config.Config) (*TrackingService, error) { + store, err := sql.NewTrackingSQLStore(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to create new sql store: %w", err) + } + + return &TrackingService{ + config: config, + Store: store, + }, nil +} diff --git a/pkg/tracking/service/tags.go b/pkg/tracking/service/tags.go new file mode 100644 index 0000000..2f3464f --- /dev/null +++ b/pkg/tracking/service/tags.go @@ -0,0 +1,27 @@ +package service + +import ( + "context" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +func (ts TrackingService) SetTag(ctx context.Context, input *protos.SetTag) (*protos.SetTag_Response, *contract.Error) { + // Print input + fmt.Println(input) + if err := ts.Store.SetTag(ctx, input.GetRunId(), input.GetKey(), input.GetValue()); err != nil { + return nil, err + } + + return &protos.SetTag_Response{}, nil +} + +func (ts TrackingService) DeleteTag(ctx context.Context, input *protos.DeleteTag) (*protos.DeleteTag_Response, *contract.Error) { + if err := ts.Store.DeleteTag(ctx, input.GetRunId(), input.GetKey()); err != nil { + return nil, err + } + + return &protos.DeleteTag_Response{}, nil +} \ No newline at end of file diff --git a/pkg/tracking/store/mock_tracking_store.go b/pkg/tracking/store/mock_tracking_store.go index d9403c6..c6d52be 100644 --- a/pkg/tracking/store/mock_tracking_store.go +++ b/pkg/tracking/store/mock_tracking_store.go @@ -13,7 +13,7 @@ import ( protos "github.com/mlflow/mlflow-go/pkg/protos" ) -// MockTrackingStore is an autogenerated mock type for the TrackingStore type. +// MockTrackingStore is an autogenerated mock type for the TrackingStore type type MockTrackingStore struct { mock.Mock } @@ -26,8 +26,8 @@ func (_m *MockTrackingStore) EXPECT() *MockTrackingStore_Expecter { return &MockTrackingStore_Expecter{mock: &_m.Mock} } -// CreateExperiment provides a mock function with given fields: ctx, name, artifactLocation, tags. -func (_m *MockTrackingStore) CreateExperiment(ctx context.Context, name, artifactLocation string, tags []*entities.ExperimentTag) (string, *contract.Error) { +// CreateExperiment provides a mock function with given fields: ctx, name, artifactLocation, tags +func (_m *MockTrackingStore) CreateExperiment(ctx context.Context, name string, artifactLocation string, tags []*entities.ExperimentTag) (string, *contract.Error) { ret := _m.Called(ctx, name, artifactLocation, tags) if len(ret) == 0 { @@ -56,7 +56,7 @@ func (_m *MockTrackingStore) CreateExperiment(ctx context.Context, name, artifac return r0, r1 } -// MockTrackingStore_CreateExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateExperiment'. +// MockTrackingStore_CreateExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateExperiment' type MockTrackingStore_CreateExperiment_Call struct { *mock.Call } @@ -66,11 +66,11 @@ type MockTrackingStore_CreateExperiment_Call struct { // - name string // - artifactLocation string // - tags []*entities.ExperimentTag -func (_e *MockTrackingStore_Expecter) CreateExperiment(ctx, name, artifactLocation, tags interface{}) *MockTrackingStore_CreateExperiment_Call { +func (_e *MockTrackingStore_Expecter) CreateExperiment(ctx interface{}, name interface{}, artifactLocation interface{}, tags interface{}) *MockTrackingStore_CreateExperiment_Call { return &MockTrackingStore_CreateExperiment_Call{Call: _e.mock.On("CreateExperiment", ctx, name, artifactLocation, tags)} } -func (_c *MockTrackingStore_CreateExperiment_Call) Run(run func(ctx context.Context, name, artifactLocation string, tags []*entities.ExperimentTag)) *MockTrackingStore_CreateExperiment_Call { +func (_c *MockTrackingStore_CreateExperiment_Call) Run(run func(ctx context.Context, name string, artifactLocation string, tags []*entities.ExperimentTag)) *MockTrackingStore_CreateExperiment_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]*entities.ExperimentTag)) }) @@ -87,8 +87,8 @@ func (_c *MockTrackingStore_CreateExperiment_Call) RunAndReturn(run func(context return _c } -// CreateRun provides a mock function with given fields: ctx, experimentID, userID, startTime, tags, runName. -func (_m *MockTrackingStore) CreateRun(ctx context.Context, experimentID, userID string, startTime int64, tags []*entities.RunTag, runName string) (*entities.Run, *contract.Error) { +// CreateRun provides a mock function with given fields: ctx, experimentID, userID, startTime, tags, runName +func (_m *MockTrackingStore) CreateRun(ctx context.Context, experimentID string, userID string, startTime int64, tags []*entities.RunTag, runName string) (*entities.Run, *contract.Error) { ret := _m.Called(ctx, experimentID, userID, startTime, tags, runName) if len(ret) == 0 { @@ -119,7 +119,7 @@ func (_m *MockTrackingStore) CreateRun(ctx context.Context, experimentID, userID return r0, r1 } -// MockTrackingStore_CreateRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateRun'. +// MockTrackingStore_CreateRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateRun' type MockTrackingStore_CreateRun_Call struct { *mock.Call } @@ -131,11 +131,11 @@ type MockTrackingStore_CreateRun_Call struct { // - startTime int64 // - tags []*entities.RunTag // - runName string -func (_e *MockTrackingStore_Expecter) CreateRun(ctx, experimentID, userID, startTime, tags, runName interface{}) *MockTrackingStore_CreateRun_Call { +func (_e *MockTrackingStore_Expecter) CreateRun(ctx interface{}, experimentID interface{}, userID interface{}, startTime interface{}, tags interface{}, runName interface{}) *MockTrackingStore_CreateRun_Call { return &MockTrackingStore_CreateRun_Call{Call: _e.mock.On("CreateRun", ctx, experimentID, userID, startTime, tags, runName)} } -func (_c *MockTrackingStore_CreateRun_Call) Run(run func(ctx context.Context, experimentID, userID string, startTime int64, tags []*entities.RunTag, runName string)) *MockTrackingStore_CreateRun_Call { +func (_c *MockTrackingStore_CreateRun_Call) Run(run func(ctx context.Context, experimentID string, userID string, startTime int64, tags []*entities.RunTag, runName string)) *MockTrackingStore_CreateRun_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64), args[4].([]*entities.RunTag), args[5].(string)) }) @@ -152,7 +152,7 @@ func (_c *MockTrackingStore_CreateRun_Call) RunAndReturn(run func(context.Contex return _c } -// DeleteExperiment provides a mock function with given fields: ctx, id. +// DeleteExperiment provides a mock function with given fields: ctx, id func (_m *MockTrackingStore) DeleteExperiment(ctx context.Context, id string) *contract.Error { ret := _m.Called(ctx, id) @@ -172,7 +172,7 @@ func (_m *MockTrackingStore) DeleteExperiment(ctx context.Context, id string) *c return r0 } -// MockTrackingStore_DeleteExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteExperiment'. +// MockTrackingStore_DeleteExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteExperiment' type MockTrackingStore_DeleteExperiment_Call struct { *mock.Call } @@ -180,7 +180,7 @@ type MockTrackingStore_DeleteExperiment_Call struct { // DeleteExperiment is a helper method to define mock.On call // - ctx context.Context // - id string -func (_e *MockTrackingStore_Expecter) DeleteExperiment(ctx, id interface{}) *MockTrackingStore_DeleteExperiment_Call { +func (_e *MockTrackingStore_Expecter) DeleteExperiment(ctx interface{}, id interface{}) *MockTrackingStore_DeleteExperiment_Call { return &MockTrackingStore_DeleteExperiment_Call{Call: _e.mock.On("DeleteExperiment", ctx, id)} } @@ -201,7 +201,7 @@ func (_c *MockTrackingStore_DeleteExperiment_Call) RunAndReturn(run func(context return _c } -// DeleteRun provides a mock function with given fields: ctx, runID. +// DeleteRun provides a mock function with given fields: ctx, runID func (_m *MockTrackingStore) DeleteRun(ctx context.Context, runID string) *contract.Error { ret := _m.Called(ctx, runID) @@ -221,7 +221,7 @@ func (_m *MockTrackingStore) DeleteRun(ctx context.Context, runID string) *contr return r0 } -// MockTrackingStore_DeleteRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteRun'. +// MockTrackingStore_DeleteRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteRun' type MockTrackingStore_DeleteRun_Call struct { *mock.Call } @@ -229,7 +229,7 @@ type MockTrackingStore_DeleteRun_Call struct { // DeleteRun is a helper method to define mock.On call // - ctx context.Context // - runID string -func (_e *MockTrackingStore_Expecter) DeleteRun(ctx, runID interface{}) *MockTrackingStore_DeleteRun_Call { +func (_e *MockTrackingStore_Expecter) DeleteRun(ctx interface{}, runID interface{}) *MockTrackingStore_DeleteRun_Call { return &MockTrackingStore_DeleteRun_Call{Call: _e.mock.On("DeleteRun", ctx, runID)} } @@ -250,7 +250,57 @@ func (_c *MockTrackingStore_DeleteRun_Call) RunAndReturn(run func(context.Contex return _c } -// GetExperiment provides a mock function with given fields: ctx, id. +// DeleteTag provides a mock function with given fields: ctx, runID, key +func (_m *MockTrackingStore) DeleteTag(ctx context.Context, runID string, key string) *contract.Error { + ret := _m.Called(ctx, runID, key) + + if len(ret) == 0 { + panic("no return value specified for DeleteTag") + } + + var r0 *contract.Error + if rf, ok := ret.Get(0).(func(context.Context, string, string) *contract.Error); ok { + r0 = rf(ctx, runID, key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*contract.Error) + } + } + + return r0 +} + +// MockTrackingStore_DeleteTag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteTag' +type MockTrackingStore_DeleteTag_Call struct { + *mock.Call +} + +// DeleteTag is a helper method to define mock.On call +// - ctx context.Context +// - runID string +// - key string +func (_e *MockTrackingStore_Expecter) DeleteTag(ctx interface{}, runID interface{}, key interface{}) *MockTrackingStore_DeleteTag_Call { + return &MockTrackingStore_DeleteTag_Call{Call: _e.mock.On("DeleteTag", ctx, runID, key)} +} + +func (_c *MockTrackingStore_DeleteTag_Call) Run(run func(ctx context.Context, runID string, key string)) *MockTrackingStore_DeleteTag_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockTrackingStore_DeleteTag_Call) Return(_a0 *contract.Error) *MockTrackingStore_DeleteTag_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTrackingStore_DeleteTag_Call) RunAndReturn(run func(context.Context, string, string) *contract.Error) *MockTrackingStore_DeleteTag_Call { + _c.Call.Return(run) + return _c +} + +// GetExperiment provides a mock function with given fields: ctx, id func (_m *MockTrackingStore) GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) { ret := _m.Called(ctx, id) @@ -282,7 +332,7 @@ func (_m *MockTrackingStore) GetExperiment(ctx context.Context, id string) (*ent return r0, r1 } -// MockTrackingStore_GetExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetExperiment'. +// MockTrackingStore_GetExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetExperiment' type MockTrackingStore_GetExperiment_Call struct { *mock.Call } @@ -290,7 +340,7 @@ type MockTrackingStore_GetExperiment_Call struct { // GetExperiment is a helper method to define mock.On call // - ctx context.Context // - id string -func (_e *MockTrackingStore_Expecter) GetExperiment(ctx, id interface{}) *MockTrackingStore_GetExperiment_Call { +func (_e *MockTrackingStore_Expecter) GetExperiment(ctx interface{}, id interface{}) *MockTrackingStore_GetExperiment_Call { return &MockTrackingStore_GetExperiment_Call{Call: _e.mock.On("GetExperiment", ctx, id)} } @@ -311,7 +361,7 @@ func (_c *MockTrackingStore_GetExperiment_Call) RunAndReturn(run func(context.Co return _c } -// GetExperimentByName provides a mock function with given fields: ctx, name. +// GetExperimentByName provides a mock function with given fields: ctx, name func (_m *MockTrackingStore) GetExperimentByName(ctx context.Context, name string) (*entities.Experiment, *contract.Error) { ret := _m.Called(ctx, name) @@ -343,7 +393,7 @@ func (_m *MockTrackingStore) GetExperimentByName(ctx context.Context, name strin return r0, r1 } -// MockTrackingStore_GetExperimentByName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetExperimentByName'. +// MockTrackingStore_GetExperimentByName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetExperimentByName' type MockTrackingStore_GetExperimentByName_Call struct { *mock.Call } @@ -351,7 +401,7 @@ type MockTrackingStore_GetExperimentByName_Call struct { // GetExperimentByName is a helper method to define mock.On call // - ctx context.Context // - name string -func (_e *MockTrackingStore_Expecter) GetExperimentByName(ctx, name interface{}) *MockTrackingStore_GetExperimentByName_Call { +func (_e *MockTrackingStore_Expecter) GetExperimentByName(ctx interface{}, name interface{}) *MockTrackingStore_GetExperimentByName_Call { return &MockTrackingStore_GetExperimentByName_Call{Call: _e.mock.On("GetExperimentByName", ctx, name)} } @@ -372,7 +422,7 @@ func (_c *MockTrackingStore_GetExperimentByName_Call) RunAndReturn(run func(cont return _c } -// GetRun provides a mock function with given fields: ctx, runID. +// GetRun provides a mock function with given fields: ctx, runID func (_m *MockTrackingStore) GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) { ret := _m.Called(ctx, runID) @@ -404,7 +454,7 @@ func (_m *MockTrackingStore) GetRun(ctx context.Context, runID string) (*entitie return r0, r1 } -// MockTrackingStore_GetRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRun'. +// MockTrackingStore_GetRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRun' type MockTrackingStore_GetRun_Call struct { *mock.Call } @@ -412,7 +462,7 @@ type MockTrackingStore_GetRun_Call struct { // GetRun is a helper method to define mock.On call // - ctx context.Context // - runID string -func (_e *MockTrackingStore_Expecter) GetRun(ctx, runID interface{}) *MockTrackingStore_GetRun_Call { +func (_e *MockTrackingStore_Expecter) GetRun(ctx interface{}, runID interface{}) *MockTrackingStore_GetRun_Call { return &MockTrackingStore_GetRun_Call{Call: _e.mock.On("GetRun", ctx, runID)} } @@ -433,8 +483,8 @@ func (_c *MockTrackingStore_GetRun_Call) RunAndReturn(run func(context.Context, return _c } -// GetRunTag provides a mock function with given fields: ctx, runID, tagKey. -func (_m *MockTrackingStore) GetRunTag(ctx context.Context, runID, tagKey string) (*entities.RunTag, *contract.Error) { +// GetRunTag provides a mock function with given fields: ctx, runID, tagKey +func (_m *MockTrackingStore) GetRunTag(ctx context.Context, runID string, tagKey string) (*entities.RunTag, *contract.Error) { ret := _m.Called(ctx, runID, tagKey) if len(ret) == 0 { @@ -465,7 +515,7 @@ func (_m *MockTrackingStore) GetRunTag(ctx context.Context, runID, tagKey string return r0, r1 } -// MockTrackingStore_GetRunTag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRunTag'. +// MockTrackingStore_GetRunTag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRunTag' type MockTrackingStore_GetRunTag_Call struct { *mock.Call } @@ -474,11 +524,11 @@ type MockTrackingStore_GetRunTag_Call struct { // - ctx context.Context // - runID string // - tagKey string -func (_e *MockTrackingStore_Expecter) GetRunTag(ctx, runID, tagKey interface{}) *MockTrackingStore_GetRunTag_Call { +func (_e *MockTrackingStore_Expecter) GetRunTag(ctx interface{}, runID interface{}, tagKey interface{}) *MockTrackingStore_GetRunTag_Call { return &MockTrackingStore_GetRunTag_Call{Call: _e.mock.On("GetRunTag", ctx, runID, tagKey)} } -func (_c *MockTrackingStore_GetRunTag_Call) Run(run func(ctx context.Context, runID, tagKey string)) *MockTrackingStore_GetRunTag_Call { +func (_c *MockTrackingStore_GetRunTag_Call) Run(run func(ctx context.Context, runID string, tagKey string)) *MockTrackingStore_GetRunTag_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(string)) }) @@ -495,7 +545,7 @@ func (_c *MockTrackingStore_GetRunTag_Call) RunAndReturn(run func(context.Contex return _c } -// LogBatch provides a mock function with given fields: ctx, runID, metrics, params, tags. +// LogBatch provides a mock function with given fields: ctx, runID, metrics, params, tags func (_m *MockTrackingStore) LogBatch(ctx context.Context, runID string, metrics []*entities.Metric, params []*entities.Param, tags []*entities.RunTag) *contract.Error { ret := _m.Called(ctx, runID, metrics, params, tags) @@ -515,7 +565,7 @@ func (_m *MockTrackingStore) LogBatch(ctx context.Context, runID string, metrics return r0 } -// MockTrackingStore_LogBatch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogBatch'. +// MockTrackingStore_LogBatch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogBatch' type MockTrackingStore_LogBatch_Call struct { *mock.Call } @@ -526,7 +576,7 @@ type MockTrackingStore_LogBatch_Call struct { // - metrics []*entities.Metric // - params []*entities.Param // - tags []*entities.RunTag -func (_e *MockTrackingStore_Expecter) LogBatch(ctx, runID, metrics, params, tags interface{}) *MockTrackingStore_LogBatch_Call { +func (_e *MockTrackingStore_Expecter) LogBatch(ctx interface{}, runID interface{}, metrics interface{}, params interface{}, tags interface{}) *MockTrackingStore_LogBatch_Call { return &MockTrackingStore_LogBatch_Call{Call: _e.mock.On("LogBatch", ctx, runID, metrics, params, tags)} } @@ -547,7 +597,7 @@ func (_c *MockTrackingStore_LogBatch_Call) RunAndReturn(run func(context.Context return _c } -// LogMetric provides a mock function with given fields: ctx, runID, metric. +// LogMetric provides a mock function with given fields: ctx, runID, metric func (_m *MockTrackingStore) LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error { ret := _m.Called(ctx, runID, metric) @@ -567,7 +617,7 @@ func (_m *MockTrackingStore) LogMetric(ctx context.Context, runID string, metric return r0 } -// MockTrackingStore_LogMetric_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogMetric'. +// MockTrackingStore_LogMetric_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogMetric' type MockTrackingStore_LogMetric_Call struct { *mock.Call } @@ -576,7 +626,7 @@ type MockTrackingStore_LogMetric_Call struct { // - ctx context.Context // - runID string // - metric *entities.Metric -func (_e *MockTrackingStore_Expecter) LogMetric(ctx, runID, metric interface{}) *MockTrackingStore_LogMetric_Call { +func (_e *MockTrackingStore_Expecter) LogMetric(ctx interface{}, runID interface{}, metric interface{}) *MockTrackingStore_LogMetric_Call { return &MockTrackingStore_LogMetric_Call{Call: _e.mock.On("LogMetric", ctx, runID, metric)} } @@ -597,8 +647,8 @@ func (_c *MockTrackingStore_LogMetric_Call) RunAndReturn(run func(context.Contex return _c } -// RenameExperiment provides a mock function with given fields: ctx, experimentID, name. -func (_m *MockTrackingStore) RenameExperiment(ctx context.Context, experimentID, name string) *contract.Error { +// RenameExperiment provides a mock function with given fields: ctx, experimentID, name +func (_m *MockTrackingStore) RenameExperiment(ctx context.Context, experimentID string, name string) *contract.Error { ret := _m.Called(ctx, experimentID, name) if len(ret) == 0 { @@ -617,7 +667,7 @@ func (_m *MockTrackingStore) RenameExperiment(ctx context.Context, experimentID, return r0 } -// MockTrackingStore_RenameExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RenameExperiment'. +// MockTrackingStore_RenameExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RenameExperiment' type MockTrackingStore_RenameExperiment_Call struct { *mock.Call } @@ -626,11 +676,11 @@ type MockTrackingStore_RenameExperiment_Call struct { // - ctx context.Context // - experimentID string // - name string -func (_e *MockTrackingStore_Expecter) RenameExperiment(ctx, experimentID, name interface{}) *MockTrackingStore_RenameExperiment_Call { +func (_e *MockTrackingStore_Expecter) RenameExperiment(ctx interface{}, experimentID interface{}, name interface{}) *MockTrackingStore_RenameExperiment_Call { return &MockTrackingStore_RenameExperiment_Call{Call: _e.mock.On("RenameExperiment", ctx, experimentID, name)} } -func (_c *MockTrackingStore_RenameExperiment_Call) Run(run func(ctx context.Context, experimentID, name string)) *MockTrackingStore_RenameExperiment_Call { +func (_c *MockTrackingStore_RenameExperiment_Call) Run(run func(ctx context.Context, experimentID string, name string)) *MockTrackingStore_RenameExperiment_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(string)) }) @@ -647,7 +697,7 @@ func (_c *MockTrackingStore_RenameExperiment_Call) RunAndReturn(run func(context return _c } -// RestoreExperiment provides a mock function with given fields: ctx, id. +// RestoreExperiment provides a mock function with given fields: ctx, id func (_m *MockTrackingStore) RestoreExperiment(ctx context.Context, id string) *contract.Error { ret := _m.Called(ctx, id) @@ -667,7 +717,7 @@ func (_m *MockTrackingStore) RestoreExperiment(ctx context.Context, id string) * return r0 } -// MockTrackingStore_RestoreExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreExperiment'. +// MockTrackingStore_RestoreExperiment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreExperiment' type MockTrackingStore_RestoreExperiment_Call struct { *mock.Call } @@ -675,7 +725,7 @@ type MockTrackingStore_RestoreExperiment_Call struct { // RestoreExperiment is a helper method to define mock.On call // - ctx context.Context // - id string -func (_e *MockTrackingStore_Expecter) RestoreExperiment(ctx, id interface{}) *MockTrackingStore_RestoreExperiment_Call { +func (_e *MockTrackingStore_Expecter) RestoreExperiment(ctx interface{}, id interface{}) *MockTrackingStore_RestoreExperiment_Call { return &MockTrackingStore_RestoreExperiment_Call{Call: _e.mock.On("RestoreExperiment", ctx, id)} } @@ -696,7 +746,7 @@ func (_c *MockTrackingStore_RestoreExperiment_Call) RunAndReturn(run func(contex return _c } -// RestoreRun provides a mock function with given fields: ctx, runID. +// RestoreRun provides a mock function with given fields: ctx, runID func (_m *MockTrackingStore) RestoreRun(ctx context.Context, runID string) *contract.Error { ret := _m.Called(ctx, runID) @@ -716,7 +766,7 @@ func (_m *MockTrackingStore) RestoreRun(ctx context.Context, runID string) *cont return r0 } -// MockTrackingStore_RestoreRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreRun'. +// MockTrackingStore_RestoreRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RestoreRun' type MockTrackingStore_RestoreRun_Call struct { *mock.Call } @@ -724,7 +774,7 @@ type MockTrackingStore_RestoreRun_Call struct { // RestoreRun is a helper method to define mock.On call // - ctx context.Context // - runID string -func (_e *MockTrackingStore_Expecter) RestoreRun(ctx, runID interface{}) *MockTrackingStore_RestoreRun_Call { +func (_e *MockTrackingStore_Expecter) RestoreRun(ctx interface{}, runID interface{}) *MockTrackingStore_RestoreRun_Call { return &MockTrackingStore_RestoreRun_Call{Call: _e.mock.On("RestoreRun", ctx, runID)} } @@ -745,7 +795,7 @@ func (_c *MockTrackingStore_RestoreRun_Call) RunAndReturn(run func(context.Conte return _c } -// SearchRuns provides a mock function with given fields: ctx, experimentIDs, filter, runViewType, maxResults, orderBy, pageToken. +// SearchRuns provides a mock function with given fields: ctx, experimentIDs, filter, runViewType, maxResults, orderBy, pageToken func (_m *MockTrackingStore) SearchRuns(ctx context.Context, experimentIDs []string, filter string, runViewType protos.ViewType, maxResults int, orderBy []string, pageToken string) ([]*entities.Run, string, *contract.Error) { ret := _m.Called(ctx, experimentIDs, filter, runViewType, maxResults, orderBy, pageToken) @@ -784,7 +834,7 @@ func (_m *MockTrackingStore) SearchRuns(ctx context.Context, experimentIDs []str return r0, r1, r2 } -// MockTrackingStore_SearchRuns_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SearchRuns'. +// MockTrackingStore_SearchRuns_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SearchRuns' type MockTrackingStore_SearchRuns_Call struct { *mock.Call } @@ -797,7 +847,7 @@ type MockTrackingStore_SearchRuns_Call struct { // - maxResults int // - orderBy []string // - pageToken string -func (_e *MockTrackingStore_Expecter) SearchRuns(ctx, experimentIDs, filter, runViewType, maxResults, orderBy, pageToken interface{}) *MockTrackingStore_SearchRuns_Call { +func (_e *MockTrackingStore_Expecter) SearchRuns(ctx interface{}, experimentIDs interface{}, filter interface{}, runViewType interface{}, maxResults interface{}, orderBy interface{}, pageToken interface{}) *MockTrackingStore_SearchRuns_Call { return &MockTrackingStore_SearchRuns_Call{Call: _e.mock.On("SearchRuns", ctx, experimentIDs, filter, runViewType, maxResults, orderBy, pageToken)} } @@ -818,8 +868,8 @@ func (_c *MockTrackingStore_SearchRuns_Call) RunAndReturn(run func(context.Conte return _c } -// UpdateRun provides a mock function with given fields: ctx, runID, runStatus, endTime, runName. -func (_m *MockTrackingStore) UpdateRun(ctx context.Context, runID, runStatus string, endTime *int64, runName string) *contract.Error { +// UpdateRun provides a mock function with given fields: ctx, runID, runStatus, endTime, runName +func (_m *MockTrackingStore) UpdateRun(ctx context.Context, runID string, runStatus string, endTime *int64, runName string) *contract.Error { ret := _m.Called(ctx, runID, runStatus, endTime, runName) if len(ret) == 0 { @@ -838,7 +888,7 @@ func (_m *MockTrackingStore) UpdateRun(ctx context.Context, runID, runStatus str return r0 } -// MockTrackingStore_UpdateRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRun'. +// MockTrackingStore_UpdateRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRun' type MockTrackingStore_UpdateRun_Call struct { *mock.Call } @@ -849,11 +899,11 @@ type MockTrackingStore_UpdateRun_Call struct { // - runStatus string // - endTime *int64 // - runName string -func (_e *MockTrackingStore_Expecter) UpdateRun(ctx, runID, runStatus, endTime, runName interface{}) *MockTrackingStore_UpdateRun_Call { +func (_e *MockTrackingStore_Expecter) UpdateRun(ctx interface{}, runID interface{}, runStatus interface{}, endTime interface{}, runName interface{}) *MockTrackingStore_UpdateRun_Call { return &MockTrackingStore_UpdateRun_Call{Call: _e.mock.On("UpdateRun", ctx, runID, runStatus, endTime, runName)} } -func (_c *MockTrackingStore_UpdateRun_Call) Run(run func(ctx context.Context, runID, runStatus string, endTime *int64, runName string)) *MockTrackingStore_UpdateRun_Call { +func (_c *MockTrackingStore_UpdateRun_Call) Run(run func(ctx context.Context, runID string, runStatus string, endTime *int64, runName string)) *MockTrackingStore_UpdateRun_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(*int64), args[4].(string)) }) @@ -875,8 +925,7 @@ func (_c *MockTrackingStore_UpdateRun_Call) RunAndReturn(run func(context.Contex func NewMockTrackingStore(t interface { mock.TestingT Cleanup(func()) -}, -) *MockTrackingStore { +}) *MockTrackingStore { mock := &MockTrackingStore{} mock.Mock.Test(t) diff --git a/pkg/tracking/store/sql/experiments.go b/pkg/tracking/store/sql/experiments.go index 7c5eba6..1c65e4a 100644 --- a/pkg/tracking/store/sql/experiments.go +++ b/pkg/tracking/store/sql/experiments.go @@ -1,254 +1,254 @@ -package sql - -import ( - "context" - "database/sql" - "errors" - "fmt" - "strconv" - "time" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func (s TrackingSQLStore) GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) { - idInt, err := strconv.ParseInt(id, 10, 32) - if err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("failed to convert experiment id %q to int", id), - err, - ) - } - - experiment := models.Experiment{ID: int32(idInt)} - if err := s.db.WithContext(ctx).Preload("Tags").First(&experiment).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("No Experiment with id=%d exists", idInt), - ) - } - - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to get experiment", - err, - ) - } - - return experiment.ToEntity(), nil -} - -func (s TrackingSQLStore) CreateExperiment( - ctx context.Context, - name string, - artifactLocation string, - tags []*entities.ExperimentTag, -) (string, *contract.Error) { - experiment := models.Experiment{ - Name: name, - Tags: make([]models.ExperimentTag, len(tags)), - ArtifactLocation: artifactLocation, - LifecycleStage: models.LifecycleStageActive, - CreationTime: time.Now().UnixMilli(), - LastUpdateTime: time.Now().UnixMilli(), - } - - for i, tag := range tags { - experiment.Tags[i] = models.ExperimentTag{ - Key: tag.Key, - Value: tag.Value, - } - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - if err := transaction.Create(&experiment).Error; err != nil { - return fmt.Errorf("failed to insert experiment: %w", err) - } - - if experiment.ArtifactLocation == "" { - artifactLocation, err := utils.AppendToURIPath(s.config.DefaultArtifactRoot, strconv.Itoa(int(experiment.ID))) - if err != nil { - return fmt.Errorf("failed to join artifact location: %w", err) - } - experiment.ArtifactLocation = artifactLocation - if err := transaction.Model(&experiment).UpdateColumn("artifact_location", artifactLocation).Error; err != nil { - return fmt.Errorf("failed to update experiment artifact location: %w", err) - } - } - - return nil - }); err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - return "", contract.NewError( - protos.ErrorCode_RESOURCE_ALREADY_EXISTS, - fmt.Sprintf("Experiment(name=%s) already exists.", experiment.Name), - ) - } - - return "", contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to create experiment", err) - } - - return strconv.Itoa(int(experiment.ID)), nil -} - -func (s TrackingSQLStore) RenameExperiment( - ctx context.Context, experimentID, name string, -) *contract.Error { - if err := s.db.WithContext(ctx).Model(&models.Experiment{}). - Where("experiment_id = ?", experimentID). - Updates(&models.Experiment{ - Name: name, - LastUpdateTime: time.Now().UnixMilli(), - }).Error; err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update experiment", err) - } - - return nil -} - -func (s TrackingSQLStore) DeleteExperiment(ctx context.Context, id string) *contract.Error { - idInt, err := strconv.ParseInt(id, 10, 32) - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("failed to convert experiment id (%s) to int", id), - err, - ) - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - // Update experiment - uex := transaction.Model(&models.Experiment{}). - Where("experiment_id = ?", idInt). - Updates(&models.Experiment{ - LifecycleStage: models.LifecycleStageDeleted, - LastUpdateTime: time.Now().UnixMilli(), - }) - - if uex.Error != nil { - return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) - } - - if uex.RowsAffected != 1 { - return gorm.ErrRecordNotFound - } - - // Update runs - if err := transaction.Model(&models.Run{}). - Where("experiment_id = ?", idInt). - Updates(&models.Run{ - LifecycleStage: models.LifecycleStageDeleted, - DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, - }).Error; err != nil { - return fmt.Errorf("failed to update runs during delete: %w", err) - } - - return nil - }); err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("No Experiment with id=%d exists", idInt), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to delete experiment", - err, - ) - } - - return nil -} - -func (s TrackingSQLStore) RestoreExperiment(ctx context.Context, id string) *contract.Error { - idInt, err := strconv.ParseInt(id, 10, 32) - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("failed to convert experiment id (%s) to int", id), - err, - ) - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - // Update experiment - uex := transaction.Model(&models.Experiment{}). - Where("experiment_id = ?", idInt). - Where("lifecycle_stage = ?", models.LifecycleStageDeleted). - Updates(&models.Experiment{ - LifecycleStage: models.LifecycleStageActive, - LastUpdateTime: time.Now().UnixMilli(), - }) - - if uex.Error != nil { - return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) - } - - if uex.RowsAffected != 1 { - return gorm.ErrRecordNotFound - } - - // Update runs - if err := transaction.Model(&models.Run{}). - Where("experiment_id = ?", idInt). - Select("DeletedTime", "LifecycleStage"). - Updates(&models.Run{ - LifecycleStage: models.LifecycleStageActive, - DeletedTime: sql.NullInt64{}, - }).Error; err != nil { - return fmt.Errorf("failed to update runs during restore: %w", err) - } - - return nil - }); err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("No Experiment with id=%d exists", idInt), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to delete experiment", - err, - ) - } - - return nil -} - -//nolint:perfsprint -func (s TrackingSQLStore) GetExperimentByName( - ctx context.Context, name string, -) (*entities.Experiment, *contract.Error) { - var experiment models.Experiment - - err := s.db.WithContext(ctx).Preload("Tags").Where("name = ?", name).First(&experiment).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("Could not find experiment with name %s", name), - ) - } - - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to get experiment by name %s", name), - err, - ) - } - - return experiment.ToEntity(), nil -} +package sql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "time" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func (s TrackingSQLStore) GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) { + idInt, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("failed to convert experiment id %q to int", id), + err, + ) + } + + experiment := models.Experiment{ID: int32(idInt)} + if err := s.db.WithContext(ctx).Preload("Tags").First(&experiment).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No Experiment with id=%d exists", idInt), + ) + } + + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to get experiment", + err, + ) + } + + return experiment.ToEntity(), nil +} + +func (s TrackingSQLStore) CreateExperiment( + ctx context.Context, + name string, + artifactLocation string, + tags []*entities.ExperimentTag, +) (string, *contract.Error) { + experiment := models.Experiment{ + Name: name, + Tags: make([]models.ExperimentTag, len(tags)), + ArtifactLocation: artifactLocation, + LifecycleStage: models.LifecycleStageActive, + CreationTime: time.Now().UnixMilli(), + LastUpdateTime: time.Now().UnixMilli(), + } + + for i, tag := range tags { + experiment.Tags[i] = models.ExperimentTag{ + Key: tag.Key, + Value: tag.Value, + } + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + if err := transaction.Create(&experiment).Error; err != nil { + return fmt.Errorf("failed to insert experiment: %w", err) + } + + if experiment.ArtifactLocation == "" { + artifactLocation, err := utils.AppendToURIPath(s.config.DefaultArtifactRoot, strconv.Itoa(int(experiment.ID))) + if err != nil { + return fmt.Errorf("failed to join artifact location: %w", err) + } + experiment.ArtifactLocation = artifactLocation + if err := transaction.Model(&experiment).UpdateColumn("artifact_location", artifactLocation).Error; err != nil { + return fmt.Errorf("failed to update experiment artifact location: %w", err) + } + } + + return nil + }); err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return "", contract.NewError( + protos.ErrorCode_RESOURCE_ALREADY_EXISTS, + fmt.Sprintf("Experiment(name=%s) already exists.", experiment.Name), + ) + } + + return "", contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to create experiment", err) + } + + return strconv.Itoa(int(experiment.ID)), nil +} + +func (s TrackingSQLStore) RenameExperiment( + ctx context.Context, experimentID, name string, +) *contract.Error { + if err := s.db.WithContext(ctx).Model(&models.Experiment{}). + Where("experiment_id = ?", experimentID). + Updates(&models.Experiment{ + Name: name, + LastUpdateTime: time.Now().UnixMilli(), + }).Error; err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update experiment", err) + } + + return nil +} + +func (s TrackingSQLStore) DeleteExperiment(ctx context.Context, id string) *contract.Error { + idInt, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("failed to convert experiment id (%s) to int", id), + err, + ) + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + // Update experiment + uex := transaction.Model(&models.Experiment{}). + Where("experiment_id = ?", idInt). + Updates(&models.Experiment{ + LifecycleStage: models.LifecycleStageDeleted, + LastUpdateTime: time.Now().UnixMilli(), + }) + + if uex.Error != nil { + return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) + } + + if uex.RowsAffected != 1 { + return gorm.ErrRecordNotFound + } + + // Update runs + if err := transaction.Model(&models.Run{}). + Where("experiment_id = ?", idInt). + Updates(&models.Run{ + LifecycleStage: models.LifecycleStageDeleted, + DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, + }).Error; err != nil { + return fmt.Errorf("failed to update runs during delete: %w", err) + } + + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No Experiment with id=%d exists", idInt), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to delete experiment", + err, + ) + } + + return nil +} + +func (s TrackingSQLStore) RestoreExperiment(ctx context.Context, id string) *contract.Error { + idInt, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("failed to convert experiment id (%s) to int", id), + err, + ) + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + // Update experiment + uex := transaction.Model(&models.Experiment{}). + Where("experiment_id = ?", idInt). + Where("lifecycle_stage = ?", models.LifecycleStageDeleted). + Updates(&models.Experiment{ + LifecycleStage: models.LifecycleStageActive, + LastUpdateTime: time.Now().UnixMilli(), + }) + + if uex.Error != nil { + return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) + } + + if uex.RowsAffected != 1 { + return gorm.ErrRecordNotFound + } + + // Update runs + if err := transaction.Model(&models.Run{}). + Where("experiment_id = ?", idInt). + Select("DeletedTime", "LifecycleStage"). + Updates(&models.Run{ + LifecycleStage: models.LifecycleStageActive, + DeletedTime: sql.NullInt64{}, + }).Error; err != nil { + return fmt.Errorf("failed to update runs during restore: %w", err) + } + + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No Experiment with id=%d exists", idInt), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to delete experiment", + err, + ) + } + + return nil +} + +//nolint:perfsprint +func (s TrackingSQLStore) GetExperimentByName( + ctx context.Context, name string, +) (*entities.Experiment, *contract.Error) { + var experiment models.Experiment + + err := s.db.WithContext(ctx).Preload("Tags").Where("name = ?", name).First(&experiment).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("Could not find experiment with name %s", name), + ) + } + + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to get experiment by name %s", name), + err, + ) + } + + return experiment.ToEntity(), nil +} diff --git a/pkg/tracking/store/sql/metrics.go b/pkg/tracking/store/sql/metrics.go index fbfdccc..df2b60e 100644 --- a/pkg/tracking/store/sql/metrics.go +++ b/pkg/tracking/store/sql/metrics.go @@ -1,193 +1,193 @@ -package sql - -import ( - "context" - "errors" - "fmt" - "math" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -const metricsBatchSize = 500 - -func getDistinctMetricKeys(metrics []models.Metric) []string { - metricKeysMap := make(map[string]any) - for _, m := range metrics { - metricKeysMap[m.Key] = nil - } - - metricKeys := make([]string, 0, len(metricKeysMap)) - for key := range metricKeysMap { - metricKeys = append(metricKeys, key) - } - - return metricKeys -} - -func getLatestMetrics(transaction *gorm.DB, runID string, metricKeys []string) ([]models.LatestMetric, error) { - const batchSize = 500 - - latestMetrics := make([]models.LatestMetric, 0, len(metricKeys)) - - for skip := 0; skip < len(metricKeys); skip += batchSize { - take := int(math.Max(float64(skip+batchSize), float64(len(metricKeys)))) - if take > len(metricKeys) { - take = len(metricKeys) - } - - currentBatch := make([]models.LatestMetric, 0, take-skip) - keys := metricKeys[skip:take] - - err := transaction. - Model(&models.LatestMetric{}). - Where("run_uuid = ?", runID).Where("key IN ?", keys). - Clauses(clause.Locking{Strength: "UPDATE"}). // https://gorm.io/docs/advanced_query.html#Locking - Order("run_uuid"). - Order("key"). - Find(¤tBatch).Error - if err != nil { - return latestMetrics, fmt.Errorf( - "failed to get latest metrics for run_uuid %q, skip %d, take %d : %w", - runID, skip, take, err, - ) - } - - latestMetrics = append(latestMetrics, currentBatch...) - } - - return latestMetrics, nil -} - -func isNewerMetric(a models.Metric, b models.LatestMetric) bool { - return a.Step > b.Step || - (a.Step == b.Step && a.Timestamp > b.Timestamp) || - (a.Step == b.Step && a.Timestamp == b.Timestamp && a.Value > b.Value) -} - -//nolint:cyclop -func updateLatestMetricsIfNecessary(transaction *gorm.DB, runID string, metrics []models.Metric) error { - if len(metrics) == 0 { - return nil - } - - metricKeys := getDistinctMetricKeys(metrics) - - latestMetrics, err := getLatestMetrics(transaction, runID, metricKeys) - if err != nil { - return fmt.Errorf("failed to get latest metrics for run_uuid %q: %w", runID, err) - } - - latestMetricsMap := make(map[string]models.LatestMetric, len(latestMetrics)) - for _, m := range latestMetrics { - latestMetricsMap[m.Key] = m - } - - nextLatestMetricsMap := make(map[string]models.LatestMetric, len(metrics)) - - for _, metric := range metrics { - latestMetric, found := latestMetricsMap[metric.Key] - nextLatestMetric, alreadyPresent := nextLatestMetricsMap[metric.Key] - - switch { - case !found && !alreadyPresent: - // brand new latest metric - nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() - case !found && alreadyPresent && isNewerMetric(metric, nextLatestMetric): - // there is no row in the database but the metric is present twice - // and we need to take the latest one from the batch. - nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() - case found && isNewerMetric(metric, latestMetric): - // compare with the row in the database - nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() - } - } - - nextLatestMetrics := make([]models.LatestMetric, 0, len(nextLatestMetricsMap)) - for _, nextLatestMetric := range nextLatestMetricsMap { - nextLatestMetrics = append(nextLatestMetrics, nextLatestMetric) - } - - if len(nextLatestMetrics) != 0 { - if err := transaction.Clauses(clause.OnConflict{ - UpdateAll: true, - }).Create(nextLatestMetrics).Error; err != nil { - return fmt.Errorf("failed to upsert latest metrics for run_uuid %q: %w", runID, err) - } - } - - return nil -} - -func (s TrackingSQLStore) logMetricsWithTransaction( - transaction *gorm.DB, runID string, metrics []*entities.Metric, -) *contract.Error { - // Duplicate metric values are eliminated - seenMetrics := make(map[models.Metric]struct{}) - modelMetrics := make([]models.Metric, 0, len(metrics)) - - for _, metric := range metrics { - currentMetric := models.NewMetricFromEntity(runID, metric) - if _, ok := seenMetrics[*currentMetric]; !ok { - seenMetrics[*currentMetric] = struct{}{} - - modelMetrics = append(modelMetrics, *currentMetric) - } - } - - if err := transaction.Clauses(clause.OnConflict{DoNothing: true}). - CreateInBatches(modelMetrics, metricsBatchSize).Error; err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("error creating metrics in batch for run_uuid %q", runID), - err, - ) - } - - if err := updateLatestMetricsIfNecessary(transaction, runID, modelMetrics); err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("error updating latest metrics for run_uuid %q", runID), - err, - ) - } - - return nil -} - -func (s TrackingSQLStore) LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error { - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - contractError := checkRunIsActive(transaction, runID) - if contractError != nil { - return contractError - } - - if err := s.logMetricsWithTransaction(transaction, runID, []*entities.Metric{ - metric, - }); err != nil { - return err - } - - return nil - }) - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return contractError - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("log metric transaction failed for %q", runID), - err, - ) - } - - return nil -} +package sql + +import ( + "context" + "errors" + "fmt" + "math" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +const metricsBatchSize = 500 + +func getDistinctMetricKeys(metrics []models.Metric) []string { + metricKeysMap := make(map[string]any) + for _, m := range metrics { + metricKeysMap[m.Key] = nil + } + + metricKeys := make([]string, 0, len(metricKeysMap)) + for key := range metricKeysMap { + metricKeys = append(metricKeys, key) + } + + return metricKeys +} + +func getLatestMetrics(transaction *gorm.DB, runID string, metricKeys []string) ([]models.LatestMetric, error) { + const batchSize = 500 + + latestMetrics := make([]models.LatestMetric, 0, len(metricKeys)) + + for skip := 0; skip < len(metricKeys); skip += batchSize { + take := int(math.Max(float64(skip+batchSize), float64(len(metricKeys)))) + if take > len(metricKeys) { + take = len(metricKeys) + } + + currentBatch := make([]models.LatestMetric, 0, take-skip) + keys := metricKeys[skip:take] + + err := transaction. + Model(&models.LatestMetric{}). + Where("run_uuid = ?", runID).Where("key IN ?", keys). + Clauses(clause.Locking{Strength: "UPDATE"}). // https://gorm.io/docs/advanced_query.html#Locking + Order("run_uuid"). + Order("key"). + Find(¤tBatch).Error + if err != nil { + return latestMetrics, fmt.Errorf( + "failed to get latest metrics for run_uuid %q, skip %d, take %d : %w", + runID, skip, take, err, + ) + } + + latestMetrics = append(latestMetrics, currentBatch...) + } + + return latestMetrics, nil +} + +func isNewerMetric(a models.Metric, b models.LatestMetric) bool { + return a.Step > b.Step || + (a.Step == b.Step && a.Timestamp > b.Timestamp) || + (a.Step == b.Step && a.Timestamp == b.Timestamp && a.Value > b.Value) +} + +//nolint:cyclop +func updateLatestMetricsIfNecessary(transaction *gorm.DB, runID string, metrics []models.Metric) error { + if len(metrics) == 0 { + return nil + } + + metricKeys := getDistinctMetricKeys(metrics) + + latestMetrics, err := getLatestMetrics(transaction, runID, metricKeys) + if err != nil { + return fmt.Errorf("failed to get latest metrics for run_uuid %q: %w", runID, err) + } + + latestMetricsMap := make(map[string]models.LatestMetric, len(latestMetrics)) + for _, m := range latestMetrics { + latestMetricsMap[m.Key] = m + } + + nextLatestMetricsMap := make(map[string]models.LatestMetric, len(metrics)) + + for _, metric := range metrics { + latestMetric, found := latestMetricsMap[metric.Key] + nextLatestMetric, alreadyPresent := nextLatestMetricsMap[metric.Key] + + switch { + case !found && !alreadyPresent: + // brand new latest metric + nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() + case !found && alreadyPresent && isNewerMetric(metric, nextLatestMetric): + // there is no row in the database but the metric is present twice + // and we need to take the latest one from the batch. + nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() + case found && isNewerMetric(metric, latestMetric): + // compare with the row in the database + nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() + } + } + + nextLatestMetrics := make([]models.LatestMetric, 0, len(nextLatestMetricsMap)) + for _, nextLatestMetric := range nextLatestMetricsMap { + nextLatestMetrics = append(nextLatestMetrics, nextLatestMetric) + } + + if len(nextLatestMetrics) != 0 { + if err := transaction.Clauses(clause.OnConflict{ + UpdateAll: true, + }).Create(nextLatestMetrics).Error; err != nil { + return fmt.Errorf("failed to upsert latest metrics for run_uuid %q: %w", runID, err) + } + } + + return nil +} + +func (s TrackingSQLStore) logMetricsWithTransaction( + transaction *gorm.DB, runID string, metrics []*entities.Metric, +) *contract.Error { + // Duplicate metric values are eliminated + seenMetrics := make(map[models.Metric]struct{}) + modelMetrics := make([]models.Metric, 0, len(metrics)) + + for _, metric := range metrics { + currentMetric := models.NewMetricFromEntity(runID, metric) + if _, ok := seenMetrics[*currentMetric]; !ok { + seenMetrics[*currentMetric] = struct{}{} + + modelMetrics = append(modelMetrics, *currentMetric) + } + } + + if err := transaction.Clauses(clause.OnConflict{DoNothing: true}). + CreateInBatches(modelMetrics, metricsBatchSize).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("error creating metrics in batch for run_uuid %q", runID), + err, + ) + } + + if err := updateLatestMetricsIfNecessary(transaction, runID, modelMetrics); err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("error updating latest metrics for run_uuid %q", runID), + err, + ) + } + + return nil +} + +func (s TrackingSQLStore) LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error { + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + if err := s.logMetricsWithTransaction(transaction, runID, []*entities.Metric{ + metric, + }); err != nil { + return err + } + + return nil + }) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("log metric transaction failed for %q", runID), + err, + ) + } + + return nil +} diff --git a/pkg/tracking/store/sql/models/alembic_version.go b/pkg/tracking/store/sql/models/alembic_version.go index bf30c95..b97ed95 100644 --- a/pkg/tracking/store/sql/models/alembic_version.go +++ b/pkg/tracking/store/sql/models/alembic_version.go @@ -1,11 +1,11 @@ -package models - -// AlembicVersion mapped from table . -type AlembicVersion struct { - VersionNum *string `db:"version_num" gorm:"column:version_num;primaryKey"` -} - -// TableName AlembicVersion's table name. -func (*AlembicVersion) TableName() string { - return "alembic_version" -} +package models + +// AlembicVersion mapped from table . +type AlembicVersion struct { + VersionNum *string `db:"version_num" gorm:"column:version_num;primaryKey"` +} + +// TableName AlembicVersion's table name. +func (*AlembicVersion) TableName() string { + return "alembic_version" +} diff --git a/pkg/tracking/store/sql/models/datasets.go b/pkg/tracking/store/sql/models/datasets.go index 1fda973..6618375 100644 --- a/pkg/tracking/store/sql/models/datasets.go +++ b/pkg/tracking/store/sql/models/datasets.go @@ -1,28 +1,28 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Dataset mapped from table . -type Dataset struct { - ID string `db:"dataset_uuid" gorm:"column:dataset_uuid;not null"` - ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` - Name string `db:"name" gorm:"column:name;primaryKey"` - Digest string `db:"digest" gorm:"column:digest;primaryKey"` - SourceType string `db:"dataset_source_type" gorm:"column:dataset_source_type;not null"` - Source string `db:"dataset_source" gorm:"column:dataset_source;not null"` - Schema string `db:"dataset_schema" gorm:"column:dataset_schema"` - Profile string `db:"dataset_profile" gorm:"column:dataset_profile"` -} - -func (d *Dataset) ToEntity() *entities.Dataset { - return &entities.Dataset{ - Name: d.Name, - Digest: d.Digest, - SourceType: d.SourceType, - Source: d.Source, - Schema: d.Schema, - Profile: d.Profile, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Dataset mapped from table . +type Dataset struct { + ID string `db:"dataset_uuid" gorm:"column:dataset_uuid;not null"` + ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` + Name string `db:"name" gorm:"column:name;primaryKey"` + Digest string `db:"digest" gorm:"column:digest;primaryKey"` + SourceType string `db:"dataset_source_type" gorm:"column:dataset_source_type;not null"` + Source string `db:"dataset_source" gorm:"column:dataset_source;not null"` + Schema string `db:"dataset_schema" gorm:"column:dataset_schema"` + Profile string `db:"dataset_profile" gorm:"column:dataset_profile"` +} + +func (d *Dataset) ToEntity() *entities.Dataset { + return &entities.Dataset{ + Name: d.Name, + Digest: d.Digest, + SourceType: d.SourceType, + Source: d.Source, + Schema: d.Schema, + Profile: d.Profile, + } +} diff --git a/pkg/tracking/store/sql/models/experiment_tags.go b/pkg/tracking/store/sql/models/experiment_tags.go index 54aca6b..8808f7f 100644 --- a/pkg/tracking/store/sql/models/experiment_tags.go +++ b/pkg/tracking/store/sql/models/experiment_tags.go @@ -1,10 +1,10 @@ -package models - -const TableNameExperimentTag = "experiment_tags" - -// ExperimentTag mapped from table . -type ExperimentTag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` -} +package models + +const TableNameExperimentTag = "experiment_tags" + +// ExperimentTag mapped from table . +type ExperimentTag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` +} diff --git a/pkg/tracking/store/sql/models/experiments.go b/pkg/tracking/store/sql/models/experiments.go index 2d057b3..fcecc46 100644 --- a/pkg/tracking/store/sql/models/experiments.go +++ b/pkg/tracking/store/sql/models/experiments.go @@ -1,40 +1,40 @@ -package models - -import ( - "strconv" - - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Experiment mapped from table . -type Experiment struct { - ID int32 `gorm:"column:experiment_id;primaryKey;autoIncrement:true"` - Name string `gorm:"column:name;not null"` - ArtifactLocation string `gorm:"column:artifact_location"` - LifecycleStage LifecycleStage `gorm:"column:lifecycle_stage"` - CreationTime int64 `gorm:"column:creation_time"` - LastUpdateTime int64 `gorm:"column:last_update_time"` - Tags []ExperimentTag - Runs []Run -} - -func (e Experiment) ToEntity() *entities.Experiment { - experiment := entities.Experiment{ - ExperimentID: strconv.Itoa(int(e.ID)), - Name: e.Name, - ArtifactLocation: e.ArtifactLocation, - LifecycleStage: e.LifecycleStage.String(), - CreationTime: e.CreationTime, - LastUpdateTime: e.LastUpdateTime, - Tags: make([]*entities.ExperimentTag, len(e.Tags)), - } - - for i, tag := range e.Tags { - experiment.Tags[i] = &entities.ExperimentTag{ - Key: tag.Key, - Value: tag.Value, - } - } - - return &experiment -} +package models + +import ( + "strconv" + + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Experiment mapped from table . +type Experiment struct { + ID int32 `gorm:"column:experiment_id;primaryKey;autoIncrement:true"` + Name string `gorm:"column:name;not null"` + ArtifactLocation string `gorm:"column:artifact_location"` + LifecycleStage LifecycleStage `gorm:"column:lifecycle_stage"` + CreationTime int64 `gorm:"column:creation_time"` + LastUpdateTime int64 `gorm:"column:last_update_time"` + Tags []ExperimentTag + Runs []Run +} + +func (e Experiment) ToEntity() *entities.Experiment { + experiment := entities.Experiment{ + ExperimentID: strconv.Itoa(int(e.ID)), + Name: e.Name, + ArtifactLocation: e.ArtifactLocation, + LifecycleStage: e.LifecycleStage.String(), + CreationTime: e.CreationTime, + LastUpdateTime: e.LastUpdateTime, + Tags: make([]*entities.ExperimentTag, len(e.Tags)), + } + + for i, tag := range e.Tags { + experiment.Tags[i] = &entities.ExperimentTag{ + Key: tag.Key, + Value: tag.Value, + } + } + + return &experiment +} diff --git a/pkg/tracking/store/sql/models/input_tags.go b/pkg/tracking/store/sql/models/input_tags.go index ca33747..b1d55b3 100644 --- a/pkg/tracking/store/sql/models/input_tags.go +++ b/pkg/tracking/store/sql/models/input_tags.go @@ -1,19 +1,19 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// InputTag mapped from table . -type InputTag struct { - Key string `gorm:"column:name;primaryKey"` - Value string `gorm:"column:value;not null"` - InputID string `gorm:"column:input_uuid;primaryKey"` -} - -func (i *InputTag) ToEntity() *entities.InputTag { - return &entities.InputTag{ - Key: i.Key, - Value: i.Value, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// InputTag mapped from table . +type InputTag struct { + Key string `gorm:"column:name;primaryKey"` + Value string `gorm:"column:value;not null"` + InputID string `gorm:"column:input_uuid;primaryKey"` +} + +func (i *InputTag) ToEntity() *entities.InputTag { + return &entities.InputTag{ + Key: i.Key, + Value: i.Value, + } +} diff --git a/pkg/tracking/store/sql/models/inputs.go b/pkg/tracking/store/sql/models/inputs.go index b9846a6..ed904eb 100644 --- a/pkg/tracking/store/sql/models/inputs.go +++ b/pkg/tracking/store/sql/models/inputs.go @@ -1,28 +1,28 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Input mapped from table . -type Input struct { - ID string `db:"input_uuid" gorm:"column:input_uuid;not null"` - SourceType string `db:"source_type" gorm:"column:source_type;primaryKey"` - SourceID string `db:"source_id" gorm:"column:source_id;primaryKey"` - DestinationType string `db:"destination_type" gorm:"column:destination_type;primaryKey"` - DestinationID string `db:"destination_id" gorm:"column:destination_id;primaryKey"` - Tags []InputTag `gorm:"foreignKey:InputID;references:ID"` - Dataset Dataset `gorm:"foreignKey:ID;references:SourceID"` -} - -func (i *Input) ToEntity() *entities.DatasetInput { - tags := make([]*entities.InputTag, 0, len(i.Tags)) - for _, tag := range i.Tags { - tags = append(tags, tag.ToEntity()) - } - - return &entities.DatasetInput{ - Tags: tags, - Dataset: i.Dataset.ToEntity(), - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Input mapped from table . +type Input struct { + ID string `db:"input_uuid" gorm:"column:input_uuid;not null"` + SourceType string `db:"source_type" gorm:"column:source_type;primaryKey"` + SourceID string `db:"source_id" gorm:"column:source_id;primaryKey"` + DestinationType string `db:"destination_type" gorm:"column:destination_type;primaryKey"` + DestinationID string `db:"destination_id" gorm:"column:destination_id;primaryKey"` + Tags []InputTag `gorm:"foreignKey:InputID;references:ID"` + Dataset Dataset `gorm:"foreignKey:ID;references:SourceID"` +} + +func (i *Input) ToEntity() *entities.DatasetInput { + tags := make([]*entities.InputTag, 0, len(i.Tags)) + for _, tag := range i.Tags { + tags = append(tags, tag.ToEntity()) + } + + return &entities.DatasetInput{ + Tags: tags, + Dataset: i.Dataset.ToEntity(), + } +} diff --git a/pkg/tracking/store/sql/models/latest_metrics.go b/pkg/tracking/store/sql/models/latest_metrics.go index 021e27e..0271650 100644 --- a/pkg/tracking/store/sql/models/latest_metrics.go +++ b/pkg/tracking/store/sql/models/latest_metrics.go @@ -1,25 +1,25 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// LatestMetric mapped from table . -type LatestMetric struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value float64 `db:"value" gorm:"column:value;not null"` - Timestamp int64 `db:"timestamp" gorm:"column:timestamp"` - Step int64 `db:"step" gorm:"column:step;not null"` - IsNaN bool `db:"is_nan" gorm:"column:is_nan;not null"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` -} - -func (lm LatestMetric) ToEntity() *entities.Metric { - return &entities.Metric{ - Key: lm.Key, - Value: lm.Value, - Timestamp: lm.Timestamp, - Step: lm.Step, - IsNaN: lm.IsNaN, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// LatestMetric mapped from table . +type LatestMetric struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value float64 `db:"value" gorm:"column:value;not null"` + Timestamp int64 `db:"timestamp" gorm:"column:timestamp"` + Step int64 `db:"step" gorm:"column:step;not null"` + IsNaN bool `db:"is_nan" gorm:"column:is_nan;not null"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` +} + +func (lm LatestMetric) ToEntity() *entities.Metric { + return &entities.Metric{ + Key: lm.Key, + Value: lm.Value, + Timestamp: lm.Timestamp, + Step: lm.Step, + IsNaN: lm.IsNaN, + } +} diff --git a/pkg/tracking/store/sql/models/lifecycle.go b/pkg/tracking/store/sql/models/lifecycle.go index c01ad5e..13cc716 100644 --- a/pkg/tracking/store/sql/models/lifecycle.go +++ b/pkg/tracking/store/sql/models/lifecycle.go @@ -1,12 +1,12 @@ -package models - -type LifecycleStage string - -func (s LifecycleStage) String() string { - return string(s) -} - -const ( - LifecycleStageActive LifecycleStage = "active" - LifecycleStageDeleted LifecycleStage = "deleted" -) +package models + +type LifecycleStage string + +func (s LifecycleStage) String() string { + return string(s) +} + +const ( + LifecycleStageActive LifecycleStage = "active" + LifecycleStageDeleted LifecycleStage = "deleted" +) diff --git a/pkg/tracking/store/sql/models/metrics.go b/pkg/tracking/store/sql/models/metrics.go index ef410a7..e100cdf 100644 --- a/pkg/tracking/store/sql/models/metrics.go +++ b/pkg/tracking/store/sql/models/metrics.go @@ -1,57 +1,57 @@ -package models - -import ( - "math" - - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Metric mapped from table . -type Metric struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value float64 `db:"value" gorm:"column:value;primaryKey"` - Timestamp int64 `db:"timestamp" gorm:"column:timestamp;primaryKey"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` - Step int64 `db:"step" gorm:"column:step;primaryKey"` - IsNaN bool `db:"is_nan" gorm:"column:is_nan;primaryKey"` -} - -func NewMetricFromEntity(runID string, metric *entities.Metric) *Metric { - model := Metric{ - RunID: runID, - Key: metric.Key, - Timestamp: metric.Timestamp, - } - - if metric.Step != 0 { - model.Step = metric.Step - } - - switch { - case math.IsNaN(metric.Value): - model.Value = 0 - model.IsNaN = true - case math.IsInf(metric.Value, 0): - // NB: SQL cannot represent Infs => We replace +/- Inf with max/min 64b float value - if metric.Value > 0 { - model.Value = math.MaxFloat64 - } else { - model.Value = -math.MaxFloat64 - } - default: - model.Value = metric.Value - } - - return &model -} - -func (m Metric) NewLatestMetricFromProto() LatestMetric { - return LatestMetric{ - RunID: m.RunID, - Key: m.Key, - Value: m.Value, - Timestamp: m.Timestamp, - Step: m.Step, - IsNaN: m.IsNaN, - } -} +package models + +import ( + "math" + + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Metric mapped from table . +type Metric struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value float64 `db:"value" gorm:"column:value;primaryKey"` + Timestamp int64 `db:"timestamp" gorm:"column:timestamp;primaryKey"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` + Step int64 `db:"step" gorm:"column:step;primaryKey"` + IsNaN bool `db:"is_nan" gorm:"column:is_nan;primaryKey"` +} + +func NewMetricFromEntity(runID string, metric *entities.Metric) *Metric { + model := Metric{ + RunID: runID, + Key: metric.Key, + Timestamp: metric.Timestamp, + } + + if metric.Step != 0 { + model.Step = metric.Step + } + + switch { + case math.IsNaN(metric.Value): + model.Value = 0 + model.IsNaN = true + case math.IsInf(metric.Value, 0): + // NB: SQL cannot represent Infs => We replace +/- Inf with max/min 64b float value + if metric.Value > 0 { + model.Value = math.MaxFloat64 + } else { + model.Value = -math.MaxFloat64 + } + default: + model.Value = metric.Value + } + + return &model +} + +func (m Metric) NewLatestMetricFromProto() LatestMetric { + return LatestMetric{ + RunID: m.RunID, + Key: m.Key, + Value: m.Value, + Timestamp: m.Timestamp, + Step: m.Step, + IsNaN: m.IsNaN, + } +} diff --git a/pkg/tracking/store/sql/models/params.go b/pkg/tracking/store/sql/models/params.go index 1f2f8f5..276fc97 100644 --- a/pkg/tracking/store/sql/models/params.go +++ b/pkg/tracking/store/sql/models/params.go @@ -1,27 +1,27 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Param mapped from table . -type Param struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value;not null"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` -} - -func (p Param) ToEntity() *entities.Param { - return &entities.Param{ - Key: p.Key, - Value: p.Value, - } -} - -func NewParamFromEntity(runID string, param *entities.Param) Param { - return Param{ - Key: param.Key, - Value: param.Value, - RunID: runID, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Param mapped from table . +type Param struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value;not null"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` +} + +func (p Param) ToEntity() *entities.Param { + return &entities.Param{ + Key: p.Key, + Value: p.Value, + } +} + +func NewParamFromEntity(runID string, param *entities.Param) Param { + return Param{ + Key: param.Key, + Value: param.Value, + RunID: runID, + } +} diff --git a/pkg/tracking/store/sql/models/runs.go b/pkg/tracking/store/sql/models/runs.go index 23ef484..810e9a8 100644 --- a/pkg/tracking/store/sql/models/runs.go +++ b/pkg/tracking/store/sql/models/runs.go @@ -1,106 +1,106 @@ -package models - -import ( - "database/sql" - - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -// Run mapped from table . -type Run struct { - ID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` - Name string `db:"name" gorm:"column:name"` - SourceType SourceType `db:"source_type" gorm:"column:source_type"` - SourceName string `db:"source_name" gorm:"column:source_name"` - EntryPointName string `db:"entry_point_name" gorm:"column:entry_point_name"` - UserID string `db:"user_id" gorm:"column:user_id"` - Status RunStatus `db:"status" gorm:"column:status"` - StartTime int64 `db:"start_time" gorm:"column:start_time"` - EndTime sql.NullInt64 `db:"end_time" gorm:"column:end_time"` - SourceVersion string `db:"source_version" gorm:"column:source_version"` - LifecycleStage LifecycleStage `db:"lifecycle_stage" gorm:"column:lifecycle_stage"` - ArtifactURI string `db:"artifact_uri" gorm:"column:artifact_uri"` - ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id"` - DeletedTime sql.NullInt64 `db:"deleted_time" gorm:"column:deleted_time"` - Params []Param - Tags []Tag - Metrics []Metric - LatestMetrics []LatestMetric - Inputs []Input `gorm:"foreignKey:DestinationID"` -} - -type RunStatus string - -func (s RunStatus) String() string { - return string(s) -} - -const ( - RunStatusRunning RunStatus = "RUNNING" - RunStatusScheduled RunStatus = "SCHEDULED" - RunStatusFinished RunStatus = "FINISHED" - RunStatusFailed RunStatus = "FAILED" - RunStatusKilled RunStatus = "KILLED" -) - -type SourceType string - -const ( - SourceTypeNotebook SourceType = "NOTEBOOK" - SourceTypeJob SourceType = "JOB" - SourceTypeProject SourceType = "PROJECT" - SourceTypeLocal SourceType = "LOCAL" - SourceTypeUnknown SourceType = "UNKNOWN" - SourceTypeRecipe SourceType = "RECIPE" -) - -func (r Run) ToEntity() *entities.Run { - metrics := make([]*entities.Metric, 0, len(r.LatestMetrics)) - for _, metric := range r.LatestMetrics { - metrics = append(metrics, metric.ToEntity()) - } - - params := make([]*entities.Param, 0, len(r.Params)) - for _, param := range r.Params { - params = append(params, param.ToEntity()) - } - - tags := make([]*entities.RunTag, 0, len(r.Tags)) - for _, tag := range r.Tags { - tags = append(tags, tag.ToEntity()) - } - - datasetInputs := make([]*entities.DatasetInput, 0, len(r.Inputs)) - for _, input := range r.Inputs { - datasetInputs = append(datasetInputs, input.ToEntity()) - } - - var endTime *int64 - if r.EndTime.Valid { - endTime = utils.PtrTo(r.EndTime.Int64) - } - - return &entities.Run{ - Info: &entities.RunInfo{ - RunID: r.ID, - RunUUID: r.ID, - RunName: r.Name, - ExperimentID: r.ExperimentID, - UserID: r.UserID, - Status: r.Status.String(), - StartTime: r.StartTime, - EndTime: endTime, - ArtifactURI: r.ArtifactURI, - LifecycleStage: r.LifecycleStage.String(), - }, - Data: &entities.RunData{ - Tags: tags, - Params: params, - Metrics: metrics, - }, - Inputs: &entities.RunInputs{ - DatasetInputs: datasetInputs, - }, - } -} +package models + +import ( + "database/sql" + + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +// Run mapped from table . +type Run struct { + ID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` + Name string `db:"name" gorm:"column:name"` + SourceType SourceType `db:"source_type" gorm:"column:source_type"` + SourceName string `db:"source_name" gorm:"column:source_name"` + EntryPointName string `db:"entry_point_name" gorm:"column:entry_point_name"` + UserID string `db:"user_id" gorm:"column:user_id"` + Status RunStatus `db:"status" gorm:"column:status"` + StartTime int64 `db:"start_time" gorm:"column:start_time"` + EndTime sql.NullInt64 `db:"end_time" gorm:"column:end_time"` + SourceVersion string `db:"source_version" gorm:"column:source_version"` + LifecycleStage LifecycleStage `db:"lifecycle_stage" gorm:"column:lifecycle_stage"` + ArtifactURI string `db:"artifact_uri" gorm:"column:artifact_uri"` + ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id"` + DeletedTime sql.NullInt64 `db:"deleted_time" gorm:"column:deleted_time"` + Params []Param + Tags []Tag + Metrics []Metric + LatestMetrics []LatestMetric + Inputs []Input `gorm:"foreignKey:DestinationID"` +} + +type RunStatus string + +func (s RunStatus) String() string { + return string(s) +} + +const ( + RunStatusRunning RunStatus = "RUNNING" + RunStatusScheduled RunStatus = "SCHEDULED" + RunStatusFinished RunStatus = "FINISHED" + RunStatusFailed RunStatus = "FAILED" + RunStatusKilled RunStatus = "KILLED" +) + +type SourceType string + +const ( + SourceTypeNotebook SourceType = "NOTEBOOK" + SourceTypeJob SourceType = "JOB" + SourceTypeProject SourceType = "PROJECT" + SourceTypeLocal SourceType = "LOCAL" + SourceTypeUnknown SourceType = "UNKNOWN" + SourceTypeRecipe SourceType = "RECIPE" +) + +func (r Run) ToEntity() *entities.Run { + metrics := make([]*entities.Metric, 0, len(r.LatestMetrics)) + for _, metric := range r.LatestMetrics { + metrics = append(metrics, metric.ToEntity()) + } + + params := make([]*entities.Param, 0, len(r.Params)) + for _, param := range r.Params { + params = append(params, param.ToEntity()) + } + + tags := make([]*entities.RunTag, 0, len(r.Tags)) + for _, tag := range r.Tags { + tags = append(tags, tag.ToEntity()) + } + + datasetInputs := make([]*entities.DatasetInput, 0, len(r.Inputs)) + for _, input := range r.Inputs { + datasetInputs = append(datasetInputs, input.ToEntity()) + } + + var endTime *int64 + if r.EndTime.Valid { + endTime = utils.PtrTo(r.EndTime.Int64) + } + + return &entities.Run{ + Info: &entities.RunInfo{ + RunID: r.ID, + RunUUID: r.ID, + RunName: r.Name, + ExperimentID: r.ExperimentID, + UserID: r.UserID, + Status: r.Status.String(), + StartTime: r.StartTime, + EndTime: endTime, + ArtifactURI: r.ArtifactURI, + LifecycleStage: r.LifecycleStage.String(), + }, + Data: &entities.RunData{ + Tags: tags, + Params: params, + Metrics: metrics, + }, + Inputs: &entities.RunInputs{ + DatasetInputs: datasetInputs, + }, + } +} diff --git a/pkg/tracking/store/sql/models/tags.go b/pkg/tracking/store/sql/models/tags.go index 0c8dbe1..d955a2a 100644 --- a/pkg/tracking/store/sql/models/tags.go +++ b/pkg/tracking/store/sql/models/tags.go @@ -1,31 +1,31 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Tag mapped from table . -type Tag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` -} - -func (t Tag) ToEntity() *entities.RunTag { - return &entities.RunTag{ - Key: t.Key, - Value: t.Value, - } -} - -func NewTagFromEntity(runID string, entity *entities.RunTag) Tag { - tag := Tag{ - Key: entity.Key, - Value: entity.Value, - } - if runID != "" { - tag.RunID = runID - } - - return tag -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Tag mapped from table . +type Tag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` +} + +func (t Tag) ToEntity() *entities.RunTag { + return &entities.RunTag{ + Key: t.Key, + Value: t.Value, + } +} + +func NewTagFromEntity(runID string, entity *entities.RunTag) Tag { + tag := Tag{ + Key: entity.Key, + Value: entity.Value, + } + if runID != "" { + tag.RunID = runID + } + + return tag +} diff --git a/pkg/tracking/store/sql/params.go b/pkg/tracking/store/sql/params.go index 2d6d6c5..cbe4da7 100644 --- a/pkg/tracking/store/sql/params.go +++ b/pkg/tracking/store/sql/params.go @@ -1,119 +1,119 @@ -package sql - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -const paramsBatchSize = 100 - -func verifyBatchParamsInserts( - transaction *gorm.DB, runID string, deduplicatedParamsMap map[string]string, -) *contract.Error { - keys := make([]string, 0, len(deduplicatedParamsMap)) - for key := range deduplicatedParamsMap { - keys = append(keys, key) - } - - var existingParams []models.Param - - err := transaction. - Model(&models.Param{}). - Select("key, value"). - Where("run_uuid = ?", runID). - Where("key IN ?", keys). - Find(&existingParams).Error - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf( - "failed to get existing params to check if duplicates for run_id %q", - runID, - ), - err) - } - - for _, existingParam := range existingParams { - if currentValue, ok := deduplicatedParamsMap[existingParam.Key]; ok && currentValue != existingParam.Value { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "Changing param values is not allowed. "+ - "Params with key=%q was already logged "+ - "with value=%q for run ID=%q. "+ - "Attempted logging new value %q", - existingParam.Key, - existingParam.Value, - runID, - currentValue, - ), - ) - } - } - - return nil -} - -func (s TrackingSQLStore) logParamsWithTransaction( - transaction *gorm.DB, runID string, params []*entities.Param, -) *contract.Error { - deduplicatedParamsMap := make(map[string]string, len(params)) - deduplicatedParams := make([]models.Param, 0, len(deduplicatedParamsMap)) - - for _, param := range params { - oldValue, paramIsPresent := deduplicatedParamsMap[param.Key] - if paramIsPresent && param.Value != oldValue { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "Changing param values is not allowed. "+ - "Params with key=%q was already logged "+ - "with value=%q for run ID=%q. "+ - "Attempted logging new value %q", - param.Key, - oldValue, - runID, - param.Value, - ), - ) - } - - if !paramIsPresent { - deduplicatedParamsMap[param.Key] = param.Value - deduplicatedParams = append(deduplicatedParams, models.NewParamFromEntity(runID, param)) - } - } - - // Try and create all params. - // Params are unique by (run_uuid, key) so any potentially conflicts will not be inserted. - err := transaction. - Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "run_uuid"}, {Name: "key"}}, - DoNothing: true, - }). - CreateInBatches(deduplicatedParams, paramsBatchSize).Error - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("error creating params in batch for run_uuid %q", runID), - err, - ) - } - - // if there were ignored conflicts, we assert that the values are the same. - if transaction.RowsAffected != int64(len(params)) { - contractError := verifyBatchParamsInserts(transaction, runID, deduplicatedParamsMap) - if contractError != nil { - return contractError - } - } - - return nil -} +package sql + +import ( + "fmt" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +const paramsBatchSize = 100 + +func verifyBatchParamsInserts( + transaction *gorm.DB, runID string, deduplicatedParamsMap map[string]string, +) *contract.Error { + keys := make([]string, 0, len(deduplicatedParamsMap)) + for key := range deduplicatedParamsMap { + keys = append(keys, key) + } + + var existingParams []models.Param + + err := transaction. + Model(&models.Param{}). + Select("key, value"). + Where("run_uuid = ?", runID). + Where("key IN ?", keys). + Find(&existingParams).Error + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "failed to get existing params to check if duplicates for run_id %q", + runID, + ), + err) + } + + for _, existingParam := range existingParams { + if currentValue, ok := deduplicatedParamsMap[existingParam.Key]; ok && currentValue != existingParam.Value { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "Changing param values is not allowed. "+ + "Params with key=%q was already logged "+ + "with value=%q for run ID=%q. "+ + "Attempted logging new value %q", + existingParam.Key, + existingParam.Value, + runID, + currentValue, + ), + ) + } + } + + return nil +} + +func (s TrackingSQLStore) logParamsWithTransaction( + transaction *gorm.DB, runID string, params []*entities.Param, +) *contract.Error { + deduplicatedParamsMap := make(map[string]string, len(params)) + deduplicatedParams := make([]models.Param, 0, len(deduplicatedParamsMap)) + + for _, param := range params { + oldValue, paramIsPresent := deduplicatedParamsMap[param.Key] + if paramIsPresent && param.Value != oldValue { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "Changing param values is not allowed. "+ + "Params with key=%q was already logged "+ + "with value=%q for run ID=%q. "+ + "Attempted logging new value %q", + param.Key, + oldValue, + runID, + param.Value, + ), + ) + } + + if !paramIsPresent { + deduplicatedParamsMap[param.Key] = param.Value + deduplicatedParams = append(deduplicatedParams, models.NewParamFromEntity(runID, param)) + } + } + + // Try and create all params. + // Params are unique by (run_uuid, key) so any potentially conflicts will not be inserted. + err := transaction. + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "run_uuid"}, {Name: "key"}}, + DoNothing: true, + }). + CreateInBatches(deduplicatedParams, paramsBatchSize).Error + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("error creating params in batch for run_uuid %q", runID), + err, + ) + } + + // if there were ignored conflicts, we assert that the values are the same. + if transaction.RowsAffected != int64(len(params)) { + contractError := verifyBatchParamsInserts(transaction, runID, deduplicatedParamsMap) + if contractError != nil { + return contractError + } + } + + return nil +} diff --git a/pkg/tracking/store/sql/runs.go b/pkg/tracking/store/sql/runs.go index 3c17920..7110671 100644 --- a/pkg/tracking/store/sql/runs.go +++ b/pkg/tracking/store/sql/runs.go @@ -1,900 +1,900 @@ -package sql - -import ( - "context" - "database/sql" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "strings" - "time" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type PageToken struct { - Offset int32 `json:"offset"` -} - -func checkRunIsActive(transaction *gorm.DB, runID string) *contract.Error { - var run models.Run - - err := transaction. - Model(&models.Run{}). - Where("run_uuid = ?", runID). - Select("lifecycle_stage"). - First(&run). - Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("Run with id=%s not found", runID), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf( - "failed to get lifecycle stage for run %q", - runID, - ), - err, - ) - } - - if run.LifecycleStage != models.LifecycleStageActive { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "The run %s must be in the 'active' state.\n"+ - "Current state is %v.", - runID, - run.LifecycleStage, - ), - ) - } - - return nil -} - -func getLifecyleStages(runViewType protos.ViewType) []models.LifecycleStage { - switch runViewType { - case protos.ViewType_ACTIVE_ONLY: - return []models.LifecycleStage{ - models.LifecycleStageActive, - } - case protos.ViewType_DELETED_ONLY: - return []models.LifecycleStage{ - models.LifecycleStageDeleted, - } - case protos.ViewType_ALL: - return []models.LifecycleStage{ - models.LifecycleStageActive, - models.LifecycleStageDeleted, - } - } - - return []models.LifecycleStage{ - models.LifecycleStageActive, - models.LifecycleStageDeleted, - } -} - -func getOffset(pageToken string) (int, *contract.Error) { - if pageToken != "" { - var token PageToken - if err := json.NewDecoder( - base64.NewDecoder( - base64.StdEncoding, - strings.NewReader(pageToken), - ), - ).Decode(&token); err != nil { - return 0, contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("invalid page_token: %q", pageToken), - err, - ) - } - - return int(token.Offset), nil - } - - return 0, nil -} - -//nolint:funlen,cyclop,gocognit -func applyFilter(ctx context.Context, database, transaction *gorm.DB, filter string) *contract.Error { - filterConditions, err := query.ParseFilter(filter) - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "error parsing search filter", - err, - ) - } - - utils.GetLoggerFromContext(ctx).Debugf("Filter conditions: %v", filterConditions) - - for index, clause := range filterConditions { - var kind any - - key := clause.Key - comparison := strings.ToUpper(clause.Operator.String()) - value := clause.Value - - switch clause.Identifier { - case parser.Metric: - kind = &models.LatestMetric{} - case parser.Parameter: - kind = &models.Param{} - case parser.Tag: - kind = &models.Tag{} - case parser.Dataset: - kind = &models.Dataset{} - case parser.Attribute: - kind = nil - } - - // Treat "attributes.run_name == " as "tags.`mlflow.runName` == ". - // The name column in the runs table is empty for runs logged in MLflow <= 1.29.0. - if key == "run_name" { - kind = &models.Tag{} - key = utils.TagRunName - } - - isSqliteAndILike := database.Dialector.Name() == "sqlite" && comparison == "ILIKE" - table := fmt.Sprintf("filter_%d", index) - - switch { - case kind == nil: - if isSqliteAndILike { - key = fmt.Sprintf("LOWER(runs.%s)", key) - comparison = "LIKE" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - - transaction.Where(fmt.Sprintf("%s %s ?", key, comparison), value) - } else { - transaction.Where(fmt.Sprintf("runs.%s %s ?", key, comparison), value) - } - case clause.Identifier == parser.Dataset && key == "context": - // SELECT * - // FROM runs - // JOIN ( - // SELECT inputs.destination_id AS run_uuid - // FROM inputs - // JOIN input_tags - // ON inputs.input_uuid = input_tags.input_uuid - // AND input_tags.name = 'mlflow.data.context' - // AND input_tags.value %s ? - // WHERE inputs.destination_type = 'RUN' - // ) AS filter_0 - // ON runs.run_uuid = filter_0.run_uuid - valueColumn := "input_tags.value " - if isSqliteAndILike { - valueColumn = "LOWER(input_tags.value) " - comparison = "LIKE" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - } - - transaction.Joins( - fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), - database.Select("inputs.destination_id AS run_uuid"). - Joins( - "JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid"+ - " AND input_tags.name = 'mlflow.data.context'"+ - " AND "+valueColumn+comparison+" ?", - value, - ). - Where("inputs.destination_type = 'RUN'"). - Model(&models.Input{}), - ) - case clause.Identifier == parser.Dataset: - // add join with datasets - // JOIN ( - // SELECT "experiment_id", key - // FROM datasests d - // JOIN inputs ON inputs.source_id = datasets.dataset_uuid - // WHERE key comparison value - // ) AS filter_0 ON runs.experiment_id = dataset.experiment_id - // - // columns: name, digest, context - where := key + " " + comparison + " ?" - if isSqliteAndILike { - where = "LOWER(" + key + ") LIKE ?" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - } - - transaction.Joins( - fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.destination_id", table, table), - database.Model(kind). - Joins("JOIN inputs ON inputs.source_id = datasets.dataset_uuid"). - Where(where, value). - Select("destination_id", key), - ) - default: - where := fmt.Sprintf("value %s ?", comparison) - if isSqliteAndILike { - where = "LOWER(value) LIKE ?" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - } - - transaction.Joins( - fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), - database.Select("run_uuid", "value").Where("key = ?", key).Where(where, value).Model(kind), - ) - } - } - - return nil -} - -type orderByExpr struct { - identifier *string - key string - order *string -} - -var ErrInvalidOrderClauseInput = errors.New("input string is empty or only contains quote characters") - -const ( - identifierAndKeyLength = 2 - startTime = "start_time" - name = "name" - attribute = "attribute" - metric = "metric" -) - -func orderByKeyAlias(input string) string { - switch input { - case "created", "Created": - return startTime - case "run_name", "run name", "Run name", "Run Name": - return name - case "run_id": - return "run_uuid" - default: - return input - } -} - -func handleInsideQuote( - char, quoteChar rune, insideQuote bool, current strings.Builder, result []string, -) (bool, strings.Builder, []string) { - if char == quoteChar { - insideQuote = false - - result = append(result, current.String()) - current.Reset() - } else { - current.WriteRune(char) - } - - return insideQuote, current, result -} - -func handleOutsideQuote( - char rune, insideQuote bool, quoteChar rune, current strings.Builder, result []string, -) (bool, rune, strings.Builder, []string) { - switch char { - case ' ': - if current.Len() > 0 { - result = append(result, current.String()) - current.Reset() - } - case '"', '\'', '`': - insideQuote = true - quoteChar = char - default: - current.WriteRune(char) - } - - return insideQuote, quoteChar, current, result -} - -// Process an order by input string to split the string into the separate parts. -// We can't simply split by space, because the column name could be wrapped in quotes, e.g. "Run name" ASC. -func splitOrderByClauseWithQuotes(input string) []string { - input = strings.ToLower(strings.Trim(input, " ")) - - var result []string - - var current strings.Builder - - var insideQuote bool - - var quoteChar rune - - // Process char per char, split items on spaces unless inside a quoted entry. - for _, char := range input { - if insideQuote { - insideQuote, current, result = handleInsideQuote(char, quoteChar, insideQuote, current, result) - } else { - insideQuote, quoteChar, current, result = handleOutsideQuote(char, insideQuote, quoteChar, current, result) - } - } - - if current.Len() > 0 { - result = append(result, current.String()) - } - - return result -} - -func translateIdentifierAlias(identifier string) string { - switch strings.ToLower(identifier) { - case "metrics": - return metric - case "parameters", "param", "params": - return "parameter" - case "tags": - return "tag" - case "attr", "attributes", "run": - return attribute - case "datasets": - return "dataset" - default: - return identifier - } -} - -func processOrderByClause(input string) (orderByExpr, error) { - parts := splitOrderByClauseWithQuotes(input) - - if len(parts) == 0 { - return orderByExpr{}, ErrInvalidOrderClauseInput - } - - var expr orderByExpr - - identifierKey := strings.Split(parts[0], ".") - - if len(identifierKey) == identifierAndKeyLength { - expr.identifier = utils.PtrTo(translateIdentifierAlias(identifierKey[0])) - expr.key = orderByKeyAlias(identifierKey[1]) - } else if len(identifierKey) == 1 { - expr.key = orderByKeyAlias(identifierKey[0]) - } - - if len(parts) > 1 { - expr.order = utils.PtrTo(strings.ToUpper(parts[1])) - } - - return expr, nil -} - -//nolint:funlen, cyclop, gocognit -func applyOrderBy(ctx context.Context, database, transaction *gorm.DB, orderBy []string) *contract.Error { - startTimeOrder := false - columnSelection := "runs.*" - - for index, orderByClause := range orderBy { - orderByExpr, err := processOrderByClause(orderByClause) - if err != nil { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "invalid order_by clause %q.", - orderByClause, - ), - ) - } - - logger := utils.GetLoggerFromContext(ctx) - logger. - Debugf( - "OrderByExpr: identifier: %v, key: %v, order: %v", - utils.DumpStringPointer(orderByExpr.identifier), - orderByExpr.key, - utils.DumpStringPointer(orderByExpr.order), - ) - - var kind any - - if orderByExpr.identifier == nil && orderByExpr.key == "start_time" { - startTimeOrder = true - } else if orderByExpr.identifier != nil { - switch { - case *orderByExpr.identifier == attribute && orderByExpr.key == "start_time": - startTimeOrder = true - case *orderByExpr.identifier == metric: - kind = &models.LatestMetric{} - case *orderByExpr.identifier == "parameter": - kind = &models.Param{} - case *orderByExpr.identifier == "tag": - kind = &models.Tag{} - } - } - - table := fmt.Sprintf("order_%d", index) - - if kind != nil { - columnsInJoin := []string{"run_uuid", "value"} - if *orderByExpr.identifier == metric { - columnsInJoin = append(columnsInJoin, "is_nan") - } - - transaction.Joins( - fmt.Sprintf("LEFT OUTER JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), - database.Select(columnsInJoin).Where("key = ?", orderByExpr.key).Model(kind), - ) - - orderByExpr.key = table + ".value" - } - - desc := false - if orderByExpr.order != nil { - desc = *orderByExpr.order == "DESC" - } - - nullableColumnAlias := fmt.Sprintf("order_null_%d", index) - - if orderByExpr.identifier == nil || *orderByExpr.identifier != metric { - var originalColumn string - - switch { - case orderByExpr.identifier != nil && *orderByExpr.identifier == "attribute": - originalColumn = "runs." + orderByExpr.key - case orderByExpr.identifier != nil: - originalColumn = table + ".value" - default: - originalColumn = orderByExpr.key - } - - columnSelection = fmt.Sprintf( - "%s, (CASE WHEN (%s IS NULL) THEN 1 ELSE 0 END) AS %s", - columnSelection, - originalColumn, - nullableColumnAlias, - ) - - transaction.Order(nullableColumnAlias) - } - - // the metric table has the is_nan column - if orderByExpr.identifier != nil && *orderByExpr.identifier == metric { - trueColumnValue := "true" - if database.Dialector.Name() == "sqlite" { - trueColumnValue = "1" - } - - columnSelection = fmt.Sprintf( - "%s, (CASE WHEN (%s.is_nan = %s) THEN 1 WHEN (%s.value IS NULL) THEN 2 ELSE 0 END) AS %s", - columnSelection, - table, - trueColumnValue, - table, - nullableColumnAlias, - ) - - transaction.Order(nullableColumnAlias) - } - - transaction.Order(clause.OrderByColumn{ - Column: clause.Column{ - Name: orderByExpr.key, - }, - Desc: desc, - }) - } - - if !startTimeOrder { - transaction.Order("runs.start_time DESC") - } - - transaction.Order("runs.run_uuid") - - // mlflow orders all nullable columns to have null last. - // For each order by clause, an additional dynamic order clause was added. - // We need to include these columns in the select clause. - transaction.Select(columnSelection) - - return nil -} - -func mkNextPageToken(runLength, maxResults, offset int) (string, *contract.Error) { - var nextPageToken string - - if runLength == maxResults { - var token strings.Builder - if err := json.NewEncoder( - base64.NewEncoder(base64.StdEncoding, &token), - ).Encode(PageToken{ - Offset: int32(offset + maxResults), - }); err != nil { - return "", contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "error encoding 'nextPageToken' value", - err, - ) - } - - nextPageToken = token.String() - } - - return nextPageToken, nil -} - -//nolint:funlen -func (s TrackingSQLStore) SearchRuns( - ctx context.Context, - experimentIDs []string, filter string, - runViewType protos.ViewType, maxResults int, orderBy []string, pageToken string, -) ([]*entities.Run, string, *contract.Error) { - // ViewType - lifecyleStages := getLifecyleStages(runViewType) - transaction := s.db.WithContext(ctx).Where( - "runs.experiment_id IN ?", experimentIDs, - ).Where( - "runs.lifecycle_stage IN ?", lifecyleStages, - ) - - // MaxResults - transaction.Limit(maxResults) - - // PageToken - offset, contractError := getOffset(pageToken) - if contractError != nil { - return nil, "", contractError - } - - transaction.Offset(offset) - - // Filter - contractError = applyFilter(ctx, s.db, transaction, filter) - if contractError != nil { - return nil, "", contractError - } - - // OrderBy - contractError = applyOrderBy(ctx, s.db, transaction, orderBy) - if contractError != nil { - return nil, "", contractError - } - - // Actual query - var runs []models.Run - - transaction.Preload("LatestMetrics").Preload("Params").Preload("Tags"). - Preload("Inputs", "inputs.destination_type = 'RUN'"). - Preload("Inputs.Dataset").Preload("Inputs.Tags").Find(&runs) - - if transaction.Error != nil { - return nil, "", contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "Failed to query search runs", - transaction.Error, - ) - } - - entityRuns := make([]*entities.Run, len(runs)) - for i, run := range runs { - entityRuns[i] = run.ToEntity() - } - - nextPageToken, contractError := mkNextPageToken(len(runs), maxResults, offset) - if contractError != nil { - return nil, "", contractError - } - - return entityRuns, nextPageToken, nil -} - -const RunIDMaxLength = 32 - -const ( - ArtifactFolderName = "artifacts" - RunNameIntegerScale = 3 - RunNameMaxLength = 20 -) - -func getRunNameFromTags(tags []models.Tag) string { - for _, tag := range tags { - if tag.Key == utils.TagRunName { - return tag.Value - } - } - - return "" -} - -func ensureRunName(runModel *models.Run) *contract.Error { - runNameFromTags := getRunNameFromTags(runModel.Tags) - - switch { - // run_name and name in tags differ - case utils.IsNotNilOrEmptyString(&runModel.Name) && runNameFromTags != "" && runModel.Name != runNameFromTags: - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "Both 'run_name' argument and 'mlflow.runName' tag are specified, but with "+ - "different values (run_name='%s', mlflow.runName='%s').", - runModel.Name, - runNameFromTags, - ), - ) - // no name was provided, generate a random name - case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags == "": - randomName, err := utils.GenerateRandomName() - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to generate random run name", - err, - ) - } - - runModel.Name = randomName - // use name from tags - case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags != "": - runModel.Name = runNameFromTags - } - - if runNameFromTags == "" { - runModel.Tags = append(runModel.Tags, models.Tag{ - Key: utils.TagRunName, - Value: runModel.Name, - }) - } - - return nil -} - -func (s TrackingSQLStore) GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) { - var run models.Run - if err := s.db.WithContext(ctx).Where( - "run_uuid = ?", runID, - ).Preload( - "Tags", - ).Preload( - "Params", - ).Preload( - "Inputs.Tags", - ).Preload( - "LatestMetrics", - ).Preload( - "Inputs.Dataset", - ).First(&run).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("Run with id=%s not found", runID), - ) - } - - return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to get run", err) - } - - return run.ToEntity(), nil -} - -//nolint:funlen -func (s TrackingSQLStore) CreateRun( - ctx context.Context, - experimentID, userID string, - startTime int64, - tags []*entities.RunTag, - runName string, -) (*entities.Run, *contract.Error) { - experiment, err := s.GetExperiment(ctx, experimentID) - if err != nil { - return nil, err - } - - if models.LifecycleStage(experiment.LifecycleStage) != models.LifecycleStageActive { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "The experiment %q must be in the 'active' state.\n"+ - "Current state is %q.", - experiment.ExperimentID, - experiment.LifecycleStage, - ), - ) - } - - runModel := &models.Run{ - ID: utils.NewUUID(), - Name: runName, - ExperimentID: utils.ConvertStringPointerToInt32Pointer(&experimentID), - StartTime: startTime, - UserID: userID, - Tags: make([]models.Tag, 0, len(tags)), - LifecycleStage: models.LifecycleStageActive, - Status: models.RunStatusRunning, - SourceType: models.SourceTypeUnknown, - } - - for _, tag := range tags { - runModel.Tags = append(runModel.Tags, models.NewTagFromEntity(runModel.ID, tag)) - } - - artifactLocation, appendErr := utils.AppendToURIPath( - experiment.ArtifactLocation, - runModel.ID, - ArtifactFolderName, - ) - if appendErr != nil { - return nil, contract.NewError( - protos.ErrorCode_INTERNAL_ERROR, - "failed to append run ID to experiment artifact location", - ) - } - - runModel.ArtifactURI = artifactLocation - - errRunName := ensureRunName(runModel) - if errRunName != nil { - return nil, errRunName - } - - if err := s.db.Create(&runModel).Error; err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf( - "failed to create run for experiment_id %q", - experiment.ExperimentID, - ), - err, - ) - } - - return runModel.ToEntity(), nil -} - -func (s TrackingSQLStore) UpdateRun( - ctx context.Context, - runID string, - runStatus string, - endTime *int64, - runName string, -) *contract.Error { - runTag, err := s.GetRunTag(ctx, runID, utils.TagRunName) - if err != nil { - return err - } - - tags := make([]models.Tag, 0, 1) - if runTag == nil { - tags = append(tags, models.Tag{ - RunID: runID, - Key: utils.TagRunName, - Value: runName, - }) - } - - var endTimeValue sql.NullInt64 - if endTime == nil { - endTimeValue = sql.NullInt64{} - } else { - endTimeValue = sql.NullInt64{Int64: *endTime, Valid: true} - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - if err := transaction.Model(&models.Run{}). - Where("run_uuid = ?", runID). - Updates(&models.Run{ - Name: runName, - Status: models.RunStatus(runStatus), - EndTime: endTimeValue, - }).Error; err != nil { - return err - } - - if len(tags) > 0 { - if err := transaction.Clauses(clause.OnConflict{ - UpdateAll: true, - }).CreateInBatches(tags, tagsBatchSize).Error; err != nil { - return fmt.Errorf("failed to create tags for run %q: %w", runID, err) - } - } - - return nil - }); err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update run", err) - } - - return nil -} - -func (s TrackingSQLStore) DeleteRun(ctx context.Context, runID string) *contract.Error { - run, err := s.GetRun(ctx, runID) - if err != nil { - return err - } - - if err := s.db.WithContext(ctx).Model(&models.Run{}). - Where("run_uuid = ?", run.Info.RunID). - Updates(&models.Run{ - DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, - LifecycleStage: models.LifecycleStageDeleted, - }).Error; err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to delete run", err) - } - - return nil -} - -func (s TrackingSQLStore) RestoreRun(ctx context.Context, runID string) *contract.Error { - run, err := s.GetRun(ctx, runID) - if err != nil { - return err - } - - if err := s.db.WithContext(ctx).Model(&models.Run{}). - Where("run_uuid = ?", run.Info.RunID). - // Force GORM to update fields with zero values by selecting them. - Select("DeletedTime", "LifecycleStage"). - Updates(&models.Run{ - DeletedTime: sql.NullInt64{}, - LifecycleStage: models.LifecycleStageActive, - }).Error; err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to restore run", err) - } - - return nil -} - -func (s TrackingSQLStore) LogBatch( - ctx context.Context, runID string, metrics []*entities.Metric, params []*entities.Param, tags []*entities.RunTag, -) *contract.Error { - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - contractError := checkRunIsActive(transaction, runID) - if contractError != nil { - return contractError - } - - err := s.setTagsWithTransaction(transaction, runID, tags) - if err != nil { - return fmt.Errorf("error setting tags for run_id %q: %w", runID, err) - } - - contractError = s.logParamsWithTransaction(transaction, runID, params) - if contractError != nil { - return contractError - } - - contractError = s.logMetricsWithTransaction(transaction, runID, metrics) - if contractError != nil { - return contractError - } - - return nil - }) - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return contractError - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("log batch transaction failed for %q", runID), - err, - ) - } - - return nil -} +package sql + +import ( + "context" + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type PageToken struct { + Offset int32 `json:"offset"` +} + +func checkRunIsActive(transaction *gorm.DB, runID string) *contract.Error { + var run models.Run + + err := transaction. + Model(&models.Run{}). + Where("run_uuid = ?", runID). + Select("lifecycle_stage"). + First(&run). + Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("Run with id=%s not found", runID), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "failed to get lifecycle stage for run %q", + runID, + ), + err, + ) + } + + if run.LifecycleStage != models.LifecycleStageActive { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "The run %s must be in the 'active' state.\n"+ + "Current state is %v.", + runID, + run.LifecycleStage, + ), + ) + } + + return nil +} + +func getLifecyleStages(runViewType protos.ViewType) []models.LifecycleStage { + switch runViewType { + case protos.ViewType_ACTIVE_ONLY: + return []models.LifecycleStage{ + models.LifecycleStageActive, + } + case protos.ViewType_DELETED_ONLY: + return []models.LifecycleStage{ + models.LifecycleStageDeleted, + } + case protos.ViewType_ALL: + return []models.LifecycleStage{ + models.LifecycleStageActive, + models.LifecycleStageDeleted, + } + } + + return []models.LifecycleStage{ + models.LifecycleStageActive, + models.LifecycleStageDeleted, + } +} + +func getOffset(pageToken string) (int, *contract.Error) { + if pageToken != "" { + var token PageToken + if err := json.NewDecoder( + base64.NewDecoder( + base64.StdEncoding, + strings.NewReader(pageToken), + ), + ).Decode(&token); err != nil { + return 0, contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("invalid page_token: %q", pageToken), + err, + ) + } + + return int(token.Offset), nil + } + + return 0, nil +} + +//nolint:funlen,cyclop,gocognit +func applyFilter(ctx context.Context, database, transaction *gorm.DB, filter string) *contract.Error { + filterConditions, err := query.ParseFilter(filter) + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + "error parsing search filter", + err, + ) + } + + utils.GetLoggerFromContext(ctx).Debugf("Filter conditions: %v", filterConditions) + + for index, clause := range filterConditions { + var kind any + + key := clause.Key + comparison := strings.ToUpper(clause.Operator.String()) + value := clause.Value + + switch clause.Identifier { + case parser.Metric: + kind = &models.LatestMetric{} + case parser.Parameter: + kind = &models.Param{} + case parser.Tag: + kind = &models.Tag{} + case parser.Dataset: + kind = &models.Dataset{} + case parser.Attribute: + kind = nil + } + + // Treat "attributes.run_name == " as "tags.`mlflow.runName` == ". + // The name column in the runs table is empty for runs logged in MLflow <= 1.29.0. + if key == "run_name" { + kind = &models.Tag{} + key = utils.TagRunName + } + + isSqliteAndILike := database.Dialector.Name() == "sqlite" && comparison == "ILIKE" + table := fmt.Sprintf("filter_%d", index) + + switch { + case kind == nil: + if isSqliteAndILike { + key = fmt.Sprintf("LOWER(runs.%s)", key) + comparison = "LIKE" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + + transaction.Where(fmt.Sprintf("%s %s ?", key, comparison), value) + } else { + transaction.Where(fmt.Sprintf("runs.%s %s ?", key, comparison), value) + } + case clause.Identifier == parser.Dataset && key == "context": + // SELECT * + // FROM runs + // JOIN ( + // SELECT inputs.destination_id AS run_uuid + // FROM inputs + // JOIN input_tags + // ON inputs.input_uuid = input_tags.input_uuid + // AND input_tags.name = 'mlflow.data.context' + // AND input_tags.value %s ? + // WHERE inputs.destination_type = 'RUN' + // ) AS filter_0 + // ON runs.run_uuid = filter_0.run_uuid + valueColumn := "input_tags.value " + if isSqliteAndILike { + valueColumn = "LOWER(input_tags.value) " + comparison = "LIKE" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + } + + transaction.Joins( + fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), + database.Select("inputs.destination_id AS run_uuid"). + Joins( + "JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid"+ + " AND input_tags.name = 'mlflow.data.context'"+ + " AND "+valueColumn+comparison+" ?", + value, + ). + Where("inputs.destination_type = 'RUN'"). + Model(&models.Input{}), + ) + case clause.Identifier == parser.Dataset: + // add join with datasets + // JOIN ( + // SELECT "experiment_id", key + // FROM datasests d + // JOIN inputs ON inputs.source_id = datasets.dataset_uuid + // WHERE key comparison value + // ) AS filter_0 ON runs.experiment_id = dataset.experiment_id + // + // columns: name, digest, context + where := key + " " + comparison + " ?" + if isSqliteAndILike { + where = "LOWER(" + key + ") LIKE ?" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + } + + transaction.Joins( + fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.destination_id", table, table), + database.Model(kind). + Joins("JOIN inputs ON inputs.source_id = datasets.dataset_uuid"). + Where(where, value). + Select("destination_id", key), + ) + default: + where := fmt.Sprintf("value %s ?", comparison) + if isSqliteAndILike { + where = "LOWER(value) LIKE ?" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + } + + transaction.Joins( + fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), + database.Select("run_uuid", "value").Where("key = ?", key).Where(where, value).Model(kind), + ) + } + } + + return nil +} + +type orderByExpr struct { + identifier *string + key string + order *string +} + +var ErrInvalidOrderClauseInput = errors.New("input string is empty or only contains quote characters") + +const ( + identifierAndKeyLength = 2 + startTime = "start_time" + name = "name" + attribute = "attribute" + metric = "metric" +) + +func orderByKeyAlias(input string) string { + switch input { + case "created", "Created": + return startTime + case "run_name", "run name", "Run name", "Run Name": + return name + case "run_id": + return "run_uuid" + default: + return input + } +} + +func handleInsideQuote( + char, quoteChar rune, insideQuote bool, current strings.Builder, result []string, +) (bool, strings.Builder, []string) { + if char == quoteChar { + insideQuote = false + + result = append(result, current.String()) + current.Reset() + } else { + current.WriteRune(char) + } + + return insideQuote, current, result +} + +func handleOutsideQuote( + char rune, insideQuote bool, quoteChar rune, current strings.Builder, result []string, +) (bool, rune, strings.Builder, []string) { + switch char { + case ' ': + if current.Len() > 0 { + result = append(result, current.String()) + current.Reset() + } + case '"', '\'', '`': + insideQuote = true + quoteChar = char + default: + current.WriteRune(char) + } + + return insideQuote, quoteChar, current, result +} + +// Process an order by input string to split the string into the separate parts. +// We can't simply split by space, because the column name could be wrapped in quotes, e.g. "Run name" ASC. +func splitOrderByClauseWithQuotes(input string) []string { + input = strings.ToLower(strings.Trim(input, " ")) + + var result []string + + var current strings.Builder + + var insideQuote bool + + var quoteChar rune + + // Process char per char, split items on spaces unless inside a quoted entry. + for _, char := range input { + if insideQuote { + insideQuote, current, result = handleInsideQuote(char, quoteChar, insideQuote, current, result) + } else { + insideQuote, quoteChar, current, result = handleOutsideQuote(char, insideQuote, quoteChar, current, result) + } + } + + if current.Len() > 0 { + result = append(result, current.String()) + } + + return result +} + +func translateIdentifierAlias(identifier string) string { + switch strings.ToLower(identifier) { + case "metrics": + return metric + case "parameters", "param", "params": + return "parameter" + case "tags": + return "tag" + case "attr", "attributes", "run": + return attribute + case "datasets": + return "dataset" + default: + return identifier + } +} + +func processOrderByClause(input string) (orderByExpr, error) { + parts := splitOrderByClauseWithQuotes(input) + + if len(parts) == 0 { + return orderByExpr{}, ErrInvalidOrderClauseInput + } + + var expr orderByExpr + + identifierKey := strings.Split(parts[0], ".") + + if len(identifierKey) == identifierAndKeyLength { + expr.identifier = utils.PtrTo(translateIdentifierAlias(identifierKey[0])) + expr.key = orderByKeyAlias(identifierKey[1]) + } else if len(identifierKey) == 1 { + expr.key = orderByKeyAlias(identifierKey[0]) + } + + if len(parts) > 1 { + expr.order = utils.PtrTo(strings.ToUpper(parts[1])) + } + + return expr, nil +} + +//nolint:funlen, cyclop, gocognit +func applyOrderBy(ctx context.Context, database, transaction *gorm.DB, orderBy []string) *contract.Error { + startTimeOrder := false + columnSelection := "runs.*" + + for index, orderByClause := range orderBy { + orderByExpr, err := processOrderByClause(orderByClause) + if err != nil { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "invalid order_by clause %q.", + orderByClause, + ), + ) + } + + logger := utils.GetLoggerFromContext(ctx) + logger. + Debugf( + "OrderByExpr: identifier: %v, key: %v, order: %v", + utils.DumpStringPointer(orderByExpr.identifier), + orderByExpr.key, + utils.DumpStringPointer(orderByExpr.order), + ) + + var kind any + + if orderByExpr.identifier == nil && orderByExpr.key == "start_time" { + startTimeOrder = true + } else if orderByExpr.identifier != nil { + switch { + case *orderByExpr.identifier == attribute && orderByExpr.key == "start_time": + startTimeOrder = true + case *orderByExpr.identifier == metric: + kind = &models.LatestMetric{} + case *orderByExpr.identifier == "parameter": + kind = &models.Param{} + case *orderByExpr.identifier == "tag": + kind = &models.Tag{} + } + } + + table := fmt.Sprintf("order_%d", index) + + if kind != nil { + columnsInJoin := []string{"run_uuid", "value"} + if *orderByExpr.identifier == metric { + columnsInJoin = append(columnsInJoin, "is_nan") + } + + transaction.Joins( + fmt.Sprintf("LEFT OUTER JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), + database.Select(columnsInJoin).Where("key = ?", orderByExpr.key).Model(kind), + ) + + orderByExpr.key = table + ".value" + } + + desc := false + if orderByExpr.order != nil { + desc = *orderByExpr.order == "DESC" + } + + nullableColumnAlias := fmt.Sprintf("order_null_%d", index) + + if orderByExpr.identifier == nil || *orderByExpr.identifier != metric { + var originalColumn string + + switch { + case orderByExpr.identifier != nil && *orderByExpr.identifier == "attribute": + originalColumn = "runs." + orderByExpr.key + case orderByExpr.identifier != nil: + originalColumn = table + ".value" + default: + originalColumn = orderByExpr.key + } + + columnSelection = fmt.Sprintf( + "%s, (CASE WHEN (%s IS NULL) THEN 1 ELSE 0 END) AS %s", + columnSelection, + originalColumn, + nullableColumnAlias, + ) + + transaction.Order(nullableColumnAlias) + } + + // the metric table has the is_nan column + if orderByExpr.identifier != nil && *orderByExpr.identifier == metric { + trueColumnValue := "true" + if database.Dialector.Name() == "sqlite" { + trueColumnValue = "1" + } + + columnSelection = fmt.Sprintf( + "%s, (CASE WHEN (%s.is_nan = %s) THEN 1 WHEN (%s.value IS NULL) THEN 2 ELSE 0 END) AS %s", + columnSelection, + table, + trueColumnValue, + table, + nullableColumnAlias, + ) + + transaction.Order(nullableColumnAlias) + } + + transaction.Order(clause.OrderByColumn{ + Column: clause.Column{ + Name: orderByExpr.key, + }, + Desc: desc, + }) + } + + if !startTimeOrder { + transaction.Order("runs.start_time DESC") + } + + transaction.Order("runs.run_uuid") + + // mlflow orders all nullable columns to have null last. + // For each order by clause, an additional dynamic order clause was added. + // We need to include these columns in the select clause. + transaction.Select(columnSelection) + + return nil +} + +func mkNextPageToken(runLength, maxResults, offset int) (string, *contract.Error) { + var nextPageToken string + + if runLength == maxResults { + var token strings.Builder + if err := json.NewEncoder( + base64.NewEncoder(base64.StdEncoding, &token), + ).Encode(PageToken{ + Offset: int32(offset + maxResults), + }); err != nil { + return "", contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "error encoding 'nextPageToken' value", + err, + ) + } + + nextPageToken = token.String() + } + + return nextPageToken, nil +} + +//nolint:funlen +func (s TrackingSQLStore) SearchRuns( + ctx context.Context, + experimentIDs []string, filter string, + runViewType protos.ViewType, maxResults int, orderBy []string, pageToken string, +) ([]*entities.Run, string, *contract.Error) { + // ViewType + lifecyleStages := getLifecyleStages(runViewType) + transaction := s.db.WithContext(ctx).Where( + "runs.experiment_id IN ?", experimentIDs, + ).Where( + "runs.lifecycle_stage IN ?", lifecyleStages, + ) + + // MaxResults + transaction.Limit(maxResults) + + // PageToken + offset, contractError := getOffset(pageToken) + if contractError != nil { + return nil, "", contractError + } + + transaction.Offset(offset) + + // Filter + contractError = applyFilter(ctx, s.db, transaction, filter) + if contractError != nil { + return nil, "", contractError + } + + // OrderBy + contractError = applyOrderBy(ctx, s.db, transaction, orderBy) + if contractError != nil { + return nil, "", contractError + } + + // Actual query + var runs []models.Run + + transaction.Preload("LatestMetrics").Preload("Params").Preload("Tags"). + Preload("Inputs", "inputs.destination_type = 'RUN'"). + Preload("Inputs.Dataset").Preload("Inputs.Tags").Find(&runs) + + if transaction.Error != nil { + return nil, "", contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "Failed to query search runs", + transaction.Error, + ) + } + + entityRuns := make([]*entities.Run, len(runs)) + for i, run := range runs { + entityRuns[i] = run.ToEntity() + } + + nextPageToken, contractError := mkNextPageToken(len(runs), maxResults, offset) + if contractError != nil { + return nil, "", contractError + } + + return entityRuns, nextPageToken, nil +} + +const RunIDMaxLength = 32 + +const ( + ArtifactFolderName = "artifacts" + RunNameIntegerScale = 3 + RunNameMaxLength = 20 +) + +func getRunNameFromTags(tags []models.Tag) string { + for _, tag := range tags { + if tag.Key == utils.TagRunName { + return tag.Value + } + } + + return "" +} + +func ensureRunName(runModel *models.Run) *contract.Error { + runNameFromTags := getRunNameFromTags(runModel.Tags) + + switch { + // run_name and name in tags differ + case utils.IsNotNilOrEmptyString(&runModel.Name) && runNameFromTags != "" && runModel.Name != runNameFromTags: + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "Both 'run_name' argument and 'mlflow.runName' tag are specified, but with "+ + "different values (run_name='%s', mlflow.runName='%s').", + runModel.Name, + runNameFromTags, + ), + ) + // no name was provided, generate a random name + case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags == "": + randomName, err := utils.GenerateRandomName() + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to generate random run name", + err, + ) + } + + runModel.Name = randomName + // use name from tags + case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags != "": + runModel.Name = runNameFromTags + } + + if runNameFromTags == "" { + runModel.Tags = append(runModel.Tags, models.Tag{ + Key: utils.TagRunName, + Value: runModel.Name, + }) + } + + return nil +} + +func (s TrackingSQLStore) GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) { + var run models.Run + if err := s.db.WithContext(ctx).Where( + "run_uuid = ?", runID, + ).Preload( + "Tags", + ).Preload( + "Params", + ).Preload( + "Inputs.Tags", + ).Preload( + "LatestMetrics", + ).Preload( + "Inputs.Dataset", + ).First(&run).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("Run with id=%s not found", runID), + ) + } + + return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to get run", err) + } + + return run.ToEntity(), nil +} + +//nolint:funlen +func (s TrackingSQLStore) CreateRun( + ctx context.Context, + experimentID, userID string, + startTime int64, + tags []*entities.RunTag, + runName string, +) (*entities.Run, *contract.Error) { + experiment, err := s.GetExperiment(ctx, experimentID) + if err != nil { + return nil, err + } + + if models.LifecycleStage(experiment.LifecycleStage) != models.LifecycleStageActive { + return nil, contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "The experiment %q must be in the 'active' state.\n"+ + "Current state is %q.", + experiment.ExperimentID, + experiment.LifecycleStage, + ), + ) + } + + runModel := &models.Run{ + ID: utils.NewUUID(), + Name: runName, + ExperimentID: utils.ConvertStringPointerToInt32Pointer(&experimentID), + StartTime: startTime, + UserID: userID, + Tags: make([]models.Tag, 0, len(tags)), + LifecycleStage: models.LifecycleStageActive, + Status: models.RunStatusRunning, + SourceType: models.SourceTypeUnknown, + } + + for _, tag := range tags { + runModel.Tags = append(runModel.Tags, models.NewTagFromEntity(runModel.ID, tag)) + } + + artifactLocation, appendErr := utils.AppendToURIPath( + experiment.ArtifactLocation, + runModel.ID, + ArtifactFolderName, + ) + if appendErr != nil { + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + "failed to append run ID to experiment artifact location", + ) + } + + runModel.ArtifactURI = artifactLocation + + errRunName := ensureRunName(runModel) + if errRunName != nil { + return nil, errRunName + } + + if err := s.db.Create(&runModel).Error; err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "failed to create run for experiment_id %q", + experiment.ExperimentID, + ), + err, + ) + } + + return runModel.ToEntity(), nil +} + +func (s TrackingSQLStore) UpdateRun( + ctx context.Context, + runID string, + runStatus string, + endTime *int64, + runName string, +) *contract.Error { + runTag, err := s.GetRunTag(ctx, runID, utils.TagRunName) + if err != nil { + return err + } + + tags := make([]models.Tag, 0, 1) + if runTag == nil { + tags = append(tags, models.Tag{ + RunID: runID, + Key: utils.TagRunName, + Value: runName, + }) + } + + var endTimeValue sql.NullInt64 + if endTime == nil { + endTimeValue = sql.NullInt64{} + } else { + endTimeValue = sql.NullInt64{Int64: *endTime, Valid: true} + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + if err := transaction.Model(&models.Run{}). + Where("run_uuid = ?", runID). + Updates(&models.Run{ + Name: runName, + Status: models.RunStatus(runStatus), + EndTime: endTimeValue, + }).Error; err != nil { + return err + } + + if len(tags) > 0 { + if err := transaction.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(tags, tagsBatchSize).Error; err != nil { + return fmt.Errorf("failed to create tags for run %q: %w", runID, err) + } + } + + return nil + }); err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update run", err) + } + + return nil +} + +func (s TrackingSQLStore) DeleteRun(ctx context.Context, runID string) *contract.Error { + run, err := s.GetRun(ctx, runID) + if err != nil { + return err + } + + if err := s.db.WithContext(ctx).Model(&models.Run{}). + Where("run_uuid = ?", run.Info.RunID). + Updates(&models.Run{ + DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, + LifecycleStage: models.LifecycleStageDeleted, + }).Error; err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to delete run", err) + } + + return nil +} + +func (s TrackingSQLStore) RestoreRun(ctx context.Context, runID string) *contract.Error { + run, err := s.GetRun(ctx, runID) + if err != nil { + return err + } + + if err := s.db.WithContext(ctx).Model(&models.Run{}). + Where("run_uuid = ?", run.Info.RunID). + // Force GORM to update fields with zero values by selecting them. + Select("DeletedTime", "LifecycleStage"). + Updates(&models.Run{ + DeletedTime: sql.NullInt64{}, + LifecycleStage: models.LifecycleStageActive, + }).Error; err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to restore run", err) + } + + return nil +} + +func (s TrackingSQLStore) LogBatch( + ctx context.Context, runID string, metrics []*entities.Metric, params []*entities.Param, tags []*entities.RunTag, +) *contract.Error { + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + err := s.setTagsWithTransaction(transaction, runID, tags) + if err != nil { + return fmt.Errorf("error setting tags for run_id %q: %w", runID, err) + } + + contractError = s.logParamsWithTransaction(transaction, runID, params) + if contractError != nil { + return contractError + } + + contractError = s.logMetricsWithTransaction(transaction, runID, metrics) + if contractError != nil { + return contractError + } + + return nil + }) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("log batch transaction failed for %q", runID), + err, + ) + } + + return nil +} diff --git a/pkg/tracking/store/sql/runs_internal_test.go b/pkg/tracking/store/sql/runs_internal_test.go index f42583d..94cb3d4 100644 --- a/pkg/tracking/store/sql/runs_internal_test.go +++ b/pkg/tracking/store/sql/runs_internal_test.go @@ -1,518 +1,518 @@ -//nolint:ireturn -package sql - -import ( - "context" - "reflect" - "regexp" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/iancoleman/strcase" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/driver/sqlserver" - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type testData struct { - name string - query string - orderBy []string - expectedSQL map[string]string - expectedVars []any -} - -var whitespaceRegex = regexp.MustCompile(`\s` + "|`") - -func removeWhitespace(s string) string { - return whitespaceRegex.ReplaceAllString(s, "") -} - -var tests = []testData{ - { - name: "SimpleMetricQuery", - query: "metrics.accuracy > 0.72", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) - AS filter_0 - ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlserver": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = @p1 AND value > @p2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "mysql": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"accuracy", 0.72}, - }, - { - name: "SimpleMetricAndParamQuery", - query: "metrics.accuracy > 0.72 AND params.batch_size = '2'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value = $4) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND value = ?) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"accuracy", 0.72, "batch_size", "2"}, - }, - { - name: "TagQuery", - query: "tags.environment = 'notebook' AND tags.task ILIKE 'classif%'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $3 AND value ILIKE $4) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"environment", "notebook", "task", "classif%"}, - }, - { - name: "DatasestsInQuery", - query: "datasets.digest IN ('s8ds293b', 'jks834s2')", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT destination_id,"digest" - FROM "datasets" JOIN inputs ON inputs.source_id = datasets.dataset_uuid - WHERE digest IN ($1,$2) - ) - AS filter_0 ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT destination_id,digest - FROM datasets JOIN inputs - ON inputs.source_id = datasets.dataset_uuid - WHERE digest IN (?,?) - ) - AS filter_0 - ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"s8ds293b", "jks834s2"}, - }, - { - name: "AttributesQuery", - query: "attributes.run_id = 'a1b2c3d4'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - WHERE runs.run_uuid = $1 - ORDER BY runs.start_time DESC,runs.run_uuid - `, - "sqlite": `SELECT run_uuid FROM runs WHERE runs.run_uuid = ? ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"a1b2c3d4"}, - }, - { - name: "Run_nameQuery", - query: "attributes.run_name = 'my-run'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"mlflow.runName", "my-run"}, - }, - { - name: "DatasetsContextQuery", - query: "datasets.context = 'train'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT inputs.destination_id AS run_uuid - FROM "inputs" - JOIN input_tags - ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND input_tags.value = $1 - WHERE inputs.destination_type = 'RUN' - ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT inputs.destination_id AS run_uuid - FROM inputs - JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND input_tags.value = ? WHERE inputs.destination_type = 'RUN' - ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"train"}, - }, - { - name: "Run_nameQuery", - query: "attributes.run_name ILIKE 'my-run%'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value ILIKE $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid, value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"mlflow.runName", "my-run%"}, - }, - { - name: "DatasetsContextQuery", - query: "datasets.context ILIKE '%train'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT inputs.destination_id AS run_uuid FROM "inputs" - JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND input_tags.value ILIKE $1 WHERE inputs.destination_type = 'RUN' - ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT inputs.destination_id AS run_uuid FROM inputs - JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND LOWER(input_tags.value) LIKE ? WHERE inputs.destination_type = 'RUN') - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid - `, - }, - expectedVars: []any{"%train"}, - }, - { - name: "DatasestsDigest", - query: "datasets.digest ILIKE '%s'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT destination_id,"digest" - FROM "datasets" - JOIN inputs ON inputs.source_id = datasets.dataset_uuid - WHERE digest ILIKE $1 - ) - AS filter_0 ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT destination_id,digest - FROM datasets - JOIN inputs ON inputs.source_id = datasets.dataset_uuid - WHERE LOWER(digest) LIKE ?) - AS filter_0 ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"%s"}, - }, - { - name: "ParamQuery", - query: "metrics.accuracy > 0.72 AND params.batch_size ILIKE '%a'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value ILIKE $4) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid, value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND LOWER(value) LIKE ?) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid - `, - }, - expectedVars: []any{"accuracy", 0.72, "batch_size", "%a"}, - }, - { - name: "OrderByStartTimeASC", - query: "", - orderBy: []string{"start_time ASC"}, - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "start_time",runs.run_uuid`, - }, - expectedVars: []any{}, - }, - { - name: "OrderByStatusDesc", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "status" DESC,runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"status DESC"}, - expectedVars: []any{}, - }, - { - name: "OrderByRunNameSnakeCase", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"run_name"}, - expectedVars: []any{}, - }, - { - name: "OrderByRunNameLowerName", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"`Run name`"}, - expectedVars: []any{}, - }, - { - name: "OrderByRunNamePascal", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"`Run Name`"}, - expectedVars: []any{}, - }, -} - -func newPostgresDialector() gorm.Dialector { - mockedDB, _, _ := sqlmock.New() - - return postgres.New(postgres.Config{ - Conn: mockedDB, - DriverName: "postgres", - }) -} - -func newSqliteDialector() gorm.Dialector { - mockedDB, mock, _ := sqlmock.New() - mock.ExpectQuery("select sqlite_version()").WillReturnRows( - sqlmock.NewRows([]string{"sqlite_version()"}).AddRow("3.41.1")) - - return sqlite.New(sqlite.Config{ - DriverName: "sqlite3", - Conn: mockedDB, - }) -} - -func newSQLServerDialector() gorm.Dialector { - mockedDB, _, _ := sqlmock.New() - - return sqlserver.New(sqlserver.Config{ - DriverName: "sqlserver", - Conn: mockedDB, - }) -} - -func newMySQLDialector() gorm.Dialector { - mockedDB, _, _ := sqlmock.New() - - return mysql.New(mysql.Config{ - DriverName: "mysql", - Conn: mockedDB, - SkipInitializeWithVersion: true, - }) -} - -var dialectors = []gorm.Dialector{ - newPostgresDialector(), - newSqliteDialector(), - newSQLServerDialector(), - newMySQLDialector(), -} - -func assertTestData( - t *testing.T, database *gorm.DB, expectedSQL string, testData testData, -) { - t.Helper() - - transaction := database.Model(&models.Run{}) - - contractErr := applyFilter(context.Background(), database, transaction, testData.query) - if contractErr != nil { - t.Fatal("contractErr: ", contractErr) - } - - contractErr = applyOrderBy(context.Background(), database, transaction, testData.orderBy) - if contractErr != nil { - t.Fatal("contractErr: ", contractErr) - } - - sqlErr := transaction.Select("ID").Find(&models.Run{}).Error - require.NoError(t, sqlErr) - - actualSQL := transaction.Statement.SQL.String() - - // if removeWhitespace(expectedSQL) != removeWhitespace(actualSQL) { - // fmt.Println(strings.ReplaceAll(actualSQL, "`", "")) - // } - - assert.Equal(t, removeWhitespace(expectedSQL), removeWhitespace(actualSQL)) - assert.Equal(t, testData.expectedVars, transaction.Statement.Vars) -} - -func TestSearchRuns(t *testing.T) { - t.Parallel() - - for _, dialector := range dialectors { - database, err := gorm.Open(dialector, &gorm.Config{DryRun: true}) - require.NoError(t, err) - - dialectorName := database.Dialector.Name() - - for _, testData := range tests { - currentTestData := testData - if expectedSQL, ok := currentTestData.expectedSQL[dialectorName]; ok { - t.Run(currentTestData.name+"_"+dialectorName, func(t *testing.T) { - t.Parallel() - assertTestData(t, database, expectedSQL, currentTestData) - }) - } - } - } -} - -func TestInvalidSearchRunsQuery(t *testing.T) { - t.Parallel() - - database, err := gorm.Open(newSqliteDialector(), &gorm.Config{DryRun: true}) - require.NoError(t, err) - - transaction := database.Model(&models.Run{}) - - contractErr := applyFilter(context.Background(), database, transaction, "⚡✱*@❖$#&") - if contractErr == nil { - t.Fatal("expected contract error") - } -} - -//nolint:funlen -func TestOrderByClauseParsing(t *testing.T) { - t.Parallel() - - testData := []struct { - input string - expected orderByExpr - }{ - { - input: "status DESC", - expected: orderByExpr{ - key: "status", - order: utils.PtrTo("DESC"), - }, - }, - { - input: "run_name", - expected: orderByExpr{ - key: "name", - }, - }, - { - input: "params.input DESC", - expected: orderByExpr{ - identifier: utils.PtrTo("parameter"), - key: "input", - order: utils.PtrTo("DESC"), - }, - }, - { - input: "metrics.alpha ASC", - expected: orderByExpr{ - identifier: utils.PtrTo("metric"), - key: "alpha", - order: utils.PtrTo("ASC"), - }, - }, - { - input: "`Run name`", - expected: orderByExpr{ - key: "name", - }, - }, - { - input: "tags.`foo bar` ASC", - expected: orderByExpr{ - identifier: utils.PtrTo("tag"), - key: "foo bar", - order: utils.PtrTo("ASC"), - }, - }, - } - - for _, testData := range testData { - t.Run(strcase.ToKebab(testData.input), func(t *testing.T) { - t.Parallel() - - result, err := processOrderByClause(testData.input) - if err != nil { - t.Fatalf("unexpected error: %A", err) - } - - if !reflect.DeepEqual(testData.expected, result) { - t.Fatalf("expected (%s, %s, %s), got (%s, %s, %s)", - *testData.expected.identifier, - testData.expected.key, - *testData.expected.order, - *result.identifier, - result.key, - *result.order, - ) - } - }) - } -} +//nolint:ireturn +package sql + +import ( + "context" + "reflect" + "regexp" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/iancoleman/strcase" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type testData struct { + name string + query string + orderBy []string + expectedSQL map[string]string + expectedVars []any +} + +var whitespaceRegex = regexp.MustCompile(`\s` + "|`") + +func removeWhitespace(s string) string { + return whitespaceRegex.ReplaceAllString(s, "") +} + +var tests = []testData{ + { + name: "SimpleMetricQuery", + query: "metrics.accuracy > 0.72", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) + AS filter_0 + ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlserver": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = @p1 AND value > @p2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "mysql": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"accuracy", 0.72}, + }, + { + name: "SimpleMetricAndParamQuery", + query: "metrics.accuracy > 0.72 AND params.batch_size = '2'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value = $4) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND value = ?) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"accuracy", 0.72, "batch_size", "2"}, + }, + { + name: "TagQuery", + query: "tags.environment = 'notebook' AND tags.task ILIKE 'classif%'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $3 AND value ILIKE $4) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"environment", "notebook", "task", "classif%"}, + }, + { + name: "DatasestsInQuery", + query: "datasets.digest IN ('s8ds293b', 'jks834s2')", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT destination_id,"digest" + FROM "datasets" JOIN inputs ON inputs.source_id = datasets.dataset_uuid + WHERE digest IN ($1,$2) + ) + AS filter_0 ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT destination_id,digest + FROM datasets JOIN inputs + ON inputs.source_id = datasets.dataset_uuid + WHERE digest IN (?,?) + ) + AS filter_0 + ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"s8ds293b", "jks834s2"}, + }, + { + name: "AttributesQuery", + query: "attributes.run_id = 'a1b2c3d4'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + WHERE runs.run_uuid = $1 + ORDER BY runs.start_time DESC,runs.run_uuid + `, + "sqlite": `SELECT run_uuid FROM runs WHERE runs.run_uuid = ? ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"a1b2c3d4"}, + }, + { + name: "Run_nameQuery", + query: "attributes.run_name = 'my-run'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"mlflow.runName", "my-run"}, + }, + { + name: "DatasetsContextQuery", + query: "datasets.context = 'train'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT inputs.destination_id AS run_uuid + FROM "inputs" + JOIN input_tags + ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND input_tags.value = $1 + WHERE inputs.destination_type = 'RUN' + ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT inputs.destination_id AS run_uuid + FROM inputs + JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND input_tags.value = ? WHERE inputs.destination_type = 'RUN' + ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"train"}, + }, + { + name: "Run_nameQuery", + query: "attributes.run_name ILIKE 'my-run%'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value ILIKE $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid, value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"mlflow.runName", "my-run%"}, + }, + { + name: "DatasetsContextQuery", + query: "datasets.context ILIKE '%train'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT inputs.destination_id AS run_uuid FROM "inputs" + JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND input_tags.value ILIKE $1 WHERE inputs.destination_type = 'RUN' + ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT inputs.destination_id AS run_uuid FROM inputs + JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND LOWER(input_tags.value) LIKE ? WHERE inputs.destination_type = 'RUN') + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid + `, + }, + expectedVars: []any{"%train"}, + }, + { + name: "DatasestsDigest", + query: "datasets.digest ILIKE '%s'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT destination_id,"digest" + FROM "datasets" + JOIN inputs ON inputs.source_id = datasets.dataset_uuid + WHERE digest ILIKE $1 + ) + AS filter_0 ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT destination_id,digest + FROM datasets + JOIN inputs ON inputs.source_id = datasets.dataset_uuid + WHERE LOWER(digest) LIKE ?) + AS filter_0 ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"%s"}, + }, + { + name: "ParamQuery", + query: "metrics.accuracy > 0.72 AND params.batch_size ILIKE '%a'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value ILIKE $4) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid, value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND LOWER(value) LIKE ?) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid + `, + }, + expectedVars: []any{"accuracy", 0.72, "batch_size", "%a"}, + }, + { + name: "OrderByStartTimeASC", + query: "", + orderBy: []string{"start_time ASC"}, + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "start_time",runs.run_uuid`, + }, + expectedVars: []any{}, + }, + { + name: "OrderByStatusDesc", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "status" DESC,runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"status DESC"}, + expectedVars: []any{}, + }, + { + name: "OrderByRunNameSnakeCase", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"run_name"}, + expectedVars: []any{}, + }, + { + name: "OrderByRunNameLowerName", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"`Run name`"}, + expectedVars: []any{}, + }, + { + name: "OrderByRunNamePascal", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"`Run Name`"}, + expectedVars: []any{}, + }, +} + +func newPostgresDialector() gorm.Dialector { + mockedDB, _, _ := sqlmock.New() + + return postgres.New(postgres.Config{ + Conn: mockedDB, + DriverName: "postgres", + }) +} + +func newSqliteDialector() gorm.Dialector { + mockedDB, mock, _ := sqlmock.New() + mock.ExpectQuery("select sqlite_version()").WillReturnRows( + sqlmock.NewRows([]string{"sqlite_version()"}).AddRow("3.41.1")) + + return sqlite.New(sqlite.Config{ + DriverName: "sqlite3", + Conn: mockedDB, + }) +} + +func newSQLServerDialector() gorm.Dialector { + mockedDB, _, _ := sqlmock.New() + + return sqlserver.New(sqlserver.Config{ + DriverName: "sqlserver", + Conn: mockedDB, + }) +} + +func newMySQLDialector() gorm.Dialector { + mockedDB, _, _ := sqlmock.New() + + return mysql.New(mysql.Config{ + DriverName: "mysql", + Conn: mockedDB, + SkipInitializeWithVersion: true, + }) +} + +var dialectors = []gorm.Dialector{ + newPostgresDialector(), + newSqliteDialector(), + newSQLServerDialector(), + newMySQLDialector(), +} + +func assertTestData( + t *testing.T, database *gorm.DB, expectedSQL string, testData testData, +) { + t.Helper() + + transaction := database.Model(&models.Run{}) + + contractErr := applyFilter(context.Background(), database, transaction, testData.query) + if contractErr != nil { + t.Fatal("contractErr: ", contractErr) + } + + contractErr = applyOrderBy(context.Background(), database, transaction, testData.orderBy) + if contractErr != nil { + t.Fatal("contractErr: ", contractErr) + } + + sqlErr := transaction.Select("ID").Find(&models.Run{}).Error + require.NoError(t, sqlErr) + + actualSQL := transaction.Statement.SQL.String() + + // if removeWhitespace(expectedSQL) != removeWhitespace(actualSQL) { + // fmt.Println(strings.ReplaceAll(actualSQL, "`", "")) + // } + + assert.Equal(t, removeWhitespace(expectedSQL), removeWhitespace(actualSQL)) + assert.Equal(t, testData.expectedVars, transaction.Statement.Vars) +} + +func TestSearchRuns(t *testing.T) { + t.Parallel() + + for _, dialector := range dialectors { + database, err := gorm.Open(dialector, &gorm.Config{DryRun: true}) + require.NoError(t, err) + + dialectorName := database.Dialector.Name() + + for _, testData := range tests { + currentTestData := testData + if expectedSQL, ok := currentTestData.expectedSQL[dialectorName]; ok { + t.Run(currentTestData.name+"_"+dialectorName, func(t *testing.T) { + t.Parallel() + assertTestData(t, database, expectedSQL, currentTestData) + }) + } + } + } +} + +func TestInvalidSearchRunsQuery(t *testing.T) { + t.Parallel() + + database, err := gorm.Open(newSqliteDialector(), &gorm.Config{DryRun: true}) + require.NoError(t, err) + + transaction := database.Model(&models.Run{}) + + contractErr := applyFilter(context.Background(), database, transaction, "⚡✱*@❖$#&") + if contractErr == nil { + t.Fatal("expected contract error") + } +} + +//nolint:funlen +func TestOrderByClauseParsing(t *testing.T) { + t.Parallel() + + testData := []struct { + input string + expected orderByExpr + }{ + { + input: "status DESC", + expected: orderByExpr{ + key: "status", + order: utils.PtrTo("DESC"), + }, + }, + { + input: "run_name", + expected: orderByExpr{ + key: "name", + }, + }, + { + input: "params.input DESC", + expected: orderByExpr{ + identifier: utils.PtrTo("parameter"), + key: "input", + order: utils.PtrTo("DESC"), + }, + }, + { + input: "metrics.alpha ASC", + expected: orderByExpr{ + identifier: utils.PtrTo("metric"), + key: "alpha", + order: utils.PtrTo("ASC"), + }, + }, + { + input: "`Run name`", + expected: orderByExpr{ + key: "name", + }, + }, + { + input: "tags.`foo bar` ASC", + expected: orderByExpr{ + identifier: utils.PtrTo("tag"), + key: "foo bar", + order: utils.PtrTo("ASC"), + }, + }, + } + + for _, testData := range testData { + t.Run(strcase.ToKebab(testData.input), func(t *testing.T) { + t.Parallel() + + result, err := processOrderByClause(testData.input) + if err != nil { + t.Fatalf("unexpected error: %A", err) + } + + if !reflect.DeepEqual(testData.expected, result) { + t.Fatalf("expected (%s, %s, %s), got (%s, %s, %s)", + *testData.expected.identifier, + testData.expected.key, + *testData.expected.order, + *result.identifier, + result.key, + *result.order, + ) + } + }) + } +} diff --git a/pkg/tracking/store/sql/store.go b/pkg/tracking/store/sql/store.go index 0570508..9fb40f3 100644 --- a/pkg/tracking/store/sql/store.go +++ b/pkg/tracking/store/sql/store.go @@ -1,28 +1,28 @@ -package sql - -import ( - "context" - "fmt" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/sql" -) - -type TrackingSQLStore struct { - config *config.Config - db *gorm.DB -} - -func NewTrackingSQLStore(ctx context.Context, config *config.Config) (*TrackingSQLStore, error) { - database, err := sql.NewDatabase(ctx, config.TrackingStoreURI) - if err != nil { - return nil, fmt.Errorf("failed to connect to database %q: %w", config.TrackingStoreURI, err) - } - - return &TrackingSQLStore{ - config: config, - db: database, - }, nil -} +package sql + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/sql" +) + +type TrackingSQLStore struct { + config *config.Config + db *gorm.DB +} + +func NewTrackingSQLStore(ctx context.Context, config *config.Config) (*TrackingSQLStore, error) { + database, err := sql.NewDatabase(ctx, config.TrackingStoreURI) + if err != nil { + return nil, fmt.Errorf("failed to connect to database %q: %w", config.TrackingStoreURI, err) + } + + return &TrackingSQLStore{ + config: config, + db: database, + }, nil +} diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index e20ae13..107d487 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -1,82 +1,262 @@ -package sql - -import ( - "context" - "errors" - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -const tagsBatchSize = 100 - -func (s TrackingSQLStore) GetRunTag( - ctx context.Context, runID, tagKey string, -) (*entities.RunTag, *contract.Error) { - var runTag models.Tag - if err := s.db.WithContext( - ctx, - ).Where( - "run_uuid = ?", runID, - ).Where( - "key = ?", tagKey, - ).First(&runTag).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, nil - } - - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to get run tag for run id %q", runID), - err, - ) - } - - return runTag.ToEntity(), nil -} - -func (s TrackingSQLStore) setTagsWithTransaction( - transaction *gorm.DB, runID string, tags []*entities.RunTag, -) error { - runColumns := make(map[string]interface{}) - - for _, tag := range tags { - switch tag.Key { - case utils.TagUser: - runColumns["user_id"] = tag.Value - case utils.TagRunName: - runColumns["name"] = tag.Value - } - } - - if len(runColumns) != 0 { - err := transaction. - Model(&models.Run{}). - Where("run_uuid = ?", runID). - UpdateColumns(runColumns).Error - if err != nil { - return fmt.Errorf("failed to update run columns: %w", err) - } - } - - runTags := make([]models.Tag, 0, len(tags)) - - for _, tag := range tags { - runTags = append(runTags, models.NewTagFromEntity(runID, tag)) - } - - if err := transaction.Clauses(clause.OnConflict{ - UpdateAll: true, - }).CreateInBatches(runTags, tagsBatchSize).Error; err != nil { - return fmt.Errorf("failed to create tags for run %q: %w", runID, err) - } - - return nil -} +package sql + +import ( + "context" + "errors" + "fmt" + "strconv" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +const tagsBatchSize = 100 + +func (s TrackingSQLStore) GetRunTag( + ctx context.Context, runID, tagKey string, +) (*entities.RunTag, *contract.Error) { + var runTag models.Tag + if err := s.db.WithContext( + ctx, + ).Where( + "run_uuid = ?", runID, + ).Where( + "key = ?", tagKey, + ).First(&runTag).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to get run tag for run id %q", runID), + err, + ) + } + + return runTag.ToEntity(), nil +} + +func (s TrackingSQLStore) setTagsWithTransaction( + transaction *gorm.DB, runID string, tags []*entities.RunTag, +) error { + runColumns := make(map[string]interface{}) + + for _, tag := range tags { + switch tag.Key { + case utils.TagUser: + runColumns["user_id"] = tag.Value + case utils.TagRunName: + runColumns["name"] = tag.Value + } + } + + if len(runColumns) != 0 { + err := transaction. + Model(&models.Run{}). + Where("run_uuid = ?", runID). + UpdateColumns(runColumns).Error + if err != nil { + return fmt.Errorf("failed to update run columns: %w", err) + } + } + + runTags := make([]models.Tag, 0, len(tags)) + + for _, tag := range tags { + runTags = append(runTags, models.NewTagFromEntity(runID, tag)) + } + + if err := transaction.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(runTags, tagsBatchSize).Error; err != nil { + return fmt.Errorf("failed to create tags for run %q: %w", runID, err) + } + + return nil +} + +const ( + maxEntityKeyLength = 250 + maxTagValueLength = 8000 +) + +// Helper function to validate the tag key and value +func validateTag(key, value string) *contract.Error { + if key == "" { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + "Missing value for required parameter 'key'", + ) + } + if len(key) > maxEntityKeyLength { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Tag key '%s' had length %d, which exceeded length limit of %d", key, len(key), maxEntityKeyLength), + ) + } + if len(value) > maxTagValueLength { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Tag value exceeded length limit of %d characters", maxTagValueLength), + ) + } + // TODO: Check if this is the correct way to prevent invalid values + if _, err := strconv.ParseFloat(value, 64); err == nil { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Invalid value %s for parameter 'value' supplied", value), + ) + } + return nil +} + +func (s TrackingSQLStore) SetTag( + ctx context.Context, runID, key, value string, +) *contract.Error { + if runID == "" { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + "RunID cannot be empty", + ) + } + + // If the runID can be parsed as a number, it should throw an error + if _, err := strconv.ParseFloat(runID, 64); err == nil { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Invalid value %s for parameter 'run_id' supplied", runID), + ) + } + + if err := validateTag(key, value); err != nil { + return err + } + + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + var tag models.Tag + result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) + + if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), + result.Error, + ) + } + + if result.RowsAffected == 1 { + tag.Value = value + if err := transaction.Save(&tag).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), + err, + ) + } + } else { + newTag := models.Tag{ + RunID: runID, + Key: key, + Value: value, + } + if err := transaction.Create(&newTag).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), + err, + ) + } + } + + return nil + }) + + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("set tag transaction failed for %q", runID), + err, + ) + } + + return nil +} + +const badDataMessage = "Bad data in database - tags for a specific run must have\n" + + "a single unique value.\n" + + "See https://mlflow.org/docs/latest/tracking.html#adding-tags-to-runs" + +func (s TrackingSQLStore) DeleteTag( + ctx context.Context, runID, key string, +) *contract.Error { + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + var tags []models.Tag + + transaction.Model(models.Tag{}).Where("run_uuid = ?", runID).Where("key = ?", key).Find(&tags) + + if transaction.Error != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query tags for run_id %q and key %q", runID, key), + transaction.Error, + ) + } + + switch len(tags) { + case 0: + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No tag with name: %s in run with id %s", key, runID), + ) + case 1: + transaction.Delete(tags[0]) + + if transaction.Error != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query tags for run_id %q and key %q", runID, key), + transaction.Error, + ) + } + + return nil + default: + return contract.NewError(protos.ErrorCode_INVALID_STATE, badDataMessage) + } + }) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("delete tag transaction failed for %q", runID), + err, + ) + } + + return nil +} \ No newline at end of file diff --git a/pkg/tracking/store/store.go b/pkg/tracking/store/store.go index 24839e7..f83fb66 100644 --- a/pkg/tracking/store/store.go +++ b/pkg/tracking/store/store.go @@ -1,79 +1,81 @@ -package store - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -//go:generate mockery -type TrackingStore interface { - RunTrackingStore - MetricTrackingStore - ExperimentTrackingStore -} - -type ( - RunTrackingStore interface { - GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) - CreateRun( - ctx context.Context, - experimentID string, - userID string, - startTime int64, - tags []*entities.RunTag, - runName string, - ) (*entities.Run, *contract.Error) - UpdateRun( - ctx context.Context, - runID string, - runStatus string, - endTime *int64, - runName string, - ) *contract.Error - DeleteRun(ctx context.Context, runID string) *contract.Error - RestoreRun(ctx context.Context, runID string) *contract.Error - - GetRunTag(ctx context.Context, runID, tagKey string) (*entities.RunTag, *contract.Error) - } - MetricTrackingStore interface { - LogBatch( - ctx context.Context, - runID string, - metrics []*entities.Metric, - params []*entities.Param, - tags []*entities.RunTag) *contract.Error - - LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error - } -) - -type ExperimentTrackingStore interface { - // GetExperiment returns experiment by the experiment ID. - // The experiment should contain the linked tags. - GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) - GetExperimentByName(ctx context.Context, name string) (*entities.Experiment, *contract.Error) - - CreateExperiment( - ctx context.Context, - name string, - artifactLocation string, - tags []*entities.ExperimentTag, - ) (string, *contract.Error) - RestoreExperiment(ctx context.Context, id string) *contract.Error - RenameExperiment(ctx context.Context, experimentID, name string) *contract.Error - - SearchRuns( - ctx context.Context, - experimentIDs []string, - filter string, - runViewType protos.ViewType, - maxResults int, - orderBy []string, - pageToken string, - ) ([]*entities.Run, string, *contract.Error) - - DeleteExperiment(ctx context.Context, id string) *contract.Error -} +package store + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +//go:generate mockery +type TrackingStore interface { + RunTrackingStore + MetricTrackingStore + ExperimentTrackingStore +} + +type ( + RunTrackingStore interface { + GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) + CreateRun( + ctx context.Context, + experimentID string, + userID string, + startTime int64, + tags []*entities.RunTag, + runName string, + ) (*entities.Run, *contract.Error) + UpdateRun( + ctx context.Context, + runID string, + runStatus string, + endTime *int64, + runName string, + ) *contract.Error + DeleteRun(ctx context.Context, runID string) *contract.Error + RestoreRun(ctx context.Context, runID string) *contract.Error + + GetRunTag(ctx context.Context, runID, tagKey string) (*entities.RunTag, *contract.Error) + SetTag(ctx context.Context, runID, key string, value string) *contract.Error + DeleteTag(ctx context.Context, runID, key string) *contract.Error + } + MetricTrackingStore interface { + LogBatch( + ctx context.Context, + runID string, + metrics []*entities.Metric, + params []*entities.Param, + tags []*entities.RunTag) *contract.Error + + LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error + } +) + +type ExperimentTrackingStore interface { + // GetExperiment returns experiment by the experiment ID. + // The experiment should contain the linked tags. + GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) + GetExperimentByName(ctx context.Context, name string) (*entities.Experiment, *contract.Error) + + CreateExperiment( + ctx context.Context, + name string, + artifactLocation string, + tags []*entities.ExperimentTag, + ) (string, *contract.Error) + RestoreExperiment(ctx context.Context, id string) *contract.Error + RenameExperiment(ctx context.Context, experimentID, name string) *contract.Error + + SearchRuns( + ctx context.Context, + experimentIDs []string, + filter string, + runViewType protos.ViewType, + maxResults int, + orderBy []string, + pageToken string, + ) ([]*entities.Run, string, *contract.Error) + + DeleteExperiment(ctx context.Context, id string) *contract.Error +} diff --git a/pkg/utils/logger.go b/pkg/utils/logger.go index 13249f7..ff44dff 100644 --- a/pkg/utils/logger.go +++ b/pkg/utils/logger.go @@ -1,49 +1,49 @@ -package utils - -import ( - "context" - - "github.com/gofiber/fiber/v2" - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" -) - -type loggerKey struct{} - -func NewContextWithLogger(ctx context.Context, logger *logrus.Logger) context.Context { - return context.WithValue(ctx, loggerKey{}, logger) -} - -// NewContextWithLoggerFromFiberContext transfer logger from Fiber context to a normal context.Context object. -func NewContextWithLoggerFromFiberContext(c *fiber.Ctx) context.Context { - logger := GetLoggerFromContext(c.UserContext()) - - return NewContextWithLogger(c.Context(), logger) -} - -func GetLoggerFromContext(ctx context.Context) *logrus.Logger { - logger := ctx.Value(loggerKey{}) - if logger != nil { - logger, ok := logger.(*logrus.Logger) - if ok { - return logger - } - } - - return logrus.StandardLogger() -} - -func NewLoggerFromConfig(cfg *config.Config) *logrus.Logger { - logger := logrus.New() - - logLevel, err := logrus.ParseLevel(cfg.LogLevel) - if err != nil { - logLevel = logrus.InfoLevel - logger.Warnf("failed to parse log level: %s - assuming %q", err, logrus.InfoLevel) - } - - logger.SetLevel(logLevel) - - return logger -} +package utils + +import ( + "context" + + "github.com/gofiber/fiber/v2" + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" +) + +type loggerKey struct{} + +func NewContextWithLogger(ctx context.Context, logger *logrus.Logger) context.Context { + return context.WithValue(ctx, loggerKey{}, logger) +} + +// NewContextWithLoggerFromFiberContext transfer logger from Fiber context to a normal context.Context object. +func NewContextWithLoggerFromFiberContext(c *fiber.Ctx) context.Context { + logger := GetLoggerFromContext(c.UserContext()) + + return NewContextWithLogger(c.Context(), logger) +} + +func GetLoggerFromContext(ctx context.Context) *logrus.Logger { + logger := ctx.Value(loggerKey{}) + if logger != nil { + logger, ok := logger.(*logrus.Logger) + if ok { + return logger + } + } + + return logrus.StandardLogger() +} + +func NewLoggerFromConfig(cfg *config.Config) *logrus.Logger { + logger := logrus.New() + + logLevel, err := logrus.ParseLevel(cfg.LogLevel) + if err != nil { + logLevel = logrus.InfoLevel + logger.Warnf("failed to parse log level: %s - assuming %q", err, logrus.InfoLevel) + } + + logger.SetLevel(logLevel) + + return logger +} diff --git a/pkg/utils/naming.go b/pkg/utils/naming.go index 76ee216..2b7e493 100644 --- a/pkg/utils/naming.go +++ b/pkg/utils/naming.go @@ -1,71 +1,71 @@ -package utils - -import ( - "crypto/rand" - "fmt" - "math/big" -) - -var nouns = []string{ - "ant", "ape", "asp", "auk", "bass", "bat", "bear", "bee", "bird", "boar", - "bug", "calf", "carp", "cat", "chimp", "cod", "colt", "conch", "cow", - "crab", "crane", "croc", "crow", "cub", "deer", "doe", "dog", "dolphin", - "donkey", "dove", "duck", "eel", "elk", "fawn", "finch", "fish", "flea", - "fly", "foal", "fowl", "fox", "frog", "gnat", "gnu", "goat", "goose", - "grouse", "grub", "gull", "hare", "hawk", "hen", "hog", "horse", "hound", - "jay", "kit", "kite", "koi", "lamb", "lark", "loon", "lynx", "mare", - "midge", "mink", "mole", "moose", "moth", "mouse", "mule", "newt", "owl", - "ox", "panda", "penguin", "perch", "pig", "pug", "quail", "ram", "rat", - "ray", "robin", "roo", "rook", "seal", "shad", "shark", "sheep", "shoat", - "shrew", "shrike", "shrimp", "skink", "skunk", "sloth", "slug", "smelt", - "snail", "snake", "snipe", "sow", "sponge", "squid", "squirrel", "stag", - "steed", "stoat", "stork", "swan", "tern", "toad", "trout", "turtle", - "vole", "wasp", "whale", "wolf", "worm", "wren", "yak", "zebra", -} - -var predicates = []string{ - "abundant", "able", "abrasive", "adorable", "adaptable", "adventurous", - "aged", "agreeable", "ambitious", "amazing", "amusing", "angry", - "auspicious", "awesome", "bald", "beautiful", "bemused", "bedecked", "big", - "bittersweet", "blushing", "bold", "bouncy", "brawny", "bright", "burly", - "bustling", "calm", "capable", "carefree", "capricious", "caring", - "casual", "charming", "chill", "classy", "clean", "clumsy", "colorful", - "crawling", "dapper", "debonair", "dashing", "defiant", "delicate", - "delightful", "dazzling", "efficient", "enchanting", "entertaining", - "enthused", "exultant", "fearless", "flawless", "fortunate", "fun", - "funny", "gaudy", "gentle", "gifted", "glamorous", "grandiose", - "gregarious", "handsome", "hilarious", "honorable", "illustrious", - "incongruous", "indecisive", "industrious", "intelligent", "inquisitive", - "intrigued", "invincible", "judicious", "kindly", "languid", "learned", - "legendary", "likeable", "loud", "luminous", "luxuriant", "lyrical", - "magnificent", "marvelous", "masked", "melodic", "merciful", "mercurial", - "monumental", "mysterious", "nebulous", "nervous", "nimble", "nosy", - "omniscient", "orderly", "overjoyed", "peaceful", "painted", "persistent", - "placid", "polite", "popular", "powerful", "puzzled", "rambunctious", - "rare", "rebellious", "respected", "resilient", "righteous", "receptive", - "redolent", "resilient", "rogue", "rumbling", "salty", "sassy", "secretive", - "selective", "sedate", "serious", "shivering", "skillful", "sincere", - "skittish", "silent", "smiling", -} - -const numRange = 1000 - -// GenerateRandomName generates random name for `run`. -func GenerateRandomName() (string, error) { - predicateIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(predicates)))) - if err != nil { - return "", fmt.Errorf("error getting random integer number: %w", err) - } - - nounIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nouns)))) - if err != nil { - return "", fmt.Errorf("error getting random integer number: %w", err) - } - - num, err := rand.Int(rand.Reader, big.NewInt(numRange)) - if err != nil { - return "", fmt.Errorf("error getting random integer number: %w", err) - } - - return fmt.Sprintf("%s-%s-%d", predicates[predicateIndex.Int64()], nouns[nounIndex.Int64()], num), nil -} +package utils + +import ( + "crypto/rand" + "fmt" + "math/big" +) + +var nouns = []string{ + "ant", "ape", "asp", "auk", "bass", "bat", "bear", "bee", "bird", "boar", + "bug", "calf", "carp", "cat", "chimp", "cod", "colt", "conch", "cow", + "crab", "crane", "croc", "crow", "cub", "deer", "doe", "dog", "dolphin", + "donkey", "dove", "duck", "eel", "elk", "fawn", "finch", "fish", "flea", + "fly", "foal", "fowl", "fox", "frog", "gnat", "gnu", "goat", "goose", + "grouse", "grub", "gull", "hare", "hawk", "hen", "hog", "horse", "hound", + "jay", "kit", "kite", "koi", "lamb", "lark", "loon", "lynx", "mare", + "midge", "mink", "mole", "moose", "moth", "mouse", "mule", "newt", "owl", + "ox", "panda", "penguin", "perch", "pig", "pug", "quail", "ram", "rat", + "ray", "robin", "roo", "rook", "seal", "shad", "shark", "sheep", "shoat", + "shrew", "shrike", "shrimp", "skink", "skunk", "sloth", "slug", "smelt", + "snail", "snake", "snipe", "sow", "sponge", "squid", "squirrel", "stag", + "steed", "stoat", "stork", "swan", "tern", "toad", "trout", "turtle", + "vole", "wasp", "whale", "wolf", "worm", "wren", "yak", "zebra", +} + +var predicates = []string{ + "abundant", "able", "abrasive", "adorable", "adaptable", "adventurous", + "aged", "agreeable", "ambitious", "amazing", "amusing", "angry", + "auspicious", "awesome", "bald", "beautiful", "bemused", "bedecked", "big", + "bittersweet", "blushing", "bold", "bouncy", "brawny", "bright", "burly", + "bustling", "calm", "capable", "carefree", "capricious", "caring", + "casual", "charming", "chill", "classy", "clean", "clumsy", "colorful", + "crawling", "dapper", "debonair", "dashing", "defiant", "delicate", + "delightful", "dazzling", "efficient", "enchanting", "entertaining", + "enthused", "exultant", "fearless", "flawless", "fortunate", "fun", + "funny", "gaudy", "gentle", "gifted", "glamorous", "grandiose", + "gregarious", "handsome", "hilarious", "honorable", "illustrious", + "incongruous", "indecisive", "industrious", "intelligent", "inquisitive", + "intrigued", "invincible", "judicious", "kindly", "languid", "learned", + "legendary", "likeable", "loud", "luminous", "luxuriant", "lyrical", + "magnificent", "marvelous", "masked", "melodic", "merciful", "mercurial", + "monumental", "mysterious", "nebulous", "nervous", "nimble", "nosy", + "omniscient", "orderly", "overjoyed", "peaceful", "painted", "persistent", + "placid", "polite", "popular", "powerful", "puzzled", "rambunctious", + "rare", "rebellious", "respected", "resilient", "righteous", "receptive", + "redolent", "resilient", "rogue", "rumbling", "salty", "sassy", "secretive", + "selective", "sedate", "serious", "shivering", "skillful", "sincere", + "skittish", "silent", "smiling", +} + +const numRange = 1000 + +// GenerateRandomName generates random name for `run`. +func GenerateRandomName() (string, error) { + predicateIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(predicates)))) + if err != nil { + return "", fmt.Errorf("error getting random integer number: %w", err) + } + + nounIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nouns)))) + if err != nil { + return "", fmt.Errorf("error getting random integer number: %w", err) + } + + num, err := rand.Int(rand.Reader, big.NewInt(numRange)) + if err != nil { + return "", fmt.Errorf("error getting random integer number: %w", err) + } + + return fmt.Sprintf("%s-%s-%d", predicates[predicateIndex.Int64()], nouns[nounIndex.Int64()], num), nil +} diff --git a/pkg/utils/path.go b/pkg/utils/path.go index 9ac35e0..39230a5 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -1,90 +1,90 @@ -package utils - -import ( - "errors" - "fmt" - "net/url" - "path" - "strings" -) - -var ( - errFailedToDecodeURL = errors.New("failed to decode url") - errInvalidQueryString = errors.New("invalid query string") -) - -func decode(input string) (string, error) { - current := input - - for range 10 { - decoded, err := url.QueryUnescape(current) - if err != nil { - return "", fmt.Errorf("could not unescape %s: %w", current, err) - } - - parsed, err := url.Parse(decoded) - if err != nil { - return "", fmt.Errorf("could not parsed %s: %w", decoded, err) - } - - if current == parsed.String() { - return current, nil - } - } - - return "", errFailedToDecodeURL -} - -func validateQueryString(query string) error { - query, err := decode(query) - if err != nil { - return err - } - - if strings.Contains(query, "..") { - return errInvalidQueryString - } - - return nil -} - -func joinPosixPathsAndAppendAbsoluteSuffixes(prefixPath, suffixPath string) string { - if len(prefixPath) == 0 { - return suffixPath - } - - suffixPath = strings.TrimPrefix(suffixPath, "/") - - return path.Join(prefixPath, suffixPath) -} - -func AppendToURIPath(uri string, paths ...string) (string, error) { - path := "" - for _, subpath := range paths { - path = joinPosixPathsAndAppendAbsoluteSuffixes(path, subpath) - } - - parsedURI, err := url.Parse(uri) - if err != nil { - return "", fmt.Errorf("failed to parse uri %s: %w", uri, err) - } - - if err := validateQueryString(parsedURI.RawQuery); err != nil { - return "", err - } - - if len(parsedURI.Scheme) == 0 { - return joinPosixPathsAndAppendAbsoluteSuffixes(uri, path), nil - } - - prefix := "" - if !strings.HasPrefix(parsedURI.Path, "/") { - prefix = parsedURI.Scheme + ":" - parsedURI.Scheme = "" - } - - newURIPath := joinPosixPathsAndAppendAbsoluteSuffixes(parsedURI.Path, path) - parsedURI.Path = newURIPath - - return prefix + parsedURI.String(), nil -} +package utils + +import ( + "errors" + "fmt" + "net/url" + "path" + "strings" +) + +var ( + errFailedToDecodeURL = errors.New("failed to decode url") + errInvalidQueryString = errors.New("invalid query string") +) + +func decode(input string) (string, error) { + current := input + + for range 10 { + decoded, err := url.QueryUnescape(current) + if err != nil { + return "", fmt.Errorf("could not unescape %s: %w", current, err) + } + + parsed, err := url.Parse(decoded) + if err != nil { + return "", fmt.Errorf("could not parsed %s: %w", decoded, err) + } + + if current == parsed.String() { + return current, nil + } + } + + return "", errFailedToDecodeURL +} + +func validateQueryString(query string) error { + query, err := decode(query) + if err != nil { + return err + } + + if strings.Contains(query, "..") { + return errInvalidQueryString + } + + return nil +} + +func joinPosixPathsAndAppendAbsoluteSuffixes(prefixPath, suffixPath string) string { + if len(prefixPath) == 0 { + return suffixPath + } + + suffixPath = strings.TrimPrefix(suffixPath, "/") + + return path.Join(prefixPath, suffixPath) +} + +func AppendToURIPath(uri string, paths ...string) (string, error) { + path := "" + for _, subpath := range paths { + path = joinPosixPathsAndAppendAbsoluteSuffixes(path, subpath) + } + + parsedURI, err := url.Parse(uri) + if err != nil { + return "", fmt.Errorf("failed to parse uri %s: %w", uri, err) + } + + if err := validateQueryString(parsedURI.RawQuery); err != nil { + return "", err + } + + if len(parsedURI.Scheme) == 0 { + return joinPosixPathsAndAppendAbsoluteSuffixes(uri, path), nil + } + + prefix := "" + if !strings.HasPrefix(parsedURI.Path, "/") { + prefix = parsedURI.Scheme + ":" + parsedURI.Scheme = "" + } + + newURIPath := joinPosixPathsAndAppendAbsoluteSuffixes(parsedURI.Path, path) + parsedURI.Path = newURIPath + + return prefix + parsedURI.String(), nil +} diff --git a/pkg/utils/pointers.go b/pkg/utils/pointers.go index b26e729..cd44985 100644 --- a/pkg/utils/pointers.go +++ b/pkg/utils/pointers.go @@ -1,41 +1,41 @@ -package utils - -import ( - "strconv" -) - -func PtrTo[T any](v T) *T { - return &v -} - -func ConvertInt32PointerToStringPointer(iPtr *int32) *string { - if iPtr == nil { - return nil - } - - iValue := *iPtr - sValue := strconv.Itoa(int(iValue)) - - return &sValue -} - -func ConvertStringPointerToInt32Pointer(s *string) int32 { - if s == nil { - return 0 - } - - iValue, err := strconv.ParseInt(*s, 10, 32) - if err != nil { - return 0 - } - - return int32(iValue) -} - -func DumpStringPointer(s *string) string { - if s == nil { - return "" - } - - return *s -} +package utils + +import ( + "strconv" +) + +func PtrTo[T any](v T) *T { + return &v +} + +func ConvertInt32PointerToStringPointer(iPtr *int32) *string { + if iPtr == nil { + return nil + } + + iValue := *iPtr + sValue := strconv.Itoa(int(iValue)) + + return &sValue +} + +func ConvertStringPointerToInt32Pointer(s *string) int32 { + if s == nil { + return 0 + } + + iValue, err := strconv.ParseInt(*s, 10, 32) + if err != nil { + return 0 + } + + return int32(iValue) +} + +func DumpStringPointer(s *string) string { + if s == nil { + return "" + } + + return *s +} diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go index e0b8773..0dec0e7 100644 --- a/pkg/utils/strings.go +++ b/pkg/utils/strings.go @@ -1,24 +1,24 @@ -package utils - -import ( - "encoding/hex" - - "github.com/google/uuid" -) - -func IsNotNilOrEmptyString(v *string) bool { - return v != nil && *v != "" -} - -func IsNilOrEmptyString(v *string) bool { - return v == nil || *v == "" -} - -func NewUUID() string { - var r [32]byte - - u := uuid.New() - hex.Encode(r[:], u[:]) - - return string(r[:]) -} +package utils + +import ( + "encoding/hex" + + "github.com/google/uuid" +) + +func IsNotNilOrEmptyString(v *string) bool { + return v != nil && *v != "" +} + +func IsNilOrEmptyString(v *string) bool { + return v == nil || *v == "" +} + +func NewUUID() string { + var r [32]byte + + u := uuid.New() + hex.Encode(r[:], u[:]) + + return string(r[:]) +} diff --git a/pkg/utils/tags.go b/pkg/utils/tags.go index 1bba1aa..98474af 100644 --- a/pkg/utils/tags.go +++ b/pkg/utils/tags.go @@ -1,6 +1,6 @@ -package utils - -const ( - TagRunName = "mlflow.runName" - TagUser = "mlflow.user" -) +package utils + +const ( + TagRunName = "mlflow.runName" + TagUser = "mlflow.user" +) diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 6349abb..006e994 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -1,299 +1,299 @@ -package validation - -import ( - "encoding/json" - "errors" - "fmt" - "net/url" - "os" - "path/filepath" - "reflect" - "regexp" - "strconv" - "strings" - - "github.com/go-playground/validator/v10" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -const ( - QuoteLength = 2 - MaxEntitiesPerBatch = 1000 - MaxValidationInputLength = 100 -) - -// regex for valid param and metric names: may only contain slashes, alphanumerics, -// underscores, periods, dashes, and spaces. -var paramAndMetricNameRegex = regexp.MustCompile(`^[/\w.\- ]*$`) - -// regex for valid run IDs: must be an alphanumeric string of length 1 to 256. -var runIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][\w\-]{0,255}$`) - -func stringAsPositiveIntegerValidation(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - - value, err := strconv.Atoi(valueStr) - if err != nil { - return false - } - - return value > -1 -} - -func uriWithoutFragmentsOrParamsOrDotDotInQueryValidation(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - if valueStr == "" { - return true - } - - u, err := url.Parse(valueStr) - if err != nil { - return false - } - - return u.Fragment == "" && u.RawQuery == "" && !strings.Contains(u.RawQuery, "..") -} - -func uniqueParamsValidation(fl validator.FieldLevel) bool { - value := fl.Field() - - params, areParams := value.Interface().([]*protos.Param) - if !areParams || len(params) == 0 { - return true - } - - hasDuplicates := false - keys := make(map[string]bool, len(params)) - - for _, param := range params { - if _, ok := keys[param.GetKey()]; ok { - hasDuplicates = true - - break - } - - keys[param.GetKey()] = true - } - - return !hasDuplicates -} - -func pathIsClean(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - norm := filepath.Clean(valueStr) - - return !(norm != valueStr || norm == "." || strings.HasPrefix(norm, "..") || strings.HasPrefix(norm, "/")) -} - -func regexValidation(regex *regexp.Regexp) validator.Func { - return func(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - - return regex.MatchString(valueStr) - } -} - -// see _validate_batch_log_limits in validation.py. -func validateLogBatchLimits(structLevel validator.StructLevel) { - logBatch, isLogBatch := structLevel.Current().Interface().(*protos.LogBatch) - - if isLogBatch { - total := len(logBatch.GetParams()) + len(logBatch.GetMetrics()) + len(logBatch.GetTags()) - if total > MaxEntitiesPerBatch { - structLevel.ReportError(&logBatch, "metrics, params, and tags", "", "", "") - } - } -} - -func truncateFn(fieldLevel validator.FieldLevel) bool { - param := fieldLevel.Param() // Get the parameter from the tag - - maxLength, err := strconv.Atoi(param) - if err != nil { - return false // If the parameter isn't a valid integer, fail the validation. - } - - truncateLongValues, shouldTruncate := os.LookupEnv("MLFLOW_TRUNCATE_LONG_VALUES") - shouldTruncate = shouldTruncate && truncateLongValues == "true" - - field := fieldLevel.Field() - - if field.Kind() == reflect.String { - strValue := field.String() - if len(strValue) <= maxLength { - return true - } - - if shouldTruncate { - field.SetString(strValue[:maxLength]) - - return true - } - - return false - } - - return true -} - -func NewValidator() (*validator.Validate, error) { - validate := validator.New() - - validate.RegisterTagNameFunc(func(fld reflect.StructField) string { - name := strings.SplitN(fld.Tag.Get("json"), ",", QuoteLength)[0] - // skip if tag key says it should be ignored - if name == "-" { - return "" - } - - return name - }) - - // Verify that the input string is a positive integer. - if err := validate.RegisterValidation( - "stringAsPositiveInteger", stringAsPositiveIntegerValidation, - ); err != nil { - return nil, fmt.Errorf("validation registration for 'stringAsPositiveInteger' failed: %w", err) - } - - // Verify that the input string, if present, is a Url without fragment or query parameters - if err := validate.RegisterValidation( - "uriWithoutFragmentsOrParamsOrDotDotInQuery", uriWithoutFragmentsOrParamsOrDotDotInQueryValidation); err != nil { - return nil, fmt.Errorf("validation registration for 'uriWithoutFragmentsOrParamsOrDotDotInQuery' failed: %w", err) - } - - if err := validate.RegisterValidation( - "validMetricParamOrTagName", regexValidation(paramAndMetricNameRegex), - ); err != nil { - return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagName' failed: %w", err) - } - - if err := validate.RegisterValidation("pathIsUnique", pathIsClean); err != nil { - return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagValue' failed: %w", err) - } - - // unique params in LogBatch - if err := validate.RegisterValidation("uniqueParams", uniqueParamsValidation); err != nil { - return nil, fmt.Errorf("validation registration for 'uniqueParams' failed: %w", err) - } - - if err := validate.RegisterValidation("runId", regexValidation(runIDRegex)); err != nil { - return nil, fmt.Errorf("validation registration for 'runId' failed: %w", err) - } - - if err := validate.RegisterValidation("truncate", truncateFn); err != nil { - return nil, fmt.Errorf("validation registration for 'truncateFn' failed: %w", err) - } - - validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{}) - - return validate, nil -} - -func dereference(value interface{}) interface{} { - valueOf := reflect.ValueOf(value) - if valueOf.Kind() == reflect.Ptr { - if valueOf.IsNil() { - return "" - } - - return valueOf.Elem().Interface() - } - - return value -} - -func getErrorPath(err validator.FieldError) string { - path := err.Field() - - if err.Namespace() != "" { - // Strip first item in struct namespace - idx := strings.Index(err.Namespace(), ".") - if idx != -1 { - path = err.Namespace()[(idx + 1):] - } - } - - return path -} - -func constructValidationError(field string, value any, suffix string) string { - formattedValue, err := json.Marshal(value) - if err != nil { - formattedValue = []byte(fmt.Sprintf("%v", value)) - } - - return fmt.Sprintf("Invalid value %s for parameter '%s' supplied%s", formattedValue, field, suffix) -} - -func mkTruncateValidationError(field string, value interface{}, err validator.FieldError) string { - strValue, ok := value.(string) - if ok { - expected := len(strValue) - - if expected > MaxValidationInputLength { - strValue = strValue[:MaxValidationInputLength] + "..." - } - - return constructValidationError( - field, - strValue, - fmt.Sprintf(": length %d exceeded length limit of %s", expected, err.Param()), - ) - } - - return constructValidationError(field, value, "") -} - -func mkMaxValidationError(field string, value interface{}, err validator.FieldError) string { - if _, ok := value.(string); ok { - return fmt.Sprintf( - "'%s' exceeds the maximum length of %s characters", - field, - err.Param(), - ) - } - - return constructValidationError(field, value, "") -} - -func NewErrorFromValidationError(err error) *contract.Error { - var validatorValidationErrors validator.ValidationErrors - if errors.As(err, &validatorValidationErrors) { - validationErrors := make([]string, 0) - - for _, err := range validatorValidationErrors { - field := getErrorPath(err) - tag := err.Tag() - value := dereference(err.Value()) - - switch tag { - case "required": - validationErrors = append( - validationErrors, - fmt.Sprintf("Missing value for required parameter '%s'", field), - ) - case "truncate": - validationErrors = append(validationErrors, mkTruncateValidationError(field, value, err)) - case "uniqueParams": - validationErrors = append( - validationErrors, - "Duplicate parameter keys have been submitted", - ) - case "max": - validationErrors = append(validationErrors, mkMaxValidationError(field, value, err)) - default: - validationErrors = append( - validationErrors, - constructValidationError(field, value, ""), - ) - } - } - - return contract.NewError(protos.ErrorCode_INVALID_PARAMETER_VALUE, strings.Join(validationErrors, ", ")) - } - - return contract.NewError(protos.ErrorCode_INTERNAL_ERROR, err.Error()) -} +package validation + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "path/filepath" + "reflect" + "regexp" + "strconv" + "strings" + + "github.com/go-playground/validator/v10" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +const ( + QuoteLength = 2 + MaxEntitiesPerBatch = 1000 + MaxValidationInputLength = 100 +) + +// regex for valid param and metric names: may only contain slashes, alphanumerics, +// underscores, periods, dashes, and spaces. +var paramAndMetricNameRegex = regexp.MustCompile(`^[/\w.\- ]*$`) + +// regex for valid run IDs: must be an alphanumeric string of length 1 to 256. +var runIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][\w\-]{0,255}$`) + +func stringAsPositiveIntegerValidation(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + + value, err := strconv.Atoi(valueStr) + if err != nil { + return false + } + + return value > -1 +} + +func uriWithoutFragmentsOrParamsOrDotDotInQueryValidation(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + if valueStr == "" { + return true + } + + u, err := url.Parse(valueStr) + if err != nil { + return false + } + + return u.Fragment == "" && u.RawQuery == "" && !strings.Contains(u.RawQuery, "..") +} + +func uniqueParamsValidation(fl validator.FieldLevel) bool { + value := fl.Field() + + params, areParams := value.Interface().([]*protos.Param) + if !areParams || len(params) == 0 { + return true + } + + hasDuplicates := false + keys := make(map[string]bool, len(params)) + + for _, param := range params { + if _, ok := keys[param.GetKey()]; ok { + hasDuplicates = true + + break + } + + keys[param.GetKey()] = true + } + + return !hasDuplicates +} + +func pathIsClean(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + norm := filepath.Clean(valueStr) + + return !(norm != valueStr || norm == "." || strings.HasPrefix(norm, "..") || strings.HasPrefix(norm, "/")) +} + +func regexValidation(regex *regexp.Regexp) validator.Func { + return func(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + + return regex.MatchString(valueStr) + } +} + +// see _validate_batch_log_limits in validation.py. +func validateLogBatchLimits(structLevel validator.StructLevel) { + logBatch, isLogBatch := structLevel.Current().Interface().(*protos.LogBatch) + + if isLogBatch { + total := len(logBatch.GetParams()) + len(logBatch.GetMetrics()) + len(logBatch.GetTags()) + if total > MaxEntitiesPerBatch { + structLevel.ReportError(&logBatch, "metrics, params, and tags", "", "", "") + } + } +} + +func truncateFn(fieldLevel validator.FieldLevel) bool { + param := fieldLevel.Param() // Get the parameter from the tag + + maxLength, err := strconv.Atoi(param) + if err != nil { + return false // If the parameter isn't a valid integer, fail the validation. + } + + truncateLongValues, shouldTruncate := os.LookupEnv("MLFLOW_TRUNCATE_LONG_VALUES") + shouldTruncate = shouldTruncate && truncateLongValues == "true" + + field := fieldLevel.Field() + + if field.Kind() == reflect.String { + strValue := field.String() + if len(strValue) <= maxLength { + return true + } + + if shouldTruncate { + field.SetString(strValue[:maxLength]) + + return true + } + + return false + } + + return true +} + +func NewValidator() (*validator.Validate, error) { + validate := validator.New() + + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", QuoteLength)[0] + // skip if tag key says it should be ignored + if name == "-" { + return "" + } + + return name + }) + + // Verify that the input string is a positive integer. + if err := validate.RegisterValidation( + "stringAsPositiveInteger", stringAsPositiveIntegerValidation, + ); err != nil { + return nil, fmt.Errorf("validation registration for 'stringAsPositiveInteger' failed: %w", err) + } + + // Verify that the input string, if present, is a Url without fragment or query parameters + if err := validate.RegisterValidation( + "uriWithoutFragmentsOrParamsOrDotDotInQuery", uriWithoutFragmentsOrParamsOrDotDotInQueryValidation); err != nil { + return nil, fmt.Errorf("validation registration for 'uriWithoutFragmentsOrParamsOrDotDotInQuery' failed: %w", err) + } + + if err := validate.RegisterValidation( + "validMetricParamOrTagName", regexValidation(paramAndMetricNameRegex), + ); err != nil { + return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagName' failed: %w", err) + } + + if err := validate.RegisterValidation("pathIsUnique", pathIsClean); err != nil { + return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagValue' failed: %w", err) + } + + // unique params in LogBatch + if err := validate.RegisterValidation("uniqueParams", uniqueParamsValidation); err != nil { + return nil, fmt.Errorf("validation registration for 'uniqueParams' failed: %w", err) + } + + if err := validate.RegisterValidation("runId", regexValidation(runIDRegex)); err != nil { + return nil, fmt.Errorf("validation registration for 'runId' failed: %w", err) + } + + if err := validate.RegisterValidation("truncate", truncateFn); err != nil { + return nil, fmt.Errorf("validation registration for 'truncateFn' failed: %w", err) + } + + validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{}) + + return validate, nil +} + +func dereference(value interface{}) interface{} { + valueOf := reflect.ValueOf(value) + if valueOf.Kind() == reflect.Ptr { + if valueOf.IsNil() { + return "" + } + + return valueOf.Elem().Interface() + } + + return value +} + +func getErrorPath(err validator.FieldError) string { + path := err.Field() + + if err.Namespace() != "" { + // Strip first item in struct namespace + idx := strings.Index(err.Namespace(), ".") + if idx != -1 { + path = err.Namespace()[(idx + 1):] + } + } + + return path +} + +func constructValidationError(field string, value any, suffix string) string { + formattedValue, err := json.Marshal(value) + if err != nil { + formattedValue = []byte(fmt.Sprintf("%v", value)) + } + + return fmt.Sprintf("Invalid value %s for parameter '%s' supplied%s", formattedValue, field, suffix) +} + +func mkTruncateValidationError(field string, value interface{}, err validator.FieldError) string { + strValue, ok := value.(string) + if ok { + expected := len(strValue) + + if expected > MaxValidationInputLength { + strValue = strValue[:MaxValidationInputLength] + "..." + } + + return constructValidationError( + field, + strValue, + fmt.Sprintf(": length %d exceeded length limit of %s", expected, err.Param()), + ) + } + + return constructValidationError(field, value, "") +} + +func mkMaxValidationError(field string, value interface{}, err validator.FieldError) string { + if _, ok := value.(string); ok { + return fmt.Sprintf( + "'%s' exceeds the maximum length of %s characters", + field, + err.Param(), + ) + } + + return constructValidationError(field, value, "") +} + +func NewErrorFromValidationError(err error) *contract.Error { + var validatorValidationErrors validator.ValidationErrors + if errors.As(err, &validatorValidationErrors) { + validationErrors := make([]string, 0) + + for _, err := range validatorValidationErrors { + field := getErrorPath(err) + tag := err.Tag() + value := dereference(err.Value()) + + switch tag { + case "required": + validationErrors = append( + validationErrors, + fmt.Sprintf("Missing value for required parameter '%s'", field), + ) + case "truncate": + validationErrors = append(validationErrors, mkTruncateValidationError(field, value, err)) + case "uniqueParams": + validationErrors = append( + validationErrors, + "Duplicate parameter keys have been submitted", + ) + case "max": + validationErrors = append(validationErrors, mkMaxValidationError(field, value, err)) + default: + validationErrors = append( + validationErrors, + constructValidationError(field, value, ""), + ) + } + } + + return contract.NewError(protos.ErrorCode_INVALID_PARAMETER_VALUE, strings.Join(validationErrors, ", ")) + } + + return contract.NewError(protos.ErrorCode_INTERNAL_ERROR, err.Error()) +} diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go index ea480ff..52a5db7 100644 --- a/pkg/validation/validation_test.go +++ b/pkg/validation/validation_test.go @@ -1,244 +1,244 @@ -package validation_test - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -type PositiveInteger struct { - Value string `validate:"stringAsPositiveInteger"` -} - -type validationScenario struct { - name string - input any - shouldTrigger bool -} - -func runscenarios(t *testing.T, scenarios []validationScenario) { - t.Helper() - - validator, err := validation.NewValidator() - require.NoError(t, err) - - for _, scenario := range scenarios { - currentScenario := scenario - t.Run(currentScenario.name, func(t *testing.T) { - t.Parallel() - - errs := validator.Struct(currentScenario.input) - - if currentScenario.shouldTrigger && errs == nil { - t.Errorf("Expected validation error, got nil") - } - - if !currentScenario.shouldTrigger && errs != nil { - t.Errorf("Expected no validation error, got %v", errs) - } - }) - } -} - -func TestStringAsPositiveInteger(t *testing.T) { - t.Parallel() - - scenarios := []validationScenario{ - { - name: "positive integer", - input: PositiveInteger{Value: "1"}, - shouldTrigger: false, - }, - { - name: "zero", - input: PositiveInteger{Value: "0"}, - shouldTrigger: false, - }, - { - name: "negative integer", - input: PositiveInteger{Value: "-1"}, - shouldTrigger: true, - }, - { - name: "alphabet", - input: PositiveInteger{Value: "a"}, - shouldTrigger: true, - }, - } - - runscenarios(t, scenarios) -} - -type uriWithoutFragmentsOrParams struct { - Value string `validate:"uriWithoutFragmentsOrParamsOrDotDotInQuery"` -} - -func TestUriWithoutFragmentsOrParams(t *testing.T) { - t.Parallel() - - scenarios := []validationScenario{ - { - name: "valid url", - input: uriWithoutFragmentsOrParams{Value: "http://example.com"}, - shouldTrigger: false, - }, - { - name: "only trigger when url is not empty", - input: uriWithoutFragmentsOrParams{Value: ""}, - shouldTrigger: false, - }, - { - name: "url with fragment", - input: uriWithoutFragmentsOrParams{Value: "http://example.com#fragment"}, - shouldTrigger: true, - }, - { - name: "url with query parameters", - input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=param"}, - shouldTrigger: true, - }, - { - name: "unparsable url", - input: uriWithoutFragmentsOrParams{Value: ":invalid-url"}, - shouldTrigger: true, - }, - { - name: ".. in query", - input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=./.."}, - shouldTrigger: true, - }, - } - - runscenarios(t, scenarios) -} - -func TestUniqueParamsInLogBatch(t *testing.T) { - t.Parallel() - - logBatchRequest := &protos.LogBatch{ - Params: []*protos.Param{ - {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value1")}, - {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value2")}, - }, - } - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(logBatchRequest) - if err == nil { - t.Error("Expected uniqueParams validation error, got none") - } -} - -func TestEmptyParamsInLogBatch(t *testing.T) { - t.Parallel() - - logBatchRequest := &protos.LogBatch{ - RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), - Params: make([]*protos.Param, 0), - } - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(logBatchRequest) - if err != nil { - t.Errorf("Unexpected uniqueParams validation error, got %v", err) - } -} - -func TestMissingTimestampInNestedMetric(t *testing.T) { - t.Parallel() - - serverValidator, err := validation.NewValidator() - require.NoError(t, err) - - logBatch := protos.LogBatch{ - RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), - Metrics: []*protos.Metric{ - { - Key: utils.PtrTo("mae"), - Value: utils.PtrTo(2.5), - }, - }, - } - - err = serverValidator.Struct(&logBatch) - if err == nil { - t.Error("Expected dive validation error, got none") - } - - msg := validation.NewErrorFromValidationError(err).Message - if !strings.Contains(msg, "metrics[0].timestamp") { - t.Errorf("Expected required validation error for nested property, got %v", msg) - } -} - -type avecTruncate struct { - X *string `validate:"truncate=3"` - Y string `validate:"truncate=3"` -} - -func TestTruncate(t *testing.T) { - input := &avecTruncate{ - X: utils.PtrTo("123456"), - Y: "654321", - } - - t.Setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true") - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(input) - require.NoError(t, err) - - if len(*input.X) != 3 { - t.Errorf("Expected the length of x to be 3, was %d", len(*input.X)) - } - - if len(input.Y) != 3 { - t.Errorf("Expected the length of y to be 3, was %d", len(input.Y)) - } -} - -// This unit test is a sanity test that confirms the `dive` validation -// enters a nested slice of pointer structs. -func TestNestedErrorsInSubCollection(t *testing.T) { - t.Parallel() - - value := strings.Repeat("X", 6001) + "Y" - - logBatchRequest := &protos.LogBatch{ - RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), - Params: []*protos.Param{ - {Key: utils.PtrTo("key1"), Value: utils.PtrTo(value)}, - {Key: utils.PtrTo("key2"), Value: utils.PtrTo(value)}, - }, - } - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(logBatchRequest) - if err != nil { - msg := validation.NewErrorFromValidationError(err).Message - // Assert the root struct name is not present in the error message - if strings.Contains(msg, "logBatch") { - t.Errorf("Validation message contained root struct name, got %s", msg) - } - - // Assert the index is listed in the parameter path - if !strings.Contains(msg, "params[0].value") || - !strings.Contains(msg, "params[1].value") || - !strings.Contains(msg, "length 6002 exceeded length limit of 6000") { - t.Errorf("Unexpected validation error message, got %s", msg) - } - } -} +package validation_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +type PositiveInteger struct { + Value string `validate:"stringAsPositiveInteger"` +} + +type validationScenario struct { + name string + input any + shouldTrigger bool +} + +func runscenarios(t *testing.T, scenarios []validationScenario) { + t.Helper() + + validator, err := validation.NewValidator() + require.NoError(t, err) + + for _, scenario := range scenarios { + currentScenario := scenario + t.Run(currentScenario.name, func(t *testing.T) { + t.Parallel() + + errs := validator.Struct(currentScenario.input) + + if currentScenario.shouldTrigger && errs == nil { + t.Errorf("Expected validation error, got nil") + } + + if !currentScenario.shouldTrigger && errs != nil { + t.Errorf("Expected no validation error, got %v", errs) + } + }) + } +} + +func TestStringAsPositiveInteger(t *testing.T) { + t.Parallel() + + scenarios := []validationScenario{ + { + name: "positive integer", + input: PositiveInteger{Value: "1"}, + shouldTrigger: false, + }, + { + name: "zero", + input: PositiveInteger{Value: "0"}, + shouldTrigger: false, + }, + { + name: "negative integer", + input: PositiveInteger{Value: "-1"}, + shouldTrigger: true, + }, + { + name: "alphabet", + input: PositiveInteger{Value: "a"}, + shouldTrigger: true, + }, + } + + runscenarios(t, scenarios) +} + +type uriWithoutFragmentsOrParams struct { + Value string `validate:"uriWithoutFragmentsOrParamsOrDotDotInQuery"` +} + +func TestUriWithoutFragmentsOrParams(t *testing.T) { + t.Parallel() + + scenarios := []validationScenario{ + { + name: "valid url", + input: uriWithoutFragmentsOrParams{Value: "http://example.com"}, + shouldTrigger: false, + }, + { + name: "only trigger when url is not empty", + input: uriWithoutFragmentsOrParams{Value: ""}, + shouldTrigger: false, + }, + { + name: "url with fragment", + input: uriWithoutFragmentsOrParams{Value: "http://example.com#fragment"}, + shouldTrigger: true, + }, + { + name: "url with query parameters", + input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=param"}, + shouldTrigger: true, + }, + { + name: "unparsable url", + input: uriWithoutFragmentsOrParams{Value: ":invalid-url"}, + shouldTrigger: true, + }, + { + name: ".. in query", + input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=./.."}, + shouldTrigger: true, + }, + } + + runscenarios(t, scenarios) +} + +func TestUniqueParamsInLogBatch(t *testing.T) { + t.Parallel() + + logBatchRequest := &protos.LogBatch{ + Params: []*protos.Param{ + {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value1")}, + {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value2")}, + }, + } + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(logBatchRequest) + if err == nil { + t.Error("Expected uniqueParams validation error, got none") + } +} + +func TestEmptyParamsInLogBatch(t *testing.T) { + t.Parallel() + + logBatchRequest := &protos.LogBatch{ + RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), + Params: make([]*protos.Param, 0), + } + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(logBatchRequest) + if err != nil { + t.Errorf("Unexpected uniqueParams validation error, got %v", err) + } +} + +func TestMissingTimestampInNestedMetric(t *testing.T) { + t.Parallel() + + serverValidator, err := validation.NewValidator() + require.NoError(t, err) + + logBatch := protos.LogBatch{ + RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), + Metrics: []*protos.Metric{ + { + Key: utils.PtrTo("mae"), + Value: utils.PtrTo(2.5), + }, + }, + } + + err = serverValidator.Struct(&logBatch) + if err == nil { + t.Error("Expected dive validation error, got none") + } + + msg := validation.NewErrorFromValidationError(err).Message + if !strings.Contains(msg, "metrics[0].timestamp") { + t.Errorf("Expected required validation error for nested property, got %v", msg) + } +} + +type avecTruncate struct { + X *string `validate:"truncate=3"` + Y string `validate:"truncate=3"` +} + +func TestTruncate(t *testing.T) { + input := &avecTruncate{ + X: utils.PtrTo("123456"), + Y: "654321", + } + + t.Setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true") + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(input) + require.NoError(t, err) + + if len(*input.X) != 3 { + t.Errorf("Expected the length of x to be 3, was %d", len(*input.X)) + } + + if len(input.Y) != 3 { + t.Errorf("Expected the length of y to be 3, was %d", len(input.Y)) + } +} + +// This unit test is a sanity test that confirms the `dive` validation +// enters a nested slice of pointer structs. +func TestNestedErrorsInSubCollection(t *testing.T) { + t.Parallel() + + value := strings.Repeat("X", 6001) + "Y" + + logBatchRequest := &protos.LogBatch{ + RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), + Params: []*protos.Param{ + {Key: utils.PtrTo("key1"), Value: utils.PtrTo(value)}, + {Key: utils.PtrTo("key2"), Value: utils.PtrTo(value)}, + }, + } + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(logBatchRequest) + if err != nil { + msg := validation.NewErrorFromValidationError(err).Message + // Assert the root struct name is not present in the error message + if strings.Contains(msg, "logBatch") { + t.Errorf("Validation message contained root struct name, got %s", msg) + } + + // Assert the index is listed in the parameter path + if !strings.Contains(msg, "params[0].value") || + !strings.Contains(msg, "params[1].value") || + !strings.Contains(msg, "length 6002 exceeded length limit of 6000") { + t.Errorf("Unexpected validation error message, got %s", msg) + } + } +} diff --git a/pyproject.toml b/pyproject.toml index b4cf50f..1e40bae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,173 +1,173 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "mlflow-go" -version = "2.14.1" -description = "MLflow is an open source platform for the complete machine learning lifecycle" -readme = "README.md" -keywords = ["mlflow", "ai", "databricks"] -classifiers = [ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Intended Audience :: End Users/Desktop", - "Intended Audience :: Science/Research", - "Intended Audience :: Information Technology", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", -] -requires-python = ">=3.8" -dependencies = ["mlflow==2.14.1", "cffi"] -license = { file = "LICENSE.txt" } - -[[project.maintainers]] -name = "Databricks" -email = "mlflow-oss-maintainers@googlegroups.com" - -[project.urls] -homepage = "https://mlflow.org" -issues = "https://github.com/mlflow/mlflow-go/issues" -documentation = "https://mlflow.org/docs/latest/index.html" -repository = "https://github.com/mlflow/mlflow-go" - -[project.scripts] -mlflow-go = "mlflow_go.cli:cli" - -[project.entry-points."mlflow.tracking_store"] -mssql = "mlflow_go.store.tracking:_get_sqlalchemy_store" -mysql = "mlflow_go.store.tracking:_get_sqlalchemy_store" -postgresql = "mlflow_go.store.tracking:_get_sqlalchemy_store" -sqlite = "mlflow_go.store.tracking:_get_sqlalchemy_store" - -[project.entry-points."mlflow.model_registry_store"] -mssql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" -mysql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" -postgresql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" -sqlite = "mlflow_go.store.model_registry:_get_sqlalchemy_store" - -[tool.setuptools.packages.find] -where = ["."] -include = ["mlflow_go", "mlflow_go.*"] -exclude = ["tests", "tests.*"] - -[tool.ruff] -line-length = 100 -target-version = "py38" -force-exclude = true -extend-include = ["*.ipynb"] -extend-exclude = [ - "examples/recipes", - "mlflow/protos", - "mlflow/ml_package_versions.py", - "mlflow/server/graphql/autogenerated_graphql_schema.py", - "mlflow/server/js", - "mlflow/store/db_migrations", - "tests/protos", -] - -[tool.ruff.format] -docstring-code-format = true -docstring-code-line-length = 88 - -[tool.ruff.lint] -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" -select = [ - "B006", # multiple-argument-default - "B015", # useless-comparison - "D209", # new-line-after-last-paragraph - "D411", # no-blank-line-before-section - "E", # error - "F", # Pyflakes - "C4", # flake8-comprehensions - "I", # isort - "ISC001", # single-line-implicit-string-concatenation - "PIE790", # unnecessary-placeholder - "PLR0402", # manual-from-import - "PLE1205", # logging-too-many-args - "PT001", # pytest-fixture-incorrect-parentheses-style - "PT002", # pytest-fixture-positional-args - "PT003", # pytest-extraneous-scope-function - "PT006", # pytest-parameterize-names-wrong-type - "PT007", # pytest-parameterize-values-wrong-type - "PT009", # pytest-unittest-assertion - "PT010", # pytest-raises-without-exception - "PT011", # pytest-raises-too-broad - "PT012", # pytest-raises-with-multiple-statements - "PT013", # pytest-incorrect-pytest-import - "PT014", # pytest-duplicate-parametrize-test-cases - "PT018", # pytest-composite-assertion - "PT022", # pytest-useless-yield-fixture - "PT023", # pytest-incorrect-mark-parentheses-style - "PT026", # pytest-use-fixtures-without-parameters - "PT027", # pytest-unittest-raises-assertion - "RET504", # unnecessary-assign - "RUF010", # explicit-f-string-type-conversion - "RUF013", # implicit-optional - "RUF100", # unused-noqa - "S307", # suspicious-eval-usage - "S324", # hashlib-insecure-hash-function - "SIM101", # duplicate-isinstance-call - "SIM103", # needless-bool - "SIM108", # if-else-block-instead-of-if-exp - "SIM114", # if-with-same-arms - "SIM115", # open-file-with-context-handler - "SIM210", # if-expr-with-true-false - "SIM910", # dict-get-with-none-default - "T20", # flake8-print - "TID251", # banned-api - "TID252", # relative-improt - "TRY302", # useless-try-except - "UP004", # useless-object-inheritance - "UP008", # super-call-with-parameters - "UP011", # lru-cache-without-parameters - "UP012", # unecessary-encode-utf8 - "UP015", # redundant-open-modes - "UP030", # format-literals - "UP031", # printf-string-format - "UP032", # f-string - "UP034", # extraneous-parenthesis - "W", # warning -] -ignore = [ - "E402", # module-import-not-at-top-of-file - "E721", # type-comparison - "E741", # ambiguous-variable-name - "F811", # redefined-while-unused -] - -[tool.ruff.lint.per-file-ignores] -"dev/*" = ["T201", "PT018"] -"examples/*" = ["T20", "RET504", "E501"] -"docs/*" = ["T20", "RET504", "E501"] -"mlflow/*" = ["PT018"] - -[tool.ruff.lint.flake8-pytest-style] -mark-parentheses = false -fixture-parentheses = false -raises-require-match-for = ["*"] - -[tool.ruff.lint.flake8-tidy-imports] -ban-relative-imports = "all" - -[tool.ruff.lint.isort] -forced-separate = ["tests"] - -[tool.ruff.lint.flake8-tidy-imports.banned-api] -"pkg_resources".msg = "We're migrating away from pkg_resources. Please use importlib.resources or importlib.metadata instead." - -[tool.ruff.lint.pydocstyle] -convention = "google" - -[tool.clint] -exclude = [ - "docs", - "mlflow/protos", - "mlflow/ml_package_versions.py", - "mlflow/server/js", - "mlflow/store/db_migrations", - "tests/protos", -] +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mlflow-go" +version = "2.14.1" +description = "MLflow is an open source platform for the complete machine learning lifecycle" +readme = "README.md" +keywords = ["mlflow", "ai", "databricks"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: End Users/Desktop", + "Intended Audience :: Science/Research", + "Intended Audience :: Information Technology", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.8", +] +requires-python = ">=3.8" +dependencies = ["mlflow==2.14.1", "cffi"] +license = { file = "LICENSE.txt" } + +[[project.maintainers]] +name = "Databricks" +email = "mlflow-oss-maintainers@googlegroups.com" + +[project.urls] +homepage = "https://mlflow.org" +issues = "https://github.com/mlflow/mlflow-go/issues" +documentation = "https://mlflow.org/docs/latest/index.html" +repository = "https://github.com/mlflow/mlflow-go" + +[project.scripts] +mlflow-go = "mlflow_go.cli:cli" + +[project.entry-points."mlflow.tracking_store"] +mssql = "mlflow_go.store.tracking:_get_sqlalchemy_store" +mysql = "mlflow_go.store.tracking:_get_sqlalchemy_store" +postgresql = "mlflow_go.store.tracking:_get_sqlalchemy_store" +sqlite = "mlflow_go.store.tracking:_get_sqlalchemy_store" + +[project.entry-points."mlflow.model_registry_store"] +mssql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" +mysql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" +postgresql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" +sqlite = "mlflow_go.store.model_registry:_get_sqlalchemy_store" + +[tool.setuptools.packages.find] +where = ["."] +include = ["mlflow_go", "mlflow_go.*"] +exclude = ["tests", "tests.*"] + +[tool.ruff] +line-length = 100 +target-version = "py38" +force-exclude = true +extend-include = ["*.ipynb"] +extend-exclude = [ + "examples/recipes", + "mlflow/protos", + "mlflow/ml_package_versions.py", + "mlflow/server/graphql/autogenerated_graphql_schema.py", + "mlflow/server/js", + "mlflow/store/db_migrations", + "tests/protos", +] + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + +[tool.ruff.lint] +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +select = [ + "B006", # multiple-argument-default + "B015", # useless-comparison + "D209", # new-line-after-last-paragraph + "D411", # no-blank-line-before-section + "E", # error + "F", # Pyflakes + "C4", # flake8-comprehensions + "I", # isort + "ISC001", # single-line-implicit-string-concatenation + "PIE790", # unnecessary-placeholder + "PLR0402", # manual-from-import + "PLE1205", # logging-too-many-args + "PT001", # pytest-fixture-incorrect-parentheses-style + "PT002", # pytest-fixture-positional-args + "PT003", # pytest-extraneous-scope-function + "PT006", # pytest-parameterize-names-wrong-type + "PT007", # pytest-parameterize-values-wrong-type + "PT009", # pytest-unittest-assertion + "PT010", # pytest-raises-without-exception + "PT011", # pytest-raises-too-broad + "PT012", # pytest-raises-with-multiple-statements + "PT013", # pytest-incorrect-pytest-import + "PT014", # pytest-duplicate-parametrize-test-cases + "PT018", # pytest-composite-assertion + "PT022", # pytest-useless-yield-fixture + "PT023", # pytest-incorrect-mark-parentheses-style + "PT026", # pytest-use-fixtures-without-parameters + "PT027", # pytest-unittest-raises-assertion + "RET504", # unnecessary-assign + "RUF010", # explicit-f-string-type-conversion + "RUF013", # implicit-optional + "RUF100", # unused-noqa + "S307", # suspicious-eval-usage + "S324", # hashlib-insecure-hash-function + "SIM101", # duplicate-isinstance-call + "SIM103", # needless-bool + "SIM108", # if-else-block-instead-of-if-exp + "SIM114", # if-with-same-arms + "SIM115", # open-file-with-context-handler + "SIM210", # if-expr-with-true-false + "SIM910", # dict-get-with-none-default + "T20", # flake8-print + "TID251", # banned-api + "TID252", # relative-improt + "TRY302", # useless-try-except + "UP004", # useless-object-inheritance + "UP008", # super-call-with-parameters + "UP011", # lru-cache-without-parameters + "UP012", # unecessary-encode-utf8 + "UP015", # redundant-open-modes + "UP030", # format-literals + "UP031", # printf-string-format + "UP032", # f-string + "UP034", # extraneous-parenthesis + "W", # warning +] +ignore = [ + "E402", # module-import-not-at-top-of-file + "E721", # type-comparison + "E741", # ambiguous-variable-name + "F811", # redefined-while-unused +] + +[tool.ruff.lint.per-file-ignores] +"dev/*" = ["T201", "PT018"] +"examples/*" = ["T20", "RET504", "E501"] +"docs/*" = ["T20", "RET504", "E501"] +"mlflow/*" = ["PT018"] + +[tool.ruff.lint.flake8-pytest-style] +mark-parentheses = false +fixture-parentheses = false +raises-require-match-for = ["*"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.lint.isort] +forced-separate = ["tests"] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"pkg_resources".msg = "We're migrating away from pkg_resources. Please use importlib.resources or importlib.metadata instead." + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.clint] +exclude = [ + "docs", + "mlflow/protos", + "mlflow/ml_package_versions.py", + "mlflow/server/js", + "mlflow/store/db_migrations", + "tests/protos", +] diff --git a/setup.py b/setup.py index 74688c8..f209bc5 100644 --- a/setup.py +++ b/setup.py @@ -1,66 +1,66 @@ -import os -import pathlib -import sys -from glob import glob -from typing import List, Tuple - -from setuptools import Distribution, setup - -sys.path.insert(0, pathlib.Path(__file__).parent.joinpath("mlflow_go").as_posix()) -from lib import build_lib - - -def _prune_go_files(path: str): - for root, dirnames, filenames in os.walk(path, topdown=False): - for filename in filenames: - if filename.endswith(".go"): - os.unlink(os.path.join(root, filename)) - for dirname in dirnames: - try: - os.rmdir(os.path.join(root, dirname)) - except OSError: - pass - - -def finalize_distribution_options(dist: Distribution) -> None: - dist.has_ext_modules = lambda: True - - # this allows us to set the tag for the wheel without the python version - bdist_wheel_base_class = dist.get_command_class("bdist_wheel") - - class bdist_wheel_go(bdist_wheel_base_class): - def get_tag(self) -> Tuple[str, str, str]: - _, _, plat = super().get_tag() - return "py3", "none", plat - - dist.cmdclass["bdist_wheel"] = bdist_wheel_go - - # this allows us to build the go binary and add the Go source files to the sdist - build_base_class = dist.get_command_class("build") - - class build_go(build_base_class): - def initialize_options(self) -> None: - self.editable_mode = False - self.build_lib = None - - def finalize_options(self) -> None: - self.set_undefined_options("build_py", ("build_lib", "build_lib")) - - def run(self) -> None: - if not self.editable_mode: - _prune_go_files(self.build_lib) - build_lib( - pathlib.Path("."), - pathlib.Path(self.build_lib).joinpath("mlflow_go"), - ) - - def get_source_files(self) -> List[str]: - return ["go.mod", "go.sum", *glob("pkg/**/*.go", recursive=True)] - - dist.cmdclass["build_go"] = build_go - build_base_class.sub_commands.append(("build_go", None)) - - -Distribution.finalize_options = finalize_distribution_options - -setup() +import os +import pathlib +import sys +from glob import glob +from typing import List, Tuple + +from setuptools import Distribution, setup + +sys.path.insert(0, pathlib.Path(__file__).parent.joinpath("mlflow_go").as_posix()) +from lib import build_lib + + +def _prune_go_files(path: str): + for root, dirnames, filenames in os.walk(path, topdown=False): + for filename in filenames: + if filename.endswith(".go"): + os.unlink(os.path.join(root, filename)) + for dirname in dirnames: + try: + os.rmdir(os.path.join(root, dirname)) + except OSError: + pass + + +def finalize_distribution_options(dist: Distribution) -> None: + dist.has_ext_modules = lambda: True + + # this allows us to set the tag for the wheel without the python version + bdist_wheel_base_class = dist.get_command_class("bdist_wheel") + + class bdist_wheel_go(bdist_wheel_base_class): + def get_tag(self) -> Tuple[str, str, str]: + _, _, plat = super().get_tag() + return "py3", "none", plat + + dist.cmdclass["bdist_wheel"] = bdist_wheel_go + + # this allows us to build the go binary and add the Go source files to the sdist + build_base_class = dist.get_command_class("build") + + class build_go(build_base_class): + def initialize_options(self) -> None: + self.editable_mode = False + self.build_lib = None + + def finalize_options(self) -> None: + self.set_undefined_options("build_py", ("build_lib", "build_lib")) + + def run(self) -> None: + if not self.editable_mode: + _prune_go_files(self.build_lib) + build_lib( + pathlib.Path("."), + pathlib.Path(self.build_lib).joinpath("mlflow_go"), + ) + + def get_source_files(self) -> List[str]: + return ["go.mod", "go.sum", *glob("pkg/**/*.go", recursive=True)] + + dist.cmdclass["build_go"] = build_go + build_base_class.sub_commands.append(("build_go", None)) + + +Distribution.finalize_options = finalize_distribution_options + +setup() diff --git a/tests/override_model_registry_store.py b/tests/override_model_registry_store.py index d1678fb..e711e22 100644 --- a/tests/override_model_registry_store.py +++ b/tests/override_model_registry_store.py @@ -1,5 +1,5 @@ -from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore - -from mlflow_go.store.model_registry import ModelRegistryStore - -SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) +from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore + +from mlflow_go.store.model_registry import ModelRegistryStore + +SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) diff --git a/tests/override_server.py b/tests/override_server.py index a529d9c..6fd7550 100644 --- a/tests/override_server.py +++ b/tests/override_server.py @@ -1,77 +1,77 @@ -import contextlib -import logging -import sys - -import mlflow -import pytest -from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR -from mlflow.server.handlers import ModelRegistryStoreRegistryWrapper, TrackingStoreRegistryWrapper -from mlflow.utils import find_free_port - -from mlflow_go.server import server - -from tests.helper_functions import LOCALHOST -from tests.tracking.integration_test_utils import _await_server_up_or_die - -_logger = logging.getLogger(__name__) - - -@contextlib.contextmanager -def _init_server(backend_uri, root_artifact_uri, extra_env=None, app="mlflow.server:app"): - """ - Launch a new REST server using the tracking store specified by backend_uri and root artifact - directory specified by root_artifact_uri. - :returns A string URL to the server. - """ - scheme = backend_uri.split("://")[0] - if ("sqlite" or "postgresql" or "mysql" or "mssql") not in scheme: - pytest.skip(f'Unsupported scheme "{scheme}" for the Go server') - - mlflow.set_tracking_uri(None) - - server_port = find_free_port() - python_port = find_free_port() - url = f"http://{LOCALHOST}:{server_port}" - - _logger.info( - f"Initializing stores with backend URI {backend_uri} and artifact root {root_artifact_uri}" - ) - TrackingStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) - ModelRegistryStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) - - _logger.info( - f"Launching tracking server on {url} with backend URI {backend_uri} and " - f"artifact root {root_artifact_uri}" - ) - - with server( - address=f"{LOCALHOST}:{server_port}", - default_artifact_root=root_artifact_uri, - log_level=logging.getLevelName(_logger.getEffectiveLevel()), - model_registry_store_uri=backend_uri, - python_address=f"{LOCALHOST}:{python_port}", - python_command=[ - sys.executable, - "-m", - "flask", - "--app", - app, - "run", - "--host", - LOCALHOST, - "--port", - str(python_port), - ], - python_env=[ - f"{k}={v}" - for k, v in { - BACKEND_STORE_URI_ENV_VAR: backend_uri, - ARTIFACT_ROOT_ENV_VAR: root_artifact_uri, - **(extra_env or {}), - }.items() - ], - shutdown_timeout="5s", - tracking_store_uri=backend_uri, - ): - _await_server_up_or_die(server_port) - yield url +import contextlib +import logging +import sys + +import mlflow +import pytest +from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR +from mlflow.server.handlers import ModelRegistryStoreRegistryWrapper, TrackingStoreRegistryWrapper +from mlflow.utils import find_free_port + +from mlflow_go.server import server + +from tests.helper_functions import LOCALHOST +from tests.tracking.integration_test_utils import _await_server_up_or_die + +_logger = logging.getLogger(__name__) + + +@contextlib.contextmanager +def _init_server(backend_uri, root_artifact_uri, extra_env=None, app="mlflow.server:app"): + """ + Launch a new REST server using the tracking store specified by backend_uri and root artifact + directory specified by root_artifact_uri. + :returns A string URL to the server. + """ + scheme = backend_uri.split("://")[0] + if ("sqlite" or "postgresql" or "mysql" or "mssql") not in scheme: + pytest.skip(f'Unsupported scheme "{scheme}" for the Go server') + + mlflow.set_tracking_uri(None) + + server_port = find_free_port() + python_port = find_free_port() + url = f"http://{LOCALHOST}:{server_port}" + + _logger.info( + f"Initializing stores with backend URI {backend_uri} and artifact root {root_artifact_uri}" + ) + TrackingStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) + ModelRegistryStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) + + _logger.info( + f"Launching tracking server on {url} with backend URI {backend_uri} and " + f"artifact root {root_artifact_uri}" + ) + + with server( + address=f"{LOCALHOST}:{server_port}", + default_artifact_root=root_artifact_uri, + log_level=logging.getLevelName(_logger.getEffectiveLevel()), + model_registry_store_uri=backend_uri, + python_address=f"{LOCALHOST}:{python_port}", + python_command=[ + sys.executable, + "-m", + "flask", + "--app", + app, + "run", + "--host", + LOCALHOST, + "--port", + str(python_port), + ], + python_env=[ + f"{k}={v}" + for k, v in { + BACKEND_STORE_URI_ENV_VAR: backend_uri, + ARTIFACT_ROOT_ENV_VAR: root_artifact_uri, + **(extra_env or {}), + }.items() + ], + shutdown_timeout="5s", + tracking_store_uri=backend_uri, + ): + _await_server_up_or_die(server_port) + yield url diff --git a/tests/override_test_sqlalchemy_store.py b/tests/override_test_sqlalchemy_store.py index 13cad4b..4ed89cb 100644 --- a/tests/override_test_sqlalchemy_store.py +++ b/tests/override_test_sqlalchemy_store.py @@ -1,17 +1,17 @@ -from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore - - -def test_log_batch_internal_error(store: SqlAlchemyStore): - () - - -def test_log_batch_params_max_length_value(store: SqlAlchemyStore, monkeypatch): - () - - -def test_log_batch_null_metrics(store: SqlAlchemyStore): - () - - -def test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db(monkeypatch): - () +from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + + +def test_log_batch_internal_error(store: SqlAlchemyStore): + () + + +def test_log_batch_params_max_length_value(store: SqlAlchemyStore, monkeypatch): + () + + +def test_log_batch_null_metrics(store: SqlAlchemyStore): + () + + +def test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db(monkeypatch): + () diff --git a/tests/override_tracking_store.py b/tests/override_tracking_store.py index 26a7577..05dcb56 100644 --- a/tests/override_tracking_store.py +++ b/tests/override_tracking_store.py @@ -1,5 +1,5 @@ -from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore - -from mlflow_go.store.tracking import TrackingStore - -SqlAlchemyStore = TrackingStore(SqlAlchemyStore) +from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + +from mlflow_go.store.tracking import TrackingStore + +SqlAlchemyStore = TrackingStore(SqlAlchemyStore) From 97927f600f199ec3465e72209a527a2c33188828 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Fri, 4 Oct 2024 03:11:23 +0000 Subject: [PATCH 02/24] Fix line endings to LF --- .devcontainer/Dockerfile | 40 +- .devcontainer/devcontainer.json | 130 +- .devcontainer/docker-compose.yml | 76 +- .devcontainer/postCreate.sh | 14 +- .github/dependabot.yml | 32 +- .github/workflows/ci.yml | 36 +- .github/workflows/lint.yml | 44 +- .github/workflows/test.yml | 118 +- .gitignore | 30 +- .golangci.yml | 110 +- .mockery.yaml | 16 +- .pre-commit-config.yaml | 68 +- LICENSE | 402 ++-- README.md | 630 +++--- conftest.py | 128 +- go.mod | 124 +- go.sum | 472 ++--- magefiles/dev.go | 56 +- magefiles/endpoints.go | 112 +- magefiles/generate.go | 90 +- magefiles/generate/ast_creation.go | 204 +- magefiles/generate/discovery/discovery.go | 194 +- .../generate/discovery/discovery_test.go | 110 +- magefiles/generate/endpoints.go | 174 +- magefiles/generate/protos.go | 116 +- magefiles/generate/query_annotations.go | 232 +-- magefiles/generate/source_code.go | 936 ++++----- magefiles/generate/validations.go | 64 +- magefiles/repo.go | 440 ++-- magefiles/temp.go | 148 +- magefiles/tests.go | 204 +- mlflow_go/__init__.py | 40 +- mlflow_go/cli.py | 224 +- mlflow_go/lib.py | 248 +-- mlflow_go/server.py | 62 +- mlflow_go/store/_service_proxy.py | 86 +- mlflow_go/store/model_registry.py | 110 +- mlflow_go/store/tracking.py | 384 ++-- pkg/artifacts/service/service.go | 34 +- pkg/cmd/server/main.go | 42 +- pkg/config/config.go | 212 +- pkg/config/config_test.go | 106 +- pkg/contract/error.go | 164 +- pkg/contract/http_request_parser.go | 16 +- pkg/entities/dataset.go | 70 +- pkg/entities/dataset_input.go | 40 +- pkg/entities/experiment.go | 72 +- pkg/entities/experiment_tag.go | 44 +- pkg/entities/input_tag.go | 30 +- pkg/entities/metric.go | 104 +- pkg/entities/param.go | 44 +- pkg/entities/run.go | 150 +- pkg/entities/run_data.go | 14 +- pkg/entities/run_info.go | 68 +- pkg/entities/run_inputs.go | 10 +- pkg/entities/run_tag.go | 44 +- pkg/lib/artifacts.go | 44 +- pkg/lib/ffi.go | 182 +- pkg/lib/instance_map.go | 156 +- pkg/lib/main.go | 6 +- pkg/lib/model_registry.go | 44 +- pkg/lib/server.go | 166 +- pkg/lib/tracking.go | 44 +- pkg/lib/validation.go | 46 +- pkg/model_registry/service/model_versions.go | 42 +- pkg/model_registry/service/service.go | 54 +- .../store/sql/model_versions.go | 188 +- .../store/sql/models/model_version_stage.go | 66 +- .../store/sql/models/model_version_tags.go | 22 +- .../store/sql/models/model_versions.go | 98 +- .../sql/models/registered_model_aliases.go | 16 +- .../store/sql/models/registered_model_tags.go | 16 +- .../store/sql/models/registered_models.go | 18 +- pkg/model_registry/store/sql/store.go | 56 +- pkg/model_registry/store/store.go | 24 +- pkg/server/command/command.go | 84 +- pkg/server/command/command_posix.go | 60 +- pkg/server/command/command_windows.go | 154 +- pkg/server/launch.go | 172 +- pkg/server/parser/http_request_parser.go | 144 +- pkg/server/server.go | 446 ++-- pkg/sql/logger.go | 278 +-- pkg/sql/sql.go | 180 +- pkg/tracking/service/experiments.go | 268 +-- pkg/tracking/service/experiments_test.go | 122 +- pkg/tracking/service/metrics.go | 40 +- pkg/tracking/service/query/README.md | 16 +- pkg/tracking/service/query/lexer/token.go | 222 +- pkg/tracking/service/query/lexer/tokenizer.go | 290 +-- .../service/query/lexer/tokenizer_test.go | 228 +-- pkg/tracking/service/query/parser/ast.go | 274 +-- pkg/tracking/service/query/parser/parser.go | 530 ++--- .../service/query/parser/parser_test.go | 364 ++-- pkg/tracking/service/query/parser/validate.go | 658 +++--- pkg/tracking/service/query/query.go | 74 +- pkg/tracking/service/query/query_test.go | 224 +- pkg/tracking/service/runs.go | 336 +-- pkg/tracking/service/service.go | 54 +- pkg/tracking/store/sql/experiments.go | 508 ++--- pkg/tracking/store/sql/metrics.go | 386 ++-- .../store/sql/models/alembic_version.go | 22 +- pkg/tracking/store/sql/models/datasets.go | 56 +- .../store/sql/models/experiment_tags.go | 20 +- pkg/tracking/store/sql/models/experiments.go | 80 +- pkg/tracking/store/sql/models/input_tags.go | 38 +- pkg/tracking/store/sql/models/inputs.go | 56 +- .../store/sql/models/latest_metrics.go | 50 +- pkg/tracking/store/sql/models/lifecycle.go | 24 +- pkg/tracking/store/sql/models/metrics.go | 114 +- pkg/tracking/store/sql/models/params.go | 54 +- pkg/tracking/store/sql/models/runs.go | 212 +- pkg/tracking/store/sql/models/tags.go | 62 +- pkg/tracking/store/sql/params.go | 238 +-- pkg/tracking/store/sql/runs.go | 1800 ++++++++--------- pkg/tracking/store/sql/runs_internal_test.go | 1036 +++++----- pkg/tracking/store/sql/store.go | 56 +- pkg/tracking/store/sql/tags.go | 522 ++--- pkg/tracking/store/store.go | 162 +- pkg/utils/logger.go | 98 +- pkg/utils/naming.go | 142 +- pkg/utils/path.go | 180 +- pkg/utils/pointers.go | 82 +- pkg/utils/strings.go | 48 +- pkg/utils/tags.go | 12 +- pkg/validation/validation.go | 598 +++--- pkg/validation/validation_test.go | 488 ++--- pyproject.toml | 346 ++-- setup.py | 132 +- tests/override_model_registry_store.py | 10 +- tests/override_server.py | 154 +- tests/override_test_sqlalchemy_store.py | 34 +- tests/override_tracking_store.py | 10 +- 132 files changed, 11152 insertions(+), 11152 deletions(-) mode change 100755 => 100644 .devcontainer/postCreate.sh diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 3c75a68..2b4ad23 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,20 +1,20 @@ -FROM mcr.microsoft.com/devcontainers/go:1-1.22-bookworm - -# [Optional] Uncomment this section to install additional OS packages. -RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ - && apt-get -y install --no-install-recommends \ - postgresql-client \ - sqlite3 \ - && rm -rf /var/lib/apt/lists/* - -# [Optional] Uncomment the next lines to use go get to install anything else you need -USER vscode -RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.33.0 \ - && go install github.com/vektra/mockery/v2@v2.43.2 \ - && go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.59.1 \ - && go install github.com/magefile/mage@v1.15.0 \ - && go clean -cache -modcache -USER root - -# [Optional] Uncomment this line to install global node packages. -# RUN su vscode -c "source /usr/local/share/nvm/nvm.sh && npm install -g " 2>&1 +FROM mcr.microsoft.com/devcontainers/go:1-1.22-bookworm + +# [Optional] Uncomment this section to install additional OS packages. +RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ + && apt-get -y install --no-install-recommends \ + postgresql-client \ + sqlite3 \ + && rm -rf /var/lib/apt/lists/* + +# [Optional] Uncomment the next lines to use go get to install anything else you need +USER vscode +RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.33.0 \ + && go install github.com/vektra/mockery/v2@v2.43.2 \ + && go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.59.1 \ + && go install github.com/magefile/mage@v1.15.0 \ + && go clean -cache -modcache +USER root + +# [Optional] Uncomment this line to install global node packages. +# RUN su vscode -c "source /usr/local/share/nvm/nvm.sh && npm install -g " 2>&1 diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 13d7b1f..6fcc3ba 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,66 +1,66 @@ -// For format details, see https://aka.ms/devcontainer.json. -{ - "name": "MLflow Go", - "dockerComposeFile": "docker-compose.yml", - "service": "app", - "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", - - // Features to add to the dev container. More info: https://containers.dev/features. - "features": { - "ghcr.io/devcontainers/features/github-cli:1": {}, - "ghcr.io/devcontainers/features/python:1": { - "version": "3.8" - }, - "ghcr.io/devcontainers/features/docker-in-docker:2": {}, - "ghcr.io/devcontainers-contrib/features/k6:1": {}, - "ghcr.io/devcontainers-contrib/features/pre-commit:2": {}, - "ghcr.io/devcontainers-contrib/features/protoc-asdf:1": { - "version": "26.0" - }, - "ghcr.io/devcontainers-contrib/features/ruff:1": {} - }, - - // Configure tool-specific properties. - "customizations": { - "vscode": { - "settings": { - "terminal.integrated.defaultProfile.linux": "zsh", - "editor.rulers": [ - 80, - 100 - ], - "editor.formatOnSave": true, - "git.alwaysSignOff": true, - "go.lintTool": "golangci-lint", - "gopls": { - "formatting.local": "github.com/mlflow/mlflow-go", - "formatting.gofumpt": true, - "build.buildFlags": ["-tags=mage"] - }, - "[python]": { - "editor.codeActionsOnSave": { - "source.fixAll": "explicit", - "source.organizeImports": "explicit" - }, - "editor.defaultFormatter": "charliermarsh.ruff" - } - }, - "extensions": [ - "charliermarsh.ruff", - "golang.Go", - "humao.rest-client", - "pbkit.vscode-pbkit", - "tamasfe.even-better-toml" - ] - } - }, - - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [5432], - - // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": ".devcontainer/postCreate.sh", - - // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. - "remoteUser": "root" +// For format details, see https://aka.ms/devcontainer.json. +{ + "name": "MLflow Go", + "dockerComposeFile": "docker-compose.yml", + "service": "app", + "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", + + // Features to add to the dev container. More info: https://containers.dev/features. + "features": { + "ghcr.io/devcontainers/features/github-cli:1": {}, + "ghcr.io/devcontainers/features/python:1": { + "version": "3.8" + }, + "ghcr.io/devcontainers/features/docker-in-docker:2": {}, + "ghcr.io/devcontainers-contrib/features/k6:1": {}, + "ghcr.io/devcontainers-contrib/features/pre-commit:2": {}, + "ghcr.io/devcontainers-contrib/features/protoc-asdf:1": { + "version": "26.0" + }, + "ghcr.io/devcontainers-contrib/features/ruff:1": {} + }, + + // Configure tool-specific properties. + "customizations": { + "vscode": { + "settings": { + "terminal.integrated.defaultProfile.linux": "zsh", + "editor.rulers": [ + 80, + 100 + ], + "editor.formatOnSave": true, + "git.alwaysSignOff": true, + "go.lintTool": "golangci-lint", + "gopls": { + "formatting.local": "github.com/mlflow/mlflow-go", + "formatting.gofumpt": true, + "build.buildFlags": ["-tags=mage"] + }, + "[python]": { + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } + }, + "extensions": [ + "charliermarsh.ruff", + "golang.Go", + "humao.rest-client", + "pbkit.vscode-pbkit", + "tamasfe.even-better-toml" + ] + } + }, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [5432], + + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": ".devcontainer/postCreate.sh", + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + "remoteUser": "root" } \ No newline at end of file diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 11e19cb..9ebb9ec 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -1,38 +1,38 @@ -volumes: - go-cache: - postgres-data: - -services: - app: - build: - context: . - dockerfile: Dockerfile - volumes: - - ../..:/workspaces:cached - - go-cache:/var/cache/go - environment: - - GOCACHE=/var/cache/go/build - - GOMODCACHE=/var/cache/go/mod - - # Overrides default command so things don't shut down after the process ends. - command: sleep infinity - - # Runs app on the same network as the database container, allows "forwardPorts" in devcontainer.json function. - network_mode: service:db - - # Use "forwardPorts" in **devcontainer.json** to forward an app port locally. - # (Adding the "ports" property to this file will not forward from a Codespace.) - - db: - image: postgres:latest - restart: unless-stopped - volumes: - - postgres-data:/var/lib/postgresql/data - environment: - - POSTGRES_USER=postgres - - POSTGRES_PASSWORD=postgres - - POSTGRES_DB=postgres - - POSTGRES_HOSTNAME=localhost=value - - # Add "forwardPorts": ["5432"] to **devcontainer.json** to forward PostgreSQL locally. - # (Adding the "ports" property to this file will not forward from a Codespace.) +volumes: + go-cache: + postgres-data: + +services: + app: + build: + context: . + dockerfile: Dockerfile + volumes: + - ../..:/workspaces:cached + - go-cache:/var/cache/go + environment: + - GOCACHE=/var/cache/go/build + - GOMODCACHE=/var/cache/go/mod + + # Overrides default command so things don't shut down after the process ends. + command: sleep infinity + + # Runs app on the same network as the database container, allows "forwardPorts" in devcontainer.json function. + network_mode: service:db + + # Use "forwardPorts" in **devcontainer.json** to forward an app port locally. + # (Adding the "ports" property to this file will not forward from a Codespace.) + + db: + image: postgres:latest + restart: unless-stopped + volumes: + - postgres-data:/var/lib/postgresql/data + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=postgres + - POSTGRES_HOSTNAME=localhost=value + + # Add "forwardPorts": ["5432"] to **devcontainer.json** to forward PostgreSQL locally. + # (Adding the "ports" property to this file will not forward from a Codespace.) diff --git a/.devcontainer/postCreate.sh b/.devcontainer/postCreate.sh old mode 100755 new mode 100644 index 68ef182..d71c8a4 --- a/.devcontainer/postCreate.sh +++ b/.devcontainer/postCreate.sh @@ -1,7 +1,7 @@ -#!/bin/sh - -# Fix permissions for the Go cache -sudo chown -R $(id -u):$(id -g) /var/cache/go - -# Install precommit (https://pre-commit.com/) -pre-commit install -t pre-commit -t prepare-commit-msg +#!/bin/sh + +# Fix permissions for the Go cache +sudo chown -R $(id -u):$(id -g) /var/cache/go + +# Install precommit (https://pre-commit.com/) +pre-commit install -t pre-commit -t prepare-commit-msg diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 505633d..97720e7 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,16 +1,16 @@ -# To get started with Dependabot version updates, you'll need to specify which -# package ecosystems to update and where the package manifests are located. -# Please see the documentation for more information: -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -# https://containers.dev/guide/dependabot - -version: 2 -updates: - - package-ecosystem: "devcontainers" - directory: "/" - schedule: - interval: weekly - - package-ecosystem: "github-actions" - directory: "/" - schedule: - interval: weekly +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: weekly diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 850711b..ae98d9a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,18 +1,18 @@ -name: CI - -on: - push: - branches: - - main - pull_request: - schedule: - # Run daily at 01:34 so we get notified if CI is broken before a pull request - # is submitted. - - cron: "34 1 * * *" - -jobs: - lint: - uses: ./.github/workflows/lint.yml - - test: - uses: ./.github/workflows/test.yml +name: CI + +on: + push: + branches: + - main + pull_request: + schedule: + # Run daily at 01:34 so we get notified if CI is broken before a pull request + # is submitted. + - cron: "34 1 * * *" + +jobs: + lint: + uses: ./.github/workflows/lint.yml + + test: + uses: ./.github/workflows/test.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d4b0f19..83e5398 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,22 +1,22 @@ -name: Lint - -on: - workflow_call: - -permissions: - contents: read - -jobs: - lint: - name: Lint - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: "1.22" - check-latest: true - cache: false - - name: Run pre-commit hooks - run: pipx run pre-commit run --all-files +name: Lint + +on: + workflow_call: + +permissions: + contents: read + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + check-latest: true + cache: false + - name: Run pre-commit hooks + run: pipx run pre-commit run --all-files diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 24b20a5..0becdec 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,59 +1,59 @@ -name: Test - -on: - workflow_call: - -permissions: - contents: read - -jobs: - go: - name: Test Go - strategy: - matrix: - runner: [macos-latest, ubuntu-latest, windows-latest] - runs-on: ${{ matrix.runner }} - steps: - - uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: "1.22" - check-latest: true - cache: false - - name: Install mage - run: go install github.com/magefile/mage@v1.15.0 - - name: Run unit tests - run: mage test:unit - - python: - name: Test Python - strategy: - matrix: - runner: [macos-latest, ubuntu-latest, windows-latest] - python: ["3.8", "3.9", "3.10", "3.11", "3.12"] - runs-on: ${{ matrix.runner }} - steps: - - uses: actions/checkout@v4 - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: "1.22" - check-latest: true - cache: false - - name: Install mage - run: go install github.com/magefile/mage@v1.15.0 - - name: Install our package in editable mode - run: pip install -e . - - name: Initialize MLflow repo - run: mage repo:init - - name: Install dependencies - run: pip install pytest==8.1.1 psycopg2-binary -e .mlflow.repo - - name: Run integration tests - run: mage test:python - # Temporary workaround for failing tests - continue-on-error: ${{ matrix.runner != 'ubuntu-latest' }} +name: Test + +on: + workflow_call: + +permissions: + contents: read + +jobs: + go: + name: Test Go + strategy: + matrix: + runner: [macos-latest, ubuntu-latest, windows-latest] + runs-on: ${{ matrix.runner }} + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + check-latest: true + cache: false + - name: Install mage + run: go install github.com/magefile/mage@v1.15.0 + - name: Run unit tests + run: mage test:unit + + python: + name: Test Python + strategy: + matrix: + runner: [macos-latest, ubuntu-latest, windows-latest] + python: ["3.8", "3.9", "3.10", "3.11", "3.12"] + runs-on: ${{ matrix.runner }} + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + check-latest: true + cache: false + - name: Install mage + run: go install github.com/magefile/mage@v1.15.0 + - name: Install our package in editable mode + run: pip install -e . + - name: Initialize MLflow repo + run: mage repo:init + - name: Install dependencies + run: pip install pytest==8.1.1 psycopg2-binary -e .mlflow.repo + - name: Run integration tests + run: mage test:python + # Temporary workaround for failing tests + continue-on-error: ${{ matrix.runner != 'ubuntu-latest' }} diff --git a/.gitignore b/.gitignore index cf9ad19..af82b33 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ -# Artifacts -dist/ -*.egg-info/ -*.so - -# Runs -mlruns/ - -# Cache -__pycache__/ - -# MLflow repo -.mlflow.repo/ - -# JetBrains +# Artifacts +dist/ +*.egg-info/ +*.so + +# Runs +mlruns/ + +# Cache +__pycache__/ + +# MLflow repo +.mlflow.repo/ + +# JetBrains .idea \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index 6252ee3..383e76b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,55 +1,55 @@ -run: - build-tags: - - mage - timeout: 5m - -linters: - enable: - - errcheck - - gosimple - - lll - disable: - - depguard - - gochecknoglobals # Immutable globals are fine. - - exhaustruct # Often the case for protobuf generated code or gorm structs. - - protogetter # We do want to use pointers for memory optimization. - presets: - - bugs - - comment - - complexity - - error - - format - - import - - metalinter - - module - - performance - - sql - - style - - test - - unused - -linters-settings: - gci: - custom-order: true - - sections: - - standard # Standard section: captures all standard packages. - - default # Default section: contains all imports that could not be matched to another section type. - - alias # Alias section: contains all alias imports. This section is not present unless explicitly enabled. - - prefix(github.com/mlflow/mlflow-go) # Custom section: groups all imports with the specified Prefix. - - blank # Blank section: contains all blank imports. This section is not present unless explicitly enabled. - - dot # Dot section: contains all dot imports. This section is not present unless explicitly enabled. - - gofumpt: - module-path: github.com/mlflow/mlflow-go - extra-rules: true - - tagliatelle: - case: - rules: - json: snake - -issues: - exclude-files: - - ".*\\.g\\.go$" - - ".*\\.pb\\.go$" +run: + build-tags: + - mage + timeout: 5m + +linters: + enable: + - errcheck + - gosimple + - lll + disable: + - depguard + - gochecknoglobals # Immutable globals are fine. + - exhaustruct # Often the case for protobuf generated code or gorm structs. + - protogetter # We do want to use pointers for memory optimization. + presets: + - bugs + - comment + - complexity + - error + - format + - import + - metalinter + - module + - performance + - sql + - style + - test + - unused + +linters-settings: + gci: + custom-order: true + + sections: + - standard # Standard section: captures all standard packages. + - default # Default section: contains all imports that could not be matched to another section type. + - alias # Alias section: contains all alias imports. This section is not present unless explicitly enabled. + - prefix(github.com/mlflow/mlflow-go) # Custom section: groups all imports with the specified Prefix. + - blank # Blank section: contains all blank imports. This section is not present unless explicitly enabled. + - dot # Dot section: contains all dot imports. This section is not present unless explicitly enabled. + + gofumpt: + module-path: github.com/mlflow/mlflow-go + extra-rules: true + + tagliatelle: + case: + rules: + json: snake + +issues: + exclude-files: + - ".*\\.g\\.go$" + - ".*\\.pb\\.go$" diff --git a/.mockery.yaml b/.mockery.yaml index 33ec946..b3021f7 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -1,8 +1,8 @@ -dir: "{{ .InterfaceDir }}" -filename: "mock_{{ .InterfaceNameSnake }}.go" -with-expecter: true -inpackage: true -packages: - github.com/mlflow/mlflow-go/pkg/tracking/store: - interfaces: - TrackingStore: +dir: "{{ .InterfaceDir }}" +filename: "mock_{{ .InterfaceNameSnake }}.go" +with-expecter: true +inpackage: true +packages: + github.com/mlflow/mlflow-go/pkg/tracking/store: + interfaces: + TrackingStore: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6dc1a46..9bfe726 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,34 +1,34 @@ -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 - hooks: - - id: end-of-file-fixer - files: \.(proto|txt|sh|rst)$ - - repo: https://github.com/golangci/golangci-lint - rev: "v1.59.1" - hooks: - - id: golangci-lint-full - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.7 - hooks: - - id: ruff - types_or: [python, pyi, jupyter] - args: [--fix] - - id: ruff-format - types_or: [python, pyi, jupyter] - - repo: local - hooks: - # - id: rstcheck - # name: rstcheck - # entry: rstcheck - # language: system - # files: README.rst - # stages: [commit] - # require_serial: true - - - id: must-have-signoff - name: must-have-signoff - entry: 'grep "Signed-off-by:"' - language: system - stages: [prepare-commit-msg] - require_serial: true +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: end-of-file-fixer + files: \.(proto|txt|sh|rst)$ + - repo: https://github.com/golangci/golangci-lint + rev: "v1.59.1" + hooks: + - id: golangci-lint-full + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.7 + hooks: + - id: ruff + types_or: [python, pyi, jupyter] + args: [--fix] + - id: ruff-format + types_or: [python, pyi, jupyter] + - repo: local + hooks: + # - id: rstcheck + # name: rstcheck + # entry: rstcheck + # language: system + # files: README.rst + # stages: [commit] + # require_serial: true + + - id: must-have-signoff + name: must-have-signoff + entry: 'grep "Signed-off-by:"' + language: system + stages: [prepare-commit-msg] + require_serial: true diff --git a/LICENSE b/LICENSE index 29f81d8..261eeb9 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,201 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index e184a13..f4d3991 100644 --- a/README.md +++ b/README.md @@ -1,316 +1,316 @@ -# Go backend for MLflow - -In order to increase the performance of the tracking server and the various stores, we propose to rewrite the server and store implementation in Go. - -## Usage - -### Installation - -This package is not yet available on PyPI and currently requires the [Go SDK](https://go.dev) to be installed. - -You can then install the package via pip: -```bash -pip install git+https://github.com/jgiannuzzi/mlflow-go.git -``` - -### Using the Go server - -```bash -# Start the Go server with a database URI -# Other databases are supported as well: postgresql, mysql and mssql -mlflow-go server --backend-store-uri sqlite:///mlflow.db -``` - -```python -import mlflow - -# Use the Go server -mlflow.set_tracking_uri("http://localhost:5000") - -# Use MLflow as usual -mlflow.set_experiment("my-experiment") - -with mlflow.start_run(): - mlflow.log_param("param", 1) - mlflow.log_metric("metric", 2) -``` - -### Using the client-side Go implementation - -```python -import mlflow -import mlflow_go - -# Enable the Go client implementation (disabled by default) -mlflow_go.enable_go() - -# Set the tracking URI (you can also set it via the environment variable MLFLOW_TRACKING_URI) -# Currently only database URIs are supported -mlflow.set_tracking_uri("sqlite:///mlflow.db") - -# Use MLflow as usual -mlflow.set_experiment("my-experiment") - -with mlflow.start_run(): - mlflow.log_param("param", 1) - mlflow.log_metric("metric", 2) -``` - -## Temp stuff - -### Dev setup - -```bash -# Install our Python package and its dependencies -pip install -e . - -# Install the dreaded psycho -pip install psycopg2-binary - -# Archive the MLFlow pre-built UI -tar -C /usr/local/python/current/lib/python3.8/site-packages/mlflow -czvf ./ui.tgz ./server/js/build - -# Clone the MLflow repo -git clone https://github.com/jgiannuzzi/mlflow.git -b master .mlflow.repo - -# Add the UI back to it -tar -C .mlflow.repo/mlflow -xzvf ./ui.tgz - -# Install it in editable mode -pip install -e .mlflow.repo -``` - -or run `mage temp`. - -### Run the tests manually - -```bash -# Build the Go binary in a temporary directory -libpath=$(mktemp -d) -python -m mlflow_go.lib . $libpath - -# Run the tests (currently just the server ones) -MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ - .mlflow.repo/tests/tracking/test_rest_tracking.py \ - .mlflow.repo/tests/tracking/test_model_registry.py \ - .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py \ - .mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py \ - -k 'not [file' - -# Remove the Go binary -rm -rf $libpath - -# If you want to run a specific test with more verbosity -# -s for live output -# --log-level=debug for more verbosity (passed down to the Go server/stores) -MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ - .mlflow.repo/tests/tracking/test_rest_tracking.py::test_create_experiment_validation \ - -k 'not [file' \ - -s --log-level=debug -``` - -Or run the `mage test:python` target. - -### Use the Go store directly in Python - -```python -import logging -import mlflow -import mlflow_go - -# Enable debug logging -logging.basicConfig() -logging.getLogger('mlflow_go').setLevel(logging.DEBUG) - -# Enable the Go client implementation (disabled by default) -mlflow_go.enable_go() - -# Instantiate the tracking store with a database URI -tracking_store = mlflow.tracking._tracking_service.utils._get_store('sqlite:///mlflow.db') - -# Call any tracking store method -tracking_store.get_experiment(0) - -# Instantiate the model registry store with a database URI -model_registry_store = mlflow.tracking._model_registry.utils._get_store('sqlite:///mlflow.db') - -# Call any model registry store method -model_registry_store.get_latest_versions("model") -``` - -## General setup - -### Mage - -This repository uses [mage](https://magefile.org/) to streamline some utilily functions. - -```bash -# Install mage (already done in the dev container) -go install github.com/magefile/mage@v1.15.0 - -# See all targets -mage - -# Execute single target -mage dev -``` - -The beauty of Mage is that we can use regular Go code for our scripting. -That being said, we are not married to this tool. - -### mlflow source code - -To integrate with MLflow, you need to include the source code. The [mlflow/mlflow](https://github.com/mlflow/mlflow/) repository contains proto files that define the tracking API. It also includes Python tests that we use to verify our Go implementation produces identical behaviour. - -We use a `.mlflow.ref` file to specify the exact location from which to pull our sources. The format should be `remote#reference`, where `remote` is a git remote and `reference` is a branch, tag, or commit SHA. - -If the `.mlflow.ref` file is modified and becomes out of sync with the current source files, the mage target will automatically detect this. To manually force a sync, you can run `mage repo:update`. - -### Protos - -To ensure we stay compatible with the Python implementation, we aim to generate as much as possible based on the `.proto` files. - -By running - -```bash -mage generate -``` - -Go code will be generated. Use the protos files from `.mlflow.repo` repository. - -This incudes the generation of: - -- Structs for each endpoint. ([pkg/protos](./protos/service.pb.go)) -- Go interfaces for each service. ([pkg/contract/service/*.g.go](./contract/service/tracking.g.go)) -- [fiber](https://gofiber.io/) routes for each endpoint. ([pkg/server/routes/*.g.go](./server/routes/tracking.g.go)) - -If there is any change in the proto files, this should ripple into the Go code. - -## Launching the Go server - -To enable use of the Go server, users can run the `mlflow-go server` command. - -```bash -mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres -``` - -This will launch the python process as usual. Within Python, a random port is chosen to start the existing server and a Go child process is spawned. The Go server will use the user specified port (5000 by default) and spawn the actual Python server as its own child process (`gunicorn` or `waitress`). -Any incoming requests the Go server cannot process will be proxied to the existing Python server. - -Any Go-specific options can be passed with `--go-opts`, which takes a comma-separated list of key-value pairs. - -```bash -mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres --go-opts log_level=debug,shutdown_timeout=5s -``` - -## Building the Go binary - -To ensure everything still compiles: - -```bash -go build -o /dev/null ./pkg/cmd/server -``` - -or - -```bash -python -m mlflow_go.lib . /tmp -``` - -## Request validation - -We use [Go validator](https://github.com/go-playground/validator) to validate all incoming request structs. -As the proto files don't specify any validation rules, we map them manually in [pkg/cmd/generate/validations.go](./cmd/generate/validations.go). - -Once the mapping has been done, validation will be invoked automatically in the generated fiber code. - -When the need arises, we can write custom validation function in [pkg/validation/validation.go](./validation/validation.go). - -## Data access - -Initially, we want to focus on supporting Postgres SQL. We chose [Gorm](https://gorm.io/) as ORM to interact with the database. - -We do not generate any Go code based on the database schema. Gorm has generation capabilities but they didn't fit our needs. The plan would be to eventually assert the current code stil matches the database schema via an intergration test. - -All the models use pointers for their fields. We do this for performance reasons and to distinguish between zero values and null values. - -## Testing strategy - -> [!WARNING] -> TODO rewrite this whole section - -The Python integration tests have been adapted to also run against the Go implementation. Just run them as usual, e.g. - -```bash -pytest tests/tracking/test_rest_tracking.py -``` - -To run only the tests targetting the Go implementation, you can use the `-k` flag: - -```bash -pytest tests/tracking/test_rest_tracking.py -k '[go-' -``` - -If you'd like to run a specific test and see its output 'live', you can use the `-s` flag: - -```bash -pytest -s "tests/tracking/test_rest_tracking.py::test_create_experiment_validation[go-postgresql]" -``` - -See the [pytest documentation](https://docs.pytest.org/en/8.2.x/how-to/usage.html#specifying-which-tests-to-run) for more details. - -## Supported endpoints - -The currently supported endpoints can be found by running - -```bash -mage endpoints -``` - -## Linters - -We have enabled various linters from [golangci-lint](https://golangci-lint.run/), you can run these via: - -```bash -pre-commit run golangci-lint --all-files -``` - -Sometimes `golangci-lint` can complain about unrelated files, run `golangci-lint cache clean` to clear the cache. - -## Failing tests - -The following Python tests are currently failing: - -``` -===================================================================================================================== short test summary info ====================================================================================================================== -FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_inputs_with_large_inputs_limit_check - AssertionError: assert {'digest': 'd...ema': '', ...} == {'digest': 'd...a': None, ...} -======================================================================================== 1 failed, 358 passed, 9 skipped, 128 deselected, 10 warnings in 227.64s (0:03:47) ========================================================================================= -``` - -## Debug Failing Tests - -Sometimes, it can be very useful to modify failing tests and use `print` statements to display the current state or differences between objects from Python or Go services. - -Adding `"-vv"` to the `pytest` command in `magefiles/tests.go` can also provide more information when assertions are not met. - -### Targeting Local Postgres in Integration Tests - -At times, you might want to apply store calls to your local database to investigate certain read operations via the local tracking server. - -You can achieve this by changing: - -```python -def test_search_runs_datasets(store: SqlAlchemyStore): -``` - -to: - -```python -def test_search_runs_datasets(): - db_uri = "postgresql://postgres:postgres@localhost:5432/postgres" - artifact_uri = Path("/tmp/artifacts") - artifact_uri.mkdir(exist_ok=True) - store = SqlAlchemyStore(db_uri, artifact_uri.as_uri()) -``` - +# Go backend for MLflow + +In order to increase the performance of the tracking server and the various stores, we propose to rewrite the server and store implementation in Go. + +## Usage + +### Installation + +This package is not yet available on PyPI and currently requires the [Go SDK](https://go.dev) to be installed. + +You can then install the package via pip: +```bash +pip install git+https://github.com/jgiannuzzi/mlflow-go.git +``` + +### Using the Go server + +```bash +# Start the Go server with a database URI +# Other databases are supported as well: postgresql, mysql and mssql +mlflow-go server --backend-store-uri sqlite:///mlflow.db +``` + +```python +import mlflow + +# Use the Go server +mlflow.set_tracking_uri("http://localhost:5000") + +# Use MLflow as usual +mlflow.set_experiment("my-experiment") + +with mlflow.start_run(): + mlflow.log_param("param", 1) + mlflow.log_metric("metric", 2) +``` + +### Using the client-side Go implementation + +```python +import mlflow +import mlflow_go + +# Enable the Go client implementation (disabled by default) +mlflow_go.enable_go() + +# Set the tracking URI (you can also set it via the environment variable MLFLOW_TRACKING_URI) +# Currently only database URIs are supported +mlflow.set_tracking_uri("sqlite:///mlflow.db") + +# Use MLflow as usual +mlflow.set_experiment("my-experiment") + +with mlflow.start_run(): + mlflow.log_param("param", 1) + mlflow.log_metric("metric", 2) +``` + +## Temp stuff + +### Dev setup + +```bash +# Install our Python package and its dependencies +pip install -e . + +# Install the dreaded psycho +pip install psycopg2-binary + +# Archive the MLFlow pre-built UI +tar -C /usr/local/python/current/lib/python3.8/site-packages/mlflow -czvf ./ui.tgz ./server/js/build + +# Clone the MLflow repo +git clone https://github.com/jgiannuzzi/mlflow.git -b master .mlflow.repo + +# Add the UI back to it +tar -C .mlflow.repo/mlflow -xzvf ./ui.tgz + +# Install it in editable mode +pip install -e .mlflow.repo +``` + +or run `mage temp`. + +### Run the tests manually + +```bash +# Build the Go binary in a temporary directory +libpath=$(mktemp -d) +python -m mlflow_go.lib . $libpath + +# Run the tests (currently just the server ones) +MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ + .mlflow.repo/tests/tracking/test_rest_tracking.py \ + .mlflow.repo/tests/tracking/test_model_registry.py \ + .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py \ + .mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py \ + -k 'not [file' + +# Remove the Go binary +rm -rf $libpath + +# If you want to run a specific test with more verbosity +# -s for live output +# --log-level=debug for more verbosity (passed down to the Go server/stores) +MLFLOW_GO_LIBRARY_PATH=$libpath pytest --confcutdir=. \ + .mlflow.repo/tests/tracking/test_rest_tracking.py::test_create_experiment_validation \ + -k 'not [file' \ + -s --log-level=debug +``` + +Or run the `mage test:python` target. + +### Use the Go store directly in Python + +```python +import logging +import mlflow +import mlflow_go + +# Enable debug logging +logging.basicConfig() +logging.getLogger('mlflow_go').setLevel(logging.DEBUG) + +# Enable the Go client implementation (disabled by default) +mlflow_go.enable_go() + +# Instantiate the tracking store with a database URI +tracking_store = mlflow.tracking._tracking_service.utils._get_store('sqlite:///mlflow.db') + +# Call any tracking store method +tracking_store.get_experiment(0) + +# Instantiate the model registry store with a database URI +model_registry_store = mlflow.tracking._model_registry.utils._get_store('sqlite:///mlflow.db') + +# Call any model registry store method +model_registry_store.get_latest_versions("model") +``` + +## General setup + +### Mage + +This repository uses [mage](https://magefile.org/) to streamline some utilily functions. + +```bash +# Install mage (already done in the dev container) +go install github.com/magefile/mage@v1.15.0 + +# See all targets +mage + +# Execute single target +mage dev +``` + +The beauty of Mage is that we can use regular Go code for our scripting. +That being said, we are not married to this tool. + +### mlflow source code + +To integrate with MLflow, you need to include the source code. The [mlflow/mlflow](https://github.com/mlflow/mlflow/) repository contains proto files that define the tracking API. It also includes Python tests that we use to verify our Go implementation produces identical behaviour. + +We use a `.mlflow.ref` file to specify the exact location from which to pull our sources. The format should be `remote#reference`, where `remote` is a git remote and `reference` is a branch, tag, or commit SHA. + +If the `.mlflow.ref` file is modified and becomes out of sync with the current source files, the mage target will automatically detect this. To manually force a sync, you can run `mage repo:update`. + +### Protos + +To ensure we stay compatible with the Python implementation, we aim to generate as much as possible based on the `.proto` files. + +By running + +```bash +mage generate +``` + +Go code will be generated. Use the protos files from `.mlflow.repo` repository. + +This incudes the generation of: + +- Structs for each endpoint. ([pkg/protos](./protos/service.pb.go)) +- Go interfaces for each service. ([pkg/contract/service/*.g.go](./contract/service/tracking.g.go)) +- [fiber](https://gofiber.io/) routes for each endpoint. ([pkg/server/routes/*.g.go](./server/routes/tracking.g.go)) + +If there is any change in the proto files, this should ripple into the Go code. + +## Launching the Go server + +To enable use of the Go server, users can run the `mlflow-go server` command. + +```bash +mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres +``` + +This will launch the python process as usual. Within Python, a random port is chosen to start the existing server and a Go child process is spawned. The Go server will use the user specified port (5000 by default) and spawn the actual Python server as its own child process (`gunicorn` or `waitress`). +Any incoming requests the Go server cannot process will be proxied to the existing Python server. + +Any Go-specific options can be passed with `--go-opts`, which takes a comma-separated list of key-value pairs. + +```bash +mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres --go-opts log_level=debug,shutdown_timeout=5s +``` + +## Building the Go binary + +To ensure everything still compiles: + +```bash +go build -o /dev/null ./pkg/cmd/server +``` + +or + +```bash +python -m mlflow_go.lib . /tmp +``` + +## Request validation + +We use [Go validator](https://github.com/go-playground/validator) to validate all incoming request structs. +As the proto files don't specify any validation rules, we map them manually in [pkg/cmd/generate/validations.go](./cmd/generate/validations.go). + +Once the mapping has been done, validation will be invoked automatically in the generated fiber code. + +When the need arises, we can write custom validation function in [pkg/validation/validation.go](./validation/validation.go). + +## Data access + +Initially, we want to focus on supporting Postgres SQL. We chose [Gorm](https://gorm.io/) as ORM to interact with the database. + +We do not generate any Go code based on the database schema. Gorm has generation capabilities but they didn't fit our needs. The plan would be to eventually assert the current code stil matches the database schema via an intergration test. + +All the models use pointers for their fields. We do this for performance reasons and to distinguish between zero values and null values. + +## Testing strategy + +> [!WARNING] +> TODO rewrite this whole section + +The Python integration tests have been adapted to also run against the Go implementation. Just run them as usual, e.g. + +```bash +pytest tests/tracking/test_rest_tracking.py +``` + +To run only the tests targetting the Go implementation, you can use the `-k` flag: + +```bash +pytest tests/tracking/test_rest_tracking.py -k '[go-' +``` + +If you'd like to run a specific test and see its output 'live', you can use the `-s` flag: + +```bash +pytest -s "tests/tracking/test_rest_tracking.py::test_create_experiment_validation[go-postgresql]" +``` + +See the [pytest documentation](https://docs.pytest.org/en/8.2.x/how-to/usage.html#specifying-which-tests-to-run) for more details. + +## Supported endpoints + +The currently supported endpoints can be found by running + +```bash +mage endpoints +``` + +## Linters + +We have enabled various linters from [golangci-lint](https://golangci-lint.run/), you can run these via: + +```bash +pre-commit run golangci-lint --all-files +``` + +Sometimes `golangci-lint` can complain about unrelated files, run `golangci-lint cache clean` to clear the cache. + +## Failing tests + +The following Python tests are currently failing: + +``` +===================================================================================================================== short test summary info ====================================================================================================================== +FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_inputs_with_large_inputs_limit_check - AssertionError: assert {'digest': 'd...ema': '', ...} == {'digest': 'd...a': None, ...} +======================================================================================== 1 failed, 358 passed, 9 skipped, 128 deselected, 10 warnings in 227.64s (0:03:47) ========================================================================================= +``` + +## Debug Failing Tests + +Sometimes, it can be very useful to modify failing tests and use `print` statements to display the current state or differences between objects from Python or Go services. + +Adding `"-vv"` to the `pytest` command in `magefiles/tests.go` can also provide more information when assertions are not met. + +### Targeting Local Postgres in Integration Tests + +At times, you might want to apply store calls to your local database to investigate certain read operations via the local tracking server. + +You can achieve this by changing: + +```python +def test_search_runs_datasets(store: SqlAlchemyStore): +``` + +to: + +```python +def test_search_runs_datasets(): + db_uri = "postgresql://postgres:postgres@localhost:5432/postgres" + artifact_uri = Path("/tmp/artifacts") + artifact_uri.mkdir(exist_ok=True) + store = SqlAlchemyStore(db_uri, artifact_uri.as_uri()) +``` + in the test file located in `.mlflow.repo`. \ No newline at end of file diff --git a/conftest.py b/conftest.py index 999a2a7..4ea53bd 100644 --- a/conftest.py +++ b/conftest.py @@ -1,64 +1,64 @@ -import logging -import pathlib -from unittest.mock import patch - -_logger = logging.getLogger(__name__) - - -def load_new_function(file_path, func_name): - with open(file_path) as f: - new_func_code = f.read() - - local_dict = {} - exec(new_func_code, local_dict) - return local_dict[func_name] - - -def pytest_configure(config): - for func_to_patch, new_func_file_relative in ( - ( - "tests.tracking.integration_test_utils._init_server", - "tests/override_server.py", - ), - ( - "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore", - "tests/override_tracking_store.py", - ), - ( - "mlflow.store.model_registry.sqlalchemy_store.SqlAlchemyStore", - "tests/override_model_registry_store.py", - ), - # This test will patch some Python internals to invoke an internal exception. - # We cannot do this in Go. - ( - "tests.store.tracking.test_sqlalchemy_store.test_log_batch_internal_error", - "tests/override_test_sqlalchemy_store.py", - ), - # This test uses monkeypatch.setenv which does not flow through to the - ( - "tests.store.tracking.test_sqlalchemy_store.test_log_batch_params_max_length_value", - "tests/override_test_sqlalchemy_store.py", - ), - # This tests calls the store using invalid metric entity that cannot be converted - # to its proto counterpart. - # Example: entities.Metric("invalid_metric", None, (int(time.time() * 1000)), 0).to_proto() - ( - "tests.store.tracking.test_sqlalchemy_store.test_log_batch_null_metrics", - "tests/override_test_sqlalchemy_store.py", - ), - # We do not support applying the SQL schema to sqlite like Python does. - # So we do not support sqlite:////:memory: database. - ( - "tests.store.tracking.test_sqlalchemy_store.test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db", - "tests/override_test_sqlalchemy_store.py", - ), - ): - func_name = func_to_patch.rsplit(".", 1)[1] - new_func_file = ( - pathlib.Path(__file__).parent.joinpath(new_func_file_relative).resolve().as_posix() - ) - - new_func = load_new_function(new_func_file, func_name) - - _logger.info(f"Patching function: {func_to_patch}") - patch(func_to_patch, new_func).start() +import logging +import pathlib +from unittest.mock import patch + +_logger = logging.getLogger(__name__) + + +def load_new_function(file_path, func_name): + with open(file_path) as f: + new_func_code = f.read() + + local_dict = {} + exec(new_func_code, local_dict) + return local_dict[func_name] + + +def pytest_configure(config): + for func_to_patch, new_func_file_relative in ( + ( + "tests.tracking.integration_test_utils._init_server", + "tests/override_server.py", + ), + ( + "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore", + "tests/override_tracking_store.py", + ), + ( + "mlflow.store.model_registry.sqlalchemy_store.SqlAlchemyStore", + "tests/override_model_registry_store.py", + ), + # This test will patch some Python internals to invoke an internal exception. + # We cannot do this in Go. + ( + "tests.store.tracking.test_sqlalchemy_store.test_log_batch_internal_error", + "tests/override_test_sqlalchemy_store.py", + ), + # This test uses monkeypatch.setenv which does not flow through to the + ( + "tests.store.tracking.test_sqlalchemy_store.test_log_batch_params_max_length_value", + "tests/override_test_sqlalchemy_store.py", + ), + # This tests calls the store using invalid metric entity that cannot be converted + # to its proto counterpart. + # Example: entities.Metric("invalid_metric", None, (int(time.time() * 1000)), 0).to_proto() + ( + "tests.store.tracking.test_sqlalchemy_store.test_log_batch_null_metrics", + "tests/override_test_sqlalchemy_store.py", + ), + # We do not support applying the SQL schema to sqlite like Python does. + # So we do not support sqlite:////:memory: database. + ( + "tests.store.tracking.test_sqlalchemy_store.test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db", + "tests/override_test_sqlalchemy_store.py", + ), + ): + func_name = func_to_patch.rsplit(".", 1)[1] + new_func_file = ( + pathlib.Path(__file__).parent.joinpath(new_func_file_relative).resolve().as_posix() + ) + + new_func = load_new_function(new_func_file, func_name) + + _logger.info(f"Patching function: {func_to_patch}") + patch(func_to_patch, new_func).start() diff --git a/go.mod b/go.mod index 25ad617..44ef743 100644 --- a/go.mod +++ b/go.mod @@ -1,62 +1,62 @@ -module github.com/mlflow/mlflow-go - -go 1.22 - -require ( - github.com/DATA-DOG/go-sqlmock v1.5.2 - github.com/go-playground/validator/v10 v10.20.0 - github.com/gofiber/fiber/v2 v2.52.4 - github.com/google/uuid v1.6.0 - github.com/iancoleman/strcase v0.3.0 - github.com/magefile/mage v1.15.0 - github.com/sirupsen/logrus v1.9.3 - github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.17.1 - golang.org/x/sys v0.20.0 - google.golang.org/protobuf v1.34.1 - gorm.io/driver/mysql v1.5.6 - gorm.io/driver/postgres v1.5.7 - gorm.io/driver/sqlite v1.5.6 - gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.10 -) - -require ( - github.com/andybalholm/brotli v1.1.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-sql-driver/mysql v1.7.0 // indirect - github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect - github.com/golang-sql/sqlexp v0.1.0 // indirect - github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect - github.com/klauspost/compress v1.17.8 // indirect - github.com/kr/text v0.2.0 // indirect - github.com/leodido/go-urn v1.4.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/microsoft/go-mssqldb v1.6.0 // indirect - github.com/olekukonko/tablewriter v0.0.5 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rivo/uniseg v0.4.7 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect - github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.53.0 // indirect - github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/net v0.25.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/text v0.15.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +module github.com/mlflow/mlflow-go + +go 1.22 + +require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/go-playground/validator/v10 v10.20.0 + github.com/gofiber/fiber/v2 v2.52.4 + github.com/google/uuid v1.6.0 + github.com/iancoleman/strcase v0.3.0 + github.com/magefile/mage v1.15.0 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 + github.com/tidwall/gjson v1.17.1 + golang.org/x/sys v0.20.0 + google.golang.org/protobuf v1.34.1 + gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 + gorm.io/driver/sqlite v1.5.6 + gorm.io/driver/sqlserver v1.5.3 + gorm.io/gorm v1.25.10 +) + +require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/klauspost/compress v1.17.8 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/microsoft/go-mssqldb v1.6.0 // indirect + github.com/olekukonko/tablewriter v0.0.5 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.53.0 // indirect + github.com/valyala/tcplisten v1.0.0 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.15.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 16c2faf..f58db98 100644 --- a/go.sum +++ b/go.sum @@ -1,236 +1,236 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 h1:/iHxaJhsFr0+xVFfbMr5vxz848jyiWuIEDhYq3y5odY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0 h1:yfJe15aSwEQ6Oo6J+gdfdulPNoZ3TEhmbhLIoxZcA+U= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0/go.mod h1:Q28U+75mpCaSCDowNEmhIo/rmgdkqmkmzI7N6TGR4UY= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 h1:T028gtTPiYt/RMUfs8nVsAL7FDQrfLlrm/NnRG/zcC4= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0/go.mod h1:cw4zVQgBby0Z5f2v0itn6se2dDP17nTjbZFXW5uPyHA= -github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpCKs6LGGLAhwBARt9632unrVcI6i8s/8os= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= -github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= -github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= -github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= -github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= -github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= -github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= -github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/gofiber/fiber/v2 v2.52.4 h1:P+T+4iK7VaqUsq2PALYEfBBo6bJZ4q3FP8cZ84EggTM= -github.com/gofiber/fiber/v2 v2.52.4/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= -github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= -github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= -github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= -github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= -github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= -github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= -github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= -github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= -github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= -github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= -github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= -github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= -github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= -github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= -github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/microsoft/go-mssqldb v1.6.0 h1:mM3gYdVwEPFrlg/Dvr2DNVEgYFG7L42l+dGc67NNNpc= -github.com/microsoft/go-mssqldb v1.6.0/go.mod h1:00mDtPbeQCRGC1HwOOR5K/gr30P1NcEG0vx6Kbv2aJU= -github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= -github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= -github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.53.0 h1:lW/+SUkOxCx2vlIu0iaImv4JLrVRnbbkpCoaawvA4zc= -github.com/valyala/fasthttp v1.53.0/go.mod h1:6dt4/8olwq9QARP/TDuPmWyWcl4byhpvTJ4AAtcz+QM= -github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= -github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= -gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= -gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= -gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= -gorm.io/driver/sqlserver v1.5.3 h1:rjupPS4PVw+rjJkfvr8jn2lJ8BMhT4UW5FwuJY0P3Z0= -gorm.io/driver/sqlserver v1.5.3/go.mod h1:B+CZ0/7oFJ6tAlefsKoyxdgDCXJKSgwS2bMOQZT0I00= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= -gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 h1:/iHxaJhsFr0+xVFfbMr5vxz848jyiWuIEDhYq3y5odY= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0 h1:yfJe15aSwEQ6Oo6J+gdfdulPNoZ3TEhmbhLIoxZcA+U= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.0/go.mod h1:Q28U+75mpCaSCDowNEmhIo/rmgdkqmkmzI7N6TGR4UY= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 h1:T028gtTPiYt/RMUfs8nVsAL7FDQrfLlrm/NnRG/zcC4= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0/go.mod h1:cw4zVQgBby0Z5f2v0itn6se2dDP17nTjbZFXW5uPyHA= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= +github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpCKs6LGGLAhwBARt9632unrVcI6i8s/8os= +github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/gofiber/fiber/v2 v2.52.4 h1:P+T+4iK7VaqUsq2PALYEfBBo6bJZ4q3FP8cZ84EggTM= +github.com/gofiber/fiber/v2 v2.52.4/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= +github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= +github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/microsoft/go-mssqldb v1.6.0 h1:mM3gYdVwEPFrlg/Dvr2DNVEgYFG7L42l+dGc67NNNpc= +github.com/microsoft/go-mssqldb v1.6.0/go.mod h1:00mDtPbeQCRGC1HwOOR5K/gr30P1NcEG0vx6Kbv2aJU= +github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.53.0 h1:lW/+SUkOxCx2vlIu0iaImv4JLrVRnbbkpCoaawvA4zc= +github.com/valyala/fasthttp v1.53.0/go.mod h1:6dt4/8olwq9QARP/TDuPmWyWcl4byhpvTJ4AAtcz+QM= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= +gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/driver/sqlserver v1.5.3 h1:rjupPS4PVw+rjJkfvr8jn2lJ8BMhT4UW5FwuJY0P3Z0= +gorm.io/driver/sqlserver v1.5.3/go.mod h1:B+CZ0/7oFJ6tAlefsKoyxdgDCXJKSgwS2bMOQZT0I00= +gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= +gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/magefiles/dev.go b/magefiles/dev.go index 1012c84..9b6d739 100644 --- a/magefiles/dev.go +++ b/magefiles/dev.go @@ -1,28 +1,28 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -// Start the mlflow-go dev server connecting to postgres. -func Dev() error { - mg.Deps(Generate) - - envs := make(map[string]string) - envs["MLFLOW_TRUNCATE_LONG_VALUES"] = "false" - envs["MLFLOW_SQLALCHEMYSTORE_ECHO"] = "true" - - return sh.RunWithV( - envs, - "mlflow-go", - "server", - "--backend-store-uri", - "postgresql://postgres:postgres@localhost:5432/postgres", - "--go-opts", - "log_level=debug,shutdown_timeout=5s", - ) -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +// Start the mlflow-go dev server connecting to postgres. +func Dev() error { + mg.Deps(Generate) + + envs := make(map[string]string) + envs["MLFLOW_TRUNCATE_LONG_VALUES"] = "false" + envs["MLFLOW_SQLALCHEMYSTORE_ECHO"] = "true" + + return sh.RunWithV( + envs, + "mlflow-go", + "server", + "--backend-store-uri", + "postgresql://postgres:postgres@localhost:5432/postgres", + "--go-opts", + "log_level=debug,shutdown_timeout=5s", + ) +} diff --git a/magefiles/endpoints.go b/magefiles/endpoints.go index a7f3163..5239cac 100644 --- a/magefiles/endpoints.go +++ b/magefiles/endpoints.go @@ -1,56 +1,56 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "os" - - "github.com/olekukonko/tablewriter" - - "github.com/mlflow/mlflow-go/magefiles/generate" - "github.com/mlflow/mlflow-go/magefiles/generate/discovery" -) - -func contains(slice []string, value string) bool { - for _, v := range slice { - if v == value { - return true - } - } - - return false -} - -// Print an overview of implementated API endpoints. -func Endpoints() error { - services, err := discovery.GetServiceInfos() - if err != nil { - return err - } - - table := tablewriter.NewWriter(os.Stdout) - table.SetHeader([]string{"Service", "Endpoint", "Implemented"}) - table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_CENTER}) - table.SetRowLine(true) - - for _, service := range services { - servinceInfo, ok := generate.ServiceInfoMap[service.Name] - if !ok { - continue - } - - for _, method := range service.Methods { - implemented := "No" - if contains(servinceInfo.ImplementedEndpoints, method.Name) { - implemented = "Yes" - } - - table.Append([]string{service.Name, method.Name, implemented}) - } - } - - table.Render() - - return nil -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "os" + + "github.com/olekukonko/tablewriter" + + "github.com/mlflow/mlflow-go/magefiles/generate" + "github.com/mlflow/mlflow-go/magefiles/generate/discovery" +) + +func contains(slice []string, value string) bool { + for _, v := range slice { + if v == value { + return true + } + } + + return false +} + +// Print an overview of implementated API endpoints. +func Endpoints() error { + services, err := discovery.GetServiceInfos() + if err != nil { + return err + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"Service", "Endpoint", "Implemented"}) + table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_CENTER}) + table.SetRowLine(true) + + for _, service := range services { + servinceInfo, ok := generate.ServiceInfoMap[service.Name] + if !ok { + continue + } + + for _, method := range service.Methods { + implemented := "No" + if contains(servinceInfo.ImplementedEndpoints, method.Name) { + implemented = "Yes" + } + + table.Append([]string{service.Name, method.Name, implemented}) + } + } + + table.Render() + + return nil +} diff --git a/magefiles/generate.go b/magefiles/generate.go index 2ed8537..7b22a58 100644 --- a/magefiles/generate.go +++ b/magefiles/generate.go @@ -1,45 +1,45 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "path" - "path/filepath" - - "github.com/gofiber/fiber/v2/log" - "github.com/magefile/mage/mg" - - "github.com/mlflow/mlflow-go/magefiles/generate" -) - -// Generate Go files based on proto files and other configuration. -func Generate() error { - mg.Deps(Repo.Init) - - protoFolder, err := filepath.Abs(path.Join(MLFlowRepoFolderName, "mlflow", "protos")) - if err != nil { - return err - } - - if err := generate.RunProtoc(protoFolder); err != nil { - return err - } - - pkgFolder, err := filepath.Abs("pkg") - if err != nil { - return err - } - - if err := generate.AddQueryAnnotations(pkgFolder); err != nil { - return err - } - - if err := generate.SourceCode(pkgFolder); err != nil { - return err - } - - log.Info("Successfully added query annotations and generated services!") - - return nil -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "path" + "path/filepath" + + "github.com/gofiber/fiber/v2/log" + "github.com/magefile/mage/mg" + + "github.com/mlflow/mlflow-go/magefiles/generate" +) + +// Generate Go files based on proto files and other configuration. +func Generate() error { + mg.Deps(Repo.Init) + + protoFolder, err := filepath.Abs(path.Join(MLFlowRepoFolderName, "mlflow", "protos")) + if err != nil { + return err + } + + if err := generate.RunProtoc(protoFolder); err != nil { + return err + } + + pkgFolder, err := filepath.Abs("pkg") + if err != nil { + return err + } + + if err := generate.AddQueryAnnotations(pkgFolder); err != nil { + return err + } + + if err := generate.SourceCode(pkgFolder); err != nil { + return err + } + + log.Info("Successfully added query annotations and generated services!") + + return nil +} diff --git a/magefiles/generate/ast_creation.go b/magefiles/generate/ast_creation.go index d40875b..1c7d048 100644 --- a/magefiles/generate/ast_creation.go +++ b/magefiles/generate/ast_creation.go @@ -1,102 +1,102 @@ -package generate - -import ( - "go/ast" - "go/token" -) - -func mkImportSpec(value string) *ast.ImportSpec { - return &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: value}} -} - -func mkImportStatements(importStatements ...string) ast.Decl { - specs := make([]ast.Spec, 0, len(importStatements)) - - for _, importStatement := range importStatements { - specs = append(specs, mkImportSpec(importStatement)) - } - - return &ast.GenDecl{ - Tok: token.IMPORT, - Specs: specs, - } -} - -func mkStarExpr(e ast.Expr) *ast.StarExpr { - return &ast.StarExpr{ - X: e, - } -} - -func mkSelectorExpr(x, sel string) *ast.SelectorExpr { - return &ast.SelectorExpr{X: ast.NewIdent(x), Sel: ast.NewIdent(sel)} -} - -func mkNamedField(name string, typ ast.Expr) *ast.Field { - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(name)}, - Type: typ, - } -} - -func mkField(typ ast.Expr) *ast.Field { - return &ast.Field{ - Type: typ, - } -} - -// fun(arg1, arg2, ...) -func mkCallExpr(fun ast.Expr, args ...ast.Expr) *ast.CallExpr { - return &ast.CallExpr{ - Fun: fun, - Args: args, - } -} - -// Shorthand for creating &expr. -func mkAmpExpr(expr ast.Expr) *ast.UnaryExpr { - return &ast.UnaryExpr{ - Op: token.AND, - X: expr, - } -} - -// err != nil. -var errNotEqualNil = &ast.BinaryExpr{ - X: ast.NewIdent("err"), - Op: token.NEQ, - Y: ast.NewIdent("nil"), -} - -// return err. -var returnErr = &ast.ReturnStmt{ - Results: []ast.Expr{ast.NewIdent("err")}, -} - -func mkBlockStmt(stmts ...ast.Stmt) *ast.BlockStmt { - return &ast.BlockStmt{ - List: stmts, - } -} - -func mkIfStmt(init ast.Stmt, cond ast.Expr, body *ast.BlockStmt) *ast.IfStmt { - return &ast.IfStmt{ - Init: init, - Cond: cond, - Body: body, - } -} - -func mkAssignStmt(lhs, rhs []ast.Expr) *ast.AssignStmt { - return &ast.AssignStmt{ - Lhs: lhs, - Tok: token.DEFINE, - Rhs: rhs, - } -} - -func mkReturnStmt(results ...ast.Expr) *ast.ReturnStmt { - return &ast.ReturnStmt{ - Results: results, - } -} +package generate + +import ( + "go/ast" + "go/token" +) + +func mkImportSpec(value string) *ast.ImportSpec { + return &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: value}} +} + +func mkImportStatements(importStatements ...string) ast.Decl { + specs := make([]ast.Spec, 0, len(importStatements)) + + for _, importStatement := range importStatements { + specs = append(specs, mkImportSpec(importStatement)) + } + + return &ast.GenDecl{ + Tok: token.IMPORT, + Specs: specs, + } +} + +func mkStarExpr(e ast.Expr) *ast.StarExpr { + return &ast.StarExpr{ + X: e, + } +} + +func mkSelectorExpr(x, sel string) *ast.SelectorExpr { + return &ast.SelectorExpr{X: ast.NewIdent(x), Sel: ast.NewIdent(sel)} +} + +func mkNamedField(name string, typ ast.Expr) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(name)}, + Type: typ, + } +} + +func mkField(typ ast.Expr) *ast.Field { + return &ast.Field{ + Type: typ, + } +} + +// fun(arg1, arg2, ...) +func mkCallExpr(fun ast.Expr, args ...ast.Expr) *ast.CallExpr { + return &ast.CallExpr{ + Fun: fun, + Args: args, + } +} + +// Shorthand for creating &expr. +func mkAmpExpr(expr ast.Expr) *ast.UnaryExpr { + return &ast.UnaryExpr{ + Op: token.AND, + X: expr, + } +} + +// err != nil. +var errNotEqualNil = &ast.BinaryExpr{ + X: ast.NewIdent("err"), + Op: token.NEQ, + Y: ast.NewIdent("nil"), +} + +// return err. +var returnErr = &ast.ReturnStmt{ + Results: []ast.Expr{ast.NewIdent("err")}, +} + +func mkBlockStmt(stmts ...ast.Stmt) *ast.BlockStmt { + return &ast.BlockStmt{ + List: stmts, + } +} + +func mkIfStmt(init ast.Stmt, cond ast.Expr, body *ast.BlockStmt) *ast.IfStmt { + return &ast.IfStmt{ + Init: init, + Cond: cond, + Body: body, + } +} + +func mkAssignStmt(lhs, rhs []ast.Expr) *ast.AssignStmt { + return &ast.AssignStmt{ + Lhs: lhs, + Tok: token.DEFINE, + Rhs: rhs, + } +} + +func mkReturnStmt(results ...ast.Expr) *ast.ReturnStmt { + return &ast.ReturnStmt{ + Results: results, + } +} diff --git a/magefiles/generate/discovery/discovery.go b/magefiles/generate/discovery/discovery.go index 237f272..18a10f2 100644 --- a/magefiles/generate/discovery/discovery.go +++ b/magefiles/generate/discovery/discovery.go @@ -1,97 +1,97 @@ -package discovery - -import ( - "fmt" - "regexp" - "strings" - - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/protos/artifacts" -) - -type ServiceInfo struct { - Name string - Methods []MethodInfo -} - -type MethodInfo struct { - Name string - PackageName string - Input string - Output string - Endpoints []Endpoint -} - -type Endpoint struct { - Method string - Path string -} - -var routeParameterRegex = regexp.MustCompile(`<[^>]+:([^>]+)>`) - -// Get the safe path to use in Fiber registration. -func (e Endpoint) GetFiberPath() string { - // e.Path cannot be trusted, it could be something like /mlflow-artifacts/artifacts/ - // Which would need to converted to /mlflow-artifacts/artifacts/:path - path := routeParameterRegex.ReplaceAllStringFunc(e.Path, func(s string) string { - parts := strings.Split(s, ":") - - return ":" + strings.Trim(parts[0], "< ") - }) - - return path -} - -func GetServiceInfos() ([]ServiceInfo, error) { - serviceInfos := make([]ServiceInfo, 0) - - services := []struct { - Name string - PackageName string - Descriptor protoreflect.FileDescriptor - }{ - {"MlflowService", "protos", protos.File_service_proto}, - {"ModelRegistryService", "protos", protos.File_model_registry_proto}, - {"MlflowArtifactsService", "artifacts", artifacts.File_mlflow_artifacts_proto}, - } - - for _, service := range services { - serviceDescriptor := service.Descriptor.Services().ByName(protoreflect.Name(service.Name)) - - if serviceDescriptor == nil { - //nolint:goerr113 - return nil, fmt.Errorf("service %s not found", service.Name) - } - - serviceInfo := ServiceInfo{Name: service.Name, Methods: make([]MethodInfo, 0)} - - methods := serviceDescriptor.Methods() - for mIdx := range methods.Len() { - method := methods.Get(mIdx) - options := method.Options() - extension := proto.GetExtension(options, protos.E_Rpc) - - endpoints := make([]Endpoint, 0) - rpcOptions, ok := extension.(*protos.DatabricksRpcOptions) - - if ok { - for _, endpoint := range rpcOptions.GetEndpoints() { - endpoints = append(endpoints, Endpoint{Method: endpoint.GetMethod(), Path: endpoint.GetPath()}) - } - } - - output := fmt.Sprintf("%s_%s", string(method.Output().Parent().Name()), string(method.Output().Name())) - methodInfo := MethodInfo{ - string(method.Name()), service.PackageName, string(method.Input().Name()), output, endpoints, - } - serviceInfo.Methods = append(serviceInfo.Methods, methodInfo) - } - - serviceInfos = append(serviceInfos, serviceInfo) - } - - return serviceInfos, nil -} +package discovery + +import ( + "fmt" + "regexp" + "strings" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/protos/artifacts" +) + +type ServiceInfo struct { + Name string + Methods []MethodInfo +} + +type MethodInfo struct { + Name string + PackageName string + Input string + Output string + Endpoints []Endpoint +} + +type Endpoint struct { + Method string + Path string +} + +var routeParameterRegex = regexp.MustCompile(`<[^>]+:([^>]+)>`) + +// Get the safe path to use in Fiber registration. +func (e Endpoint) GetFiberPath() string { + // e.Path cannot be trusted, it could be something like /mlflow-artifacts/artifacts/ + // Which would need to converted to /mlflow-artifacts/artifacts/:path + path := routeParameterRegex.ReplaceAllStringFunc(e.Path, func(s string) string { + parts := strings.Split(s, ":") + + return ":" + strings.Trim(parts[0], "< ") + }) + + return path +} + +func GetServiceInfos() ([]ServiceInfo, error) { + serviceInfos := make([]ServiceInfo, 0) + + services := []struct { + Name string + PackageName string + Descriptor protoreflect.FileDescriptor + }{ + {"MlflowService", "protos", protos.File_service_proto}, + {"ModelRegistryService", "protos", protos.File_model_registry_proto}, + {"MlflowArtifactsService", "artifacts", artifacts.File_mlflow_artifacts_proto}, + } + + for _, service := range services { + serviceDescriptor := service.Descriptor.Services().ByName(protoreflect.Name(service.Name)) + + if serviceDescriptor == nil { + //nolint:goerr113 + return nil, fmt.Errorf("service %s not found", service.Name) + } + + serviceInfo := ServiceInfo{Name: service.Name, Methods: make([]MethodInfo, 0)} + + methods := serviceDescriptor.Methods() + for mIdx := range methods.Len() { + method := methods.Get(mIdx) + options := method.Options() + extension := proto.GetExtension(options, protos.E_Rpc) + + endpoints := make([]Endpoint, 0) + rpcOptions, ok := extension.(*protos.DatabricksRpcOptions) + + if ok { + for _, endpoint := range rpcOptions.GetEndpoints() { + endpoints = append(endpoints, Endpoint{Method: endpoint.GetMethod(), Path: endpoint.GetPath()}) + } + } + + output := fmt.Sprintf("%s_%s", string(method.Output().Parent().Name()), string(method.Output().Name())) + methodInfo := MethodInfo{ + string(method.Name()), service.PackageName, string(method.Input().Name()), output, endpoints, + } + serviceInfo.Methods = append(serviceInfo.Methods, methodInfo) + } + + serviceInfos = append(serviceInfos, serviceInfo) + } + + return serviceInfos, nil +} diff --git a/magefiles/generate/discovery/discovery_test.go b/magefiles/generate/discovery/discovery_test.go index f4bdeaf..447f7b7 100644 --- a/magefiles/generate/discovery/discovery_test.go +++ b/magefiles/generate/discovery/discovery_test.go @@ -1,55 +1,55 @@ -package discovery_test - -import ( - "testing" - - "github.com/mlflow/mlflow-go/magefiles/generate/discovery" -) - -func TestPattern(t *testing.T) { - t.Parallel() - - scenarios := []struct { - name string - endpoint discovery.Endpoint - expected string - }{ - { - name: "simple GET", - endpoint: discovery.Endpoint{ - Method: "GET", - Path: "/mlflow/experiments/get-by-name", - }, - expected: "/mlflow/experiments/get-by-name", - }, - { - name: "simple POST", - endpoint: discovery.Endpoint{ - Method: "POST", - Path: "/mlflow/experiments/create", - }, - expected: "/mlflow/experiments/create", - }, - { - name: "PUT with route parameter", - endpoint: discovery.Endpoint{ - Method: "PUT", - Path: "/mlflow-artifacts/artifacts/", - }, - expected: "/mlflow-artifacts/artifacts/:path", - }, - } - - for _, scenario := range scenarios { - currentScenario := scenario - t.Run(currentScenario.name, func(t *testing.T) { - t.Parallel() - - actual := currentScenario.endpoint.GetFiberPath() - - if actual != currentScenario.expected { - t.Errorf("Expected %s, got %s", currentScenario.expected, actual) - } - }) - } -} +package discovery_test + +import ( + "testing" + + "github.com/mlflow/mlflow-go/magefiles/generate/discovery" +) + +func TestPattern(t *testing.T) { + t.Parallel() + + scenarios := []struct { + name string + endpoint discovery.Endpoint + expected string + }{ + { + name: "simple GET", + endpoint: discovery.Endpoint{ + Method: "GET", + Path: "/mlflow/experiments/get-by-name", + }, + expected: "/mlflow/experiments/get-by-name", + }, + { + name: "simple POST", + endpoint: discovery.Endpoint{ + Method: "POST", + Path: "/mlflow/experiments/create", + }, + expected: "/mlflow/experiments/create", + }, + { + name: "PUT with route parameter", + endpoint: discovery.Endpoint{ + Method: "PUT", + Path: "/mlflow-artifacts/artifacts/", + }, + expected: "/mlflow-artifacts/artifacts/:path", + }, + } + + for _, scenario := range scenarios { + currentScenario := scenario + t.Run(currentScenario.name, func(t *testing.T) { + t.Parallel() + + actual := currentScenario.endpoint.GetFiberPath() + + if actual != currentScenario.expected { + t.Errorf("Expected %s, got %s", currentScenario.expected, actual) + } + }) + } +} diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index 63b22dc..e0f5989 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -1,87 +1,87 @@ -package generate - -type ServiceGenerationInfo struct { - FileNameWithoutExtension string - ServiceName string - ImplementedEndpoints []string -} - -var ServiceInfoMap = map[string]ServiceGenerationInfo{ - "MlflowService": { - FileNameWithoutExtension: "tracking", - ServiceName: "TrackingService", - ImplementedEndpoints: []string{ - "getExperimentByName", - "createExperiment", - // "searchExperiments", - "getExperiment", - "deleteExperiment", - "restoreExperiment", - "updateExperiment", - "getRun", - "createRun", - "updateRun", - "deleteRun", - "restoreRun", - "logMetric", - // "logParam", - // "setExperimentTag", - "setTag", - // "setTraceTag", - // "deleteTraceTag", - "deleteTag", - "searchRuns", - // "listArtifacts", - // "getMetricHistory", - // "getMetricHistoryBulkInterval", - "logBatch", - // "logModel", - // "logInputs", - // "startTrace", - // "endTrace", - // "getTraceInfo", - // "searchTraces", - // "deleteTraces", - }, - }, - "ModelRegistryService": { - FileNameWithoutExtension: "model_registry", - ServiceName: "ModelRegistryService", - ImplementedEndpoints: []string{ - // "createRegisteredModel", - // "renameRegisteredModel", - // "updateRegisteredModel", - // "deleteRegisteredModel", - // "getRegisteredModel", - // "searchRegisteredModels", - "getLatestVersions", - // "createModelVersion", - // "updateModelVersion", - // "transitionModelVersionStage", - // "deleteModelVersion", - // "getModelVersion", - // "searchModelVersions", - // "getModelVersionDownloadUri", - // "setRegisteredModelTag", - // "setModelVersionTag", - // "deleteRegisteredModelTag", - // "deleteModelVersionTag", - // "setRegisteredModelAlias", - // "deleteRegisteredModelAlias", - // "getModelVersionByAlias", - }, - }, - "MlflowArtifactsService": { - FileNameWithoutExtension: "artifacts", - ServiceName: "ArtifactsService", - ImplementedEndpoints: []string{ - // "downloadArtifact", - // "uploadArtifact", - // "listArtifacts", - // "deleteArtifact", - // "createMultipartUpload", - // "completeMultipartUpload", - // "abortMultipartUpload", - }, - }, -} +package generate + +type ServiceGenerationInfo struct { + FileNameWithoutExtension string + ServiceName string + ImplementedEndpoints []string +} + +var ServiceInfoMap = map[string]ServiceGenerationInfo{ + "MlflowService": { + FileNameWithoutExtension: "tracking", + ServiceName: "TrackingService", + ImplementedEndpoints: []string{ + "getExperimentByName", + "createExperiment", + // "searchExperiments", + "getExperiment", + "deleteExperiment", + "restoreExperiment", + "updateExperiment", + "getRun", + "createRun", + "updateRun", + "deleteRun", + "restoreRun", + "logMetric", + // "logParam", + // "setExperimentTag", + "setTag", + // "setTraceTag", + // "deleteTraceTag", + "deleteTag", + "searchRuns", + // "listArtifacts", + // "getMetricHistory", + // "getMetricHistoryBulkInterval", + "logBatch", + // "logModel", + // "logInputs", + // "startTrace", + // "endTrace", + // "getTraceInfo", + // "searchTraces", + // "deleteTraces", + }, + }, + "ModelRegistryService": { + FileNameWithoutExtension: "model_registry", + ServiceName: "ModelRegistryService", + ImplementedEndpoints: []string{ + // "createRegisteredModel", + // "renameRegisteredModel", + // "updateRegisteredModel", + // "deleteRegisteredModel", + // "getRegisteredModel", + // "searchRegisteredModels", + "getLatestVersions", + // "createModelVersion", + // "updateModelVersion", + // "transitionModelVersionStage", + // "deleteModelVersion", + // "getModelVersion", + // "searchModelVersions", + // "getModelVersionDownloadUri", + // "setRegisteredModelTag", + // "setModelVersionTag", + // "deleteRegisteredModelTag", + // "deleteModelVersionTag", + // "setRegisteredModelAlias", + // "deleteRegisteredModelAlias", + // "getModelVersionByAlias", + }, + }, + "MlflowArtifactsService": { + FileNameWithoutExtension: "artifacts", + ServiceName: "ArtifactsService", + ImplementedEndpoints: []string{ + // "downloadArtifact", + // "uploadArtifact", + // "listArtifacts", + // "deleteArtifact", + // "createMultipartUpload", + // "completeMultipartUpload", + // "abortMultipartUpload", + }, + }, +} diff --git a/magefiles/generate/protos.go b/magefiles/generate/protos.go index f26d4a3..5586e36 100644 --- a/magefiles/generate/protos.go +++ b/magefiles/generate/protos.go @@ -1,58 +1,58 @@ -package generate - -import ( - "fmt" - "os/exec" - "path" - "strings" -) - -const MLFlowCommit = "3effa7380c86946f4557f03aa81119a097d8b433" - -var protoFiles = map[string]string{ - "databricks.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "service.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "model_registry.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "databricks_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "mlflow_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos/artifacts", - "internal.proto": "github.com/mlflow/mlflow-go/pkg/protos", - "scalapb/scalapb.proto": "github.com/mlflow/mlflow-go/pkg/protos/scalapb", -} - -const fixedArguments = 3 - -func RunProtoc(protoDir string) error { - arguments := make([]string, 0, len(protoFiles)*2+fixedArguments) - - arguments = append( - arguments, - "-I="+protoDir, - `--go_out=.`, - `--go_opt=module=github.com/mlflow/mlflow-go`, - ) - - for fileName, goPackage := range protoFiles { - arguments = append( - arguments, - fmt.Sprintf("--go_opt=M%s=%s", fileName, goPackage), - ) - } - - for fileName := range protoFiles { - arguments = append(arguments, path.Join(protoDir, fileName)) - } - - cmd := exec.Command("protoc", arguments...) - - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf( - "failed to run protoc %s process, err: %s: %w", - strings.Join(arguments, " "), - output, - err, - ) - } - - return nil -} +package generate + +import ( + "fmt" + "os/exec" + "path" + "strings" +) + +const MLFlowCommit = "3effa7380c86946f4557f03aa81119a097d8b433" + +var protoFiles = map[string]string{ + "databricks.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "service.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "model_registry.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "databricks_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "mlflow_artifacts.proto": "github.com/mlflow/mlflow-go/pkg/protos/artifacts", + "internal.proto": "github.com/mlflow/mlflow-go/pkg/protos", + "scalapb/scalapb.proto": "github.com/mlflow/mlflow-go/pkg/protos/scalapb", +} + +const fixedArguments = 3 + +func RunProtoc(protoDir string) error { + arguments := make([]string, 0, len(protoFiles)*2+fixedArguments) + + arguments = append( + arguments, + "-I="+protoDir, + `--go_out=.`, + `--go_opt=module=github.com/mlflow/mlflow-go`, + ) + + for fileName, goPackage := range protoFiles { + arguments = append( + arguments, + fmt.Sprintf("--go_opt=M%s=%s", fileName, goPackage), + ) + } + + for fileName := range protoFiles { + arguments = append(arguments, path.Join(protoDir, fileName)) + } + + cmd := exec.Command("protoc", arguments...) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf( + "failed to run protoc %s process, err: %s: %w", + strings.Join(arguments, " "), + output, + err, + ) + } + + return nil +} diff --git a/magefiles/generate/query_annotations.go b/magefiles/generate/query_annotations.go index a7c0b41..c661e9b 100644 --- a/magefiles/generate/query_annotations.go +++ b/magefiles/generate/query_annotations.go @@ -1,116 +1,116 @@ -package generate - -import ( - "fmt" - "go/ast" - "go/parser" - "go/token" - "io/fs" - "os" - "path/filepath" - "regexp" - "strings" -) - -// Inspect the AST of the incoming file and add a query annotation to the struct tags which have a json tag. -// -//nolint:funlen,cyclop -func addQueryAnnotation(generatedGoFile string) error { - // Parse the file into an AST - fset := token.NewFileSet() - - node, err := parser.ParseFile(fset, generatedGoFile, nil, parser.ParseComments) - if err != nil { - return fmt.Errorf("add query annotation failed: %w", err) - } - - // Create an AST inspector to modify specific struct tags - ast.Inspect(node, func(n ast.Node) bool { - // Look for struct type declarations - typeSpec, isTypeSpec := n.(*ast.TypeSpec) - if !isTypeSpec { - return true - } - - structType, isStructType := typeSpec.Type.(*ast.StructType) - - if !isStructType { - return true - } - - // Iterate over fields in the struct - for _, field := range structType.Fields.List { - if field.Tag == nil { - continue - } - - tagValue := field.Tag.Value - - hasQuery := strings.Contains(tagValue, "query:") - hasValidate := strings.Contains(tagValue, "validate:") - validationKey := fmt.Sprintf("%s_%s", typeSpec.Name, field.Names[0]) - validationRule, needsValidation := validations[validationKey] - - if hasQuery && (!needsValidation || needsValidation && hasValidate) { - continue - } - - // With opening ` tick - newTag := tagValue[0 : len(tagValue)-1] - - matches := jsonFieldTagRegexp.FindStringSubmatch(tagValue) - if len(matches) > 0 && !hasQuery { - // Modify the tag here - // The json annotation could be something like `json:"key,omitempty"` - // We only want the key part, the `omitempty` is not relevant for the query annotation - key := matches[1] - if strings.Contains(key, ",") { - key = strings.Split(key, ",")[0] - } - // Add query annotation - newTag += fmt.Sprintf(" query:\"%s\"", key) - } - - if needsValidation { - // Add validation annotation - newTag += fmt.Sprintf(" validate:\"%s\"", validationRule) - } - - // Closing ` tick - newTag += "`" - field.Tag.Value = newTag - } - - return false - }) - - return saveASTToFile(fset, node, false, generatedGoFile) -} - -var jsonFieldTagRegexp = regexp.MustCompile(`json:"([^"]+)"`) - -//nolint:err113 -func AddQueryAnnotations(pkgFolder string) error { - protoFolder := filepath.Join(pkgFolder, "protos") - - if _, pathError := os.Stat(protoFolder); os.IsNotExist(pathError) { - return fmt.Errorf("the %s folder does not exist. Are the Go protobuf files generated?", protoFolder) - } - - err := filepath.WalkDir(protoFolder, func(path string, _ fs.DirEntry, err error) error { - if err != nil { - return err - } - - if filepath.Ext(path) == ".go" { - err = addQueryAnnotation(path) - } - - return err - }) - if err != nil { - return fmt.Errorf("failed to add query annotation: %w", err) - } - - return nil -} +package generate + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" +) + +// Inspect the AST of the incoming file and add a query annotation to the struct tags which have a json tag. +// +//nolint:funlen,cyclop +func addQueryAnnotation(generatedGoFile string) error { + // Parse the file into an AST + fset := token.NewFileSet() + + node, err := parser.ParseFile(fset, generatedGoFile, nil, parser.ParseComments) + if err != nil { + return fmt.Errorf("add query annotation failed: %w", err) + } + + // Create an AST inspector to modify specific struct tags + ast.Inspect(node, func(n ast.Node) bool { + // Look for struct type declarations + typeSpec, isTypeSpec := n.(*ast.TypeSpec) + if !isTypeSpec { + return true + } + + structType, isStructType := typeSpec.Type.(*ast.StructType) + + if !isStructType { + return true + } + + // Iterate over fields in the struct + for _, field := range structType.Fields.List { + if field.Tag == nil { + continue + } + + tagValue := field.Tag.Value + + hasQuery := strings.Contains(tagValue, "query:") + hasValidate := strings.Contains(tagValue, "validate:") + validationKey := fmt.Sprintf("%s_%s", typeSpec.Name, field.Names[0]) + validationRule, needsValidation := validations[validationKey] + + if hasQuery && (!needsValidation || needsValidation && hasValidate) { + continue + } + + // With opening ` tick + newTag := tagValue[0 : len(tagValue)-1] + + matches := jsonFieldTagRegexp.FindStringSubmatch(tagValue) + if len(matches) > 0 && !hasQuery { + // Modify the tag here + // The json annotation could be something like `json:"key,omitempty"` + // We only want the key part, the `omitempty` is not relevant for the query annotation + key := matches[1] + if strings.Contains(key, ",") { + key = strings.Split(key, ",")[0] + } + // Add query annotation + newTag += fmt.Sprintf(" query:\"%s\"", key) + } + + if needsValidation { + // Add validation annotation + newTag += fmt.Sprintf(" validate:\"%s\"", validationRule) + } + + // Closing ` tick + newTag += "`" + field.Tag.Value = newTag + } + + return false + }) + + return saveASTToFile(fset, node, false, generatedGoFile) +} + +var jsonFieldTagRegexp = regexp.MustCompile(`json:"([^"]+)"`) + +//nolint:err113 +func AddQueryAnnotations(pkgFolder string) error { + protoFolder := filepath.Join(pkgFolder, "protos") + + if _, pathError := os.Stat(protoFolder); os.IsNotExist(pathError) { + return fmt.Errorf("the %s folder does not exist. Are the Go protobuf files generated?", protoFolder) + } + + err := filepath.WalkDir(protoFolder, func(path string, _ fs.DirEntry, err error) error { + if err != nil { + return err + } + + if filepath.Ext(path) == ".go" { + err = addQueryAnnotation(path) + } + + return err + }) + if err != nil { + return fmt.Errorf("failed to add query annotation: %w", err) + } + + return nil +} diff --git a/magefiles/generate/source_code.go b/magefiles/generate/source_code.go index 0f15325..7146d2c 100644 --- a/magefiles/generate/source_code.go +++ b/magefiles/generate/source_code.go @@ -1,468 +1,468 @@ -package generate - -import ( - "bufio" - "fmt" - "go/ast" - "go/format" - "go/token" - "net/http" - "os" - "path/filepath" - - "github.com/iancoleman/strcase" - - "github.com/mlflow/mlflow-go/magefiles/generate/discovery" -) - -func mkMethodInfoInputPointerType(methodInfo discovery.MethodInfo) *ast.StarExpr { - return mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Input)) -} - -// Generate a method declaration on an service interface. -func mkServiceInterfaceMethod(methodInfo discovery.MethodInfo) *ast.Field { - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(strcase.ToCamel(methodInfo.Name))}, - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("ctx", mkSelectorExpr("context", "Context")), - mkNamedField("input", mkMethodInfoInputPointerType(methodInfo)), - }, - }, - Results: &ast.FieldList{ - List: []*ast.Field{ - mkField(mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Output))), - mkField(mkStarExpr(mkSelectorExpr("contract", "Error"))), - }, - }, - }, - } -} - -// Generate a service interface declaration. -func mkServiceInterfaceNode( - endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, -) *ast.GenDecl { - // We add one method to validate any of the input structs - methods := make([]*ast.Field, 0, len(serviceInfo.Methods)) - - for _, method := range serviceInfo.Methods { - if _, ok := endpoints[method.Name]; ok { - methods = append(methods, mkServiceInterfaceMethod(method)) - } - } - - // Create an interface declaration - return &ast.GenDecl{ - Tok: token.TYPE, // Specifies a type declaration - Specs: []ast.Spec{ - &ast.TypeSpec{ - Name: ast.NewIdent(interfaceName), - Type: &ast.InterfaceType{ - Methods: &ast.FieldList{ - List: methods, - }, - }, - }, - }, - } -} - -func saveASTToFile(fset *token.FileSet, file *ast.File, addComment bool, outputPath string) error { - // Create or truncate the output file - outputFile, err := os.Create(outputPath) - if err != nil { - return fmt.Errorf("failed to create output file: %w", err) - } - defer outputFile.Close() - - // Use a bufio.Writer for buffered writing (optional) - writer := bufio.NewWriter(outputFile) - defer writer.Flush() - - if addComment { - _, err := writer.WriteString("// Code generated by mlflow/go/cmd/generate/main.go. DO NOT EDIT.\n\n") - if err != nil { - return fmt.Errorf("failed to add comment to generated file: %w", err) - } - } - - // Write the generated code to the file - err = format.Node(writer, fset, file) - if err != nil { - return fmt.Errorf("failed to write generated AST to file: %w", err) - } - - return nil -} - -//nolint:funlen -func mkAppRoute(method discovery.MethodInfo, endpoint discovery.Endpoint) ast.Stmt { - urlExpr := &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s"`, endpoint.GetFiberPath())} - - // input := &protos.SearchExperiments - inputExpr := mkAssignStmt( - []ast.Expr{ast.NewIdent("input")}, - []ast.Expr{ - mkAmpExpr(&ast.CompositeLit{ - Type: mkSelectorExpr(method.PackageName, method.Input), - }), - }) - - // if err := parser.ParseQuery(ctx, input); err != nil { return err } - // if err := parser.ParseBody(ctx, input); err != nil { return err } - var extractModel ast.Expr - if endpoint.Method == http.MethodGet { - extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseQuery"), ast.NewIdent("ctx"), ast.NewIdent("input")) - } else { - extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseBody"), ast.NewIdent("ctx"), ast.NewIdent("input")) - } - - inputErrorCheck := mkIfStmt( - mkAssignStmt([]ast.Expr{ast.NewIdent("err")}, []ast.Expr{extractModel}), - errNotEqualNil, - mkBlockStmt(returnErr), - ) - - // output, err := service.Method(input) - outputExpr := mkAssignStmt([]ast.Expr{ - ast.NewIdent("output"), - ast.NewIdent("err"), - }, []ast.Expr{ - mkCallExpr( - mkSelectorExpr( - "service", - strcase.ToCamel(method.Name), - ), - mkCallExpr( - mkSelectorExpr("utils", "NewContextWithLoggerFromFiberContext"), - ast.NewIdent("ctx"), - ), - ast.NewIdent("input"), - ), - }) - - // if err != nil { - // return err - // } - errorCheck := mkIfStmt( - nil, - errNotEqualNil, - mkBlockStmt( - mkReturnStmt(ast.NewIdent("err")), - ), - ) - - // return ctx.JSON(output) - returnExpr := mkReturnStmt(mkCallExpr(mkSelectorExpr("ctx", "JSON"), ast.NewIdent("output"))) - - // func(ctx *fiber.Ctx) error { .. } - funcExpr := &ast.FuncLit{ - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("ctx", mkStarExpr(mkSelectorExpr("fiber", "Ctx"))), - }, - }, - Results: &ast.FieldList{ - List: []*ast.Field{ - mkField(ast.NewIdent("error")), - }, - }, - }, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - inputExpr, - inputErrorCheck, - outputExpr, - errorCheck, - returnExpr, - }, - }, - } - - return &ast.ExprStmt{ - // app.Get("/mlflow/experiments/search", func(ctx *fiber.Ctx) error { .. }) - X: mkCallExpr( - mkSelectorExpr("app", strcase.ToCamel(endpoint.Method)), urlExpr, funcExpr, - ), - } -} - -func mkRouteRegistrationFunction( - endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, -) *ast.FuncDecl { - routes := make([]ast.Stmt, 0, len(serviceInfo.Methods)) - - for _, method := range serviceInfo.Methods { - for _, endpoint := range method.Endpoints { - if _, ok := endpoints[method.Name]; ok { - routes = append(routes, mkAppRoute(method, endpoint)) - } - } - } - - return &ast.FuncDecl{ - Name: ast.NewIdent(fmt.Sprintf("Register%sRoutes", interfaceName)), - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("service", mkSelectorExpr("service", interfaceName)), - mkNamedField("parser", mkStarExpr(mkSelectorExpr("parser", "HTTPRequestParser"))), - mkNamedField("app", mkStarExpr(ast.NewIdent("fiber.App"))), - }, - }, - }, - Body: &ast.BlockStmt{ - List: routes, - }, - } -} - -func mkGeneratedFile(pkg, outputPath string, decls []ast.Decl) error { - // Set up the FileSet and the AST File - fset := token.NewFileSet() - - file := &ast.File{ - Name: ast.NewIdent(pkg), - Decls: decls, - } - - err := saveASTToFile(fset, file, true, outputPath) - if err != nil { - return fmt.Errorf("failed to save AST to file: %w", err) - } - - return nil -} - -const expectedImportStatements = 2 - -// Generate the service interface. -func generateServices( - pkgFolder string, - serviceInfo discovery.ServiceInfo, - generationInfo ServiceGenerationInfo, - endpoints map[string]any, -) error { - decls := make([]ast.Decl, 0, len(endpoints)+expectedImportStatements) - - if len(endpoints) > 0 { - decls = append(decls, - mkImportStatements( - `"context"`, - `"github.com/mlflow/mlflow-go/pkg/protos"`, - `"github.com/mlflow/mlflow-go/pkg/contract"`, - )) - } - - decls = append(decls, mkServiceInterfaceNode( - endpoints, - generationInfo.ServiceName, - serviceInfo, - )) - - fileName := generationInfo.FileNameWithoutExtension + ".g.go" - pkg := "service" - outputPath := filepath.Join(pkgFolder, "contract", pkg, fileName) - - return mkGeneratedFile(pkg, outputPath, decls) -} - -func generateRouteRegistrations( - pkgFolder string, - serviceInfo discovery.ServiceInfo, - generationInfo ServiceGenerationInfo, - endpoints map[string]any, -) error { - importStatements := []string{ - `"github.com/gofiber/fiber/v2"`, - `"github.com/mlflow/mlflow-go/pkg/server/parser"`, - `"github.com/mlflow/mlflow-go/pkg/contract/service"`, - } - - if len(endpoints) > 0 { - importStatements = append( - importStatements, - `"github.com/mlflow/mlflow-go/pkg/utils"`, - `"github.com/mlflow/mlflow-go/pkg/protos"`, - ) - } - - decls := []ast.Decl{ - mkImportStatements(importStatements...), - mkRouteRegistrationFunction(endpoints, generationInfo.ServiceName, serviceInfo), - } - - fileName := generationInfo.FileNameWithoutExtension + ".g.go" - pkg := "routes" - outputPath := filepath.Join(pkgFolder, "server", pkg, fileName) - - return mkGeneratedFile(pkg, outputPath, decls) -} - -func mkCEndpointBody(serviceName string, method discovery.MethodInfo) *ast.BlockStmt { - mapName := strcase.ToLowerCamel(serviceName) + "s" - - return &ast.BlockStmt{ - List: []ast.Stmt{ - // service, err := trackingServices.Get(serviceID) - mkAssignStmt( - []ast.Expr{ - ast.NewIdent("service"), - ast.NewIdent("err"), - }, - []ast.Expr{ - mkCallExpr(mkSelectorExpr(mapName, "Get"), ast.NewIdent("serviceID")), - }, - ), - // if err != nil { - // return makePointerFromError(err, responseSize) - // } - mkIfStmt( - nil, - errNotEqualNil, - mkBlockStmt( - mkReturnStmt( - mkCallExpr( - ast.NewIdent("makePointerFromError"), - ast.NewIdent("err"), - ast.NewIdent("responseSize"), - ), - ), - ), - ), - // return invokeServiceMethod( - // service.GetExperiment, - // new(protos.GetExperiment), - // requestData, - // requestSize, - // responseSize, - // ) - mkReturnStmt( - mkCallExpr( - ast.NewIdent("invokeServiceMethod"), - mkSelectorExpr("service", strcase.ToCamel(method.Name)), - mkCallExpr(ast.NewIdent("new"), mkSelectorExpr("protos", method.Input)), - ast.NewIdent("requestData"), - ast.NewIdent("requestSize"), - ast.NewIdent("responseSize"), - ), - ), - }, - } -} - -func mkCEndpoint(serviceName string, method discovery.MethodInfo) *ast.FuncDecl { - functionName := fmt.Sprintf("%s%s", serviceName, strcase.ToCamel(method.Name)) - - return &ast.FuncDecl{ - Doc: &ast.CommentGroup{ - List: []*ast.Comment{ - { - Text: "//export " + functionName, - }, - }, - }, - Name: ast.NewIdent(functionName), - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - mkNamedField("serviceID", ast.NewIdent("int64")), - mkNamedField("requestData", mkSelectorExpr("unsafe", "Pointer")), - mkNamedField("requestSize", mkSelectorExpr("C", "int")), - mkNamedField("responseSize", mkStarExpr(mkSelectorExpr("C", "int"))), - }, - }, - Results: &ast.FieldList{ - List: []*ast.Field{ - mkField(mkSelectorExpr("unsafe", "Pointer")), - }, - }, - }, - Body: mkCEndpointBody(serviceName, method), - } -} - -func mkCEndpoints( - endpoints map[string]any, serviceName string, serviceInfo discovery.ServiceInfo, -) []*ast.FuncDecl { - funcs := make([]*ast.FuncDecl, 0, len(endpoints)) - - for _, method := range serviceInfo.Methods { - if _, ok := endpoints[method.Name]; ok { - funcs = append(funcs, mkCEndpoint(serviceName, method)) - } - } - - return funcs -} - -func generateEndpoints( - pkgFolder string, - serviceInfo discovery.ServiceInfo, - generationInfo ServiceGenerationInfo, - endpoints map[string]any, -) error { - decls := []ast.Decl{ - mkImportStatements(`"C"`), - } - - if len(endpoints) > 0 { - decls = append( - decls, - mkImportStatements( - `"unsafe"`, - `"github.com/mlflow/mlflow-go/pkg/protos"`, - ), - ) - - endpoints := mkCEndpoints(endpoints, generationInfo.ServiceName, serviceInfo) - for _, endpoint := range endpoints { - decls = append(decls, endpoint) - } - } - - fileName := generationInfo.FileNameWithoutExtension + ".g.go" - outputPath := filepath.Join(pkgFolder, "lib", fileName) - - return mkGeneratedFile("main", outputPath, decls) -} - -func SourceCode(pkgFolder string) error { - services, err := discovery.GetServiceInfos() - if err != nil { - return fmt.Errorf("failed to get service info: %w", err) - } - - for _, serviceInfo := range services { - generationInfo, ok := ServiceInfoMap[serviceInfo.Name] - if !ok { - continue - } - - endpoints := make(map[string]any, len(generationInfo.ImplementedEndpoints)) - - for _, endpoint := range generationInfo.ImplementedEndpoints { - endpoints[endpoint] = nil - } - - err = generateServices(pkgFolder, serviceInfo, generationInfo, endpoints) - if err != nil { - return err - } - - err = generateRouteRegistrations(pkgFolder, serviceInfo, generationInfo, endpoints) - if err != nil { - return err - } - - err = generateEndpoints(pkgFolder, serviceInfo, generationInfo, endpoints) - if err != nil { - return err - } - } - - return nil -} +package generate + +import ( + "bufio" + "fmt" + "go/ast" + "go/format" + "go/token" + "net/http" + "os" + "path/filepath" + + "github.com/iancoleman/strcase" + + "github.com/mlflow/mlflow-go/magefiles/generate/discovery" +) + +func mkMethodInfoInputPointerType(methodInfo discovery.MethodInfo) *ast.StarExpr { + return mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Input)) +} + +// Generate a method declaration on an service interface. +func mkServiceInterfaceMethod(methodInfo discovery.MethodInfo) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(strcase.ToCamel(methodInfo.Name))}, + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("ctx", mkSelectorExpr("context", "Context")), + mkNamedField("input", mkMethodInfoInputPointerType(methodInfo)), + }, + }, + Results: &ast.FieldList{ + List: []*ast.Field{ + mkField(mkStarExpr(mkSelectorExpr(methodInfo.PackageName, methodInfo.Output))), + mkField(mkStarExpr(mkSelectorExpr("contract", "Error"))), + }, + }, + }, + } +} + +// Generate a service interface declaration. +func mkServiceInterfaceNode( + endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, +) *ast.GenDecl { + // We add one method to validate any of the input structs + methods := make([]*ast.Field, 0, len(serviceInfo.Methods)) + + for _, method := range serviceInfo.Methods { + if _, ok := endpoints[method.Name]; ok { + methods = append(methods, mkServiceInterfaceMethod(method)) + } + } + + // Create an interface declaration + return &ast.GenDecl{ + Tok: token.TYPE, // Specifies a type declaration + Specs: []ast.Spec{ + &ast.TypeSpec{ + Name: ast.NewIdent(interfaceName), + Type: &ast.InterfaceType{ + Methods: &ast.FieldList{ + List: methods, + }, + }, + }, + }, + } +} + +func saveASTToFile(fset *token.FileSet, file *ast.File, addComment bool, outputPath string) error { + // Create or truncate the output file + outputFile, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer outputFile.Close() + + // Use a bufio.Writer for buffered writing (optional) + writer := bufio.NewWriter(outputFile) + defer writer.Flush() + + if addComment { + _, err := writer.WriteString("// Code generated by mlflow/go/cmd/generate/main.go. DO NOT EDIT.\n\n") + if err != nil { + return fmt.Errorf("failed to add comment to generated file: %w", err) + } + } + + // Write the generated code to the file + err = format.Node(writer, fset, file) + if err != nil { + return fmt.Errorf("failed to write generated AST to file: %w", err) + } + + return nil +} + +//nolint:funlen +func mkAppRoute(method discovery.MethodInfo, endpoint discovery.Endpoint) ast.Stmt { + urlExpr := &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s"`, endpoint.GetFiberPath())} + + // input := &protos.SearchExperiments + inputExpr := mkAssignStmt( + []ast.Expr{ast.NewIdent("input")}, + []ast.Expr{ + mkAmpExpr(&ast.CompositeLit{ + Type: mkSelectorExpr(method.PackageName, method.Input), + }), + }) + + // if err := parser.ParseQuery(ctx, input); err != nil { return err } + // if err := parser.ParseBody(ctx, input); err != nil { return err } + var extractModel ast.Expr + if endpoint.Method == http.MethodGet { + extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseQuery"), ast.NewIdent("ctx"), ast.NewIdent("input")) + } else { + extractModel = mkCallExpr(mkSelectorExpr("parser", "ParseBody"), ast.NewIdent("ctx"), ast.NewIdent("input")) + } + + inputErrorCheck := mkIfStmt( + mkAssignStmt([]ast.Expr{ast.NewIdent("err")}, []ast.Expr{extractModel}), + errNotEqualNil, + mkBlockStmt(returnErr), + ) + + // output, err := service.Method(input) + outputExpr := mkAssignStmt([]ast.Expr{ + ast.NewIdent("output"), + ast.NewIdent("err"), + }, []ast.Expr{ + mkCallExpr( + mkSelectorExpr( + "service", + strcase.ToCamel(method.Name), + ), + mkCallExpr( + mkSelectorExpr("utils", "NewContextWithLoggerFromFiberContext"), + ast.NewIdent("ctx"), + ), + ast.NewIdent("input"), + ), + }) + + // if err != nil { + // return err + // } + errorCheck := mkIfStmt( + nil, + errNotEqualNil, + mkBlockStmt( + mkReturnStmt(ast.NewIdent("err")), + ), + ) + + // return ctx.JSON(output) + returnExpr := mkReturnStmt(mkCallExpr(mkSelectorExpr("ctx", "JSON"), ast.NewIdent("output"))) + + // func(ctx *fiber.Ctx) error { .. } + funcExpr := &ast.FuncLit{ + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("ctx", mkStarExpr(mkSelectorExpr("fiber", "Ctx"))), + }, + }, + Results: &ast.FieldList{ + List: []*ast.Field{ + mkField(ast.NewIdent("error")), + }, + }, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + inputExpr, + inputErrorCheck, + outputExpr, + errorCheck, + returnExpr, + }, + }, + } + + return &ast.ExprStmt{ + // app.Get("/mlflow/experiments/search", func(ctx *fiber.Ctx) error { .. }) + X: mkCallExpr( + mkSelectorExpr("app", strcase.ToCamel(endpoint.Method)), urlExpr, funcExpr, + ), + } +} + +func mkRouteRegistrationFunction( + endpoints map[string]any, interfaceName string, serviceInfo discovery.ServiceInfo, +) *ast.FuncDecl { + routes := make([]ast.Stmt, 0, len(serviceInfo.Methods)) + + for _, method := range serviceInfo.Methods { + for _, endpoint := range method.Endpoints { + if _, ok := endpoints[method.Name]; ok { + routes = append(routes, mkAppRoute(method, endpoint)) + } + } + } + + return &ast.FuncDecl{ + Name: ast.NewIdent(fmt.Sprintf("Register%sRoutes", interfaceName)), + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("service", mkSelectorExpr("service", interfaceName)), + mkNamedField("parser", mkStarExpr(mkSelectorExpr("parser", "HTTPRequestParser"))), + mkNamedField("app", mkStarExpr(ast.NewIdent("fiber.App"))), + }, + }, + }, + Body: &ast.BlockStmt{ + List: routes, + }, + } +} + +func mkGeneratedFile(pkg, outputPath string, decls []ast.Decl) error { + // Set up the FileSet and the AST File + fset := token.NewFileSet() + + file := &ast.File{ + Name: ast.NewIdent(pkg), + Decls: decls, + } + + err := saveASTToFile(fset, file, true, outputPath) + if err != nil { + return fmt.Errorf("failed to save AST to file: %w", err) + } + + return nil +} + +const expectedImportStatements = 2 + +// Generate the service interface. +func generateServices( + pkgFolder string, + serviceInfo discovery.ServiceInfo, + generationInfo ServiceGenerationInfo, + endpoints map[string]any, +) error { + decls := make([]ast.Decl, 0, len(endpoints)+expectedImportStatements) + + if len(endpoints) > 0 { + decls = append(decls, + mkImportStatements( + `"context"`, + `"github.com/mlflow/mlflow-go/pkg/protos"`, + `"github.com/mlflow/mlflow-go/pkg/contract"`, + )) + } + + decls = append(decls, mkServiceInterfaceNode( + endpoints, + generationInfo.ServiceName, + serviceInfo, + )) + + fileName := generationInfo.FileNameWithoutExtension + ".g.go" + pkg := "service" + outputPath := filepath.Join(pkgFolder, "contract", pkg, fileName) + + return mkGeneratedFile(pkg, outputPath, decls) +} + +func generateRouteRegistrations( + pkgFolder string, + serviceInfo discovery.ServiceInfo, + generationInfo ServiceGenerationInfo, + endpoints map[string]any, +) error { + importStatements := []string{ + `"github.com/gofiber/fiber/v2"`, + `"github.com/mlflow/mlflow-go/pkg/server/parser"`, + `"github.com/mlflow/mlflow-go/pkg/contract/service"`, + } + + if len(endpoints) > 0 { + importStatements = append( + importStatements, + `"github.com/mlflow/mlflow-go/pkg/utils"`, + `"github.com/mlflow/mlflow-go/pkg/protos"`, + ) + } + + decls := []ast.Decl{ + mkImportStatements(importStatements...), + mkRouteRegistrationFunction(endpoints, generationInfo.ServiceName, serviceInfo), + } + + fileName := generationInfo.FileNameWithoutExtension + ".g.go" + pkg := "routes" + outputPath := filepath.Join(pkgFolder, "server", pkg, fileName) + + return mkGeneratedFile(pkg, outputPath, decls) +} + +func mkCEndpointBody(serviceName string, method discovery.MethodInfo) *ast.BlockStmt { + mapName := strcase.ToLowerCamel(serviceName) + "s" + + return &ast.BlockStmt{ + List: []ast.Stmt{ + // service, err := trackingServices.Get(serviceID) + mkAssignStmt( + []ast.Expr{ + ast.NewIdent("service"), + ast.NewIdent("err"), + }, + []ast.Expr{ + mkCallExpr(mkSelectorExpr(mapName, "Get"), ast.NewIdent("serviceID")), + }, + ), + // if err != nil { + // return makePointerFromError(err, responseSize) + // } + mkIfStmt( + nil, + errNotEqualNil, + mkBlockStmt( + mkReturnStmt( + mkCallExpr( + ast.NewIdent("makePointerFromError"), + ast.NewIdent("err"), + ast.NewIdent("responseSize"), + ), + ), + ), + ), + // return invokeServiceMethod( + // service.GetExperiment, + // new(protos.GetExperiment), + // requestData, + // requestSize, + // responseSize, + // ) + mkReturnStmt( + mkCallExpr( + ast.NewIdent("invokeServiceMethod"), + mkSelectorExpr("service", strcase.ToCamel(method.Name)), + mkCallExpr(ast.NewIdent("new"), mkSelectorExpr("protos", method.Input)), + ast.NewIdent("requestData"), + ast.NewIdent("requestSize"), + ast.NewIdent("responseSize"), + ), + ), + }, + } +} + +func mkCEndpoint(serviceName string, method discovery.MethodInfo) *ast.FuncDecl { + functionName := fmt.Sprintf("%s%s", serviceName, strcase.ToCamel(method.Name)) + + return &ast.FuncDecl{ + Doc: &ast.CommentGroup{ + List: []*ast.Comment{ + { + Text: "//export " + functionName, + }, + }, + }, + Name: ast.NewIdent(functionName), + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + mkNamedField("serviceID", ast.NewIdent("int64")), + mkNamedField("requestData", mkSelectorExpr("unsafe", "Pointer")), + mkNamedField("requestSize", mkSelectorExpr("C", "int")), + mkNamedField("responseSize", mkStarExpr(mkSelectorExpr("C", "int"))), + }, + }, + Results: &ast.FieldList{ + List: []*ast.Field{ + mkField(mkSelectorExpr("unsafe", "Pointer")), + }, + }, + }, + Body: mkCEndpointBody(serviceName, method), + } +} + +func mkCEndpoints( + endpoints map[string]any, serviceName string, serviceInfo discovery.ServiceInfo, +) []*ast.FuncDecl { + funcs := make([]*ast.FuncDecl, 0, len(endpoints)) + + for _, method := range serviceInfo.Methods { + if _, ok := endpoints[method.Name]; ok { + funcs = append(funcs, mkCEndpoint(serviceName, method)) + } + } + + return funcs +} + +func generateEndpoints( + pkgFolder string, + serviceInfo discovery.ServiceInfo, + generationInfo ServiceGenerationInfo, + endpoints map[string]any, +) error { + decls := []ast.Decl{ + mkImportStatements(`"C"`), + } + + if len(endpoints) > 0 { + decls = append( + decls, + mkImportStatements( + `"unsafe"`, + `"github.com/mlflow/mlflow-go/pkg/protos"`, + ), + ) + + endpoints := mkCEndpoints(endpoints, generationInfo.ServiceName, serviceInfo) + for _, endpoint := range endpoints { + decls = append(decls, endpoint) + } + } + + fileName := generationInfo.FileNameWithoutExtension + ".g.go" + outputPath := filepath.Join(pkgFolder, "lib", fileName) + + return mkGeneratedFile("main", outputPath, decls) +} + +func SourceCode(pkgFolder string) error { + services, err := discovery.GetServiceInfos() + if err != nil { + return fmt.Errorf("failed to get service info: %w", err) + } + + for _, serviceInfo := range services { + generationInfo, ok := ServiceInfoMap[serviceInfo.Name] + if !ok { + continue + } + + endpoints := make(map[string]any, len(generationInfo.ImplementedEndpoints)) + + for _, endpoint := range generationInfo.ImplementedEndpoints { + endpoints[endpoint] = nil + } + + err = generateServices(pkgFolder, serviceInfo, generationInfo, endpoints) + if err != nil { + return err + } + + err = generateRouteRegistrations(pkgFolder, serviceInfo, generationInfo, endpoints) + if err != nil { + return err + } + + err = generateEndpoints(pkgFolder, serviceInfo, generationInfo, endpoints) + if err != nil { + return err + } + } + + return nil +} diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index c625861..b7f40c2 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -1,32 +1,32 @@ -package generate - -var validations = map[string]string{ - "GetExperiment_ExperimentId": "required,stringAsPositiveInteger", - "CreateExperiment_Name": "required,max=500", - "CreateExperiment_ArtifactLocation": "omitempty,uriWithoutFragmentsOrParamsOrDotDotInQuery", - "SearchRuns_RunViewType": "omitempty", - "SearchRuns_MaxResults": "gt=0,max=50000", - "DeleteExperiment_ExperimentId": "required,stringAsPositiveInteger", - "LogBatch_RunId": "required,runId", - "LogBatch_Params": "omitempty,uniqueParams,max=100,dive", - "LogBatch_Metrics": "max=1000,dive", - "LogBatch_Tags": "max=100", - "RunTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", - "RunTag_Value": "omitempty,max=5000", - "Param_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", - "Param_Value": "omitempty,truncate=6000", - "Metric_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", - "Metric_Timestamp": "required", - "Metric_Value": "required", - "CreateRun_ExperimentId": "required,stringAsPositiveInteger", - "GetExperimentByName_ExperimentName": "required", - "GetLatestVersions_Name": "required", - "LogMetric_RunId": "required", - "LogMetric_Key": "required", - "LogMetric_Value": "required", - "LogMetric_Timestamp": "required", - "SetTag_RunId": "required", - "SetTag_Key": "required", - "DeleteTag_RunId": "required", - "DeleteTag_Key": "required", -} +package generate + +var validations = map[string]string{ + "GetExperiment_ExperimentId": "required,stringAsPositiveInteger", + "CreateExperiment_Name": "required,max=500", + "CreateExperiment_ArtifactLocation": "omitempty,uriWithoutFragmentsOrParamsOrDotDotInQuery", + "SearchRuns_RunViewType": "omitempty", + "SearchRuns_MaxResults": "gt=0,max=50000", + "DeleteExperiment_ExperimentId": "required,stringAsPositiveInteger", + "LogBatch_RunId": "required,runId", + "LogBatch_Params": "omitempty,uniqueParams,max=100,dive", + "LogBatch_Metrics": "max=1000,dive", + "LogBatch_Tags": "max=100", + "RunTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "RunTag_Value": "omitempty,max=5000", + "Param_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "Param_Value": "omitempty,truncate=6000", + "Metric_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "Metric_Timestamp": "required", + "Metric_Value": "required", + "CreateRun_ExperimentId": "required,stringAsPositiveInteger", + "GetExperimentByName_ExperimentName": "required", + "GetLatestVersions_Name": "required", + "LogMetric_RunId": "required", + "LogMetric_Key": "required", + "LogMetric_Value": "required", + "LogMetric_Timestamp": "required", + "SetTag_RunId": "required", + "SetTag_Key": "required", + "DeleteTag_RunId": "required", + "DeleteTag_Key": "required", +} diff --git a/magefiles/repo.go b/magefiles/repo.go index e5da12d..25647e7 100644 --- a/magefiles/repo.go +++ b/magefiles/repo.go @@ -1,220 +1,220 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "errors" - "fmt" - "log" - "os" - "path/filepath" - "strings" - - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -const ( - MLFlowRepoFolderName = ".mlflow.repo" -) - -type Repo mg.Namespace - -func folderExists(path string) bool { - info, err := os.Stat(path) - if os.IsNotExist(err) { - return false - } - - return info.IsDir() -} - -func git(args ...string) error { - return sh.RunV("git", args...) -} - -func gitMlflowRepo(args ...string) error { - allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) - - return sh.RunV("git", allArgs...) -} - -func gitMlflowRepoOutput(args ...string) (string, error) { - allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) - - return sh.Output("git", allArgs...) -} - -type gitReference struct { - remote string - reference string -} - -const refFileName = ".mlflow.ref" - -func readFile(filename string) (string, error) { - content, err := os.ReadFile(filename) - if err != nil { - return "", err - } - - return string(content), nil -} - -var ErrInvalidGitRefFormat = errors.New("invalid format in .mlflow.ref file: expected 'remote#reference'") - -func readGitReference() (gitReference, error) { - refFilePath, err := filepath.Abs(refFileName) - if err != nil { - return gitReference{}, fmt.Errorf("failed to get .mlflow.ref: %w", err) - } - - content, err := readFile(refFilePath) - if err != nil { - return gitReference{}, err - } - - parts := strings.Split(content, "#") - - if len(parts) != 2 || parts[0] == "" || parts[1] == "" { - return gitReference{}, ErrInvalidGitRefFormat - } - - remote := strings.TrimSpace(parts[0]) - reference := strings.TrimSpace(parts[1]) - - return gitReference{remote: remote, reference: reference}, nil -} - -func freshCheckout(gitReference gitReference) error { - if err := git("clone", "--no-checkout", gitReference.remote, MLFlowRepoFolderName); err != nil { - return err - } - - return gitMlflowRepo("checkout", gitReference.reference) -} - -func checkRemote(gitReference gitReference) bool { - // git -C .mlflow.repo remote get-url origin - output, err := gitMlflowRepoOutput("remote", "get-url", "origin") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.remote -} - -func checkBranch(gitReference gitReference) bool { - // git -C .mlflow.repo rev-parse --abbrev-ref HEAD - output, err := gitMlflowRepoOutput("rev-parse", "--abbrev-ref", "HEAD") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.reference -} - -func checkTag(gitReference gitReference) bool { - // git -C .mlflow.repo describe --tags HEAD - output, err := gitMlflowRepoOutput("describe", "--tags", "HEAD") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.reference -} - -func checkCommit(gitReference gitReference) bool { - // git -C .mlflow.repo rev-parse HEAD - output, err := gitMlflowRepoOutput("rev-parse", "HEAD") - if err != nil { - return false - } - - return strings.TrimSpace(output) == gitReference.reference -} - -func checkReference(gitReference gitReference) bool { - switch { - case checkBranch(gitReference): - log.Printf("Already on branch %q", gitReference.reference) - - return true - case checkTag(gitReference): - log.Printf("Already on tag %q", gitReference.reference) - - return true - case checkCommit(gitReference): - log.Printf("Already on commit %q", gitReference.reference) - - return true - } - - return false -} - -func syncRepo(gitReference gitReference) error { - log.Printf("syncing mlflow repo to %s#%s", gitReference.remote, gitReference.reference) - - if err := gitMlflowRepo("remote", "set-url", "origin", gitReference.remote); err != nil { - return err - } - - if err := gitMlflowRepo("fetch", "origin"); err != nil { - return err - } - - if err := gitMlflowRepo("checkout", gitReference.reference); err != nil { - return err - } - - if checkBranch(gitReference) { - return gitMlflowRepo("pull") - } - - return nil -} - -// Clone or reset the .mlflow.repo fork. -func (Repo) Init() error { - gitReference, err := readGitReference() - if err != nil { - return err - } - - repoPath, err := filepath.Abs(MLFlowRepoFolderName) - if err != nil { - return err - } - - if !folderExists(repoPath) { - return freshCheckout(gitReference) - } - - // Verify remote - if !checkRemote(gitReference) { - log.Printf("Remote %s no longer matches", gitReference.remote) - - return syncRepo(gitReference) - } - - // Verify reference - if !checkReference(gitReference) { - log.Printf("The current reference %q no longer matches", gitReference.reference) - - return syncRepo(gitReference) - } - - return nil -} - -// Forcefully update the .mlflow.repo according to the .mlflow.ref. -func (Repo) Update() error { - gitReference, err := readGitReference() - if err != nil { - return err - } - - return syncRepo(gitReference) -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "errors" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +const ( + MLFlowRepoFolderName = ".mlflow.repo" +) + +type Repo mg.Namespace + +func folderExists(path string) bool { + info, err := os.Stat(path) + if os.IsNotExist(err) { + return false + } + + return info.IsDir() +} + +func git(args ...string) error { + return sh.RunV("git", args...) +} + +func gitMlflowRepo(args ...string) error { + allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) + + return sh.RunV("git", allArgs...) +} + +func gitMlflowRepoOutput(args ...string) (string, error) { + allArgs := append([]string{"-C", MLFlowRepoFolderName}, args...) + + return sh.Output("git", allArgs...) +} + +type gitReference struct { + remote string + reference string +} + +const refFileName = ".mlflow.ref" + +func readFile(filename string) (string, error) { + content, err := os.ReadFile(filename) + if err != nil { + return "", err + } + + return string(content), nil +} + +var ErrInvalidGitRefFormat = errors.New("invalid format in .mlflow.ref file: expected 'remote#reference'") + +func readGitReference() (gitReference, error) { + refFilePath, err := filepath.Abs(refFileName) + if err != nil { + return gitReference{}, fmt.Errorf("failed to get .mlflow.ref: %w", err) + } + + content, err := readFile(refFilePath) + if err != nil { + return gitReference{}, err + } + + parts := strings.Split(content, "#") + + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return gitReference{}, ErrInvalidGitRefFormat + } + + remote := strings.TrimSpace(parts[0]) + reference := strings.TrimSpace(parts[1]) + + return gitReference{remote: remote, reference: reference}, nil +} + +func freshCheckout(gitReference gitReference) error { + if err := git("clone", "--no-checkout", gitReference.remote, MLFlowRepoFolderName); err != nil { + return err + } + + return gitMlflowRepo("checkout", gitReference.reference) +} + +func checkRemote(gitReference gitReference) bool { + // git -C .mlflow.repo remote get-url origin + output, err := gitMlflowRepoOutput("remote", "get-url", "origin") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.remote +} + +func checkBranch(gitReference gitReference) bool { + // git -C .mlflow.repo rev-parse --abbrev-ref HEAD + output, err := gitMlflowRepoOutput("rev-parse", "--abbrev-ref", "HEAD") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.reference +} + +func checkTag(gitReference gitReference) bool { + // git -C .mlflow.repo describe --tags HEAD + output, err := gitMlflowRepoOutput("describe", "--tags", "HEAD") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.reference +} + +func checkCommit(gitReference gitReference) bool { + // git -C .mlflow.repo rev-parse HEAD + output, err := gitMlflowRepoOutput("rev-parse", "HEAD") + if err != nil { + return false + } + + return strings.TrimSpace(output) == gitReference.reference +} + +func checkReference(gitReference gitReference) bool { + switch { + case checkBranch(gitReference): + log.Printf("Already on branch %q", gitReference.reference) + + return true + case checkTag(gitReference): + log.Printf("Already on tag %q", gitReference.reference) + + return true + case checkCommit(gitReference): + log.Printf("Already on commit %q", gitReference.reference) + + return true + } + + return false +} + +func syncRepo(gitReference gitReference) error { + log.Printf("syncing mlflow repo to %s#%s", gitReference.remote, gitReference.reference) + + if err := gitMlflowRepo("remote", "set-url", "origin", gitReference.remote); err != nil { + return err + } + + if err := gitMlflowRepo("fetch", "origin"); err != nil { + return err + } + + if err := gitMlflowRepo("checkout", gitReference.reference); err != nil { + return err + } + + if checkBranch(gitReference) { + return gitMlflowRepo("pull") + } + + return nil +} + +// Clone or reset the .mlflow.repo fork. +func (Repo) Init() error { + gitReference, err := readGitReference() + if err != nil { + return err + } + + repoPath, err := filepath.Abs(MLFlowRepoFolderName) + if err != nil { + return err + } + + if !folderExists(repoPath) { + return freshCheckout(gitReference) + } + + // Verify remote + if !checkRemote(gitReference) { + log.Printf("Remote %s no longer matches", gitReference.remote) + + return syncRepo(gitReference) + } + + // Verify reference + if !checkReference(gitReference) { + log.Printf("The current reference %q no longer matches", gitReference.reference) + + return syncRepo(gitReference) + } + + return nil +} + +// Forcefully update the .mlflow.repo according to the .mlflow.ref. +func (Repo) Update() error { + gitReference, err := readGitReference() + if err != nil { + return err + } + + return syncRepo(gitReference) +} diff --git a/magefiles/temp.go b/magefiles/temp.go index 484d553..b835db7 100644 --- a/magefiles/temp.go +++ b/magefiles/temp.go @@ -1,74 +1,74 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "os" - "path/filepath" - - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -func pipInstall(args ...string) error { - allArgs := append([]string{"install"}, args...) - - return sh.RunV("pip", allArgs...) -} - -func tar(args ...string) error { - return sh.RunV("tar", args...) -} - -func Temp() error { - mg.Deps(Repo.Init) - - // Install our Python package and its dependencies - if err := pipInstall("-e", "."); err != nil { - return err - } - - // Install the dreaded psycho - if err := pipInstall("psycopg2-binary"); err != nil { - return err - } - - // Archive the MLFlow pre-built UI - if err := tar( - "-C", "/usr/local/python/current/lib/python3.8/site-packages/mlflow", - "-czvf", - "./ui.tgz", - "./server/js/build", - ); err != nil { - return err - } - - mlflowRepoPath, err := filepath.Abs(MLFlowRepoFolderName) - if err != nil { - return err - } - - // Add the UI back to it - if err := tar( - "-C", mlflowRepoPath, - "-xzvf", "./ui.tgz", - ); err != nil { - return err - } - - // Remove tar file - tarPath, err := filepath.Abs("ui.tgz") - if err != nil { - return err - } - - defer os.Remove(tarPath) - - // Install it in editable mode - if err := pipInstall("-e", mlflowRepoPath); err != nil { - return err - } - - return nil -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "os" + "path/filepath" + + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +func pipInstall(args ...string) error { + allArgs := append([]string{"install"}, args...) + + return sh.RunV("pip", allArgs...) +} + +func tar(args ...string) error { + return sh.RunV("tar", args...) +} + +func Temp() error { + mg.Deps(Repo.Init) + + // Install our Python package and its dependencies + if err := pipInstall("-e", "."); err != nil { + return err + } + + // Install the dreaded psycho + if err := pipInstall("psycopg2-binary"); err != nil { + return err + } + + // Archive the MLFlow pre-built UI + if err := tar( + "-C", "/usr/local/python/current/lib/python3.8/site-packages/mlflow", + "-czvf", + "./ui.tgz", + "./server/js/build", + ); err != nil { + return err + } + + mlflowRepoPath, err := filepath.Abs(MLFlowRepoFolderName) + if err != nil { + return err + } + + // Add the UI back to it + if err := tar( + "-C", mlflowRepoPath, + "-xzvf", "./ui.tgz", + ); err != nil { + return err + } + + // Remove tar file + tarPath, err := filepath.Abs("ui.tgz") + if err != nil { + return err + } + + defer os.Remove(tarPath) + + // Install it in editable mode + if err := pipInstall("-e", mlflowRepoPath); err != nil { + return err + } + + return nil +} diff --git a/magefiles/tests.go b/magefiles/tests.go index e7a1cf3..2e10272 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -1,102 +1,102 @@ -//go:build mage - -//nolint:wrapcheck -package main - -import ( - "os" - - "github.com/magefile/mage/mg" - "github.com/magefile/mage/sh" -) - -type Test mg.Namespace - -func cleanUpMemoryFile() error { - // Clean up :memory: file - filename := ":memory:" - _, err := os.Stat(filename) - - if err == nil { - // File exists, delete it - err = os.Remove(filename) - if err != nil { - return err - } - } - - return nil -} - -// Run mlflow Python tests against the Go backend. -func (Test) Python() error { - libpath, err := os.MkdirTemp("", "") - if err != nil { - return err - } - - // Remove the Go binary - defer os.RemoveAll(libpath) - //nolint:errcheck - defer cleanUpMemoryFile() - - // Build the Go binary in a temporary directory - if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { - return nil - } - - // Run the tests (currently just the server ones) - if err := sh.RunWithV(map[string]string{ - "MLFLOW_GO_LIBRARY_PATH": libpath, - }, "pytest", - "--confcutdir=.", - ".mlflow.repo/tests/tracking/test_rest_tracking.py", - ".mlflow.repo/tests/tracking/test_model_registry.py", - ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", - ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", - "-k", - "not [file", - // "-vv", - ); err != nil { - return err - } - - return nil -} - -// Run specific Python test against the Go backend. -func (Test) PythonSpecific(testName string) error { - libpath, err := os.MkdirTemp("", "") - if err != nil { - return err - } - - defer os.RemoveAll(libpath) - defer cleanUpMemoryFile() - - if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { - return nil - } - - if err := sh.RunWithV(map[string]string{ - "MLFLOW_GO_LIBRARY_PATH": libpath, - }, "pytest", - "--confcutdir=.", - ".mlflow.repo/tests/tracking/test_rest_tracking.py", - "-k", testName, - ); err != nil { - return err - } - - return nil -} - -// Run the Go unit tests. -func (Test) Unit() error { - return sh.RunV("go", "test", "./pkg/...") -} - -// Run all tests. -func (Test) All() { - mg.Deps(Test.Unit, Test.Python) -} +//go:build mage + +//nolint:wrapcheck +package main + +import ( + "os" + + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +type Test mg.Namespace + +func cleanUpMemoryFile() error { + // Clean up :memory: file + filename := ":memory:" + _, err := os.Stat(filename) + + if err == nil { + // File exists, delete it + err = os.Remove(filename) + if err != nil { + return err + } + } + + return nil +} + +// Run mlflow Python tests against the Go backend. +func (Test) Python() error { + libpath, err := os.MkdirTemp("", "") + if err != nil { + return err + } + + // Remove the Go binary + defer os.RemoveAll(libpath) + //nolint:errcheck + defer cleanUpMemoryFile() + + // Build the Go binary in a temporary directory + if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { + return nil + } + + // Run the tests (currently just the server ones) + if err := sh.RunWithV(map[string]string{ + "MLFLOW_GO_LIBRARY_PATH": libpath, + }, "pytest", + "--confcutdir=.", + ".mlflow.repo/tests/tracking/test_rest_tracking.py", + ".mlflow.repo/tests/tracking/test_model_registry.py", + ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", + ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", + "-k", + "not [file", + // "-vv", + ); err != nil { + return err + } + + return nil +} + +// Run specific Python test against the Go backend. +func (Test) PythonSpecific(testName string) error { + libpath, err := os.MkdirTemp("", "") + if err != nil { + return err + } + + defer os.RemoveAll(libpath) + defer cleanUpMemoryFile() + + if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { + return nil + } + + if err := sh.RunWithV(map[string]string{ + "MLFLOW_GO_LIBRARY_PATH": libpath, + }, "pytest", + "--confcutdir=.", + ".mlflow.repo/tests/tracking/test_rest_tracking.py", + "-k", testName, + ); err != nil { + return err + } + + return nil +} + +// Run the Go unit tests. +func (Test) Unit() error { + return sh.RunV("go", "test", "./pkg/...") +} + +// Run all tests. +func (Test) All() { + mg.Deps(Test.Unit, Test.Python) +} diff --git a/mlflow_go/__init__.py b/mlflow_go/__init__.py index a9275a8..d7ac0f1 100644 --- a/mlflow_go/__init__.py +++ b/mlflow_go/__init__.py @@ -1,20 +1,20 @@ -import os - -_go_enabled = "MLFLOW_GO_ENABLED" in os.environ - - -def _set_go_enabled(enabled: bool): - global _go_enabled - _go_enabled = enabled - - -def is_go_enabled(): - return _go_enabled - - -def disable_go(): - _set_go_enabled(False) - - -def enable_go(): - _set_go_enabled(True) +import os + +_go_enabled = "MLFLOW_GO_ENABLED" in os.environ + + +def _set_go_enabled(enabled: bool): + global _go_enabled + _go_enabled = enabled + + +def is_go_enabled(): + return _go_enabled + + +def disable_go(): + _set_go_enabled(False) + + +def enable_go(): + _set_go_enabled(True) diff --git a/mlflow_go/cli.py b/mlflow_go/cli.py index a9d0530..a6ed160 100644 --- a/mlflow_go/cli.py +++ b/mlflow_go/cli.py @@ -1,112 +1,112 @@ -import json -import pathlib -import shlex - -import click -import mlflow.cli -import mlflow.version -from mlflow.utils import find_free_port - -from mlflow_go.lib import get_lib - - -def _get_commands(): - """Returns the MLflow CLI commands with the `server` command replaced with a Go server.""" - commands = mlflow.cli.cli.commands.copy() - - def server( - go_opts, - **kwargs, - ): - # convert the Go options to a dictionary - opts = {} - if go_opts: - for opt in go_opts.split(","): - key, value = opt.split("=", 1) - opts[key] = value - - # validate the Python server configuration if set - if ("python_address" in opts) ^ ("python_command" in opts): - raise click.ClickException("python_address and python_command have to be set together") - - if "python_address" and "python_command" in opts: - # use the provided Python server configuration - python_address = opts["python_address"] - python_command = shlex.split(opts["python_command"]) - else: - # assign a random port for the Python server - python_host = "127.0.0.1" - python_port = find_free_port() - python_address = f"{python_host}:{python_port}" - python_args = kwargs.copy() - python_args.update( - { - "host": python_host, - "port": python_port, - } - ) - - # construct the Python server command - python_command = [ - "mlflow", - "server", - ] - for key, value in python_args.items(): - if isinstance(value, bool): - if value: - python_command.append(f"--{key.replace('_', '-')}") - elif value is not None: - python_command.append(f"--{key.replace('_', '-')}") - python_command.append(str(value)) - - # initialize the Go server configuration - tracking_store_uri = kwargs["backend_store_uri"] - config = { - "address": f'{kwargs["host"]}:{kwargs["port"]}', - "default_artifact_root": mlflow.cli.resolve_default_artifact_root( - kwargs["serve_artifacts"], kwargs["default_artifact_root"], tracking_store_uri - ), - "log_level": opts.get("log_level", "DEBUG" if kwargs["dev"] else "INFO"), - "python_address": python_address, - "python_command": python_command, - "shutdown_timeout": opts.get("shutdown_timeout", "1m"), - "static_folder": pathlib.Path(mlflow.server.__file__) - .parent.joinpath(mlflow.server.REL_STATIC_DIR) - .resolve() - .as_posix(), - "tracking_store_uri": tracking_store_uri, - "model_registry_store_uri": kwargs["registry_store_uri"] or tracking_store_uri, - "version": mlflow.version.VERSION, - } - config_bytes = json.dumps(config).encode("utf-8") - - # start the Go server and check for errors - ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) - if ret != 0: - raise click.ClickException(f"Non-zero exit code: {ret}") - - server.__doc__ = mlflow.cli.server.callback.__doc__ - - server_params = mlflow.cli.server.params.copy() - idx = next((i for i, x in enumerate(mlflow.cli.server.params) if x.name == "gunicorn_opts"), -1) - server_params.insert( - idx, - click.Option( - ["--go-opts"], - default=None, - help="Additional options forwarded to the Go server", - ), - ) - - commands["server"] = click.command(params=server_params)(server) - - return commands - - -@click.group(commands=_get_commands()) -def cli(): - pass - - -if __name__ == "__main__": - cli() +import json +import pathlib +import shlex + +import click +import mlflow.cli +import mlflow.version +from mlflow.utils import find_free_port + +from mlflow_go.lib import get_lib + + +def _get_commands(): + """Returns the MLflow CLI commands with the `server` command replaced with a Go server.""" + commands = mlflow.cli.cli.commands.copy() + + def server( + go_opts, + **kwargs, + ): + # convert the Go options to a dictionary + opts = {} + if go_opts: + for opt in go_opts.split(","): + key, value = opt.split("=", 1) + opts[key] = value + + # validate the Python server configuration if set + if ("python_address" in opts) ^ ("python_command" in opts): + raise click.ClickException("python_address and python_command have to be set together") + + if "python_address" and "python_command" in opts: + # use the provided Python server configuration + python_address = opts["python_address"] + python_command = shlex.split(opts["python_command"]) + else: + # assign a random port for the Python server + python_host = "127.0.0.1" + python_port = find_free_port() + python_address = f"{python_host}:{python_port}" + python_args = kwargs.copy() + python_args.update( + { + "host": python_host, + "port": python_port, + } + ) + + # construct the Python server command + python_command = [ + "mlflow", + "server", + ] + for key, value in python_args.items(): + if isinstance(value, bool): + if value: + python_command.append(f"--{key.replace('_', '-')}") + elif value is not None: + python_command.append(f"--{key.replace('_', '-')}") + python_command.append(str(value)) + + # initialize the Go server configuration + tracking_store_uri = kwargs["backend_store_uri"] + config = { + "address": f'{kwargs["host"]}:{kwargs["port"]}', + "default_artifact_root": mlflow.cli.resolve_default_artifact_root( + kwargs["serve_artifacts"], kwargs["default_artifact_root"], tracking_store_uri + ), + "log_level": opts.get("log_level", "DEBUG" if kwargs["dev"] else "INFO"), + "python_address": python_address, + "python_command": python_command, + "shutdown_timeout": opts.get("shutdown_timeout", "1m"), + "static_folder": pathlib.Path(mlflow.server.__file__) + .parent.joinpath(mlflow.server.REL_STATIC_DIR) + .resolve() + .as_posix(), + "tracking_store_uri": tracking_store_uri, + "model_registry_store_uri": kwargs["registry_store_uri"] or tracking_store_uri, + "version": mlflow.version.VERSION, + } + config_bytes = json.dumps(config).encode("utf-8") + + # start the Go server and check for errors + ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) + if ret != 0: + raise click.ClickException(f"Non-zero exit code: {ret}") + + server.__doc__ = mlflow.cli.server.callback.__doc__ + + server_params = mlflow.cli.server.params.copy() + idx = next((i for i, x in enumerate(mlflow.cli.server.params) if x.name == "gunicorn_opts"), -1) + server_params.insert( + idx, + click.Option( + ["--go-opts"], + default=None, + help="Additional options forwarded to the Go server", + ), + ) + + commands["server"] = click.command(params=server_params)(server) + + return commands + + +@click.group(commands=_get_commands()) +def cli(): + pass + + +if __name__ == "__main__": + cli() diff --git a/mlflow_go/lib.py b/mlflow_go/lib.py index d4cbb32..ad5c482 100644 --- a/mlflow_go/lib.py +++ b/mlflow_go/lib.py @@ -1,124 +1,124 @@ -import logging -import os -import pathlib -import re -import subprocess -import sys -import tempfile - - -def _get_lib_name() -> str: - ext = ".so" - if sys.platform == "win32": - ext = ".dll" - elif sys.platform == "darwin": - ext = ".dylib" - return "libmlflow-go" + ext - - -def build_lib(src_dir: pathlib.Path, out_dir: pathlib.Path) -> pathlib.Path: - out_path = out_dir.joinpath(_get_lib_name()) - env = os.environ.copy() - env.update( - { - "CGO_ENABLED": "1", - } - ) - subprocess.check_call( - [ - "go", - "build", - "-trimpath", - "-ldflags", - "-w -s", - "-o", - out_path.resolve().as_posix(), - "-buildmode", - "c-shared", - src_dir.joinpath("pkg", "lib").resolve().as_posix(), - ], - cwd=src_dir.resolve().as_posix(), - env=env, - ) - return out_path - - -def _get_lib(): - # check if the library exists and load it - path = pathlib.Path( - os.environ.get("MLFLOW_GO_LIBRARY_PATH", pathlib.Path(__file__).parent.as_posix()) - ).joinpath(_get_lib_name()) - if path.is_file(): - return _load_lib(path) - - logging.getLogger(__name__).warn("Go library not found, building it now") - - # build the library in a temporary directory and load it - with tempfile.TemporaryDirectory() as tmpdir: - return _load_lib( - build_lib( - pathlib.Path(__file__).parent.parent, - pathlib.Path(tmpdir), - ) - ) - - -def _load_lib(path: pathlib.Path): - ffi = get_ffi() - - # load from header file - ffi.cdef(_parse_header(path.with_suffix(".h"))) - - # load the library - return ffi.dlopen(path.as_posix()) - - -def _parse_header(path: pathlib.Path): - with open(path) as file: - content = file.read() - - # Find all matches in the header - functions = re.findall(r"extern\s+\w+\s*\*?\s+\w+\s*\([^)]*\);", content, re.MULTILINE) - - # Replace GoInt64 with int64_t in each function - transformed_functions = [func.replace("GoInt64", "int64_t") for func in functions] - - return "\n".join(transformed_functions) - - -def _get_ffi(): - import cffi - - return cffi.FFI() - - -_ffi = None - - -def get_ffi(): - global _ffi - if _ffi is None: - _ffi = _get_ffi() - _ffi.cdef("void free(void*);") - return _ffi - - -_lib = None - - -def get_lib(): - global _lib - if _lib is None: - _lib = _get_lib() - return _lib - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser("build_lib", description="Build Go library") - parser.add_argument("src", help="the Go source directory") - parser.add_argument("out", help="the output directory") - args = parser.parse_args() - - build_lib(pathlib.Path(args.src), pathlib.Path(args.out)) +import logging +import os +import pathlib +import re +import subprocess +import sys +import tempfile + + +def _get_lib_name() -> str: + ext = ".so" + if sys.platform == "win32": + ext = ".dll" + elif sys.platform == "darwin": + ext = ".dylib" + return "libmlflow-go" + ext + + +def build_lib(src_dir: pathlib.Path, out_dir: pathlib.Path) -> pathlib.Path: + out_path = out_dir.joinpath(_get_lib_name()) + env = os.environ.copy() + env.update( + { + "CGO_ENABLED": "1", + } + ) + subprocess.check_call( + [ + "go", + "build", + "-trimpath", + "-ldflags", + "-w -s", + "-o", + out_path.resolve().as_posix(), + "-buildmode", + "c-shared", + src_dir.joinpath("pkg", "lib").resolve().as_posix(), + ], + cwd=src_dir.resolve().as_posix(), + env=env, + ) + return out_path + + +def _get_lib(): + # check if the library exists and load it + path = pathlib.Path( + os.environ.get("MLFLOW_GO_LIBRARY_PATH", pathlib.Path(__file__).parent.as_posix()) + ).joinpath(_get_lib_name()) + if path.is_file(): + return _load_lib(path) + + logging.getLogger(__name__).warn("Go library not found, building it now") + + # build the library in a temporary directory and load it + with tempfile.TemporaryDirectory() as tmpdir: + return _load_lib( + build_lib( + pathlib.Path(__file__).parent.parent, + pathlib.Path(tmpdir), + ) + ) + + +def _load_lib(path: pathlib.Path): + ffi = get_ffi() + + # load from header file + ffi.cdef(_parse_header(path.with_suffix(".h"))) + + # load the library + return ffi.dlopen(path.as_posix()) + + +def _parse_header(path: pathlib.Path): + with open(path) as file: + content = file.read() + + # Find all matches in the header + functions = re.findall(r"extern\s+\w+\s*\*?\s+\w+\s*\([^)]*\);", content, re.MULTILINE) + + # Replace GoInt64 with int64_t in each function + transformed_functions = [func.replace("GoInt64", "int64_t") for func in functions] + + return "\n".join(transformed_functions) + + +def _get_ffi(): + import cffi + + return cffi.FFI() + + +_ffi = None + + +def get_ffi(): + global _ffi + if _ffi is None: + _ffi = _get_ffi() + _ffi.cdef("void free(void*);") + return _ffi + + +_lib = None + + +def get_lib(): + global _lib + if _lib is None: + _lib = _get_lib() + return _lib + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser("build_lib", description="Build Go library") + parser.add_argument("src", help="the Go source directory") + parser.add_argument("out", help="the output directory") + args = parser.parse_args() + + build_lib(pathlib.Path(args.src), pathlib.Path(args.out)) diff --git a/mlflow_go/server.py b/mlflow_go/server.py index bc82c44..102eb23 100644 --- a/mlflow_go/server.py +++ b/mlflow_go/server.py @@ -1,31 +1,31 @@ -import json -from contextlib import contextmanager - -from mlflow_go.lib import get_lib - - -def launch_server(**config): - config_bytes = json.dumps(config).encode("utf-8") - - # start the Go server and check for errors - ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) - if ret != 0: - raise Exception(f"Non-zero exit code: {ret}") - - -@contextmanager -def server(**config): - config_bytes = json.dumps(config).encode("utf-8") - - # start the Go server and check for errors - id = get_lib().LaunchServerAsync(config_bytes, len(config_bytes)) - if id < 0: - raise Exception(f"Non-zero exit code: {id}") - - try: - yield - finally: - # stop the Go server and check for errors - ret = get_lib().StopServer(id) - if ret != 0: - raise Exception(f"Non-zero exit code: {ret}") +import json +from contextlib import contextmanager + +from mlflow_go.lib import get_lib + + +def launch_server(**config): + config_bytes = json.dumps(config).encode("utf-8") + + # start the Go server and check for errors + ret = get_lib().LaunchServer(config_bytes, len(config_bytes)) + if ret != 0: + raise Exception(f"Non-zero exit code: {ret}") + + +@contextmanager +def server(**config): + config_bytes = json.dumps(config).encode("utf-8") + + # start the Go server and check for errors + id = get_lib().LaunchServerAsync(config_bytes, len(config_bytes)) + if id < 0: + raise Exception(f"Non-zero exit code: {id}") + + try: + yield + finally: + # stop the Go server and check for errors + ret = get_lib().StopServer(id) + if ret != 0: + raise Exception(f"Non-zero exit code: {ret}") diff --git a/mlflow_go/store/_service_proxy.py b/mlflow_go/store/_service_proxy.py index 4241b1a..be854ac 100644 --- a/mlflow_go/store/_service_proxy.py +++ b/mlflow_go/store/_service_proxy.py @@ -1,43 +1,43 @@ -import json - -from google.protobuf.message import DecodeError -from mlflow.exceptions import MlflowException -from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode - -from mlflow_go.lib import get_ffi, get_lib - - -class _ServiceProxy: - def __init__(self, id): - self.id = id - - def call_endpoint(self, endpoint, request): - request_data = request.SerializeToString() - response_size = get_ffi().new("int*") - - response_data = endpoint( - self.id, - request_data, - len(request_data), - response_size, - ) - - response_bytes = get_ffi().buffer(response_data, response_size[0])[:] - get_lib().free(response_data) - - try: - response = type(request).Response() - response.ParseFromString(response_bytes) - return response - except DecodeError: - try: - e = json.loads(response_bytes) - error_code = e.get("error_code", ErrorCode.Name(INTERNAL_ERROR)) - raise MlflowException( - message=e["message"], - error_code=ErrorCode.Value(error_code), - ) from None - except json.JSONDecodeError as e: - raise MlflowException( - message=f"Failed to parse response: {e}", - ) +import json + +from google.protobuf.message import DecodeError +from mlflow.exceptions import MlflowException +from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode + +from mlflow_go.lib import get_ffi, get_lib + + +class _ServiceProxy: + def __init__(self, id): + self.id = id + + def call_endpoint(self, endpoint, request): + request_data = request.SerializeToString() + response_size = get_ffi().new("int*") + + response_data = endpoint( + self.id, + request_data, + len(request_data), + response_size, + ) + + response_bytes = get_ffi().buffer(response_data, response_size[0])[:] + get_lib().free(response_data) + + try: + response = type(request).Response() + response.ParseFromString(response_bytes) + return response + except DecodeError: + try: + e = json.loads(response_bytes) + error_code = e.get("error_code", ErrorCode.Name(INTERNAL_ERROR)) + raise MlflowException( + message=e["message"], + error_code=ErrorCode.Value(error_code), + ) from None + except json.JSONDecodeError as e: + raise MlflowException( + message=f"Failed to parse response: {e}", + ) diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index ba80a17..bc5ee11 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -1,55 +1,55 @@ -import json -import logging - -from mlflow.entities.model_registry import ( - ModelVersion, -) -from mlflow.protos.model_registry_pb2 import ( - GetLatestVersions, -) - -from mlflow_go import is_go_enabled -from mlflow_go.lib import get_lib -from mlflow_go.store._service_proxy import _ServiceProxy - -_logger = logging.getLogger(__name__) - - -class _ModelRegistryStore: - def __init__(self, *args, **kwargs): - store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) - config = json.dumps( - { - "model_registry_store_uri": store_uri, - "log_level": logging.getLevelName(_logger.getEffectiveLevel()), - } - ).encode("utf-8") - self.service = _ServiceProxy(get_lib().CreateModelRegistryService(config, len(config))) - super().__init__(store_uri) - - def __del__(self): - if hasattr(self, "service"): - get_lib().DestroyModelRegistryService(self.service.id) - - def get_latest_versions(self, name, stages=None): - request = GetLatestVersions( - name=name, - stages=stages, - ) - response = self.service.call_endpoint( - get_lib().ModelRegistryServiceGetLatestVersions, request - ) - return [ModelVersion.from_proto(mv) for mv in response.model_versions] - - -def ModelRegistryStore(cls): - return type(cls.__name__, (_ModelRegistryStore, cls), {}) - - -def _get_sqlalchemy_store(store_uri): - from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore - - if is_go_enabled(): - SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) - - return SqlAlchemyStore(store_uri) +import json +import logging + +from mlflow.entities.model_registry import ( + ModelVersion, +) +from mlflow.protos.model_registry_pb2 import ( + GetLatestVersions, +) + +from mlflow_go import is_go_enabled +from mlflow_go.lib import get_lib +from mlflow_go.store._service_proxy import _ServiceProxy + +_logger = logging.getLogger(__name__) + + +class _ModelRegistryStore: + def __init__(self, *args, **kwargs): + store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) + config = json.dumps( + { + "model_registry_store_uri": store_uri, + "log_level": logging.getLevelName(_logger.getEffectiveLevel()), + } + ).encode("utf-8") + self.service = _ServiceProxy(get_lib().CreateModelRegistryService(config, len(config))) + super().__init__(store_uri) + + def __del__(self): + if hasattr(self, "service"): + get_lib().DestroyModelRegistryService(self.service.id) + + def get_latest_versions(self, name, stages=None): + request = GetLatestVersions( + name=name, + stages=stages, + ) + response = self.service.call_endpoint( + get_lib().ModelRegistryServiceGetLatestVersions, request + ) + return [ModelVersion.from_proto(mv) for mv in response.model_versions] + + +def ModelRegistryStore(cls): + return type(cls.__name__, (_ModelRegistryStore, cls), {}) + + +def _get_sqlalchemy_store(store_uri): + from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore + + if is_go_enabled(): + SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) + + return SqlAlchemyStore(store_uri) diff --git a/mlflow_go/store/tracking.py b/mlflow_go/store/tracking.py index 97c4e9c..8d22239 100644 --- a/mlflow_go/store/tracking.py +++ b/mlflow_go/store/tracking.py @@ -1,192 +1,192 @@ -import json -import logging - -from mlflow.entities import ( - Experiment, - Run, - RunInfo, - ViewType, -) -from mlflow.exceptions import MlflowException -from mlflow.protos import databricks_pb2 -from mlflow.protos.service_pb2 import ( - CreateExperiment, - CreateRun, - DeleteExperiment, - DeleteRun, - DeleteTag, - GetExperiment, - GetExperimentByName, - GetRun, - LogBatch, - LogMetric, - RestoreExperiment, - RestoreRun, - SearchRuns, - SetTag, - UpdateExperiment, - UpdateRun, -) -from mlflow.utils.uri import resolve_uri_if_local - -from mlflow_go import is_go_enabled -from mlflow_go.lib import get_lib -from mlflow_go.store._service_proxy import _ServiceProxy - -_logger = logging.getLogger(__name__) - - -class _TrackingStore: - def __init__(self, *args, **kwargs): - store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) - default_artifact_root = ( - args[1] - if len(args) > 1 - else kwargs.get("default_artifact_root", kwargs.get("artifact_root_uri")) - ) - config = json.dumps( - { - "default_artifact_root": resolve_uri_if_local(default_artifact_root), - "tracking_store_uri": store_uri, - "log_level": logging.getLevelName(_logger.getEffectiveLevel()), - } - ).encode("utf-8") - self.service = _ServiceProxy(get_lib().CreateTrackingService(config, len(config))) - super().__init__(store_uri, default_artifact_root) - - def __del__(self): - if hasattr(self, "service"): - get_lib().DestroyTrackingService(self.service.id) - - def get_experiment(self, experiment_id): - request = GetExperiment(experiment_id=str(experiment_id)) - response = self.service.call_endpoint(get_lib().TrackingServiceGetExperiment, request) - return Experiment.from_proto(response.experiment) - - def get_experiment_by_name(self, experiment_name): - request = GetExperimentByName(experiment_name=experiment_name) - try: - response = self.service.call_endpoint( - get_lib().TrackingServiceGetExperimentByName, request - ) - return Experiment.from_proto(response.experiment) - except MlflowException as e: - if e.error_code == databricks_pb2.ErrorCode.Name( - databricks_pb2.RESOURCE_DOES_NOT_EXIST - ): - return None - raise - - def create_experiment(self, name, artifact_location=None, tags=None): - request = CreateExperiment( - name=name, - artifact_location=artifact_location, - tags=[tag.to_proto() for tag in tags] if tags else [], - ) - response = self.service.call_endpoint(get_lib().TrackingServiceCreateExperiment, request) - return response.experiment_id - - def delete_experiment(self, experiment_id): - request = DeleteExperiment(experiment_id=str(experiment_id)) - self.service.call_endpoint(get_lib().TrackingServiceDeleteExperiment, request) - - def restore_experiment(self, experiment_id): - request = RestoreExperiment(experiment_id=str(experiment_id)) - self.service.call_endpoint(get_lib().TrackingServiceRestoreExperiment, request) - - def rename_experiment(self, experiment_id, new_name): - request = UpdateExperiment(experiment_id=str(experiment_id), new_name=new_name) - self.service.call_endpoint(get_lib().TrackingServiceUpdateExperiment, request) - - def get_run(self, run_id): - request = GetRun(run_uuid=run_id, run_id=run_id) - response = self.service.call_endpoint(get_lib().TrackingServiceGetRun, request) - return Run.from_proto(response.run) - - def create_run(self, experiment_id, user_id, start_time, tags, run_name): - request = CreateRun( - experiment_id=str(experiment_id), - user_id=user_id, - start_time=start_time, - tags=[tag.to_proto() for tag in tags] if tags else [], - run_name=run_name, - ) - response = self.service.call_endpoint(get_lib().TrackingServiceCreateRun, request) - return Run.from_proto(response.run) - - def delete_run(self, run_id): - request = DeleteRun(run_id=run_id) - self.service.call_endpoint(get_lib().TrackingServiceDeleteRun, request) - - def restore_run(self, run_id): - request = RestoreRun(run_id=run_id) - self.service.call_endpoint(get_lib().TrackingServiceRestoreRun, request) - - def update_run(self, run_id, run_status, end_time, run_name): - request = UpdateRun( - run_uuid=run_id, - run_id=run_id, - status=run_status, - end_time=end_time, - run_name=run_name, - ) - response = self.service.call_endpoint(get_lib().TrackingServiceUpdateRun, request) - return RunInfo.from_proto(response.run_info) - - def _search_runs( - self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token - ): - request = SearchRuns( - experiment_ids=[str(experiment_id) for experiment_id in experiment_ids], - filter=filter_string, - run_view_type=ViewType.to_proto(run_view_type), - max_results=max_results, - order_by=order_by, - page_token=page_token, - ) - response = self.service.call_endpoint(get_lib().TrackingServiceSearchRuns, request) - runs = [Run.from_proto(proto_run) for proto_run in response.runs] - return runs, (response.next_page_token or None) - - def log_batch(self, run_id, metrics, params, tags): - request = LogBatch( - run_id=run_id, - metrics=[metric.to_proto() for metric in metrics], - params=[param.to_proto() for param in params], - tags=[tag.to_proto() for tag in tags], - ) - self.service.call_endpoint(get_lib().TrackingServiceLogBatch, request) - - def log_metric(self, run_id, metric): - request = LogMetric( - run_id=run_id, - key=metric.key, - value=metric.value, - timestamp=metric.timestamp, - step=metric.step, - ) - self.service.call_endpoint(get_lib().TrackingServiceLogMetric, request) - - def set_tag(self, run_id, tag): - request = SetTag(run_id=run_id, key=tag.key, value=tag.value) - self.service.call_endpoint(get_lib().TrackingServiceSetTag, request) - - def delete_tag(self, run_id, key): - request = DeleteTag(run_id=run_id, key=key) - self.service.call_endpoint(get_lib().TrackingServiceDeleteTag, request) - -def TrackingStore(cls): - return type(cls.__name__, (_TrackingStore, cls), {}) - - -def _get_sqlalchemy_store(store_uri, artifact_uri): - from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH - from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore - - if is_go_enabled(): - SqlAlchemyStore = TrackingStore(SqlAlchemyStore) - - if artifact_uri is None: - artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH - - return SqlAlchemyStore(store_uri, artifact_uri) +import json +import logging + +from mlflow.entities import ( + Experiment, + Run, + RunInfo, + ViewType, +) +from mlflow.exceptions import MlflowException +from mlflow.protos import databricks_pb2 +from mlflow.protos.service_pb2 import ( + CreateExperiment, + CreateRun, + DeleteExperiment, + DeleteRun, + DeleteTag, + GetExperiment, + GetExperimentByName, + GetRun, + LogBatch, + LogMetric, + RestoreExperiment, + RestoreRun, + SearchRuns, + SetTag, + UpdateExperiment, + UpdateRun, +) +from mlflow.utils.uri import resolve_uri_if_local + +from mlflow_go import is_go_enabled +from mlflow_go.lib import get_lib +from mlflow_go.store._service_proxy import _ServiceProxy + +_logger = logging.getLogger(__name__) + + +class _TrackingStore: + def __init__(self, *args, **kwargs): + store_uri = args[0] if len(args) > 0 else kwargs.get("db_uri", kwargs.get("root_directory")) + default_artifact_root = ( + args[1] + if len(args) > 1 + else kwargs.get("default_artifact_root", kwargs.get("artifact_root_uri")) + ) + config = json.dumps( + { + "default_artifact_root": resolve_uri_if_local(default_artifact_root), + "tracking_store_uri": store_uri, + "log_level": logging.getLevelName(_logger.getEffectiveLevel()), + } + ).encode("utf-8") + self.service = _ServiceProxy(get_lib().CreateTrackingService(config, len(config))) + super().__init__(store_uri, default_artifact_root) + + def __del__(self): + if hasattr(self, "service"): + get_lib().DestroyTrackingService(self.service.id) + + def get_experiment(self, experiment_id): + request = GetExperiment(experiment_id=str(experiment_id)) + response = self.service.call_endpoint(get_lib().TrackingServiceGetExperiment, request) + return Experiment.from_proto(response.experiment) + + def get_experiment_by_name(self, experiment_name): + request = GetExperimentByName(experiment_name=experiment_name) + try: + response = self.service.call_endpoint( + get_lib().TrackingServiceGetExperimentByName, request + ) + return Experiment.from_proto(response.experiment) + except MlflowException as e: + if e.error_code == databricks_pb2.ErrorCode.Name( + databricks_pb2.RESOURCE_DOES_NOT_EXIST + ): + return None + raise + + def create_experiment(self, name, artifact_location=None, tags=None): + request = CreateExperiment( + name=name, + artifact_location=artifact_location, + tags=[tag.to_proto() for tag in tags] if tags else [], + ) + response = self.service.call_endpoint(get_lib().TrackingServiceCreateExperiment, request) + return response.experiment_id + + def delete_experiment(self, experiment_id): + request = DeleteExperiment(experiment_id=str(experiment_id)) + self.service.call_endpoint(get_lib().TrackingServiceDeleteExperiment, request) + + def restore_experiment(self, experiment_id): + request = RestoreExperiment(experiment_id=str(experiment_id)) + self.service.call_endpoint(get_lib().TrackingServiceRestoreExperiment, request) + + def rename_experiment(self, experiment_id, new_name): + request = UpdateExperiment(experiment_id=str(experiment_id), new_name=new_name) + self.service.call_endpoint(get_lib().TrackingServiceUpdateExperiment, request) + + def get_run(self, run_id): + request = GetRun(run_uuid=run_id, run_id=run_id) + response = self.service.call_endpoint(get_lib().TrackingServiceGetRun, request) + return Run.from_proto(response.run) + + def create_run(self, experiment_id, user_id, start_time, tags, run_name): + request = CreateRun( + experiment_id=str(experiment_id), + user_id=user_id, + start_time=start_time, + tags=[tag.to_proto() for tag in tags] if tags else [], + run_name=run_name, + ) + response = self.service.call_endpoint(get_lib().TrackingServiceCreateRun, request) + return Run.from_proto(response.run) + + def delete_run(self, run_id): + request = DeleteRun(run_id=run_id) + self.service.call_endpoint(get_lib().TrackingServiceDeleteRun, request) + + def restore_run(self, run_id): + request = RestoreRun(run_id=run_id) + self.service.call_endpoint(get_lib().TrackingServiceRestoreRun, request) + + def update_run(self, run_id, run_status, end_time, run_name): + request = UpdateRun( + run_uuid=run_id, + run_id=run_id, + status=run_status, + end_time=end_time, + run_name=run_name, + ) + response = self.service.call_endpoint(get_lib().TrackingServiceUpdateRun, request) + return RunInfo.from_proto(response.run_info) + + def _search_runs( + self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token + ): + request = SearchRuns( + experiment_ids=[str(experiment_id) for experiment_id in experiment_ids], + filter=filter_string, + run_view_type=ViewType.to_proto(run_view_type), + max_results=max_results, + order_by=order_by, + page_token=page_token, + ) + response = self.service.call_endpoint(get_lib().TrackingServiceSearchRuns, request) + runs = [Run.from_proto(proto_run) for proto_run in response.runs] + return runs, (response.next_page_token or None) + + def log_batch(self, run_id, metrics, params, tags): + request = LogBatch( + run_id=run_id, + metrics=[metric.to_proto() for metric in metrics], + params=[param.to_proto() for param in params], + tags=[tag.to_proto() for tag in tags], + ) + self.service.call_endpoint(get_lib().TrackingServiceLogBatch, request) + + def log_metric(self, run_id, metric): + request = LogMetric( + run_id=run_id, + key=metric.key, + value=metric.value, + timestamp=metric.timestamp, + step=metric.step, + ) + self.service.call_endpoint(get_lib().TrackingServiceLogMetric, request) + + def set_tag(self, run_id, tag): + request = SetTag(run_id=run_id, key=tag.key, value=tag.value) + self.service.call_endpoint(get_lib().TrackingServiceSetTag, request) + + def delete_tag(self, run_id, key): + request = DeleteTag(run_id=run_id, key=key) + self.service.call_endpoint(get_lib().TrackingServiceDeleteTag, request) + +def TrackingStore(cls): + return type(cls.__name__, (_TrackingStore, cls), {}) + + +def _get_sqlalchemy_store(store_uri, artifact_uri): + from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH + from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + + if is_go_enabled(): + SqlAlchemyStore = TrackingStore(SqlAlchemyStore) + + if artifact_uri is None: + artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH + + return SqlAlchemyStore(store_uri, artifact_uri) diff --git a/pkg/artifacts/service/service.go b/pkg/artifacts/service/service.go index e4b6971..0f21d38 100644 --- a/pkg/artifacts/service/service.go +++ b/pkg/artifacts/service/service.go @@ -1,17 +1,17 @@ -package service - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/config" -) - -type ArtifactsService struct { - config *config.Config -} - -func NewArtifactsService(_ context.Context, config *config.Config) (*ArtifactsService, error) { - return &ArtifactsService{ - config: config, - }, nil -} +package service + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/config" +) + +type ArtifactsService struct { + config *config.Config +} + +func NewArtifactsService(_ context.Context, config *config.Config) (*ArtifactsService, error) { + return &ArtifactsService{ + config: config, + }, nil +} diff --git a/pkg/cmd/server/main.go b/pkg/cmd/server/main.go index de96c60..5b71e49 100644 --- a/pkg/cmd/server/main.go +++ b/pkg/cmd/server/main.go @@ -1,21 +1,21 @@ -package main - -import ( - "os" - - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/server" -) - -func main() { - cfg, err := config.NewConfigFromString(os.Getenv("MLFLOW_GO_CONFIG")) - if err != nil { - logrus.Fatal("Failed to read config from MLFLOW_GO_CONFIG environment variable: ", err) - } - - if err := server.LaunchWithSignalHandler(cfg); err != nil { - logrus.Fatal("Failed to launch server: ", err) - } -} +package main + +import ( + "os" + + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/server" +) + +func main() { + cfg, err := config.NewConfigFromString(os.Getenv("MLFLOW_GO_CONFIG")) + if err != nil { + logrus.Fatal("Failed to read config from MLFLOW_GO_CONFIG environment variable: ", err) + } + + if err := server.LaunchWithSignalHandler(cfg); err != nil { + logrus.Fatal("Failed to launch server: ", err) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 59fe948..2116518 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,106 +1,106 @@ -package config - -import ( - "encoding/json" - "errors" - "fmt" - "time" -) - -type Duration struct { - time.Duration -} - -var ErrDuration = errors.New("invalid duration") - -func (d *Duration) UnmarshalJSON(b []byte) error { - var v interface{} - if err := json.Unmarshal(b, &v); err != nil { - return fmt.Errorf("failed to unmarshal duration: %w", err) - } - - switch value := v.(type) { - case float64: - d.Duration = time.Duration(value) - - return nil - case string: - var err error - - d.Duration, err = time.ParseDuration(value) - if err != nil { - return fmt.Errorf("failed to parse duration \"%s\": %w", value, err) - } - - return nil - default: - return ErrDuration - } -} - -type Config struct { - Address string `json:"address"` - DefaultArtifactRoot string `json:"default_artifact_root"` - LogLevel string `json:"log_level"` - ModelRegistryStoreURI string `json:"model_registry_store_uri"` - PythonAddress string `json:"python_address"` - PythonCommand []string `json:"python_command"` - PythonEnv []string `json:"python_env"` - ShutdownTimeout Duration `json:"shutdown_timeout"` - StaticFolder string `json:"static_folder"` - TrackingStoreURI string `json:"tracking_store_uri"` - Version string `json:"version"` -} - -func NewConfigFromBytes(cfgBytes []byte) (*Config, error) { - if len(cfgBytes) == 0 { - cfgBytes = []byte("{}") - } - - var cfg Config - if err := json.Unmarshal(cfgBytes, &cfg); err != nil { - return nil, fmt.Errorf("failed to parse JSON config: %w", err) - } - - cfg.applyDefaults() - - return &cfg, nil -} - -func NewConfigFromString(s string) (*Config, error) { - return NewConfigFromBytes([]byte(s)) -} - -func (c *Config) applyDefaults() { - if c.Address == "" { - c.Address = "localhost:5000" - } - - if c.DefaultArtifactRoot == "" { - c.DefaultArtifactRoot = "mlflow-artifacts:/" - } - - if c.LogLevel == "" { - c.LogLevel = "INFO" - } - - if c.ShutdownTimeout.Duration == 0 { - c.ShutdownTimeout.Duration = time.Minute - } - - if c.TrackingStoreURI == "" { - if c.ModelRegistryStoreURI != "" { - c.TrackingStoreURI = c.ModelRegistryStoreURI - } else { - c.TrackingStoreURI = "sqlite:///mlflow.db" - } - } - - if c.ModelRegistryStoreURI == "" { - c.ModelRegistryStoreURI = c.TrackingStoreURI - } - - if c.Version == "" { - c.Version = "dev" - } -} +package config + +import ( + "encoding/json" + "errors" + "fmt" + "time" +) + +type Duration struct { + time.Duration +} + +var ErrDuration = errors.New("invalid duration") + +func (d *Duration) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return fmt.Errorf("failed to unmarshal duration: %w", err) + } + + switch value := v.(type) { + case float64: + d.Duration = time.Duration(value) + + return nil + case string: + var err error + + d.Duration, err = time.ParseDuration(value) + if err != nil { + return fmt.Errorf("failed to parse duration \"%s\": %w", value, err) + } + + return nil + default: + return ErrDuration + } +} + +type Config struct { + Address string `json:"address"` + DefaultArtifactRoot string `json:"default_artifact_root"` + LogLevel string `json:"log_level"` + ModelRegistryStoreURI string `json:"model_registry_store_uri"` + PythonAddress string `json:"python_address"` + PythonCommand []string `json:"python_command"` + PythonEnv []string `json:"python_env"` + ShutdownTimeout Duration `json:"shutdown_timeout"` + StaticFolder string `json:"static_folder"` + TrackingStoreURI string `json:"tracking_store_uri"` + Version string `json:"version"` +} + +func NewConfigFromBytes(cfgBytes []byte) (*Config, error) { + if len(cfgBytes) == 0 { + cfgBytes = []byte("{}") + } + + var cfg Config + if err := json.Unmarshal(cfgBytes, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse JSON config: %w", err) + } + + cfg.applyDefaults() + + return &cfg, nil +} + +func NewConfigFromString(s string) (*Config, error) { + return NewConfigFromBytes([]byte(s)) +} + +func (c *Config) applyDefaults() { + if c.Address == "" { + c.Address = "localhost:5000" + } + + if c.DefaultArtifactRoot == "" { + c.DefaultArtifactRoot = "mlflow-artifacts:/" + } + + if c.LogLevel == "" { + c.LogLevel = "INFO" + } + + if c.ShutdownTimeout.Duration == 0 { + c.ShutdownTimeout.Duration = time.Minute + } + + if c.TrackingStoreURI == "" { + if c.ModelRegistryStoreURI != "" { + c.TrackingStoreURI = c.ModelRegistryStoreURI + } else { + c.TrackingStoreURI = "sqlite:///mlflow.db" + } + } + + if c.ModelRegistryStoreURI == "" { + c.ModelRegistryStoreURI = c.TrackingStoreURI + } + + if c.Version == "" { + c.Version = "dev" + } +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 1a98b86..765305a 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1,53 +1,53 @@ -package config_test - -import ( - "encoding/json" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/mlflow/mlflow-go/pkg/config" -) - -type validSample struct { - input string - duration config.Duration -} - -func TestValidDuration(t *testing.T) { - t.Parallel() - - samples := []validSample{ - {input: "1000", duration: config.Duration{Duration: 1000 * time.Nanosecond}}, - {input: `"1s"`, duration: config.Duration{Duration: 1 * time.Second}}, - {input: `"2h45m"`, duration: config.Duration{Duration: 2*time.Hour + 45*time.Minute}}, - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - jsonConfig := fmt.Sprintf(`{ "shutdown_timeout": %s }`, currentSample.input) - - var cfg config.Config - - err := json.Unmarshal([]byte(jsonConfig), &cfg) - require.NoError(t, err) - - require.Equal(t, currentSample.duration, cfg.ShutdownTimeout) - }) - } -} - -func TestInvalidDuration(t *testing.T) { - t.Parallel() - - var cfg config.Config - - if err := json.Unmarshal([]byte(`{ "shutdown_timeout": "two seconds" }`), &cfg); err == nil { - t.Error("expected error") - } -} +package config_test + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/mlflow/mlflow-go/pkg/config" +) + +type validSample struct { + input string + duration config.Duration +} + +func TestValidDuration(t *testing.T) { + t.Parallel() + + samples := []validSample{ + {input: "1000", duration: config.Duration{Duration: 1000 * time.Nanosecond}}, + {input: `"1s"`, duration: config.Duration{Duration: 1 * time.Second}}, + {input: `"2h45m"`, duration: config.Duration{Duration: 2*time.Hour + 45*time.Minute}}, + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + jsonConfig := fmt.Sprintf(`{ "shutdown_timeout": %s }`, currentSample.input) + + var cfg config.Config + + err := json.Unmarshal([]byte(jsonConfig), &cfg) + require.NoError(t, err) + + require.Equal(t, currentSample.duration, cfg.ShutdownTimeout) + }) + } +} + +func TestInvalidDuration(t *testing.T) { + t.Parallel() + + var cfg config.Config + + if err := json.Unmarshal([]byte(`{ "shutdown_timeout": "two seconds" }`), &cfg); err == nil { + t.Error("expected error") + } +} diff --git a/pkg/contract/error.go b/pkg/contract/error.go index 8c95ba0..04b3f2b 100644 --- a/pkg/contract/error.go +++ b/pkg/contract/error.go @@ -1,82 +1,82 @@ -package contract - -import ( - "encoding/json" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/protos" -) - -type ErrorCode protos.ErrorCode - -func (e ErrorCode) String() string { - return protos.ErrorCode(e).String() -} - -// Custom json marshalling for ErrorCode. -func (e ErrorCode) MarshalJSON() ([]byte, error) { - //nolint:wrapcheck - return json.Marshal(e.String()) -} - -type Error struct { - Code ErrorCode `json:"error_code"` - Message string `json:"message"` - Inner error `json:"-"` -} - -func NewError(code protos.ErrorCode, message string) *Error { - return NewErrorWith(code, message, nil) -} - -func NewErrorWith(code protos.ErrorCode, message string, err error) *Error { - return &Error{ - Code: ErrorCode(code), - Message: message, - Inner: err, - } -} - -func (e *Error) Error() string { - msg := fmt.Sprintf("[%s] %s", e.Code.String(), e.Message) - if e.Inner != nil { - return fmt.Sprintf("%s: %s", msg, e.Inner) - } - - return msg -} - -func (e *Error) Unwrap() error { - return e.Inner -} - -//nolint:cyclop -func (e *Error) StatusCode() int { - //nolint:exhaustive,mnd - switch protos.ErrorCode(e.Code) { - case protos.ErrorCode_BAD_REQUEST, protos.ErrorCode_INVALID_PARAMETER_VALUE, protos.ErrorCode_RESOURCE_ALREADY_EXISTS: - return 400 - case protos.ErrorCode_CUSTOMER_UNAUTHORIZED, protos.ErrorCode_UNAUTHENTICATED: - return 401 - case protos.ErrorCode_PERMISSION_DENIED: - return 403 - case protos.ErrorCode_ENDPOINT_NOT_FOUND, protos.ErrorCode_NOT_FOUND, protos.ErrorCode_RESOURCE_DOES_NOT_EXIST: - return 404 - case protos.ErrorCode_ABORTED, protos.ErrorCode_ALREADY_EXISTS, protos.ErrorCode_RESOURCE_CONFLICT: - return 409 - case protos.ErrorCode_RESOURCE_EXHAUSTED, protos.ErrorCode_RESOURCE_LIMIT_EXCEEDED: - return 429 - case protos.ErrorCode_CANCELLED: - return 499 - case protos.ErrorCode_DATA_LOSS, protos.ErrorCode_INTERNAL_ERROR, protos.ErrorCode_INVALID_STATE: - return 500 - case protos.ErrorCode_NOT_IMPLEMENTED: - return 501 - case protos.ErrorCode_TEMPORARILY_UNAVAILABLE: - return 503 - case protos.ErrorCode_DEADLINE_EXCEEDED: - return 504 - default: - return 500 - } -} +package contract + +import ( + "encoding/json" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/protos" +) + +type ErrorCode protos.ErrorCode + +func (e ErrorCode) String() string { + return protos.ErrorCode(e).String() +} + +// Custom json marshalling for ErrorCode. +func (e ErrorCode) MarshalJSON() ([]byte, error) { + //nolint:wrapcheck + return json.Marshal(e.String()) +} + +type Error struct { + Code ErrorCode `json:"error_code"` + Message string `json:"message"` + Inner error `json:"-"` +} + +func NewError(code protos.ErrorCode, message string) *Error { + return NewErrorWith(code, message, nil) +} + +func NewErrorWith(code protos.ErrorCode, message string, err error) *Error { + return &Error{ + Code: ErrorCode(code), + Message: message, + Inner: err, + } +} + +func (e *Error) Error() string { + msg := fmt.Sprintf("[%s] %s", e.Code.String(), e.Message) + if e.Inner != nil { + return fmt.Sprintf("%s: %s", msg, e.Inner) + } + + return msg +} + +func (e *Error) Unwrap() error { + return e.Inner +} + +//nolint:cyclop +func (e *Error) StatusCode() int { + //nolint:exhaustive,mnd + switch protos.ErrorCode(e.Code) { + case protos.ErrorCode_BAD_REQUEST, protos.ErrorCode_INVALID_PARAMETER_VALUE, protos.ErrorCode_RESOURCE_ALREADY_EXISTS: + return 400 + case protos.ErrorCode_CUSTOMER_UNAUTHORIZED, protos.ErrorCode_UNAUTHENTICATED: + return 401 + case protos.ErrorCode_PERMISSION_DENIED: + return 403 + case protos.ErrorCode_ENDPOINT_NOT_FOUND, protos.ErrorCode_NOT_FOUND, protos.ErrorCode_RESOURCE_DOES_NOT_EXIST: + return 404 + case protos.ErrorCode_ABORTED, protos.ErrorCode_ALREADY_EXISTS, protos.ErrorCode_RESOURCE_CONFLICT: + return 409 + case protos.ErrorCode_RESOURCE_EXHAUSTED, protos.ErrorCode_RESOURCE_LIMIT_EXCEEDED: + return 429 + case protos.ErrorCode_CANCELLED: + return 499 + case protos.ErrorCode_DATA_LOSS, protos.ErrorCode_INTERNAL_ERROR, protos.ErrorCode_INVALID_STATE: + return 500 + case protos.ErrorCode_NOT_IMPLEMENTED: + return 501 + case protos.ErrorCode_TEMPORARILY_UNAVAILABLE: + return 503 + case protos.ErrorCode_DEADLINE_EXCEEDED: + return 504 + default: + return 500 + } +} diff --git a/pkg/contract/http_request_parser.go b/pkg/contract/http_request_parser.go index f2fca20..d17201a 100644 --- a/pkg/contract/http_request_parser.go +++ b/pkg/contract/http_request_parser.go @@ -1,8 +1,8 @@ -package contract - -import "github.com/gofiber/fiber/v2" - -type HTTPRequestParser interface { - ParseBody(ctx *fiber.Ctx, out interface{}) *Error - ParseQuery(ctx *fiber.Ctx, out interface{}) *Error -} +package contract + +import "github.com/gofiber/fiber/v2" + +type HTTPRequestParser interface { + ParseBody(ctx *fiber.Ctx, out interface{}) *Error + ParseQuery(ctx *fiber.Ctx, out interface{}) *Error +} diff --git a/pkg/entities/dataset.go b/pkg/entities/dataset.go index 157c696..c0fe7e2 100644 --- a/pkg/entities/dataset.go +++ b/pkg/entities/dataset.go @@ -1,35 +1,35 @@ -package entities - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" -) - -type Dataset struct { - Name string - Digest string - SourceType string - Source string - Schema string - Profile string -} - -func (d *Dataset) ToProto() *protos.Dataset { - var schema *string - if d.Schema != "" { - schema = &d.Schema - } - - var profile *string - if d.Profile != "" { - profile = &d.Profile - } - - return &protos.Dataset{ - Name: &d.Name, - Digest: &d.Digest, - SourceType: &d.SourceType, - Source: &d.Source, - Schema: schema, - Profile: profile, - } -} +package entities + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" +) + +type Dataset struct { + Name string + Digest string + SourceType string + Source string + Schema string + Profile string +} + +func (d *Dataset) ToProto() *protos.Dataset { + var schema *string + if d.Schema != "" { + schema = &d.Schema + } + + var profile *string + if d.Profile != "" { + profile = &d.Profile + } + + return &protos.Dataset{ + Name: &d.Name, + Digest: &d.Digest, + SourceType: &d.SourceType, + Source: &d.Source, + Schema: schema, + Profile: profile, + } +} diff --git a/pkg/entities/dataset_input.go b/pkg/entities/dataset_input.go index 7f48dc4..9284671 100644 --- a/pkg/entities/dataset_input.go +++ b/pkg/entities/dataset_input.go @@ -1,20 +1,20 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type DatasetInput struct { - Tags []*InputTag - Dataset *Dataset -} - -func (ds DatasetInput) ToProto() *protos.DatasetInput { - tags := make([]*protos.InputTag, 0, len(ds.Tags)) - for _, tag := range ds.Tags { - tags = append(tags, tag.ToProto()) - } - - return &protos.DatasetInput{ - Tags: tags, - Dataset: ds.Dataset.ToProto(), - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type DatasetInput struct { + Tags []*InputTag + Dataset *Dataset +} + +func (ds DatasetInput) ToProto() *protos.DatasetInput { + tags := make([]*protos.InputTag, 0, len(ds.Tags)) + for _, tag := range ds.Tags { + tags = append(tags, tag.ToProto()) + } + + return &protos.DatasetInput{ + Tags: tags, + Dataset: ds.Dataset.ToProto(), + } +} diff --git a/pkg/entities/experiment.go b/pkg/entities/experiment.go index 3c081bc..0bb58b4 100644 --- a/pkg/entities/experiment.go +++ b/pkg/entities/experiment.go @@ -1,36 +1,36 @@ -package entities - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type Experiment struct { - Name string - ExperimentID string - ArtifactLocation string - LifecycleStage string - LastUpdateTime int64 - CreationTime int64 - Tags []*ExperimentTag -} - -func (e Experiment) ToProto() *protos.Experiment { - tags := make([]*protos.ExperimentTag, len(e.Tags)) - - for i, tag := range e.Tags { - tags[i] = tag.ToProto() - } - - experiment := protos.Experiment{ - ExperimentId: &e.ExperimentID, - Name: &e.Name, - ArtifactLocation: &e.ArtifactLocation, - LifecycleStage: utils.PtrTo(e.LifecycleStage), - CreationTime: &e.CreationTime, - LastUpdateTime: &e.LastUpdateTime, - Tags: tags, - } - - return &experiment -} +package entities + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type Experiment struct { + Name string + ExperimentID string + ArtifactLocation string + LifecycleStage string + LastUpdateTime int64 + CreationTime int64 + Tags []*ExperimentTag +} + +func (e Experiment) ToProto() *protos.Experiment { + tags := make([]*protos.ExperimentTag, len(e.Tags)) + + for i, tag := range e.Tags { + tags[i] = tag.ToProto() + } + + experiment := protos.Experiment{ + ExperimentId: &e.ExperimentID, + Name: &e.Name, + ArtifactLocation: &e.ArtifactLocation, + LifecycleStage: utils.PtrTo(e.LifecycleStage), + CreationTime: &e.CreationTime, + LastUpdateTime: &e.LastUpdateTime, + Tags: tags, + } + + return &experiment +} diff --git a/pkg/entities/experiment_tag.go b/pkg/entities/experiment_tag.go index 43ef33c..ced86c1 100644 --- a/pkg/entities/experiment_tag.go +++ b/pkg/entities/experiment_tag.go @@ -1,22 +1,22 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type ExperimentTag struct { - Key string - Value string -} - -func (et *ExperimentTag) ToProto() *protos.ExperimentTag { - return &protos.ExperimentTag{ - Key: &et.Key, - Value: &et.Value, - } -} - -func NewExperimentTagFromProto(proto *protos.ExperimentTag) *ExperimentTag { - return &ExperimentTag{ - Key: proto.GetKey(), - Value: proto.GetValue(), - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type ExperimentTag struct { + Key string + Value string +} + +func (et *ExperimentTag) ToProto() *protos.ExperimentTag { + return &protos.ExperimentTag{ + Key: &et.Key, + Value: &et.Value, + } +} + +func NewExperimentTagFromProto(proto *protos.ExperimentTag) *ExperimentTag { + return &ExperimentTag{ + Key: proto.GetKey(), + Value: proto.GetValue(), + } +} diff --git a/pkg/entities/input_tag.go b/pkg/entities/input_tag.go index 06f01e2..67dda7c 100644 --- a/pkg/entities/input_tag.go +++ b/pkg/entities/input_tag.go @@ -1,15 +1,15 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type InputTag struct { - Key string - Value string -} - -func (i InputTag) ToProto() *protos.InputTag { - return &protos.InputTag{ - Key: &i.Key, - Value: &i.Value, - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type InputTag struct { + Key string + Value string +} + +func (i InputTag) ToProto() *protos.InputTag { + return &protos.InputTag{ + Key: &i.Key, + Value: &i.Value, + } +} diff --git a/pkg/entities/metric.go b/pkg/entities/metric.go index 75c0600..9a502f8 100644 --- a/pkg/entities/metric.go +++ b/pkg/entities/metric.go @@ -1,52 +1,52 @@ -package entities - -import ( - "math" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type Metric struct { - Key string - Value float64 - Timestamp int64 - Step int64 - IsNaN bool -} - -func (m Metric) ToProto() *protos.Metric { - metric := protos.Metric{ - Key: &m.Key, - Value: &m.Value, - Timestamp: &m.Timestamp, - Step: &m.Step, - } - - switch { - case m.IsNaN: - metric.Value = utils.PtrTo(math.NaN()) - default: - metric.Value = &m.Value - } - - return &metric -} - -func MetricFromProto(proto *protos.Metric) *Metric { - return &Metric{ - Key: proto.GetKey(), - Value: proto.GetValue(), - Timestamp: proto.GetTimestamp(), - Step: proto.GetStep(), - } -} - -func MetricFromLogMetricProtoInput(input *protos.LogMetric) *Metric { - return &Metric{ - Key: input.GetKey(), - Value: input.GetValue(), - Timestamp: input.GetTimestamp(), - Step: input.GetStep(), - } -} +package entities + +import ( + "math" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type Metric struct { + Key string + Value float64 + Timestamp int64 + Step int64 + IsNaN bool +} + +func (m Metric) ToProto() *protos.Metric { + metric := protos.Metric{ + Key: &m.Key, + Value: &m.Value, + Timestamp: &m.Timestamp, + Step: &m.Step, + } + + switch { + case m.IsNaN: + metric.Value = utils.PtrTo(math.NaN()) + default: + metric.Value = &m.Value + } + + return &metric +} + +func MetricFromProto(proto *protos.Metric) *Metric { + return &Metric{ + Key: proto.GetKey(), + Value: proto.GetValue(), + Timestamp: proto.GetTimestamp(), + Step: proto.GetStep(), + } +} + +func MetricFromLogMetricProtoInput(input *protos.LogMetric) *Metric { + return &Metric{ + Key: input.GetKey(), + Value: input.GetValue(), + Timestamp: input.GetTimestamp(), + Step: input.GetStep(), + } +} diff --git a/pkg/entities/param.go b/pkg/entities/param.go index 9e1693d..2d3ec01 100644 --- a/pkg/entities/param.go +++ b/pkg/entities/param.go @@ -1,22 +1,22 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type Param struct { - Key string - Value string -} - -func (p Param) ToProto() *protos.Param { - return &protos.Param{ - Key: &p.Key, - Value: &p.Value, - } -} - -func ParamFromProto(proto *protos.Param) *Param { - return &Param{ - Key: *proto.Key, - Value: *proto.Value, - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type Param struct { + Key string + Value string +} + +func (p Param) ToProto() *protos.Param { + return &protos.Param{ + Key: &p.Key, + Value: &p.Value, + } +} + +func ParamFromProto(proto *protos.Param) *Param { + return &Param{ + Key: *proto.Key, + Value: *proto.Value, + } +} diff --git a/pkg/entities/run.go b/pkg/entities/run.go index 3698200..c0f231f 100644 --- a/pkg/entities/run.go +++ b/pkg/entities/run.go @@ -1,75 +1,75 @@ -package entities - -import ( - "strings" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func RunStatusToProto(status string) *protos.RunStatus { - if status == "" { - return nil - } - - if protoStatus, ok := protos.RunStatus_value[strings.ToUpper(status)]; ok { - return (*protos.RunStatus)(&protoStatus) - } - - return nil -} - -type Run struct { - Info *RunInfo - Data *RunData - Inputs *RunInputs -} - -func (r Run) ToProto() *protos.Run { - metrics := make([]*protos.Metric, 0, len(r.Data.Metrics)) - for _, metric := range r.Data.Metrics { - metrics = append(metrics, metric.ToProto()) - } - - params := make([]*protos.Param, 0, len(r.Data.Params)) - for _, param := range r.Data.Params { - params = append(params, param.ToProto()) - } - - tags := make([]*protos.RunTag, 0, len(r.Data.Tags)) - for _, tag := range r.Data.Tags { - tags = append(tags, tag.ToProto()) - } - - data := &protos.RunData{ - Metrics: metrics, - Params: params, - Tags: tags, - } - - datasetInputs := make([]*protos.DatasetInput, 0, len(r.Inputs.DatasetInputs)) - for _, input := range r.Inputs.DatasetInputs { - datasetInputs = append(datasetInputs, input.ToProto()) - } - - inputs := &protos.RunInputs{ - DatasetInputs: datasetInputs, - } - - return &protos.Run{ - Info: &protos.RunInfo{ - RunId: &r.Info.RunID, - RunUuid: &r.Info.RunID, - RunName: &r.Info.RunName, - ExperimentId: utils.ConvertInt32PointerToStringPointer(&r.Info.ExperimentID), - UserId: &r.Info.UserID, - Status: RunStatusToProto(r.Info.Status), - StartTime: &r.Info.StartTime, - EndTime: r.Info.EndTime, - ArtifactUri: &r.Info.ArtifactURI, - LifecycleStage: utils.PtrTo(r.Info.LifecycleStage), - }, - Data: data, - Inputs: inputs, - } -} +package entities + +import ( + "strings" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func RunStatusToProto(status string) *protos.RunStatus { + if status == "" { + return nil + } + + if protoStatus, ok := protos.RunStatus_value[strings.ToUpper(status)]; ok { + return (*protos.RunStatus)(&protoStatus) + } + + return nil +} + +type Run struct { + Info *RunInfo + Data *RunData + Inputs *RunInputs +} + +func (r Run) ToProto() *protos.Run { + metrics := make([]*protos.Metric, 0, len(r.Data.Metrics)) + for _, metric := range r.Data.Metrics { + metrics = append(metrics, metric.ToProto()) + } + + params := make([]*protos.Param, 0, len(r.Data.Params)) + for _, param := range r.Data.Params { + params = append(params, param.ToProto()) + } + + tags := make([]*protos.RunTag, 0, len(r.Data.Tags)) + for _, tag := range r.Data.Tags { + tags = append(tags, tag.ToProto()) + } + + data := &protos.RunData{ + Metrics: metrics, + Params: params, + Tags: tags, + } + + datasetInputs := make([]*protos.DatasetInput, 0, len(r.Inputs.DatasetInputs)) + for _, input := range r.Inputs.DatasetInputs { + datasetInputs = append(datasetInputs, input.ToProto()) + } + + inputs := &protos.RunInputs{ + DatasetInputs: datasetInputs, + } + + return &protos.Run{ + Info: &protos.RunInfo{ + RunId: &r.Info.RunID, + RunUuid: &r.Info.RunID, + RunName: &r.Info.RunName, + ExperimentId: utils.ConvertInt32PointerToStringPointer(&r.Info.ExperimentID), + UserId: &r.Info.UserID, + Status: RunStatusToProto(r.Info.Status), + StartTime: &r.Info.StartTime, + EndTime: r.Info.EndTime, + ArtifactUri: &r.Info.ArtifactURI, + LifecycleStage: utils.PtrTo(r.Info.LifecycleStage), + }, + Data: data, + Inputs: inputs, + } +} diff --git a/pkg/entities/run_data.go b/pkg/entities/run_data.go index 38eeb21..92679f2 100644 --- a/pkg/entities/run_data.go +++ b/pkg/entities/run_data.go @@ -1,7 +1,7 @@ -package entities - -type RunData struct { - Tags []*RunTag - Params []*Param - Metrics []*Metric -} +package entities + +type RunData struct { + Tags []*RunTag + Params []*Param + Metrics []*Metric +} diff --git a/pkg/entities/run_info.go b/pkg/entities/run_info.go index 15eea71..51debc1 100644 --- a/pkg/entities/run_info.go +++ b/pkg/entities/run_info.go @@ -1,34 +1,34 @@ -package entities - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type RunInfo struct { - RunID string - RunUUID string - RunName string - ExperimentID int32 - UserID string - Status string - StartTime int64 - EndTime *int64 - ArtifactURI string - LifecycleStage string -} - -func (ri RunInfo) ToProto() *protos.RunInfo { - return &protos.RunInfo{ - RunId: &ri.RunID, - RunUuid: &ri.RunID, - RunName: &ri.RunName, - ExperimentId: utils.ConvertInt32PointerToStringPointer(&ri.ExperimentID), - UserId: &ri.UserID, - Status: RunStatusToProto(ri.Status), - StartTime: &ri.StartTime, - EndTime: ri.EndTime, - ArtifactUri: &ri.ArtifactURI, - LifecycleStage: utils.PtrTo(ri.LifecycleStage), - } -} +package entities + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type RunInfo struct { + RunID string + RunUUID string + RunName string + ExperimentID int32 + UserID string + Status string + StartTime int64 + EndTime *int64 + ArtifactURI string + LifecycleStage string +} + +func (ri RunInfo) ToProto() *protos.RunInfo { + return &protos.RunInfo{ + RunId: &ri.RunID, + RunUuid: &ri.RunID, + RunName: &ri.RunName, + ExperimentId: utils.ConvertInt32PointerToStringPointer(&ri.ExperimentID), + UserId: &ri.UserID, + Status: RunStatusToProto(ri.Status), + StartTime: &ri.StartTime, + EndTime: ri.EndTime, + ArtifactUri: &ri.ArtifactURI, + LifecycleStage: utils.PtrTo(ri.LifecycleStage), + } +} diff --git a/pkg/entities/run_inputs.go b/pkg/entities/run_inputs.go index c66ec9a..20bc444 100644 --- a/pkg/entities/run_inputs.go +++ b/pkg/entities/run_inputs.go @@ -1,5 +1,5 @@ -package entities - -type RunInputs struct { - DatasetInputs []*DatasetInput -} +package entities + +type RunInputs struct { + DatasetInputs []*DatasetInput +} diff --git a/pkg/entities/run_tag.go b/pkg/entities/run_tag.go index 2321507..1b31f2e 100644 --- a/pkg/entities/run_tag.go +++ b/pkg/entities/run_tag.go @@ -1,22 +1,22 @@ -package entities - -import "github.com/mlflow/mlflow-go/pkg/protos" - -type RunTag struct { - Key string - Value string -} - -func (t RunTag) ToProto() *protos.RunTag { - return &protos.RunTag{ - Key: &t.Key, - Value: &t.Value, - } -} - -func NewTagFromProto(proto *protos.RunTag) *RunTag { - return &RunTag{ - Key: proto.GetKey(), - Value: proto.GetValue(), - } -} +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type RunTag struct { + Key string + Value string +} + +func (t RunTag) ToProto() *protos.RunTag { + return &protos.RunTag{ + Key: &t.Key, + Value: &t.Value, + } +} + +func NewTagFromProto(proto *protos.RunTag) *RunTag { + return &RunTag{ + Key: proto.GetKey(), + Value: proto.GetValue(), + } +} diff --git a/pkg/lib/artifacts.go b/pkg/lib/artifacts.go index 198dc5e..4aaca17 100644 --- a/pkg/lib/artifacts.go +++ b/pkg/lib/artifacts.go @@ -1,22 +1,22 @@ -package main - -import "C" - -import ( - "unsafe" - - "github.com/mlflow/mlflow-go/pkg/artifacts/service" -) - -var artifactsServices = newInstanceMap[*service.ArtifactsService]() - -//export CreateArtifactsService -func CreateArtifactsService(configData unsafe.Pointer, configSize C.int) int64 { - //nolint:nlreturn - return artifactsServices.Create(service.NewArtifactsService, C.GoBytes(configData, configSize)) -} - -//export DestroyArtifactsService -func DestroyArtifactsService(id int64) { - artifactsServices.Destroy(id) -} +package main + +import "C" + +import ( + "unsafe" + + "github.com/mlflow/mlflow-go/pkg/artifacts/service" +) + +var artifactsServices = newInstanceMap[*service.ArtifactsService]() + +//export CreateArtifactsService +func CreateArtifactsService(configData unsafe.Pointer, configSize C.int) int64 { + //nolint:nlreturn + return artifactsServices.Create(service.NewArtifactsService, C.GoBytes(configData, configSize)) +} + +//export DestroyArtifactsService +func DestroyArtifactsService(id int64) { + artifactsServices.Destroy(id) +} diff --git a/pkg/lib/ffi.go b/pkg/lib/ffi.go index 60751d2..335a828 100644 --- a/pkg/lib/ffi.go +++ b/pkg/lib/ffi.go @@ -1,91 +1,91 @@ -package main - -import "C" - -import ( - "context" - "encoding/json" - "unsafe" - - "google.golang.org/protobuf/proto" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -func unmarshalAndValidateProto( - data []byte, - msg proto.Message, -) *contract.Error { - if err := proto.Unmarshal(data, msg); err != nil { - return contract.NewError( - protos.ErrorCode_BAD_REQUEST, - err.Error(), - ) - } - - validate, cErr := getValidator() - if cErr != nil { - return cErr - } - - if err := validate.Struct(msg); err != nil { - return validation.NewErrorFromValidationError(err) - } - - return nil -} - -func marshalProto(msg proto.Message) ([]byte, *contract.Error) { - res, err := proto.Marshal(msg) - if err != nil { - return nil, contract.NewError( - protos.ErrorCode_INTERNAL_ERROR, - err.Error(), - ) - } - - return res, nil -} - -func makePointerFromBytes(data []byte, size *C.int) unsafe.Pointer { - *size = C.int(len(data)) - - return C.CBytes(data) //nolint:nlreturn -} - -func makePointerFromError(err *contract.Error, size *C.int) unsafe.Pointer { - data, _ := json.Marshal(err) //nolint:errchkjson - - return makePointerFromBytes(data, size) -} - -// invokeServiceMethod is a helper function that invokes a service method and handles -// marshalling/unmarshalling of request/response data through the FFI boundary. -func invokeServiceMethod[I, O proto.Message]( - serviceMethod func(context.Context, I) (O, *contract.Error), - request I, - requestData unsafe.Pointer, - requestSize C.int, - responseSize *C.int, -) unsafe.Pointer { - requestBytes := C.GoBytes(requestData, requestSize) //nolint:nlreturn - - err := unmarshalAndValidateProto(requestBytes, request) - if err != nil { - return makePointerFromError(err, responseSize) - } - - response, err := serviceMethod(context.Background(), request) - if err != nil { - return makePointerFromError(err, responseSize) - } - - responseBytes, err := marshalProto(response) - if err != nil { - return makePointerFromError(err, responseSize) - } - - return makePointerFromBytes(responseBytes, responseSize) -} +package main + +import "C" + +import ( + "context" + "encoding/json" + "unsafe" + + "google.golang.org/protobuf/proto" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +func unmarshalAndValidateProto( + data []byte, + msg proto.Message, +) *contract.Error { + if err := proto.Unmarshal(data, msg); err != nil { + return contract.NewError( + protos.ErrorCode_BAD_REQUEST, + err.Error(), + ) + } + + validate, cErr := getValidator() + if cErr != nil { + return cErr + } + + if err := validate.Struct(msg); err != nil { + return validation.NewErrorFromValidationError(err) + } + + return nil +} + +func marshalProto(msg proto.Message) ([]byte, *contract.Error) { + res, err := proto.Marshal(msg) + if err != nil { + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + err.Error(), + ) + } + + return res, nil +} + +func makePointerFromBytes(data []byte, size *C.int) unsafe.Pointer { + *size = C.int(len(data)) + + return C.CBytes(data) //nolint:nlreturn +} + +func makePointerFromError(err *contract.Error, size *C.int) unsafe.Pointer { + data, _ := json.Marshal(err) //nolint:errchkjson + + return makePointerFromBytes(data, size) +} + +// invokeServiceMethod is a helper function that invokes a service method and handles +// marshalling/unmarshalling of request/response data through the FFI boundary. +func invokeServiceMethod[I, O proto.Message]( + serviceMethod func(context.Context, I) (O, *contract.Error), + request I, + requestData unsafe.Pointer, + requestSize C.int, + responseSize *C.int, +) unsafe.Pointer { + requestBytes := C.GoBytes(requestData, requestSize) //nolint:nlreturn + + err := unmarshalAndValidateProto(requestBytes, request) + if err != nil { + return makePointerFromError(err, responseSize) + } + + response, err := serviceMethod(context.Background(), request) + if err != nil { + return makePointerFromError(err, responseSize) + } + + responseBytes, err := marshalProto(response) + if err != nil { + return makePointerFromError(err, responseSize) + } + + return makePointerFromBytes(responseBytes, responseSize) +} diff --git a/pkg/lib/instance_map.go b/pkg/lib/instance_map.go index dabf1a9..70c778d 100644 --- a/pkg/lib/instance_map.go +++ b/pkg/lib/instance_map.go @@ -1,78 +1,78 @@ -package main - -import ( - "context" - "sync" - - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type instanceMap[T any] struct { - counter int64 - mutex sync.Mutex - instances map[int64]T -} - -func newInstanceMap[T any]() *instanceMap[T] { - return &instanceMap[T]{ - instances: make(map[int64]T), - } -} - -//nolint:ireturn -func (s *instanceMap[T]) Get(id int64) (T, *contract.Error) { - instance, ok := s.instances[id] - if !ok { - return instance, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - "Instance not found", - ) - } - - return instance, nil -} - -func (s *instanceMap[T]) Create( - creator func(ctx context.Context, cfg *config.Config) (T, error), - configBytes []byte, -) int64 { - cfg, err := config.NewConfigFromBytes(configBytes) - if err != nil { - logrus.Error("Failed to read config: ", err) - - return -1 - } - - logger := utils.NewLoggerFromConfig(cfg) - - logger.Debugf("Loaded config: %#v", cfg) - - instance, err := creator( - utils.NewContextWithLogger(context.Background(), logger), - cfg, - ) - if err != nil { - logger.Error("Failed to create instance: ", err) - - return -1 - } - - s.mutex.Lock() - defer s.mutex.Unlock() - - s.counter++ - s.instances[s.counter] = instance - - return s.counter -} - -func (s *instanceMap[T]) Destroy(id int64) { - s.mutex.Lock() - defer s.mutex.Unlock() - delete(s.instances, id) -} +package main + +import ( + "context" + "sync" + + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type instanceMap[T any] struct { + counter int64 + mutex sync.Mutex + instances map[int64]T +} + +func newInstanceMap[T any]() *instanceMap[T] { + return &instanceMap[T]{ + instances: make(map[int64]T), + } +} + +//nolint:ireturn +func (s *instanceMap[T]) Get(id int64) (T, *contract.Error) { + instance, ok := s.instances[id] + if !ok { + return instance, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + "Instance not found", + ) + } + + return instance, nil +} + +func (s *instanceMap[T]) Create( + creator func(ctx context.Context, cfg *config.Config) (T, error), + configBytes []byte, +) int64 { + cfg, err := config.NewConfigFromBytes(configBytes) + if err != nil { + logrus.Error("Failed to read config: ", err) + + return -1 + } + + logger := utils.NewLoggerFromConfig(cfg) + + logger.Debugf("Loaded config: %#v", cfg) + + instance, err := creator( + utils.NewContextWithLogger(context.Background(), logger), + cfg, + ) + if err != nil { + logger.Error("Failed to create instance: ", err) + + return -1 + } + + s.mutex.Lock() + defer s.mutex.Unlock() + + s.counter++ + s.instances[s.counter] = instance + + return s.counter +} + +func (s *instanceMap[T]) Destroy(id int64) { + s.mutex.Lock() + defer s.mutex.Unlock() + delete(s.instances, id) +} diff --git a/pkg/lib/main.go b/pkg/lib/main.go index c8a27b4..38dd16d 100644 --- a/pkg/lib/main.go +++ b/pkg/lib/main.go @@ -1,3 +1,3 @@ -package main - -func main() {} +package main + +func main() {} diff --git a/pkg/lib/model_registry.go b/pkg/lib/model_registry.go index 13133f6..f7efb6c 100644 --- a/pkg/lib/model_registry.go +++ b/pkg/lib/model_registry.go @@ -1,22 +1,22 @@ -package main - -import "C" - -import ( - "unsafe" - - "github.com/mlflow/mlflow-go/pkg/model_registry/service" -) - -var modelRegistryServices = newInstanceMap[*service.ModelRegistryService]() - -//export CreateModelRegistryService -func CreateModelRegistryService(configData unsafe.Pointer, configSize C.int) int64 { - //nolint:nlreturn - return modelRegistryServices.Create(service.NewModelRegistryService, C.GoBytes(configData, configSize)) -} - -//export DestroyModelRegistryService -func DestroyModelRegistryService(id int64) { - modelRegistryServices.Destroy(id) -} +package main + +import "C" + +import ( + "unsafe" + + "github.com/mlflow/mlflow-go/pkg/model_registry/service" +) + +var modelRegistryServices = newInstanceMap[*service.ModelRegistryService]() + +//export CreateModelRegistryService +func CreateModelRegistryService(configData unsafe.Pointer, configSize C.int) int64 { + //nolint:nlreturn + return modelRegistryServices.Create(service.NewModelRegistryService, C.GoBytes(configData, configSize)) +} + +//export DestroyModelRegistryService +func DestroyModelRegistryService(id int64) { + modelRegistryServices.Destroy(id) +} diff --git a/pkg/lib/server.go b/pkg/lib/server.go index 912041a..a2d02d2 100644 --- a/pkg/lib/server.go +++ b/pkg/lib/server.go @@ -1,83 +1,83 @@ -package main - -import "C" - -import ( - "context" - "unsafe" - - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/server" -) - -type serverInstance struct { - cancel context.CancelFunc - errChan <-chan error -} - -var serverInstances = newInstanceMap[serverInstance]() - -//export LaunchServer -func LaunchServer(configData unsafe.Pointer, configSize C.int) int64 { - cfg, err := config.NewConfigFromBytes(C.GoBytes(configData, configSize)) //nolint:nlreturn - if err != nil { - logrus.Error("Failed to read config: ", err) - - return -1 - } - - if err := server.LaunchWithSignalHandler(cfg); err != nil { - logrus.Error("Failed to launch server: ", err) - - return -1 - } - - return 0 -} - -//export LaunchServerAsync -func LaunchServerAsync(configData unsafe.Pointer, configSize C.int) int64 { - serverID := serverInstances.Create( - func(ctx context.Context, cfg *config.Config) (serverInstance, error) { - errChan := make(chan error, 1) - - ctx, cancel := context.WithCancel(ctx) - - go func() { - errChan <- server.Launch(ctx, cfg) - }() - - return serverInstance{ - cancel: cancel, - errChan: errChan, - }, nil - }, - C.GoBytes(configData, configSize), //nolint:nlreturn - ) - - return serverID -} - -//export StopServer -func StopServer(serverID int64) int64 { - instance, cErr := serverInstances.Get(serverID) - if cErr != nil { - logrus.Error("Failed to get instance: ", cErr) - - return -1 - } - defer serverInstances.Destroy(serverID) - - instance.cancel() - - err := <-instance.errChan - if err != nil { - logrus.Error("Server has exited with error: ", err) - - return -1 - } - - return 0 -} +package main + +import "C" + +import ( + "context" + "unsafe" + + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/server" +) + +type serverInstance struct { + cancel context.CancelFunc + errChan <-chan error +} + +var serverInstances = newInstanceMap[serverInstance]() + +//export LaunchServer +func LaunchServer(configData unsafe.Pointer, configSize C.int) int64 { + cfg, err := config.NewConfigFromBytes(C.GoBytes(configData, configSize)) //nolint:nlreturn + if err != nil { + logrus.Error("Failed to read config: ", err) + + return -1 + } + + if err := server.LaunchWithSignalHandler(cfg); err != nil { + logrus.Error("Failed to launch server: ", err) + + return -1 + } + + return 0 +} + +//export LaunchServerAsync +func LaunchServerAsync(configData unsafe.Pointer, configSize C.int) int64 { + serverID := serverInstances.Create( + func(ctx context.Context, cfg *config.Config) (serverInstance, error) { + errChan := make(chan error, 1) + + ctx, cancel := context.WithCancel(ctx) + + go func() { + errChan <- server.Launch(ctx, cfg) + }() + + return serverInstance{ + cancel: cancel, + errChan: errChan, + }, nil + }, + C.GoBytes(configData, configSize), //nolint:nlreturn + ) + + return serverID +} + +//export StopServer +func StopServer(serverID int64) int64 { + instance, cErr := serverInstances.Get(serverID) + if cErr != nil { + logrus.Error("Failed to get instance: ", cErr) + + return -1 + } + defer serverInstances.Destroy(serverID) + + instance.cancel() + + err := <-instance.errChan + if err != nil { + logrus.Error("Server has exited with error: ", err) + + return -1 + } + + return 0 +} diff --git a/pkg/lib/tracking.go b/pkg/lib/tracking.go index 800b5ac..8cd6257 100644 --- a/pkg/lib/tracking.go +++ b/pkg/lib/tracking.go @@ -1,22 +1,22 @@ -package main - -import "C" - -import ( - "unsafe" - - "github.com/mlflow/mlflow-go/pkg/tracking/service" -) - -var trackingServices = newInstanceMap[*service.TrackingService]() - -//export CreateTrackingService -func CreateTrackingService(configData unsafe.Pointer, configSize C.int) int64 { - //nolint:nlreturn - return trackingServices.Create(service.NewTrackingService, C.GoBytes(configData, configSize)) -} - -//export DestroyTrackingService -func DestroyTrackingService(id int64) { - trackingServices.Destroy(id) -} +package main + +import "C" + +import ( + "unsafe" + + "github.com/mlflow/mlflow-go/pkg/tracking/service" +) + +var trackingServices = newInstanceMap[*service.TrackingService]() + +//export CreateTrackingService +func CreateTrackingService(configData unsafe.Pointer, configSize C.int) int64 { + //nolint:nlreturn + return trackingServices.Create(service.NewTrackingService, C.GoBytes(configData, configSize)) +} + +//export DestroyTrackingService +func DestroyTrackingService(id int64) { + trackingServices.Destroy(id) +} diff --git a/pkg/lib/validation.go b/pkg/lib/validation.go index b7bf64f..f46526c 100644 --- a/pkg/lib/validation.go +++ b/pkg/lib/validation.go @@ -1,23 +1,23 @@ -package main - -import ( - "sync" - - "github.com/go-playground/validator/v10" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -var getValidator = sync.OnceValues(func() (*validator.Validate, *contract.Error) { - validate, err := validation.NewValidator() - if err != nil { - return nil, contract.NewError( - protos.ErrorCode_INTERNAL_ERROR, - err.Error(), - ) - } - - return validate, nil -}) +package main + +import ( + "sync" + + "github.com/go-playground/validator/v10" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +var getValidator = sync.OnceValues(func() (*validator.Validate, *contract.Error) { + validate, err := validation.NewValidator() + if err != nil { + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + err.Error(), + ) + } + + return validate, nil +}) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index 1c18d8f..f4b3d3f 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -1,21 +1,21 @@ -package service - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -func (m *ModelRegistryService) GetLatestVersions( - ctx context.Context, input *protos.GetLatestVersions, -) (*protos.GetLatestVersions_Response, *contract.Error) { - latestVersions, err := m.store.GetLatestVersions(ctx, input.GetName(), input.GetStages()) - if err != nil { - return nil, err - } - - return &protos.GetLatestVersions_Response{ - ModelVersions: latestVersions, - }, nil -} +package service + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +func (m *ModelRegistryService) GetLatestVersions( + ctx context.Context, input *protos.GetLatestVersions, +) (*protos.GetLatestVersions_Response, *contract.Error) { + latestVersions, err := m.store.GetLatestVersions(ctx, input.GetName(), input.GetStages()) + if err != nil { + return nil, err + } + + return &protos.GetLatestVersions_Response{ + ModelVersions: latestVersions, + }, nil +} diff --git a/pkg/model_registry/service/service.go b/pkg/model_registry/service/service.go index ae59d3e..8be5f38 100644 --- a/pkg/model_registry/service/service.go +++ b/pkg/model_registry/service/service.go @@ -1,27 +1,27 @@ -package service - -import ( - "context" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/model_registry/store" - "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql" -) - -type ModelRegistryService struct { - store store.ModelRegistryStore - config *config.Config -} - -func NewModelRegistryService(ctx context.Context, config *config.Config) (*ModelRegistryService, error) { - store, err := sql.NewModelRegistrySQLStore(ctx, config) - if err != nil { - return nil, fmt.Errorf("failed to create new sql store: %w", err) - } - - return &ModelRegistryService{ - store: store, - config: config, - }, nil -} +package service + +import ( + "context" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/model_registry/store" + "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql" +) + +type ModelRegistryService struct { + store store.ModelRegistryStore + config *config.Config +} + +func NewModelRegistryService(ctx context.Context, config *config.Config) (*ModelRegistryService, error) { + store, err := sql.NewModelRegistrySQLStore(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to create new sql store: %w", err) + } + + return &ModelRegistryService{ + store: store, + config: config, + }, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index ccf5be0..cbe3b45 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -1,94 +1,94 @@ -package sql - -import ( - "context" - "errors" - "fmt" - "strings" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -// Validate whether there is a registered model with the given name. -func assertModelExists(db *gorm.DB, name string) *contract.Error { - if err := db.Select("name").Where("name = ?", name).First(&models.RegisteredModel{}).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("registered model with name=%q not found", name), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to query registered model with name=%q", name), - err, - ) - } - - return nil -} - -func (m *ModelRegistrySQLStore) GetLatestVersions( - ctx context.Context, name string, stages []string, -) ([]*protos.ModelVersion, *contract.Error) { - if err := assertModelExists(m.db.WithContext(ctx), name); err != nil { - return nil, err - } - - var modelVersions []*models.ModelVersion - - subQuery := m.db. - WithContext(ctx). - Model(&models.ModelVersion{}). - Select("name, MAX(version) AS max_version"). - Where("name = ?", name). - Where("current_stage <> ?", models.StageDeletedInternal). - Group("name, current_stage") - - if len(stages) > 0 { - for idx, stage := range stages { - stages[idx] = strings.ToLower(stage) - if canonicalStage, ok := models.CanonicalMapping[stages[idx]]; ok { - stages[idx] = canonicalStage - - continue - } - - return nil, contract.NewError( - protos.ErrorCode_BAD_REQUEST, - fmt.Sprintf( - "Invalid Model Version stage: %s. Value must be one of %s.", - stage, - models.AllModelVersionStages(), - ), - ) - } - - subQuery = subQuery.Where("current_stage IN (?)", stages) - } - - err := m.db. - WithContext(ctx). - Model(&models.ModelVersion{}). - Joins("JOIN (?) AS sub ON model_versions.name = sub.name AND model_versions.version = sub.max_version", subQuery). - Find(&modelVersions).Error - if err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to query latest model version for %q", name), - err, - ) - } - - results := make([]*protos.ModelVersion, 0, len(modelVersions)) - for _, modelVersion := range modelVersions { - results = append(results, modelVersion.ToProto()) - } - - return results, nil -} +package sql + +import ( + "context" + "errors" + "fmt" + "strings" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +// Validate whether there is a registered model with the given name. +func assertModelExists(db *gorm.DB, name string) *contract.Error { + if err := db.Select("name").Where("name = ?", name).First(&models.RegisteredModel{}).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("registered model with name=%q not found", name), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to query registered model with name=%q", name), + err, + ) + } + + return nil +} + +func (m *ModelRegistrySQLStore) GetLatestVersions( + ctx context.Context, name string, stages []string, +) ([]*protos.ModelVersion, *contract.Error) { + if err := assertModelExists(m.db.WithContext(ctx), name); err != nil { + return nil, err + } + + var modelVersions []*models.ModelVersion + + subQuery := m.db. + WithContext(ctx). + Model(&models.ModelVersion{}). + Select("name, MAX(version) AS max_version"). + Where("name = ?", name). + Where("current_stage <> ?", models.StageDeletedInternal). + Group("name, current_stage") + + if len(stages) > 0 { + for idx, stage := range stages { + stages[idx] = strings.ToLower(stage) + if canonicalStage, ok := models.CanonicalMapping[stages[idx]]; ok { + stages[idx] = canonicalStage + + continue + } + + return nil, contract.NewError( + protos.ErrorCode_BAD_REQUEST, + fmt.Sprintf( + "Invalid Model Version stage: %s. Value must be one of %s.", + stage, + models.AllModelVersionStages(), + ), + ) + } + + subQuery = subQuery.Where("current_stage IN (?)", stages) + } + + err := m.db. + WithContext(ctx). + Model(&models.ModelVersion{}). + Joins("JOIN (?) AS sub ON model_versions.name = sub.name AND model_versions.version = sub.max_version", subQuery). + Find(&modelVersions).Error + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to query latest model version for %q", name), + err, + ) + } + + results := make([]*protos.ModelVersion, 0, len(modelVersions)) + for _, modelVersion := range modelVersions { + results = append(results, modelVersion.ToProto()) + } + + return results, nil +} diff --git a/pkg/model_registry/store/sql/models/model_version_stage.go b/pkg/model_registry/store/sql/models/model_version_stage.go index 56cc070..4020f36 100644 --- a/pkg/model_registry/store/sql/models/model_version_stage.go +++ b/pkg/model_registry/store/sql/models/model_version_stage.go @@ -1,33 +1,33 @@ -package models - -import "strings" - -type ModelVersionStage string - -func (s ModelVersionStage) String() string { - return string(s) -} - -const ( - ModelVersionStageNone = "None" - ModelVersionStageStaging = "Staging" - ModelVersionStageProduction = "Production" - ModelVersionStageArchived = "Archived" -) - -var CanonicalMapping = map[string]string{ - strings.ToLower(ModelVersionStageNone): ModelVersionStageNone, - strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging, - strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction, - strings.ToLower(ModelVersionStageArchived): ModelVersionStageArchived, -} - -func AllModelVersionStages() string { - pairs := make([]string, 0, len(CanonicalMapping)) - - for _, v := range CanonicalMapping { - pairs = append(pairs, v) - } - - return strings.Join(pairs, ",") -} +package models + +import "strings" + +type ModelVersionStage string + +func (s ModelVersionStage) String() string { + return string(s) +} + +const ( + ModelVersionStageNone = "None" + ModelVersionStageStaging = "Staging" + ModelVersionStageProduction = "Production" + ModelVersionStageArchived = "Archived" +) + +var CanonicalMapping = map[string]string{ + strings.ToLower(ModelVersionStageNone): ModelVersionStageNone, + strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging, + strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction, + strings.ToLower(ModelVersionStageArchived): ModelVersionStageArchived, +} + +func AllModelVersionStages() string { + pairs := make([]string, 0, len(CanonicalMapping)) + + for _, v := range CanonicalMapping { + pairs = append(pairs, v) + } + + return strings.Join(pairs, ",") +} diff --git a/pkg/model_registry/store/sql/models/model_version_tags.go b/pkg/model_registry/store/sql/models/model_version_tags.go index 16c4ab0..7bde392 100644 --- a/pkg/model_registry/store/sql/models/model_version_tags.go +++ b/pkg/model_registry/store/sql/models/model_version_tags.go @@ -1,11 +1,11 @@ -package models - -// ModelVersionTag mapped from table . -// -//revive:disable:exported -type ModelVersionTag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - Name string `db:"name" gorm:"column:name;primaryKey"` - Version int32 `db:"version" gorm:"column:version;primaryKey"` -} +package models + +// ModelVersionTag mapped from table . +// +//revive:disable:exported +type ModelVersionTag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + Name string `db:"name" gorm:"column:name;primaryKey"` + Version int32 `db:"version" gorm:"column:version;primaryKey"` +} diff --git a/pkg/model_registry/store/sql/models/model_versions.go b/pkg/model_registry/store/sql/models/model_versions.go index c6eb4bc..d2b373c 100644 --- a/pkg/model_registry/store/sql/models/model_versions.go +++ b/pkg/model_registry/store/sql/models/model_versions.go @@ -1,49 +1,49 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -// ModelVersion mapped from table . -// -//revive:disable:exported -type ModelVersion struct { - Name string `db:"name" gorm:"column:name;primaryKey"` - Version int32 `db:"version" gorm:"column:version;primaryKey"` - CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` - LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` - Description string `db:"description" gorm:"column:description"` - UserID string `db:"user_id" gorm:"column:user_id"` - CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"` - Source string `db:"source" gorm:"column:source"` - RunID string `db:"run_id" gorm:"column:run_id"` - Status string `db:"status" gorm:"column:status"` - StatusMessage string `db:"status_message" gorm:"column:status_message"` - RunLink string `db:"run_link" gorm:"column:run_link"` - StorageLocation string `db:"storage_location" gorm:"column:storage_location"` -} - -const StageDeletedInternal = "Deleted_Internal" - -func (mv ModelVersion) ToProto() *protos.ModelVersion { - var status *protos.ModelVersionStatus - if s, ok := protos.ModelVersionStatus_value[mv.Status]; ok { - status = utils.PtrTo(protos.ModelVersionStatus(s)) - } - - return &protos.ModelVersion{ - Name: &mv.Name, - Version: utils.ConvertInt32PointerToStringPointer(&mv.Version), - CreationTimestamp: &mv.CreationTime, - LastUpdatedTimestamp: &mv.LastUpdatedTime, - UserId: &mv.UserID, - CurrentStage: utils.PtrTo(mv.CurrentStage.String()), - Description: &mv.Description, - Source: &mv.Source, - RunId: &mv.RunID, - Status: status, - StatusMessage: &mv.StatusMessage, - RunLink: &mv.RunLink, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +// ModelVersion mapped from table . +// +//revive:disable:exported +type ModelVersion struct { + Name string `db:"name" gorm:"column:name;primaryKey"` + Version int32 `db:"version" gorm:"column:version;primaryKey"` + CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` + LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` + Description string `db:"description" gorm:"column:description"` + UserID string `db:"user_id" gorm:"column:user_id"` + CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"` + Source string `db:"source" gorm:"column:source"` + RunID string `db:"run_id" gorm:"column:run_id"` + Status string `db:"status" gorm:"column:status"` + StatusMessage string `db:"status_message" gorm:"column:status_message"` + RunLink string `db:"run_link" gorm:"column:run_link"` + StorageLocation string `db:"storage_location" gorm:"column:storage_location"` +} + +const StageDeletedInternal = "Deleted_Internal" + +func (mv ModelVersion) ToProto() *protos.ModelVersion { + var status *protos.ModelVersionStatus + if s, ok := protos.ModelVersionStatus_value[mv.Status]; ok { + status = utils.PtrTo(protos.ModelVersionStatus(s)) + } + + return &protos.ModelVersion{ + Name: &mv.Name, + Version: utils.ConvertInt32PointerToStringPointer(&mv.Version), + CreationTimestamp: &mv.CreationTime, + LastUpdatedTimestamp: &mv.LastUpdatedTime, + UserId: &mv.UserID, + CurrentStage: utils.PtrTo(mv.CurrentStage.String()), + Description: &mv.Description, + Source: &mv.Source, + RunId: &mv.RunID, + Status: status, + StatusMessage: &mv.StatusMessage, + RunLink: &mv.RunLink, + } +} diff --git a/pkg/model_registry/store/sql/models/registered_model_aliases.go b/pkg/model_registry/store/sql/models/registered_model_aliases.go index d720770..2cdf25a 100644 --- a/pkg/model_registry/store/sql/models/registered_model_aliases.go +++ b/pkg/model_registry/store/sql/models/registered_model_aliases.go @@ -1,8 +1,8 @@ -package models - -// RegisteredModelAlias mapped from table . -type RegisteredModelAlias struct { - Alias string `db:"alias" gorm:"column:alias;primaryKey"` - Version int32 `db:"version" gorm:"column:version;not null"` - Name string `db:"name" gorm:"column:name;primaryKey"` -} +package models + +// RegisteredModelAlias mapped from table . +type RegisteredModelAlias struct { + Alias string `db:"alias" gorm:"column:alias;primaryKey"` + Version int32 `db:"version" gorm:"column:version;not null"` + Name string `db:"name" gorm:"column:name;primaryKey"` +} diff --git a/pkg/model_registry/store/sql/models/registered_model_tags.go b/pkg/model_registry/store/sql/models/registered_model_tags.go index 99d4896..6935047 100644 --- a/pkg/model_registry/store/sql/models/registered_model_tags.go +++ b/pkg/model_registry/store/sql/models/registered_model_tags.go @@ -1,8 +1,8 @@ -package models - -// RegisteredModelTag mapped from table . -type RegisteredModelTag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - Name string `db:"name" gorm:"column:name;primaryKey"` -} +package models + +// RegisteredModelTag mapped from table . +type RegisteredModelTag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + Name string `db:"name" gorm:"column:name;primaryKey"` +} diff --git a/pkg/model_registry/store/sql/models/registered_models.go b/pkg/model_registry/store/sql/models/registered_models.go index d13b6d9..0a99a30 100644 --- a/pkg/model_registry/store/sql/models/registered_models.go +++ b/pkg/model_registry/store/sql/models/registered_models.go @@ -1,9 +1,9 @@ -package models - -// RegisteredModel mapped from table . -type RegisteredModel struct { - Name string `db:"name" gorm:"column:name;primaryKey"` - CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` - LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` - Description string `db:"description" gorm:"column:description"` -} +package models + +// RegisteredModel mapped from table . +type RegisteredModel struct { + Name string `db:"name" gorm:"column:name;primaryKey"` + CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` + LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` + Description string `db:"description" gorm:"column:description"` +} diff --git a/pkg/model_registry/store/sql/store.go b/pkg/model_registry/store/sql/store.go index a8a0f46..1bfe8b2 100644 --- a/pkg/model_registry/store/sql/store.go +++ b/pkg/model_registry/store/sql/store.go @@ -1,28 +1,28 @@ -package sql - -import ( - "context" - "fmt" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/sql" -) - -type ModelRegistrySQLStore struct { - config *config.Config - db *gorm.DB -} - -func NewModelRegistrySQLStore(ctx context.Context, config *config.Config) (*ModelRegistrySQLStore, error) { - database, err := sql.NewDatabase(ctx, config.ModelRegistryStoreURI) - if err != nil { - return nil, fmt.Errorf("failed to connect to database %q: %w", config.ModelRegistryStoreURI, err) - } - - return &ModelRegistrySQLStore{ - config: config, - db: database, - }, nil -} +package sql + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/sql" +) + +type ModelRegistrySQLStore struct { + config *config.Config + db *gorm.DB +} + +func NewModelRegistrySQLStore(ctx context.Context, config *config.Config) (*ModelRegistrySQLStore, error) { + database, err := sql.NewDatabase(ctx, config.ModelRegistryStoreURI) + if err != nil { + return nil, fmt.Errorf("failed to connect to database %q: %w", config.ModelRegistryStoreURI, err) + } + + return &ModelRegistrySQLStore{ + config: config, + db: database, + }, nil +} diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index f14ba8c..d9a5c21 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -1,12 +1,12 @@ -package store - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -type ModelRegistryStore interface { - GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) -} +package store + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +type ModelRegistryStore interface { + GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) +} diff --git a/pkg/server/command/command.go b/pkg/server/command/command.go index 515264a..ac218a1 100644 --- a/pkg/server/command/command.go +++ b/pkg/server/command/command.go @@ -1,42 +1,42 @@ -package command - -import ( - "context" - "fmt" - "os" - "os/exec" - "time" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func LaunchCommand(ctx context.Context, cfg *config.Config) error { - logger := utils.GetLoggerFromContext(ctx) - - //nolint:gosec - cmd, err := newProcessGroupCommand( - ctx, - exec.CommandContext(ctx, cfg.PythonCommand[0], cfg.PythonCommand[1:]...), - ) - if err != nil { - return fmt.Errorf("failed to create process group command: %w", err) - } - - cmd.Env = append(os.Environ(), cfg.PythonEnv...) - cmd.Stdout = logger.Writer() - cmd.Stderr = logger.Writer() - cmd.WaitDelay = 5 * time.Second //nolint:mnd - - logger.Debugf("Launching command: %v", cmd) - - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to launch command: %w", err) - } - - if err := cmd.Wait(); err != nil { - return fmt.Errorf("command exited with error: %w", err) - } - - return nil -} +package command + +import ( + "context" + "fmt" + "os" + "os/exec" + "time" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func LaunchCommand(ctx context.Context, cfg *config.Config) error { + logger := utils.GetLoggerFromContext(ctx) + + //nolint:gosec + cmd, err := newProcessGroupCommand( + ctx, + exec.CommandContext(ctx, cfg.PythonCommand[0], cfg.PythonCommand[1:]...), + ) + if err != nil { + return fmt.Errorf("failed to create process group command: %w", err) + } + + cmd.Env = append(os.Environ(), cfg.PythonEnv...) + cmd.Stdout = logger.Writer() + cmd.Stderr = logger.Writer() + cmd.WaitDelay = 5 * time.Second //nolint:mnd + + logger.Debugf("Launching command: %v", cmd) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to launch command: %w", err) + } + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("command exited with error: %w", err) + } + + return nil +} diff --git a/pkg/server/command/command_posix.go b/pkg/server/command/command_posix.go index 7982a02..3e71750 100644 --- a/pkg/server/command/command_posix.go +++ b/pkg/server/command/command_posix.go @@ -1,30 +1,30 @@ -//go:build !windows - -package command - -import ( - "context" - "os/exec" - "syscall" - - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*exec.Cmd, error) { - logger := utils.GetLoggerFromContext(ctx) - - // Create the process in a new process group - cmd.SysProcAttr = &syscall.SysProcAttr{ - Setpgid: true, - Pgid: 0, - } - - // Terminate the process group - cmd.Cancel = func() error { - logger.Debug("Sending interrupt signal to command process group") - - return syscall.Kill(-cmd.Process.Pid, syscall.SIGINT) - } - - return cmd, nil -} +//go:build !windows + +package command + +import ( + "context" + "os/exec" + "syscall" + + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*exec.Cmd, error) { + logger := utils.GetLoggerFromContext(ctx) + + // Create the process in a new process group + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Pgid: 0, + } + + // Terminate the process group + cmd.Cancel = func() error { + logger.Debug("Sending interrupt signal to command process group") + + return syscall.Kill(-cmd.Process.Pid, syscall.SIGINT) + } + + return cmd, nil +} diff --git a/pkg/server/command/command_windows.go b/pkg/server/command/command_windows.go index c94c20e..d1401e3 100644 --- a/pkg/server/command/command_windows.go +++ b/pkg/server/command/command_windows.go @@ -1,77 +1,77 @@ -package command - -import ( - "context" - "fmt" - "os/exec" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" - - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type processGroupCmd struct { - *exec.Cmd - job windows.Handle -} - -const PROCESS_ALL_ACCESS = 2097151 - -func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*processGroupCmd, error) { - logger := utils.GetLoggerFromContext(ctx) - - // Get the job object handle - jobHandle, err := windows.CreateJobObject(nil, nil) - if err != nil { - return nil, fmt.Errorf("failed to create job object: %w", err) - } - - // Set the job object to kill processes when the job is closed - info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ - BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ - LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, - }, - } - if _, err = windows.SetInformationJobObject( - jobHandle, - windows.JobObjectExtendedLimitInformation, - uintptr(unsafe.Pointer(&info)), - uint32(unsafe.Sizeof(info))); err != nil { - return nil, fmt.Errorf("failed to set job object information: %w", err) - } - - // Terminate the job object (which will terminate all processes in the job) - cmd.Cancel = func() error { - logger.Debug("Closing job object to terminate command process group") - - return windows.CloseHandle(jobHandle) - } - - // Create the process in a new process group - cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP} - - return &processGroupCmd{Cmd: cmd, job: jobHandle}, nil -} - -func (pgc *processGroupCmd) Start() error { - // Start the command - if err := pgc.Cmd.Start(); err != nil { - return fmt.Errorf("failed to start command: %w", err) - } - - // Get the process handle - hProc, err := windows.OpenProcess(PROCESS_ALL_ACCESS, true, uint32(pgc.Process.Pid)) - if err != nil { - return fmt.Errorf("failed to open process: %w", err) - } - defer windows.CloseHandle(hProc) - - // Assign the process to the job object - if err := windows.AssignProcessToJobObject(pgc.job, hProc); err != nil { - return fmt.Errorf("failed to assign process to job object: %w", err) - } - - return nil -} +package command + +import ( + "context" + "fmt" + "os/exec" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type processGroupCmd struct { + *exec.Cmd + job windows.Handle +} + +const PROCESS_ALL_ACCESS = 2097151 + +func newProcessGroupCommand(ctx context.Context, cmd *exec.Cmd) (*processGroupCmd, error) { + logger := utils.GetLoggerFromContext(ctx) + + // Get the job object handle + jobHandle, err := windows.CreateJobObject(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to create job object: %w", err) + } + + // Set the job object to kill processes when the job is closed + info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ + BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ + LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + } + if _, err = windows.SetInformationJobObject( + jobHandle, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info))); err != nil { + return nil, fmt.Errorf("failed to set job object information: %w", err) + } + + // Terminate the job object (which will terminate all processes in the job) + cmd.Cancel = func() error { + logger.Debug("Closing job object to terminate command process group") + + return windows.CloseHandle(jobHandle) + } + + // Create the process in a new process group + cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP} + + return &processGroupCmd{Cmd: cmd, job: jobHandle}, nil +} + +func (pgc *processGroupCmd) Start() error { + // Start the command + if err := pgc.Cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + // Get the process handle + hProc, err := windows.OpenProcess(PROCESS_ALL_ACCESS, true, uint32(pgc.Process.Pid)) + if err != nil { + return fmt.Errorf("failed to open process: %w", err) + } + defer windows.CloseHandle(hProc) + + // Assign the process to the job object + if err := windows.AssignProcessToJobObject(pgc.job, hProc); err != nil { + return fmt.Errorf("failed to assign process to job object: %w", err) + } + + return nil +} diff --git a/pkg/server/launch.go b/pkg/server/launch.go index 5f70369..8c50f32 100644 --- a/pkg/server/launch.go +++ b/pkg/server/launch.go @@ -1,86 +1,86 @@ -package server - -import ( - "context" - "errors" - "os" - "os/signal" - "sync" - "syscall" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/server/command" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func Launch(ctx context.Context, cfg *config.Config) error { - if len(cfg.PythonCommand) > 0 { - return launchCommandAndServer(ctx, cfg) - } - - return launchServer(ctx, cfg) -} - -func launchCommandAndServer(ctx context.Context, cfg *config.Config) error { - var errs []error - - logger := utils.GetLoggerFromContext(ctx) - - cmdCtx, cmdCancel := context.WithCancel(ctx) - srvCtx, srvCancel := context.WithCancel(ctx) - - waitGroup := sync.WaitGroup{} - waitGroup.Add(1) - - go func() { - defer waitGroup.Done() - - if err := command.LaunchCommand(cmdCtx, cfg); err != nil && cmdCtx.Err() == nil { - errs = append(errs, err) - } - - logger.Debug("Python server has exited") - - srvCancel() - }() - - waitGroup.Add(1) - - go func() { - defer waitGroup.Done() - - if err := launchServer(srvCtx, cfg); err != nil && srvCtx.Err() == nil { - errs = append(errs, err) - } - - logger.Debug("Go server has exited") - - cmdCancel() - }() - - waitGroup.Wait() - - return errors.Join(errs...) -} - -func LaunchWithSignalHandler(cfg *config.Config) error { - logger := utils.NewLoggerFromConfig(cfg) - - logger.Debugf("Loaded config: %#v", cfg) - - sigint := make(chan os.Signal, 1) - signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) - defer signal.Stop(sigint) - - ctx, cancel := context.WithCancel( - utils.NewContextWithLogger(context.Background(), logger)) - - go func() { - sig := <-sigint - logger.Debugf("Received signal: %v", sig) - - cancel() - }() - - return Launch(ctx, cfg) -} +package server + +import ( + "context" + "errors" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/server/command" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func Launch(ctx context.Context, cfg *config.Config) error { + if len(cfg.PythonCommand) > 0 { + return launchCommandAndServer(ctx, cfg) + } + + return launchServer(ctx, cfg) +} + +func launchCommandAndServer(ctx context.Context, cfg *config.Config) error { + var errs []error + + logger := utils.GetLoggerFromContext(ctx) + + cmdCtx, cmdCancel := context.WithCancel(ctx) + srvCtx, srvCancel := context.WithCancel(ctx) + + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + + go func() { + defer waitGroup.Done() + + if err := command.LaunchCommand(cmdCtx, cfg); err != nil && cmdCtx.Err() == nil { + errs = append(errs, err) + } + + logger.Debug("Python server has exited") + + srvCancel() + }() + + waitGroup.Add(1) + + go func() { + defer waitGroup.Done() + + if err := launchServer(srvCtx, cfg); err != nil && srvCtx.Err() == nil { + errs = append(errs, err) + } + + logger.Debug("Go server has exited") + + cmdCancel() + }() + + waitGroup.Wait() + + return errors.Join(errs...) +} + +func LaunchWithSignalHandler(cfg *config.Config) error { + logger := utils.NewLoggerFromConfig(cfg) + + logger.Debugf("Loaded config: %#v", cfg) + + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigint) + + ctx, cancel := context.WithCancel( + utils.NewContextWithLogger(context.Background(), logger)) + + go func() { + sig := <-sigint + logger.Debugf("Received signal: %v", sig) + + cancel() + }() + + return Launch(ctx, cfg) +} diff --git a/pkg/server/parser/http_request_parser.go b/pkg/server/parser/http_request_parser.go index 5e17242..9b9db53 100644 --- a/pkg/server/parser/http_request_parser.go +++ b/pkg/server/parser/http_request_parser.go @@ -1,72 +1,72 @@ -package parser - -import ( - "encoding/json" - "errors" - "fmt" - - "github.com/go-playground/validator/v10" - "github.com/gofiber/fiber/v2" - "github.com/tidwall/gjson" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -type HTTPRequestParser struct { - validator *validator.Validate -} - -func NewHTTPRequestParser() (*HTTPRequestParser, error) { - validator, err := validation.NewValidator() - if err != nil { - return nil, fmt.Errorf("failed to create validator: %w", err) - } - - return &HTTPRequestParser{ - validator: validator, - }, nil -} - -func (p *HTTPRequestParser) ParseBody(ctx *fiber.Ctx, input proto.Message) *contract.Error { - if protojsonErr := protojson.Unmarshal(ctx.Body(), input); protojsonErr != nil { - // falling back to JSON, because `protojson` doesn't provide any information - // about `field` name for which ut fails. MLFlow tests expect to know the exact - // `field` name where validation failed. This approach has no effect on MLFlow - // tests, so let's keep it for now. - if jsonErr := json.Unmarshal(ctx.Body(), input); jsonErr != nil { - var unmarshalTypeError *json.UnmarshalTypeError - if errors.As(jsonErr, &unmarshalTypeError) { - result := gjson.GetBytes(ctx.Body(), unmarshalTypeError.Field) - - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Invalid value %s for parameter '%s' supplied", result.Raw, unmarshalTypeError.Field), - ) - } - } - - return contract.NewError(protos.ErrorCode_BAD_REQUEST, protojsonErr.Error()) - } - - if err := p.validator.Struct(input); err != nil { - return validation.NewErrorFromValidationError(err) - } - - return nil -} - -func (p *HTTPRequestParser) ParseQuery(ctx *fiber.Ctx, input interface{}) *contract.Error { - if err := ctx.QueryParser(input); err != nil { - return contract.NewError(protos.ErrorCode_BAD_REQUEST, err.Error()) - } - - if err := p.validator.Struct(input); err != nil { - return validation.NewErrorFromValidationError(err) - } - - return nil -} +package parser + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" + "github.com/tidwall/gjson" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +type HTTPRequestParser struct { + validator *validator.Validate +} + +func NewHTTPRequestParser() (*HTTPRequestParser, error) { + validator, err := validation.NewValidator() + if err != nil { + return nil, fmt.Errorf("failed to create validator: %w", err) + } + + return &HTTPRequestParser{ + validator: validator, + }, nil +} + +func (p *HTTPRequestParser) ParseBody(ctx *fiber.Ctx, input proto.Message) *contract.Error { + if protojsonErr := protojson.Unmarshal(ctx.Body(), input); protojsonErr != nil { + // falling back to JSON, because `protojson` doesn't provide any information + // about `field` name for which ut fails. MLFlow tests expect to know the exact + // `field` name where validation failed. This approach has no effect on MLFlow + // tests, so let's keep it for now. + if jsonErr := json.Unmarshal(ctx.Body(), input); jsonErr != nil { + var unmarshalTypeError *json.UnmarshalTypeError + if errors.As(jsonErr, &unmarshalTypeError) { + result := gjson.GetBytes(ctx.Body(), unmarshalTypeError.Field) + + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Invalid value %s for parameter '%s' supplied", result.Raw, unmarshalTypeError.Field), + ) + } + } + + return contract.NewError(protos.ErrorCode_BAD_REQUEST, protojsonErr.Error()) + } + + if err := p.validator.Struct(input); err != nil { + return validation.NewErrorFromValidationError(err) + } + + return nil +} + +func (p *HTTPRequestParser) ParseQuery(ctx *fiber.Ctx, input interface{}) *contract.Error { + if err := ctx.QueryParser(input); err != nil { + return contract.NewError(protos.ErrorCode_BAD_REQUEST, err.Error()) + } + + if err := p.validator.Struct(input); err != nil { + return validation.NewErrorFromValidationError(err) + } + + return nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 3260613..59022a1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,223 +1,223 @@ -package server - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net" - "path/filepath" - "time" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/compress" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/fiber/v2/middleware/proxy" - "github.com/gofiber/fiber/v2/middleware/recover" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - - as "github.com/mlflow/mlflow-go/pkg/artifacts/service" - mr "github.com/mlflow/mlflow-go/pkg/model_registry/service" - ts "github.com/mlflow/mlflow-go/pkg/tracking/service" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/server/parser" - "github.com/mlflow/mlflow-go/pkg/server/routes" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -//nolint:funlen -func configureApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { - //nolint:mnd - app := fiber.New(fiber.Config{ - BodyLimit: 16 * 1024 * 1024, - ReadBufferSize: 16384, - ReadTimeout: 5 * time.Second, - WriteTimeout: 600 * time.Second, - IdleTimeout: 120 * time.Second, - ServerHeader: "mlflow/" + cfg.Version, - JSONEncoder: func(v interface{}) ([]byte, error) { - if protoMessage, ok := v.(proto.Message); ok { - return protojson.Marshal(protoMessage) - } - - return json.Marshal(v) - }, - JSONDecoder: func(data []byte, v interface{}) error { - if protoMessage, ok := v.(proto.Message); ok { - return protojson.Unmarshal(data, protoMessage) - } - - return json.Unmarshal(data, v) - }, - DisableStartupMessage: true, - }) - - app.Use(compress.New()) - app.Use(recover.New(recover.Config{EnableStackTrace: true})) - app.Use(logger.New(logger.Config{ - Format: "${status} - ${latency} ${method} ${path}\n", - Output: utils.GetLoggerFromContext(ctx).Writer(), - })) - app.Use(func(c *fiber.Ctx) error { - c.SetUserContext(ctx) - - return c.Next() - }) - - apiApp, err := newAPIApp(ctx, cfg) - if err != nil { - return nil, err - } - - app.Mount("/api/2.0", apiApp) - app.Mount("/ajax-api/2.0", apiApp) - - if cfg.StaticFolder != "" { - app.Static("/static-files", cfg.StaticFolder) - app.Get("/", func(c *fiber.Ctx) error { - return c.SendFile(filepath.Join(cfg.StaticFolder, "index.html")) - }) - } - - app.Get("/health", func(c *fiber.Ctx) error { - return c.SendString("OK") - }) - app.Get("/version", func(c *fiber.Ctx) error { - return c.SendString(cfg.Version) - }) - - if cfg.PythonAddress != "" { - app.Use(proxy.BalancerForward([]string{cfg.PythonAddress})) - } - - return app, nil -} - -func launchServer(ctx context.Context, cfg *config.Config) error { - logger := utils.GetLoggerFromContext(ctx) - - app, err := configureApp(ctx, cfg) - if err != nil { - return err - } - - go func() { - <-ctx.Done() - - logger.Info("Shutting down MLflow Go server") - - if err := app.ShutdownWithTimeout(cfg.ShutdownTimeout.Duration); err != nil { - logger.Errorf("Failed to gracefully shutdown MLflow Go server: %v", err) - } - }() - - if cfg.PythonAddress != "" { - logger.Debugf("Waiting for Python server to be ready on http://%s", cfg.PythonAddress) - - for { - dialer := &net.Dialer{} - conn, err := dialer.DialContext(ctx, "tcp", cfg.PythonAddress) - - if err == nil { - conn.Close() - - break - } - - if errors.Is(err, context.Canceled) { - return fmt.Errorf("failed to connect to Python server: %w", err) - } - - time.Sleep(50 * time.Millisecond) //nolint:mnd - } - logger.Debugf("Python server is ready on http://%s", cfg.PythonAddress) - } - - logger.Infof("Launching MLflow Go server on http://%s", cfg.Address) - - err = app.Listen(cfg.Address) - if err != nil { - return fmt.Errorf("failed to start MLflow Go server: %w", err) - } - - return nil -} - -func newFiberConfig() fiber.Config { - return fiber.Config{ - ErrorHandler: func(context *fiber.Ctx, err error) error { - var contractError *contract.Error - if !errors.As(err, &contractError) { - code := protos.ErrorCode_INTERNAL_ERROR - - var f *fiber.Error - if errors.As(err, &f) { - switch f.Code { - case fiber.StatusBadRequest: - code = protos.ErrorCode_BAD_REQUEST - case fiber.StatusServiceUnavailable: - code = protos.ErrorCode_SERVICE_UNDER_MAINTENANCE - case fiber.StatusNotFound: - code = protos.ErrorCode_ENDPOINT_NOT_FOUND - } - } - - contractError = contract.NewError(code, err.Error()) - } - - var logFn func(format string, args ...any) - - logger := utils.GetLoggerFromContext(context.Context()) - switch contractError.StatusCode() { - case fiber.StatusBadRequest: - logFn = logger.Infof - case fiber.StatusServiceUnavailable: - logFn = logger.Warnf - case fiber.StatusNotFound: - logFn = logger.Debugf - default: - logFn = logger.Errorf - } - - logFn("Error encountered in %s %s: %s", context.Method(), context.Path(), err) - - return context.Status(contractError.StatusCode()).JSON(contractError) - }, - } -} - -func newAPIApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { - app := fiber.New(newFiberConfig()) - - parser, err := parser.NewHTTPRequestParser() - if err != nil { - return nil, fmt.Errorf("failed to create new HTTP request parser: %w", err) - } - - trackingService, err := ts.NewTrackingService(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to create new tracking service: %w", err) - } - - routes.RegisterTrackingServiceRoutes(trackingService, parser, app) - - modelRegistryService, err := mr.NewModelRegistryService(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to create new model registry service: %w", err) - } - - routes.RegisterModelRegistryServiceRoutes(modelRegistryService, parser, app) - - artifactService, err := as.NewArtifactsService(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to create new artifacts service: %w", err) - } - - routes.RegisterArtifactsServiceRoutes(artifactService, parser, app) - - return app, nil -} +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "path/filepath" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/compress" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/proxy" + "github.com/gofiber/fiber/v2/middleware/recover" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + as "github.com/mlflow/mlflow-go/pkg/artifacts/service" + mr "github.com/mlflow/mlflow-go/pkg/model_registry/service" + ts "github.com/mlflow/mlflow-go/pkg/tracking/service" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/server/parser" + "github.com/mlflow/mlflow-go/pkg/server/routes" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +//nolint:funlen +func configureApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { + //nolint:mnd + app := fiber.New(fiber.Config{ + BodyLimit: 16 * 1024 * 1024, + ReadBufferSize: 16384, + ReadTimeout: 5 * time.Second, + WriteTimeout: 600 * time.Second, + IdleTimeout: 120 * time.Second, + ServerHeader: "mlflow/" + cfg.Version, + JSONEncoder: func(v interface{}) ([]byte, error) { + if protoMessage, ok := v.(proto.Message); ok { + return protojson.Marshal(protoMessage) + } + + return json.Marshal(v) + }, + JSONDecoder: func(data []byte, v interface{}) error { + if protoMessage, ok := v.(proto.Message); ok { + return protojson.Unmarshal(data, protoMessage) + } + + return json.Unmarshal(data, v) + }, + DisableStartupMessage: true, + }) + + app.Use(compress.New()) + app.Use(recover.New(recover.Config{EnableStackTrace: true})) + app.Use(logger.New(logger.Config{ + Format: "${status} - ${latency} ${method} ${path}\n", + Output: utils.GetLoggerFromContext(ctx).Writer(), + })) + app.Use(func(c *fiber.Ctx) error { + c.SetUserContext(ctx) + + return c.Next() + }) + + apiApp, err := newAPIApp(ctx, cfg) + if err != nil { + return nil, err + } + + app.Mount("/api/2.0", apiApp) + app.Mount("/ajax-api/2.0", apiApp) + + if cfg.StaticFolder != "" { + app.Static("/static-files", cfg.StaticFolder) + app.Get("/", func(c *fiber.Ctx) error { + return c.SendFile(filepath.Join(cfg.StaticFolder, "index.html")) + }) + } + + app.Get("/health", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) + app.Get("/version", func(c *fiber.Ctx) error { + return c.SendString(cfg.Version) + }) + + if cfg.PythonAddress != "" { + app.Use(proxy.BalancerForward([]string{cfg.PythonAddress})) + } + + return app, nil +} + +func launchServer(ctx context.Context, cfg *config.Config) error { + logger := utils.GetLoggerFromContext(ctx) + + app, err := configureApp(ctx, cfg) + if err != nil { + return err + } + + go func() { + <-ctx.Done() + + logger.Info("Shutting down MLflow Go server") + + if err := app.ShutdownWithTimeout(cfg.ShutdownTimeout.Duration); err != nil { + logger.Errorf("Failed to gracefully shutdown MLflow Go server: %v", err) + } + }() + + if cfg.PythonAddress != "" { + logger.Debugf("Waiting for Python server to be ready on http://%s", cfg.PythonAddress) + + for { + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", cfg.PythonAddress) + + if err == nil { + conn.Close() + + break + } + + if errors.Is(err, context.Canceled) { + return fmt.Errorf("failed to connect to Python server: %w", err) + } + + time.Sleep(50 * time.Millisecond) //nolint:mnd + } + logger.Debugf("Python server is ready on http://%s", cfg.PythonAddress) + } + + logger.Infof("Launching MLflow Go server on http://%s", cfg.Address) + + err = app.Listen(cfg.Address) + if err != nil { + return fmt.Errorf("failed to start MLflow Go server: %w", err) + } + + return nil +} + +func newFiberConfig() fiber.Config { + return fiber.Config{ + ErrorHandler: func(context *fiber.Ctx, err error) error { + var contractError *contract.Error + if !errors.As(err, &contractError) { + code := protos.ErrorCode_INTERNAL_ERROR + + var f *fiber.Error + if errors.As(err, &f) { + switch f.Code { + case fiber.StatusBadRequest: + code = protos.ErrorCode_BAD_REQUEST + case fiber.StatusServiceUnavailable: + code = protos.ErrorCode_SERVICE_UNDER_MAINTENANCE + case fiber.StatusNotFound: + code = protos.ErrorCode_ENDPOINT_NOT_FOUND + } + } + + contractError = contract.NewError(code, err.Error()) + } + + var logFn func(format string, args ...any) + + logger := utils.GetLoggerFromContext(context.Context()) + switch contractError.StatusCode() { + case fiber.StatusBadRequest: + logFn = logger.Infof + case fiber.StatusServiceUnavailable: + logFn = logger.Warnf + case fiber.StatusNotFound: + logFn = logger.Debugf + default: + logFn = logger.Errorf + } + + logFn("Error encountered in %s %s: %s", context.Method(), context.Path(), err) + + return context.Status(contractError.StatusCode()).JSON(contractError) + }, + } +} + +func newAPIApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) { + app := fiber.New(newFiberConfig()) + + parser, err := parser.NewHTTPRequestParser() + if err != nil { + return nil, fmt.Errorf("failed to create new HTTP request parser: %w", err) + } + + trackingService, err := ts.NewTrackingService(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create new tracking service: %w", err) + } + + routes.RegisterTrackingServiceRoutes(trackingService, parser, app) + + modelRegistryService, err := mr.NewModelRegistryService(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create new model registry service: %w", err) + } + + routes.RegisterModelRegistryServiceRoutes(modelRegistryService, parser, app) + + artifactService, err := as.NewArtifactsService(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create new artifacts service: %w", err) + } + + routes.RegisterArtifactsServiceRoutes(artifactService, parser, app) + + return app, nil +} diff --git a/pkg/sql/logger.go b/pkg/sql/logger.go index 3939229..2699168 100644 --- a/pkg/sql/logger.go +++ b/pkg/sql/logger.go @@ -1,139 +1,139 @@ -//nolint:goprintffuncname -package sql - -import ( - "context" - "errors" - "fmt" - "runtime" - "strings" - "time" - - "github.com/sirupsen/logrus" - "gorm.io/gorm" - "gorm.io/gorm/logger" -) - -type loggerAdaptor struct { - Logger *logrus.Logger - Config LoggerAdaptorConfig -} - -type LoggerAdaptorConfig struct { - SlowThreshold time.Duration - IgnoreRecordNotFoundError bool - ParameterizedQueries bool -} - -// NewLoggerAdaptor creates a new logger adaptor. -// -//nolint:ireturn -func NewLoggerAdaptor(l *logrus.Logger, cfg LoggerAdaptorConfig) logger.Interface { - return &loggerAdaptor{l, cfg} -} - -// LogMode implements the gorm.io/gorm/logger.Interface interface and is a no-op. -// -//nolint:ireturn -func (l *loggerAdaptor) LogMode(_ logger.LogLevel) logger.Interface { - return l -} - -const ( - maximumCallerDepth int = 15 - minimumCallerDepth int = 4 -) - -// getLoggerEntry gets a logger entry with context and caller information added. -func (l *loggerAdaptor) getLoggerEntry(ctx context.Context) *logrus.Entry { - entry := l.Logger.WithContext(ctx) - // We want to report the caller of the function that called gorm's logger, - // not the caller of the loggerAdaptor, so we skip the first few frames and - // then look for the first frame that is not in the gorm package. - pcs := make([]uintptr, maximumCallerDepth) - depth := runtime.Callers(minimumCallerDepth, pcs) - frames := runtime.CallersFrames(pcs[:depth]) - - for f, again := frames.Next(); again; f, again = frames.Next() { - if !strings.HasPrefix(f.Function, "gorm.io/gorm") { - entry = entry.WithFields(logrus.Fields{ - "app_file": fmt.Sprintf("%s:%d", f.File, f.Line), - "app_func": f.Function + "()", - }) - - break - } - } - - return entry -} - -// Info logs message at info level and implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Info(ctx context.Context, format string, args ...interface{}) { - l.getLoggerEntry(ctx).Infof(format, args...) -} - -// Warn logs message at warn level and implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Warn(ctx context.Context, format string, args ...interface{}) { - l.getLoggerEntry(ctx).Warnf(format, args...) -} - -// Error logs message at error level and implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Error(ctx context.Context, format string, args ...interface{}) { - l.getLoggerEntry(ctx).Errorf(format, args...) -} - -const NanosecondsPerMillisecond = 1e6 - -// getLoggerEntryWithSQL gets a logger entry with context, caller information and SQL information added. -func (l *loggerAdaptor) getLoggerEntryWithSQL( - ctx context.Context, - elapsed time.Duration, - fc func() (sql string, rowsAffected int64), -) *logrus.Entry { - entry := l.getLoggerEntry(ctx) - - if fc != nil { - sql, rows := fc() - entry = entry.WithFields(logrus.Fields{ - "elapsed": fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/NanosecondsPerMillisecond), - "rows": rows, - "sql": sql, - }) - - if rows == -1 { - entry = entry.WithField("rows", "-") - } - } - - return entry -} - -// Trace logs SQL statement, amount of affected rows, and elapsed time. -// It implements the gorm.io/gorm/logger.Interface interface. -func (l *loggerAdaptor) Trace( - ctx context.Context, - begin time.Time, - function func() (sql string, rowsAffected int64), - err error, -) { - if l.Logger.GetLevel() <= logrus.FatalLevel { - return - } - - // This logic is similar to the default logger in gorm.io/gorm/logger. - elapsed := time.Since(begin) - - switch { - case err != nil && - l.Logger.IsLevelEnabled(logrus.ErrorLevel) && - (!errors.Is(err, gorm.ErrRecordNotFound) || !l.Config.IgnoreRecordNotFoundError): - l.getLoggerEntryWithSQL(ctx, elapsed, function).WithError(err).Error("SQL error") - case elapsed > l.Config.SlowThreshold && - l.Config.SlowThreshold != 0 && - l.Logger.IsLevelEnabled(logrus.WarnLevel): - l.getLoggerEntryWithSQL(ctx, elapsed, function).Warnf("SLOW SQL >= %v", l.Config.SlowThreshold) - case l.Logger.IsLevelEnabled(logrus.DebugLevel): - l.getLoggerEntryWithSQL(ctx, elapsed, function).Debug("SQL trace") - } -} +//nolint:goprintffuncname +package sql + +import ( + "context" + "errors" + "fmt" + "runtime" + "strings" + "time" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +type loggerAdaptor struct { + Logger *logrus.Logger + Config LoggerAdaptorConfig +} + +type LoggerAdaptorConfig struct { + SlowThreshold time.Duration + IgnoreRecordNotFoundError bool + ParameterizedQueries bool +} + +// NewLoggerAdaptor creates a new logger adaptor. +// +//nolint:ireturn +func NewLoggerAdaptor(l *logrus.Logger, cfg LoggerAdaptorConfig) logger.Interface { + return &loggerAdaptor{l, cfg} +} + +// LogMode implements the gorm.io/gorm/logger.Interface interface and is a no-op. +// +//nolint:ireturn +func (l *loggerAdaptor) LogMode(_ logger.LogLevel) logger.Interface { + return l +} + +const ( + maximumCallerDepth int = 15 + minimumCallerDepth int = 4 +) + +// getLoggerEntry gets a logger entry with context and caller information added. +func (l *loggerAdaptor) getLoggerEntry(ctx context.Context) *logrus.Entry { + entry := l.Logger.WithContext(ctx) + // We want to report the caller of the function that called gorm's logger, + // not the caller of the loggerAdaptor, so we skip the first few frames and + // then look for the first frame that is not in the gorm package. + pcs := make([]uintptr, maximumCallerDepth) + depth := runtime.Callers(minimumCallerDepth, pcs) + frames := runtime.CallersFrames(pcs[:depth]) + + for f, again := frames.Next(); again; f, again = frames.Next() { + if !strings.HasPrefix(f.Function, "gorm.io/gorm") { + entry = entry.WithFields(logrus.Fields{ + "app_file": fmt.Sprintf("%s:%d", f.File, f.Line), + "app_func": f.Function + "()", + }) + + break + } + } + + return entry +} + +// Info logs message at info level and implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Info(ctx context.Context, format string, args ...interface{}) { + l.getLoggerEntry(ctx).Infof(format, args...) +} + +// Warn logs message at warn level and implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Warn(ctx context.Context, format string, args ...interface{}) { + l.getLoggerEntry(ctx).Warnf(format, args...) +} + +// Error logs message at error level and implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Error(ctx context.Context, format string, args ...interface{}) { + l.getLoggerEntry(ctx).Errorf(format, args...) +} + +const NanosecondsPerMillisecond = 1e6 + +// getLoggerEntryWithSQL gets a logger entry with context, caller information and SQL information added. +func (l *loggerAdaptor) getLoggerEntryWithSQL( + ctx context.Context, + elapsed time.Duration, + fc func() (sql string, rowsAffected int64), +) *logrus.Entry { + entry := l.getLoggerEntry(ctx) + + if fc != nil { + sql, rows := fc() + entry = entry.WithFields(logrus.Fields{ + "elapsed": fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/NanosecondsPerMillisecond), + "rows": rows, + "sql": sql, + }) + + if rows == -1 { + entry = entry.WithField("rows", "-") + } + } + + return entry +} + +// Trace logs SQL statement, amount of affected rows, and elapsed time. +// It implements the gorm.io/gorm/logger.Interface interface. +func (l *loggerAdaptor) Trace( + ctx context.Context, + begin time.Time, + function func() (sql string, rowsAffected int64), + err error, +) { + if l.Logger.GetLevel() <= logrus.FatalLevel { + return + } + + // This logic is similar to the default logger in gorm.io/gorm/logger. + elapsed := time.Since(begin) + + switch { + case err != nil && + l.Logger.IsLevelEnabled(logrus.ErrorLevel) && + (!errors.Is(err, gorm.ErrRecordNotFound) || !l.Config.IgnoreRecordNotFoundError): + l.getLoggerEntryWithSQL(ctx, elapsed, function).WithError(err).Error("SQL error") + case elapsed > l.Config.SlowThreshold && + l.Config.SlowThreshold != 0 && + l.Logger.IsLevelEnabled(logrus.WarnLevel): + l.getLoggerEntryWithSQL(ctx, elapsed, function).Warnf("SLOW SQL >= %v", l.Config.SlowThreshold) + case l.Logger.IsLevelEnabled(logrus.DebugLevel): + l.getLoggerEntryWithSQL(ctx, elapsed, function).Debug("SQL trace") + } +} diff --git a/pkg/sql/sql.go b/pkg/sql/sql.go index bc6b8e5..a7e88d4 100644 --- a/pkg/sql/sql.go +++ b/pkg/sql/sql.go @@ -1,90 +1,90 @@ -package sql - -import ( - "context" - "errors" - "fmt" - "net/url" - "strings" - - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/driver/sqlserver" - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/utils" -) - -var errSqliteMemory = errors.New("go implementation does not support :memory: for sqlite") - -//nolint:ireturn -func getDialector(uri *url.URL) (gorm.Dialector, error) { - uri.Scheme, _, _ = strings.Cut(uri.Scheme, "+") - - switch uri.Scheme { - case "mssql": - uri.Scheme = "sqlserver" - - return sqlserver.Open(uri.String()), nil - case "mysql": - return mysql.Open(fmt.Sprintf("%s@tcp(%s)%s?%s", uri.User, uri.Host, uri.Path, uri.RawQuery)), nil - case "postgres", "postgresql": - return postgres.Open(uri.String()), nil - case "sqlite": - uri.Scheme = "" - uri.Path = uri.Path[1:] - - if uri.Path == ":memory:" { - return nil, errSqliteMemory - } - - return sqlite.Open(uri.String()), nil - default: - return nil, fmt.Errorf("unsupported store URL scheme %q", uri.Scheme) //nolint:err113 - } -} - -func initSqlite(database *gorm.DB) error { - database.Exec("PRAGMA case_sensitive_like = true;") - - sqlDB, err := database.DB() - if err != nil { - return fmt.Errorf("failed to get database instance: %w", err) - } - // set SetMaxOpenConns to be 1 only in case of SQLite to avoid `database is locked` - // in case of parallel calls to some endpoints that use `transactions`. - sqlDB.SetMaxOpenConns(1) - - return nil -} - -func NewDatabase(ctx context.Context, storeURL string) (*gorm.DB, error) { - logger := utils.GetLoggerFromContext(ctx) - - uri, err := url.Parse(storeURL) - if err != nil { - return nil, fmt.Errorf("failed to parse store URL %q: %w", storeURL, err) - } - - dialector, err := getDialector(uri) - if err != nil { - return nil, err - } - - database, err := gorm.Open(dialector, &gorm.Config{ - TranslateError: true, - Logger: NewLoggerAdaptor(logger, LoggerAdaptorConfig{}), - }) - if err != nil { - return nil, fmt.Errorf("failed to connect to database %q: %w", uri.String(), err) - } - - if dialector.Name() == "sqlite" { - if err := initSqlite(database); err != nil { - return nil, err - } - } - - return database, nil -} +package sql + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/utils" +) + +var errSqliteMemory = errors.New("go implementation does not support :memory: for sqlite") + +//nolint:ireturn +func getDialector(uri *url.URL) (gorm.Dialector, error) { + uri.Scheme, _, _ = strings.Cut(uri.Scheme, "+") + + switch uri.Scheme { + case "mssql": + uri.Scheme = "sqlserver" + + return sqlserver.Open(uri.String()), nil + case "mysql": + return mysql.Open(fmt.Sprintf("%s@tcp(%s)%s?%s", uri.User, uri.Host, uri.Path, uri.RawQuery)), nil + case "postgres", "postgresql": + return postgres.Open(uri.String()), nil + case "sqlite": + uri.Scheme = "" + uri.Path = uri.Path[1:] + + if uri.Path == ":memory:" { + return nil, errSqliteMemory + } + + return sqlite.Open(uri.String()), nil + default: + return nil, fmt.Errorf("unsupported store URL scheme %q", uri.Scheme) //nolint:err113 + } +} + +func initSqlite(database *gorm.DB) error { + database.Exec("PRAGMA case_sensitive_like = true;") + + sqlDB, err := database.DB() + if err != nil { + return fmt.Errorf("failed to get database instance: %w", err) + } + // set SetMaxOpenConns to be 1 only in case of SQLite to avoid `database is locked` + // in case of parallel calls to some endpoints that use `transactions`. + sqlDB.SetMaxOpenConns(1) + + return nil +} + +func NewDatabase(ctx context.Context, storeURL string) (*gorm.DB, error) { + logger := utils.GetLoggerFromContext(ctx) + + uri, err := url.Parse(storeURL) + if err != nil { + return nil, fmt.Errorf("failed to parse store URL %q: %w", storeURL, err) + } + + dialector, err := getDialector(uri) + if err != nil { + return nil, err + } + + database, err := gorm.Open(dialector, &gorm.Config{ + TranslateError: true, + Logger: NewLoggerAdaptor(logger, LoggerAdaptorConfig{}), + }) + if err != nil { + return nil, fmt.Errorf("failed to connect to database %q: %w", uri.String(), err) + } + + if dialector.Name() == "sqlite" { + if err := initSqlite(database); err != nil { + return nil, err + } + } + + return database, nil +} diff --git a/pkg/tracking/service/experiments.go b/pkg/tracking/service/experiments.go index 6e4291f..0eeefa2 100644 --- a/pkg/tracking/service/experiments.go +++ b/pkg/tracking/service/experiments.go @@ -1,134 +1,134 @@ -package service - -import ( - "context" - "fmt" - "net/url" - "path/filepath" - "runtime" - "strings" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -// CreateExperiment implements TrackingService. -func (ts TrackingService) CreateExperiment(ctx context.Context, input *protos.CreateExperiment) ( - *protos.CreateExperiment_Response, *contract.Error, -) { - if input.GetArtifactLocation() != "" { - artifactLocation := strings.TrimRight(input.GetArtifactLocation(), "/") - - // We don't check the validation here as this was already covered in the validator. - url, _ := url.Parse(artifactLocation) - switch url.Scheme { - case "file", "": - path, err := filepath.Abs(url.Path) - if err != nil { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("error getting absolute path: %v", err), - ) - } - - if runtime.GOOS == "windows" { - url.Scheme = "file" - path = "/" + strings.ReplaceAll(path, "\\", "/") - } - - url.Path = path - artifactLocation = url.String() - } - - input.ArtifactLocation = &artifactLocation - } - - tags := make([]*entities.ExperimentTag, len(input.GetTags())) - for i, tag := range input.GetTags() { - tags[i] = entities.NewExperimentTagFromProto(tag) - } - - experimentID, err := ts.Store.CreateExperiment(ctx, input.GetName(), input.GetArtifactLocation(), tags) - if err != nil { - return nil, err - } - - return &protos.CreateExperiment_Response{ - ExperimentId: &experimentID, - }, nil -} - -// GetExperiment implements TrackingService. -func (ts TrackingService) GetExperiment( - ctx context.Context, input *protos.GetExperiment, -) (*protos.GetExperiment_Response, *contract.Error) { - experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - return &protos.GetExperiment_Response{ - Experiment: experiment.ToProto(), - }, nil -} - -func (ts TrackingService) DeleteExperiment( - ctx context.Context, input *protos.DeleteExperiment, -) (*protos.DeleteExperiment_Response, *contract.Error) { - err := ts.Store.DeleteExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - return &protos.DeleteExperiment_Response{}, nil -} - -func (ts TrackingService) RestoreExperiment( - ctx context.Context, input *protos.RestoreExperiment, -) (*protos.RestoreExperiment_Response, *contract.Error) { - err := ts.Store.RestoreExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - return &protos.RestoreExperiment_Response{}, nil -} - -func (ts TrackingService) UpdateExperiment( - ctx context.Context, input *protos.UpdateExperiment, -) (*protos.UpdateExperiment_Response, *contract.Error) { - experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) - if err != nil { - return nil, err - } - - if experiment.LifecycleStage != string(models.LifecycleStageActive) { - return nil, contract.NewError( - protos.ErrorCode_INVALID_STATE, - "Cannot rename a non-active experiment.", - ) - } - - if name := input.GetNewName(); name != "" { - if err := ts.Store.RenameExperiment(ctx, input.GetExperimentId(), input.GetNewName()); err != nil { - return nil, err - } - } - - return &protos.UpdateExperiment_Response{}, nil -} - -func (ts TrackingService) GetExperimentByName( - ctx context.Context, input *protos.GetExperimentByName, -) (*protos.GetExperimentByName_Response, *contract.Error) { - experiment, err := ts.Store.GetExperimentByName(ctx, input.GetExperimentName()) - if err != nil { - return nil, err - } - - return &protos.GetExperimentByName_Response{ - Experiment: experiment.ToProto(), - }, nil -} +package service + +import ( + "context" + "fmt" + "net/url" + "path/filepath" + "runtime" + "strings" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +// CreateExperiment implements TrackingService. +func (ts TrackingService) CreateExperiment(ctx context.Context, input *protos.CreateExperiment) ( + *protos.CreateExperiment_Response, *contract.Error, +) { + if input.GetArtifactLocation() != "" { + artifactLocation := strings.TrimRight(input.GetArtifactLocation(), "/") + + // We don't check the validation here as this was already covered in the validator. + url, _ := url.Parse(artifactLocation) + switch url.Scheme { + case "file", "": + path, err := filepath.Abs(url.Path) + if err != nil { + return nil, contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("error getting absolute path: %v", err), + ) + } + + if runtime.GOOS == "windows" { + url.Scheme = "file" + path = "/" + strings.ReplaceAll(path, "\\", "/") + } + + url.Path = path + artifactLocation = url.String() + } + + input.ArtifactLocation = &artifactLocation + } + + tags := make([]*entities.ExperimentTag, len(input.GetTags())) + for i, tag := range input.GetTags() { + tags[i] = entities.NewExperimentTagFromProto(tag) + } + + experimentID, err := ts.Store.CreateExperiment(ctx, input.GetName(), input.GetArtifactLocation(), tags) + if err != nil { + return nil, err + } + + return &protos.CreateExperiment_Response{ + ExperimentId: &experimentID, + }, nil +} + +// GetExperiment implements TrackingService. +func (ts TrackingService) GetExperiment( + ctx context.Context, input *protos.GetExperiment, +) (*protos.GetExperiment_Response, *contract.Error) { + experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + return &protos.GetExperiment_Response{ + Experiment: experiment.ToProto(), + }, nil +} + +func (ts TrackingService) DeleteExperiment( + ctx context.Context, input *protos.DeleteExperiment, +) (*protos.DeleteExperiment_Response, *contract.Error) { + err := ts.Store.DeleteExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + return &protos.DeleteExperiment_Response{}, nil +} + +func (ts TrackingService) RestoreExperiment( + ctx context.Context, input *protos.RestoreExperiment, +) (*protos.RestoreExperiment_Response, *contract.Error) { + err := ts.Store.RestoreExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + return &protos.RestoreExperiment_Response{}, nil +} + +func (ts TrackingService) UpdateExperiment( + ctx context.Context, input *protos.UpdateExperiment, +) (*protos.UpdateExperiment_Response, *contract.Error) { + experiment, err := ts.Store.GetExperiment(ctx, input.GetExperimentId()) + if err != nil { + return nil, err + } + + if experiment.LifecycleStage != string(models.LifecycleStageActive) { + return nil, contract.NewError( + protos.ErrorCode_INVALID_STATE, + "Cannot rename a non-active experiment.", + ) + } + + if name := input.GetNewName(); name != "" { + if err := ts.Store.RenameExperiment(ctx, input.GetExperimentId(), input.GetNewName()); err != nil { + return nil, err + } + } + + return &protos.UpdateExperiment_Response{}, nil +} + +func (ts TrackingService) GetExperimentByName( + ctx context.Context, input *protos.GetExperimentByName, +) (*protos.GetExperimentByName_Response, *contract.Error) { + experiment, err := ts.Store.GetExperimentByName(ctx, input.GetExperimentName()) + if err != nil { + return nil, err + } + + return &protos.GetExperimentByName_Response{ + Experiment: experiment.ToProto(), + }, nil +} diff --git a/pkg/tracking/service/experiments_test.go b/pkg/tracking/service/experiments_test.go index 52ec5fd..cbee4e9 100644 --- a/pkg/tracking/service/experiments_test.go +++ b/pkg/tracking/service/experiments_test.go @@ -1,61 +1,61 @@ -package service //nolint:testpackage - -import ( - "context" - "testing" - - "github.com/stretchr/testify/mock" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type testRelativeArtifactLocationScenario struct { - name string - input string -} - -func TestRelativeArtifactLocation(t *testing.T) { - t.Parallel() - - scenarios := []testRelativeArtifactLocationScenario{ - {name: "without scheme", input: "../yow"}, - {name: "with file scheme", input: "file:///../yow"}, - } - - for _, scenario := range scenarios { - t.Run(scenario.name, func(t *testing.T) { - t.Parallel() - - store := store.NewMockTrackingStore(t) - store.EXPECT().CreateExperiment( - context.Background(), - mock.Anything, - mock.Anything, - mock.Anything, - ).Return(mock.Anything, nil) - - service := TrackingService{ - Store: store, - } - - input := protos.CreateExperiment{ - ArtifactLocation: utils.PtrTo(scenario.input), - } - - response, err := service.CreateExperiment(context.Background(), &input) - if err != nil { - t.Error("expected create experiment to succeed") - } - - if response == nil { - t.Error("expected response to be non-nil") - } - - if input.GetArtifactLocation() == scenario.input { - t.Errorf("expected artifact location to be absolute, got %s", input.GetArtifactLocation()) - } - }) - } -} +package service //nolint:testpackage + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type testRelativeArtifactLocationScenario struct { + name string + input string +} + +func TestRelativeArtifactLocation(t *testing.T) { + t.Parallel() + + scenarios := []testRelativeArtifactLocationScenario{ + {name: "without scheme", input: "../yow"}, + {name: "with file scheme", input: "file:///../yow"}, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + t.Parallel() + + store := store.NewMockTrackingStore(t) + store.EXPECT().CreateExperiment( + context.Background(), + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(mock.Anything, nil) + + service := TrackingService{ + Store: store, + } + + input := protos.CreateExperiment{ + ArtifactLocation: utils.PtrTo(scenario.input), + } + + response, err := service.CreateExperiment(context.Background(), &input) + if err != nil { + t.Error("expected create experiment to succeed") + } + + if response == nil { + t.Error("expected response to be non-nil") + } + + if input.GetArtifactLocation() == scenario.input { + t.Errorf("expected artifact location to be absolute, got %s", input.GetArtifactLocation()) + } + }) + } +} diff --git a/pkg/tracking/service/metrics.go b/pkg/tracking/service/metrics.go index edb2005..e62bf0b 100644 --- a/pkg/tracking/service/metrics.go +++ b/pkg/tracking/service/metrics.go @@ -1,20 +1,20 @@ -package service - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -func (ts TrackingService) LogMetric( - ctx context.Context, - input *protos.LogMetric, -) (*protos.LogMetric_Response, *contract.Error) { - if err := ts.Store.LogMetric(ctx, input.GetRunId(), entities.MetricFromLogMetricProtoInput(input)); err != nil { - return nil, err - } - - return &protos.LogMetric_Response{}, nil -} +package service + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +func (ts TrackingService) LogMetric( + ctx context.Context, + input *protos.LogMetric, +) (*protos.LogMetric_Response, *contract.Error) { + if err := ts.Store.LogMetric(ctx, input.GetRunId(), entities.MetricFromLogMetricProtoInput(input)); err != nil { + return nil, err + } + + return &protos.LogMetric_Response{}, nil +} diff --git a/pkg/tracking/service/query/README.md b/pkg/tracking/service/query/README.md index a5dba0e..019bb82 100644 --- a/pkg/tracking/service/query/README.md +++ b/pkg/tracking/service/query/README.md @@ -1,8 +1,8 @@ -# Search Query Syntax - -Mlflow has a [query syntax](https://mlflow.org/docs/latest/search-runs.html#search-query-syntax-deep-dive). - -This package is meant to lex and parse this query dialect. - -The code is slightly based on the https://github.com/tlaceby/parser-series. -I did not implement a proper Pratt parser because of how limited the query language is. +# Search Query Syntax + +Mlflow has a [query syntax](https://mlflow.org/docs/latest/search-runs.html#search-query-syntax-deep-dive). + +This package is meant to lex and parse this query dialect. + +The code is slightly based on the https://github.com/tlaceby/parser-series. +I did not implement a proper Pratt parser because of how limited the query language is. diff --git a/pkg/tracking/service/query/lexer/token.go b/pkg/tracking/service/query/lexer/token.go index 064246a..97fddf3 100644 --- a/pkg/tracking/service/query/lexer/token.go +++ b/pkg/tracking/service/query/lexer/token.go @@ -1,111 +1,111 @@ -package lexer - -import "fmt" - -type TokenKind int - -const ( - EOF TokenKind = iota - Number - String - Identifier - - // Grouping & Braces. - OpenParen - CloseParen - - // Equivilance. - Equals - NotEquals - - // Conditional. - Less - LessEquals - Greater - GreaterEquals - - // Symbols. - Dot - Comma - - // Reserved Keywords. - In //nolint:varnamelen - Not - Like - ILike - And -) - -//nolint:gochecknoglobals -var reservedLu = map[string]TokenKind{ - "AND": And, - "NOT": Not, - "IN": In, - "LIKE": Like, - "ILIKE": ILike, -} - -type Token struct { - Kind TokenKind - Value string -} - -func (token Token) Debug() string { - if token.Kind == Identifier || token.Kind == Number || token.Kind == String { - return fmt.Sprintf("%s(%s)", TokenKindString(token.Kind), token.Value) - } - - return TokenKindString(token.Kind) -} - -//nolint:funlen,cyclop -func TokenKindString(kind TokenKind) string { - switch kind { - case EOF: - return "eof" - case Number: - return "number" - case String: - return "string" - case Identifier: - return "identifier" - case OpenParen: - return "open_paren" - case CloseParen: - return "close_paren" - case Equals: - return "equals" - case NotEquals: - return "not_equals" - case Less: - return "less" - case LessEquals: - return "less_equals" - case Greater: - return "greater" - case GreaterEquals: - return "greater_equals" - case And: - return "and" - case Dot: - return "dot" - case Comma: - return "comma" - case In: - return "in" - case Not: - return "not" - case Like: - return "like" - case ILike: - return "ilike" - default: - return fmt.Sprintf("unknown(%d)", kind) - } -} - -func newUniqueToken(kind TokenKind, value string) Token { - return Token{ - kind, value, - } -} +package lexer + +import "fmt" + +type TokenKind int + +const ( + EOF TokenKind = iota + Number + String + Identifier + + // Grouping & Braces. + OpenParen + CloseParen + + // Equivilance. + Equals + NotEquals + + // Conditional. + Less + LessEquals + Greater + GreaterEquals + + // Symbols. + Dot + Comma + + // Reserved Keywords. + In //nolint:varnamelen + Not + Like + ILike + And +) + +//nolint:gochecknoglobals +var reservedLu = map[string]TokenKind{ + "AND": And, + "NOT": Not, + "IN": In, + "LIKE": Like, + "ILIKE": ILike, +} + +type Token struct { + Kind TokenKind + Value string +} + +func (token Token) Debug() string { + if token.Kind == Identifier || token.Kind == Number || token.Kind == String { + return fmt.Sprintf("%s(%s)", TokenKindString(token.Kind), token.Value) + } + + return TokenKindString(token.Kind) +} + +//nolint:funlen,cyclop +func TokenKindString(kind TokenKind) string { + switch kind { + case EOF: + return "eof" + case Number: + return "number" + case String: + return "string" + case Identifier: + return "identifier" + case OpenParen: + return "open_paren" + case CloseParen: + return "close_paren" + case Equals: + return "equals" + case NotEquals: + return "not_equals" + case Less: + return "less" + case LessEquals: + return "less_equals" + case Greater: + return "greater" + case GreaterEquals: + return "greater_equals" + case And: + return "and" + case Dot: + return "dot" + case Comma: + return "comma" + case In: + return "in" + case Not: + return "not" + case Like: + return "like" + case ILike: + return "ilike" + default: + return fmt.Sprintf("unknown(%d)", kind) + } +} + +func newUniqueToken(kind TokenKind, value string) Token { + return Token{ + kind, value, + } +} diff --git a/pkg/tracking/service/query/lexer/tokenizer.go b/pkg/tracking/service/query/lexer/tokenizer.go index 6fd3c64..9d4fa61 100644 --- a/pkg/tracking/service/query/lexer/tokenizer.go +++ b/pkg/tracking/service/query/lexer/tokenizer.go @@ -1,145 +1,145 @@ -package lexer - -import ( - "fmt" - "regexp" - "strings" -) - -type regexPattern struct { - regex *regexp.Regexp - handler regexHandler -} - -type lexer struct { - patterns []regexPattern - Tokens []Token - source *string - pos int - line int -} - -type Error struct { - message string -} - -func NewLexerError(format string, a ...any) *Error { - return &Error{message: fmt.Sprintf(format, a...)} -} - -func (e *Error) Error() string { - return e.message -} - -func Tokenize(source *string) ([]Token, error) { - lex := createLexer(source) - - for !lex.atEOF() { - matched := false - - for _, pattern := range lex.patterns { - loc := pattern.regex.FindStringIndex(lex.remainder()) - if loc != nil && loc[0] == 0 { - pattern.handler(lex, pattern.regex) - - matched = true - - break // Exit the loop after the first match - } - } - - if !matched { - return lex.Tokens, NewLexerError("unrecognized token near '%v'", lex.remainder()) - } - } - - lex.push(newUniqueToken(EOF, "EOF")) - - return lex.Tokens, nil -} - -func (lex *lexer) advanceN(n int) { - lex.pos += n -} - -func (lex *lexer) remainder() string { - return (*lex.source)[lex.pos:] -} - -func (lex *lexer) push(token Token) { - lex.Tokens = append(lex.Tokens, token) -} - -func (lex *lexer) atEOF() bool { - return lex.pos >= len(*lex.source) -} - -func createLexer(source *string) *lexer { - return &lexer{ - pos: 0, - line: 1, - source: source, - Tokens: make([]Token, 0), - patterns: []regexPattern{ - {regexp.MustCompile(`\s+`), skipHandler}, - {regexp.MustCompile(`"[^"]*"`), stringHandler}, - {regexp.MustCompile(`'[^\']*\'`), stringHandler}, - {regexp.MustCompile("`[^`]*`"), stringHandler}, - {regexp.MustCompile(`\-?[0-9]+(\.[0-9]+)?`), numberHandler}, - {regexp.MustCompile(`[a-zA-Z_][a-zA-Z0-9_]*`), symbolHandler}, - {regexp.MustCompile(`\(`), defaultHandler(OpenParen, "(")}, - {regexp.MustCompile(`\)`), defaultHandler(CloseParen, ")")}, - {regexp.MustCompile(`!=`), defaultHandler(NotEquals, "!=")}, - {regexp.MustCompile(`=`), defaultHandler(Equals, "=")}, - {regexp.MustCompile(`<=`), defaultHandler(LessEquals, "<=")}, - {regexp.MustCompile(`<`), defaultHandler(Less, "<")}, - {regexp.MustCompile(`>=`), defaultHandler(GreaterEquals, ">=")}, - {regexp.MustCompile(`>`), defaultHandler(Greater, ">")}, - {regexp.MustCompile(`\.`), defaultHandler(Dot, ".")}, - {regexp.MustCompile(`,`), defaultHandler(Comma, ",")}, - }, - } -} - -type regexHandler func(lex *lexer, regex *regexp.Regexp) - -// Created a default handler which will simply create a token with the matched contents. -// This handler is used with most simple tokens. -func defaultHandler(kind TokenKind, value string) regexHandler { - return func(lex *lexer, _ *regexp.Regexp) { - lex.advanceN(len(value)) - lex.push(newUniqueToken(kind, value)) - } -} - -func stringHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindStringIndex(lex.remainder()) - stringLiteral := lex.remainder()[match[0]:match[1]] - - lex.push(newUniqueToken(String, stringLiteral)) - lex.advanceN(len(stringLiteral)) -} - -func numberHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindString(lex.remainder()) - lex.push(newUniqueToken(Number, match)) - lex.advanceN(len(match)) -} - -func symbolHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindString(lex.remainder()) - keyword := strings.ToUpper(match) - - if kind, found := reservedLu[keyword]; found { - lex.push(newUniqueToken(kind, match)) - } else { - lex.push(newUniqueToken(Identifier, match)) - } - - lex.advanceN(len(match)) -} - -func skipHandler(lex *lexer, regex *regexp.Regexp) { - match := regex.FindStringIndex(lex.remainder()) - lex.advanceN(match[1]) -} +package lexer + +import ( + "fmt" + "regexp" + "strings" +) + +type regexPattern struct { + regex *regexp.Regexp + handler regexHandler +} + +type lexer struct { + patterns []regexPattern + Tokens []Token + source *string + pos int + line int +} + +type Error struct { + message string +} + +func NewLexerError(format string, a ...any) *Error { + return &Error{message: fmt.Sprintf(format, a...)} +} + +func (e *Error) Error() string { + return e.message +} + +func Tokenize(source *string) ([]Token, error) { + lex := createLexer(source) + + for !lex.atEOF() { + matched := false + + for _, pattern := range lex.patterns { + loc := pattern.regex.FindStringIndex(lex.remainder()) + if loc != nil && loc[0] == 0 { + pattern.handler(lex, pattern.regex) + + matched = true + + break // Exit the loop after the first match + } + } + + if !matched { + return lex.Tokens, NewLexerError("unrecognized token near '%v'", lex.remainder()) + } + } + + lex.push(newUniqueToken(EOF, "EOF")) + + return lex.Tokens, nil +} + +func (lex *lexer) advanceN(n int) { + lex.pos += n +} + +func (lex *lexer) remainder() string { + return (*lex.source)[lex.pos:] +} + +func (lex *lexer) push(token Token) { + lex.Tokens = append(lex.Tokens, token) +} + +func (lex *lexer) atEOF() bool { + return lex.pos >= len(*lex.source) +} + +func createLexer(source *string) *lexer { + return &lexer{ + pos: 0, + line: 1, + source: source, + Tokens: make([]Token, 0), + patterns: []regexPattern{ + {regexp.MustCompile(`\s+`), skipHandler}, + {regexp.MustCompile(`"[^"]*"`), stringHandler}, + {regexp.MustCompile(`'[^\']*\'`), stringHandler}, + {regexp.MustCompile("`[^`]*`"), stringHandler}, + {regexp.MustCompile(`\-?[0-9]+(\.[0-9]+)?`), numberHandler}, + {regexp.MustCompile(`[a-zA-Z_][a-zA-Z0-9_]*`), symbolHandler}, + {regexp.MustCompile(`\(`), defaultHandler(OpenParen, "(")}, + {regexp.MustCompile(`\)`), defaultHandler(CloseParen, ")")}, + {regexp.MustCompile(`!=`), defaultHandler(NotEquals, "!=")}, + {regexp.MustCompile(`=`), defaultHandler(Equals, "=")}, + {regexp.MustCompile(`<=`), defaultHandler(LessEquals, "<=")}, + {regexp.MustCompile(`<`), defaultHandler(Less, "<")}, + {regexp.MustCompile(`>=`), defaultHandler(GreaterEquals, ">=")}, + {regexp.MustCompile(`>`), defaultHandler(Greater, ">")}, + {regexp.MustCompile(`\.`), defaultHandler(Dot, ".")}, + {regexp.MustCompile(`,`), defaultHandler(Comma, ",")}, + }, + } +} + +type regexHandler func(lex *lexer, regex *regexp.Regexp) + +// Created a default handler which will simply create a token with the matched contents. +// This handler is used with most simple tokens. +func defaultHandler(kind TokenKind, value string) regexHandler { + return func(lex *lexer, _ *regexp.Regexp) { + lex.advanceN(len(value)) + lex.push(newUniqueToken(kind, value)) + } +} + +func stringHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindStringIndex(lex.remainder()) + stringLiteral := lex.remainder()[match[0]:match[1]] + + lex.push(newUniqueToken(String, stringLiteral)) + lex.advanceN(len(stringLiteral)) +} + +func numberHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindString(lex.remainder()) + lex.push(newUniqueToken(Number, match)) + lex.advanceN(len(match)) +} + +func symbolHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindString(lex.remainder()) + keyword := strings.ToUpper(match) + + if kind, found := reservedLu[keyword]; found { + lex.push(newUniqueToken(kind, match)) + } else { + lex.push(newUniqueToken(Identifier, match)) + } + + lex.advanceN(len(match)) +} + +func skipHandler(lex *lexer, regex *regexp.Regexp) { + match := regex.FindStringIndex(lex.remainder()) + lex.advanceN(match[1]) +} diff --git a/pkg/tracking/service/query/lexer/tokenizer_test.go b/pkg/tracking/service/query/lexer/tokenizer_test.go index d4c9cea..fa5fdf1 100644 --- a/pkg/tracking/service/query/lexer/tokenizer_test.go +++ b/pkg/tracking/service/query/lexer/tokenizer_test.go @@ -1,114 +1,114 @@ -package lexer_test - -import ( - "strings" - "testing" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" -) - -type Sample struct { - input string - expected string -} - -//nolint:lll,funlen -func TestQueries(t *testing.T) { - t.Parallel() - - samples := []Sample{ - { - input: "metrics.accuracy > 0.72", - expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) eof", - }, - { - input: "metrics.\"accuracy\" > 0.72", - expected: "identifier(metrics) dot string(\"accuracy\") greater number(0.72) eof", - }, - { - input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", - expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) and identifier(metrics) dot identifier(loss) less_equals number(0.15) eof", - }, - { - input: "params.batch_size = \"2\"", - expected: "identifier(params) dot identifier(batch_size) equals string(\"2\") eof", - }, - { - input: "tags.task ILIKE \"classif%\"", - expected: "identifier(tags) dot identifier(task) ilike string(\"classif%\") eof", - }, - { - input: "datasets.digest IN ('s8ds293b', 'jks834s2')", - expected: "identifier(datasets) dot identifier(digest) in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", - }, - { - input: "attributes.created > 1664067852747", - expected: "identifier(attributes) dot identifier(created) greater number(1664067852747) eof", - }, - { - input: "params.batch_size != \"None\"", - expected: "identifier(params) dot identifier(batch_size) not_equals string(\"None\") eof", - }, - { - input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", - expected: "identifier(datasets) dot identifier(digest) not in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", - }, - { - input: "params.`random_state` = \"8888\"", - expected: "identifier(params) dot string(`random_state`) equals string(\"8888\") eof", - }, - { - input: "metrics.measure_a != -12.0", - expected: "identifier(metrics) dot identifier(measure_a) not_equals number(-12.0) eof", - }, - } - - for _, sample := range samples { - currentSample := sample - - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - tokens, err := lexer.Tokenize(¤tSample.input) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - output := "" - - for _, token := range tokens { - output += " " + token.Debug() - } - - output = strings.TrimLeft(output, " ") - - if output != currentSample.expected { - t.Errorf("expected %s, got %s", currentSample.expected, output) - } - }) - } -} - -func TestInvalidInput(t *testing.T) { - t.Parallel() - - samples := []string{ - "params.'acc = LR", - "params.acc = 'LR", - "params.acc = LR'", - "params.acc = \"LR'", - "tags.acc = \"LR'", - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample, func(t *testing.T) { - t.Parallel() - - _, err := lexer.Tokenize(¤tSample) - if err == nil { - t.Errorf("expected error, got nil") - } - }) - } -} +package lexer_test + +import ( + "strings" + "testing" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" +) + +type Sample struct { + input string + expected string +} + +//nolint:lll,funlen +func TestQueries(t *testing.T) { + t.Parallel() + + samples := []Sample{ + { + input: "metrics.accuracy > 0.72", + expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) eof", + }, + { + input: "metrics.\"accuracy\" > 0.72", + expected: "identifier(metrics) dot string(\"accuracy\") greater number(0.72) eof", + }, + { + input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", + expected: "identifier(metrics) dot identifier(accuracy) greater number(0.72) and identifier(metrics) dot identifier(loss) less_equals number(0.15) eof", + }, + { + input: "params.batch_size = \"2\"", + expected: "identifier(params) dot identifier(batch_size) equals string(\"2\") eof", + }, + { + input: "tags.task ILIKE \"classif%\"", + expected: "identifier(tags) dot identifier(task) ilike string(\"classif%\") eof", + }, + { + input: "datasets.digest IN ('s8ds293b', 'jks834s2')", + expected: "identifier(datasets) dot identifier(digest) in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", + }, + { + input: "attributes.created > 1664067852747", + expected: "identifier(attributes) dot identifier(created) greater number(1664067852747) eof", + }, + { + input: "params.batch_size != \"None\"", + expected: "identifier(params) dot identifier(batch_size) not_equals string(\"None\") eof", + }, + { + input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", + expected: "identifier(datasets) dot identifier(digest) not in open_paren string('s8ds293b') comma string('jks834s2') close_paren eof", + }, + { + input: "params.`random_state` = \"8888\"", + expected: "identifier(params) dot string(`random_state`) equals string(\"8888\") eof", + }, + { + input: "metrics.measure_a != -12.0", + expected: "identifier(metrics) dot identifier(measure_a) not_equals number(-12.0) eof", + }, + } + + for _, sample := range samples { + currentSample := sample + + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + tokens, err := lexer.Tokenize(¤tSample.input) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + output := "" + + for _, token := range tokens { + output += " " + token.Debug() + } + + output = strings.TrimLeft(output, " ") + + if output != currentSample.expected { + t.Errorf("expected %s, got %s", currentSample.expected, output) + } + }) + } +} + +func TestInvalidInput(t *testing.T) { + t.Parallel() + + samples := []string{ + "params.'acc = LR", + "params.acc = 'LR", + "params.acc = LR'", + "params.acc = \"LR'", + "tags.acc = \"LR'", + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample, func(t *testing.T) { + t.Parallel() + + _, err := lexer.Tokenize(¤tSample) + if err == nil { + t.Errorf("expected error, got nil") + } + }) + } +} diff --git a/pkg/tracking/service/query/parser/ast.go b/pkg/tracking/service/query/parser/ast.go index 6a92b31..9f2673c 100644 --- a/pkg/tracking/service/query/parser/ast.go +++ b/pkg/tracking/service/query/parser/ast.go @@ -1,137 +1,137 @@ -package parser - -import ( - "fmt" - "strings" -) - -// -------------------- -// Literal Expressions -// -------------------- - -type Value interface { - value() interface{} - fmt.Stringer -} - -type NumberExpr struct { - Value float64 -} - -func (n NumberExpr) value() interface{} { - return n.Value -} - -func (n NumberExpr) String() string { - return fmt.Sprintf("%f", n.Value) -} - -type StringExpr struct { - Value string -} - -func (n StringExpr) value() interface{} { - return n.Value -} - -func (n StringExpr) String() string { - return fmt.Sprintf("\"%s\"", n.Value) -} - -type StringListExpr struct { - Values []string -} - -func (n StringListExpr) value() interface{} { - return n.Values -} - -func (n StringListExpr) String() string { - items := make([]string, 0, len(n.Values)) - for _, v := range n.Values { - items = append(items, fmt.Sprintf("\"%s\"", v)) - } - - return strings.Join(items, ", ") -} - -//----------------------- -// Identifier Expressions -// ---------------------- - -// identifier.key expression, like metric.foo. -type Identifier struct { - Identifier string - Key string -} - -func (i Identifier) String() string { - if i.Key == "" { - return i.Identifier - } - - return fmt.Sprintf("%s.%s", i.Identifier, i.Key) -} - -// -------------------- -// Comparison Expression -// -------------------- - -type OperatorKind int - -const ( - Equals OperatorKind = iota - NotEquals - Less - LessEquals - Greater - GreaterEquals - Like - ILike - In //nolint:varnamelen - NotIn -) - -//nolint:cyclop -func (op OperatorKind) String() string { - switch op { - case Equals: - return "=" - case NotEquals: - return "!=" - case Less: - return "<" - case LessEquals: - return "<=" - case Greater: - return ">" - case GreaterEquals: - return ">=" - case Like: - return "LIKE" - case ILike: - return "ILIKE" - case In: - return "IN" - case NotIn: - return "NOT IN" - default: - return "UNKNOWN" - } -} - -// a operator b. -type CompareExpr struct { - Left Identifier - Operator OperatorKind - Right Value -} - -func (expr *CompareExpr) String() string { - return fmt.Sprintf("%s %s %s", expr.Left, expr.Operator, expr.Right) -} - -// AND. -type AndExpr struct { - Exprs []*CompareExpr -} +package parser + +import ( + "fmt" + "strings" +) + +// -------------------- +// Literal Expressions +// -------------------- + +type Value interface { + value() interface{} + fmt.Stringer +} + +type NumberExpr struct { + Value float64 +} + +func (n NumberExpr) value() interface{} { + return n.Value +} + +func (n NumberExpr) String() string { + return fmt.Sprintf("%f", n.Value) +} + +type StringExpr struct { + Value string +} + +func (n StringExpr) value() interface{} { + return n.Value +} + +func (n StringExpr) String() string { + return fmt.Sprintf("\"%s\"", n.Value) +} + +type StringListExpr struct { + Values []string +} + +func (n StringListExpr) value() interface{} { + return n.Values +} + +func (n StringListExpr) String() string { + items := make([]string, 0, len(n.Values)) + for _, v := range n.Values { + items = append(items, fmt.Sprintf("\"%s\"", v)) + } + + return strings.Join(items, ", ") +} + +//----------------------- +// Identifier Expressions +// ---------------------- + +// identifier.key expression, like metric.foo. +type Identifier struct { + Identifier string + Key string +} + +func (i Identifier) String() string { + if i.Key == "" { + return i.Identifier + } + + return fmt.Sprintf("%s.%s", i.Identifier, i.Key) +} + +// -------------------- +// Comparison Expression +// -------------------- + +type OperatorKind int + +const ( + Equals OperatorKind = iota + NotEquals + Less + LessEquals + Greater + GreaterEquals + Like + ILike + In //nolint:varnamelen + NotIn +) + +//nolint:cyclop +func (op OperatorKind) String() string { + switch op { + case Equals: + return "=" + case NotEquals: + return "!=" + case Less: + return "<" + case LessEquals: + return "<=" + case Greater: + return ">" + case GreaterEquals: + return ">=" + case Like: + return "LIKE" + case ILike: + return "ILIKE" + case In: + return "IN" + case NotIn: + return "NOT IN" + default: + return "UNKNOWN" + } +} + +// a operator b. +type CompareExpr struct { + Left Identifier + Operator OperatorKind + Right Value +} + +func (expr *CompareExpr) String() string { + return fmt.Sprintf("%s %s %s", expr.Left, expr.Operator, expr.Right) +} + +// AND. +type AndExpr struct { + Exprs []*CompareExpr +} diff --git a/pkg/tracking/service/query/parser/parser.go b/pkg/tracking/service/query/parser/parser.go index 978aa29..6bc39de 100644 --- a/pkg/tracking/service/query/parser/parser.go +++ b/pkg/tracking/service/query/parser/parser.go @@ -1,265 +1,265 @@ -package parser - -import ( - "fmt" - "strconv" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" -) - -type parser struct { - tokens []lexer.Token - pos int -} - -func newParser(tokens []lexer.Token) *parser { - return &parser{ - tokens: tokens, - pos: 0, - } -} - -func (p *parser) currentTokenKind() lexer.TokenKind { - return p.tokens[p.pos].Kind -} - -func (p *parser) hasTokens() bool { - return p.pos < len(p.tokens) && p.currentTokenKind() != lexer.EOF -} - -func (p *parser) printCurrentToken() string { - return p.tokens[p.pos].Debug() -} - -func (p *parser) currentToken() lexer.Token { - return p.tokens[p.pos] -} - -func (p *parser) advance() lexer.Token { - tk := p.currentToken() - p.pos++ - - return tk -} - -type Error struct { - message string -} - -func NewParserError(format string, a ...any) *Error { - return &Error{message: fmt.Sprintf(format, a...)} -} - -func (e *Error) Error() string { - return e.message -} - -func (p *parser) parseIdentifier() (Identifier, error) { - emptyIdentifier := Identifier{Identifier: "", Key: ""} - if p.hasTokens() && p.currentTokenKind() != lexer.Identifier { - return emptyIdentifier, NewParserError( - "expected identifier, got %s", - p.printCurrentToken(), - ) - } - - identToken := p.advance() - - if p.currentTokenKind() == lexer.Dot { - p.advance() // Consume the DOT - //nolint:exhaustive - switch p.currentTokenKind() { - case lexer.Identifier: - column := p.advance().Value - - return Identifier{Identifier: identToken.Value, Key: column}, nil - case lexer.String: - column := p.advance().Value - column = column[1 : len(column)-1] // Remove quotes - - return Identifier{Identifier: identToken.Value, Key: column}, nil - default: - return emptyIdentifier, NewParserError( - "expected IDENTIFIER or STRING, got %s", - p.printCurrentToken(), - ) - } - } else { - return Identifier{Identifier: "", Key: identToken.Value}, nil - } -} - -func (p *parser) parseOperator() (OperatorKind, error) { - //nolint:exhaustive - switch p.advance().Kind { - case lexer.Equals: - return Equals, nil - case lexer.NotEquals: - return NotEquals, nil - case lexer.Less: - return Less, nil - case lexer.LessEquals: - return LessEquals, nil - case lexer.Greater: - return Greater, nil - case lexer.GreaterEquals: - return GreaterEquals, nil - case lexer.Like: - return Like, nil - case lexer.ILike: - return ILike, nil - default: - return -1, NewParserError("expected operator, got %s", p.printCurrentToken()) - } -} - -//nolint:ireturn -func (p *parser) parseValue() (Value, error) { - //nolint:exhaustive - switch p.currentTokenKind() { - case lexer.Number: - n, err := strconv.ParseFloat(p.advance().Value, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse number token to float: %w", err) - } - - return NumberExpr{Value: n}, nil - case lexer.String: - value := p.advance().Value - value = value[1 : len(value)-1] // Remove quotes - - return StringExpr{Value: value}, nil - default: - return nil, NewParserError( - "Expected NUMBER or STRING, got %s", - p.printCurrentToken(), - ) - } -} - -func (p *parser) parseInSetExpr(ident Identifier) (*CompareExpr, error) { - if p.currentTokenKind() != lexer.OpenParen { - return nil, NewParserError( - "expected '(', got %s", - p.printCurrentToken(), - ) - } - - p.advance() // Consume the OPEN_PAREN - - set := make([]string, 0) - - for p.hasTokens() && p.currentTokenKind() != lexer.CloseParen { - if p.currentTokenKind() != lexer.String { - return nil, NewParserError( - "expected STRING, got %s", - p.printCurrentToken(), - ) - } - - value := p.advance().Value - value = value[1 : len(value)-1] // Remove quotes - - set = append(set, value) - - if p.currentTokenKind() == lexer.Comma { - p.advance() // Consume the COMMA - } - } - - if p.currentTokenKind() != lexer.CloseParen { - return nil, NewParserError( - "expected ')', got %s", - p.printCurrentToken(), - ) - } - - p.advance() // Consume the CLOSE_PAREN - - return &CompareExpr{Left: ident, Operator: In, Right: StringListExpr{Values: set}}, nil -} - -func (p *parser) parseExpression() (*CompareExpr, error) { - ident, err := p.parseIdentifier() - if err != nil { - return nil, err - } - - //nolint:exhaustive - switch p.currentTokenKind() { - case lexer.In: - p.advance() // Consume the IN - - return p.parseInSetExpr(ident) - case lexer.Not: - p.advance() // Consume the NOT - - if p.currentTokenKind() != lexer.In { - return nil, NewParserError( - "expected IN after NOT, got %s", - p.printCurrentToken(), - ) - } - - p.advance() // Consume the IN - - expr, err := p.parseInSetExpr(ident) - if err != nil { - return nil, err - } - - expr.Operator = NotIn - - return expr, nil - default: - operator, err := p.parseOperator() - if err != nil { - return nil, err - } - - value, err := p.parseValue() - if err != nil { - return nil, err - } - - return &CompareExpr{Left: ident, Operator: operator, Right: value}, nil - } -} - -func (p *parser) parse() (*AndExpr, error) { - exprs := make([]*CompareExpr, 0) - - leftExpr, err := p.parseExpression() - if err != nil { - return nil, fmt.Errorf("error while parsing initial expression: %w", err) - } - - exprs = append(exprs, leftExpr) - - // While there are tokens and the next token is AND - for p.currentTokenKind() == lexer.And { - p.advance() // Consume the AND - - rightExpr, err := p.parseExpression() - if err != nil { - return nil, err - } - - exprs = append(exprs, rightExpr) - } - - if p.hasTokens() { - return nil, NewParserError( - "unexpected leftover token(s) after parsing: %s", - p.printCurrentToken(), - ) - } - - return &AndExpr{Exprs: exprs}, nil -} - -func Parse(tokens []lexer.Token) (*AndExpr, error) { - parser := newParser(tokens) - - return parser.parse() -} +package parser + +import ( + "fmt" + "strconv" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" +) + +type parser struct { + tokens []lexer.Token + pos int +} + +func newParser(tokens []lexer.Token) *parser { + return &parser{ + tokens: tokens, + pos: 0, + } +} + +func (p *parser) currentTokenKind() lexer.TokenKind { + return p.tokens[p.pos].Kind +} + +func (p *parser) hasTokens() bool { + return p.pos < len(p.tokens) && p.currentTokenKind() != lexer.EOF +} + +func (p *parser) printCurrentToken() string { + return p.tokens[p.pos].Debug() +} + +func (p *parser) currentToken() lexer.Token { + return p.tokens[p.pos] +} + +func (p *parser) advance() lexer.Token { + tk := p.currentToken() + p.pos++ + + return tk +} + +type Error struct { + message string +} + +func NewParserError(format string, a ...any) *Error { + return &Error{message: fmt.Sprintf(format, a...)} +} + +func (e *Error) Error() string { + return e.message +} + +func (p *parser) parseIdentifier() (Identifier, error) { + emptyIdentifier := Identifier{Identifier: "", Key: ""} + if p.hasTokens() && p.currentTokenKind() != lexer.Identifier { + return emptyIdentifier, NewParserError( + "expected identifier, got %s", + p.printCurrentToken(), + ) + } + + identToken := p.advance() + + if p.currentTokenKind() == lexer.Dot { + p.advance() // Consume the DOT + //nolint:exhaustive + switch p.currentTokenKind() { + case lexer.Identifier: + column := p.advance().Value + + return Identifier{Identifier: identToken.Value, Key: column}, nil + case lexer.String: + column := p.advance().Value + column = column[1 : len(column)-1] // Remove quotes + + return Identifier{Identifier: identToken.Value, Key: column}, nil + default: + return emptyIdentifier, NewParserError( + "expected IDENTIFIER or STRING, got %s", + p.printCurrentToken(), + ) + } + } else { + return Identifier{Identifier: "", Key: identToken.Value}, nil + } +} + +func (p *parser) parseOperator() (OperatorKind, error) { + //nolint:exhaustive + switch p.advance().Kind { + case lexer.Equals: + return Equals, nil + case lexer.NotEquals: + return NotEquals, nil + case lexer.Less: + return Less, nil + case lexer.LessEquals: + return LessEquals, nil + case lexer.Greater: + return Greater, nil + case lexer.GreaterEquals: + return GreaterEquals, nil + case lexer.Like: + return Like, nil + case lexer.ILike: + return ILike, nil + default: + return -1, NewParserError("expected operator, got %s", p.printCurrentToken()) + } +} + +//nolint:ireturn +func (p *parser) parseValue() (Value, error) { + //nolint:exhaustive + switch p.currentTokenKind() { + case lexer.Number: + n, err := strconv.ParseFloat(p.advance().Value, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse number token to float: %w", err) + } + + return NumberExpr{Value: n}, nil + case lexer.String: + value := p.advance().Value + value = value[1 : len(value)-1] // Remove quotes + + return StringExpr{Value: value}, nil + default: + return nil, NewParserError( + "Expected NUMBER or STRING, got %s", + p.printCurrentToken(), + ) + } +} + +func (p *parser) parseInSetExpr(ident Identifier) (*CompareExpr, error) { + if p.currentTokenKind() != lexer.OpenParen { + return nil, NewParserError( + "expected '(', got %s", + p.printCurrentToken(), + ) + } + + p.advance() // Consume the OPEN_PAREN + + set := make([]string, 0) + + for p.hasTokens() && p.currentTokenKind() != lexer.CloseParen { + if p.currentTokenKind() != lexer.String { + return nil, NewParserError( + "expected STRING, got %s", + p.printCurrentToken(), + ) + } + + value := p.advance().Value + value = value[1 : len(value)-1] // Remove quotes + + set = append(set, value) + + if p.currentTokenKind() == lexer.Comma { + p.advance() // Consume the COMMA + } + } + + if p.currentTokenKind() != lexer.CloseParen { + return nil, NewParserError( + "expected ')', got %s", + p.printCurrentToken(), + ) + } + + p.advance() // Consume the CLOSE_PAREN + + return &CompareExpr{Left: ident, Operator: In, Right: StringListExpr{Values: set}}, nil +} + +func (p *parser) parseExpression() (*CompareExpr, error) { + ident, err := p.parseIdentifier() + if err != nil { + return nil, err + } + + //nolint:exhaustive + switch p.currentTokenKind() { + case lexer.In: + p.advance() // Consume the IN + + return p.parseInSetExpr(ident) + case lexer.Not: + p.advance() // Consume the NOT + + if p.currentTokenKind() != lexer.In { + return nil, NewParserError( + "expected IN after NOT, got %s", + p.printCurrentToken(), + ) + } + + p.advance() // Consume the IN + + expr, err := p.parseInSetExpr(ident) + if err != nil { + return nil, err + } + + expr.Operator = NotIn + + return expr, nil + default: + operator, err := p.parseOperator() + if err != nil { + return nil, err + } + + value, err := p.parseValue() + if err != nil { + return nil, err + } + + return &CompareExpr{Left: ident, Operator: operator, Right: value}, nil + } +} + +func (p *parser) parse() (*AndExpr, error) { + exprs := make([]*CompareExpr, 0) + + leftExpr, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("error while parsing initial expression: %w", err) + } + + exprs = append(exprs, leftExpr) + + // While there are tokens and the next token is AND + for p.currentTokenKind() == lexer.And { + p.advance() // Consume the AND + + rightExpr, err := p.parseExpression() + if err != nil { + return nil, err + } + + exprs = append(exprs, rightExpr) + } + + if p.hasTokens() { + return nil, NewParserError( + "unexpected leftover token(s) after parsing: %s", + p.printCurrentToken(), + ) + } + + return &AndExpr{Exprs: exprs}, nil +} + +func Parse(tokens []lexer.Token) (*AndExpr, error) { + parser := newParser(tokens) + + return parser.parse() +} diff --git a/pkg/tracking/service/query/parser/parser_test.go b/pkg/tracking/service/query/parser/parser_test.go index 9b99094..64129c3 100644 --- a/pkg/tracking/service/query/parser/parser_test.go +++ b/pkg/tracking/service/query/parser/parser_test.go @@ -1,182 +1,182 @@ -package parser_test - -import ( - "reflect" - "testing" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" -) - -type Sample struct { - input string - expected *parser.AndExpr -} - -//nolint:funlen -func TestQueries(t *testing.T) { - t.Parallel() - - samples := []Sample{ - { - input: "metrics.accuracy > 0.72", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"metrics", "accuracy"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 0.72}, - }, - }, - }, - }, - { - input: "metrics.\"accuracy\" > 0.72", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"metrics", "accuracy"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 0.72}, - }, - }, - }, - }, - { - input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"metrics", "accuracy"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 0.72}, - }, - { - Left: parser.Identifier{"metrics", "loss"}, - Operator: parser.LessEquals, - Right: parser.NumberExpr{Value: 0.15}, - }, - }, - }, - }, - { - input: "params.batch_size = \"2\"", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"params", "batch_size"}, - Operator: parser.Equals, - Right: parser.StringExpr{Value: "2"}, - }, - }, - }, - }, - { - input: "tags.task ILIKE \"classif%\"", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"tags", "task"}, - Operator: parser.ILike, - Right: parser.StringExpr{Value: "classif%"}, - }, - }, - }, - }, - { - input: "datasets.digest IN ('s8ds293b', 'jks834s2')", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"datasets", "digest"}, - Operator: parser.In, - Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, - }, - }, - }, - }, - { - input: "attributes.created > 1664067852747", - expected: &parser.AndExpr{ - []*parser.CompareExpr{ - { - Left: parser.Identifier{"attributes", "created"}, - Operator: parser.Greater, - Right: parser.NumberExpr{Value: 1664067852747}, - }, - }, - }, - }, - { - input: "params.batch_size != \"None\"", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"params", "batch_size"}, - Operator: parser.NotEquals, - Right: parser.StringExpr{Value: "None"}, - }, - }, - }, - }, - { - input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", - expected: &parser.AndExpr{ - Exprs: []*parser.CompareExpr{ - { - Left: parser.Identifier{"datasets", "digest"}, - Operator: parser.NotIn, - Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, - }, - }, - }, - }, - } - - for _, sample := range samples { - currentSample := sample - - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - tokens, err := lexer.Tokenize(¤tSample.input) - if err != nil { - t.Errorf("unexpected lex error: %v", err) - } - - ast, err := parser.Parse(tokens) - if err != nil { - t.Errorf("error parsing: %s", err) - } - - if !reflect.DeepEqual(ast, currentSample.expected) { - t.Errorf("expected %#v, got %#v", currentSample.expected, ast) - } - }) - } -} - -func TestInvalidSyntax(t *testing.T) { - t.Parallel() - - samples := []string{ - "attribute.status IS 'RUNNING'", - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample, func(t *testing.T) { - t.Parallel() - - tokens, err := lexer.Tokenize(¤tSample) - if err != nil { - t.Errorf("unexpected lex error: %v", err) - } - - _, err = parser.Parse(tokens) - if err == nil { - t.Errorf("expected parse error, got nil") - } - }) - } -} +package parser_test + +import ( + "reflect" + "testing" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" +) + +type Sample struct { + input string + expected *parser.AndExpr +} + +//nolint:funlen +func TestQueries(t *testing.T) { + t.Parallel() + + samples := []Sample{ + { + input: "metrics.accuracy > 0.72", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"metrics", "accuracy"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 0.72}, + }, + }, + }, + }, + { + input: "metrics.\"accuracy\" > 0.72", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"metrics", "accuracy"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 0.72}, + }, + }, + }, + }, + { + input: "metrics.accuracy > 0.72 AND metrics.loss <= 0.15", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"metrics", "accuracy"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 0.72}, + }, + { + Left: parser.Identifier{"metrics", "loss"}, + Operator: parser.LessEquals, + Right: parser.NumberExpr{Value: 0.15}, + }, + }, + }, + }, + { + input: "params.batch_size = \"2\"", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"params", "batch_size"}, + Operator: parser.Equals, + Right: parser.StringExpr{Value: "2"}, + }, + }, + }, + }, + { + input: "tags.task ILIKE \"classif%\"", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"tags", "task"}, + Operator: parser.ILike, + Right: parser.StringExpr{Value: "classif%"}, + }, + }, + }, + }, + { + input: "datasets.digest IN ('s8ds293b', 'jks834s2')", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"datasets", "digest"}, + Operator: parser.In, + Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, + }, + }, + }, + }, + { + input: "attributes.created > 1664067852747", + expected: &parser.AndExpr{ + []*parser.CompareExpr{ + { + Left: parser.Identifier{"attributes", "created"}, + Operator: parser.Greater, + Right: parser.NumberExpr{Value: 1664067852747}, + }, + }, + }, + }, + { + input: "params.batch_size != \"None\"", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"params", "batch_size"}, + Operator: parser.NotEquals, + Right: parser.StringExpr{Value: "None"}, + }, + }, + }, + }, + { + input: "datasets.digest NOT IN ('s8ds293b', 'jks834s2')", + expected: &parser.AndExpr{ + Exprs: []*parser.CompareExpr{ + { + Left: parser.Identifier{"datasets", "digest"}, + Operator: parser.NotIn, + Right: parser.StringListExpr{Values: []string{"s8ds293b", "jks834s2"}}, + }, + }, + }, + }, + } + + for _, sample := range samples { + currentSample := sample + + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + tokens, err := lexer.Tokenize(¤tSample.input) + if err != nil { + t.Errorf("unexpected lex error: %v", err) + } + + ast, err := parser.Parse(tokens) + if err != nil { + t.Errorf("error parsing: %s", err) + } + + if !reflect.DeepEqual(ast, currentSample.expected) { + t.Errorf("expected %#v, got %#v", currentSample.expected, ast) + } + }) + } +} + +func TestInvalidSyntax(t *testing.T) { + t.Parallel() + + samples := []string{ + "attribute.status IS 'RUNNING'", + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample, func(t *testing.T) { + t.Parallel() + + tokens, err := lexer.Tokenize(¤tSample) + if err != nil { + t.Errorf("unexpected lex error: %v", err) + } + + _, err = parser.Parse(tokens) + if err == nil { + t.Errorf("expected parse error, got nil") + } + }) + } +} diff --git a/pkg/tracking/service/query/parser/validate.go b/pkg/tracking/service/query/parser/validate.go index 011589a..d2292c0 100644 --- a/pkg/tracking/service/query/parser/validate.go +++ b/pkg/tracking/service/query/parser/validate.go @@ -1,329 +1,329 @@ -package parser - -import ( - "errors" - "fmt" - "strings" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -/* - -This is the equivalent of type-checking the untyped tree. -Not every parsed tree is a valid one. - -Grammar rule: identifier.key operator value - -The rules are: - -For identifiers: - -identifier.key - -Or if only key is passed, the identifier is "attribute" - -Identifiers can have aliases. - -if the identifier is dataset, the allowed keys are: name, digest and context. - -*/ - -type ValidIdentifier int - -const ( - Metric ValidIdentifier = iota - Parameter - Tag - Attribute - Dataset -) - -func (v ValidIdentifier) String() string { - switch v { - case Metric: - return "metric" - case Parameter: - return "parameter" - case Tag: - return "tag" - case Attribute: - return "attribute" - case Dataset: - return "dataset" - default: - return "unknown" - } -} - -type ValidCompareExpr struct { - Identifier ValidIdentifier - Key string - Operator OperatorKind - Value interface{} -} - -func (v ValidCompareExpr) String() string { - return fmt.Sprintf("%s.%s %s %v", v.Identifier, v.Key, v.Operator, v.Value) -} - -type ValidationError struct { - message string -} - -func (e *ValidationError) Error() string { - return e.message -} - -func NewValidationError(format string, a ...interface{}) *ValidationError { - return &ValidationError{message: fmt.Sprintf(format, a...)} -} - -const ( - metricIdentifier = "metric" - parameterIdentifier = "parameter" - tagIdentifier = "tag" - attributeIdentifier = "attribute" - datasetIdentifier = "dataset" -) - -var identifiers = []string{ - metricIdentifier, - parameterIdentifier, - tagIdentifier, - attributeIdentifier, - datasetIdentifier, -} - -func parseValidIdentifier(identifier string) (ValidIdentifier, error) { - switch identifier { - case metricIdentifier, "metrics": - return Metric, nil - case parameterIdentifier, "parameters", "param", "params": - return Parameter, nil - case tagIdentifier, "tags": - return Tag, nil - case "", attributeIdentifier, "attr", "attributes", "run": - return Attribute, nil - case datasetIdentifier, "datasets": - return Dataset, nil - default: - return -1, NewValidationError("invalid identifier %q", identifier) - } -} - -const ( - RunID = "run_id" - RunName = "run_name" - Created = "created" - StartTime = "start_time" -) - -// This should be configurable and only applies to the runs table. -var searchableRunAttributes = []string{ - RunID, - RunName, - "user_id", - "status", - StartTime, - "end_time", - "artifact_uri", -} - -var datasetAttributes = []string{"name", "digest", "context"} - -func parseAttributeKey(key string) (string, error) { - switch key { - case "run_id": - // We return run_uuid before that is the SQL column name. - return "run_uuid", nil - case - "user_id", - "status", - StartTime, - "end_time", - "artifact_uri": - return key, nil - case Created, "Created": - return StartTime, nil - case RunName, "run name", "Run name", "Run Name": - return RunName, nil - default: - return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, - fmt.Sprintf( - "Invalid attribute key '{%s}' specified. Valid keys are '%v'", - key, - searchableRunAttributes, - ), - ) - } -} - -func parseKey(identifier ValidIdentifier, key string) (string, error) { - if key == "" { - return attributeIdentifier, nil - } - - //nolint:exhaustive - switch identifier { - case Attribute: - return parseAttributeKey(key) - case Dataset: - switch key { - case "name", "digest", "context": - return key, nil - default: - return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, - fmt.Sprintf( - "Invalid dataset key '{%s}' specified. Valid keys are '%v'", - key, - searchableRunAttributes, - ), - ) - } - default: - return key, nil - } -} - -// Returns a standardized LongIdentifierExpr. -func validatedIdentifier(identifier *Identifier) (ValidIdentifier, string, error) { - validIdentifier, err := parseValidIdentifier(identifier.Identifier) - if err != nil { - return -1, "", err - } - - validKey, err := parseKey(validIdentifier, identifier.Key) - if err != nil { - return -1, "", err - } - - identifier.Key = validKey - - return validIdentifier, validKey, nil -} - -/* - -The value part is determined by the identifier - -"metric" takes numbers -"parameter" and "tag" takes strings - -"attribute" could be either string or number, -number when StartTime, "end_time" or "created", "Created" -otherwise string - -"dataset" takes strings for "name", "digest" and "context" - -*/ - -func validateDatasetValue(key string, value Value) (interface{}, error) { - switch key { - case "name", "digest", "context": - if _, ok := value.(NumberExpr); ok { - return nil, NewValidationError( - "expected datasets.%s to be either a string or list of strings. Found %s", - key, - value, - ) - } - - return value.value(), nil - default: - return nil, NewValidationError( - "expected dataset attribute key to be one of %s. Found %s", - strings.Join(datasetAttributes, ", "), - key, - ) - } -} - -// Port of _get_value in search_utils.py. -func validateValue(identifier ValidIdentifier, key string, value Value) (interface{}, error) { - switch identifier { - case Metric: - if _, ok := value.(NumberExpr); !ok { - return nil, NewValidationError( - "expected numeric value type for metric. Found %s", - value, - ) - } - - return value.value(), nil - case Parameter, Tag: - if _, ok := value.(StringExpr); !ok { - return nil, NewValidationError( - "expected a quoted string value for %s. Found %s", - identifier, value, - ) - } - - return value.value(), nil - case Attribute: - value, err := validateAttributeValue(key, value) - - return value, err - case Dataset: - return validateDatasetValue(key, value) - default: - return nil, NewValidationError( - "Invalid identifier type %s. Expected one of %s", - identifier, - strings.Join(identifiers, ", "), - ) - } -} - -func validateAttributeValue(key string, value Value) (interface{}, error) { - switch key { - case StartTime, "end_time", Created: - if _, ok := value.(NumberExpr); !ok { - return nil, NewValidationError( - "expected numeric value type for numeric attribute: %s. Found %s", - key, - value, - ) - } - - return value.value(), nil - default: - // run_id was earlier converted to run_uuid - if _, ok := value.(StringListExpr); key != "run_uuid" && ok { - return nil, NewValidationError( - "only the 'run_id' attribute supports comparison with a list of quoted string values", - ) - } - - return value.value(), nil - } -} - -// Validate an expression according to the mlflow domain. -// This represent is a simple type-checker for the expression. -// Not every identifier is valid according to the mlflow domain. -// The same for the value part. -func ValidateExpression(expression *CompareExpr) (*ValidCompareExpr, error) { - validIdentifier, validKey, err := validatedIdentifier(&expression.Left) - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return nil, contractError - } - - return nil, fmt.Errorf("Error on parsing filter expression: %w", err) - } - - value, err := validateValue(validIdentifier, validKey, expression.Right) - if err != nil { - return nil, fmt.Errorf("Error on parsing filter expression: %w", err) - } - - return &ValidCompareExpr{ - Identifier: validIdentifier, - Key: validKey, - Operator: expression.Operator, - Value: value, - }, nil -} +package parser + +import ( + "errors" + "fmt" + "strings" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +/* + +This is the equivalent of type-checking the untyped tree. +Not every parsed tree is a valid one. + +Grammar rule: identifier.key operator value + +The rules are: + +For identifiers: + +identifier.key + +Or if only key is passed, the identifier is "attribute" + +Identifiers can have aliases. + +if the identifier is dataset, the allowed keys are: name, digest and context. + +*/ + +type ValidIdentifier int + +const ( + Metric ValidIdentifier = iota + Parameter + Tag + Attribute + Dataset +) + +func (v ValidIdentifier) String() string { + switch v { + case Metric: + return "metric" + case Parameter: + return "parameter" + case Tag: + return "tag" + case Attribute: + return "attribute" + case Dataset: + return "dataset" + default: + return "unknown" + } +} + +type ValidCompareExpr struct { + Identifier ValidIdentifier + Key string + Operator OperatorKind + Value interface{} +} + +func (v ValidCompareExpr) String() string { + return fmt.Sprintf("%s.%s %s %v", v.Identifier, v.Key, v.Operator, v.Value) +} + +type ValidationError struct { + message string +} + +func (e *ValidationError) Error() string { + return e.message +} + +func NewValidationError(format string, a ...interface{}) *ValidationError { + return &ValidationError{message: fmt.Sprintf(format, a...)} +} + +const ( + metricIdentifier = "metric" + parameterIdentifier = "parameter" + tagIdentifier = "tag" + attributeIdentifier = "attribute" + datasetIdentifier = "dataset" +) + +var identifiers = []string{ + metricIdentifier, + parameterIdentifier, + tagIdentifier, + attributeIdentifier, + datasetIdentifier, +} + +func parseValidIdentifier(identifier string) (ValidIdentifier, error) { + switch identifier { + case metricIdentifier, "metrics": + return Metric, nil + case parameterIdentifier, "parameters", "param", "params": + return Parameter, nil + case tagIdentifier, "tags": + return Tag, nil + case "", attributeIdentifier, "attr", "attributes", "run": + return Attribute, nil + case datasetIdentifier, "datasets": + return Dataset, nil + default: + return -1, NewValidationError("invalid identifier %q", identifier) + } +} + +const ( + RunID = "run_id" + RunName = "run_name" + Created = "created" + StartTime = "start_time" +) + +// This should be configurable and only applies to the runs table. +var searchableRunAttributes = []string{ + RunID, + RunName, + "user_id", + "status", + StartTime, + "end_time", + "artifact_uri", +} + +var datasetAttributes = []string{"name", "digest", "context"} + +func parseAttributeKey(key string) (string, error) { + switch key { + case "run_id": + // We return run_uuid before that is the SQL column name. + return "run_uuid", nil + case + "user_id", + "status", + StartTime, + "end_time", + "artifact_uri": + return key, nil + case Created, "Created": + return StartTime, nil + case RunName, "run name", "Run name", "Run Name": + return RunName, nil + default: + return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, + fmt.Sprintf( + "Invalid attribute key '{%s}' specified. Valid keys are '%v'", + key, + searchableRunAttributes, + ), + ) + } +} + +func parseKey(identifier ValidIdentifier, key string) (string, error) { + if key == "" { + return attributeIdentifier, nil + } + + //nolint:exhaustive + switch identifier { + case Attribute: + return parseAttributeKey(key) + case Dataset: + switch key { + case "name", "digest", "context": + return key, nil + default: + return "", contract.NewError(protos.ErrorCode_BAD_REQUEST, + fmt.Sprintf( + "Invalid dataset key '{%s}' specified. Valid keys are '%v'", + key, + searchableRunAttributes, + ), + ) + } + default: + return key, nil + } +} + +// Returns a standardized LongIdentifierExpr. +func validatedIdentifier(identifier *Identifier) (ValidIdentifier, string, error) { + validIdentifier, err := parseValidIdentifier(identifier.Identifier) + if err != nil { + return -1, "", err + } + + validKey, err := parseKey(validIdentifier, identifier.Key) + if err != nil { + return -1, "", err + } + + identifier.Key = validKey + + return validIdentifier, validKey, nil +} + +/* + +The value part is determined by the identifier + +"metric" takes numbers +"parameter" and "tag" takes strings + +"attribute" could be either string or number, +number when StartTime, "end_time" or "created", "Created" +otherwise string + +"dataset" takes strings for "name", "digest" and "context" + +*/ + +func validateDatasetValue(key string, value Value) (interface{}, error) { + switch key { + case "name", "digest", "context": + if _, ok := value.(NumberExpr); ok { + return nil, NewValidationError( + "expected datasets.%s to be either a string or list of strings. Found %s", + key, + value, + ) + } + + return value.value(), nil + default: + return nil, NewValidationError( + "expected dataset attribute key to be one of %s. Found %s", + strings.Join(datasetAttributes, ", "), + key, + ) + } +} + +// Port of _get_value in search_utils.py. +func validateValue(identifier ValidIdentifier, key string, value Value) (interface{}, error) { + switch identifier { + case Metric: + if _, ok := value.(NumberExpr); !ok { + return nil, NewValidationError( + "expected numeric value type for metric. Found %s", + value, + ) + } + + return value.value(), nil + case Parameter, Tag: + if _, ok := value.(StringExpr); !ok { + return nil, NewValidationError( + "expected a quoted string value for %s. Found %s", + identifier, value, + ) + } + + return value.value(), nil + case Attribute: + value, err := validateAttributeValue(key, value) + + return value, err + case Dataset: + return validateDatasetValue(key, value) + default: + return nil, NewValidationError( + "Invalid identifier type %s. Expected one of %s", + identifier, + strings.Join(identifiers, ", "), + ) + } +} + +func validateAttributeValue(key string, value Value) (interface{}, error) { + switch key { + case StartTime, "end_time", Created: + if _, ok := value.(NumberExpr); !ok { + return nil, NewValidationError( + "expected numeric value type for numeric attribute: %s. Found %s", + key, + value, + ) + } + + return value.value(), nil + default: + // run_id was earlier converted to run_uuid + if _, ok := value.(StringListExpr); key != "run_uuid" && ok { + return nil, NewValidationError( + "only the 'run_id' attribute supports comparison with a list of quoted string values", + ) + } + + return value.value(), nil + } +} + +// Validate an expression according to the mlflow domain. +// This represent is a simple type-checker for the expression. +// Not every identifier is valid according to the mlflow domain. +// The same for the value part. +func ValidateExpression(expression *CompareExpr) (*ValidCompareExpr, error) { + validIdentifier, validKey, err := validatedIdentifier(&expression.Left) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return nil, contractError + } + + return nil, fmt.Errorf("Error on parsing filter expression: %w", err) + } + + value, err := validateValue(validIdentifier, validKey, expression.Right) + if err != nil { + return nil, fmt.Errorf("Error on parsing filter expression: %w", err) + } + + return &ValidCompareExpr{ + Identifier: validIdentifier, + Key: validKey, + Operator: expression.Operator, + Value: value, + }, nil +} diff --git a/pkg/tracking/service/query/query.go b/pkg/tracking/service/query/query.go index eac82cc..9f6cb30 100644 --- a/pkg/tracking/service/query/query.go +++ b/pkg/tracking/service/query/query.go @@ -1,37 +1,37 @@ -package query - -import ( - "fmt" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" -) - -func ParseFilter(input string) ([]*parser.ValidCompareExpr, error) { - if input == "" { - return make([]*parser.ValidCompareExpr, 0), nil - } - - tokens, err := lexer.Tokenize(&input) - if err != nil { - return nil, fmt.Errorf("error while lexing %s: %w", input, err) - } - - ast, err := parser.Parse(tokens) - if err != nil { - return nil, fmt.Errorf("error while parsing %s: %w", input, err) - } - - validExpressions := make([]*parser.ValidCompareExpr, 0, len(ast.Exprs)) - - for _, expr := range ast.Exprs { - ve, err := parser.ValidateExpression(expr) - if err != nil { - return nil, fmt.Errorf("error while validating %s: %w", input, err) - } - - validExpressions = append(validExpressions, ve) - } - - return validExpressions, nil -} +package query + +import ( + "fmt" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/lexer" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" +) + +func ParseFilter(input string) ([]*parser.ValidCompareExpr, error) { + if input == "" { + return make([]*parser.ValidCompareExpr, 0), nil + } + + tokens, err := lexer.Tokenize(&input) + if err != nil { + return nil, fmt.Errorf("error while lexing %s: %w", input, err) + } + + ast, err := parser.Parse(tokens) + if err != nil { + return nil, fmt.Errorf("error while parsing %s: %w", input, err) + } + + validExpressions := make([]*parser.ValidCompareExpr, 0, len(ast.Exprs)) + + for _, expr := range ast.Exprs { + ve, err := parser.ValidateExpression(expr) + if err != nil { + return nil, fmt.Errorf("error while validating %s: %w", input, err) + } + + validExpressions = append(validExpressions, ve) + } + + return validExpressions, nil +} diff --git a/pkg/tracking/service/query/query_test.go b/pkg/tracking/service/query/query_test.go index 95b2277..50192cb 100644 --- a/pkg/tracking/service/query/query_test.go +++ b/pkg/tracking/service/query/query_test.go @@ -1,112 +1,112 @@ -package query_test - -import ( - "strings" - "testing" - - "github.com/mlflow/mlflow-go/pkg/tracking/service/query" -) - -func TestValidQueries(t *testing.T) { - t.Parallel() - - samples := []string{ - "metrics.foobar = 40", - "metrics.foobar = 40 AND run_name = \"bouncy-boar-498\"", - "tags.\"mlflow.source.name\" = \"scratch.py\"", - "metrics.accuracy > 0.9", - "params.\"random_state\" = \"8888\"", - "params.`random_state` = \"8888\"", - "params.solver ILIKE \"L%\"", - "params.solver LIKE \"l%\"", - "datasets.digest IN ('77a19fc0')", - "attributes.run_id IN ('meh')", - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample, func(t *testing.T) { - t.Parallel() - - _, err := query.ParseFilter(currentSample) - if err != nil { - t.Errorf("unexpected parse error: %v", err) - } - }) - } -} - -type invalidSample struct { - input string - expectedError string -} - -//nolint:funlen -func TestInvalidQueries(t *testing.T) { - t.Parallel() - - samples := []invalidSample{ - { - input: "yow.foobar = 40", - expectedError: "invalid identifier", - }, - { - input: "attributes.foobar = 40", - expectedError: "Invalid attribute key '{foobar}' specified. " + - "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", - }, - { - input: "datasets.foobar = 40", - expectedError: "Invalid dataset key '{foobar}' specified. " + - "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", - }, - { - input: "metric.yow = 'z'", - expectedError: "expected numeric value type for metric.", - }, - { - input: "parameter.tag = 2", - expectedError: "expected a quoted string value", - }, - { - input: "attributes.start_time = 'now'", - expectedError: "expected numeric value type for numeric attribute", - }, - { - input: "attributes.run_name IN ('foo','bar')", - expectedError: "only the 'run_id' attribute supports comparison with a list", - }, - { - input: "datasets.name = 40", - expectedError: "expected datasets.name to be either a string or list of strings", - }, - { - input: "datasets.digest = 50", - expectedError: "expected datasets.digest to be either a string or list of strings", - }, - { - input: "datasets.context = 60", - expectedError: "expected datasets.context to be either a string or list of strings", - }, - } - - for _, sample := range samples { - currentSample := sample - t.Run(currentSample.input, func(t *testing.T) { - t.Parallel() - - _, err := query.ParseFilter(currentSample.input) - if err == nil { - t.Errorf("expected parse error but got nil") - } - - if !strings.Contains(err.Error(), currentSample.expectedError) { - t.Errorf( - "expected error to contain %q, got %q", - currentSample.expectedError, - err.Error(), - ) - } - }) - } -} +package query_test + +import ( + "strings" + "testing" + + "github.com/mlflow/mlflow-go/pkg/tracking/service/query" +) + +func TestValidQueries(t *testing.T) { + t.Parallel() + + samples := []string{ + "metrics.foobar = 40", + "metrics.foobar = 40 AND run_name = \"bouncy-boar-498\"", + "tags.\"mlflow.source.name\" = \"scratch.py\"", + "metrics.accuracy > 0.9", + "params.\"random_state\" = \"8888\"", + "params.`random_state` = \"8888\"", + "params.solver ILIKE \"L%\"", + "params.solver LIKE \"l%\"", + "datasets.digest IN ('77a19fc0')", + "attributes.run_id IN ('meh')", + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample, func(t *testing.T) { + t.Parallel() + + _, err := query.ParseFilter(currentSample) + if err != nil { + t.Errorf("unexpected parse error: %v", err) + } + }) + } +} + +type invalidSample struct { + input string + expectedError string +} + +//nolint:funlen +func TestInvalidQueries(t *testing.T) { + t.Parallel() + + samples := []invalidSample{ + { + input: "yow.foobar = 40", + expectedError: "invalid identifier", + }, + { + input: "attributes.foobar = 40", + expectedError: "Invalid attribute key '{foobar}' specified. " + + "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", + }, + { + input: "datasets.foobar = 40", + expectedError: "Invalid dataset key '{foobar}' specified. " + + "Valid keys are '[run_id run_name user_id status start_time end_time artifact_uri]'", + }, + { + input: "metric.yow = 'z'", + expectedError: "expected numeric value type for metric.", + }, + { + input: "parameter.tag = 2", + expectedError: "expected a quoted string value", + }, + { + input: "attributes.start_time = 'now'", + expectedError: "expected numeric value type for numeric attribute", + }, + { + input: "attributes.run_name IN ('foo','bar')", + expectedError: "only the 'run_id' attribute supports comparison with a list", + }, + { + input: "datasets.name = 40", + expectedError: "expected datasets.name to be either a string or list of strings", + }, + { + input: "datasets.digest = 50", + expectedError: "expected datasets.digest to be either a string or list of strings", + }, + { + input: "datasets.context = 60", + expectedError: "expected datasets.context to be either a string or list of strings", + }, + } + + for _, sample := range samples { + currentSample := sample + t.Run(currentSample.input, func(t *testing.T) { + t.Parallel() + + _, err := query.ParseFilter(currentSample.input) + if err == nil { + t.Errorf("expected parse error but got nil") + } + + if !strings.Contains(err.Error(), currentSample.expectedError) { + t.Errorf( + "expected error to contain %q, got %q", + currentSample.expectedError, + err.Error(), + ) + } + }) + } +} diff --git a/pkg/tracking/service/runs.go b/pkg/tracking/service/runs.go index cf20e2b..15835c5 100644 --- a/pkg/tracking/service/runs.go +++ b/pkg/tracking/service/runs.go @@ -1,168 +1,168 @@ -package service - -import ( - "context" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -func (ts TrackingService) SearchRuns( - ctx context.Context, input *protos.SearchRuns, -) (*protos.SearchRuns_Response, *contract.Error) { - var runViewType protos.ViewType - if input.RunViewType == nil { - runViewType = protos.ViewType_ALL - } else { - runViewType = input.GetRunViewType() - } - - maxResults := int(input.GetMaxResults()) - - runs, nextPageToken, err := ts.Store.SearchRuns( - ctx, - input.GetExperimentIds(), - input.GetFilter(), - runViewType, - maxResults, - input.GetOrderBy(), - input.GetPageToken(), - ) - if err != nil { - return nil, contract.NewError(protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error getting runs: %v", err)) - } - - response := protos.SearchRuns_Response{ - Runs: make([]*protos.Run, len(runs)), - NextPageToken: &nextPageToken, - } - - for i, run := range runs { - response.Runs[i] = run.ToProto() - } - - return &response, nil -} - -func (ts TrackingService) LogBatch( - ctx context.Context, input *protos.LogBatch, -) (*protos.LogBatch_Response, *contract.Error) { - metrics := make([]*entities.Metric, len(input.GetMetrics())) - for i, metric := range input.GetMetrics() { - metrics[i] = entities.MetricFromProto(metric) - } - - params := make([]*entities.Param, len(input.GetParams())) - for i, param := range input.GetParams() { - params[i] = entities.ParamFromProto(param) - } - - tags := make([]*entities.RunTag, len(input.GetTags())) - for i, tag := range input.GetTags() { - tags[i] = entities.NewTagFromProto(tag) - } - - err := ts.Store.LogBatch(ctx, input.GetRunId(), metrics, params, tags) - if err != nil { - return nil, err - } - - return &protos.LogBatch_Response{}, nil -} - -func (ts TrackingService) GetRun( - ctx context.Context, input *protos.GetRun, -) (*protos.GetRun_Response, *contract.Error) { - run, err := ts.Store.GetRun(ctx, input.GetRunId()) - if err != nil { - return nil, err - } - - return &protos.GetRun_Response{Run: run.ToProto()}, nil -} - -func (ts TrackingService) CreateRun( - ctx context.Context, input *protos.CreateRun, -) (*protos.CreateRun_Response, *contract.Error) { - tags := make([]*entities.RunTag, 0, len(input.GetTags())) - for _, tag := range input.GetTags() { - tags = append(tags, entities.NewTagFromProto(tag)) - } - - run, err := ts.Store.CreateRun( - ctx, - input.GetExperimentId(), - input.GetUserId(), - input.GetStartTime(), - tags, - input.GetRunName(), - ) - if err != nil { - return nil, err - } - - return &protos.CreateRun_Response{Run: run.ToProto()}, nil -} - -func (ts TrackingService) UpdateRun( - ctx context.Context, input *protos.UpdateRun, -) (*protos.UpdateRun_Response, *contract.Error) { - run, err := ts.Store.GetRun(ctx, input.GetRunId()) - if err != nil { - return nil, err - } - - if run.Info.LifecycleStage != string(models.LifecycleStageActive) { - return nil, contract.NewError( - protos.ErrorCode_INVALID_STATE, - fmt.Sprintf( - "The run %s must be in the 'active' state. Current state is %s.", - input.GetRunUuid(), - run.Info.LifecycleStage, - ), - ) - } - - if status := input.GetStatus(); status != 0 { - run.Info.Status = status.String() - } - - if runName := input.GetRunName(); runName != "" { - run.Info.RunName = runName - } - - if err := ts.Store.UpdateRun( - ctx, - run.Info.RunID, - run.Info.Status, - input.EndTime, - run.Info.RunName, - ); err != nil { - return nil, err - } - - return &protos.UpdateRun_Response{RunInfo: run.Info.ToProto()}, nil -} - -func (ts TrackingService) DeleteRun( - ctx context.Context, input *protos.DeleteRun, -) (*protos.DeleteRun_Response, *contract.Error) { - if err := ts.Store.DeleteRun(ctx, input.GetRunId()); err != nil { - return nil, err - } - - return &protos.DeleteRun_Response{}, nil -} - -func (ts TrackingService) RestoreRun( - ctx context.Context, input *protos.RestoreRun, -) (*protos.RestoreRun_Response, *contract.Error) { - if err := ts.Store.RestoreRun(ctx, input.GetRunId()); err != nil { - return nil, err - } - - return &protos.RestoreRun_Response{}, nil -} +package service + +import ( + "context" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +func (ts TrackingService) SearchRuns( + ctx context.Context, input *protos.SearchRuns, +) (*protos.SearchRuns_Response, *contract.Error) { + var runViewType protos.ViewType + if input.RunViewType == nil { + runViewType = protos.ViewType_ALL + } else { + runViewType = input.GetRunViewType() + } + + maxResults := int(input.GetMaxResults()) + + runs, nextPageToken, err := ts.Store.SearchRuns( + ctx, + input.GetExperimentIds(), + input.GetFilter(), + runViewType, + maxResults, + input.GetOrderBy(), + input.GetPageToken(), + ) + if err != nil { + return nil, contract.NewError(protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error getting runs: %v", err)) + } + + response := protos.SearchRuns_Response{ + Runs: make([]*protos.Run, len(runs)), + NextPageToken: &nextPageToken, + } + + for i, run := range runs { + response.Runs[i] = run.ToProto() + } + + return &response, nil +} + +func (ts TrackingService) LogBatch( + ctx context.Context, input *protos.LogBatch, +) (*protos.LogBatch_Response, *contract.Error) { + metrics := make([]*entities.Metric, len(input.GetMetrics())) + for i, metric := range input.GetMetrics() { + metrics[i] = entities.MetricFromProto(metric) + } + + params := make([]*entities.Param, len(input.GetParams())) + for i, param := range input.GetParams() { + params[i] = entities.ParamFromProto(param) + } + + tags := make([]*entities.RunTag, len(input.GetTags())) + for i, tag := range input.GetTags() { + tags[i] = entities.NewTagFromProto(tag) + } + + err := ts.Store.LogBatch(ctx, input.GetRunId(), metrics, params, tags) + if err != nil { + return nil, err + } + + return &protos.LogBatch_Response{}, nil +} + +func (ts TrackingService) GetRun( + ctx context.Context, input *protos.GetRun, +) (*protos.GetRun_Response, *contract.Error) { + run, err := ts.Store.GetRun(ctx, input.GetRunId()) + if err != nil { + return nil, err + } + + return &protos.GetRun_Response{Run: run.ToProto()}, nil +} + +func (ts TrackingService) CreateRun( + ctx context.Context, input *protos.CreateRun, +) (*protos.CreateRun_Response, *contract.Error) { + tags := make([]*entities.RunTag, 0, len(input.GetTags())) + for _, tag := range input.GetTags() { + tags = append(tags, entities.NewTagFromProto(tag)) + } + + run, err := ts.Store.CreateRun( + ctx, + input.GetExperimentId(), + input.GetUserId(), + input.GetStartTime(), + tags, + input.GetRunName(), + ) + if err != nil { + return nil, err + } + + return &protos.CreateRun_Response{Run: run.ToProto()}, nil +} + +func (ts TrackingService) UpdateRun( + ctx context.Context, input *protos.UpdateRun, +) (*protos.UpdateRun_Response, *contract.Error) { + run, err := ts.Store.GetRun(ctx, input.GetRunId()) + if err != nil { + return nil, err + } + + if run.Info.LifecycleStage != string(models.LifecycleStageActive) { + return nil, contract.NewError( + protos.ErrorCode_INVALID_STATE, + fmt.Sprintf( + "The run %s must be in the 'active' state. Current state is %s.", + input.GetRunUuid(), + run.Info.LifecycleStage, + ), + ) + } + + if status := input.GetStatus(); status != 0 { + run.Info.Status = status.String() + } + + if runName := input.GetRunName(); runName != "" { + run.Info.RunName = runName + } + + if err := ts.Store.UpdateRun( + ctx, + run.Info.RunID, + run.Info.Status, + input.EndTime, + run.Info.RunName, + ); err != nil { + return nil, err + } + + return &protos.UpdateRun_Response{RunInfo: run.Info.ToProto()}, nil +} + +func (ts TrackingService) DeleteRun( + ctx context.Context, input *protos.DeleteRun, +) (*protos.DeleteRun_Response, *contract.Error) { + if err := ts.Store.DeleteRun(ctx, input.GetRunId()); err != nil { + return nil, err + } + + return &protos.DeleteRun_Response{}, nil +} + +func (ts TrackingService) RestoreRun( + ctx context.Context, input *protos.RestoreRun, +) (*protos.RestoreRun_Response, *contract.Error) { + if err := ts.Store.RestoreRun(ctx, input.GetRunId()); err != nil { + return nil, err + } + + return &protos.RestoreRun_Response{}, nil +} diff --git a/pkg/tracking/service/service.go b/pkg/tracking/service/service.go index 8218627..074b854 100644 --- a/pkg/tracking/service/service.go +++ b/pkg/tracking/service/service.go @@ -1,27 +1,27 @@ -package service - -import ( - "context" - "fmt" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/tracking/store" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql" -) - -type TrackingService struct { - config *config.Config - Store store.TrackingStore -} - -func NewTrackingService(ctx context.Context, config *config.Config) (*TrackingService, error) { - store, err := sql.NewTrackingSQLStore(ctx, config) - if err != nil { - return nil, fmt.Errorf("failed to create new sql store: %w", err) - } - - return &TrackingService{ - config: config, - Store: store, - }, nil -} +package service + +import ( + "context" + "fmt" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/tracking/store" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql" +) + +type TrackingService struct { + config *config.Config + Store store.TrackingStore +} + +func NewTrackingService(ctx context.Context, config *config.Config) (*TrackingService, error) { + store, err := sql.NewTrackingSQLStore(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to create new sql store: %w", err) + } + + return &TrackingService{ + config: config, + Store: store, + }, nil +} diff --git a/pkg/tracking/store/sql/experiments.go b/pkg/tracking/store/sql/experiments.go index 1c65e4a..7c5eba6 100644 --- a/pkg/tracking/store/sql/experiments.go +++ b/pkg/tracking/store/sql/experiments.go @@ -1,254 +1,254 @@ -package sql - -import ( - "context" - "database/sql" - "errors" - "fmt" - "strconv" - "time" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -func (s TrackingSQLStore) GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) { - idInt, err := strconv.ParseInt(id, 10, 32) - if err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("failed to convert experiment id %q to int", id), - err, - ) - } - - experiment := models.Experiment{ID: int32(idInt)} - if err := s.db.WithContext(ctx).Preload("Tags").First(&experiment).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("No Experiment with id=%d exists", idInt), - ) - } - - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to get experiment", - err, - ) - } - - return experiment.ToEntity(), nil -} - -func (s TrackingSQLStore) CreateExperiment( - ctx context.Context, - name string, - artifactLocation string, - tags []*entities.ExperimentTag, -) (string, *contract.Error) { - experiment := models.Experiment{ - Name: name, - Tags: make([]models.ExperimentTag, len(tags)), - ArtifactLocation: artifactLocation, - LifecycleStage: models.LifecycleStageActive, - CreationTime: time.Now().UnixMilli(), - LastUpdateTime: time.Now().UnixMilli(), - } - - for i, tag := range tags { - experiment.Tags[i] = models.ExperimentTag{ - Key: tag.Key, - Value: tag.Value, - } - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - if err := transaction.Create(&experiment).Error; err != nil { - return fmt.Errorf("failed to insert experiment: %w", err) - } - - if experiment.ArtifactLocation == "" { - artifactLocation, err := utils.AppendToURIPath(s.config.DefaultArtifactRoot, strconv.Itoa(int(experiment.ID))) - if err != nil { - return fmt.Errorf("failed to join artifact location: %w", err) - } - experiment.ArtifactLocation = artifactLocation - if err := transaction.Model(&experiment).UpdateColumn("artifact_location", artifactLocation).Error; err != nil { - return fmt.Errorf("failed to update experiment artifact location: %w", err) - } - } - - return nil - }); err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - return "", contract.NewError( - protos.ErrorCode_RESOURCE_ALREADY_EXISTS, - fmt.Sprintf("Experiment(name=%s) already exists.", experiment.Name), - ) - } - - return "", contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to create experiment", err) - } - - return strconv.Itoa(int(experiment.ID)), nil -} - -func (s TrackingSQLStore) RenameExperiment( - ctx context.Context, experimentID, name string, -) *contract.Error { - if err := s.db.WithContext(ctx).Model(&models.Experiment{}). - Where("experiment_id = ?", experimentID). - Updates(&models.Experiment{ - Name: name, - LastUpdateTime: time.Now().UnixMilli(), - }).Error; err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update experiment", err) - } - - return nil -} - -func (s TrackingSQLStore) DeleteExperiment(ctx context.Context, id string) *contract.Error { - idInt, err := strconv.ParseInt(id, 10, 32) - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("failed to convert experiment id (%s) to int", id), - err, - ) - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - // Update experiment - uex := transaction.Model(&models.Experiment{}). - Where("experiment_id = ?", idInt). - Updates(&models.Experiment{ - LifecycleStage: models.LifecycleStageDeleted, - LastUpdateTime: time.Now().UnixMilli(), - }) - - if uex.Error != nil { - return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) - } - - if uex.RowsAffected != 1 { - return gorm.ErrRecordNotFound - } - - // Update runs - if err := transaction.Model(&models.Run{}). - Where("experiment_id = ?", idInt). - Updates(&models.Run{ - LifecycleStage: models.LifecycleStageDeleted, - DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, - }).Error; err != nil { - return fmt.Errorf("failed to update runs during delete: %w", err) - } - - return nil - }); err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("No Experiment with id=%d exists", idInt), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to delete experiment", - err, - ) - } - - return nil -} - -func (s TrackingSQLStore) RestoreExperiment(ctx context.Context, id string) *contract.Error { - idInt, err := strconv.ParseInt(id, 10, 32) - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("failed to convert experiment id (%s) to int", id), - err, - ) - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - // Update experiment - uex := transaction.Model(&models.Experiment{}). - Where("experiment_id = ?", idInt). - Where("lifecycle_stage = ?", models.LifecycleStageDeleted). - Updates(&models.Experiment{ - LifecycleStage: models.LifecycleStageActive, - LastUpdateTime: time.Now().UnixMilli(), - }) - - if uex.Error != nil { - return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) - } - - if uex.RowsAffected != 1 { - return gorm.ErrRecordNotFound - } - - // Update runs - if err := transaction.Model(&models.Run{}). - Where("experiment_id = ?", idInt). - Select("DeletedTime", "LifecycleStage"). - Updates(&models.Run{ - LifecycleStage: models.LifecycleStageActive, - DeletedTime: sql.NullInt64{}, - }).Error; err != nil { - return fmt.Errorf("failed to update runs during restore: %w", err) - } - - return nil - }); err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("No Experiment with id=%d exists", idInt), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to delete experiment", - err, - ) - } - - return nil -} - -//nolint:perfsprint -func (s TrackingSQLStore) GetExperimentByName( - ctx context.Context, name string, -) (*entities.Experiment, *contract.Error) { - var experiment models.Experiment - - err := s.db.WithContext(ctx).Preload("Tags").Where("name = ?", name).First(&experiment).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("Could not find experiment with name %s", name), - ) - } - - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to get experiment by name %s", name), - err, - ) - } - - return experiment.ToEntity(), nil -} +package sql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "time" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +func (s TrackingSQLStore) GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) { + idInt, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("failed to convert experiment id %q to int", id), + err, + ) + } + + experiment := models.Experiment{ID: int32(idInt)} + if err := s.db.WithContext(ctx).Preload("Tags").First(&experiment).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No Experiment with id=%d exists", idInt), + ) + } + + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to get experiment", + err, + ) + } + + return experiment.ToEntity(), nil +} + +func (s TrackingSQLStore) CreateExperiment( + ctx context.Context, + name string, + artifactLocation string, + tags []*entities.ExperimentTag, +) (string, *contract.Error) { + experiment := models.Experiment{ + Name: name, + Tags: make([]models.ExperimentTag, len(tags)), + ArtifactLocation: artifactLocation, + LifecycleStage: models.LifecycleStageActive, + CreationTime: time.Now().UnixMilli(), + LastUpdateTime: time.Now().UnixMilli(), + } + + for i, tag := range tags { + experiment.Tags[i] = models.ExperimentTag{ + Key: tag.Key, + Value: tag.Value, + } + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + if err := transaction.Create(&experiment).Error; err != nil { + return fmt.Errorf("failed to insert experiment: %w", err) + } + + if experiment.ArtifactLocation == "" { + artifactLocation, err := utils.AppendToURIPath(s.config.DefaultArtifactRoot, strconv.Itoa(int(experiment.ID))) + if err != nil { + return fmt.Errorf("failed to join artifact location: %w", err) + } + experiment.ArtifactLocation = artifactLocation + if err := transaction.Model(&experiment).UpdateColumn("artifact_location", artifactLocation).Error; err != nil { + return fmt.Errorf("failed to update experiment artifact location: %w", err) + } + } + + return nil + }); err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return "", contract.NewError( + protos.ErrorCode_RESOURCE_ALREADY_EXISTS, + fmt.Sprintf("Experiment(name=%s) already exists.", experiment.Name), + ) + } + + return "", contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to create experiment", err) + } + + return strconv.Itoa(int(experiment.ID)), nil +} + +func (s TrackingSQLStore) RenameExperiment( + ctx context.Context, experimentID, name string, +) *contract.Error { + if err := s.db.WithContext(ctx).Model(&models.Experiment{}). + Where("experiment_id = ?", experimentID). + Updates(&models.Experiment{ + Name: name, + LastUpdateTime: time.Now().UnixMilli(), + }).Error; err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update experiment", err) + } + + return nil +} + +func (s TrackingSQLStore) DeleteExperiment(ctx context.Context, id string) *contract.Error { + idInt, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("failed to convert experiment id (%s) to int", id), + err, + ) + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + // Update experiment + uex := transaction.Model(&models.Experiment{}). + Where("experiment_id = ?", idInt). + Updates(&models.Experiment{ + LifecycleStage: models.LifecycleStageDeleted, + LastUpdateTime: time.Now().UnixMilli(), + }) + + if uex.Error != nil { + return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) + } + + if uex.RowsAffected != 1 { + return gorm.ErrRecordNotFound + } + + // Update runs + if err := transaction.Model(&models.Run{}). + Where("experiment_id = ?", idInt). + Updates(&models.Run{ + LifecycleStage: models.LifecycleStageDeleted, + DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, + }).Error; err != nil { + return fmt.Errorf("failed to update runs during delete: %w", err) + } + + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No Experiment with id=%d exists", idInt), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to delete experiment", + err, + ) + } + + return nil +} + +func (s TrackingSQLStore) RestoreExperiment(ctx context.Context, id string) *contract.Error { + idInt, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("failed to convert experiment id (%s) to int", id), + err, + ) + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + // Update experiment + uex := transaction.Model(&models.Experiment{}). + Where("experiment_id = ?", idInt). + Where("lifecycle_stage = ?", models.LifecycleStageDeleted). + Updates(&models.Experiment{ + LifecycleStage: models.LifecycleStageActive, + LastUpdateTime: time.Now().UnixMilli(), + }) + + if uex.Error != nil { + return fmt.Errorf("failed to update experiment (%d) during delete: %w", idInt, err) + } + + if uex.RowsAffected != 1 { + return gorm.ErrRecordNotFound + } + + // Update runs + if err := transaction.Model(&models.Run{}). + Where("experiment_id = ?", idInt). + Select("DeletedTime", "LifecycleStage"). + Updates(&models.Run{ + LifecycleStage: models.LifecycleStageActive, + DeletedTime: sql.NullInt64{}, + }).Error; err != nil { + return fmt.Errorf("failed to update runs during restore: %w", err) + } + + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No Experiment with id=%d exists", idInt), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to delete experiment", + err, + ) + } + + return nil +} + +//nolint:perfsprint +func (s TrackingSQLStore) GetExperimentByName( + ctx context.Context, name string, +) (*entities.Experiment, *contract.Error) { + var experiment models.Experiment + + err := s.db.WithContext(ctx).Preload("Tags").Where("name = ?", name).First(&experiment).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("Could not find experiment with name %s", name), + ) + } + + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to get experiment by name %s", name), + err, + ) + } + + return experiment.ToEntity(), nil +} diff --git a/pkg/tracking/store/sql/metrics.go b/pkg/tracking/store/sql/metrics.go index df2b60e..fbfdccc 100644 --- a/pkg/tracking/store/sql/metrics.go +++ b/pkg/tracking/store/sql/metrics.go @@ -1,193 +1,193 @@ -package sql - -import ( - "context" - "errors" - "fmt" - "math" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -const metricsBatchSize = 500 - -func getDistinctMetricKeys(metrics []models.Metric) []string { - metricKeysMap := make(map[string]any) - for _, m := range metrics { - metricKeysMap[m.Key] = nil - } - - metricKeys := make([]string, 0, len(metricKeysMap)) - for key := range metricKeysMap { - metricKeys = append(metricKeys, key) - } - - return metricKeys -} - -func getLatestMetrics(transaction *gorm.DB, runID string, metricKeys []string) ([]models.LatestMetric, error) { - const batchSize = 500 - - latestMetrics := make([]models.LatestMetric, 0, len(metricKeys)) - - for skip := 0; skip < len(metricKeys); skip += batchSize { - take := int(math.Max(float64(skip+batchSize), float64(len(metricKeys)))) - if take > len(metricKeys) { - take = len(metricKeys) - } - - currentBatch := make([]models.LatestMetric, 0, take-skip) - keys := metricKeys[skip:take] - - err := transaction. - Model(&models.LatestMetric{}). - Where("run_uuid = ?", runID).Where("key IN ?", keys). - Clauses(clause.Locking{Strength: "UPDATE"}). // https://gorm.io/docs/advanced_query.html#Locking - Order("run_uuid"). - Order("key"). - Find(¤tBatch).Error - if err != nil { - return latestMetrics, fmt.Errorf( - "failed to get latest metrics for run_uuid %q, skip %d, take %d : %w", - runID, skip, take, err, - ) - } - - latestMetrics = append(latestMetrics, currentBatch...) - } - - return latestMetrics, nil -} - -func isNewerMetric(a models.Metric, b models.LatestMetric) bool { - return a.Step > b.Step || - (a.Step == b.Step && a.Timestamp > b.Timestamp) || - (a.Step == b.Step && a.Timestamp == b.Timestamp && a.Value > b.Value) -} - -//nolint:cyclop -func updateLatestMetricsIfNecessary(transaction *gorm.DB, runID string, metrics []models.Metric) error { - if len(metrics) == 0 { - return nil - } - - metricKeys := getDistinctMetricKeys(metrics) - - latestMetrics, err := getLatestMetrics(transaction, runID, metricKeys) - if err != nil { - return fmt.Errorf("failed to get latest metrics for run_uuid %q: %w", runID, err) - } - - latestMetricsMap := make(map[string]models.LatestMetric, len(latestMetrics)) - for _, m := range latestMetrics { - latestMetricsMap[m.Key] = m - } - - nextLatestMetricsMap := make(map[string]models.LatestMetric, len(metrics)) - - for _, metric := range metrics { - latestMetric, found := latestMetricsMap[metric.Key] - nextLatestMetric, alreadyPresent := nextLatestMetricsMap[metric.Key] - - switch { - case !found && !alreadyPresent: - // brand new latest metric - nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() - case !found && alreadyPresent && isNewerMetric(metric, nextLatestMetric): - // there is no row in the database but the metric is present twice - // and we need to take the latest one from the batch. - nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() - case found && isNewerMetric(metric, latestMetric): - // compare with the row in the database - nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() - } - } - - nextLatestMetrics := make([]models.LatestMetric, 0, len(nextLatestMetricsMap)) - for _, nextLatestMetric := range nextLatestMetricsMap { - nextLatestMetrics = append(nextLatestMetrics, nextLatestMetric) - } - - if len(nextLatestMetrics) != 0 { - if err := transaction.Clauses(clause.OnConflict{ - UpdateAll: true, - }).Create(nextLatestMetrics).Error; err != nil { - return fmt.Errorf("failed to upsert latest metrics for run_uuid %q: %w", runID, err) - } - } - - return nil -} - -func (s TrackingSQLStore) logMetricsWithTransaction( - transaction *gorm.DB, runID string, metrics []*entities.Metric, -) *contract.Error { - // Duplicate metric values are eliminated - seenMetrics := make(map[models.Metric]struct{}) - modelMetrics := make([]models.Metric, 0, len(metrics)) - - for _, metric := range metrics { - currentMetric := models.NewMetricFromEntity(runID, metric) - if _, ok := seenMetrics[*currentMetric]; !ok { - seenMetrics[*currentMetric] = struct{}{} - - modelMetrics = append(modelMetrics, *currentMetric) - } - } - - if err := transaction.Clauses(clause.OnConflict{DoNothing: true}). - CreateInBatches(modelMetrics, metricsBatchSize).Error; err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("error creating metrics in batch for run_uuid %q", runID), - err, - ) - } - - if err := updateLatestMetricsIfNecessary(transaction, runID, modelMetrics); err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("error updating latest metrics for run_uuid %q", runID), - err, - ) - } - - return nil -} - -func (s TrackingSQLStore) LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error { - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - contractError := checkRunIsActive(transaction, runID) - if contractError != nil { - return contractError - } - - if err := s.logMetricsWithTransaction(transaction, runID, []*entities.Metric{ - metric, - }); err != nil { - return err - } - - return nil - }) - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return contractError - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("log metric transaction failed for %q", runID), - err, - ) - } - - return nil -} +package sql + +import ( + "context" + "errors" + "fmt" + "math" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +const metricsBatchSize = 500 + +func getDistinctMetricKeys(metrics []models.Metric) []string { + metricKeysMap := make(map[string]any) + for _, m := range metrics { + metricKeysMap[m.Key] = nil + } + + metricKeys := make([]string, 0, len(metricKeysMap)) + for key := range metricKeysMap { + metricKeys = append(metricKeys, key) + } + + return metricKeys +} + +func getLatestMetrics(transaction *gorm.DB, runID string, metricKeys []string) ([]models.LatestMetric, error) { + const batchSize = 500 + + latestMetrics := make([]models.LatestMetric, 0, len(metricKeys)) + + for skip := 0; skip < len(metricKeys); skip += batchSize { + take := int(math.Max(float64(skip+batchSize), float64(len(metricKeys)))) + if take > len(metricKeys) { + take = len(metricKeys) + } + + currentBatch := make([]models.LatestMetric, 0, take-skip) + keys := metricKeys[skip:take] + + err := transaction. + Model(&models.LatestMetric{}). + Where("run_uuid = ?", runID).Where("key IN ?", keys). + Clauses(clause.Locking{Strength: "UPDATE"}). // https://gorm.io/docs/advanced_query.html#Locking + Order("run_uuid"). + Order("key"). + Find(¤tBatch).Error + if err != nil { + return latestMetrics, fmt.Errorf( + "failed to get latest metrics for run_uuid %q, skip %d, take %d : %w", + runID, skip, take, err, + ) + } + + latestMetrics = append(latestMetrics, currentBatch...) + } + + return latestMetrics, nil +} + +func isNewerMetric(a models.Metric, b models.LatestMetric) bool { + return a.Step > b.Step || + (a.Step == b.Step && a.Timestamp > b.Timestamp) || + (a.Step == b.Step && a.Timestamp == b.Timestamp && a.Value > b.Value) +} + +//nolint:cyclop +func updateLatestMetricsIfNecessary(transaction *gorm.DB, runID string, metrics []models.Metric) error { + if len(metrics) == 0 { + return nil + } + + metricKeys := getDistinctMetricKeys(metrics) + + latestMetrics, err := getLatestMetrics(transaction, runID, metricKeys) + if err != nil { + return fmt.Errorf("failed to get latest metrics for run_uuid %q: %w", runID, err) + } + + latestMetricsMap := make(map[string]models.LatestMetric, len(latestMetrics)) + for _, m := range latestMetrics { + latestMetricsMap[m.Key] = m + } + + nextLatestMetricsMap := make(map[string]models.LatestMetric, len(metrics)) + + for _, metric := range metrics { + latestMetric, found := latestMetricsMap[metric.Key] + nextLatestMetric, alreadyPresent := nextLatestMetricsMap[metric.Key] + + switch { + case !found && !alreadyPresent: + // brand new latest metric + nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() + case !found && alreadyPresent && isNewerMetric(metric, nextLatestMetric): + // there is no row in the database but the metric is present twice + // and we need to take the latest one from the batch. + nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() + case found && isNewerMetric(metric, latestMetric): + // compare with the row in the database + nextLatestMetricsMap[metric.Key] = metric.NewLatestMetricFromProto() + } + } + + nextLatestMetrics := make([]models.LatestMetric, 0, len(nextLatestMetricsMap)) + for _, nextLatestMetric := range nextLatestMetricsMap { + nextLatestMetrics = append(nextLatestMetrics, nextLatestMetric) + } + + if len(nextLatestMetrics) != 0 { + if err := transaction.Clauses(clause.OnConflict{ + UpdateAll: true, + }).Create(nextLatestMetrics).Error; err != nil { + return fmt.Errorf("failed to upsert latest metrics for run_uuid %q: %w", runID, err) + } + } + + return nil +} + +func (s TrackingSQLStore) logMetricsWithTransaction( + transaction *gorm.DB, runID string, metrics []*entities.Metric, +) *contract.Error { + // Duplicate metric values are eliminated + seenMetrics := make(map[models.Metric]struct{}) + modelMetrics := make([]models.Metric, 0, len(metrics)) + + for _, metric := range metrics { + currentMetric := models.NewMetricFromEntity(runID, metric) + if _, ok := seenMetrics[*currentMetric]; !ok { + seenMetrics[*currentMetric] = struct{}{} + + modelMetrics = append(modelMetrics, *currentMetric) + } + } + + if err := transaction.Clauses(clause.OnConflict{DoNothing: true}). + CreateInBatches(modelMetrics, metricsBatchSize).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("error creating metrics in batch for run_uuid %q", runID), + err, + ) + } + + if err := updateLatestMetricsIfNecessary(transaction, runID, modelMetrics); err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("error updating latest metrics for run_uuid %q", runID), + err, + ) + } + + return nil +} + +func (s TrackingSQLStore) LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error { + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + if err := s.logMetricsWithTransaction(transaction, runID, []*entities.Metric{ + metric, + }); err != nil { + return err + } + + return nil + }) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("log metric transaction failed for %q", runID), + err, + ) + } + + return nil +} diff --git a/pkg/tracking/store/sql/models/alembic_version.go b/pkg/tracking/store/sql/models/alembic_version.go index b97ed95..bf30c95 100644 --- a/pkg/tracking/store/sql/models/alembic_version.go +++ b/pkg/tracking/store/sql/models/alembic_version.go @@ -1,11 +1,11 @@ -package models - -// AlembicVersion mapped from table . -type AlembicVersion struct { - VersionNum *string `db:"version_num" gorm:"column:version_num;primaryKey"` -} - -// TableName AlembicVersion's table name. -func (*AlembicVersion) TableName() string { - return "alembic_version" -} +package models + +// AlembicVersion mapped from table . +type AlembicVersion struct { + VersionNum *string `db:"version_num" gorm:"column:version_num;primaryKey"` +} + +// TableName AlembicVersion's table name. +func (*AlembicVersion) TableName() string { + return "alembic_version" +} diff --git a/pkg/tracking/store/sql/models/datasets.go b/pkg/tracking/store/sql/models/datasets.go index 6618375..1fda973 100644 --- a/pkg/tracking/store/sql/models/datasets.go +++ b/pkg/tracking/store/sql/models/datasets.go @@ -1,28 +1,28 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Dataset mapped from table . -type Dataset struct { - ID string `db:"dataset_uuid" gorm:"column:dataset_uuid;not null"` - ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` - Name string `db:"name" gorm:"column:name;primaryKey"` - Digest string `db:"digest" gorm:"column:digest;primaryKey"` - SourceType string `db:"dataset_source_type" gorm:"column:dataset_source_type;not null"` - Source string `db:"dataset_source" gorm:"column:dataset_source;not null"` - Schema string `db:"dataset_schema" gorm:"column:dataset_schema"` - Profile string `db:"dataset_profile" gorm:"column:dataset_profile"` -} - -func (d *Dataset) ToEntity() *entities.Dataset { - return &entities.Dataset{ - Name: d.Name, - Digest: d.Digest, - SourceType: d.SourceType, - Source: d.Source, - Schema: d.Schema, - Profile: d.Profile, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Dataset mapped from table . +type Dataset struct { + ID string `db:"dataset_uuid" gorm:"column:dataset_uuid;not null"` + ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` + Name string `db:"name" gorm:"column:name;primaryKey"` + Digest string `db:"digest" gorm:"column:digest;primaryKey"` + SourceType string `db:"dataset_source_type" gorm:"column:dataset_source_type;not null"` + Source string `db:"dataset_source" gorm:"column:dataset_source;not null"` + Schema string `db:"dataset_schema" gorm:"column:dataset_schema"` + Profile string `db:"dataset_profile" gorm:"column:dataset_profile"` +} + +func (d *Dataset) ToEntity() *entities.Dataset { + return &entities.Dataset{ + Name: d.Name, + Digest: d.Digest, + SourceType: d.SourceType, + Source: d.Source, + Schema: d.Schema, + Profile: d.Profile, + } +} diff --git a/pkg/tracking/store/sql/models/experiment_tags.go b/pkg/tracking/store/sql/models/experiment_tags.go index 8808f7f..54aca6b 100644 --- a/pkg/tracking/store/sql/models/experiment_tags.go +++ b/pkg/tracking/store/sql/models/experiment_tags.go @@ -1,10 +1,10 @@ -package models - -const TableNameExperimentTag = "experiment_tags" - -// ExperimentTag mapped from table . -type ExperimentTag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` -} +package models + +const TableNameExperimentTag = "experiment_tags" + +// ExperimentTag mapped from table . +type ExperimentTag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"` +} diff --git a/pkg/tracking/store/sql/models/experiments.go b/pkg/tracking/store/sql/models/experiments.go index fcecc46..2d057b3 100644 --- a/pkg/tracking/store/sql/models/experiments.go +++ b/pkg/tracking/store/sql/models/experiments.go @@ -1,40 +1,40 @@ -package models - -import ( - "strconv" - - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Experiment mapped from table . -type Experiment struct { - ID int32 `gorm:"column:experiment_id;primaryKey;autoIncrement:true"` - Name string `gorm:"column:name;not null"` - ArtifactLocation string `gorm:"column:artifact_location"` - LifecycleStage LifecycleStage `gorm:"column:lifecycle_stage"` - CreationTime int64 `gorm:"column:creation_time"` - LastUpdateTime int64 `gorm:"column:last_update_time"` - Tags []ExperimentTag - Runs []Run -} - -func (e Experiment) ToEntity() *entities.Experiment { - experiment := entities.Experiment{ - ExperimentID: strconv.Itoa(int(e.ID)), - Name: e.Name, - ArtifactLocation: e.ArtifactLocation, - LifecycleStage: e.LifecycleStage.String(), - CreationTime: e.CreationTime, - LastUpdateTime: e.LastUpdateTime, - Tags: make([]*entities.ExperimentTag, len(e.Tags)), - } - - for i, tag := range e.Tags { - experiment.Tags[i] = &entities.ExperimentTag{ - Key: tag.Key, - Value: tag.Value, - } - } - - return &experiment -} +package models + +import ( + "strconv" + + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Experiment mapped from table . +type Experiment struct { + ID int32 `gorm:"column:experiment_id;primaryKey;autoIncrement:true"` + Name string `gorm:"column:name;not null"` + ArtifactLocation string `gorm:"column:artifact_location"` + LifecycleStage LifecycleStage `gorm:"column:lifecycle_stage"` + CreationTime int64 `gorm:"column:creation_time"` + LastUpdateTime int64 `gorm:"column:last_update_time"` + Tags []ExperimentTag + Runs []Run +} + +func (e Experiment) ToEntity() *entities.Experiment { + experiment := entities.Experiment{ + ExperimentID: strconv.Itoa(int(e.ID)), + Name: e.Name, + ArtifactLocation: e.ArtifactLocation, + LifecycleStage: e.LifecycleStage.String(), + CreationTime: e.CreationTime, + LastUpdateTime: e.LastUpdateTime, + Tags: make([]*entities.ExperimentTag, len(e.Tags)), + } + + for i, tag := range e.Tags { + experiment.Tags[i] = &entities.ExperimentTag{ + Key: tag.Key, + Value: tag.Value, + } + } + + return &experiment +} diff --git a/pkg/tracking/store/sql/models/input_tags.go b/pkg/tracking/store/sql/models/input_tags.go index b1d55b3..ca33747 100644 --- a/pkg/tracking/store/sql/models/input_tags.go +++ b/pkg/tracking/store/sql/models/input_tags.go @@ -1,19 +1,19 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// InputTag mapped from table . -type InputTag struct { - Key string `gorm:"column:name;primaryKey"` - Value string `gorm:"column:value;not null"` - InputID string `gorm:"column:input_uuid;primaryKey"` -} - -func (i *InputTag) ToEntity() *entities.InputTag { - return &entities.InputTag{ - Key: i.Key, - Value: i.Value, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// InputTag mapped from table . +type InputTag struct { + Key string `gorm:"column:name;primaryKey"` + Value string `gorm:"column:value;not null"` + InputID string `gorm:"column:input_uuid;primaryKey"` +} + +func (i *InputTag) ToEntity() *entities.InputTag { + return &entities.InputTag{ + Key: i.Key, + Value: i.Value, + } +} diff --git a/pkg/tracking/store/sql/models/inputs.go b/pkg/tracking/store/sql/models/inputs.go index ed904eb..b9846a6 100644 --- a/pkg/tracking/store/sql/models/inputs.go +++ b/pkg/tracking/store/sql/models/inputs.go @@ -1,28 +1,28 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Input mapped from table . -type Input struct { - ID string `db:"input_uuid" gorm:"column:input_uuid;not null"` - SourceType string `db:"source_type" gorm:"column:source_type;primaryKey"` - SourceID string `db:"source_id" gorm:"column:source_id;primaryKey"` - DestinationType string `db:"destination_type" gorm:"column:destination_type;primaryKey"` - DestinationID string `db:"destination_id" gorm:"column:destination_id;primaryKey"` - Tags []InputTag `gorm:"foreignKey:InputID;references:ID"` - Dataset Dataset `gorm:"foreignKey:ID;references:SourceID"` -} - -func (i *Input) ToEntity() *entities.DatasetInput { - tags := make([]*entities.InputTag, 0, len(i.Tags)) - for _, tag := range i.Tags { - tags = append(tags, tag.ToEntity()) - } - - return &entities.DatasetInput{ - Tags: tags, - Dataset: i.Dataset.ToEntity(), - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Input mapped from table . +type Input struct { + ID string `db:"input_uuid" gorm:"column:input_uuid;not null"` + SourceType string `db:"source_type" gorm:"column:source_type;primaryKey"` + SourceID string `db:"source_id" gorm:"column:source_id;primaryKey"` + DestinationType string `db:"destination_type" gorm:"column:destination_type;primaryKey"` + DestinationID string `db:"destination_id" gorm:"column:destination_id;primaryKey"` + Tags []InputTag `gorm:"foreignKey:InputID;references:ID"` + Dataset Dataset `gorm:"foreignKey:ID;references:SourceID"` +} + +func (i *Input) ToEntity() *entities.DatasetInput { + tags := make([]*entities.InputTag, 0, len(i.Tags)) + for _, tag := range i.Tags { + tags = append(tags, tag.ToEntity()) + } + + return &entities.DatasetInput{ + Tags: tags, + Dataset: i.Dataset.ToEntity(), + } +} diff --git a/pkg/tracking/store/sql/models/latest_metrics.go b/pkg/tracking/store/sql/models/latest_metrics.go index 0271650..021e27e 100644 --- a/pkg/tracking/store/sql/models/latest_metrics.go +++ b/pkg/tracking/store/sql/models/latest_metrics.go @@ -1,25 +1,25 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// LatestMetric mapped from table . -type LatestMetric struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value float64 `db:"value" gorm:"column:value;not null"` - Timestamp int64 `db:"timestamp" gorm:"column:timestamp"` - Step int64 `db:"step" gorm:"column:step;not null"` - IsNaN bool `db:"is_nan" gorm:"column:is_nan;not null"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` -} - -func (lm LatestMetric) ToEntity() *entities.Metric { - return &entities.Metric{ - Key: lm.Key, - Value: lm.Value, - Timestamp: lm.Timestamp, - Step: lm.Step, - IsNaN: lm.IsNaN, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// LatestMetric mapped from table . +type LatestMetric struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value float64 `db:"value" gorm:"column:value;not null"` + Timestamp int64 `db:"timestamp" gorm:"column:timestamp"` + Step int64 `db:"step" gorm:"column:step;not null"` + IsNaN bool `db:"is_nan" gorm:"column:is_nan;not null"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` +} + +func (lm LatestMetric) ToEntity() *entities.Metric { + return &entities.Metric{ + Key: lm.Key, + Value: lm.Value, + Timestamp: lm.Timestamp, + Step: lm.Step, + IsNaN: lm.IsNaN, + } +} diff --git a/pkg/tracking/store/sql/models/lifecycle.go b/pkg/tracking/store/sql/models/lifecycle.go index 13cc716..c01ad5e 100644 --- a/pkg/tracking/store/sql/models/lifecycle.go +++ b/pkg/tracking/store/sql/models/lifecycle.go @@ -1,12 +1,12 @@ -package models - -type LifecycleStage string - -func (s LifecycleStage) String() string { - return string(s) -} - -const ( - LifecycleStageActive LifecycleStage = "active" - LifecycleStageDeleted LifecycleStage = "deleted" -) +package models + +type LifecycleStage string + +func (s LifecycleStage) String() string { + return string(s) +} + +const ( + LifecycleStageActive LifecycleStage = "active" + LifecycleStageDeleted LifecycleStage = "deleted" +) diff --git a/pkg/tracking/store/sql/models/metrics.go b/pkg/tracking/store/sql/models/metrics.go index e100cdf..ef410a7 100644 --- a/pkg/tracking/store/sql/models/metrics.go +++ b/pkg/tracking/store/sql/models/metrics.go @@ -1,57 +1,57 @@ -package models - -import ( - "math" - - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Metric mapped from table . -type Metric struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value float64 `db:"value" gorm:"column:value;primaryKey"` - Timestamp int64 `db:"timestamp" gorm:"column:timestamp;primaryKey"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` - Step int64 `db:"step" gorm:"column:step;primaryKey"` - IsNaN bool `db:"is_nan" gorm:"column:is_nan;primaryKey"` -} - -func NewMetricFromEntity(runID string, metric *entities.Metric) *Metric { - model := Metric{ - RunID: runID, - Key: metric.Key, - Timestamp: metric.Timestamp, - } - - if metric.Step != 0 { - model.Step = metric.Step - } - - switch { - case math.IsNaN(metric.Value): - model.Value = 0 - model.IsNaN = true - case math.IsInf(metric.Value, 0): - // NB: SQL cannot represent Infs => We replace +/- Inf with max/min 64b float value - if metric.Value > 0 { - model.Value = math.MaxFloat64 - } else { - model.Value = -math.MaxFloat64 - } - default: - model.Value = metric.Value - } - - return &model -} - -func (m Metric) NewLatestMetricFromProto() LatestMetric { - return LatestMetric{ - RunID: m.RunID, - Key: m.Key, - Value: m.Value, - Timestamp: m.Timestamp, - Step: m.Step, - IsNaN: m.IsNaN, - } -} +package models + +import ( + "math" + + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Metric mapped from table . +type Metric struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value float64 `db:"value" gorm:"column:value;primaryKey"` + Timestamp int64 `db:"timestamp" gorm:"column:timestamp;primaryKey"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` + Step int64 `db:"step" gorm:"column:step;primaryKey"` + IsNaN bool `db:"is_nan" gorm:"column:is_nan;primaryKey"` +} + +func NewMetricFromEntity(runID string, metric *entities.Metric) *Metric { + model := Metric{ + RunID: runID, + Key: metric.Key, + Timestamp: metric.Timestamp, + } + + if metric.Step != 0 { + model.Step = metric.Step + } + + switch { + case math.IsNaN(metric.Value): + model.Value = 0 + model.IsNaN = true + case math.IsInf(metric.Value, 0): + // NB: SQL cannot represent Infs => We replace +/- Inf with max/min 64b float value + if metric.Value > 0 { + model.Value = math.MaxFloat64 + } else { + model.Value = -math.MaxFloat64 + } + default: + model.Value = metric.Value + } + + return &model +} + +func (m Metric) NewLatestMetricFromProto() LatestMetric { + return LatestMetric{ + RunID: m.RunID, + Key: m.Key, + Value: m.Value, + Timestamp: m.Timestamp, + Step: m.Step, + IsNaN: m.IsNaN, + } +} diff --git a/pkg/tracking/store/sql/models/params.go b/pkg/tracking/store/sql/models/params.go index 276fc97..1f2f8f5 100644 --- a/pkg/tracking/store/sql/models/params.go +++ b/pkg/tracking/store/sql/models/params.go @@ -1,27 +1,27 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Param mapped from table . -type Param struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value;not null"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` -} - -func (p Param) ToEntity() *entities.Param { - return &entities.Param{ - Key: p.Key, - Value: p.Value, - } -} - -func NewParamFromEntity(runID string, param *entities.Param) Param { - return Param{ - Key: param.Key, - Value: param.Value, - RunID: runID, - } -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Param mapped from table . +type Param struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value;not null"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` +} + +func (p Param) ToEntity() *entities.Param { + return &entities.Param{ + Key: p.Key, + Value: p.Value, + } +} + +func NewParamFromEntity(runID string, param *entities.Param) Param { + return Param{ + Key: param.Key, + Value: param.Value, + RunID: runID, + } +} diff --git a/pkg/tracking/store/sql/models/runs.go b/pkg/tracking/store/sql/models/runs.go index 810e9a8..23ef484 100644 --- a/pkg/tracking/store/sql/models/runs.go +++ b/pkg/tracking/store/sql/models/runs.go @@ -1,106 +1,106 @@ -package models - -import ( - "database/sql" - - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -// Run mapped from table . -type Run struct { - ID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` - Name string `db:"name" gorm:"column:name"` - SourceType SourceType `db:"source_type" gorm:"column:source_type"` - SourceName string `db:"source_name" gorm:"column:source_name"` - EntryPointName string `db:"entry_point_name" gorm:"column:entry_point_name"` - UserID string `db:"user_id" gorm:"column:user_id"` - Status RunStatus `db:"status" gorm:"column:status"` - StartTime int64 `db:"start_time" gorm:"column:start_time"` - EndTime sql.NullInt64 `db:"end_time" gorm:"column:end_time"` - SourceVersion string `db:"source_version" gorm:"column:source_version"` - LifecycleStage LifecycleStage `db:"lifecycle_stage" gorm:"column:lifecycle_stage"` - ArtifactURI string `db:"artifact_uri" gorm:"column:artifact_uri"` - ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id"` - DeletedTime sql.NullInt64 `db:"deleted_time" gorm:"column:deleted_time"` - Params []Param - Tags []Tag - Metrics []Metric - LatestMetrics []LatestMetric - Inputs []Input `gorm:"foreignKey:DestinationID"` -} - -type RunStatus string - -func (s RunStatus) String() string { - return string(s) -} - -const ( - RunStatusRunning RunStatus = "RUNNING" - RunStatusScheduled RunStatus = "SCHEDULED" - RunStatusFinished RunStatus = "FINISHED" - RunStatusFailed RunStatus = "FAILED" - RunStatusKilled RunStatus = "KILLED" -) - -type SourceType string - -const ( - SourceTypeNotebook SourceType = "NOTEBOOK" - SourceTypeJob SourceType = "JOB" - SourceTypeProject SourceType = "PROJECT" - SourceTypeLocal SourceType = "LOCAL" - SourceTypeUnknown SourceType = "UNKNOWN" - SourceTypeRecipe SourceType = "RECIPE" -) - -func (r Run) ToEntity() *entities.Run { - metrics := make([]*entities.Metric, 0, len(r.LatestMetrics)) - for _, metric := range r.LatestMetrics { - metrics = append(metrics, metric.ToEntity()) - } - - params := make([]*entities.Param, 0, len(r.Params)) - for _, param := range r.Params { - params = append(params, param.ToEntity()) - } - - tags := make([]*entities.RunTag, 0, len(r.Tags)) - for _, tag := range r.Tags { - tags = append(tags, tag.ToEntity()) - } - - datasetInputs := make([]*entities.DatasetInput, 0, len(r.Inputs)) - for _, input := range r.Inputs { - datasetInputs = append(datasetInputs, input.ToEntity()) - } - - var endTime *int64 - if r.EndTime.Valid { - endTime = utils.PtrTo(r.EndTime.Int64) - } - - return &entities.Run{ - Info: &entities.RunInfo{ - RunID: r.ID, - RunUUID: r.ID, - RunName: r.Name, - ExperimentID: r.ExperimentID, - UserID: r.UserID, - Status: r.Status.String(), - StartTime: r.StartTime, - EndTime: endTime, - ArtifactURI: r.ArtifactURI, - LifecycleStage: r.LifecycleStage.String(), - }, - Data: &entities.RunData{ - Tags: tags, - Params: params, - Metrics: metrics, - }, - Inputs: &entities.RunInputs{ - DatasetInputs: datasetInputs, - }, - } -} +package models + +import ( + "database/sql" + + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +// Run mapped from table . +type Run struct { + ID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` + Name string `db:"name" gorm:"column:name"` + SourceType SourceType `db:"source_type" gorm:"column:source_type"` + SourceName string `db:"source_name" gorm:"column:source_name"` + EntryPointName string `db:"entry_point_name" gorm:"column:entry_point_name"` + UserID string `db:"user_id" gorm:"column:user_id"` + Status RunStatus `db:"status" gorm:"column:status"` + StartTime int64 `db:"start_time" gorm:"column:start_time"` + EndTime sql.NullInt64 `db:"end_time" gorm:"column:end_time"` + SourceVersion string `db:"source_version" gorm:"column:source_version"` + LifecycleStage LifecycleStage `db:"lifecycle_stage" gorm:"column:lifecycle_stage"` + ArtifactURI string `db:"artifact_uri" gorm:"column:artifact_uri"` + ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id"` + DeletedTime sql.NullInt64 `db:"deleted_time" gorm:"column:deleted_time"` + Params []Param + Tags []Tag + Metrics []Metric + LatestMetrics []LatestMetric + Inputs []Input `gorm:"foreignKey:DestinationID"` +} + +type RunStatus string + +func (s RunStatus) String() string { + return string(s) +} + +const ( + RunStatusRunning RunStatus = "RUNNING" + RunStatusScheduled RunStatus = "SCHEDULED" + RunStatusFinished RunStatus = "FINISHED" + RunStatusFailed RunStatus = "FAILED" + RunStatusKilled RunStatus = "KILLED" +) + +type SourceType string + +const ( + SourceTypeNotebook SourceType = "NOTEBOOK" + SourceTypeJob SourceType = "JOB" + SourceTypeProject SourceType = "PROJECT" + SourceTypeLocal SourceType = "LOCAL" + SourceTypeUnknown SourceType = "UNKNOWN" + SourceTypeRecipe SourceType = "RECIPE" +) + +func (r Run) ToEntity() *entities.Run { + metrics := make([]*entities.Metric, 0, len(r.LatestMetrics)) + for _, metric := range r.LatestMetrics { + metrics = append(metrics, metric.ToEntity()) + } + + params := make([]*entities.Param, 0, len(r.Params)) + for _, param := range r.Params { + params = append(params, param.ToEntity()) + } + + tags := make([]*entities.RunTag, 0, len(r.Tags)) + for _, tag := range r.Tags { + tags = append(tags, tag.ToEntity()) + } + + datasetInputs := make([]*entities.DatasetInput, 0, len(r.Inputs)) + for _, input := range r.Inputs { + datasetInputs = append(datasetInputs, input.ToEntity()) + } + + var endTime *int64 + if r.EndTime.Valid { + endTime = utils.PtrTo(r.EndTime.Int64) + } + + return &entities.Run{ + Info: &entities.RunInfo{ + RunID: r.ID, + RunUUID: r.ID, + RunName: r.Name, + ExperimentID: r.ExperimentID, + UserID: r.UserID, + Status: r.Status.String(), + StartTime: r.StartTime, + EndTime: endTime, + ArtifactURI: r.ArtifactURI, + LifecycleStage: r.LifecycleStage.String(), + }, + Data: &entities.RunData{ + Tags: tags, + Params: params, + Metrics: metrics, + }, + Inputs: &entities.RunInputs{ + DatasetInputs: datasetInputs, + }, + } +} diff --git a/pkg/tracking/store/sql/models/tags.go b/pkg/tracking/store/sql/models/tags.go index d955a2a..0c8dbe1 100644 --- a/pkg/tracking/store/sql/models/tags.go +++ b/pkg/tracking/store/sql/models/tags.go @@ -1,31 +1,31 @@ -package models - -import ( - "github.com/mlflow/mlflow-go/pkg/entities" -) - -// Tag mapped from table . -type Tag struct { - Key string `db:"key" gorm:"column:key;primaryKey"` - Value string `db:"value" gorm:"column:value"` - RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` -} - -func (t Tag) ToEntity() *entities.RunTag { - return &entities.RunTag{ - Key: t.Key, - Value: t.Value, - } -} - -func NewTagFromEntity(runID string, entity *entities.RunTag) Tag { - tag := Tag{ - Key: entity.Key, - Value: entity.Value, - } - if runID != "" { - tag.RunID = runID - } - - return tag -} +package models + +import ( + "github.com/mlflow/mlflow-go/pkg/entities" +) + +// Tag mapped from table . +type Tag struct { + Key string `db:"key" gorm:"column:key;primaryKey"` + Value string `db:"value" gorm:"column:value"` + RunID string `db:"run_uuid" gorm:"column:run_uuid;primaryKey"` +} + +func (t Tag) ToEntity() *entities.RunTag { + return &entities.RunTag{ + Key: t.Key, + Value: t.Value, + } +} + +func NewTagFromEntity(runID string, entity *entities.RunTag) Tag { + tag := Tag{ + Key: entity.Key, + Value: entity.Value, + } + if runID != "" { + tag.RunID = runID + } + + return tag +} diff --git a/pkg/tracking/store/sql/params.go b/pkg/tracking/store/sql/params.go index cbe4da7..2d6d6c5 100644 --- a/pkg/tracking/store/sql/params.go +++ b/pkg/tracking/store/sql/params.go @@ -1,119 +1,119 @@ -package sql - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" -) - -const paramsBatchSize = 100 - -func verifyBatchParamsInserts( - transaction *gorm.DB, runID string, deduplicatedParamsMap map[string]string, -) *contract.Error { - keys := make([]string, 0, len(deduplicatedParamsMap)) - for key := range deduplicatedParamsMap { - keys = append(keys, key) - } - - var existingParams []models.Param - - err := transaction. - Model(&models.Param{}). - Select("key, value"). - Where("run_uuid = ?", runID). - Where("key IN ?", keys). - Find(&existingParams).Error - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf( - "failed to get existing params to check if duplicates for run_id %q", - runID, - ), - err) - } - - for _, existingParam := range existingParams { - if currentValue, ok := deduplicatedParamsMap[existingParam.Key]; ok && currentValue != existingParam.Value { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "Changing param values is not allowed. "+ - "Params with key=%q was already logged "+ - "with value=%q for run ID=%q. "+ - "Attempted logging new value %q", - existingParam.Key, - existingParam.Value, - runID, - currentValue, - ), - ) - } - } - - return nil -} - -func (s TrackingSQLStore) logParamsWithTransaction( - transaction *gorm.DB, runID string, params []*entities.Param, -) *contract.Error { - deduplicatedParamsMap := make(map[string]string, len(params)) - deduplicatedParams := make([]models.Param, 0, len(deduplicatedParamsMap)) - - for _, param := range params { - oldValue, paramIsPresent := deduplicatedParamsMap[param.Key] - if paramIsPresent && param.Value != oldValue { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "Changing param values is not allowed. "+ - "Params with key=%q was already logged "+ - "with value=%q for run ID=%q. "+ - "Attempted logging new value %q", - param.Key, - oldValue, - runID, - param.Value, - ), - ) - } - - if !paramIsPresent { - deduplicatedParamsMap[param.Key] = param.Value - deduplicatedParams = append(deduplicatedParams, models.NewParamFromEntity(runID, param)) - } - } - - // Try and create all params. - // Params are unique by (run_uuid, key) so any potentially conflicts will not be inserted. - err := transaction. - Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "run_uuid"}, {Name: "key"}}, - DoNothing: true, - }). - CreateInBatches(deduplicatedParams, paramsBatchSize).Error - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("error creating params in batch for run_uuid %q", runID), - err, - ) - } - - // if there were ignored conflicts, we assert that the values are the same. - if transaction.RowsAffected != int64(len(params)) { - contractError := verifyBatchParamsInserts(transaction, runID, deduplicatedParamsMap) - if contractError != nil { - return contractError - } - } - - return nil -} +package sql + +import ( + "fmt" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" +) + +const paramsBatchSize = 100 + +func verifyBatchParamsInserts( + transaction *gorm.DB, runID string, deduplicatedParamsMap map[string]string, +) *contract.Error { + keys := make([]string, 0, len(deduplicatedParamsMap)) + for key := range deduplicatedParamsMap { + keys = append(keys, key) + } + + var existingParams []models.Param + + err := transaction. + Model(&models.Param{}). + Select("key, value"). + Where("run_uuid = ?", runID). + Where("key IN ?", keys). + Find(&existingParams).Error + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "failed to get existing params to check if duplicates for run_id %q", + runID, + ), + err) + } + + for _, existingParam := range existingParams { + if currentValue, ok := deduplicatedParamsMap[existingParam.Key]; ok && currentValue != existingParam.Value { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "Changing param values is not allowed. "+ + "Params with key=%q was already logged "+ + "with value=%q for run ID=%q. "+ + "Attempted logging new value %q", + existingParam.Key, + existingParam.Value, + runID, + currentValue, + ), + ) + } + } + + return nil +} + +func (s TrackingSQLStore) logParamsWithTransaction( + transaction *gorm.DB, runID string, params []*entities.Param, +) *contract.Error { + deduplicatedParamsMap := make(map[string]string, len(params)) + deduplicatedParams := make([]models.Param, 0, len(deduplicatedParamsMap)) + + for _, param := range params { + oldValue, paramIsPresent := deduplicatedParamsMap[param.Key] + if paramIsPresent && param.Value != oldValue { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "Changing param values is not allowed. "+ + "Params with key=%q was already logged "+ + "with value=%q for run ID=%q. "+ + "Attempted logging new value %q", + param.Key, + oldValue, + runID, + param.Value, + ), + ) + } + + if !paramIsPresent { + deduplicatedParamsMap[param.Key] = param.Value + deduplicatedParams = append(deduplicatedParams, models.NewParamFromEntity(runID, param)) + } + } + + // Try and create all params. + // Params are unique by (run_uuid, key) so any potentially conflicts will not be inserted. + err := transaction. + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "run_uuid"}, {Name: "key"}}, + DoNothing: true, + }). + CreateInBatches(deduplicatedParams, paramsBatchSize).Error + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("error creating params in batch for run_uuid %q", runID), + err, + ) + } + + // if there were ignored conflicts, we assert that the values are the same. + if transaction.RowsAffected != int64(len(params)) { + contractError := verifyBatchParamsInserts(transaction, runID, deduplicatedParamsMap) + if contractError != nil { + return contractError + } + } + + return nil +} diff --git a/pkg/tracking/store/sql/runs.go b/pkg/tracking/store/sql/runs.go index 7110671..3c17920 100644 --- a/pkg/tracking/store/sql/runs.go +++ b/pkg/tracking/store/sql/runs.go @@ -1,900 +1,900 @@ -package sql - -import ( - "context" - "database/sql" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "strings" - "time" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query" - "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type PageToken struct { - Offset int32 `json:"offset"` -} - -func checkRunIsActive(transaction *gorm.DB, runID string) *contract.Error { - var run models.Run - - err := transaction. - Model(&models.Run{}). - Where("run_uuid = ?", runID). - Select("lifecycle_stage"). - First(&run). - Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("Run with id=%s not found", runID), - ) - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf( - "failed to get lifecycle stage for run %q", - runID, - ), - err, - ) - } - - if run.LifecycleStage != models.LifecycleStageActive { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "The run %s must be in the 'active' state.\n"+ - "Current state is %v.", - runID, - run.LifecycleStage, - ), - ) - } - - return nil -} - -func getLifecyleStages(runViewType protos.ViewType) []models.LifecycleStage { - switch runViewType { - case protos.ViewType_ACTIVE_ONLY: - return []models.LifecycleStage{ - models.LifecycleStageActive, - } - case protos.ViewType_DELETED_ONLY: - return []models.LifecycleStage{ - models.LifecycleStageDeleted, - } - case protos.ViewType_ALL: - return []models.LifecycleStage{ - models.LifecycleStageActive, - models.LifecycleStageDeleted, - } - } - - return []models.LifecycleStage{ - models.LifecycleStageActive, - models.LifecycleStageDeleted, - } -} - -func getOffset(pageToken string) (int, *contract.Error) { - if pageToken != "" { - var token PageToken - if err := json.NewDecoder( - base64.NewDecoder( - base64.StdEncoding, - strings.NewReader(pageToken), - ), - ).Decode(&token); err != nil { - return 0, contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("invalid page_token: %q", pageToken), - err, - ) - } - - return int(token.Offset), nil - } - - return 0, nil -} - -//nolint:funlen,cyclop,gocognit -func applyFilter(ctx context.Context, database, transaction *gorm.DB, filter string) *contract.Error { - filterConditions, err := query.ParseFilter(filter) - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "error parsing search filter", - err, - ) - } - - utils.GetLoggerFromContext(ctx).Debugf("Filter conditions: %v", filterConditions) - - for index, clause := range filterConditions { - var kind any - - key := clause.Key - comparison := strings.ToUpper(clause.Operator.String()) - value := clause.Value - - switch clause.Identifier { - case parser.Metric: - kind = &models.LatestMetric{} - case parser.Parameter: - kind = &models.Param{} - case parser.Tag: - kind = &models.Tag{} - case parser.Dataset: - kind = &models.Dataset{} - case parser.Attribute: - kind = nil - } - - // Treat "attributes.run_name == " as "tags.`mlflow.runName` == ". - // The name column in the runs table is empty for runs logged in MLflow <= 1.29.0. - if key == "run_name" { - kind = &models.Tag{} - key = utils.TagRunName - } - - isSqliteAndILike := database.Dialector.Name() == "sqlite" && comparison == "ILIKE" - table := fmt.Sprintf("filter_%d", index) - - switch { - case kind == nil: - if isSqliteAndILike { - key = fmt.Sprintf("LOWER(runs.%s)", key) - comparison = "LIKE" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - - transaction.Where(fmt.Sprintf("%s %s ?", key, comparison), value) - } else { - transaction.Where(fmt.Sprintf("runs.%s %s ?", key, comparison), value) - } - case clause.Identifier == parser.Dataset && key == "context": - // SELECT * - // FROM runs - // JOIN ( - // SELECT inputs.destination_id AS run_uuid - // FROM inputs - // JOIN input_tags - // ON inputs.input_uuid = input_tags.input_uuid - // AND input_tags.name = 'mlflow.data.context' - // AND input_tags.value %s ? - // WHERE inputs.destination_type = 'RUN' - // ) AS filter_0 - // ON runs.run_uuid = filter_0.run_uuid - valueColumn := "input_tags.value " - if isSqliteAndILike { - valueColumn = "LOWER(input_tags.value) " - comparison = "LIKE" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - } - - transaction.Joins( - fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), - database.Select("inputs.destination_id AS run_uuid"). - Joins( - "JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid"+ - " AND input_tags.name = 'mlflow.data.context'"+ - " AND "+valueColumn+comparison+" ?", - value, - ). - Where("inputs.destination_type = 'RUN'"). - Model(&models.Input{}), - ) - case clause.Identifier == parser.Dataset: - // add join with datasets - // JOIN ( - // SELECT "experiment_id", key - // FROM datasests d - // JOIN inputs ON inputs.source_id = datasets.dataset_uuid - // WHERE key comparison value - // ) AS filter_0 ON runs.experiment_id = dataset.experiment_id - // - // columns: name, digest, context - where := key + " " + comparison + " ?" - if isSqliteAndILike { - where = "LOWER(" + key + ") LIKE ?" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - } - - transaction.Joins( - fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.destination_id", table, table), - database.Model(kind). - Joins("JOIN inputs ON inputs.source_id = datasets.dataset_uuid"). - Where(where, value). - Select("destination_id", key), - ) - default: - where := fmt.Sprintf("value %s ?", comparison) - if isSqliteAndILike { - where = "LOWER(value) LIKE ?" - - if str, ok := value.(string); ok { - value = strings.ToLower(str) - } - } - - transaction.Joins( - fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), - database.Select("run_uuid", "value").Where("key = ?", key).Where(where, value).Model(kind), - ) - } - } - - return nil -} - -type orderByExpr struct { - identifier *string - key string - order *string -} - -var ErrInvalidOrderClauseInput = errors.New("input string is empty or only contains quote characters") - -const ( - identifierAndKeyLength = 2 - startTime = "start_time" - name = "name" - attribute = "attribute" - metric = "metric" -) - -func orderByKeyAlias(input string) string { - switch input { - case "created", "Created": - return startTime - case "run_name", "run name", "Run name", "Run Name": - return name - case "run_id": - return "run_uuid" - default: - return input - } -} - -func handleInsideQuote( - char, quoteChar rune, insideQuote bool, current strings.Builder, result []string, -) (bool, strings.Builder, []string) { - if char == quoteChar { - insideQuote = false - - result = append(result, current.String()) - current.Reset() - } else { - current.WriteRune(char) - } - - return insideQuote, current, result -} - -func handleOutsideQuote( - char rune, insideQuote bool, quoteChar rune, current strings.Builder, result []string, -) (bool, rune, strings.Builder, []string) { - switch char { - case ' ': - if current.Len() > 0 { - result = append(result, current.String()) - current.Reset() - } - case '"', '\'', '`': - insideQuote = true - quoteChar = char - default: - current.WriteRune(char) - } - - return insideQuote, quoteChar, current, result -} - -// Process an order by input string to split the string into the separate parts. -// We can't simply split by space, because the column name could be wrapped in quotes, e.g. "Run name" ASC. -func splitOrderByClauseWithQuotes(input string) []string { - input = strings.ToLower(strings.Trim(input, " ")) - - var result []string - - var current strings.Builder - - var insideQuote bool - - var quoteChar rune - - // Process char per char, split items on spaces unless inside a quoted entry. - for _, char := range input { - if insideQuote { - insideQuote, current, result = handleInsideQuote(char, quoteChar, insideQuote, current, result) - } else { - insideQuote, quoteChar, current, result = handleOutsideQuote(char, insideQuote, quoteChar, current, result) - } - } - - if current.Len() > 0 { - result = append(result, current.String()) - } - - return result -} - -func translateIdentifierAlias(identifier string) string { - switch strings.ToLower(identifier) { - case "metrics": - return metric - case "parameters", "param", "params": - return "parameter" - case "tags": - return "tag" - case "attr", "attributes", "run": - return attribute - case "datasets": - return "dataset" - default: - return identifier - } -} - -func processOrderByClause(input string) (orderByExpr, error) { - parts := splitOrderByClauseWithQuotes(input) - - if len(parts) == 0 { - return orderByExpr{}, ErrInvalidOrderClauseInput - } - - var expr orderByExpr - - identifierKey := strings.Split(parts[0], ".") - - if len(identifierKey) == identifierAndKeyLength { - expr.identifier = utils.PtrTo(translateIdentifierAlias(identifierKey[0])) - expr.key = orderByKeyAlias(identifierKey[1]) - } else if len(identifierKey) == 1 { - expr.key = orderByKeyAlias(identifierKey[0]) - } - - if len(parts) > 1 { - expr.order = utils.PtrTo(strings.ToUpper(parts[1])) - } - - return expr, nil -} - -//nolint:funlen, cyclop, gocognit -func applyOrderBy(ctx context.Context, database, transaction *gorm.DB, orderBy []string) *contract.Error { - startTimeOrder := false - columnSelection := "runs.*" - - for index, orderByClause := range orderBy { - orderByExpr, err := processOrderByClause(orderByClause) - if err != nil { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "invalid order_by clause %q.", - orderByClause, - ), - ) - } - - logger := utils.GetLoggerFromContext(ctx) - logger. - Debugf( - "OrderByExpr: identifier: %v, key: %v, order: %v", - utils.DumpStringPointer(orderByExpr.identifier), - orderByExpr.key, - utils.DumpStringPointer(orderByExpr.order), - ) - - var kind any - - if orderByExpr.identifier == nil && orderByExpr.key == "start_time" { - startTimeOrder = true - } else if orderByExpr.identifier != nil { - switch { - case *orderByExpr.identifier == attribute && orderByExpr.key == "start_time": - startTimeOrder = true - case *orderByExpr.identifier == metric: - kind = &models.LatestMetric{} - case *orderByExpr.identifier == "parameter": - kind = &models.Param{} - case *orderByExpr.identifier == "tag": - kind = &models.Tag{} - } - } - - table := fmt.Sprintf("order_%d", index) - - if kind != nil { - columnsInJoin := []string{"run_uuid", "value"} - if *orderByExpr.identifier == metric { - columnsInJoin = append(columnsInJoin, "is_nan") - } - - transaction.Joins( - fmt.Sprintf("LEFT OUTER JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), - database.Select(columnsInJoin).Where("key = ?", orderByExpr.key).Model(kind), - ) - - orderByExpr.key = table + ".value" - } - - desc := false - if orderByExpr.order != nil { - desc = *orderByExpr.order == "DESC" - } - - nullableColumnAlias := fmt.Sprintf("order_null_%d", index) - - if orderByExpr.identifier == nil || *orderByExpr.identifier != metric { - var originalColumn string - - switch { - case orderByExpr.identifier != nil && *orderByExpr.identifier == "attribute": - originalColumn = "runs." + orderByExpr.key - case orderByExpr.identifier != nil: - originalColumn = table + ".value" - default: - originalColumn = orderByExpr.key - } - - columnSelection = fmt.Sprintf( - "%s, (CASE WHEN (%s IS NULL) THEN 1 ELSE 0 END) AS %s", - columnSelection, - originalColumn, - nullableColumnAlias, - ) - - transaction.Order(nullableColumnAlias) - } - - // the metric table has the is_nan column - if orderByExpr.identifier != nil && *orderByExpr.identifier == metric { - trueColumnValue := "true" - if database.Dialector.Name() == "sqlite" { - trueColumnValue = "1" - } - - columnSelection = fmt.Sprintf( - "%s, (CASE WHEN (%s.is_nan = %s) THEN 1 WHEN (%s.value IS NULL) THEN 2 ELSE 0 END) AS %s", - columnSelection, - table, - trueColumnValue, - table, - nullableColumnAlias, - ) - - transaction.Order(nullableColumnAlias) - } - - transaction.Order(clause.OrderByColumn{ - Column: clause.Column{ - Name: orderByExpr.key, - }, - Desc: desc, - }) - } - - if !startTimeOrder { - transaction.Order("runs.start_time DESC") - } - - transaction.Order("runs.run_uuid") - - // mlflow orders all nullable columns to have null last. - // For each order by clause, an additional dynamic order clause was added. - // We need to include these columns in the select clause. - transaction.Select(columnSelection) - - return nil -} - -func mkNextPageToken(runLength, maxResults, offset int) (string, *contract.Error) { - var nextPageToken string - - if runLength == maxResults { - var token strings.Builder - if err := json.NewEncoder( - base64.NewEncoder(base64.StdEncoding, &token), - ).Encode(PageToken{ - Offset: int32(offset + maxResults), - }); err != nil { - return "", contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "error encoding 'nextPageToken' value", - err, - ) - } - - nextPageToken = token.String() - } - - return nextPageToken, nil -} - -//nolint:funlen -func (s TrackingSQLStore) SearchRuns( - ctx context.Context, - experimentIDs []string, filter string, - runViewType protos.ViewType, maxResults int, orderBy []string, pageToken string, -) ([]*entities.Run, string, *contract.Error) { - // ViewType - lifecyleStages := getLifecyleStages(runViewType) - transaction := s.db.WithContext(ctx).Where( - "runs.experiment_id IN ?", experimentIDs, - ).Where( - "runs.lifecycle_stage IN ?", lifecyleStages, - ) - - // MaxResults - transaction.Limit(maxResults) - - // PageToken - offset, contractError := getOffset(pageToken) - if contractError != nil { - return nil, "", contractError - } - - transaction.Offset(offset) - - // Filter - contractError = applyFilter(ctx, s.db, transaction, filter) - if contractError != nil { - return nil, "", contractError - } - - // OrderBy - contractError = applyOrderBy(ctx, s.db, transaction, orderBy) - if contractError != nil { - return nil, "", contractError - } - - // Actual query - var runs []models.Run - - transaction.Preload("LatestMetrics").Preload("Params").Preload("Tags"). - Preload("Inputs", "inputs.destination_type = 'RUN'"). - Preload("Inputs.Dataset").Preload("Inputs.Tags").Find(&runs) - - if transaction.Error != nil { - return nil, "", contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "Failed to query search runs", - transaction.Error, - ) - } - - entityRuns := make([]*entities.Run, len(runs)) - for i, run := range runs { - entityRuns[i] = run.ToEntity() - } - - nextPageToken, contractError := mkNextPageToken(len(runs), maxResults, offset) - if contractError != nil { - return nil, "", contractError - } - - return entityRuns, nextPageToken, nil -} - -const RunIDMaxLength = 32 - -const ( - ArtifactFolderName = "artifacts" - RunNameIntegerScale = 3 - RunNameMaxLength = 20 -) - -func getRunNameFromTags(tags []models.Tag) string { - for _, tag := range tags { - if tag.Key == utils.TagRunName { - return tag.Value - } - } - - return "" -} - -func ensureRunName(runModel *models.Run) *contract.Error { - runNameFromTags := getRunNameFromTags(runModel.Tags) - - switch { - // run_name and name in tags differ - case utils.IsNotNilOrEmptyString(&runModel.Name) && runNameFromTags != "" && runModel.Name != runNameFromTags: - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "Both 'run_name' argument and 'mlflow.runName' tag are specified, but with "+ - "different values (run_name='%s', mlflow.runName='%s').", - runModel.Name, - runNameFromTags, - ), - ) - // no name was provided, generate a random name - case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags == "": - randomName, err := utils.GenerateRandomName() - if err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - "failed to generate random run name", - err, - ) - } - - runModel.Name = randomName - // use name from tags - case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags != "": - runModel.Name = runNameFromTags - } - - if runNameFromTags == "" { - runModel.Tags = append(runModel.Tags, models.Tag{ - Key: utils.TagRunName, - Value: runModel.Name, - }) - } - - return nil -} - -func (s TrackingSQLStore) GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) { - var run models.Run - if err := s.db.WithContext(ctx).Where( - "run_uuid = ?", runID, - ).Preload( - "Tags", - ).Preload( - "Params", - ).Preload( - "Inputs.Tags", - ).Preload( - "LatestMetrics", - ).Preload( - "Inputs.Dataset", - ).First(&run).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("Run with id=%s not found", runID), - ) - } - - return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to get run", err) - } - - return run.ToEntity(), nil -} - -//nolint:funlen -func (s TrackingSQLStore) CreateRun( - ctx context.Context, - experimentID, userID string, - startTime int64, - tags []*entities.RunTag, - runName string, -) (*entities.Run, *contract.Error) { - experiment, err := s.GetExperiment(ctx, experimentID) - if err != nil { - return nil, err - } - - if models.LifecycleStage(experiment.LifecycleStage) != models.LifecycleStageActive { - return nil, contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf( - "The experiment %q must be in the 'active' state.\n"+ - "Current state is %q.", - experiment.ExperimentID, - experiment.LifecycleStage, - ), - ) - } - - runModel := &models.Run{ - ID: utils.NewUUID(), - Name: runName, - ExperimentID: utils.ConvertStringPointerToInt32Pointer(&experimentID), - StartTime: startTime, - UserID: userID, - Tags: make([]models.Tag, 0, len(tags)), - LifecycleStage: models.LifecycleStageActive, - Status: models.RunStatusRunning, - SourceType: models.SourceTypeUnknown, - } - - for _, tag := range tags { - runModel.Tags = append(runModel.Tags, models.NewTagFromEntity(runModel.ID, tag)) - } - - artifactLocation, appendErr := utils.AppendToURIPath( - experiment.ArtifactLocation, - runModel.ID, - ArtifactFolderName, - ) - if appendErr != nil { - return nil, contract.NewError( - protos.ErrorCode_INTERNAL_ERROR, - "failed to append run ID to experiment artifact location", - ) - } - - runModel.ArtifactURI = artifactLocation - - errRunName := ensureRunName(runModel) - if errRunName != nil { - return nil, errRunName - } - - if err := s.db.Create(&runModel).Error; err != nil { - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf( - "failed to create run for experiment_id %q", - experiment.ExperimentID, - ), - err, - ) - } - - return runModel.ToEntity(), nil -} - -func (s TrackingSQLStore) UpdateRun( - ctx context.Context, - runID string, - runStatus string, - endTime *int64, - runName string, -) *contract.Error { - runTag, err := s.GetRunTag(ctx, runID, utils.TagRunName) - if err != nil { - return err - } - - tags := make([]models.Tag, 0, 1) - if runTag == nil { - tags = append(tags, models.Tag{ - RunID: runID, - Key: utils.TagRunName, - Value: runName, - }) - } - - var endTimeValue sql.NullInt64 - if endTime == nil { - endTimeValue = sql.NullInt64{} - } else { - endTimeValue = sql.NullInt64{Int64: *endTime, Valid: true} - } - - if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - if err := transaction.Model(&models.Run{}). - Where("run_uuid = ?", runID). - Updates(&models.Run{ - Name: runName, - Status: models.RunStatus(runStatus), - EndTime: endTimeValue, - }).Error; err != nil { - return err - } - - if len(tags) > 0 { - if err := transaction.Clauses(clause.OnConflict{ - UpdateAll: true, - }).CreateInBatches(tags, tagsBatchSize).Error; err != nil { - return fmt.Errorf("failed to create tags for run %q: %w", runID, err) - } - } - - return nil - }); err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update run", err) - } - - return nil -} - -func (s TrackingSQLStore) DeleteRun(ctx context.Context, runID string) *contract.Error { - run, err := s.GetRun(ctx, runID) - if err != nil { - return err - } - - if err := s.db.WithContext(ctx).Model(&models.Run{}). - Where("run_uuid = ?", run.Info.RunID). - Updates(&models.Run{ - DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, - LifecycleStage: models.LifecycleStageDeleted, - }).Error; err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to delete run", err) - } - - return nil -} - -func (s TrackingSQLStore) RestoreRun(ctx context.Context, runID string) *contract.Error { - run, err := s.GetRun(ctx, runID) - if err != nil { - return err - } - - if err := s.db.WithContext(ctx).Model(&models.Run{}). - Where("run_uuid = ?", run.Info.RunID). - // Force GORM to update fields with zero values by selecting them. - Select("DeletedTime", "LifecycleStage"). - Updates(&models.Run{ - DeletedTime: sql.NullInt64{}, - LifecycleStage: models.LifecycleStageActive, - }).Error; err != nil { - return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to restore run", err) - } - - return nil -} - -func (s TrackingSQLStore) LogBatch( - ctx context.Context, runID string, metrics []*entities.Metric, params []*entities.Param, tags []*entities.RunTag, -) *contract.Error { - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - contractError := checkRunIsActive(transaction, runID) - if contractError != nil { - return contractError - } - - err := s.setTagsWithTransaction(transaction, runID, tags) - if err != nil { - return fmt.Errorf("error setting tags for run_id %q: %w", runID, err) - } - - contractError = s.logParamsWithTransaction(transaction, runID, params) - if contractError != nil { - return contractError - } - - contractError = s.logMetricsWithTransaction(transaction, runID, metrics) - if contractError != nil { - return contractError - } - - return nil - }) - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return contractError - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("log batch transaction failed for %q", runID), - err, - ) - } - - return nil -} +package sql + +import ( + "context" + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query" + "github.com/mlflow/mlflow-go/pkg/tracking/service/query/parser" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type PageToken struct { + Offset int32 `json:"offset"` +} + +func checkRunIsActive(transaction *gorm.DB, runID string) *contract.Error { + var run models.Run + + err := transaction. + Model(&models.Run{}). + Where("run_uuid = ?", runID). + Select("lifecycle_stage"). + First(&run). + Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("Run with id=%s not found", runID), + ) + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "failed to get lifecycle stage for run %q", + runID, + ), + err, + ) + } + + if run.LifecycleStage != models.LifecycleStageActive { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "The run %s must be in the 'active' state.\n"+ + "Current state is %v.", + runID, + run.LifecycleStage, + ), + ) + } + + return nil +} + +func getLifecyleStages(runViewType protos.ViewType) []models.LifecycleStage { + switch runViewType { + case protos.ViewType_ACTIVE_ONLY: + return []models.LifecycleStage{ + models.LifecycleStageActive, + } + case protos.ViewType_DELETED_ONLY: + return []models.LifecycleStage{ + models.LifecycleStageDeleted, + } + case protos.ViewType_ALL: + return []models.LifecycleStage{ + models.LifecycleStageActive, + models.LifecycleStageDeleted, + } + } + + return []models.LifecycleStage{ + models.LifecycleStageActive, + models.LifecycleStageDeleted, + } +} + +func getOffset(pageToken string) (int, *contract.Error) { + if pageToken != "" { + var token PageToken + if err := json.NewDecoder( + base64.NewDecoder( + base64.StdEncoding, + strings.NewReader(pageToken), + ), + ).Decode(&token); err != nil { + return 0, contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("invalid page_token: %q", pageToken), + err, + ) + } + + return int(token.Offset), nil + } + + return 0, nil +} + +//nolint:funlen,cyclop,gocognit +func applyFilter(ctx context.Context, database, transaction *gorm.DB, filter string) *contract.Error { + filterConditions, err := query.ParseFilter(filter) + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + "error parsing search filter", + err, + ) + } + + utils.GetLoggerFromContext(ctx).Debugf("Filter conditions: %v", filterConditions) + + for index, clause := range filterConditions { + var kind any + + key := clause.Key + comparison := strings.ToUpper(clause.Operator.String()) + value := clause.Value + + switch clause.Identifier { + case parser.Metric: + kind = &models.LatestMetric{} + case parser.Parameter: + kind = &models.Param{} + case parser.Tag: + kind = &models.Tag{} + case parser.Dataset: + kind = &models.Dataset{} + case parser.Attribute: + kind = nil + } + + // Treat "attributes.run_name == " as "tags.`mlflow.runName` == ". + // The name column in the runs table is empty for runs logged in MLflow <= 1.29.0. + if key == "run_name" { + kind = &models.Tag{} + key = utils.TagRunName + } + + isSqliteAndILike := database.Dialector.Name() == "sqlite" && comparison == "ILIKE" + table := fmt.Sprintf("filter_%d", index) + + switch { + case kind == nil: + if isSqliteAndILike { + key = fmt.Sprintf("LOWER(runs.%s)", key) + comparison = "LIKE" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + + transaction.Where(fmt.Sprintf("%s %s ?", key, comparison), value) + } else { + transaction.Where(fmt.Sprintf("runs.%s %s ?", key, comparison), value) + } + case clause.Identifier == parser.Dataset && key == "context": + // SELECT * + // FROM runs + // JOIN ( + // SELECT inputs.destination_id AS run_uuid + // FROM inputs + // JOIN input_tags + // ON inputs.input_uuid = input_tags.input_uuid + // AND input_tags.name = 'mlflow.data.context' + // AND input_tags.value %s ? + // WHERE inputs.destination_type = 'RUN' + // ) AS filter_0 + // ON runs.run_uuid = filter_0.run_uuid + valueColumn := "input_tags.value " + if isSqliteAndILike { + valueColumn = "LOWER(input_tags.value) " + comparison = "LIKE" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + } + + transaction.Joins( + fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), + database.Select("inputs.destination_id AS run_uuid"). + Joins( + "JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid"+ + " AND input_tags.name = 'mlflow.data.context'"+ + " AND "+valueColumn+comparison+" ?", + value, + ). + Where("inputs.destination_type = 'RUN'"). + Model(&models.Input{}), + ) + case clause.Identifier == parser.Dataset: + // add join with datasets + // JOIN ( + // SELECT "experiment_id", key + // FROM datasests d + // JOIN inputs ON inputs.source_id = datasets.dataset_uuid + // WHERE key comparison value + // ) AS filter_0 ON runs.experiment_id = dataset.experiment_id + // + // columns: name, digest, context + where := key + " " + comparison + " ?" + if isSqliteAndILike { + where = "LOWER(" + key + ") LIKE ?" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + } + + transaction.Joins( + fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.destination_id", table, table), + database.Model(kind). + Joins("JOIN inputs ON inputs.source_id = datasets.dataset_uuid"). + Where(where, value). + Select("destination_id", key), + ) + default: + where := fmt.Sprintf("value %s ?", comparison) + if isSqliteAndILike { + where = "LOWER(value) LIKE ?" + + if str, ok := value.(string); ok { + value = strings.ToLower(str) + } + } + + transaction.Joins( + fmt.Sprintf("JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), + database.Select("run_uuid", "value").Where("key = ?", key).Where(where, value).Model(kind), + ) + } + } + + return nil +} + +type orderByExpr struct { + identifier *string + key string + order *string +} + +var ErrInvalidOrderClauseInput = errors.New("input string is empty or only contains quote characters") + +const ( + identifierAndKeyLength = 2 + startTime = "start_time" + name = "name" + attribute = "attribute" + metric = "metric" +) + +func orderByKeyAlias(input string) string { + switch input { + case "created", "Created": + return startTime + case "run_name", "run name", "Run name", "Run Name": + return name + case "run_id": + return "run_uuid" + default: + return input + } +} + +func handleInsideQuote( + char, quoteChar rune, insideQuote bool, current strings.Builder, result []string, +) (bool, strings.Builder, []string) { + if char == quoteChar { + insideQuote = false + + result = append(result, current.String()) + current.Reset() + } else { + current.WriteRune(char) + } + + return insideQuote, current, result +} + +func handleOutsideQuote( + char rune, insideQuote bool, quoteChar rune, current strings.Builder, result []string, +) (bool, rune, strings.Builder, []string) { + switch char { + case ' ': + if current.Len() > 0 { + result = append(result, current.String()) + current.Reset() + } + case '"', '\'', '`': + insideQuote = true + quoteChar = char + default: + current.WriteRune(char) + } + + return insideQuote, quoteChar, current, result +} + +// Process an order by input string to split the string into the separate parts. +// We can't simply split by space, because the column name could be wrapped in quotes, e.g. "Run name" ASC. +func splitOrderByClauseWithQuotes(input string) []string { + input = strings.ToLower(strings.Trim(input, " ")) + + var result []string + + var current strings.Builder + + var insideQuote bool + + var quoteChar rune + + // Process char per char, split items on spaces unless inside a quoted entry. + for _, char := range input { + if insideQuote { + insideQuote, current, result = handleInsideQuote(char, quoteChar, insideQuote, current, result) + } else { + insideQuote, quoteChar, current, result = handleOutsideQuote(char, insideQuote, quoteChar, current, result) + } + } + + if current.Len() > 0 { + result = append(result, current.String()) + } + + return result +} + +func translateIdentifierAlias(identifier string) string { + switch strings.ToLower(identifier) { + case "metrics": + return metric + case "parameters", "param", "params": + return "parameter" + case "tags": + return "tag" + case "attr", "attributes", "run": + return attribute + case "datasets": + return "dataset" + default: + return identifier + } +} + +func processOrderByClause(input string) (orderByExpr, error) { + parts := splitOrderByClauseWithQuotes(input) + + if len(parts) == 0 { + return orderByExpr{}, ErrInvalidOrderClauseInput + } + + var expr orderByExpr + + identifierKey := strings.Split(parts[0], ".") + + if len(identifierKey) == identifierAndKeyLength { + expr.identifier = utils.PtrTo(translateIdentifierAlias(identifierKey[0])) + expr.key = orderByKeyAlias(identifierKey[1]) + } else if len(identifierKey) == 1 { + expr.key = orderByKeyAlias(identifierKey[0]) + } + + if len(parts) > 1 { + expr.order = utils.PtrTo(strings.ToUpper(parts[1])) + } + + return expr, nil +} + +//nolint:funlen, cyclop, gocognit +func applyOrderBy(ctx context.Context, database, transaction *gorm.DB, orderBy []string) *contract.Error { + startTimeOrder := false + columnSelection := "runs.*" + + for index, orderByClause := range orderBy { + orderByExpr, err := processOrderByClause(orderByClause) + if err != nil { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "invalid order_by clause %q.", + orderByClause, + ), + ) + } + + logger := utils.GetLoggerFromContext(ctx) + logger. + Debugf( + "OrderByExpr: identifier: %v, key: %v, order: %v", + utils.DumpStringPointer(orderByExpr.identifier), + orderByExpr.key, + utils.DumpStringPointer(orderByExpr.order), + ) + + var kind any + + if orderByExpr.identifier == nil && orderByExpr.key == "start_time" { + startTimeOrder = true + } else if orderByExpr.identifier != nil { + switch { + case *orderByExpr.identifier == attribute && orderByExpr.key == "start_time": + startTimeOrder = true + case *orderByExpr.identifier == metric: + kind = &models.LatestMetric{} + case *orderByExpr.identifier == "parameter": + kind = &models.Param{} + case *orderByExpr.identifier == "tag": + kind = &models.Tag{} + } + } + + table := fmt.Sprintf("order_%d", index) + + if kind != nil { + columnsInJoin := []string{"run_uuid", "value"} + if *orderByExpr.identifier == metric { + columnsInJoin = append(columnsInJoin, "is_nan") + } + + transaction.Joins( + fmt.Sprintf("LEFT OUTER JOIN (?) AS %s ON runs.run_uuid = %s.run_uuid", table, table), + database.Select(columnsInJoin).Where("key = ?", orderByExpr.key).Model(kind), + ) + + orderByExpr.key = table + ".value" + } + + desc := false + if orderByExpr.order != nil { + desc = *orderByExpr.order == "DESC" + } + + nullableColumnAlias := fmt.Sprintf("order_null_%d", index) + + if orderByExpr.identifier == nil || *orderByExpr.identifier != metric { + var originalColumn string + + switch { + case orderByExpr.identifier != nil && *orderByExpr.identifier == "attribute": + originalColumn = "runs." + orderByExpr.key + case orderByExpr.identifier != nil: + originalColumn = table + ".value" + default: + originalColumn = orderByExpr.key + } + + columnSelection = fmt.Sprintf( + "%s, (CASE WHEN (%s IS NULL) THEN 1 ELSE 0 END) AS %s", + columnSelection, + originalColumn, + nullableColumnAlias, + ) + + transaction.Order(nullableColumnAlias) + } + + // the metric table has the is_nan column + if orderByExpr.identifier != nil && *orderByExpr.identifier == metric { + trueColumnValue := "true" + if database.Dialector.Name() == "sqlite" { + trueColumnValue = "1" + } + + columnSelection = fmt.Sprintf( + "%s, (CASE WHEN (%s.is_nan = %s) THEN 1 WHEN (%s.value IS NULL) THEN 2 ELSE 0 END) AS %s", + columnSelection, + table, + trueColumnValue, + table, + nullableColumnAlias, + ) + + transaction.Order(nullableColumnAlias) + } + + transaction.Order(clause.OrderByColumn{ + Column: clause.Column{ + Name: orderByExpr.key, + }, + Desc: desc, + }) + } + + if !startTimeOrder { + transaction.Order("runs.start_time DESC") + } + + transaction.Order("runs.run_uuid") + + // mlflow orders all nullable columns to have null last. + // For each order by clause, an additional dynamic order clause was added. + // We need to include these columns in the select clause. + transaction.Select(columnSelection) + + return nil +} + +func mkNextPageToken(runLength, maxResults, offset int) (string, *contract.Error) { + var nextPageToken string + + if runLength == maxResults { + var token strings.Builder + if err := json.NewEncoder( + base64.NewEncoder(base64.StdEncoding, &token), + ).Encode(PageToken{ + Offset: int32(offset + maxResults), + }); err != nil { + return "", contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "error encoding 'nextPageToken' value", + err, + ) + } + + nextPageToken = token.String() + } + + return nextPageToken, nil +} + +//nolint:funlen +func (s TrackingSQLStore) SearchRuns( + ctx context.Context, + experimentIDs []string, filter string, + runViewType protos.ViewType, maxResults int, orderBy []string, pageToken string, +) ([]*entities.Run, string, *contract.Error) { + // ViewType + lifecyleStages := getLifecyleStages(runViewType) + transaction := s.db.WithContext(ctx).Where( + "runs.experiment_id IN ?", experimentIDs, + ).Where( + "runs.lifecycle_stage IN ?", lifecyleStages, + ) + + // MaxResults + transaction.Limit(maxResults) + + // PageToken + offset, contractError := getOffset(pageToken) + if contractError != nil { + return nil, "", contractError + } + + transaction.Offset(offset) + + // Filter + contractError = applyFilter(ctx, s.db, transaction, filter) + if contractError != nil { + return nil, "", contractError + } + + // OrderBy + contractError = applyOrderBy(ctx, s.db, transaction, orderBy) + if contractError != nil { + return nil, "", contractError + } + + // Actual query + var runs []models.Run + + transaction.Preload("LatestMetrics").Preload("Params").Preload("Tags"). + Preload("Inputs", "inputs.destination_type = 'RUN'"). + Preload("Inputs.Dataset").Preload("Inputs.Tags").Find(&runs) + + if transaction.Error != nil { + return nil, "", contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "Failed to query search runs", + transaction.Error, + ) + } + + entityRuns := make([]*entities.Run, len(runs)) + for i, run := range runs { + entityRuns[i] = run.ToEntity() + } + + nextPageToken, contractError := mkNextPageToken(len(runs), maxResults, offset) + if contractError != nil { + return nil, "", contractError + } + + return entityRuns, nextPageToken, nil +} + +const RunIDMaxLength = 32 + +const ( + ArtifactFolderName = "artifacts" + RunNameIntegerScale = 3 + RunNameMaxLength = 20 +) + +func getRunNameFromTags(tags []models.Tag) string { + for _, tag := range tags { + if tag.Key == utils.TagRunName { + return tag.Value + } + } + + return "" +} + +func ensureRunName(runModel *models.Run) *contract.Error { + runNameFromTags := getRunNameFromTags(runModel.Tags) + + switch { + // run_name and name in tags differ + case utils.IsNotNilOrEmptyString(&runModel.Name) && runNameFromTags != "" && runModel.Name != runNameFromTags: + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "Both 'run_name' argument and 'mlflow.runName' tag are specified, but with "+ + "different values (run_name='%s', mlflow.runName='%s').", + runModel.Name, + runNameFromTags, + ), + ) + // no name was provided, generate a random name + case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags == "": + randomName, err := utils.GenerateRandomName() + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + "failed to generate random run name", + err, + ) + } + + runModel.Name = randomName + // use name from tags + case utils.IsNilOrEmptyString(&runModel.Name) && runNameFromTags != "": + runModel.Name = runNameFromTags + } + + if runNameFromTags == "" { + runModel.Tags = append(runModel.Tags, models.Tag{ + Key: utils.TagRunName, + Value: runModel.Name, + }) + } + + return nil +} + +func (s TrackingSQLStore) GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) { + var run models.Run + if err := s.db.WithContext(ctx).Where( + "run_uuid = ?", runID, + ).Preload( + "Tags", + ).Preload( + "Params", + ).Preload( + "Inputs.Tags", + ).Preload( + "LatestMetrics", + ).Preload( + "Inputs.Dataset", + ).First(&run).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("Run with id=%s not found", runID), + ) + } + + return nil, contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to get run", err) + } + + return run.ToEntity(), nil +} + +//nolint:funlen +func (s TrackingSQLStore) CreateRun( + ctx context.Context, + experimentID, userID string, + startTime int64, + tags []*entities.RunTag, + runName string, +) (*entities.Run, *contract.Error) { + experiment, err := s.GetExperiment(ctx, experimentID) + if err != nil { + return nil, err + } + + if models.LifecycleStage(experiment.LifecycleStage) != models.LifecycleStageActive { + return nil, contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "The experiment %q must be in the 'active' state.\n"+ + "Current state is %q.", + experiment.ExperimentID, + experiment.LifecycleStage, + ), + ) + } + + runModel := &models.Run{ + ID: utils.NewUUID(), + Name: runName, + ExperimentID: utils.ConvertStringPointerToInt32Pointer(&experimentID), + StartTime: startTime, + UserID: userID, + Tags: make([]models.Tag, 0, len(tags)), + LifecycleStage: models.LifecycleStageActive, + Status: models.RunStatusRunning, + SourceType: models.SourceTypeUnknown, + } + + for _, tag := range tags { + runModel.Tags = append(runModel.Tags, models.NewTagFromEntity(runModel.ID, tag)) + } + + artifactLocation, appendErr := utils.AppendToURIPath( + experiment.ArtifactLocation, + runModel.ID, + ArtifactFolderName, + ) + if appendErr != nil { + return nil, contract.NewError( + protos.ErrorCode_INTERNAL_ERROR, + "failed to append run ID to experiment artifact location", + ) + } + + runModel.ArtifactURI = artifactLocation + + errRunName := ensureRunName(runModel) + if errRunName != nil { + return nil, errRunName + } + + if err := s.db.Create(&runModel).Error; err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf( + "failed to create run for experiment_id %q", + experiment.ExperimentID, + ), + err, + ) + } + + return runModel.ToEntity(), nil +} + +func (s TrackingSQLStore) UpdateRun( + ctx context.Context, + runID string, + runStatus string, + endTime *int64, + runName string, +) *contract.Error { + runTag, err := s.GetRunTag(ctx, runID, utils.TagRunName) + if err != nil { + return err + } + + tags := make([]models.Tag, 0, 1) + if runTag == nil { + tags = append(tags, models.Tag{ + RunID: runID, + Key: utils.TagRunName, + Value: runName, + }) + } + + var endTimeValue sql.NullInt64 + if endTime == nil { + endTimeValue = sql.NullInt64{} + } else { + endTimeValue = sql.NullInt64{Int64: *endTime, Valid: true} + } + + if err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + if err := transaction.Model(&models.Run{}). + Where("run_uuid = ?", runID). + Updates(&models.Run{ + Name: runName, + Status: models.RunStatus(runStatus), + EndTime: endTimeValue, + }).Error; err != nil { + return err + } + + if len(tags) > 0 { + if err := transaction.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(tags, tagsBatchSize).Error; err != nil { + return fmt.Errorf("failed to create tags for run %q: %w", runID, err) + } + } + + return nil + }); err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to update run", err) + } + + return nil +} + +func (s TrackingSQLStore) DeleteRun(ctx context.Context, runID string) *contract.Error { + run, err := s.GetRun(ctx, runID) + if err != nil { + return err + } + + if err := s.db.WithContext(ctx).Model(&models.Run{}). + Where("run_uuid = ?", run.Info.RunID). + Updates(&models.Run{ + DeletedTime: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, + LifecycleStage: models.LifecycleStageDeleted, + }).Error; err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to delete run", err) + } + + return nil +} + +func (s TrackingSQLStore) RestoreRun(ctx context.Context, runID string) *contract.Error { + run, err := s.GetRun(ctx, runID) + if err != nil { + return err + } + + if err := s.db.WithContext(ctx).Model(&models.Run{}). + Where("run_uuid = ?", run.Info.RunID). + // Force GORM to update fields with zero values by selecting them. + Select("DeletedTime", "LifecycleStage"). + Updates(&models.Run{ + DeletedTime: sql.NullInt64{}, + LifecycleStage: models.LifecycleStageActive, + }).Error; err != nil { + return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "failed to restore run", err) + } + + return nil +} + +func (s TrackingSQLStore) LogBatch( + ctx context.Context, runID string, metrics []*entities.Metric, params []*entities.Param, tags []*entities.RunTag, +) *contract.Error { + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + err := s.setTagsWithTransaction(transaction, runID, tags) + if err != nil { + return fmt.Errorf("error setting tags for run_id %q: %w", runID, err) + } + + contractError = s.logParamsWithTransaction(transaction, runID, params) + if contractError != nil { + return contractError + } + + contractError = s.logMetricsWithTransaction(transaction, runID, metrics) + if contractError != nil { + return contractError + } + + return nil + }) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("log batch transaction failed for %q", runID), + err, + ) + } + + return nil +} diff --git a/pkg/tracking/store/sql/runs_internal_test.go b/pkg/tracking/store/sql/runs_internal_test.go index 94cb3d4..f42583d 100644 --- a/pkg/tracking/store/sql/runs_internal_test.go +++ b/pkg/tracking/store/sql/runs_internal_test.go @@ -1,518 +1,518 @@ -//nolint:ireturn -package sql - -import ( - "context" - "reflect" - "regexp" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/iancoleman/strcase" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/driver/sqlserver" - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -type testData struct { - name string - query string - orderBy []string - expectedSQL map[string]string - expectedVars []any -} - -var whitespaceRegex = regexp.MustCompile(`\s` + "|`") - -func removeWhitespace(s string) string { - return whitespaceRegex.ReplaceAllString(s, "") -} - -var tests = []testData{ - { - name: "SimpleMetricQuery", - query: "metrics.accuracy > 0.72", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) - AS filter_0 - ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlserver": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = @p1 AND value > @p2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "mysql": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"accuracy", 0.72}, - }, - { - name: "SimpleMetricAndParamQuery", - query: "metrics.accuracy > 0.72 AND params.batch_size = '2'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value = $4) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND value = ?) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"accuracy", 0.72, "batch_size", "2"}, - }, - { - name: "TagQuery", - query: "tags.environment = 'notebook' AND tags.task ILIKE 'classif%'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $3 AND value ILIKE $4) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"environment", "notebook", "task", "classif%"}, - }, - { - name: "DatasestsInQuery", - query: "datasets.digest IN ('s8ds293b', 'jks834s2')", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT destination_id,"digest" - FROM "datasets" JOIN inputs ON inputs.source_id = datasets.dataset_uuid - WHERE digest IN ($1,$2) - ) - AS filter_0 ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT destination_id,digest - FROM datasets JOIN inputs - ON inputs.source_id = datasets.dataset_uuid - WHERE digest IN (?,?) - ) - AS filter_0 - ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"s8ds293b", "jks834s2"}, - }, - { - name: "AttributesQuery", - query: "attributes.run_id = 'a1b2c3d4'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - WHERE runs.run_uuid = $1 - ORDER BY runs.start_time DESC,runs.run_uuid - `, - "sqlite": `SELECT run_uuid FROM runs WHERE runs.run_uuid = ? ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"a1b2c3d4"}, - }, - { - name: "Run_nameQuery", - query: "attributes.run_name = 'my-run'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"mlflow.runName", "my-run"}, - }, - { - name: "DatasetsContextQuery", - query: "datasets.context = 'train'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT inputs.destination_id AS run_uuid - FROM "inputs" - JOIN input_tags - ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND input_tags.value = $1 - WHERE inputs.destination_type = 'RUN' - ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT inputs.destination_id AS run_uuid - FROM inputs - JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND input_tags.value = ? WHERE inputs.destination_type = 'RUN' - ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"train"}, - }, - { - name: "Run_nameQuery", - query: "attributes.run_name ILIKE 'my-run%'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value ILIKE $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid, value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"mlflow.runName", "my-run%"}, - }, - { - name: "DatasetsContextQuery", - query: "datasets.context ILIKE '%train'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT inputs.destination_id AS run_uuid FROM "inputs" - JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND input_tags.value ILIKE $1 WHERE inputs.destination_type = 'RUN' - ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT inputs.destination_id AS run_uuid FROM inputs - JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid - AND input_tags.name = 'mlflow.data.context' - AND LOWER(input_tags.value) LIKE ? WHERE inputs.destination_type = 'RUN') - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid - `, - }, - expectedVars: []any{"%train"}, - }, - { - name: "DatasestsDigest", - query: "datasets.digest ILIKE '%s'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN ( - SELECT destination_id,"digest" - FROM "datasets" - JOIN inputs ON inputs.source_id = datasets.dataset_uuid - WHERE digest ILIKE $1 - ) - AS filter_0 ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN ( - SELECT destination_id,digest - FROM datasets - JOIN inputs ON inputs.source_id = datasets.dataset_uuid - WHERE LOWER(digest) LIKE ?) - AS filter_0 ON runs.run_uuid = filter_0.destination_id - ORDER BY runs.start_time DESC,runs.run_uuid`, - }, - expectedVars: []any{"%s"}, - }, - { - name: "ParamQuery", - query: "metrics.accuracy > 0.72 AND params.batch_size ILIKE '%a'", - expectedSQL: map[string]string{ - "postgres": ` - SELECT "run_uuid" FROM "runs" - JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value ILIKE $4) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid`, - "sqlite": ` - SELECT run_uuid FROM runs - JOIN (SELECT run_uuid, value FROM latest_metrics WHERE key = ? AND value > ?) - AS filter_0 ON runs.run_uuid = filter_0.run_uuid - JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND LOWER(value) LIKE ?) - AS filter_1 ON runs.run_uuid = filter_1.run_uuid - ORDER BY runs.start_time DESC,runs.run_uuid - `, - }, - expectedVars: []any{"accuracy", 0.72, "batch_size", "%a"}, - }, - { - name: "OrderByStartTimeASC", - query: "", - orderBy: []string{"start_time ASC"}, - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "start_time",runs.run_uuid`, - }, - expectedVars: []any{}, - }, - { - name: "OrderByStatusDesc", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "status" DESC,runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"status DESC"}, - expectedVars: []any{}, - }, - { - name: "OrderByRunNameSnakeCase", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"run_name"}, - expectedVars: []any{}, - }, - { - name: "OrderByRunNameLowerName", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"`Run name`"}, - expectedVars: []any{}, - }, - { - name: "OrderByRunNamePascal", - query: "", - expectedSQL: map[string]string{ - "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, - }, - orderBy: []string{"`Run Name`"}, - expectedVars: []any{}, - }, -} - -func newPostgresDialector() gorm.Dialector { - mockedDB, _, _ := sqlmock.New() - - return postgres.New(postgres.Config{ - Conn: mockedDB, - DriverName: "postgres", - }) -} - -func newSqliteDialector() gorm.Dialector { - mockedDB, mock, _ := sqlmock.New() - mock.ExpectQuery("select sqlite_version()").WillReturnRows( - sqlmock.NewRows([]string{"sqlite_version()"}).AddRow("3.41.1")) - - return sqlite.New(sqlite.Config{ - DriverName: "sqlite3", - Conn: mockedDB, - }) -} - -func newSQLServerDialector() gorm.Dialector { - mockedDB, _, _ := sqlmock.New() - - return sqlserver.New(sqlserver.Config{ - DriverName: "sqlserver", - Conn: mockedDB, - }) -} - -func newMySQLDialector() gorm.Dialector { - mockedDB, _, _ := sqlmock.New() - - return mysql.New(mysql.Config{ - DriverName: "mysql", - Conn: mockedDB, - SkipInitializeWithVersion: true, - }) -} - -var dialectors = []gorm.Dialector{ - newPostgresDialector(), - newSqliteDialector(), - newSQLServerDialector(), - newMySQLDialector(), -} - -func assertTestData( - t *testing.T, database *gorm.DB, expectedSQL string, testData testData, -) { - t.Helper() - - transaction := database.Model(&models.Run{}) - - contractErr := applyFilter(context.Background(), database, transaction, testData.query) - if contractErr != nil { - t.Fatal("contractErr: ", contractErr) - } - - contractErr = applyOrderBy(context.Background(), database, transaction, testData.orderBy) - if contractErr != nil { - t.Fatal("contractErr: ", contractErr) - } - - sqlErr := transaction.Select("ID").Find(&models.Run{}).Error - require.NoError(t, sqlErr) - - actualSQL := transaction.Statement.SQL.String() - - // if removeWhitespace(expectedSQL) != removeWhitespace(actualSQL) { - // fmt.Println(strings.ReplaceAll(actualSQL, "`", "")) - // } - - assert.Equal(t, removeWhitespace(expectedSQL), removeWhitespace(actualSQL)) - assert.Equal(t, testData.expectedVars, transaction.Statement.Vars) -} - -func TestSearchRuns(t *testing.T) { - t.Parallel() - - for _, dialector := range dialectors { - database, err := gorm.Open(dialector, &gorm.Config{DryRun: true}) - require.NoError(t, err) - - dialectorName := database.Dialector.Name() - - for _, testData := range tests { - currentTestData := testData - if expectedSQL, ok := currentTestData.expectedSQL[dialectorName]; ok { - t.Run(currentTestData.name+"_"+dialectorName, func(t *testing.T) { - t.Parallel() - assertTestData(t, database, expectedSQL, currentTestData) - }) - } - } - } -} - -func TestInvalidSearchRunsQuery(t *testing.T) { - t.Parallel() - - database, err := gorm.Open(newSqliteDialector(), &gorm.Config{DryRun: true}) - require.NoError(t, err) - - transaction := database.Model(&models.Run{}) - - contractErr := applyFilter(context.Background(), database, transaction, "⚡✱*@❖$#&") - if contractErr == nil { - t.Fatal("expected contract error") - } -} - -//nolint:funlen -func TestOrderByClauseParsing(t *testing.T) { - t.Parallel() - - testData := []struct { - input string - expected orderByExpr - }{ - { - input: "status DESC", - expected: orderByExpr{ - key: "status", - order: utils.PtrTo("DESC"), - }, - }, - { - input: "run_name", - expected: orderByExpr{ - key: "name", - }, - }, - { - input: "params.input DESC", - expected: orderByExpr{ - identifier: utils.PtrTo("parameter"), - key: "input", - order: utils.PtrTo("DESC"), - }, - }, - { - input: "metrics.alpha ASC", - expected: orderByExpr{ - identifier: utils.PtrTo("metric"), - key: "alpha", - order: utils.PtrTo("ASC"), - }, - }, - { - input: "`Run name`", - expected: orderByExpr{ - key: "name", - }, - }, - { - input: "tags.`foo bar` ASC", - expected: orderByExpr{ - identifier: utils.PtrTo("tag"), - key: "foo bar", - order: utils.PtrTo("ASC"), - }, - }, - } - - for _, testData := range testData { - t.Run(strcase.ToKebab(testData.input), func(t *testing.T) { - t.Parallel() - - result, err := processOrderByClause(testData.input) - if err != nil { - t.Fatalf("unexpected error: %A", err) - } - - if !reflect.DeepEqual(testData.expected, result) { - t.Fatalf("expected (%s, %s, %s), got (%s, %s, %s)", - *testData.expected.identifier, - testData.expected.key, - *testData.expected.order, - *result.identifier, - result.key, - *result.order, - ) - } - }) - } -} +//nolint:ireturn +package sql + +import ( + "context" + "reflect" + "regexp" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/iancoleman/strcase" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +type testData struct { + name string + query string + orderBy []string + expectedSQL map[string]string + expectedVars []any +} + +var whitespaceRegex = regexp.MustCompile(`\s` + "|`") + +func removeWhitespace(s string) string { + return whitespaceRegex.ReplaceAllString(s, "") +} + +var tests = []testData{ + { + name: "SimpleMetricQuery", + query: "metrics.accuracy > 0.72", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) + AS filter_0 + ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlserver": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = @p1 AND value > @p2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "mysql": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"accuracy", 0.72}, + }, + { + name: "SimpleMetricAndParamQuery", + query: "metrics.accuracy > 0.72 AND params.batch_size = '2'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value = $4) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND value = ?) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"accuracy", 0.72, "batch_size", "2"}, + }, + { + name: "TagQuery", + query: "tags.environment = 'notebook' AND tags.task ILIKE 'classif%'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $3 AND value ILIKE $4) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"environment", "notebook", "task", "classif%"}, + }, + { + name: "DatasestsInQuery", + query: "datasets.digest IN ('s8ds293b', 'jks834s2')", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT destination_id,"digest" + FROM "datasets" JOIN inputs ON inputs.source_id = datasets.dataset_uuid + WHERE digest IN ($1,$2) + ) + AS filter_0 ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT destination_id,digest + FROM datasets JOIN inputs + ON inputs.source_id = datasets.dataset_uuid + WHERE digest IN (?,?) + ) + AS filter_0 + ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"s8ds293b", "jks834s2"}, + }, + { + name: "AttributesQuery", + query: "attributes.run_id = 'a1b2c3d4'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + WHERE runs.run_uuid = $1 + ORDER BY runs.start_time DESC,runs.run_uuid + `, + "sqlite": `SELECT run_uuid FROM runs WHERE runs.run_uuid = ? ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"a1b2c3d4"}, + }, + { + name: "Run_nameQuery", + query: "attributes.run_name = 'my-run'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value = $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid,value FROM tags WHERE key = ? AND value = ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"mlflow.runName", "my-run"}, + }, + { + name: "DatasetsContextQuery", + query: "datasets.context = 'train'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT inputs.destination_id AS run_uuid + FROM "inputs" + JOIN input_tags + ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND input_tags.value = $1 + WHERE inputs.destination_type = 'RUN' + ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT inputs.destination_id AS run_uuid + FROM inputs + JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND input_tags.value = ? WHERE inputs.destination_type = 'RUN' + ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"train"}, + }, + { + name: "Run_nameQuery", + query: "attributes.run_name ILIKE 'my-run%'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "tags" WHERE key = $1 AND value ILIKE $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid, value FROM tags WHERE key = ? AND LOWER(value) LIKE ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"mlflow.runName", "my-run%"}, + }, + { + name: "DatasetsContextQuery", + query: "datasets.context ILIKE '%train'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT inputs.destination_id AS run_uuid FROM "inputs" + JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND input_tags.value ILIKE $1 WHERE inputs.destination_type = 'RUN' + ) AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT inputs.destination_id AS run_uuid FROM inputs + JOIN input_tags ON inputs.input_uuid = input_tags.input_uuid + AND input_tags.name = 'mlflow.data.context' + AND LOWER(input_tags.value) LIKE ? WHERE inputs.destination_type = 'RUN') + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid + `, + }, + expectedVars: []any{"%train"}, + }, + { + name: "DatasestsDigest", + query: "datasets.digest ILIKE '%s'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN ( + SELECT destination_id,"digest" + FROM "datasets" + JOIN inputs ON inputs.source_id = datasets.dataset_uuid + WHERE digest ILIKE $1 + ) + AS filter_0 ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN ( + SELECT destination_id,digest + FROM datasets + JOIN inputs ON inputs.source_id = datasets.dataset_uuid + WHERE LOWER(digest) LIKE ?) + AS filter_0 ON runs.run_uuid = filter_0.destination_id + ORDER BY runs.start_time DESC,runs.run_uuid`, + }, + expectedVars: []any{"%s"}, + }, + { + name: "ParamQuery", + query: "metrics.accuracy > 0.72 AND params.batch_size ILIKE '%a'", + expectedSQL: map[string]string{ + "postgres": ` + SELECT "run_uuid" FROM "runs" + JOIN (SELECT "run_uuid","value" FROM "latest_metrics" WHERE key = $1 AND value > $2) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT "run_uuid","value" FROM "params" WHERE key = $3 AND value ILIKE $4) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid`, + "sqlite": ` + SELECT run_uuid FROM runs + JOIN (SELECT run_uuid, value FROM latest_metrics WHERE key = ? AND value > ?) + AS filter_0 ON runs.run_uuid = filter_0.run_uuid + JOIN (SELECT run_uuid,value FROM params WHERE key = ? AND LOWER(value) LIKE ?) + AS filter_1 ON runs.run_uuid = filter_1.run_uuid + ORDER BY runs.start_time DESC,runs.run_uuid + `, + }, + expectedVars: []any{"accuracy", 0.72, "batch_size", "%a"}, + }, + { + name: "OrderByStartTimeASC", + query: "", + orderBy: []string{"start_time ASC"}, + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "start_time",runs.run_uuid`, + }, + expectedVars: []any{}, + }, + { + name: "OrderByStatusDesc", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "status" DESC,runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"status DESC"}, + expectedVars: []any{}, + }, + { + name: "OrderByRunNameSnakeCase", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"run_name"}, + expectedVars: []any{}, + }, + { + name: "OrderByRunNameLowerName", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"`Run name`"}, + expectedVars: []any{}, + }, + { + name: "OrderByRunNamePascal", + query: "", + expectedSQL: map[string]string{ + "postgres": `SELECT "run_uuid" FROM "runs" ORDER BY order_null_0, "name",runs.start_time DESC,runs.run_uuid`, + }, + orderBy: []string{"`Run Name`"}, + expectedVars: []any{}, + }, +} + +func newPostgresDialector() gorm.Dialector { + mockedDB, _, _ := sqlmock.New() + + return postgres.New(postgres.Config{ + Conn: mockedDB, + DriverName: "postgres", + }) +} + +func newSqliteDialector() gorm.Dialector { + mockedDB, mock, _ := sqlmock.New() + mock.ExpectQuery("select sqlite_version()").WillReturnRows( + sqlmock.NewRows([]string{"sqlite_version()"}).AddRow("3.41.1")) + + return sqlite.New(sqlite.Config{ + DriverName: "sqlite3", + Conn: mockedDB, + }) +} + +func newSQLServerDialector() gorm.Dialector { + mockedDB, _, _ := sqlmock.New() + + return sqlserver.New(sqlserver.Config{ + DriverName: "sqlserver", + Conn: mockedDB, + }) +} + +func newMySQLDialector() gorm.Dialector { + mockedDB, _, _ := sqlmock.New() + + return mysql.New(mysql.Config{ + DriverName: "mysql", + Conn: mockedDB, + SkipInitializeWithVersion: true, + }) +} + +var dialectors = []gorm.Dialector{ + newPostgresDialector(), + newSqliteDialector(), + newSQLServerDialector(), + newMySQLDialector(), +} + +func assertTestData( + t *testing.T, database *gorm.DB, expectedSQL string, testData testData, +) { + t.Helper() + + transaction := database.Model(&models.Run{}) + + contractErr := applyFilter(context.Background(), database, transaction, testData.query) + if contractErr != nil { + t.Fatal("contractErr: ", contractErr) + } + + contractErr = applyOrderBy(context.Background(), database, transaction, testData.orderBy) + if contractErr != nil { + t.Fatal("contractErr: ", contractErr) + } + + sqlErr := transaction.Select("ID").Find(&models.Run{}).Error + require.NoError(t, sqlErr) + + actualSQL := transaction.Statement.SQL.String() + + // if removeWhitespace(expectedSQL) != removeWhitespace(actualSQL) { + // fmt.Println(strings.ReplaceAll(actualSQL, "`", "")) + // } + + assert.Equal(t, removeWhitespace(expectedSQL), removeWhitespace(actualSQL)) + assert.Equal(t, testData.expectedVars, transaction.Statement.Vars) +} + +func TestSearchRuns(t *testing.T) { + t.Parallel() + + for _, dialector := range dialectors { + database, err := gorm.Open(dialector, &gorm.Config{DryRun: true}) + require.NoError(t, err) + + dialectorName := database.Dialector.Name() + + for _, testData := range tests { + currentTestData := testData + if expectedSQL, ok := currentTestData.expectedSQL[dialectorName]; ok { + t.Run(currentTestData.name+"_"+dialectorName, func(t *testing.T) { + t.Parallel() + assertTestData(t, database, expectedSQL, currentTestData) + }) + } + } + } +} + +func TestInvalidSearchRunsQuery(t *testing.T) { + t.Parallel() + + database, err := gorm.Open(newSqliteDialector(), &gorm.Config{DryRun: true}) + require.NoError(t, err) + + transaction := database.Model(&models.Run{}) + + contractErr := applyFilter(context.Background(), database, transaction, "⚡✱*@❖$#&") + if contractErr == nil { + t.Fatal("expected contract error") + } +} + +//nolint:funlen +func TestOrderByClauseParsing(t *testing.T) { + t.Parallel() + + testData := []struct { + input string + expected orderByExpr + }{ + { + input: "status DESC", + expected: orderByExpr{ + key: "status", + order: utils.PtrTo("DESC"), + }, + }, + { + input: "run_name", + expected: orderByExpr{ + key: "name", + }, + }, + { + input: "params.input DESC", + expected: orderByExpr{ + identifier: utils.PtrTo("parameter"), + key: "input", + order: utils.PtrTo("DESC"), + }, + }, + { + input: "metrics.alpha ASC", + expected: orderByExpr{ + identifier: utils.PtrTo("metric"), + key: "alpha", + order: utils.PtrTo("ASC"), + }, + }, + { + input: "`Run name`", + expected: orderByExpr{ + key: "name", + }, + }, + { + input: "tags.`foo bar` ASC", + expected: orderByExpr{ + identifier: utils.PtrTo("tag"), + key: "foo bar", + order: utils.PtrTo("ASC"), + }, + }, + } + + for _, testData := range testData { + t.Run(strcase.ToKebab(testData.input), func(t *testing.T) { + t.Parallel() + + result, err := processOrderByClause(testData.input) + if err != nil { + t.Fatalf("unexpected error: %A", err) + } + + if !reflect.DeepEqual(testData.expected, result) { + t.Fatalf("expected (%s, %s, %s), got (%s, %s, %s)", + *testData.expected.identifier, + testData.expected.key, + *testData.expected.order, + *result.identifier, + result.key, + *result.order, + ) + } + }) + } +} diff --git a/pkg/tracking/store/sql/store.go b/pkg/tracking/store/sql/store.go index 9fb40f3..0570508 100644 --- a/pkg/tracking/store/sql/store.go +++ b/pkg/tracking/store/sql/store.go @@ -1,28 +1,28 @@ -package sql - -import ( - "context" - "fmt" - - "gorm.io/gorm" - - "github.com/mlflow/mlflow-go/pkg/config" - "github.com/mlflow/mlflow-go/pkg/sql" -) - -type TrackingSQLStore struct { - config *config.Config - db *gorm.DB -} - -func NewTrackingSQLStore(ctx context.Context, config *config.Config) (*TrackingSQLStore, error) { - database, err := sql.NewDatabase(ctx, config.TrackingStoreURI) - if err != nil { - return nil, fmt.Errorf("failed to connect to database %q: %w", config.TrackingStoreURI, err) - } - - return &TrackingSQLStore{ - config: config, - db: database, - }, nil -} +package sql + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/mlflow/mlflow-go/pkg/config" + "github.com/mlflow/mlflow-go/pkg/sql" +) + +type TrackingSQLStore struct { + config *config.Config + db *gorm.DB +} + +func NewTrackingSQLStore(ctx context.Context, config *config.Config) (*TrackingSQLStore, error) { + database, err := sql.NewDatabase(ctx, config.TrackingStoreURI) + if err != nil { + return nil, fmt.Errorf("failed to connect to database %q: %w", config.TrackingStoreURI, err) + } + + return &TrackingSQLStore{ + config: config, + db: database, + }, nil +} diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index 107d487..9507133 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -1,262 +1,262 @@ -package sql - -import ( - "context" - "errors" - "fmt" - "strconv" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" - "github.com/mlflow/mlflow-go/pkg/utils" -) - -const tagsBatchSize = 100 - -func (s TrackingSQLStore) GetRunTag( - ctx context.Context, runID, tagKey string, -) (*entities.RunTag, *contract.Error) { - var runTag models.Tag - if err := s.db.WithContext( - ctx, - ).Where( - "run_uuid = ?", runID, - ).Where( - "key = ?", tagKey, - ).First(&runTag).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, nil - } - - return nil, contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("failed to get run tag for run id %q", runID), - err, - ) - } - - return runTag.ToEntity(), nil -} - -func (s TrackingSQLStore) setTagsWithTransaction( - transaction *gorm.DB, runID string, tags []*entities.RunTag, -) error { - runColumns := make(map[string]interface{}) - - for _, tag := range tags { - switch tag.Key { - case utils.TagUser: - runColumns["user_id"] = tag.Value - case utils.TagRunName: - runColumns["name"] = tag.Value - } - } - - if len(runColumns) != 0 { - err := transaction. - Model(&models.Run{}). - Where("run_uuid = ?", runID). - UpdateColumns(runColumns).Error - if err != nil { - return fmt.Errorf("failed to update run columns: %w", err) - } - } - - runTags := make([]models.Tag, 0, len(tags)) - - for _, tag := range tags { - runTags = append(runTags, models.NewTagFromEntity(runID, tag)) - } - - if err := transaction.Clauses(clause.OnConflict{ - UpdateAll: true, - }).CreateInBatches(runTags, tagsBatchSize).Error; err != nil { - return fmt.Errorf("failed to create tags for run %q: %w", runID, err) - } - - return nil -} - -const ( - maxEntityKeyLength = 250 - maxTagValueLength = 8000 -) - -// Helper function to validate the tag key and value -func validateTag(key, value string) *contract.Error { - if key == "" { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "Missing value for required parameter 'key'", - ) - } - if len(key) > maxEntityKeyLength { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Tag key '%s' had length %d, which exceeded length limit of %d", key, len(key), maxEntityKeyLength), - ) - } - if len(value) > maxTagValueLength { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Tag value exceeded length limit of %d characters", maxTagValueLength), - ) - } - // TODO: Check if this is the correct way to prevent invalid values - if _, err := strconv.ParseFloat(value, 64); err == nil { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Invalid value %s for parameter 'value' supplied", value), - ) - } - return nil -} - -func (s TrackingSQLStore) SetTag( - ctx context.Context, runID, key, value string, -) *contract.Error { - if runID == "" { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "RunID cannot be empty", - ) - } - - // If the runID can be parsed as a number, it should throw an error - if _, err := strconv.ParseFloat(runID, 64); err == nil { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Invalid value %s for parameter 'run_id' supplied", runID), - ) - } - - if err := validateTag(key, value); err != nil { - return err - } - - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - contractError := checkRunIsActive(transaction, runID) - if contractError != nil { - return contractError - } - - var tag models.Tag - result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) - - if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), - result.Error, - ) - } - - if result.RowsAffected == 1 { - tag.Value = value - if err := transaction.Save(&tag).Error; err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), - err, - ) - } - } else { - newTag := models.Tag{ - RunID: runID, - Key: key, - Value: value, - } - if err := transaction.Create(&newTag).Error; err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), - err, - ) - } - } - - return nil - }) - - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return contractError - } - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("set tag transaction failed for %q", runID), - err, - ) - } - - return nil -} - -const badDataMessage = "Bad data in database - tags for a specific run must have\n" + - "a single unique value.\n" + - "See https://mlflow.org/docs/latest/tracking.html#adding-tags-to-runs" - -func (s TrackingSQLStore) DeleteTag( - ctx context.Context, runID, key string, -) *contract.Error { - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { - contractError := checkRunIsActive(transaction, runID) - if contractError != nil { - return contractError - } - - var tags []models.Tag - - transaction.Model(models.Tag{}).Where("run_uuid = ?", runID).Where("key = ?", key).Find(&tags) - - if transaction.Error != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to query tags for run_id %q and key %q", runID, key), - transaction.Error, - ) - } - - switch len(tags) { - case 0: - return contract.NewError( - protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, - fmt.Sprintf("No tag with name: %s in run with id %s", key, runID), - ) - case 1: - transaction.Delete(tags[0]) - - if transaction.Error != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to query tags for run_id %q and key %q", runID, key), - transaction.Error, - ) - } - - return nil - default: - return contract.NewError(protos.ErrorCode_INVALID_STATE, badDataMessage) - } - }) - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return contractError - } - - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("delete tag transaction failed for %q", runID), - err, - ) - } - - return nil +package sql + +import ( + "context" + "errors" + "fmt" + "strconv" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/tracking/store/sql/models" + "github.com/mlflow/mlflow-go/pkg/utils" +) + +const tagsBatchSize = 100 + +func (s TrackingSQLStore) GetRunTag( + ctx context.Context, runID, tagKey string, +) (*entities.RunTag, *contract.Error) { + var runTag models.Tag + if err := s.db.WithContext( + ctx, + ).Where( + "run_uuid = ?", runID, + ).Where( + "key = ?", tagKey, + ).First(&runTag).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to get run tag for run id %q", runID), + err, + ) + } + + return runTag.ToEntity(), nil +} + +func (s TrackingSQLStore) setTagsWithTransaction( + transaction *gorm.DB, runID string, tags []*entities.RunTag, +) error { + runColumns := make(map[string]interface{}) + + for _, tag := range tags { + switch tag.Key { + case utils.TagUser: + runColumns["user_id"] = tag.Value + case utils.TagRunName: + runColumns["name"] = tag.Value + } + } + + if len(runColumns) != 0 { + err := transaction. + Model(&models.Run{}). + Where("run_uuid = ?", runID). + UpdateColumns(runColumns).Error + if err != nil { + return fmt.Errorf("failed to update run columns: %w", err) + } + } + + runTags := make([]models.Tag, 0, len(tags)) + + for _, tag := range tags { + runTags = append(runTags, models.NewTagFromEntity(runID, tag)) + } + + if err := transaction.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(runTags, tagsBatchSize).Error; err != nil { + return fmt.Errorf("failed to create tags for run %q: %w", runID, err) + } + + return nil +} + +const ( + maxEntityKeyLength = 250 + maxTagValueLength = 8000 +) + +// Helper function to validate the tag key and value +func validateTag(key, value string) *contract.Error { + if key == "" { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + "Missing value for required parameter 'key'", + ) + } + if len(key) > maxEntityKeyLength { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Tag key '%s' had length %d, which exceeded length limit of %d", key, len(key), maxEntityKeyLength), + ) + } + if len(value) > maxTagValueLength { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Tag value exceeded length limit of %d characters", maxTagValueLength), + ) + } + // TODO: Check if this is the correct way to prevent invalid values + if _, err := strconv.ParseFloat(value, 64); err == nil { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Invalid value %s for parameter 'value' supplied", value), + ) + } + return nil +} + +func (s TrackingSQLStore) SetTag( + ctx context.Context, runID, key, value string, +) *contract.Error { + if runID == "" { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + "RunID cannot be empty", + ) + } + + // If the runID can be parsed as a number, it should throw an error + if _, err := strconv.ParseFloat(runID, 64); err == nil { + return contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf("Invalid value %s for parameter 'run_id' supplied", runID), + ) + } + + if err := validateTag(key, value); err != nil { + return err + } + + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + var tag models.Tag + result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) + + if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), + result.Error, + ) + } + + if result.RowsAffected == 1 { + tag.Value = value + if err := transaction.Save(&tag).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), + err, + ) + } + } else { + newTag := models.Tag{ + RunID: runID, + Key: key, + Value: value, + } + if err := transaction.Create(&newTag).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), + err, + ) + } + } + + return nil + }) + + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("set tag transaction failed for %q", runID), + err, + ) + } + + return nil +} + +const badDataMessage = "Bad data in database - tags for a specific run must have\n" + + "a single unique value.\n" + + "See https://mlflow.org/docs/latest/tracking.html#adding-tags-to-runs" + +func (s TrackingSQLStore) DeleteTag( + ctx context.Context, runID, key string, +) *contract.Error { + err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + contractError := checkRunIsActive(transaction, runID) + if contractError != nil { + return contractError + } + + var tags []models.Tag + + transaction.Model(models.Tag{}).Where("run_uuid = ?", runID).Where("key = ?", key).Find(&tags) + + if transaction.Error != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query tags for run_id %q and key %q", runID, key), + transaction.Error, + ) + } + + switch len(tags) { + case 0: + return contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("No tag with name: %s in run with id %s", key, runID), + ) + case 1: + transaction.Delete(tags[0]) + + if transaction.Error != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query tags for run_id %q and key %q", runID, key), + transaction.Error, + ) + } + + return nil + default: + return contract.NewError(protos.ErrorCode_INVALID_STATE, badDataMessage) + } + }) + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError + } + + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("delete tag transaction failed for %q", runID), + err, + ) + } + + return nil } \ No newline at end of file diff --git a/pkg/tracking/store/store.go b/pkg/tracking/store/store.go index f83fb66..7374400 100644 --- a/pkg/tracking/store/store.go +++ b/pkg/tracking/store/store.go @@ -1,81 +1,81 @@ -package store - -import ( - "context" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/entities" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -//go:generate mockery -type TrackingStore interface { - RunTrackingStore - MetricTrackingStore - ExperimentTrackingStore -} - -type ( - RunTrackingStore interface { - GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) - CreateRun( - ctx context.Context, - experimentID string, - userID string, - startTime int64, - tags []*entities.RunTag, - runName string, - ) (*entities.Run, *contract.Error) - UpdateRun( - ctx context.Context, - runID string, - runStatus string, - endTime *int64, - runName string, - ) *contract.Error - DeleteRun(ctx context.Context, runID string) *contract.Error - RestoreRun(ctx context.Context, runID string) *contract.Error - - GetRunTag(ctx context.Context, runID, tagKey string) (*entities.RunTag, *contract.Error) - SetTag(ctx context.Context, runID, key string, value string) *contract.Error - DeleteTag(ctx context.Context, runID, key string) *contract.Error - } - MetricTrackingStore interface { - LogBatch( - ctx context.Context, - runID string, - metrics []*entities.Metric, - params []*entities.Param, - tags []*entities.RunTag) *contract.Error - - LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error - } -) - -type ExperimentTrackingStore interface { - // GetExperiment returns experiment by the experiment ID. - // The experiment should contain the linked tags. - GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) - GetExperimentByName(ctx context.Context, name string) (*entities.Experiment, *contract.Error) - - CreateExperiment( - ctx context.Context, - name string, - artifactLocation string, - tags []*entities.ExperimentTag, - ) (string, *contract.Error) - RestoreExperiment(ctx context.Context, id string) *contract.Error - RenameExperiment(ctx context.Context, experimentID, name string) *contract.Error - - SearchRuns( - ctx context.Context, - experimentIDs []string, - filter string, - runViewType protos.ViewType, - maxResults int, - orderBy []string, - pageToken string, - ) ([]*entities.Run, string, *contract.Error) - - DeleteExperiment(ctx context.Context, id string) *contract.Error -} +package store + +import ( + "context" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +//go:generate mockery +type TrackingStore interface { + RunTrackingStore + MetricTrackingStore + ExperimentTrackingStore +} + +type ( + RunTrackingStore interface { + GetRun(ctx context.Context, runID string) (*entities.Run, *contract.Error) + CreateRun( + ctx context.Context, + experimentID string, + userID string, + startTime int64, + tags []*entities.RunTag, + runName string, + ) (*entities.Run, *contract.Error) + UpdateRun( + ctx context.Context, + runID string, + runStatus string, + endTime *int64, + runName string, + ) *contract.Error + DeleteRun(ctx context.Context, runID string) *contract.Error + RestoreRun(ctx context.Context, runID string) *contract.Error + + GetRunTag(ctx context.Context, runID, tagKey string) (*entities.RunTag, *contract.Error) + SetTag(ctx context.Context, runID, key string, value string) *contract.Error + DeleteTag(ctx context.Context, runID, key string) *contract.Error + } + MetricTrackingStore interface { + LogBatch( + ctx context.Context, + runID string, + metrics []*entities.Metric, + params []*entities.Param, + tags []*entities.RunTag) *contract.Error + + LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error + } +) + +type ExperimentTrackingStore interface { + // GetExperiment returns experiment by the experiment ID. + // The experiment should contain the linked tags. + GetExperiment(ctx context.Context, id string) (*entities.Experiment, *contract.Error) + GetExperimentByName(ctx context.Context, name string) (*entities.Experiment, *contract.Error) + + CreateExperiment( + ctx context.Context, + name string, + artifactLocation string, + tags []*entities.ExperimentTag, + ) (string, *contract.Error) + RestoreExperiment(ctx context.Context, id string) *contract.Error + RenameExperiment(ctx context.Context, experimentID, name string) *contract.Error + + SearchRuns( + ctx context.Context, + experimentIDs []string, + filter string, + runViewType protos.ViewType, + maxResults int, + orderBy []string, + pageToken string, + ) ([]*entities.Run, string, *contract.Error) + + DeleteExperiment(ctx context.Context, id string) *contract.Error +} diff --git a/pkg/utils/logger.go b/pkg/utils/logger.go index ff44dff..13249f7 100644 --- a/pkg/utils/logger.go +++ b/pkg/utils/logger.go @@ -1,49 +1,49 @@ -package utils - -import ( - "context" - - "github.com/gofiber/fiber/v2" - "github.com/sirupsen/logrus" - - "github.com/mlflow/mlflow-go/pkg/config" -) - -type loggerKey struct{} - -func NewContextWithLogger(ctx context.Context, logger *logrus.Logger) context.Context { - return context.WithValue(ctx, loggerKey{}, logger) -} - -// NewContextWithLoggerFromFiberContext transfer logger from Fiber context to a normal context.Context object. -func NewContextWithLoggerFromFiberContext(c *fiber.Ctx) context.Context { - logger := GetLoggerFromContext(c.UserContext()) - - return NewContextWithLogger(c.Context(), logger) -} - -func GetLoggerFromContext(ctx context.Context) *logrus.Logger { - logger := ctx.Value(loggerKey{}) - if logger != nil { - logger, ok := logger.(*logrus.Logger) - if ok { - return logger - } - } - - return logrus.StandardLogger() -} - -func NewLoggerFromConfig(cfg *config.Config) *logrus.Logger { - logger := logrus.New() - - logLevel, err := logrus.ParseLevel(cfg.LogLevel) - if err != nil { - logLevel = logrus.InfoLevel - logger.Warnf("failed to parse log level: %s - assuming %q", err, logrus.InfoLevel) - } - - logger.SetLevel(logLevel) - - return logger -} +package utils + +import ( + "context" + + "github.com/gofiber/fiber/v2" + "github.com/sirupsen/logrus" + + "github.com/mlflow/mlflow-go/pkg/config" +) + +type loggerKey struct{} + +func NewContextWithLogger(ctx context.Context, logger *logrus.Logger) context.Context { + return context.WithValue(ctx, loggerKey{}, logger) +} + +// NewContextWithLoggerFromFiberContext transfer logger from Fiber context to a normal context.Context object. +func NewContextWithLoggerFromFiberContext(c *fiber.Ctx) context.Context { + logger := GetLoggerFromContext(c.UserContext()) + + return NewContextWithLogger(c.Context(), logger) +} + +func GetLoggerFromContext(ctx context.Context) *logrus.Logger { + logger := ctx.Value(loggerKey{}) + if logger != nil { + logger, ok := logger.(*logrus.Logger) + if ok { + return logger + } + } + + return logrus.StandardLogger() +} + +func NewLoggerFromConfig(cfg *config.Config) *logrus.Logger { + logger := logrus.New() + + logLevel, err := logrus.ParseLevel(cfg.LogLevel) + if err != nil { + logLevel = logrus.InfoLevel + logger.Warnf("failed to parse log level: %s - assuming %q", err, logrus.InfoLevel) + } + + logger.SetLevel(logLevel) + + return logger +} diff --git a/pkg/utils/naming.go b/pkg/utils/naming.go index 2b7e493..76ee216 100644 --- a/pkg/utils/naming.go +++ b/pkg/utils/naming.go @@ -1,71 +1,71 @@ -package utils - -import ( - "crypto/rand" - "fmt" - "math/big" -) - -var nouns = []string{ - "ant", "ape", "asp", "auk", "bass", "bat", "bear", "bee", "bird", "boar", - "bug", "calf", "carp", "cat", "chimp", "cod", "colt", "conch", "cow", - "crab", "crane", "croc", "crow", "cub", "deer", "doe", "dog", "dolphin", - "donkey", "dove", "duck", "eel", "elk", "fawn", "finch", "fish", "flea", - "fly", "foal", "fowl", "fox", "frog", "gnat", "gnu", "goat", "goose", - "grouse", "grub", "gull", "hare", "hawk", "hen", "hog", "horse", "hound", - "jay", "kit", "kite", "koi", "lamb", "lark", "loon", "lynx", "mare", - "midge", "mink", "mole", "moose", "moth", "mouse", "mule", "newt", "owl", - "ox", "panda", "penguin", "perch", "pig", "pug", "quail", "ram", "rat", - "ray", "robin", "roo", "rook", "seal", "shad", "shark", "sheep", "shoat", - "shrew", "shrike", "shrimp", "skink", "skunk", "sloth", "slug", "smelt", - "snail", "snake", "snipe", "sow", "sponge", "squid", "squirrel", "stag", - "steed", "stoat", "stork", "swan", "tern", "toad", "trout", "turtle", - "vole", "wasp", "whale", "wolf", "worm", "wren", "yak", "zebra", -} - -var predicates = []string{ - "abundant", "able", "abrasive", "adorable", "adaptable", "adventurous", - "aged", "agreeable", "ambitious", "amazing", "amusing", "angry", - "auspicious", "awesome", "bald", "beautiful", "bemused", "bedecked", "big", - "bittersweet", "blushing", "bold", "bouncy", "brawny", "bright", "burly", - "bustling", "calm", "capable", "carefree", "capricious", "caring", - "casual", "charming", "chill", "classy", "clean", "clumsy", "colorful", - "crawling", "dapper", "debonair", "dashing", "defiant", "delicate", - "delightful", "dazzling", "efficient", "enchanting", "entertaining", - "enthused", "exultant", "fearless", "flawless", "fortunate", "fun", - "funny", "gaudy", "gentle", "gifted", "glamorous", "grandiose", - "gregarious", "handsome", "hilarious", "honorable", "illustrious", - "incongruous", "indecisive", "industrious", "intelligent", "inquisitive", - "intrigued", "invincible", "judicious", "kindly", "languid", "learned", - "legendary", "likeable", "loud", "luminous", "luxuriant", "lyrical", - "magnificent", "marvelous", "masked", "melodic", "merciful", "mercurial", - "monumental", "mysterious", "nebulous", "nervous", "nimble", "nosy", - "omniscient", "orderly", "overjoyed", "peaceful", "painted", "persistent", - "placid", "polite", "popular", "powerful", "puzzled", "rambunctious", - "rare", "rebellious", "respected", "resilient", "righteous", "receptive", - "redolent", "resilient", "rogue", "rumbling", "salty", "sassy", "secretive", - "selective", "sedate", "serious", "shivering", "skillful", "sincere", - "skittish", "silent", "smiling", -} - -const numRange = 1000 - -// GenerateRandomName generates random name for `run`. -func GenerateRandomName() (string, error) { - predicateIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(predicates)))) - if err != nil { - return "", fmt.Errorf("error getting random integer number: %w", err) - } - - nounIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nouns)))) - if err != nil { - return "", fmt.Errorf("error getting random integer number: %w", err) - } - - num, err := rand.Int(rand.Reader, big.NewInt(numRange)) - if err != nil { - return "", fmt.Errorf("error getting random integer number: %w", err) - } - - return fmt.Sprintf("%s-%s-%d", predicates[predicateIndex.Int64()], nouns[nounIndex.Int64()], num), nil -} +package utils + +import ( + "crypto/rand" + "fmt" + "math/big" +) + +var nouns = []string{ + "ant", "ape", "asp", "auk", "bass", "bat", "bear", "bee", "bird", "boar", + "bug", "calf", "carp", "cat", "chimp", "cod", "colt", "conch", "cow", + "crab", "crane", "croc", "crow", "cub", "deer", "doe", "dog", "dolphin", + "donkey", "dove", "duck", "eel", "elk", "fawn", "finch", "fish", "flea", + "fly", "foal", "fowl", "fox", "frog", "gnat", "gnu", "goat", "goose", + "grouse", "grub", "gull", "hare", "hawk", "hen", "hog", "horse", "hound", + "jay", "kit", "kite", "koi", "lamb", "lark", "loon", "lynx", "mare", + "midge", "mink", "mole", "moose", "moth", "mouse", "mule", "newt", "owl", + "ox", "panda", "penguin", "perch", "pig", "pug", "quail", "ram", "rat", + "ray", "robin", "roo", "rook", "seal", "shad", "shark", "sheep", "shoat", + "shrew", "shrike", "shrimp", "skink", "skunk", "sloth", "slug", "smelt", + "snail", "snake", "snipe", "sow", "sponge", "squid", "squirrel", "stag", + "steed", "stoat", "stork", "swan", "tern", "toad", "trout", "turtle", + "vole", "wasp", "whale", "wolf", "worm", "wren", "yak", "zebra", +} + +var predicates = []string{ + "abundant", "able", "abrasive", "adorable", "adaptable", "adventurous", + "aged", "agreeable", "ambitious", "amazing", "amusing", "angry", + "auspicious", "awesome", "bald", "beautiful", "bemused", "bedecked", "big", + "bittersweet", "blushing", "bold", "bouncy", "brawny", "bright", "burly", + "bustling", "calm", "capable", "carefree", "capricious", "caring", + "casual", "charming", "chill", "classy", "clean", "clumsy", "colorful", + "crawling", "dapper", "debonair", "dashing", "defiant", "delicate", + "delightful", "dazzling", "efficient", "enchanting", "entertaining", + "enthused", "exultant", "fearless", "flawless", "fortunate", "fun", + "funny", "gaudy", "gentle", "gifted", "glamorous", "grandiose", + "gregarious", "handsome", "hilarious", "honorable", "illustrious", + "incongruous", "indecisive", "industrious", "intelligent", "inquisitive", + "intrigued", "invincible", "judicious", "kindly", "languid", "learned", + "legendary", "likeable", "loud", "luminous", "luxuriant", "lyrical", + "magnificent", "marvelous", "masked", "melodic", "merciful", "mercurial", + "monumental", "mysterious", "nebulous", "nervous", "nimble", "nosy", + "omniscient", "orderly", "overjoyed", "peaceful", "painted", "persistent", + "placid", "polite", "popular", "powerful", "puzzled", "rambunctious", + "rare", "rebellious", "respected", "resilient", "righteous", "receptive", + "redolent", "resilient", "rogue", "rumbling", "salty", "sassy", "secretive", + "selective", "sedate", "serious", "shivering", "skillful", "sincere", + "skittish", "silent", "smiling", +} + +const numRange = 1000 + +// GenerateRandomName generates random name for `run`. +func GenerateRandomName() (string, error) { + predicateIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(predicates)))) + if err != nil { + return "", fmt.Errorf("error getting random integer number: %w", err) + } + + nounIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nouns)))) + if err != nil { + return "", fmt.Errorf("error getting random integer number: %w", err) + } + + num, err := rand.Int(rand.Reader, big.NewInt(numRange)) + if err != nil { + return "", fmt.Errorf("error getting random integer number: %w", err) + } + + return fmt.Sprintf("%s-%s-%d", predicates[predicateIndex.Int64()], nouns[nounIndex.Int64()], num), nil +} diff --git a/pkg/utils/path.go b/pkg/utils/path.go index 39230a5..9ac35e0 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -1,90 +1,90 @@ -package utils - -import ( - "errors" - "fmt" - "net/url" - "path" - "strings" -) - -var ( - errFailedToDecodeURL = errors.New("failed to decode url") - errInvalidQueryString = errors.New("invalid query string") -) - -func decode(input string) (string, error) { - current := input - - for range 10 { - decoded, err := url.QueryUnescape(current) - if err != nil { - return "", fmt.Errorf("could not unescape %s: %w", current, err) - } - - parsed, err := url.Parse(decoded) - if err != nil { - return "", fmt.Errorf("could not parsed %s: %w", decoded, err) - } - - if current == parsed.String() { - return current, nil - } - } - - return "", errFailedToDecodeURL -} - -func validateQueryString(query string) error { - query, err := decode(query) - if err != nil { - return err - } - - if strings.Contains(query, "..") { - return errInvalidQueryString - } - - return nil -} - -func joinPosixPathsAndAppendAbsoluteSuffixes(prefixPath, suffixPath string) string { - if len(prefixPath) == 0 { - return suffixPath - } - - suffixPath = strings.TrimPrefix(suffixPath, "/") - - return path.Join(prefixPath, suffixPath) -} - -func AppendToURIPath(uri string, paths ...string) (string, error) { - path := "" - for _, subpath := range paths { - path = joinPosixPathsAndAppendAbsoluteSuffixes(path, subpath) - } - - parsedURI, err := url.Parse(uri) - if err != nil { - return "", fmt.Errorf("failed to parse uri %s: %w", uri, err) - } - - if err := validateQueryString(parsedURI.RawQuery); err != nil { - return "", err - } - - if len(parsedURI.Scheme) == 0 { - return joinPosixPathsAndAppendAbsoluteSuffixes(uri, path), nil - } - - prefix := "" - if !strings.HasPrefix(parsedURI.Path, "/") { - prefix = parsedURI.Scheme + ":" - parsedURI.Scheme = "" - } - - newURIPath := joinPosixPathsAndAppendAbsoluteSuffixes(parsedURI.Path, path) - parsedURI.Path = newURIPath - - return prefix + parsedURI.String(), nil -} +package utils + +import ( + "errors" + "fmt" + "net/url" + "path" + "strings" +) + +var ( + errFailedToDecodeURL = errors.New("failed to decode url") + errInvalidQueryString = errors.New("invalid query string") +) + +func decode(input string) (string, error) { + current := input + + for range 10 { + decoded, err := url.QueryUnescape(current) + if err != nil { + return "", fmt.Errorf("could not unescape %s: %w", current, err) + } + + parsed, err := url.Parse(decoded) + if err != nil { + return "", fmt.Errorf("could not parsed %s: %w", decoded, err) + } + + if current == parsed.String() { + return current, nil + } + } + + return "", errFailedToDecodeURL +} + +func validateQueryString(query string) error { + query, err := decode(query) + if err != nil { + return err + } + + if strings.Contains(query, "..") { + return errInvalidQueryString + } + + return nil +} + +func joinPosixPathsAndAppendAbsoluteSuffixes(prefixPath, suffixPath string) string { + if len(prefixPath) == 0 { + return suffixPath + } + + suffixPath = strings.TrimPrefix(suffixPath, "/") + + return path.Join(prefixPath, suffixPath) +} + +func AppendToURIPath(uri string, paths ...string) (string, error) { + path := "" + for _, subpath := range paths { + path = joinPosixPathsAndAppendAbsoluteSuffixes(path, subpath) + } + + parsedURI, err := url.Parse(uri) + if err != nil { + return "", fmt.Errorf("failed to parse uri %s: %w", uri, err) + } + + if err := validateQueryString(parsedURI.RawQuery); err != nil { + return "", err + } + + if len(parsedURI.Scheme) == 0 { + return joinPosixPathsAndAppendAbsoluteSuffixes(uri, path), nil + } + + prefix := "" + if !strings.HasPrefix(parsedURI.Path, "/") { + prefix = parsedURI.Scheme + ":" + parsedURI.Scheme = "" + } + + newURIPath := joinPosixPathsAndAppendAbsoluteSuffixes(parsedURI.Path, path) + parsedURI.Path = newURIPath + + return prefix + parsedURI.String(), nil +} diff --git a/pkg/utils/pointers.go b/pkg/utils/pointers.go index cd44985..b26e729 100644 --- a/pkg/utils/pointers.go +++ b/pkg/utils/pointers.go @@ -1,41 +1,41 @@ -package utils - -import ( - "strconv" -) - -func PtrTo[T any](v T) *T { - return &v -} - -func ConvertInt32PointerToStringPointer(iPtr *int32) *string { - if iPtr == nil { - return nil - } - - iValue := *iPtr - sValue := strconv.Itoa(int(iValue)) - - return &sValue -} - -func ConvertStringPointerToInt32Pointer(s *string) int32 { - if s == nil { - return 0 - } - - iValue, err := strconv.ParseInt(*s, 10, 32) - if err != nil { - return 0 - } - - return int32(iValue) -} - -func DumpStringPointer(s *string) string { - if s == nil { - return "" - } - - return *s -} +package utils + +import ( + "strconv" +) + +func PtrTo[T any](v T) *T { + return &v +} + +func ConvertInt32PointerToStringPointer(iPtr *int32) *string { + if iPtr == nil { + return nil + } + + iValue := *iPtr + sValue := strconv.Itoa(int(iValue)) + + return &sValue +} + +func ConvertStringPointerToInt32Pointer(s *string) int32 { + if s == nil { + return 0 + } + + iValue, err := strconv.ParseInt(*s, 10, 32) + if err != nil { + return 0 + } + + return int32(iValue) +} + +func DumpStringPointer(s *string) string { + if s == nil { + return "" + } + + return *s +} diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go index 0dec0e7..e0b8773 100644 --- a/pkg/utils/strings.go +++ b/pkg/utils/strings.go @@ -1,24 +1,24 @@ -package utils - -import ( - "encoding/hex" - - "github.com/google/uuid" -) - -func IsNotNilOrEmptyString(v *string) bool { - return v != nil && *v != "" -} - -func IsNilOrEmptyString(v *string) bool { - return v == nil || *v == "" -} - -func NewUUID() string { - var r [32]byte - - u := uuid.New() - hex.Encode(r[:], u[:]) - - return string(r[:]) -} +package utils + +import ( + "encoding/hex" + + "github.com/google/uuid" +) + +func IsNotNilOrEmptyString(v *string) bool { + return v != nil && *v != "" +} + +func IsNilOrEmptyString(v *string) bool { + return v == nil || *v == "" +} + +func NewUUID() string { + var r [32]byte + + u := uuid.New() + hex.Encode(r[:], u[:]) + + return string(r[:]) +} diff --git a/pkg/utils/tags.go b/pkg/utils/tags.go index 98474af..1bba1aa 100644 --- a/pkg/utils/tags.go +++ b/pkg/utils/tags.go @@ -1,6 +1,6 @@ -package utils - -const ( - TagRunName = "mlflow.runName" - TagUser = "mlflow.user" -) +package utils + +const ( + TagRunName = "mlflow.runName" + TagUser = "mlflow.user" +) diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 006e994..6349abb 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -1,299 +1,299 @@ -package validation - -import ( - "encoding/json" - "errors" - "fmt" - "net/url" - "os" - "path/filepath" - "reflect" - "regexp" - "strconv" - "strings" - - "github.com/go-playground/validator/v10" - - "github.com/mlflow/mlflow-go/pkg/contract" - "github.com/mlflow/mlflow-go/pkg/protos" -) - -const ( - QuoteLength = 2 - MaxEntitiesPerBatch = 1000 - MaxValidationInputLength = 100 -) - -// regex for valid param and metric names: may only contain slashes, alphanumerics, -// underscores, periods, dashes, and spaces. -var paramAndMetricNameRegex = regexp.MustCompile(`^[/\w.\- ]*$`) - -// regex for valid run IDs: must be an alphanumeric string of length 1 to 256. -var runIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][\w\-]{0,255}$`) - -func stringAsPositiveIntegerValidation(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - - value, err := strconv.Atoi(valueStr) - if err != nil { - return false - } - - return value > -1 -} - -func uriWithoutFragmentsOrParamsOrDotDotInQueryValidation(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - if valueStr == "" { - return true - } - - u, err := url.Parse(valueStr) - if err != nil { - return false - } - - return u.Fragment == "" && u.RawQuery == "" && !strings.Contains(u.RawQuery, "..") -} - -func uniqueParamsValidation(fl validator.FieldLevel) bool { - value := fl.Field() - - params, areParams := value.Interface().([]*protos.Param) - if !areParams || len(params) == 0 { - return true - } - - hasDuplicates := false - keys := make(map[string]bool, len(params)) - - for _, param := range params { - if _, ok := keys[param.GetKey()]; ok { - hasDuplicates = true - - break - } - - keys[param.GetKey()] = true - } - - return !hasDuplicates -} - -func pathIsClean(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - norm := filepath.Clean(valueStr) - - return !(norm != valueStr || norm == "." || strings.HasPrefix(norm, "..") || strings.HasPrefix(norm, "/")) -} - -func regexValidation(regex *regexp.Regexp) validator.Func { - return func(fl validator.FieldLevel) bool { - valueStr := fl.Field().String() - - return regex.MatchString(valueStr) - } -} - -// see _validate_batch_log_limits in validation.py. -func validateLogBatchLimits(structLevel validator.StructLevel) { - logBatch, isLogBatch := structLevel.Current().Interface().(*protos.LogBatch) - - if isLogBatch { - total := len(logBatch.GetParams()) + len(logBatch.GetMetrics()) + len(logBatch.GetTags()) - if total > MaxEntitiesPerBatch { - structLevel.ReportError(&logBatch, "metrics, params, and tags", "", "", "") - } - } -} - -func truncateFn(fieldLevel validator.FieldLevel) bool { - param := fieldLevel.Param() // Get the parameter from the tag - - maxLength, err := strconv.Atoi(param) - if err != nil { - return false // If the parameter isn't a valid integer, fail the validation. - } - - truncateLongValues, shouldTruncate := os.LookupEnv("MLFLOW_TRUNCATE_LONG_VALUES") - shouldTruncate = shouldTruncate && truncateLongValues == "true" - - field := fieldLevel.Field() - - if field.Kind() == reflect.String { - strValue := field.String() - if len(strValue) <= maxLength { - return true - } - - if shouldTruncate { - field.SetString(strValue[:maxLength]) - - return true - } - - return false - } - - return true -} - -func NewValidator() (*validator.Validate, error) { - validate := validator.New() - - validate.RegisterTagNameFunc(func(fld reflect.StructField) string { - name := strings.SplitN(fld.Tag.Get("json"), ",", QuoteLength)[0] - // skip if tag key says it should be ignored - if name == "-" { - return "" - } - - return name - }) - - // Verify that the input string is a positive integer. - if err := validate.RegisterValidation( - "stringAsPositiveInteger", stringAsPositiveIntegerValidation, - ); err != nil { - return nil, fmt.Errorf("validation registration for 'stringAsPositiveInteger' failed: %w", err) - } - - // Verify that the input string, if present, is a Url without fragment or query parameters - if err := validate.RegisterValidation( - "uriWithoutFragmentsOrParamsOrDotDotInQuery", uriWithoutFragmentsOrParamsOrDotDotInQueryValidation); err != nil { - return nil, fmt.Errorf("validation registration for 'uriWithoutFragmentsOrParamsOrDotDotInQuery' failed: %w", err) - } - - if err := validate.RegisterValidation( - "validMetricParamOrTagName", regexValidation(paramAndMetricNameRegex), - ); err != nil { - return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagName' failed: %w", err) - } - - if err := validate.RegisterValidation("pathIsUnique", pathIsClean); err != nil { - return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagValue' failed: %w", err) - } - - // unique params in LogBatch - if err := validate.RegisterValidation("uniqueParams", uniqueParamsValidation); err != nil { - return nil, fmt.Errorf("validation registration for 'uniqueParams' failed: %w", err) - } - - if err := validate.RegisterValidation("runId", regexValidation(runIDRegex)); err != nil { - return nil, fmt.Errorf("validation registration for 'runId' failed: %w", err) - } - - if err := validate.RegisterValidation("truncate", truncateFn); err != nil { - return nil, fmt.Errorf("validation registration for 'truncateFn' failed: %w", err) - } - - validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{}) - - return validate, nil -} - -func dereference(value interface{}) interface{} { - valueOf := reflect.ValueOf(value) - if valueOf.Kind() == reflect.Ptr { - if valueOf.IsNil() { - return "" - } - - return valueOf.Elem().Interface() - } - - return value -} - -func getErrorPath(err validator.FieldError) string { - path := err.Field() - - if err.Namespace() != "" { - // Strip first item in struct namespace - idx := strings.Index(err.Namespace(), ".") - if idx != -1 { - path = err.Namespace()[(idx + 1):] - } - } - - return path -} - -func constructValidationError(field string, value any, suffix string) string { - formattedValue, err := json.Marshal(value) - if err != nil { - formattedValue = []byte(fmt.Sprintf("%v", value)) - } - - return fmt.Sprintf("Invalid value %s for parameter '%s' supplied%s", formattedValue, field, suffix) -} - -func mkTruncateValidationError(field string, value interface{}, err validator.FieldError) string { - strValue, ok := value.(string) - if ok { - expected := len(strValue) - - if expected > MaxValidationInputLength { - strValue = strValue[:MaxValidationInputLength] + "..." - } - - return constructValidationError( - field, - strValue, - fmt.Sprintf(": length %d exceeded length limit of %s", expected, err.Param()), - ) - } - - return constructValidationError(field, value, "") -} - -func mkMaxValidationError(field string, value interface{}, err validator.FieldError) string { - if _, ok := value.(string); ok { - return fmt.Sprintf( - "'%s' exceeds the maximum length of %s characters", - field, - err.Param(), - ) - } - - return constructValidationError(field, value, "") -} - -func NewErrorFromValidationError(err error) *contract.Error { - var validatorValidationErrors validator.ValidationErrors - if errors.As(err, &validatorValidationErrors) { - validationErrors := make([]string, 0) - - for _, err := range validatorValidationErrors { - field := getErrorPath(err) - tag := err.Tag() - value := dereference(err.Value()) - - switch tag { - case "required": - validationErrors = append( - validationErrors, - fmt.Sprintf("Missing value for required parameter '%s'", field), - ) - case "truncate": - validationErrors = append(validationErrors, mkTruncateValidationError(field, value, err)) - case "uniqueParams": - validationErrors = append( - validationErrors, - "Duplicate parameter keys have been submitted", - ) - case "max": - validationErrors = append(validationErrors, mkMaxValidationError(field, value, err)) - default: - validationErrors = append( - validationErrors, - constructValidationError(field, value, ""), - ) - } - } - - return contract.NewError(protos.ErrorCode_INVALID_PARAMETER_VALUE, strings.Join(validationErrors, ", ")) - } - - return contract.NewError(protos.ErrorCode_INTERNAL_ERROR, err.Error()) -} +package validation + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "path/filepath" + "reflect" + "regexp" + "strconv" + "strings" + + "github.com/go-playground/validator/v10" + + "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/protos" +) + +const ( + QuoteLength = 2 + MaxEntitiesPerBatch = 1000 + MaxValidationInputLength = 100 +) + +// regex for valid param and metric names: may only contain slashes, alphanumerics, +// underscores, periods, dashes, and spaces. +var paramAndMetricNameRegex = regexp.MustCompile(`^[/\w.\- ]*$`) + +// regex for valid run IDs: must be an alphanumeric string of length 1 to 256. +var runIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][\w\-]{0,255}$`) + +func stringAsPositiveIntegerValidation(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + + value, err := strconv.Atoi(valueStr) + if err != nil { + return false + } + + return value > -1 +} + +func uriWithoutFragmentsOrParamsOrDotDotInQueryValidation(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + if valueStr == "" { + return true + } + + u, err := url.Parse(valueStr) + if err != nil { + return false + } + + return u.Fragment == "" && u.RawQuery == "" && !strings.Contains(u.RawQuery, "..") +} + +func uniqueParamsValidation(fl validator.FieldLevel) bool { + value := fl.Field() + + params, areParams := value.Interface().([]*protos.Param) + if !areParams || len(params) == 0 { + return true + } + + hasDuplicates := false + keys := make(map[string]bool, len(params)) + + for _, param := range params { + if _, ok := keys[param.GetKey()]; ok { + hasDuplicates = true + + break + } + + keys[param.GetKey()] = true + } + + return !hasDuplicates +} + +func pathIsClean(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + norm := filepath.Clean(valueStr) + + return !(norm != valueStr || norm == "." || strings.HasPrefix(norm, "..") || strings.HasPrefix(norm, "/")) +} + +func regexValidation(regex *regexp.Regexp) validator.Func { + return func(fl validator.FieldLevel) bool { + valueStr := fl.Field().String() + + return regex.MatchString(valueStr) + } +} + +// see _validate_batch_log_limits in validation.py. +func validateLogBatchLimits(structLevel validator.StructLevel) { + logBatch, isLogBatch := structLevel.Current().Interface().(*protos.LogBatch) + + if isLogBatch { + total := len(logBatch.GetParams()) + len(logBatch.GetMetrics()) + len(logBatch.GetTags()) + if total > MaxEntitiesPerBatch { + structLevel.ReportError(&logBatch, "metrics, params, and tags", "", "", "") + } + } +} + +func truncateFn(fieldLevel validator.FieldLevel) bool { + param := fieldLevel.Param() // Get the parameter from the tag + + maxLength, err := strconv.Atoi(param) + if err != nil { + return false // If the parameter isn't a valid integer, fail the validation. + } + + truncateLongValues, shouldTruncate := os.LookupEnv("MLFLOW_TRUNCATE_LONG_VALUES") + shouldTruncate = shouldTruncate && truncateLongValues == "true" + + field := fieldLevel.Field() + + if field.Kind() == reflect.String { + strValue := field.String() + if len(strValue) <= maxLength { + return true + } + + if shouldTruncate { + field.SetString(strValue[:maxLength]) + + return true + } + + return false + } + + return true +} + +func NewValidator() (*validator.Validate, error) { + validate := validator.New() + + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", QuoteLength)[0] + // skip if tag key says it should be ignored + if name == "-" { + return "" + } + + return name + }) + + // Verify that the input string is a positive integer. + if err := validate.RegisterValidation( + "stringAsPositiveInteger", stringAsPositiveIntegerValidation, + ); err != nil { + return nil, fmt.Errorf("validation registration for 'stringAsPositiveInteger' failed: %w", err) + } + + // Verify that the input string, if present, is a Url without fragment or query parameters + if err := validate.RegisterValidation( + "uriWithoutFragmentsOrParamsOrDotDotInQuery", uriWithoutFragmentsOrParamsOrDotDotInQueryValidation); err != nil { + return nil, fmt.Errorf("validation registration for 'uriWithoutFragmentsOrParamsOrDotDotInQuery' failed: %w", err) + } + + if err := validate.RegisterValidation( + "validMetricParamOrTagName", regexValidation(paramAndMetricNameRegex), + ); err != nil { + return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagName' failed: %w", err) + } + + if err := validate.RegisterValidation("pathIsUnique", pathIsClean); err != nil { + return nil, fmt.Errorf("validation registration for 'validMetricParamOrTagValue' failed: %w", err) + } + + // unique params in LogBatch + if err := validate.RegisterValidation("uniqueParams", uniqueParamsValidation); err != nil { + return nil, fmt.Errorf("validation registration for 'uniqueParams' failed: %w", err) + } + + if err := validate.RegisterValidation("runId", regexValidation(runIDRegex)); err != nil { + return nil, fmt.Errorf("validation registration for 'runId' failed: %w", err) + } + + if err := validate.RegisterValidation("truncate", truncateFn); err != nil { + return nil, fmt.Errorf("validation registration for 'truncateFn' failed: %w", err) + } + + validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{}) + + return validate, nil +} + +func dereference(value interface{}) interface{} { + valueOf := reflect.ValueOf(value) + if valueOf.Kind() == reflect.Ptr { + if valueOf.IsNil() { + return "" + } + + return valueOf.Elem().Interface() + } + + return value +} + +func getErrorPath(err validator.FieldError) string { + path := err.Field() + + if err.Namespace() != "" { + // Strip first item in struct namespace + idx := strings.Index(err.Namespace(), ".") + if idx != -1 { + path = err.Namespace()[(idx + 1):] + } + } + + return path +} + +func constructValidationError(field string, value any, suffix string) string { + formattedValue, err := json.Marshal(value) + if err != nil { + formattedValue = []byte(fmt.Sprintf("%v", value)) + } + + return fmt.Sprintf("Invalid value %s for parameter '%s' supplied%s", formattedValue, field, suffix) +} + +func mkTruncateValidationError(field string, value interface{}, err validator.FieldError) string { + strValue, ok := value.(string) + if ok { + expected := len(strValue) + + if expected > MaxValidationInputLength { + strValue = strValue[:MaxValidationInputLength] + "..." + } + + return constructValidationError( + field, + strValue, + fmt.Sprintf(": length %d exceeded length limit of %s", expected, err.Param()), + ) + } + + return constructValidationError(field, value, "") +} + +func mkMaxValidationError(field string, value interface{}, err validator.FieldError) string { + if _, ok := value.(string); ok { + return fmt.Sprintf( + "'%s' exceeds the maximum length of %s characters", + field, + err.Param(), + ) + } + + return constructValidationError(field, value, "") +} + +func NewErrorFromValidationError(err error) *contract.Error { + var validatorValidationErrors validator.ValidationErrors + if errors.As(err, &validatorValidationErrors) { + validationErrors := make([]string, 0) + + for _, err := range validatorValidationErrors { + field := getErrorPath(err) + tag := err.Tag() + value := dereference(err.Value()) + + switch tag { + case "required": + validationErrors = append( + validationErrors, + fmt.Sprintf("Missing value for required parameter '%s'", field), + ) + case "truncate": + validationErrors = append(validationErrors, mkTruncateValidationError(field, value, err)) + case "uniqueParams": + validationErrors = append( + validationErrors, + "Duplicate parameter keys have been submitted", + ) + case "max": + validationErrors = append(validationErrors, mkMaxValidationError(field, value, err)) + default: + validationErrors = append( + validationErrors, + constructValidationError(field, value, ""), + ) + } + } + + return contract.NewError(protos.ErrorCode_INVALID_PARAMETER_VALUE, strings.Join(validationErrors, ", ")) + } + + return contract.NewError(protos.ErrorCode_INTERNAL_ERROR, err.Error()) +} diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go index 52a5db7..ea480ff 100644 --- a/pkg/validation/validation_test.go +++ b/pkg/validation/validation_test.go @@ -1,244 +1,244 @@ -package validation_test - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/mlflow/mlflow-go/pkg/protos" - "github.com/mlflow/mlflow-go/pkg/utils" - "github.com/mlflow/mlflow-go/pkg/validation" -) - -type PositiveInteger struct { - Value string `validate:"stringAsPositiveInteger"` -} - -type validationScenario struct { - name string - input any - shouldTrigger bool -} - -func runscenarios(t *testing.T, scenarios []validationScenario) { - t.Helper() - - validator, err := validation.NewValidator() - require.NoError(t, err) - - for _, scenario := range scenarios { - currentScenario := scenario - t.Run(currentScenario.name, func(t *testing.T) { - t.Parallel() - - errs := validator.Struct(currentScenario.input) - - if currentScenario.shouldTrigger && errs == nil { - t.Errorf("Expected validation error, got nil") - } - - if !currentScenario.shouldTrigger && errs != nil { - t.Errorf("Expected no validation error, got %v", errs) - } - }) - } -} - -func TestStringAsPositiveInteger(t *testing.T) { - t.Parallel() - - scenarios := []validationScenario{ - { - name: "positive integer", - input: PositiveInteger{Value: "1"}, - shouldTrigger: false, - }, - { - name: "zero", - input: PositiveInteger{Value: "0"}, - shouldTrigger: false, - }, - { - name: "negative integer", - input: PositiveInteger{Value: "-1"}, - shouldTrigger: true, - }, - { - name: "alphabet", - input: PositiveInteger{Value: "a"}, - shouldTrigger: true, - }, - } - - runscenarios(t, scenarios) -} - -type uriWithoutFragmentsOrParams struct { - Value string `validate:"uriWithoutFragmentsOrParamsOrDotDotInQuery"` -} - -func TestUriWithoutFragmentsOrParams(t *testing.T) { - t.Parallel() - - scenarios := []validationScenario{ - { - name: "valid url", - input: uriWithoutFragmentsOrParams{Value: "http://example.com"}, - shouldTrigger: false, - }, - { - name: "only trigger when url is not empty", - input: uriWithoutFragmentsOrParams{Value: ""}, - shouldTrigger: false, - }, - { - name: "url with fragment", - input: uriWithoutFragmentsOrParams{Value: "http://example.com#fragment"}, - shouldTrigger: true, - }, - { - name: "url with query parameters", - input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=param"}, - shouldTrigger: true, - }, - { - name: "unparsable url", - input: uriWithoutFragmentsOrParams{Value: ":invalid-url"}, - shouldTrigger: true, - }, - { - name: ".. in query", - input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=./.."}, - shouldTrigger: true, - }, - } - - runscenarios(t, scenarios) -} - -func TestUniqueParamsInLogBatch(t *testing.T) { - t.Parallel() - - logBatchRequest := &protos.LogBatch{ - Params: []*protos.Param{ - {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value1")}, - {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value2")}, - }, - } - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(logBatchRequest) - if err == nil { - t.Error("Expected uniqueParams validation error, got none") - } -} - -func TestEmptyParamsInLogBatch(t *testing.T) { - t.Parallel() - - logBatchRequest := &protos.LogBatch{ - RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), - Params: make([]*protos.Param, 0), - } - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(logBatchRequest) - if err != nil { - t.Errorf("Unexpected uniqueParams validation error, got %v", err) - } -} - -func TestMissingTimestampInNestedMetric(t *testing.T) { - t.Parallel() - - serverValidator, err := validation.NewValidator() - require.NoError(t, err) - - logBatch := protos.LogBatch{ - RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), - Metrics: []*protos.Metric{ - { - Key: utils.PtrTo("mae"), - Value: utils.PtrTo(2.5), - }, - }, - } - - err = serverValidator.Struct(&logBatch) - if err == nil { - t.Error("Expected dive validation error, got none") - } - - msg := validation.NewErrorFromValidationError(err).Message - if !strings.Contains(msg, "metrics[0].timestamp") { - t.Errorf("Expected required validation error for nested property, got %v", msg) - } -} - -type avecTruncate struct { - X *string `validate:"truncate=3"` - Y string `validate:"truncate=3"` -} - -func TestTruncate(t *testing.T) { - input := &avecTruncate{ - X: utils.PtrTo("123456"), - Y: "654321", - } - - t.Setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true") - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(input) - require.NoError(t, err) - - if len(*input.X) != 3 { - t.Errorf("Expected the length of x to be 3, was %d", len(*input.X)) - } - - if len(input.Y) != 3 { - t.Errorf("Expected the length of y to be 3, was %d", len(input.Y)) - } -} - -// This unit test is a sanity test that confirms the `dive` validation -// enters a nested slice of pointer structs. -func TestNestedErrorsInSubCollection(t *testing.T) { - t.Parallel() - - value := strings.Repeat("X", 6001) + "Y" - - logBatchRequest := &protos.LogBatch{ - RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), - Params: []*protos.Param{ - {Key: utils.PtrTo("key1"), Value: utils.PtrTo(value)}, - {Key: utils.PtrTo("key2"), Value: utils.PtrTo(value)}, - }, - } - - validator, err := validation.NewValidator() - require.NoError(t, err) - - err = validator.Struct(logBatchRequest) - if err != nil { - msg := validation.NewErrorFromValidationError(err).Message - // Assert the root struct name is not present in the error message - if strings.Contains(msg, "logBatch") { - t.Errorf("Validation message contained root struct name, got %s", msg) - } - - // Assert the index is listed in the parameter path - if !strings.Contains(msg, "params[0].value") || - !strings.Contains(msg, "params[1].value") || - !strings.Contains(msg, "length 6002 exceeded length limit of 6000") { - t.Errorf("Unexpected validation error message, got %s", msg) - } - } -} +package validation_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/utils" + "github.com/mlflow/mlflow-go/pkg/validation" +) + +type PositiveInteger struct { + Value string `validate:"stringAsPositiveInteger"` +} + +type validationScenario struct { + name string + input any + shouldTrigger bool +} + +func runscenarios(t *testing.T, scenarios []validationScenario) { + t.Helper() + + validator, err := validation.NewValidator() + require.NoError(t, err) + + for _, scenario := range scenarios { + currentScenario := scenario + t.Run(currentScenario.name, func(t *testing.T) { + t.Parallel() + + errs := validator.Struct(currentScenario.input) + + if currentScenario.shouldTrigger && errs == nil { + t.Errorf("Expected validation error, got nil") + } + + if !currentScenario.shouldTrigger && errs != nil { + t.Errorf("Expected no validation error, got %v", errs) + } + }) + } +} + +func TestStringAsPositiveInteger(t *testing.T) { + t.Parallel() + + scenarios := []validationScenario{ + { + name: "positive integer", + input: PositiveInteger{Value: "1"}, + shouldTrigger: false, + }, + { + name: "zero", + input: PositiveInteger{Value: "0"}, + shouldTrigger: false, + }, + { + name: "negative integer", + input: PositiveInteger{Value: "-1"}, + shouldTrigger: true, + }, + { + name: "alphabet", + input: PositiveInteger{Value: "a"}, + shouldTrigger: true, + }, + } + + runscenarios(t, scenarios) +} + +type uriWithoutFragmentsOrParams struct { + Value string `validate:"uriWithoutFragmentsOrParamsOrDotDotInQuery"` +} + +func TestUriWithoutFragmentsOrParams(t *testing.T) { + t.Parallel() + + scenarios := []validationScenario{ + { + name: "valid url", + input: uriWithoutFragmentsOrParams{Value: "http://example.com"}, + shouldTrigger: false, + }, + { + name: "only trigger when url is not empty", + input: uriWithoutFragmentsOrParams{Value: ""}, + shouldTrigger: false, + }, + { + name: "url with fragment", + input: uriWithoutFragmentsOrParams{Value: "http://example.com#fragment"}, + shouldTrigger: true, + }, + { + name: "url with query parameters", + input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=param"}, + shouldTrigger: true, + }, + { + name: "unparsable url", + input: uriWithoutFragmentsOrParams{Value: ":invalid-url"}, + shouldTrigger: true, + }, + { + name: ".. in query", + input: uriWithoutFragmentsOrParams{Value: "http://example.com?query=./.."}, + shouldTrigger: true, + }, + } + + runscenarios(t, scenarios) +} + +func TestUniqueParamsInLogBatch(t *testing.T) { + t.Parallel() + + logBatchRequest := &protos.LogBatch{ + Params: []*protos.Param{ + {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value1")}, + {Key: utils.PtrTo("key1"), Value: utils.PtrTo("value2")}, + }, + } + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(logBatchRequest) + if err == nil { + t.Error("Expected uniqueParams validation error, got none") + } +} + +func TestEmptyParamsInLogBatch(t *testing.T) { + t.Parallel() + + logBatchRequest := &protos.LogBatch{ + RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), + Params: make([]*protos.Param, 0), + } + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(logBatchRequest) + if err != nil { + t.Errorf("Unexpected uniqueParams validation error, got %v", err) + } +} + +func TestMissingTimestampInNestedMetric(t *testing.T) { + t.Parallel() + + serverValidator, err := validation.NewValidator() + require.NoError(t, err) + + logBatch := protos.LogBatch{ + RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), + Metrics: []*protos.Metric{ + { + Key: utils.PtrTo("mae"), + Value: utils.PtrTo(2.5), + }, + }, + } + + err = serverValidator.Struct(&logBatch) + if err == nil { + t.Error("Expected dive validation error, got none") + } + + msg := validation.NewErrorFromValidationError(err).Message + if !strings.Contains(msg, "metrics[0].timestamp") { + t.Errorf("Expected required validation error for nested property, got %v", msg) + } +} + +type avecTruncate struct { + X *string `validate:"truncate=3"` + Y string `validate:"truncate=3"` +} + +func TestTruncate(t *testing.T) { + input := &avecTruncate{ + X: utils.PtrTo("123456"), + Y: "654321", + } + + t.Setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true") + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(input) + require.NoError(t, err) + + if len(*input.X) != 3 { + t.Errorf("Expected the length of x to be 3, was %d", len(*input.X)) + } + + if len(input.Y) != 3 { + t.Errorf("Expected the length of y to be 3, was %d", len(input.Y)) + } +} + +// This unit test is a sanity test that confirms the `dive` validation +// enters a nested slice of pointer structs. +func TestNestedErrorsInSubCollection(t *testing.T) { + t.Parallel() + + value := strings.Repeat("X", 6001) + "Y" + + logBatchRequest := &protos.LogBatch{ + RunId: utils.PtrTo("odcppTsGTMkHeDcqfZOYDMZSf"), + Params: []*protos.Param{ + {Key: utils.PtrTo("key1"), Value: utils.PtrTo(value)}, + {Key: utils.PtrTo("key2"), Value: utils.PtrTo(value)}, + }, + } + + validator, err := validation.NewValidator() + require.NoError(t, err) + + err = validator.Struct(logBatchRequest) + if err != nil { + msg := validation.NewErrorFromValidationError(err).Message + // Assert the root struct name is not present in the error message + if strings.Contains(msg, "logBatch") { + t.Errorf("Validation message contained root struct name, got %s", msg) + } + + // Assert the index is listed in the parameter path + if !strings.Contains(msg, "params[0].value") || + !strings.Contains(msg, "params[1].value") || + !strings.Contains(msg, "length 6002 exceeded length limit of 6000") { + t.Errorf("Unexpected validation error message, got %s", msg) + } + } +} diff --git a/pyproject.toml b/pyproject.toml index 1e40bae..b4cf50f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,173 +1,173 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "mlflow-go" -version = "2.14.1" -description = "MLflow is an open source platform for the complete machine learning lifecycle" -readme = "README.md" -keywords = ["mlflow", "ai", "databricks"] -classifiers = [ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Intended Audience :: End Users/Desktop", - "Intended Audience :: Science/Research", - "Intended Audience :: Information Technology", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", -] -requires-python = ">=3.8" -dependencies = ["mlflow==2.14.1", "cffi"] -license = { file = "LICENSE.txt" } - -[[project.maintainers]] -name = "Databricks" -email = "mlflow-oss-maintainers@googlegroups.com" - -[project.urls] -homepage = "https://mlflow.org" -issues = "https://github.com/mlflow/mlflow-go/issues" -documentation = "https://mlflow.org/docs/latest/index.html" -repository = "https://github.com/mlflow/mlflow-go" - -[project.scripts] -mlflow-go = "mlflow_go.cli:cli" - -[project.entry-points."mlflow.tracking_store"] -mssql = "mlflow_go.store.tracking:_get_sqlalchemy_store" -mysql = "mlflow_go.store.tracking:_get_sqlalchemy_store" -postgresql = "mlflow_go.store.tracking:_get_sqlalchemy_store" -sqlite = "mlflow_go.store.tracking:_get_sqlalchemy_store" - -[project.entry-points."mlflow.model_registry_store"] -mssql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" -mysql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" -postgresql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" -sqlite = "mlflow_go.store.model_registry:_get_sqlalchemy_store" - -[tool.setuptools.packages.find] -where = ["."] -include = ["mlflow_go", "mlflow_go.*"] -exclude = ["tests", "tests.*"] - -[tool.ruff] -line-length = 100 -target-version = "py38" -force-exclude = true -extend-include = ["*.ipynb"] -extend-exclude = [ - "examples/recipes", - "mlflow/protos", - "mlflow/ml_package_versions.py", - "mlflow/server/graphql/autogenerated_graphql_schema.py", - "mlflow/server/js", - "mlflow/store/db_migrations", - "tests/protos", -] - -[tool.ruff.format] -docstring-code-format = true -docstring-code-line-length = 88 - -[tool.ruff.lint] -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" -select = [ - "B006", # multiple-argument-default - "B015", # useless-comparison - "D209", # new-line-after-last-paragraph - "D411", # no-blank-line-before-section - "E", # error - "F", # Pyflakes - "C4", # flake8-comprehensions - "I", # isort - "ISC001", # single-line-implicit-string-concatenation - "PIE790", # unnecessary-placeholder - "PLR0402", # manual-from-import - "PLE1205", # logging-too-many-args - "PT001", # pytest-fixture-incorrect-parentheses-style - "PT002", # pytest-fixture-positional-args - "PT003", # pytest-extraneous-scope-function - "PT006", # pytest-parameterize-names-wrong-type - "PT007", # pytest-parameterize-values-wrong-type - "PT009", # pytest-unittest-assertion - "PT010", # pytest-raises-without-exception - "PT011", # pytest-raises-too-broad - "PT012", # pytest-raises-with-multiple-statements - "PT013", # pytest-incorrect-pytest-import - "PT014", # pytest-duplicate-parametrize-test-cases - "PT018", # pytest-composite-assertion - "PT022", # pytest-useless-yield-fixture - "PT023", # pytest-incorrect-mark-parentheses-style - "PT026", # pytest-use-fixtures-without-parameters - "PT027", # pytest-unittest-raises-assertion - "RET504", # unnecessary-assign - "RUF010", # explicit-f-string-type-conversion - "RUF013", # implicit-optional - "RUF100", # unused-noqa - "S307", # suspicious-eval-usage - "S324", # hashlib-insecure-hash-function - "SIM101", # duplicate-isinstance-call - "SIM103", # needless-bool - "SIM108", # if-else-block-instead-of-if-exp - "SIM114", # if-with-same-arms - "SIM115", # open-file-with-context-handler - "SIM210", # if-expr-with-true-false - "SIM910", # dict-get-with-none-default - "T20", # flake8-print - "TID251", # banned-api - "TID252", # relative-improt - "TRY302", # useless-try-except - "UP004", # useless-object-inheritance - "UP008", # super-call-with-parameters - "UP011", # lru-cache-without-parameters - "UP012", # unecessary-encode-utf8 - "UP015", # redundant-open-modes - "UP030", # format-literals - "UP031", # printf-string-format - "UP032", # f-string - "UP034", # extraneous-parenthesis - "W", # warning -] -ignore = [ - "E402", # module-import-not-at-top-of-file - "E721", # type-comparison - "E741", # ambiguous-variable-name - "F811", # redefined-while-unused -] - -[tool.ruff.lint.per-file-ignores] -"dev/*" = ["T201", "PT018"] -"examples/*" = ["T20", "RET504", "E501"] -"docs/*" = ["T20", "RET504", "E501"] -"mlflow/*" = ["PT018"] - -[tool.ruff.lint.flake8-pytest-style] -mark-parentheses = false -fixture-parentheses = false -raises-require-match-for = ["*"] - -[tool.ruff.lint.flake8-tidy-imports] -ban-relative-imports = "all" - -[tool.ruff.lint.isort] -forced-separate = ["tests"] - -[tool.ruff.lint.flake8-tidy-imports.banned-api] -"pkg_resources".msg = "We're migrating away from pkg_resources. Please use importlib.resources or importlib.metadata instead." - -[tool.ruff.lint.pydocstyle] -convention = "google" - -[tool.clint] -exclude = [ - "docs", - "mlflow/protos", - "mlflow/ml_package_versions.py", - "mlflow/server/js", - "mlflow/store/db_migrations", - "tests/protos", -] +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mlflow-go" +version = "2.14.1" +description = "MLflow is an open source platform for the complete machine learning lifecycle" +readme = "README.md" +keywords = ["mlflow", "ai", "databricks"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: End Users/Desktop", + "Intended Audience :: Science/Research", + "Intended Audience :: Information Technology", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.8", +] +requires-python = ">=3.8" +dependencies = ["mlflow==2.14.1", "cffi"] +license = { file = "LICENSE.txt" } + +[[project.maintainers]] +name = "Databricks" +email = "mlflow-oss-maintainers@googlegroups.com" + +[project.urls] +homepage = "https://mlflow.org" +issues = "https://github.com/mlflow/mlflow-go/issues" +documentation = "https://mlflow.org/docs/latest/index.html" +repository = "https://github.com/mlflow/mlflow-go" + +[project.scripts] +mlflow-go = "mlflow_go.cli:cli" + +[project.entry-points."mlflow.tracking_store"] +mssql = "mlflow_go.store.tracking:_get_sqlalchemy_store" +mysql = "mlflow_go.store.tracking:_get_sqlalchemy_store" +postgresql = "mlflow_go.store.tracking:_get_sqlalchemy_store" +sqlite = "mlflow_go.store.tracking:_get_sqlalchemy_store" + +[project.entry-points."mlflow.model_registry_store"] +mssql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" +mysql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" +postgresql = "mlflow_go.store.model_registry:_get_sqlalchemy_store" +sqlite = "mlflow_go.store.model_registry:_get_sqlalchemy_store" + +[tool.setuptools.packages.find] +where = ["."] +include = ["mlflow_go", "mlflow_go.*"] +exclude = ["tests", "tests.*"] + +[tool.ruff] +line-length = 100 +target-version = "py38" +force-exclude = true +extend-include = ["*.ipynb"] +extend-exclude = [ + "examples/recipes", + "mlflow/protos", + "mlflow/ml_package_versions.py", + "mlflow/server/graphql/autogenerated_graphql_schema.py", + "mlflow/server/js", + "mlflow/store/db_migrations", + "tests/protos", +] + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + +[tool.ruff.lint] +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +select = [ + "B006", # multiple-argument-default + "B015", # useless-comparison + "D209", # new-line-after-last-paragraph + "D411", # no-blank-line-before-section + "E", # error + "F", # Pyflakes + "C4", # flake8-comprehensions + "I", # isort + "ISC001", # single-line-implicit-string-concatenation + "PIE790", # unnecessary-placeholder + "PLR0402", # manual-from-import + "PLE1205", # logging-too-many-args + "PT001", # pytest-fixture-incorrect-parentheses-style + "PT002", # pytest-fixture-positional-args + "PT003", # pytest-extraneous-scope-function + "PT006", # pytest-parameterize-names-wrong-type + "PT007", # pytest-parameterize-values-wrong-type + "PT009", # pytest-unittest-assertion + "PT010", # pytest-raises-without-exception + "PT011", # pytest-raises-too-broad + "PT012", # pytest-raises-with-multiple-statements + "PT013", # pytest-incorrect-pytest-import + "PT014", # pytest-duplicate-parametrize-test-cases + "PT018", # pytest-composite-assertion + "PT022", # pytest-useless-yield-fixture + "PT023", # pytest-incorrect-mark-parentheses-style + "PT026", # pytest-use-fixtures-without-parameters + "PT027", # pytest-unittest-raises-assertion + "RET504", # unnecessary-assign + "RUF010", # explicit-f-string-type-conversion + "RUF013", # implicit-optional + "RUF100", # unused-noqa + "S307", # suspicious-eval-usage + "S324", # hashlib-insecure-hash-function + "SIM101", # duplicate-isinstance-call + "SIM103", # needless-bool + "SIM108", # if-else-block-instead-of-if-exp + "SIM114", # if-with-same-arms + "SIM115", # open-file-with-context-handler + "SIM210", # if-expr-with-true-false + "SIM910", # dict-get-with-none-default + "T20", # flake8-print + "TID251", # banned-api + "TID252", # relative-improt + "TRY302", # useless-try-except + "UP004", # useless-object-inheritance + "UP008", # super-call-with-parameters + "UP011", # lru-cache-without-parameters + "UP012", # unecessary-encode-utf8 + "UP015", # redundant-open-modes + "UP030", # format-literals + "UP031", # printf-string-format + "UP032", # f-string + "UP034", # extraneous-parenthesis + "W", # warning +] +ignore = [ + "E402", # module-import-not-at-top-of-file + "E721", # type-comparison + "E741", # ambiguous-variable-name + "F811", # redefined-while-unused +] + +[tool.ruff.lint.per-file-ignores] +"dev/*" = ["T201", "PT018"] +"examples/*" = ["T20", "RET504", "E501"] +"docs/*" = ["T20", "RET504", "E501"] +"mlflow/*" = ["PT018"] + +[tool.ruff.lint.flake8-pytest-style] +mark-parentheses = false +fixture-parentheses = false +raises-require-match-for = ["*"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.lint.isort] +forced-separate = ["tests"] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"pkg_resources".msg = "We're migrating away from pkg_resources. Please use importlib.resources or importlib.metadata instead." + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.clint] +exclude = [ + "docs", + "mlflow/protos", + "mlflow/ml_package_versions.py", + "mlflow/server/js", + "mlflow/store/db_migrations", + "tests/protos", +] diff --git a/setup.py b/setup.py index f209bc5..74688c8 100644 --- a/setup.py +++ b/setup.py @@ -1,66 +1,66 @@ -import os -import pathlib -import sys -from glob import glob -from typing import List, Tuple - -from setuptools import Distribution, setup - -sys.path.insert(0, pathlib.Path(__file__).parent.joinpath("mlflow_go").as_posix()) -from lib import build_lib - - -def _prune_go_files(path: str): - for root, dirnames, filenames in os.walk(path, topdown=False): - for filename in filenames: - if filename.endswith(".go"): - os.unlink(os.path.join(root, filename)) - for dirname in dirnames: - try: - os.rmdir(os.path.join(root, dirname)) - except OSError: - pass - - -def finalize_distribution_options(dist: Distribution) -> None: - dist.has_ext_modules = lambda: True - - # this allows us to set the tag for the wheel without the python version - bdist_wheel_base_class = dist.get_command_class("bdist_wheel") - - class bdist_wheel_go(bdist_wheel_base_class): - def get_tag(self) -> Tuple[str, str, str]: - _, _, plat = super().get_tag() - return "py3", "none", plat - - dist.cmdclass["bdist_wheel"] = bdist_wheel_go - - # this allows us to build the go binary and add the Go source files to the sdist - build_base_class = dist.get_command_class("build") - - class build_go(build_base_class): - def initialize_options(self) -> None: - self.editable_mode = False - self.build_lib = None - - def finalize_options(self) -> None: - self.set_undefined_options("build_py", ("build_lib", "build_lib")) - - def run(self) -> None: - if not self.editable_mode: - _prune_go_files(self.build_lib) - build_lib( - pathlib.Path("."), - pathlib.Path(self.build_lib).joinpath("mlflow_go"), - ) - - def get_source_files(self) -> List[str]: - return ["go.mod", "go.sum", *glob("pkg/**/*.go", recursive=True)] - - dist.cmdclass["build_go"] = build_go - build_base_class.sub_commands.append(("build_go", None)) - - -Distribution.finalize_options = finalize_distribution_options - -setup() +import os +import pathlib +import sys +from glob import glob +from typing import List, Tuple + +from setuptools import Distribution, setup + +sys.path.insert(0, pathlib.Path(__file__).parent.joinpath("mlflow_go").as_posix()) +from lib import build_lib + + +def _prune_go_files(path: str): + for root, dirnames, filenames in os.walk(path, topdown=False): + for filename in filenames: + if filename.endswith(".go"): + os.unlink(os.path.join(root, filename)) + for dirname in dirnames: + try: + os.rmdir(os.path.join(root, dirname)) + except OSError: + pass + + +def finalize_distribution_options(dist: Distribution) -> None: + dist.has_ext_modules = lambda: True + + # this allows us to set the tag for the wheel without the python version + bdist_wheel_base_class = dist.get_command_class("bdist_wheel") + + class bdist_wheel_go(bdist_wheel_base_class): + def get_tag(self) -> Tuple[str, str, str]: + _, _, plat = super().get_tag() + return "py3", "none", plat + + dist.cmdclass["bdist_wheel"] = bdist_wheel_go + + # this allows us to build the go binary and add the Go source files to the sdist + build_base_class = dist.get_command_class("build") + + class build_go(build_base_class): + def initialize_options(self) -> None: + self.editable_mode = False + self.build_lib = None + + def finalize_options(self) -> None: + self.set_undefined_options("build_py", ("build_lib", "build_lib")) + + def run(self) -> None: + if not self.editable_mode: + _prune_go_files(self.build_lib) + build_lib( + pathlib.Path("."), + pathlib.Path(self.build_lib).joinpath("mlflow_go"), + ) + + def get_source_files(self) -> List[str]: + return ["go.mod", "go.sum", *glob("pkg/**/*.go", recursive=True)] + + dist.cmdclass["build_go"] = build_go + build_base_class.sub_commands.append(("build_go", None)) + + +Distribution.finalize_options = finalize_distribution_options + +setup() diff --git a/tests/override_model_registry_store.py b/tests/override_model_registry_store.py index e711e22..d1678fb 100644 --- a/tests/override_model_registry_store.py +++ b/tests/override_model_registry_store.py @@ -1,5 +1,5 @@ -from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore - -from mlflow_go.store.model_registry import ModelRegistryStore - -SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) +from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore + +from mlflow_go.store.model_registry import ModelRegistryStore + +SqlAlchemyStore = ModelRegistryStore(SqlAlchemyStore) diff --git a/tests/override_server.py b/tests/override_server.py index 6fd7550..a529d9c 100644 --- a/tests/override_server.py +++ b/tests/override_server.py @@ -1,77 +1,77 @@ -import contextlib -import logging -import sys - -import mlflow -import pytest -from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR -from mlflow.server.handlers import ModelRegistryStoreRegistryWrapper, TrackingStoreRegistryWrapper -from mlflow.utils import find_free_port - -from mlflow_go.server import server - -from tests.helper_functions import LOCALHOST -from tests.tracking.integration_test_utils import _await_server_up_or_die - -_logger = logging.getLogger(__name__) - - -@contextlib.contextmanager -def _init_server(backend_uri, root_artifact_uri, extra_env=None, app="mlflow.server:app"): - """ - Launch a new REST server using the tracking store specified by backend_uri and root artifact - directory specified by root_artifact_uri. - :returns A string URL to the server. - """ - scheme = backend_uri.split("://")[0] - if ("sqlite" or "postgresql" or "mysql" or "mssql") not in scheme: - pytest.skip(f'Unsupported scheme "{scheme}" for the Go server') - - mlflow.set_tracking_uri(None) - - server_port = find_free_port() - python_port = find_free_port() - url = f"http://{LOCALHOST}:{server_port}" - - _logger.info( - f"Initializing stores with backend URI {backend_uri} and artifact root {root_artifact_uri}" - ) - TrackingStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) - ModelRegistryStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) - - _logger.info( - f"Launching tracking server on {url} with backend URI {backend_uri} and " - f"artifact root {root_artifact_uri}" - ) - - with server( - address=f"{LOCALHOST}:{server_port}", - default_artifact_root=root_artifact_uri, - log_level=logging.getLevelName(_logger.getEffectiveLevel()), - model_registry_store_uri=backend_uri, - python_address=f"{LOCALHOST}:{python_port}", - python_command=[ - sys.executable, - "-m", - "flask", - "--app", - app, - "run", - "--host", - LOCALHOST, - "--port", - str(python_port), - ], - python_env=[ - f"{k}={v}" - for k, v in { - BACKEND_STORE_URI_ENV_VAR: backend_uri, - ARTIFACT_ROOT_ENV_VAR: root_artifact_uri, - **(extra_env or {}), - }.items() - ], - shutdown_timeout="5s", - tracking_store_uri=backend_uri, - ): - _await_server_up_or_die(server_port) - yield url +import contextlib +import logging +import sys + +import mlflow +import pytest +from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR +from mlflow.server.handlers import ModelRegistryStoreRegistryWrapper, TrackingStoreRegistryWrapper +from mlflow.utils import find_free_port + +from mlflow_go.server import server + +from tests.helper_functions import LOCALHOST +from tests.tracking.integration_test_utils import _await_server_up_or_die + +_logger = logging.getLogger(__name__) + + +@contextlib.contextmanager +def _init_server(backend_uri, root_artifact_uri, extra_env=None, app="mlflow.server:app"): + """ + Launch a new REST server using the tracking store specified by backend_uri and root artifact + directory specified by root_artifact_uri. + :returns A string URL to the server. + """ + scheme = backend_uri.split("://")[0] + if ("sqlite" or "postgresql" or "mysql" or "mssql") not in scheme: + pytest.skip(f'Unsupported scheme "{scheme}" for the Go server') + + mlflow.set_tracking_uri(None) + + server_port = find_free_port() + python_port = find_free_port() + url = f"http://{LOCALHOST}:{server_port}" + + _logger.info( + f"Initializing stores with backend URI {backend_uri} and artifact root {root_artifact_uri}" + ) + TrackingStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) + ModelRegistryStoreRegistryWrapper().get_store(backend_uri, root_artifact_uri) + + _logger.info( + f"Launching tracking server on {url} with backend URI {backend_uri} and " + f"artifact root {root_artifact_uri}" + ) + + with server( + address=f"{LOCALHOST}:{server_port}", + default_artifact_root=root_artifact_uri, + log_level=logging.getLevelName(_logger.getEffectiveLevel()), + model_registry_store_uri=backend_uri, + python_address=f"{LOCALHOST}:{python_port}", + python_command=[ + sys.executable, + "-m", + "flask", + "--app", + app, + "run", + "--host", + LOCALHOST, + "--port", + str(python_port), + ], + python_env=[ + f"{k}={v}" + for k, v in { + BACKEND_STORE_URI_ENV_VAR: backend_uri, + ARTIFACT_ROOT_ENV_VAR: root_artifact_uri, + **(extra_env or {}), + }.items() + ], + shutdown_timeout="5s", + tracking_store_uri=backend_uri, + ): + _await_server_up_or_die(server_port) + yield url diff --git a/tests/override_test_sqlalchemy_store.py b/tests/override_test_sqlalchemy_store.py index 4ed89cb..13cad4b 100644 --- a/tests/override_test_sqlalchemy_store.py +++ b/tests/override_test_sqlalchemy_store.py @@ -1,17 +1,17 @@ -from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore - - -def test_log_batch_internal_error(store: SqlAlchemyStore): - () - - -def test_log_batch_params_max_length_value(store: SqlAlchemyStore, monkeypatch): - () - - -def test_log_batch_null_metrics(store: SqlAlchemyStore): - () - - -def test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db(monkeypatch): - () +from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + + +def test_log_batch_internal_error(store: SqlAlchemyStore): + () + + +def test_log_batch_params_max_length_value(store: SqlAlchemyStore, monkeypatch): + () + + +def test_log_batch_null_metrics(store: SqlAlchemyStore): + () + + +def test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db(monkeypatch): + () diff --git a/tests/override_tracking_store.py b/tests/override_tracking_store.py index 05dcb56..26a7577 100644 --- a/tests/override_tracking_store.py +++ b/tests/override_tracking_store.py @@ -1,5 +1,5 @@ -from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore - -from mlflow_go.store.tracking import TrackingStore - -SqlAlchemyStore = TrackingStore(SqlAlchemyStore) +from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore + +from mlflow_go.store.tracking import TrackingStore + +SqlAlchemyStore = TrackingStore(SqlAlchemyStore) From 2cc3ee1da8d0eae8a3e9a197c989093ae1225d96 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Tue, 8 Oct 2024 01:27:30 +0000 Subject: [PATCH 03/24] Update mock store Signed-off-by: Juan Escalada --- pkg/tracking/store/mock_tracking_store.go | 51 +++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/pkg/tracking/store/mock_tracking_store.go b/pkg/tracking/store/mock_tracking_store.go index c6d52be..d5c73b1 100644 --- a/pkg/tracking/store/mock_tracking_store.go +++ b/pkg/tracking/store/mock_tracking_store.go @@ -868,6 +868,57 @@ func (_c *MockTrackingStore_SearchRuns_Call) RunAndReturn(run func(context.Conte return _c } +// SetTag provides a mock function with given fields: ctx, runID, key, value +func (_m *MockTrackingStore) SetTag(ctx context.Context, runID string, key string, value string) *contract.Error { + ret := _m.Called(ctx, runID, key, value) + + if len(ret) == 0 { + panic("no return value specified for SetTag") + } + + var r0 *contract.Error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *contract.Error); ok { + r0 = rf(ctx, runID, key, value) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*contract.Error) + } + } + + return r0 +} + +// MockTrackingStore_SetTag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTag' +type MockTrackingStore_SetTag_Call struct { + *mock.Call +} + +// SetTag is a helper method to define mock.On call +// - ctx context.Context +// - runID string +// - key string +// - value string +func (_e *MockTrackingStore_Expecter) SetTag(ctx interface{}, runID interface{}, key interface{}, value interface{}) *MockTrackingStore_SetTag_Call { + return &MockTrackingStore_SetTag_Call{Call: _e.mock.On("SetTag", ctx, runID, key, value)} +} + +func (_c *MockTrackingStore_SetTag_Call) Run(run func(ctx context.Context, runID string, key string, value string)) *MockTrackingStore_SetTag_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockTrackingStore_SetTag_Call) Return(_a0 *contract.Error) *MockTrackingStore_SetTag_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTrackingStore_SetTag_Call) RunAndReturn(run func(context.Context, string, string, string) *contract.Error) *MockTrackingStore_SetTag_Call { + _c.Call.Return(run) + return _c +} + // UpdateRun provides a mock function with given fields: ctx, runID, runStatus, endTime, runName func (_m *MockTrackingStore) UpdateRun(ctx context.Context, runID string, runStatus string, endTime *int64, runName string) *contract.Error { ret := _m.Called(ctx, runID, runStatus, endTime, runName) From 59f30ac567311133622c4b7ae8626f579feebc58 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Tue, 8 Oct 2024 01:29:25 +0000 Subject: [PATCH 04/24] Clean up unused code Signed-off-by: Juan Escalada --- .devcontainer/devcontainer.json | 2 +- pkg/tracking/service/tags.go | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 6fcc3ba..2498fab 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -62,5 +62,5 @@ "postCreateCommand": ".devcontainer/postCreate.sh", // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. - "remoteUser": "root" + // "remoteUser": "root" } \ No newline at end of file diff --git a/pkg/tracking/service/tags.go b/pkg/tracking/service/tags.go index 2f3464f..ba0456d 100644 --- a/pkg/tracking/service/tags.go +++ b/pkg/tracking/service/tags.go @@ -2,15 +2,12 @@ package service import ( "context" - "fmt" "github.com/mlflow/mlflow-go/pkg/contract" "github.com/mlflow/mlflow-go/pkg/protos" ) func (ts TrackingService) SetTag(ctx context.Context, input *protos.SetTag) (*protos.SetTag_Response, *contract.Error) { - // Print input - fmt.Println(input) if err := ts.Store.SetTag(ctx, input.GetRunId(), input.GetKey(), input.GetValue()); err != nil { return nil, err } From a3235d3453902e1b6b7a71bde0102a04adbf9683 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Tue, 8 Oct 2024 03:09:22 +0000 Subject: [PATCH 05/24] Refactor tests.go Signed-off-by: Juan Escalada --- magefiles/tests.go | 58 +++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/magefiles/tests.go b/magefiles/tests.go index 2e10272..d497617 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -28,8 +28,7 @@ func cleanUpMemoryFile() error { return nil } -// Run mlflow Python tests against the Go backend. -func (Test) Python() error { +func RunPythonTests(testFiles []string, testName string) error { libpath, err := os.MkdirTemp("", "") if err != nil { return err @@ -45,17 +44,22 @@ func (Test) Python() error { return nil } + args := []string{ + "--confcutdir=.", + } + args = append(args, testFiles...) + + // Add testName filter if provided + if testName != "" { + args = append(args, "-k", testName) + } else { + args = append(args, "-k", "not [file") + } + // Run the tests (currently just the server ones) if err := sh.RunWithV(map[string]string{ "MLFLOW_GO_LIBRARY_PATH": libpath, - }, "pytest", - "--confcutdir=.", - ".mlflow.repo/tests/tracking/test_rest_tracking.py", - ".mlflow.repo/tests/tracking/test_model_registry.py", - ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", - ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", - "-k", - "not [file", + }, "pytest", args..., // "-vv", ); err != nil { return err @@ -64,31 +68,21 @@ func (Test) Python() error { return nil } +// Run mlflow Python tests against the Go backend. +func (Test) Python() error { + return RunPythonTests([]string{ + ".mlflow.repo/tests/tracking/test_rest_tracking.py", + ".mlflow.repo/tests/tracking/test_model_registry.py", + ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", + ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", + }, "") +} + // Run specific Python test against the Go backend. func (Test) PythonSpecific(testName string) error { - libpath, err := os.MkdirTemp("", "") - if err != nil { - return err - } - - defer os.RemoveAll(libpath) - defer cleanUpMemoryFile() - - if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil { - return nil - } - - if err := sh.RunWithV(map[string]string{ - "MLFLOW_GO_LIBRARY_PATH": libpath, - }, "pytest", - "--confcutdir=.", + return RunPythonTests([]string{ ".mlflow.repo/tests/tracking/test_rest_tracking.py", - "-k", testName, - ); err != nil { - return err - } - - return nil + }, testName) } // Run the Go unit tests. From b810993957e8920d8249f2e312e181a1fbb1c6f7 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Tue, 8 Oct 2024 03:54:08 +0000 Subject: [PATCH 06/24] Remove unused validation code Signed-off-by: Juan Escalada --- pkg/tracking/store/sql/tags.go | 40 +--------------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index 9507133..5125882 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -82,41 +82,6 @@ func (s TrackingSQLStore) setTagsWithTransaction( return nil } -const ( - maxEntityKeyLength = 250 - maxTagValueLength = 8000 -) - -// Helper function to validate the tag key and value -func validateTag(key, value string) *contract.Error { - if key == "" { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "Missing value for required parameter 'key'", - ) - } - if len(key) > maxEntityKeyLength { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Tag key '%s' had length %d, which exceeded length limit of %d", key, len(key), maxEntityKeyLength), - ) - } - if len(value) > maxTagValueLength { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Tag value exceeded length limit of %d characters", maxTagValueLength), - ) - } - // TODO: Check if this is the correct way to prevent invalid values - if _, err := strconv.ParseFloat(value, 64); err == nil { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Invalid value %s for parameter 'value' supplied", value), - ) - } - return nil -} - func (s TrackingSQLStore) SetTag( ctx context.Context, runID, key, value string, ) *contract.Error { @@ -135,10 +100,6 @@ func (s TrackingSQLStore) SetTag( ) } - if err := validateTag(key, value); err != nil { - return err - } - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { contractError := checkRunIsActive(transaction, runID) if contractError != nil { @@ -188,6 +149,7 @@ func (s TrackingSQLStore) SetTag( if errors.As(err, &contractError) { return contractError } + return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("set tag transaction failed for %q", runID), From 084c471b6ab750f6a404f43c85ad15f1c80ae817 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 9 Oct 2024 04:01:34 +0000 Subject: [PATCH 07/24] Add SetTag struct validation Signed-off-by: Juan Escalada --- pkg/validation/validation.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 6349abb..b0e98c3 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -107,6 +107,15 @@ func validateLogBatchLimits(structLevel validator.StructLevel) { } } +// SetTag must have either a run_id or a run_uuid present. +func validateSetTagRunIDExists(structLevel validator.StructLevel) { + tag, isTag := structLevel.Current().Interface().(*protos.SetTag) + + if isTag && tag.GetRunId() == "" && tag.GetRunUuid() == "" { + structLevel.ReportError(&tag, "run_id", "", "", "") + } +} + func truncateFn(fieldLevel validator.FieldLevel) bool { param := fieldLevel.Param() // Get the parameter from the tag @@ -188,6 +197,7 @@ func NewValidator() (*validator.Validate, error) { } validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{}) + validate.RegisterStructValidation(validateSetTagRunIDExists, &protos.SetTag{}) return validate, nil } From 9039d3f3bf315dfa8a7dd8e09793495fb82f9c7d Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 9 Oct 2024 04:02:19 +0000 Subject: [PATCH 08/24] Replace runId requirement with manual check Signed-off-by: Juan Escalada --- magefiles/generate/validations.go | 7 +++---- pkg/protos/service.pb.go | 2 +- pkg/tracking/service/tags.go | 8 +++++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index b7f40c2..ea3b9fb 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -25,8 +25,7 @@ var validations = map[string]string{ "LogMetric_Key": "required", "LogMetric_Value": "required", "LogMetric_Timestamp": "required", - "SetTag_RunId": "required", - "SetTag_Key": "required", - "DeleteTag_RunId": "required", - "DeleteTag_Key": "required", + "SetTag_Key": "required", + "DeleteTag_RunId": "required", + "DeleteTag_Key": "required", } diff --git a/pkg/protos/service.pb.go b/pkg/protos/service.pb.go index 8a388b5..e31d83b 100644 --- a/pkg/protos/service.pb.go +++ b/pkg/protos/service.pb.go @@ -2074,7 +2074,7 @@ type SetTag struct { unknownFields protoimpl.UnknownFields // ID of the run under which to log the tag. Must be provided. - RunId *string `protobuf:"bytes,4,opt,name=run_id,json=runId" json:"run_id,omitempty" query:"run_id" validate:"required"` + RunId *string `protobuf:"bytes,4,opt,name=run_id,json=runId" json:"run_id,omitempty" query:"run_id"` // [Deprecated, use run_id instead] ID of the run under which to log the tag. This field will // be removed in a future MLflow version. RunUuid *string `protobuf:"bytes,1,opt,name=run_uuid,json=runUuid" json:"run_uuid,omitempty" query:"run_uuid"` diff --git a/pkg/tracking/service/tags.go b/pkg/tracking/service/tags.go index ba0456d..3336fdf 100644 --- a/pkg/tracking/service/tags.go +++ b/pkg/tracking/service/tags.go @@ -8,7 +8,13 @@ import ( ) func (ts TrackingService) SetTag(ctx context.Context, input *protos.SetTag) (*protos.SetTag_Response, *contract.Error) { - if err := ts.Store.SetTag(ctx, input.GetRunId(), input.GetKey(), input.GetValue()); err != nil { + runID := input.GetRunId() + + if runID == "" { + runID = input.GetRunUuid() + } + + if err := ts.Store.SetTag(ctx, runID, input.GetKey(), input.GetValue()); err != nil { return nil, err } From fd44cbcba32bb75cad5a8a737fdfdef72cfbce54 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 9 Oct 2024 04:02:39 +0000 Subject: [PATCH 09/24] Clean up SetTag store Signed-off-by: Juan Escalada --- pkg/tracking/store/sql/tags.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index 5125882..aee4be8 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strconv" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -92,14 +91,6 @@ func (s TrackingSQLStore) SetTag( ) } - // If the runID can be parsed as a number, it should throw an error - if _, err := strconv.ParseFloat(runID, 64); err == nil { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - fmt.Sprintf("Invalid value %s for parameter 'run_id' supplied", runID), - ) - } - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { contractError := checkRunIsActive(transaction, runID) if contractError != nil { From 4344210039012fb11c635e81f21331e49b017571 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 9 Oct 2024 04:09:22 +0000 Subject: [PATCH 10/24] Minor adjustments Signed-off-by: Juan Escalada --- .devcontainer/devcontainer.json | 2 +- magefiles/tests.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 2498fab..2a3f0fb 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -59,7 +59,7 @@ // "forwardPorts": [5432], // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": ".devcontainer/postCreate.sh", + "postCreateCommand": ".devcontainer/postCreate.sh" // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. // "remoteUser": "root" diff --git a/magefiles/tests.go b/magefiles/tests.go index d497617..9919a8a 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -28,7 +28,7 @@ func cleanUpMemoryFile() error { return nil } -func RunPythonTests(testFiles []string, testName string) error { +func runPythonTests(testFiles []string, testName string) error { libpath, err := os.MkdirTemp("", "") if err != nil { return err @@ -70,7 +70,7 @@ func RunPythonTests(testFiles []string, testName string) error { // Run mlflow Python tests against the Go backend. func (Test) Python() error { - return RunPythonTests([]string{ + return runPythonTests([]string{ ".mlflow.repo/tests/tracking/test_rest_tracking.py", ".mlflow.repo/tests/tracking/test_model_registry.py", ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", @@ -80,7 +80,7 @@ func (Test) Python() error { // Run specific Python test against the Go backend. func (Test) PythonSpecific(testName string) error { - return RunPythonTests([]string{ + return runPythonTests([]string{ ".mlflow.repo/tests/tracking/test_rest_tracking.py", }, testName) } From 02ed6d9bbcb23b42dbba1c91f4897f61cbe6f3c1 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 9 Oct 2024 04:17:20 +0000 Subject: [PATCH 11/24] Update validations Signed-off-by: Juan Escalada --- magefiles/generate/validations.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index ea3b9fb..46d733d 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -25,7 +25,8 @@ var validations = map[string]string{ "LogMetric_Key": "required", "LogMetric_Value": "required", "LogMetric_Timestamp": "required", - "SetTag_Key": "required", + "SetTag_Key": "required,max=1000,validMetricParamOrTagName,pathIsUnique", + "SetTag_Value": "required,max=8000", "DeleteTag_RunId": "required", "DeleteTag_Key": "required", } From bc25c96a25beac815d722845e49c32e50c33f1ad Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Fri, 11 Oct 2024 08:55:35 +0000 Subject: [PATCH 12/24] Add preliminary SetTag missing logic Signed-off-by: Juan Escalada --- magefiles/generate/validations.go | 2 +- magefiles/tests.go | 7 +++-- pkg/tracking/store/sql/tags.go | 52 ++++++++++++++++++++++++++++--- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index 46d733d..2950516 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -26,7 +26,7 @@ var validations = map[string]string{ "LogMetric_Value": "required", "LogMetric_Timestamp": "required", "SetTag_Key": "required,max=1000,validMetricParamOrTagName,pathIsUnique", - "SetTag_Value": "required,max=8000", + "SetTag_Value": "required,truncate=8000", "DeleteTag_RunId": "required", "DeleteTag_Key": "required", } diff --git a/magefiles/tests.go b/magefiles/tests.go index 9919a8a..5f466fd 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -51,9 +51,9 @@ func runPythonTests(testFiles []string, testName string) error { // Add testName filter if provided if testName != "" { - args = append(args, "-k", testName) + args = append(args, "-k", testName, "-v") } else { - args = append(args, "-k", "not [file") + args = append(args, "-k", "not [file", "-v") } // Run the tests (currently just the server ones) @@ -82,6 +82,9 @@ func (Test) Python() error { func (Test) PythonSpecific(testName string) error { return runPythonTests([]string{ ".mlflow.repo/tests/tracking/test_rest_tracking.py", + ".mlflow.repo/tests/tracking/test_model_registry.py", + ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", + ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", }, testName) } diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index aee4be8..06a338e 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -84,7 +84,11 @@ func (s TrackingSQLStore) setTagsWithTransaction( func (s TrackingSQLStore) SetTag( ctx context.Context, runID, key, value string, ) *contract.Error { + // Retrieve the logger from the context + logger := utils.GetLoggerFromContext(ctx) + if runID == "" { + logger.Info("RunID cannot be empty") return contract.NewError( protos.ErrorCode_INVALID_PARAMETER_VALUE, "RunID cannot be empty", @@ -94,13 +98,51 @@ func (s TrackingSQLStore) SetTag( err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { contractError := checkRunIsActive(transaction, runID) if contractError != nil { + logger.Info("Run is not active") return contractError } + if key == utils.TagRunName { + var run models.Run + result := transaction.Where("run_uuid = ?", runID).First(&run) + + if result.Error != nil { + logger.Info("Failed to query run for run_id %q", runID) + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query run for run_id %q", runID), + result.Error, + ) + } + + runStatus := run.Status.String() + + var endTimePtr *int64 + if run.EndTime.Valid { + endTimePtr = &run.EndTime.Int64 + } + + logger.Info("Updating run info for run_id %q", runID) + if err := s.UpdateRun(ctx, runID, runStatus, endTimePtr, value); err != nil { + logger.Printf("Failed to update run info for run_id %q", runID) + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to update run info for run_id %q", runID), + err, + ) + } + + return nil + } + + // Logging tag update + logger.Info("Setting tag for run_id %q", runID) + var tag models.Tag result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { + logger.Printf("Failed to query tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), @@ -111,6 +153,7 @@ func (s TrackingSQLStore) SetTag( if result.RowsAffected == 1 { tag.Value = value if err := transaction.Save(&tag).Error; err != nil { + logger.Printf("Failed to update tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), @@ -120,10 +163,11 @@ func (s TrackingSQLStore) SetTag( } else { newTag := models.Tag{ RunID: runID, - Key: key, - Value: value, + Key: key, + Value: value, } if err := transaction.Create(&newTag).Error; err != nil { + logger.Printf("Failed to create tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), @@ -134,8 +178,8 @@ func (s TrackingSQLStore) SetTag( return nil }) - if err != nil { + logger.Printf("SetTag transaction failed for run_id %q", runID) var contractError *contract.Error if errors.As(err, &contractError) { return contractError @@ -212,4 +256,4 @@ func (s TrackingSQLStore) DeleteTag( } return nil -} \ No newline at end of file +} From 7086ea1ac09a07fd7e966f83d356f06cbe8777cf Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Fri, 11 Oct 2024 08:59:06 +0000 Subject: [PATCH 13/24] Add missing generated file Signed-off-by: Juan Escalada --- pkg/protos/service.pb.go | 984 +++++++++++++++++++-------------------- 1 file changed, 477 insertions(+), 507 deletions(-) diff --git a/pkg/protos/service.pb.go b/pkg/protos/service.pb.go index e31d83b..0039e7b 100644 --- a/pkg/protos/service.pb.go +++ b/pkg/protos/service.pb.go @@ -2080,10 +2080,10 @@ type SetTag struct { RunUuid *string `protobuf:"bytes,1,opt,name=run_uuid,json=runUuid" json:"run_uuid,omitempty" query:"run_uuid"` // Name of the tag. Maximum size depends on storage backend. // All storage backends are guaranteed to support key values up to 250 bytes in size. - Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key" validate:"required"` + Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key" validate:"required,max=1000,validMetricParamOrTagName,pathIsUnique"` // String value of the tag being logged. Maximum size depends on storage backend. // All storage backends are guaranteed to support key values up to 5000 bytes in size. - Value *string `protobuf:"bytes,3,opt,name=value" json:"value,omitempty" query:"value"` + Value *string `protobuf:"bytes,3,opt,name=value" json:"value,omitempty" query:"value" validate:"required,truncate=8000"` } func (x *SetTag) Reset() { @@ -5661,7 +5661,7 @@ var file_service_proto_rawDesc = []byte{ 0x03, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x04, 0x73, 0x74, 0x65, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x3a, 0x01, 0x30, 0x52, 0x04, 0x73, 0x74, 0x65, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x75, 0x6e, 0x49, 0x64, 0x22, 0xd4, 0x02, 0x0a, 0x1c, 0x47, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x75, 0x6e, 0x49, 0x64, 0x22, 0xa3, 0x02, 0x0a, 0x1c, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x42, 0x75, 0x6c, 0x6b, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x75, @@ -5677,523 +5677,493 @@ var file_service_proto_rawDesc = []byte{ 0x12, 0x31, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x57, 0x69, 0x74, 0x68, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, - 0x69, 0x63, 0x73, 0x3a, 0x5c, 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, + 0x69, 0x63, 0x73, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, - 0x0a, 0x2f, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x54, 0x72, 0x61, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0xcf, 0x01, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x15, + 0x22, 0xcf, 0x01, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x15, 0x0a, + 0x06, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, + 0x75, 0x6e, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, + 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, 0x25, + 0x0a, 0x06, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, + 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x52, 0x06, 0x70, + 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x22, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x04, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x75, 0x6e, + 0x54, 0x61, 0x67, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, + 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, + 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x5d, 0x22, 0x79, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x0a, 0x06, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x72, 0x75, 0x6e, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, - 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, - 0x25, 0x0a, 0x06, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x0d, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x52, 0x06, - 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x22, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x04, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x75, - 0x6e, 0x54, 0x61, 0x67, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, + 0x72, 0x75, 0x6e, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x6a, + 0x73, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6d, 0x6f, 0x64, 0x65, 0x6c, + 0x4a, 0x73, 0x6f, 0x6e, 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, + 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, + 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, 0x93, 0x01, + 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x12, 0x1b, 0x0a, 0x06, 0x72, + 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, + 0x01, 0x52, 0x05, 0x72, 0x75, 0x6e, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x08, 0x64, 0x61, 0x74, 0x61, + 0x73, 0x65, 0x74, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x6c, 0x66, + 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x49, 0x6e, 0x70, 0x75, 0x74, + 0x52, 0x08, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x5d, 0x22, 0x79, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, - 0x15, 0x0a, 0x06, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x72, 0x75, 0x6e, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, - 0x6a, 0x73, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6d, 0x6f, 0x64, 0x65, - 0x6c, 0x4a, 0x73, 0x6f, 0x6e, 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, - 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, - 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, 0xc4, - 0x01, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x12, 0x1b, 0x0a, 0x06, - 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, - 0x19, 0x01, 0x52, 0x05, 0x72, 0x75, 0x6e, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x08, 0x64, 0x61, 0x74, - 0x61, 0x73, 0x65, 0x74, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x49, 0x6e, 0x70, 0x75, - 0x74, 0x52, 0x08, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x1a, 0x0a, 0x0a, 0x08, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x3a, 0x5c, 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, - 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, - 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, - 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, - 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xb1, 0x01, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x2d, 0x0a, - 0x0f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, 0x52, 0x0e, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x1a, 0x3e, 0x0a, 0x08, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, - 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x3a, 0x2b, 0xe2, 0x3f, + 0x73, 0x65, 0x5d, 0x22, 0xb1, 0x01, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x2d, 0x0a, 0x0f, 0x65, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, 0x52, 0x0e, 0x65, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x1a, 0x3e, 0x0a, 0x08, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x6c, 0x66, + 0x6c, 0x6f, 0x77, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0a, + 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, + 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, + 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, 0xba, 0x02, 0x0a, 0x09, 0x54, 0x72, 0x61, 0x63, + 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x49, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x65, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x74, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x0b, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x73, 0x12, 0x2a, 0x0a, 0x11, + 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x5f, 0x6d, + 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0f, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, + 0x6f, 0x6e, 0x54, 0x69, 0x6d, 0x65, 0x4d, 0x73, 0x12, 0x2b, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x13, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, + 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x47, 0x0a, 0x10, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1c, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x0f, 0x72, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x24, + 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x52, 0x04, + 0x74, 0x61, 0x67, 0x73, 0x22, 0x3e, 0x0a, 0x14, 0x54, 0x72, 0x61, 0x63, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x10, 0x0a, 0x03, + 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, + 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x22, 0x32, 0x0a, 0x08, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, + 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, + 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xae, 0x02, 0x0a, 0x0a, 0x53, 0x74, 0x61, + 0x72, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, + 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, + 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x0b, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x73, 0x12, + 0x47, 0x0a, 0x10, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x24, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, + 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x1a, 0x3c, + 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x0a, 0x74, 0x72, + 0x61, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, + 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, + 0x6f, 0x52, 0x09, 0x74, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, 0xba, 0x02, 0x0a, 0x09, 0x54, 0x72, - 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x65, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x74, - 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x0b, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x73, 0x12, 0x2a, - 0x0a, 0x11, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x69, 0x6d, 0x65, - 0x5f, 0x6d, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0f, 0x65, 0x78, 0x65, 0x63, 0x75, - 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x69, 0x6d, 0x65, 0x4d, 0x73, 0x12, 0x2b, 0x0a, 0x06, 0x73, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x13, 0x2e, 0x6d, 0x6c, 0x66, - 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, - 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x47, 0x0a, 0x10, 0x72, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, - 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, - 0x12, 0x24, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, - 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x22, 0x3e, 0x0a, 0x14, 0x54, 0x72, 0x61, 0x63, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x10, - 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, - 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x32, 0x0a, 0x08, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, - 0x61, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xdf, 0x02, 0x0a, 0x0a, 0x53, - 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, - 0x0a, 0x0c, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x73, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, - 0x73, 0x12, 0x47, 0x0a, 0x10, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x6d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x24, 0x0a, 0x04, 0x74, 0x61, - 0x67, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, - 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, - 0x1a, 0x3c, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x0a, - 0x74, 0x72, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, - 0x6e, 0x66, 0x6f, 0x52, 0x09, 0x74, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x3a, 0x5c, - 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, - 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, - 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, - 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, - 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x84, 0x03, 0x0a, - 0x08, 0x45, 0x6e, 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x74, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, - 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x73, 0x12, 0x2b, 0x0a, 0x06, 0x73, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x13, 0x2e, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x47, 0x0a, 0x10, 0x72, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x04, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, - 0x52, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, - 0x61, 0x12, 0x24, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, - 0x67, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x1a, 0x3c, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x0a, 0x74, 0x72, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, - 0x6f, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x09, 0x74, 0x72, 0x61, 0x63, - 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x3a, 0x5c, 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, - 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, - 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, - 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0xc9, 0x01, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, - 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x49, 0x64, 0x1a, 0x3c, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x30, 0x0a, 0x0a, 0x74, 0x72, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, - 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x09, 0x74, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, - 0x6f, 0x3a, 0x5c, 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, - 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, - 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x0a, 0x2f, - 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x6d, - 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x54, 0x72, 0x61, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0xea, 0x02, 0x0a, 0x0c, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, - 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, - 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, - 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x12, - 0x24, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x5f, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x05, 0x3a, 0x03, 0x31, 0x30, 0x30, 0x52, 0x0a, 0x6d, 0x61, 0x78, 0x52, 0x65, - 0x73, 0x75, 0x6c, 0x74, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x5f, 0x62, - 0x79, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x42, 0x79, - 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x61, 0x67, 0x65, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x61, 0x67, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x1a, - 0x5d, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x29, 0x0a, 0x06, 0x74, - 0x72, 0x61, 0x63, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x06, - 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x6e, 0x65, 0x78, 0x74, 0x5f, 0x70, - 0x61, 0x67, 0x65, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0d, 0x6e, 0x65, 0x78, 0x74, 0x50, 0x61, 0x67, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x3a, 0x5c, - 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, - 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, - 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, - 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, - 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xbc, 0x02, 0x0a, - 0x0c, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x29, 0x0a, - 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x6d, 0x61, 0x78, 0x5f, - 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x69, 0x6c, 0x6c, 0x69, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x12, 0x6d, 0x61, 0x78, 0x54, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x69, 0x6c, 0x6c, 0x69, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x6d, 0x61, - 0x78, 0x5f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, - 0x6d, 0x61, 0x78, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x72, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, - 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x73, 0x1a, 0x31, 0x0a, 0x08, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, - 0x5f, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, - 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x3a, 0x5c, 0xe2, - 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, 0xd3, 0x02, 0x0a, 0x08, 0x45, 0x6e, + 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x5f, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x74, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x73, 0x12, 0x2b, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x13, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, + 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x47, 0x0a, 0x10, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1c, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x0f, 0x72, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x24, + 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x52, 0x04, + 0x74, 0x61, 0x67, 0x73, 0x1a, 0x3c, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x30, 0x0a, 0x0a, 0x74, 0x72, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x09, 0x74, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, + 0x66, 0x6f, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, + 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, + 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, + 0x98, 0x01, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, + 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x1a, + 0x3c, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x0a, 0x74, + 0x72, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, + 0x66, 0x6f, 0x52, 0x09, 0x74, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x3a, 0x2b, 0xe2, + 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, - 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, 0x2e, - 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, - 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, 0x63, - 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xbe, 0x01, 0x0a, 0x0b, - 0x53, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x72, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, - 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x3a, 0x5c, - 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, - 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, - 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, - 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, - 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xab, 0x01, 0x0a, - 0x0e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x12, - 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x10, - 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, - 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x3a, 0x5c, 0xe2, 0x3f, - 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, - 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, 0x2e, 0x64, - 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, 0x63, 0x6b, - 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x8d, 0x01, 0x0a, 0x0e, 0x44, - 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x53, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x29, 0x0a, - 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, 0x52, 0x04, 0x6e, 0x61, - 0x6d, 0x65, 0x12, 0x1c, 0x0a, 0x06, 0x64, 0x69, 0x67, 0x65, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, 0x52, 0x06, 0x64, 0x69, 0x67, 0x65, 0x73, 0x74, - 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0xe6, 0x01, 0x0a, 0x0e, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x12, 0x25, 0x0a, - 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x49, 0x64, 0x73, 0x1a, 0x4f, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x43, 0x0a, 0x11, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x5f, 0x73, 0x75, 0x6d, 0x6d, - 0x61, 0x72, 0x69, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x53, 0x75, 0x6d, 0x6d, - 0x61, 0x72, 0x79, 0x52, 0x10, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x53, 0x75, 0x6d, 0x6d, - 0x61, 0x72, 0x69, 0x65, 0x73, 0x3a, 0x5c, 0xe2, 0x3f, 0x59, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, - 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, - 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x5d, 0x0a, 0x2f, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, - 0x6b, 0x73, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x2a, 0x36, 0x0a, 0x08, 0x56, 0x69, 0x65, 0x77, 0x54, 0x79, 0x70, 0x65, 0x12, - 0x0f, 0x0a, 0x0b, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x5f, 0x4f, 0x4e, 0x4c, 0x59, 0x10, 0x01, - 0x12, 0x10, 0x0a, 0x0c, 0x44, 0x45, 0x4c, 0x45, 0x54, 0x45, 0x44, 0x5f, 0x4f, 0x4e, 0x4c, 0x59, - 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x03, 0x2a, 0x49, 0x0a, 0x0a, 0x53, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0c, 0x0a, 0x08, 0x4e, 0x4f, 0x54, - 0x45, 0x42, 0x4f, 0x4f, 0x4b, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x4a, 0x4f, 0x42, 0x10, 0x02, - 0x12, 0x0b, 0x0a, 0x07, 0x50, 0x52, 0x4f, 0x4a, 0x45, 0x43, 0x54, 0x10, 0x03, 0x12, 0x09, 0x0a, - 0x05, 0x4c, 0x4f, 0x43, 0x41, 0x4c, 0x10, 0x04, 0x12, 0x0c, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, - 0x4f, 0x57, 0x4e, 0x10, 0xe8, 0x07, 0x2a, 0x4d, 0x0a, 0x09, 0x52, 0x75, 0x6e, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x12, 0x0b, 0x0a, 0x07, 0x52, 0x55, 0x4e, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x01, - 0x12, 0x0d, 0x0a, 0x09, 0x53, 0x43, 0x48, 0x45, 0x44, 0x55, 0x4c, 0x45, 0x44, 0x10, 0x02, 0x12, - 0x0c, 0x0a, 0x08, 0x46, 0x49, 0x4e, 0x49, 0x53, 0x48, 0x45, 0x44, 0x10, 0x03, 0x12, 0x0a, 0x0a, - 0x06, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x4b, 0x49, 0x4c, - 0x4c, 0x45, 0x44, 0x10, 0x05, 0x2a, 0x4f, 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x12, 0x1c, 0x0a, 0x18, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x53, 0x54, - 0x41, 0x54, 0x55, 0x53, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, - 0x10, 0x00, 0x12, 0x06, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, - 0x52, 0x4f, 0x52, 0x10, 0x02, 0x12, 0x0f, 0x0a, 0x0b, 0x49, 0x4e, 0x5f, 0x50, 0x52, 0x4f, 0x47, - 0x52, 0x45, 0x53, 0x53, 0x10, 0x03, 0x32, 0xe7, 0x21, 0x0a, 0x0d, 0x4d, 0x6c, 0x66, 0x6c, 0x6f, - 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0xa6, 0x01, 0x0a, 0x13, 0x67, 0x65, 0x74, - 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, - 0x12, 0x1b, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x1a, 0x24, 0x2e, - 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x4c, 0xf2, 0x86, 0x19, 0x48, 0x0a, 0x2c, 0x0a, 0x03, 0x47, 0x45, 0x54, - 0x12, 0x1f, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x67, 0x65, 0x74, 0x2d, 0x62, 0x79, 0x2d, 0x6e, 0x61, 0x6d, - 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x16, 0x47, 0x65, 0x74, 0x20, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x42, 0x79, 0x20, 0x4e, 0x61, 0x6d, - 0x65, 0x12, 0x94, 0x01, 0x0a, 0x10, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x18, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, - 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x1a, 0x21, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x43, 0xf2, 0x86, 0x19, 0x3f, 0x0a, 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, - 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x1a, 0x04, 0x08, - 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x20, 0x45, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0xc1, 0x01, 0x0a, 0x11, 0x73, 0x65, 0x61, - 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x19, + 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, 0xb9, 0x02, 0x0a, 0x0c, 0x53, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x65, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, + 0x64, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x12, 0x24, 0x0a, 0x0b, 0x6d, 0x61, + 0x78, 0x5f, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x3a, + 0x03, 0x31, 0x30, 0x30, 0x52, 0x0a, 0x6d, 0x61, 0x78, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, + 0x12, 0x19, 0x0a, 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x5f, 0x62, 0x79, 0x18, 0x04, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x07, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x42, 0x79, 0x12, 0x1d, 0x0a, 0x0a, 0x70, + 0x61, 0x67, 0x65, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x70, 0x61, 0x67, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x1a, 0x5d, 0x0a, 0x08, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x29, 0x0a, 0x06, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x06, 0x74, 0x72, 0x61, 0x63, 0x65, + 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x6e, 0x65, 0x78, 0x74, 0x5f, 0x70, 0x61, 0x67, 0x65, 0x5f, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6e, 0x65, 0x78, 0x74, + 0x50, 0x61, 0x67, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, + 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, + 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x22, 0x8b, 0x02, 0x0a, 0x0c, 0x44, 0x65, 0x6c, 0x65, 0x74, + 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x29, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, + 0xf8, 0x86, 0x19, 0x01, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x49, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x6d, 0x61, 0x78, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, + 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x69, 0x6c, 0x6c, 0x69, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x12, 0x6d, 0x61, 0x78, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x69, + 0x6c, 0x6c, 0x69, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x6d, 0x61, 0x78, 0x5f, 0x74, 0x72, 0x61, 0x63, + 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6d, 0x61, 0x78, 0x54, 0x72, 0x61, + 0x63, 0x65, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, + 0x64, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x49, 0x64, 0x73, 0x1a, 0x31, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x25, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x5f, 0x64, 0x65, 0x6c, 0x65, 0x74, + 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, + 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, + 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, + 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x5d, 0x22, 0x8d, 0x01, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, + 0x65, 0x54, 0x61, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x49, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x1a, 0x0a, 0x0a, 0x08, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, + 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, + 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x5d, 0x22, 0x7a, 0x0a, 0x0e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x1a, 0x0a, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, + 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, + 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, + 0x22, 0x8d, 0x01, 0x0a, 0x0e, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x53, 0x75, 0x6d, 0x6d, + 0x61, 0x72, 0x79, 0x12, 0x29, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, + 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x18, + 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, + 0x19, 0x01, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1c, 0x0a, 0x06, 0x64, 0x69, 0x67, 0x65, + 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x04, 0xf8, 0x86, 0x19, 0x01, 0x52, 0x06, + 0x64, 0x69, 0x67, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, + 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x22, 0xb5, 0x01, 0x0a, 0x0e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x61, 0x74, 0x61, 0x73, + 0x65, 0x74, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x1a, 0x4f, 0x0a, 0x08, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x11, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, + 0x74, 0x5f, 0x73, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x69, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x73, + 0x65, 0x74, 0x53, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x52, 0x10, 0x64, 0x61, 0x74, 0x61, 0x73, + 0x65, 0x74, 0x53, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x69, 0x65, 0x73, 0x3a, 0x2b, 0xe2, 0x3f, 0x28, + 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, + 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x50, 0x43, 0x5b, 0x24, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5d, 0x2a, 0x36, 0x0a, 0x08, 0x56, 0x69, 0x65, 0x77, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x5f, 0x4f, + 0x4e, 0x4c, 0x59, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x44, 0x45, 0x4c, 0x45, 0x54, 0x45, 0x44, + 0x5f, 0x4f, 0x4e, 0x4c, 0x59, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x03, + 0x2a, 0x49, 0x0a, 0x0a, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0c, + 0x0a, 0x08, 0x4e, 0x4f, 0x54, 0x45, 0x42, 0x4f, 0x4f, 0x4b, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, + 0x4a, 0x4f, 0x42, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x50, 0x52, 0x4f, 0x4a, 0x45, 0x43, 0x54, + 0x10, 0x03, 0x12, 0x09, 0x0a, 0x05, 0x4c, 0x4f, 0x43, 0x41, 0x4c, 0x10, 0x04, 0x12, 0x0c, 0x0a, + 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0xe8, 0x07, 0x2a, 0x4d, 0x0a, 0x09, 0x52, + 0x75, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x0b, 0x0a, 0x07, 0x52, 0x55, 0x4e, 0x4e, + 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x53, 0x43, 0x48, 0x45, 0x44, 0x55, 0x4c, + 0x45, 0x44, 0x10, 0x02, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x49, 0x4e, 0x49, 0x53, 0x48, 0x45, 0x44, + 0x10, 0x03, 0x12, 0x0a, 0x0a, 0x06, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x0a, + 0x0a, 0x06, 0x4b, 0x49, 0x4c, 0x4c, 0x45, 0x44, 0x10, 0x05, 0x2a, 0x4f, 0x0a, 0x0b, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1c, 0x0a, 0x18, 0x54, 0x52, 0x41, + 0x43, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, + 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x06, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0x01, 0x12, + 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x02, 0x12, 0x0f, 0x0a, 0x0b, 0x49, 0x4e, + 0x5f, 0x50, 0x52, 0x4f, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x03, 0x32, 0xe7, 0x21, 0x0a, 0x0d, + 0x4d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0xa6, 0x01, + 0x0a, 0x13, 0x67, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x42, + 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, + 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, + 0x6d, 0x65, 0x1a, 0x24, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x45, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4c, 0xf2, 0x86, 0x19, 0x48, 0x0a, 0x2c, + 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x1f, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x67, 0x65, 0x74, 0x2d, 0x62, + 0x79, 0x2d, 0x6e, 0x61, 0x6d, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x16, + 0x47, 0x65, 0x74, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x42, + 0x79, 0x20, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x94, 0x01, 0x0a, 0x10, 0x63, 0x72, 0x65, 0x61, 0x74, + 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x18, 0x2e, 0x6d, 0x6c, + 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x43, 0xf2, 0x86, 0x19, 0x3f, 0x0a, 0x28, + 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, + 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x63, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x11, 0x43, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0xc1, 0x01, + 0x0a, 0x11, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x73, 0x12, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x1a, 0x22, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x1a, 0x22, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x6d, 0xf2, - 0x86, 0x19, 0x69, 0x0a, 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, - 0x2f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x0a, 0x27, 0x0a, - 0x03, 0x47, 0x45, 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x12, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x88, 0x01, 0x0a, - 0x0d, 0x67, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x15, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, - 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x40, 0xf2, 0x86, 0x19, 0x38, 0x0a, 0x24, 0x0a, 0x03, 0x47, - 0x45, 0x54, 0x12, 0x17, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x67, 0x65, 0x74, 0x1a, 0x04, 0x08, 0x02, 0x10, - 0x00, 0x10, 0x01, 0x2a, 0x0e, 0x47, 0x65, 0x74, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x94, 0x01, 0x0a, 0x10, 0x64, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x18, 0x2e, 0x6d, - 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, - 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x43, 0xf2, 0x86, 0x19, 0x3f, 0x0a, - 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x64, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x11, 0x44, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x99, - 0x01, 0x0a, 0x11, 0x72, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, - 0x73, 0x74, 0x6f, 0x72, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, - 0x22, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, - 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x45, 0xf2, 0x86, 0x19, 0x41, 0x0a, 0x29, 0x0a, 0x04, 0x50, 0x4f, 0x53, - 0x54, 0x12, 0x1b, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x1a, 0x04, - 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x12, 0x52, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x20, - 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x94, 0x01, 0x0a, 0x10, 0x75, - 0x70, 0x64, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, - 0x18, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x43, 0xf2, 0x86, - 0x19, 0x3f, 0x0a, 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, - 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, - 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x11, - 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x12, 0x71, 0x0a, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x12, 0x11, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x52, 0x75, - 0x6e, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x52, 0x75, 0x6e, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x35, 0xf2, - 0x86, 0x19, 0x31, 0x0a, 0x21, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x13, 0x2f, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x20, 0x52, 0x75, 0x6e, 0x12, 0x71, 0x0a, 0x09, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x75, - 0x6e, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, - 0x65, 0x52, 0x75, 0x6e, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x55, 0x70, - 0x64, 0x61, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x35, 0xf2, 0x86, 0x19, 0x31, 0x0a, 0x21, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x13, - 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x75, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, 0x55, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x20, 0x52, 0x75, 0x6e, 0x12, 0x71, 0x0a, 0x09, 0x64, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x52, 0x75, 0x6e, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x35, 0xf2, 0x86, 0x19, 0x31, 0x0a, 0x21, 0x0a, 0x04, 0x50, 0x4f, 0x53, - 0x54, 0x12, 0x13, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, - 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, - 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x20, 0x52, 0x75, 0x6e, 0x12, 0x76, 0x0a, 0x0a, 0x72, 0x65, - 0x73, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x75, 0x6e, 0x12, 0x12, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, - 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x75, 0x6e, 0x1a, 0x1b, 0x2e, 0x6d, - 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x75, 0x6e, - 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x37, 0xf2, 0x86, 0x19, 0x33, 0x0a, - 0x22, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x14, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x1a, 0x04, 0x08, - 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0b, 0x52, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x20, 0x52, - 0x75, 0x6e, 0x12, 0x75, 0x0a, 0x09, 0x6c, 0x6f, 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, - 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x4d, 0x65, 0x74, 0x72, - 0x69, 0x63, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x39, - 0xf2, 0x86, 0x19, 0x35, 0x0a, 0x25, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x17, 0x2f, 0x6d, - 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x6c, 0x6f, 0x67, 0x2d, 0x6d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, 0x4c, - 0x6f, 0x67, 0x20, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x74, 0x0a, 0x08, 0x6c, 0x6f, 0x67, - 0x50, 0x61, 0x72, 0x61, 0x6d, 0x12, 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, - 0x6f, 0x67, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x1a, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2e, 0x4c, 0x6f, 0x67, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x3b, 0xf2, 0x86, 0x19, 0x37, 0x0a, 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, - 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x6c, - 0x6f, 0x67, 0x2d, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x1a, 0x04, 0x08, 0x02, - 0x10, 0x00, 0x10, 0x01, 0x2a, 0x09, 0x4c, 0x6f, 0x67, 0x20, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x12, - 0xa1, 0x01, 0x0a, 0x10, 0x73, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x54, 0x61, 0x67, 0x12, 0x18, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, - 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x54, 0x61, 0x67, 0x1a, 0x21, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x54, 0x61, 0x67, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x50, 0xf2, 0x86, 0x19, 0x4c, 0x0a, 0x34, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, - 0x26, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x73, 0x65, 0x74, 0x2d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x2d, 0x74, 0x61, 0x67, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, - 0x12, 0x53, 0x65, 0x74, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x20, - 0x54, 0x61, 0x67, 0x12, 0x66, 0x0a, 0x06, 0x73, 0x65, 0x74, 0x54, 0x61, 0x67, 0x12, 0x0e, 0x2e, - 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x61, 0x67, 0x1a, 0x17, 0x2e, - 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x61, 0x67, 0x2e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x33, 0xf2, 0x86, 0x19, 0x2f, 0x0a, 0x22, 0x0a, 0x04, - 0x50, 0x4f, 0x53, 0x54, 0x12, 0x14, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, - 0x6e, 0x73, 0x2f, 0x73, 0x65, 0x74, 0x2d, 0x74, 0x61, 0x67, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, - 0x10, 0x01, 0x2a, 0x07, 0x53, 0x65, 0x74, 0x20, 0x54, 0x61, 0x67, 0x12, 0x88, 0x01, 0x0a, 0x0b, - 0x73, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x12, 0x13, 0x2e, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, - 0x1a, 0x1c, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x72, 0x61, - 0x63, 0x65, 0x54, 0x61, 0x67, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x46, - 0xf2, 0x86, 0x19, 0x42, 0x0a, 0x2f, 0x0a, 0x05, 0x50, 0x41, 0x54, 0x43, 0x48, 0x12, 0x20, 0x2f, - 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2f, 0x7b, 0x72, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x74, 0x61, 0x67, 0x73, 0x1a, - 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x0d, 0x53, 0x65, 0x74, 0x20, 0x54, 0x72, 0x61, - 0x63, 0x65, 0x20, 0x54, 0x61, 0x67, 0x12, 0x95, 0x01, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x12, 0x16, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, - 0x67, 0x1a, 0x1f, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x4a, 0xf2, 0x86, 0x19, 0x46, 0x0a, 0x30, 0x0a, 0x06, 0x44, 0x45, 0x4c, 0x45, - 0x54, 0x45, 0x12, 0x20, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, - 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x7d, 0x2f, - 0x74, 0x61, 0x67, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x10, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x20, 0x54, 0x61, 0x67, 0x12, 0x75, - 0x0a, 0x09, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x61, 0x67, 0x12, 0x11, 0x2e, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x61, 0x67, 0x1a, 0x1a, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x61, - 0x67, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x39, 0xf2, 0x86, 0x19, 0x35, - 0x0a, 0x25, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x17, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, - 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x2d, 0x74, 0x61, - 0x67, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, 0x44, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x20, 0x54, 0x61, 0x67, 0x12, 0x65, 0x0a, 0x06, 0x67, 0x65, 0x74, 0x52, 0x75, 0x6e, 0x12, - 0x0e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x75, 0x6e, 0x1a, - 0x17, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x75, 0x6e, 0x2e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x32, 0xf2, 0x86, 0x19, 0x2a, 0x0a, 0x1d, - 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x10, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, - 0x75, 0x6e, 0x73, 0x2f, 0x67, 0x65, 0x74, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, - 0x07, 0x47, 0x65, 0x74, 0x20, 0x52, 0x75, 0x6e, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x79, 0x0a, 0x0a, - 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x52, 0x75, 0x6e, 0x73, 0x12, 0x12, 0x2e, 0x6d, 0x6c, 0x66, - 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x52, 0x75, 0x6e, 0x73, 0x1a, 0x1b, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x52, 0x75, - 0x6e, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3a, 0xf2, 0x86, 0x19, - 0x32, 0x0a, 0x21, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x13, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x1a, 0x04, - 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0b, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x20, 0x52, - 0x75, 0x6e, 0x73, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x87, 0x01, 0x0a, 0x0d, 0x6c, 0x69, 0x73, 0x74, - 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x73, 0x12, 0x15, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x73, - 0x1a, 0x1e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x41, 0x72, - 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x3f, 0xf2, 0x86, 0x19, 0x37, 0x0a, 0x23, 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x16, 0x2f, - 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x61, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x73, - 0x2f, 0x6c, 0x69, 0x73, 0x74, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0e, 0x4c, - 0x69, 0x73, 0x74, 0x20, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x73, 0xba, 0x8c, 0x19, - 0x00, 0x12, 0x95, 0x01, 0x0a, 0x10, 0x67, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, - 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x12, 0x18, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, - 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, - 0x1a, 0x21, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x44, 0xf2, 0x86, 0x19, 0x40, 0x0a, 0x28, 0x0a, 0x03, 0x47, 0x45, 0x54, - 0x12, 0x1b, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, - 0x73, 0x2f, 0x67, 0x65, 0x74, 0x2d, 0x68, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x04, 0x08, - 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x12, 0x47, 0x65, 0x74, 0x20, 0x4d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x20, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x12, 0xb7, 0x01, 0x0a, 0x1c, 0x67, 0x65, - 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x42, 0x75, - 0x6c, 0x6b, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x24, 0x2e, 0x6d, 0x6c, 0x66, - 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, - 0x74, 0x6f, 0x72, 0x79, 0x42, 0x75, 0x6c, 0x6b, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, - 0x1a, 0x2d, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x42, 0x75, 0x6c, 0x6b, 0x49, 0x6e, - 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x42, 0xf2, 0x86, 0x19, 0x3a, 0x0a, 0x36, 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x29, 0x2f, 0x6d, - 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x2f, 0x67, 0x65, - 0x74, 0x2d, 0x68, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x2d, 0x62, 0x75, 0x6c, 0x6b, 0x2d, 0x69, - 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x0b, 0x10, 0x03, 0xba, - 0x8c, 0x19, 0x00, 0x12, 0x70, 0x0a, 0x08, 0x6c, 0x6f, 0x67, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, - 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x42, 0x61, 0x74, 0x63, - 0x68, 0x1a, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x42, 0x61, - 0x74, 0x63, 0x68, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x37, 0xf2, 0x86, - 0x19, 0x33, 0x0a, 0x24, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x16, 0x2f, 0x6d, 0x6c, 0x66, - 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x6c, 0x6f, 0x67, 0x2d, 0x62, 0x61, 0x74, - 0x63, 0x68, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x09, 0x4c, 0x6f, 0x67, 0x20, - 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x70, 0x0a, 0x08, 0x6c, 0x6f, 0x67, 0x4d, 0x6f, 0x64, 0x65, - 0x6c, 0x12, 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x4d, 0x6f, - 0x64, 0x65, 0x6c, 0x1a, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, - 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x37, - 0xf2, 0x86, 0x19, 0x33, 0x0a, 0x24, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x16, 0x2f, 0x6d, - 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x6c, 0x6f, 0x67, 0x2d, 0x6d, - 0x6f, 0x64, 0x65, 0x6c, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x09, 0x4c, 0x6f, - 0x67, 0x20, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x75, 0x0a, 0x09, 0x6c, 0x6f, 0x67, 0x49, 0x6e, - 0x70, 0x75, 0x74, 0x73, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, - 0x67, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2e, 0x4c, 0x6f, 0x67, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x6d, 0xf2, 0x86, 0x19, 0x69, 0x0a, 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, + 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x1a, 0x04, 0x08, 0x02, + 0x10, 0x00, 0x0a, 0x27, 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x73, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x12, 0x53, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x73, 0x12, 0x88, 0x01, 0x0a, 0x0d, 0x67, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x12, 0x15, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, + 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x6c, 0x66, + 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x40, 0xf2, 0x86, 0x19, 0x38, + 0x0a, 0x24, 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x17, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, + 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x67, 0x65, 0x74, + 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0e, 0x47, 0x65, 0x74, 0x20, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x94, 0x01, 0x0a, + 0x10, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x12, 0x18, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, + 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x6c, + 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x43, + 0xf2, 0x86, 0x19, 0x3f, 0x0a, 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1a, 0x2f, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x73, 0x2f, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, + 0x2a, 0x11, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x12, 0x99, 0x01, 0x0a, 0x11, 0x72, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x45, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x22, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, + 0x73, 0x74, 0x6f, 0x72, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x45, 0xf2, 0x86, 0x19, 0x41, 0x0a, 0x29, + 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1b, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, + 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x72, 0x65, 0x73, 0x74, + 0x6f, 0x72, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x12, 0x52, 0x65, 0x73, + 0x74, 0x6f, 0x72, 0x65, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, + 0x94, 0x01, 0x0a, 0x10, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x18, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x21, + 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x43, 0xf2, 0x86, 0x19, 0x3f, 0x0a, 0x28, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, + 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, + 0x00, 0x10, 0x01, 0x2a, 0x11, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x20, 0x45, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x71, 0x0a, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, + 0x52, 0x75, 0x6e, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x35, 0xf2, 0x86, 0x19, 0x31, 0x0a, 0x21, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, + 0x12, 0x13, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x63, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, 0x43, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x20, 0x52, 0x75, 0x6e, 0x12, 0x71, 0x0a, 0x09, 0x75, 0x70, 0x64, + 0x61, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x2e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x35, 0xf2, 0x86, 0x19, 0x31, 0x0a, 0x21, 0x0a, 0x04, 0x50, + 0x4f, 0x53, 0x54, 0x12, 0x13, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, + 0x73, 0x2f, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, + 0x2a, 0x0a, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x20, 0x52, 0x75, 0x6e, 0x12, 0x71, 0x0a, 0x09, + 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x1a, 0x1a, 0x2e, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x75, 0x6e, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x35, 0xf2, 0x86, 0x19, 0x31, 0x0a, 0x21, + 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x13, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, + 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, + 0x00, 0x10, 0x01, 0x2a, 0x0a, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x20, 0x52, 0x75, 0x6e, 0x12, + 0x76, 0x0a, 0x0a, 0x72, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x75, 0x6e, 0x12, 0x12, 0x2e, + 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x52, 0x75, + 0x6e, 0x1a, 0x1b, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x6f, + 0x72, 0x65, 0x52, 0x75, 0x6e, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x37, + 0xf2, 0x86, 0x19, 0x33, 0x0a, 0x22, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x14, 0x2f, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x6f, + 0x72, 0x65, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0b, 0x52, 0x65, 0x73, 0x74, + 0x6f, 0x72, 0x65, 0x20, 0x52, 0x75, 0x6e, 0x12, 0x75, 0x0a, 0x09, 0x6c, 0x6f, 0x67, 0x4d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, + 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, + 0x2e, 0x4c, 0x6f, 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x39, 0xf2, 0x86, 0x19, 0x35, 0x0a, 0x25, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x17, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, - 0x6c, 0x6f, 0x67, 0x2d, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, - 0x10, 0x01, 0x2a, 0x0a, 0x4c, 0x6f, 0x67, 0x20, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x12, 0x87, - 0x01, 0x0a, 0x0e, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, - 0x73, 0x12, 0x16, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, - 0x68, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x1a, 0x1f, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, - 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3c, 0xf2, 0x86, 0x19, 0x34, - 0x0a, 0x30, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x22, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, - 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x73, 0x65, 0x61, - 0x72, 0x63, 0x68, 0x2d, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x1a, 0x04, 0x08, 0x02, - 0x10, 0x00, 0x10, 0x03, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x70, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, - 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x12, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, - 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x1a, 0x1b, 0x2e, 0x6d, 0x6c, 0x66, - 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x2e, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x31, 0xf2, 0x86, 0x19, 0x2d, 0x0a, 0x1c, 0x0a, - 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x0e, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, - 0x72, 0x61, 0x63, 0x65, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x0b, 0x53, - 0x74, 0x61, 0x72, 0x74, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x76, 0x0a, 0x08, 0x65, 0x6e, - 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, - 0x45, 0x6e, 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x1a, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, - 0x77, 0x2e, 0x45, 0x6e, 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x3d, 0xf2, 0x86, 0x19, 0x39, 0x0a, 0x2a, 0x0a, 0x05, 0x50, 0x41, 0x54, - 0x43, 0x48, 0x12, 0x1b, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, - 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x7d, 0x1a, - 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x09, 0x45, 0x6e, 0x64, 0x20, 0x54, 0x72, 0x61, - 0x63, 0x65, 0x12, 0x89, 0x01, 0x0a, 0x0c, 0x67, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, - 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, - 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x1d, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x2e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x44, 0xf2, 0x86, 0x19, 0x40, 0x0a, 0x2d, - 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x20, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, - 0x72, 0x61, 0x63, 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, - 0x64, 0x7d, 0x2f, 0x69, 0x6e, 0x66, 0x6f, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, - 0x0d, 0x47, 0x65, 0x74, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x77, - 0x0a, 0x0c, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x14, - 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x54, 0x72, - 0x61, 0x63, 0x65, 0x73, 0x1a, 0x1d, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x32, 0xf2, 0x86, 0x19, 0x2e, 0x0a, 0x1b, 0x0a, 0x03, 0x47, 0x45, 0x54, - 0x12, 0x0e, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, - 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x0d, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x86, 0x01, 0x0a, 0x0c, 0x64, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x14, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, - 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x1a, 0x1d, + 0x6c, 0x6f, 0x67, 0x2d, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, + 0x10, 0x01, 0x2a, 0x0a, 0x4c, 0x6f, 0x67, 0x20, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x74, + 0x0a, 0x08, 0x6c, 0x6f, 0x67, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x12, 0x10, 0x2e, 0x6d, 0x6c, 0x66, + 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x1a, 0x19, 0x2e, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x2e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3b, 0xf2, 0x86, 0x19, 0x37, 0x0a, 0x28, 0x0a, + 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1a, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, + 0x75, 0x6e, 0x73, 0x2f, 0x6c, 0x6f, 0x67, 0x2d, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, + 0x72, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x09, 0x4c, 0x6f, 0x67, 0x20, 0x50, + 0x61, 0x72, 0x61, 0x6d, 0x12, 0xa1, 0x01, 0x0a, 0x10, 0x73, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x54, 0x61, 0x67, 0x12, 0x18, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x54, 0x61, 0x67, 0x1a, 0x21, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, + 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x54, 0x61, 0x67, 0x2e, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x50, 0xf2, 0x86, 0x19, 0x4c, 0x0a, 0x34, 0x0a, 0x04, + 0x50, 0x4f, 0x53, 0x54, 0x12, 0x26, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x73, 0x65, 0x74, 0x2d, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2d, 0x74, 0x61, 0x67, 0x1a, 0x04, 0x08, 0x02, + 0x10, 0x00, 0x10, 0x01, 0x2a, 0x12, 0x53, 0x65, 0x74, 0x20, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x54, 0x61, 0x67, 0x12, 0x66, 0x0a, 0x06, 0x73, 0x65, 0x74, 0x54, + 0x61, 0x67, 0x12, 0x0e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x54, + 0x61, 0x67, 0x1a, 0x17, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x54, + 0x61, 0x67, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x33, 0xf2, 0x86, 0x19, + 0x2f, 0x0a, 0x22, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x14, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x73, 0x65, 0x74, 0x2d, 0x74, 0x61, 0x67, 0x1a, + 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x07, 0x53, 0x65, 0x74, 0x20, 0x54, 0x61, 0x67, + 0x12, 0x88, 0x01, 0x0a, 0x0b, 0x73, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, + 0x12, 0x13, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x72, 0x61, + 0x63, 0x65, 0x54, 0x61, 0x67, 0x1a, 0x1c, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, + 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x46, 0xf2, 0x86, 0x19, 0x42, 0x0a, 0x2f, 0x0a, 0x05, 0x50, 0x41, 0x54, + 0x43, 0x48, 0x12, 0x20, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, + 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x7d, 0x2f, + 0x74, 0x61, 0x67, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x0d, 0x53, 0x65, + 0x74, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x20, 0x54, 0x61, 0x67, 0x12, 0x95, 0x01, 0x0a, 0x0e, + 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x12, 0x16, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, - 0x61, 0x63, 0x65, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x41, 0xf2, - 0x86, 0x19, 0x3d, 0x0a, 0x2a, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x1c, 0x2f, 0x6d, 0x6c, - 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2f, 0x64, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x2d, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, - 0x03, 0x2a, 0x0d, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, - 0x42, 0x1e, 0xe2, 0x3f, 0x02, 0x10, 0x01, 0x0a, 0x14, 0x6f, 0x72, 0x67, 0x2e, 0x6d, 0x6c, 0x66, - 0x6c, 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x90, 0x01, 0x01, + 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x1a, 0x1f, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x54, 0x61, 0x67, 0x2e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4a, 0xf2, 0x86, 0x19, 0x46, 0x0a, 0x30, 0x0a, + 0x06, 0x44, 0x45, 0x4c, 0x45, 0x54, 0x45, 0x12, 0x20, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, + 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x74, 0x61, 0x67, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, + 0x03, 0x2a, 0x10, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x20, + 0x54, 0x61, 0x67, 0x12, 0x75, 0x0a, 0x09, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x61, 0x67, + 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x54, 0x61, 0x67, 0x1a, 0x1a, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, + 0x65, 0x74, 0x65, 0x54, 0x61, 0x67, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x39, 0xf2, 0x86, 0x19, 0x35, 0x0a, 0x25, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x17, 0x2f, + 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x64, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x2d, 0x74, 0x61, 0x67, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, + 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x20, 0x54, 0x61, 0x67, 0x12, 0x65, 0x0a, 0x06, 0x67, 0x65, + 0x74, 0x52, 0x75, 0x6e, 0x12, 0x0e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, + 0x74, 0x52, 0x75, 0x6e, 0x1a, 0x17, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, + 0x74, 0x52, 0x75, 0x6e, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x32, 0xf2, + 0x86, 0x19, 0x2a, 0x0a, 0x1d, 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x10, 0x2f, 0x6d, 0x6c, 0x66, + 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x67, 0x65, 0x74, 0x1a, 0x04, 0x08, 0x02, + 0x10, 0x00, 0x10, 0x01, 0x2a, 0x07, 0x47, 0x65, 0x74, 0x20, 0x52, 0x75, 0x6e, 0xba, 0x8c, 0x19, + 0x00, 0x12, 0x79, 0x0a, 0x0a, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x52, 0x75, 0x6e, 0x73, 0x12, + 0x12, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x52, + 0x75, 0x6e, 0x73, 0x1a, 0x1b, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x52, 0x75, 0x6e, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x3a, 0xf2, 0x86, 0x19, 0x32, 0x0a, 0x21, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x13, + 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x73, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0b, 0x53, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x20, 0x52, 0x75, 0x6e, 0x73, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x87, 0x01, 0x0a, + 0x0d, 0x6c, 0x69, 0x73, 0x74, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x73, 0x12, 0x15, + 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x41, 0x72, 0x74, 0x69, + 0x66, 0x61, 0x63, 0x74, 0x73, 0x1a, 0x1e, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, + 0x69, 0x73, 0x74, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x73, 0x2e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3f, 0xf2, 0x86, 0x19, 0x37, 0x0a, 0x23, 0x0a, 0x03, 0x47, + 0x45, 0x54, 0x12, 0x16, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x61, 0x72, 0x74, 0x69, + 0x66, 0x61, 0x63, 0x74, 0x73, 0x2f, 0x6c, 0x69, 0x73, 0x74, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, + 0x10, 0x01, 0x2a, 0x0e, 0x4c, 0x69, 0x73, 0x74, 0x20, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, + 0x74, 0x73, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x95, 0x01, 0x0a, 0x10, 0x67, 0x65, 0x74, 0x4d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x12, 0x18, 0x2e, 0x6d, 0x6c, + 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, + 0x73, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x21, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, + 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x44, 0xf2, 0x86, 0x19, 0x40, 0x0a, 0x28, + 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x1b, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x6d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x2f, 0x67, 0x65, 0x74, 0x2d, 0x68, 0x69, 0x73, 0x74, 0x6f, + 0x72, 0x79, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x12, 0x47, 0x65, 0x74, 0x20, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x20, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x12, 0xb7, + 0x01, 0x0a, 0x1c, 0x67, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, + 0x6f, 0x72, 0x79, 0x42, 0x75, 0x6c, 0x6b, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, + 0x24, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x42, 0x75, 0x6c, 0x6b, 0x49, 0x6e, 0x74, + 0x65, 0x72, 0x76, 0x61, 0x6c, 0x1a, 0x2d, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, + 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x42, + 0x75, 0x6c, 0x6b, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x2e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x42, 0xf2, 0x86, 0x19, 0x3a, 0x0a, 0x36, 0x0a, 0x03, 0x47, 0x45, + 0x54, 0x12, 0x29, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x6d, 0x65, 0x74, 0x72, 0x69, + 0x63, 0x73, 0x2f, 0x67, 0x65, 0x74, 0x2d, 0x68, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x2d, 0x62, + 0x75, 0x6c, 0x6b, 0x2d, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x1a, 0x04, 0x08, 0x02, + 0x10, 0x0b, 0x10, 0x03, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x70, 0x0a, 0x08, 0x6c, 0x6f, 0x67, 0x42, + 0x61, 0x74, 0x63, 0x68, 0x12, 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, + 0x67, 0x42, 0x61, 0x74, 0x63, 0x68, 0x1a, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x4c, 0x6f, 0x67, 0x42, 0x61, 0x74, 0x63, 0x68, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x37, 0xf2, 0x86, 0x19, 0x33, 0x0a, 0x24, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, + 0x16, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x6c, 0x6f, + 0x67, 0x2d, 0x62, 0x61, 0x74, 0x63, 0x68, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, + 0x09, 0x4c, 0x6f, 0x67, 0x20, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x70, 0x0a, 0x08, 0x6c, 0x6f, + 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x10, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x4c, 0x6f, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x1a, 0x19, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, + 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x37, 0xf2, 0x86, 0x19, 0x33, 0x0a, 0x24, 0x0a, 0x04, 0x50, 0x4f, 0x53, + 0x54, 0x12, 0x16, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x72, 0x75, 0x6e, 0x73, 0x2f, + 0x6c, 0x6f, 0x67, 0x2d, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, + 0x01, 0x2a, 0x09, 0x4c, 0x6f, 0x67, 0x20, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x75, 0x0a, 0x09, + 0x6c, 0x6f, 0x67, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x12, 0x11, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x1a, 0x1a, 0x2e, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x4c, 0x6f, 0x67, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x39, 0xf2, 0x86, 0x19, 0x35, 0x0a, 0x25, + 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x17, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, + 0x72, 0x75, 0x6e, 0x73, 0x2f, 0x6c, 0x6f, 0x67, 0x2d, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x1a, + 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x01, 0x2a, 0x0a, 0x4c, 0x6f, 0x67, 0x20, 0x49, 0x6e, 0x70, + 0x75, 0x74, 0x73, 0x12, 0x87, 0x01, 0x0a, 0x0e, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x61, + 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x12, 0x16, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, + 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x1a, 0x1f, + 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x61, + 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x3c, 0xf2, 0x86, 0x19, 0x34, 0x0a, 0x30, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x22, 0x6d, + 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x73, 0x2f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x2d, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, + 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0xba, 0x8c, 0x19, 0x00, 0x12, 0x70, 0x0a, + 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x12, 0x2e, 0x6d, 0x6c, + 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x1a, + 0x1b, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x31, 0xf2, 0x86, + 0x19, 0x2d, 0x0a, 0x1c, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, 0x12, 0x0e, 0x2f, 0x6d, 0x6c, 0x66, + 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, + 0x10, 0x03, 0x2a, 0x0b, 0x53, 0x74, 0x61, 0x72, 0x74, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, + 0x76, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x10, 0x2e, 0x6d, 0x6c, + 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x45, 0x6e, 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x1a, 0x19, 0x2e, + 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x45, 0x6e, 0x64, 0x54, 0x72, 0x61, 0x63, 0x65, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0xf2, 0x86, 0x19, 0x39, 0x0a, 0x2a, + 0x0a, 0x05, 0x50, 0x41, 0x54, 0x43, 0x48, 0x12, 0x1b, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, + 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x5f, 0x69, 0x64, 0x7d, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x09, 0x45, 0x6e, + 0x64, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x12, 0x89, 0x01, 0x0a, 0x0c, 0x67, 0x65, 0x74, 0x54, + 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, + 0x77, 0x2e, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x1d, + 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x63, 0x65, + 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x44, 0xf2, + 0x86, 0x19, 0x40, 0x0a, 0x2d, 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x20, 0x2f, 0x6d, 0x6c, 0x66, + 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x69, 0x6e, 0x66, 0x6f, 0x1a, 0x04, 0x08, 0x02, + 0x10, 0x00, 0x10, 0x03, 0x2a, 0x0d, 0x47, 0x65, 0x74, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x49, + 0x6e, 0x66, 0x6f, 0x12, 0x77, 0x0a, 0x0c, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, + 0x63, 0x65, 0x73, 0x12, 0x14, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x1a, 0x1d, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, + 0x6f, 0x77, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x32, 0xf2, 0x86, 0x19, 0x2e, 0x0a, 0x1b, + 0x0a, 0x03, 0x47, 0x45, 0x54, 0x12, 0x0e, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, + 0x72, 0x61, 0x63, 0x65, 0x73, 0x1a, 0x04, 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x0d, 0x53, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x20, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x86, 0x01, 0x0a, + 0x0c, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x12, 0x14, 0x2e, + 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, + 0x63, 0x65, 0x73, 0x1a, 0x1d, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x65, 0x6c, + 0x65, 0x74, 0x65, 0x54, 0x72, 0x61, 0x63, 0x65, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x41, 0xf2, 0x86, 0x19, 0x3d, 0x0a, 0x2a, 0x0a, 0x04, 0x50, 0x4f, 0x53, 0x54, + 0x12, 0x1c, 0x2f, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, + 0x2f, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x2d, 0x74, 0x72, 0x61, 0x63, 0x65, 0x73, 0x1a, 0x04, + 0x08, 0x02, 0x10, 0x00, 0x10, 0x03, 0x2a, 0x0d, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x20, 0x54, + 0x72, 0x61, 0x63, 0x65, 0x73, 0x42, 0x1e, 0xe2, 0x3f, 0x02, 0x10, 0x01, 0x0a, 0x14, 0x6f, 0x72, + 0x67, 0x2e, 0x6d, 0x6c, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x90, 0x01, 0x01, } var ( From 8eeaacb16fb80716dd133c62a431246da3495805 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Fri, 11 Oct 2024 11:09:44 +0000 Subject: [PATCH 14/24] Fix protoc version issue Signed-off-by: Juan Escalada --- magefiles/tests.go | 6 +++--- pkg/protos/artifacts/mlflow_artifacts.pb.go | 2 +- pkg/protos/databricks.pb.go | 2 +- pkg/protos/databricks_artifacts.pb.go | 2 +- pkg/protos/internal.pb.go | 2 +- pkg/protos/model_registry.pb.go | 2 +- pkg/protos/scalapb/scalapb.pb.go | 2 +- pkg/protos/service.pb.go | 2 +- pkg/tracking/service/tags.go | 2 +- pkg/tracking/store/sql/runs_internal_test.go | 1 - pkg/tracking/store/sql/tags.go | 18 +++++++++--------- pkg/tracking/store/store.go | 2 +- 12 files changed, 21 insertions(+), 22 deletions(-) diff --git a/magefiles/tests.go b/magefiles/tests.go index 5f466fd..a66a639 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -51,16 +51,16 @@ func runPythonTests(testFiles []string, testName string) error { // Add testName filter if provided if testName != "" { - args = append(args, "-k", testName, "-v") + args = append(args, "-k", testName, "-vv") } else { - args = append(args, "-k", "not [file", "-v") + args = append(args, "-k", "not [file") } // Run the tests (currently just the server ones) if err := sh.RunWithV(map[string]string{ "MLFLOW_GO_LIBRARY_PATH": libpath, }, "pytest", args..., - // "-vv", + // "-vv", ); err != nil { return err } diff --git a/pkg/protos/artifacts/mlflow_artifacts.pb.go b/pkg/protos/artifacts/mlflow_artifacts.pb.go index ca00c6b..9633522 100644 --- a/pkg/protos/artifacts/mlflow_artifacts.pb.go +++ b/pkg/protos/artifacts/mlflow_artifacts.pb.go @@ -7,7 +7,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v3.21.12 +// protoc v5.26.0 // source: mlflow_artifacts.proto package artifacts diff --git a/pkg/protos/databricks.pb.go b/pkg/protos/databricks.pb.go index cb6a6b9..112dda9 100644 --- a/pkg/protos/databricks.pb.go +++ b/pkg/protos/databricks.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v3.21.12 +// protoc v5.26.0 // source: databricks.proto package protos diff --git a/pkg/protos/databricks_artifacts.pb.go b/pkg/protos/databricks_artifacts.pb.go index 26484ae..beddd2b 100644 --- a/pkg/protos/databricks_artifacts.pb.go +++ b/pkg/protos/databricks_artifacts.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v3.21.12 +// protoc v5.26.0 // source: databricks_artifacts.proto package protos diff --git a/pkg/protos/internal.pb.go b/pkg/protos/internal.pb.go index 38d9bc8..2a9d430 100644 --- a/pkg/protos/internal.pb.go +++ b/pkg/protos/internal.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v3.21.12 +// protoc v5.26.0 // source: internal.proto package protos diff --git a/pkg/protos/model_registry.pb.go b/pkg/protos/model_registry.pb.go index dd6eb30..a6aec3f 100644 --- a/pkg/protos/model_registry.pb.go +++ b/pkg/protos/model_registry.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v3.21.12 +// protoc v5.26.0 // source: model_registry.proto package protos diff --git a/pkg/protos/scalapb/scalapb.pb.go b/pkg/protos/scalapb/scalapb.pb.go index 95aaa19..3b4f090 100644 --- a/pkg/protos/scalapb/scalapb.pb.go +++ b/pkg/protos/scalapb/scalapb.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v3.21.12 +// protoc v5.26.0 // source: scalapb/scalapb.proto package scalapb diff --git a/pkg/protos/service.pb.go b/pkg/protos/service.pb.go index 3390a31..fd61eff 100644 --- a/pkg/protos/service.pb.go +++ b/pkg/protos/service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.33.0 -// protoc v3.21.12 +// protoc v5.26.0 // source: service.proto package protos diff --git a/pkg/tracking/service/tags.go b/pkg/tracking/service/tags.go index bb96ea2..6e22327 100644 --- a/pkg/tracking/service/tags.go +++ b/pkg/tracking/service/tags.go @@ -23,7 +23,7 @@ func (ts TrackingService) SetTag(ctx context.Context, input *protos.SetTag) (*pr if runID == "" { runID = input.GetRunUuid() } - + if err := ts.Store.SetTag(ctx, runID, input.GetKey(), input.GetValue()); err != nil { return nil, err } diff --git a/pkg/tracking/store/sql/runs_internal_test.go b/pkg/tracking/store/sql/runs_internal_test.go index f42583d..d20b437 100644 --- a/pkg/tracking/store/sql/runs_internal_test.go +++ b/pkg/tracking/store/sql/runs_internal_test.go @@ -1,4 +1,3 @@ -//nolint:ireturn package sql import ( diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index cdce536..9781111 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -151,7 +151,7 @@ func (s TrackingSQLStore) SetTag( logger := utils.GetLoggerFromContext(ctx) if runID == "" { - logger.Info("RunID cannot be empty") + logger.Infof("RunID cannot be empty") return contract.NewError( protos.ErrorCode_INVALID_PARAMETER_VALUE, "RunID cannot be empty", @@ -161,7 +161,7 @@ func (s TrackingSQLStore) SetTag( err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { contractError := checkRunIsActive(transaction, runID) if contractError != nil { - logger.Info("Run is not active") + logger.Infof("Run is not active") return contractError } @@ -170,7 +170,7 @@ func (s TrackingSQLStore) SetTag( result := transaction.Where("run_uuid = ?", runID).First(&run) if result.Error != nil { - logger.Info("Failed to query run for run_id %q", runID) + logger.Infof("Failed to query run for run_id %q", runID) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to query run for run_id %q", runID), @@ -185,7 +185,7 @@ func (s TrackingSQLStore) SetTag( endTimePtr = &run.EndTime.Int64 } - logger.Info("Updating run info for run_id %q", runID) + logger.Infof("Updating run info for run_id %q", runID) if err := s.UpdateRun(ctx, runID, runStatus, endTimePtr, value); err != nil { logger.Printf("Failed to update run info for run_id %q", runID) return contract.NewErrorWith( @@ -199,13 +199,13 @@ func (s TrackingSQLStore) SetTag( } // Logging tag update - logger.Info("Setting tag for run_id %q", runID) + logger.Infof("Setting tag for run_id %q", runID) var tag models.Tag result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { - logger.Printf("Failed to query tag for run_id %q and key %q", runID, key) + logger.Infof("Failed to query tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), @@ -216,7 +216,7 @@ func (s TrackingSQLStore) SetTag( if result.RowsAffected == 1 { tag.Value = value if err := transaction.Save(&tag).Error; err != nil { - logger.Printf("Failed to update tag for run_id %q and key %q", runID, key) + logger.Infof("Failed to update tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), @@ -230,7 +230,7 @@ func (s TrackingSQLStore) SetTag( Value: value, } if err := transaction.Create(&newTag).Error; err != nil { - logger.Printf("Failed to create tag for run_id %q and key %q", runID, key) + logger.Infof("Failed to create tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), @@ -242,7 +242,7 @@ func (s TrackingSQLStore) SetTag( return nil }) if err != nil { - logger.Printf("SetTag transaction failed for run_id %q", runID) + logger.Infof("SetTag transaction failed for run_id %q", runID) var contractError *contract.Error if errors.As(err, &contractError) { return contractError diff --git a/pkg/tracking/store/store.go b/pkg/tracking/store/store.go index 40d0f01..45c8714 100644 --- a/pkg/tracking/store/store.go +++ b/pkg/tracking/store/store.go @@ -37,7 +37,7 @@ type ( RestoreRun(ctx context.Context, runID string) *contract.Error GetRunTag(ctx context.Context, runID, tagKey string) (*entities.RunTag, *contract.Error) DeleteTag(ctx context.Context, runID, key string) *contract.Error - SetTag(ctx context.Context, runID, key string, value string) *contract.Error + SetTag(ctx context.Context, runID, key, value string) *contract.Error } MetricTrackingStore interface { LogBatch( From 91660ed758b0c7dd6a24d855aa1c9d5fad185ab8 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Mon, 14 Oct 2024 02:24:33 +0000 Subject: [PATCH 15/24] Fix test_update_run_name test Signed-off-by: Juan Escalada --- magefiles/generate/validations.go | 2 +- pkg/protos/service.pb.go | 2 +- pkg/tracking/store/sql/tags.go | 80 +++++++++++++------------------ 3 files changed, 35 insertions(+), 49 deletions(-) diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index a734601..6b31217 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -33,5 +33,5 @@ var validations = map[string]string{ "SetExperimentTag_Key": "required,max=250,validMetricParamOrTagName", "SetExperimentTag_Value": "max=5000", "SetTag_Key": "required,max=1000,validMetricParamOrTagName,pathIsUnique", - "SetTag_Value": "required,truncate=8000", + "SetTag_Value": "omitempty,truncate=8000", } diff --git a/pkg/protos/service.pb.go b/pkg/protos/service.pb.go index fd61eff..d7a98a5 100644 --- a/pkg/protos/service.pb.go +++ b/pkg/protos/service.pb.go @@ -2083,7 +2083,7 @@ type SetTag struct { Key *string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty" query:"key" validate:"required,max=1000,validMetricParamOrTagName,pathIsUnique"` // String value of the tag being logged. Maximum size depends on storage backend. // All storage backends are guaranteed to support key values up to 5000 bytes in size. - Value *string `protobuf:"bytes,3,opt,name=value" json:"value,omitempty" query:"value" validate:"required,truncate=8000"` + Value *string `protobuf:"bytes,3,opt,name=value" json:"value,omitempty" query:"value" validate:"omitempty,truncate=8001"` } func (x *SetTag) Reset() { diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index 9781111..d92c11a 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -147,65 +147,34 @@ func (s TrackingSQLStore) DeleteTag( func (s TrackingSQLStore) SetTag( ctx context.Context, runID, key, value string, ) *contract.Error { - // Retrieve the logger from the context - logger := utils.GetLoggerFromContext(ctx) - if runID == "" { - logger.Infof("RunID cannot be empty") return contract.NewError( protos.ErrorCode_INVALID_PARAMETER_VALUE, "RunID cannot be empty", ) } - err := s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { + var run models.Run + err := s.db.Where("run_uuid = ?", runID).First(&run).Error + + if err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query run for run_id %q", runID), + err, + ) + } + + err = s.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error { contractError := checkRunIsActive(transaction, runID) if contractError != nil { - logger.Infof("Run is not active") return contractError } - if key == utils.TagRunName { - var run models.Run - result := transaction.Where("run_uuid = ?", runID).First(&run) - - if result.Error != nil { - logger.Infof("Failed to query run for run_id %q", runID) - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to query run for run_id %q", runID), - result.Error, - ) - } - - runStatus := run.Status.String() - - var endTimePtr *int64 - if run.EndTime.Valid { - endTimePtr = &run.EndTime.Int64 - } - - logger.Infof("Updating run info for run_id %q", runID) - if err := s.UpdateRun(ctx, runID, runStatus, endTimePtr, value); err != nil { - logger.Printf("Failed to update run info for run_id %q", runID) - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to update run info for run_id %q", runID), - err, - ) - } - - return nil - } - - // Logging tag update - logger.Infof("Setting tag for run_id %q", runID) - var tag models.Tag result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { - logger.Infof("Failed to query tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), @@ -216,7 +185,6 @@ func (s TrackingSQLStore) SetTag( if result.RowsAffected == 1 { tag.Value = value if err := transaction.Save(&tag).Error; err != nil { - logger.Infof("Failed to update tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), @@ -230,7 +198,6 @@ func (s TrackingSQLStore) SetTag( Value: value, } if err := transaction.Create(&newTag).Error; err != nil { - logger.Infof("Failed to create tag for run_id %q and key %q", runID, key) return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), @@ -241,8 +208,27 @@ func (s TrackingSQLStore) SetTag( return nil }) + + if key == utils.TagRunName { + runStatus := run.Status.String() + + var endTimePtr *int64 + if run.EndTime.Valid { + endTimePtr = &run.EndTime.Int64 + } + + if err := s.UpdateRun(ctx, runID, runStatus, endTimePtr, value); err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to update run info for run_id %q", runID), + err, + ) + } + + return nil + } + if err != nil { - logger.Infof("SetTag transaction failed for run_id %q", runID) var contractError *contract.Error if errors.As(err, &contractError) { return contractError @@ -250,7 +236,7 @@ func (s TrackingSQLStore) SetTag( return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("set tag transaction failed for %q", runID), + fmt.Sprintf("Set tag transaction failed for run_id %q", runID), err, ) } From f91e43b81b1a366df89c5ae5051830f307c67119 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Mon, 14 Oct 2024 03:37:18 +0000 Subject: [PATCH 16/24] Add pythonSpecific to docs, normalize MLflow spelling Signed-off-by: Juan Escalada --- CONTRIBUTING.md | 13 +++++++++---- docs/porting-a-new-endpoint.md | 2 +- magefiles/configure.go | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4b5c1cf..d96147d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ to configure all the development environment just run `mage` target: mage configure ``` -it will configure mlflow and all the Python dependencies required by the project or run each step manually: +it will configure MLflow and all the Python dependencies required by the project or run each step manually: ```bash # Install our Python package and its dependencies @@ -17,7 +17,7 @@ pip install -e . # Install the dreaded psycho pip install psycopg2-binary -# Archive the MLFlow pre-built UI +# Archive the MLflow pre-built UI tar -C /usr/local/python/current/lib/python3.8/site-packages/mlflow -czvf ./ui.tgz ./server/js/build # Clone the MLflow repo @@ -30,7 +30,7 @@ tar -C .mlflow.repo/mlflow -xzvf ./ui.tgz pip install -e .mlflow.repo ``` -## Run Go Mlflow server +## Run Go MLflow server To start the mlflow-go dev server connecting to postgres just run next `mage` target: @@ -61,10 +61,15 @@ mage test:all ``` ```bash -# Run just MLFlow Python tests +# Run just MLflow Python tests mage test:python ``` +```bash +# Run specific MLflow Python tests (matches all tests containing the argument) +mage test:pythonSpecific +``` + ```bash # Run just unit tests mage test:unit diff --git a/docs/porting-a-new-endpoint.md b/docs/porting-a-new-endpoint.md index e887a44..1e14f60 100644 --- a/docs/porting-a-new-endpoint.md +++ b/docs/porting-a-new-endpoint.md @@ -219,7 +219,7 @@ An example use case where unit tests proved to be highly beneficial is the `filt ## Run Tests -Run `mage test:python` and verify that our Go implementation passes the existing tests. +Run `mage test:python` and verify that our Go implementation passes the existing tests. You can also use `mage test:pythonSpecific ` to run a specific set of tests. There is one caveat to these tests; occasionally, they may have a Python bias, meaning that the tests pass due to Python's dynamic nature, while our Go tests might fail because they are strongly typed. Another issue may arise if the Python implementation does not consistently return the same error messages. Therefore, it may be necessary to submit a PR to [mlflow](https://github.com/mlflow/mlflow) to adjust the existing tests. diff --git a/magefiles/configure.go b/magefiles/configure.go index ecfb4f6..a8b5b5a 100644 --- a/magefiles/configure.go +++ b/magefiles/configure.go @@ -35,7 +35,7 @@ func Configure() error { return err } - // Archive the MLFlow pre-built UI + // Archive the MLflow pre-built UI if err := tar( "-C", "/usr/local/python/current/lib/python3.8/site-packages/mlflow", "-czvf", From a3bd537dcb58fdc10420103ea43d82b2e7738ce4 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 16 Oct 2024 00:02:17 +0000 Subject: [PATCH 17/24] Override test_set_tag execution Signed-off-by: Juan Escalada --- conftest.py | 5 +++++ tests/override_test_sqlalchemy_store.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/conftest.py b/conftest.py index d613c12..c47b04f 100644 --- a/conftest.py +++ b/conftest.py @@ -44,6 +44,11 @@ def pytest_configure(config): "tests.store.tracking.test_sqlalchemy_store.test_log_param_max_length_value", "tests/override_test_sqlalchemy_store.py", ), + # This test uses monkeypatch.setenv which does not flow through to the Go side. + ( + "tests.store.tracking.test_sqlalchemy_store.test_set_tag", + "tests/override_test_sqlalchemy_store.py", + ), # This tests calls the store using invalid metric entity that cannot be converted # to its proto counterpart. # Example: entities.Metric("invalid_metric", None, (int(time.time() * 1000)), 0).to_proto() diff --git a/tests/override_test_sqlalchemy_store.py b/tests/override_test_sqlalchemy_store.py index 6bc5e5a..79e932a 100644 --- a/tests/override_test_sqlalchemy_store.py +++ b/tests/override_test_sqlalchemy_store.py @@ -13,6 +13,10 @@ def test_log_param_max_length_value(store: SqlAlchemyStore, monkeypatch): () +def test_set_tag(store: SqlAlchemyStore, monkeypatch): + () + + def test_log_batch_null_metrics(store: SqlAlchemyStore): () From 19404be432e51380ad3a36972021d50e56b7e171 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 16 Oct 2024 07:33:53 +0000 Subject: [PATCH 18/24] Extract helper functions from SetTag Signed-off-by: Juan Escalada --- pkg/tracking/store/sql/tags.go | 119 +++++++++++++++++---------------- 1 file changed, 63 insertions(+), 56 deletions(-) diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index d92c11a..f13bd2c 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -147,13 +147,6 @@ func (s TrackingSQLStore) DeleteTag( func (s TrackingSQLStore) SetTag( ctx context.Context, runID, key, value string, ) *contract.Error { - if runID == "" { - return contract.NewError( - protos.ErrorCode_INVALID_PARAMETER_VALUE, - "RunID cannot be empty", - ) - } - var run models.Run err := s.db.Where("run_uuid = ?", runID).First(&run).Error @@ -171,72 +164,86 @@ func (s TrackingSQLStore) SetTag( return contractError } - var tag models.Tag - result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) - - if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), - result.Error, - ) - } + return s.handleTagUpsert(transaction, runID, key, value) + }) - if result.RowsAffected == 1 { - tag.Value = value - if err := transaction.Save(&tag).Error; err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), - err, - ) - } - } else { - newTag := models.Tag{ - RunID: runID, - Key: key, - Value: value, - } - if err := transaction.Create(&newTag).Error; err != nil { - return contract.NewErrorWith( - protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), - err, - ) - } + if err != nil { + var contractError *contract.Error + if errors.As(err, &contractError) { + return contractError } - return nil - }) + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Set tag transaction failed for run_id %q", runID), + err, + ) + } if key == utils.TagRunName { - runStatus := run.Status.String() + return s.handleRunNameUpdate(ctx, run, runID, value) + } - var endTimePtr *int64 - if run.EndTime.Valid { - endTimePtr = &run.EndTime.Int64 - } + return nil +} + +// Handle tag creation and update. +func (s TrackingSQLStore) handleTagUpsert( + transaction *gorm.DB, runID, key, value string, +) error { + var tag models.Tag + result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) - if err := s.UpdateRun(ctx, runID, runStatus, endTimePtr, value); err != nil { + if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to query tag for run_id %q and key %q", runID, key), + result.Error, + ) + } + + if result.RowsAffected == 1 { + tag.Value = value + if err := transaction.Save(&tag).Error; err != nil { return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Failed to update run info for run_id %q", runID), + fmt.Sprintf("Failed to update tag for run_id %q and key %q", runID, key), + err, + ) + } + } else { + newTag := models.Tag{ + RunID: runID, + Key: key, + Value: value, + } + if err := transaction.Create(&newTag).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Failed to create tag for run_id %q and key %q", runID, key), err, ) } - - return nil } - if err != nil { - var contractError *contract.Error - if errors.As(err, &contractError) { - return contractError - } + return nil +} + +// Handles updating the run name when setting tag. +func (s TrackingSQLStore) handleRunNameUpdate( + ctx context.Context, run models.Run, runID, value string, +) *contract.Error { + runStatus := run.Status.String() + var endTimePtr *int64 + if run.EndTime.Valid { + endTimePtr = &run.EndTime.Int64 + } + + if err := s.UpdateRun(ctx, runID, runStatus, endTimePtr, value); err != nil { return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, - fmt.Sprintf("Set tag transaction failed for run_id %q", runID), + fmt.Sprintf("Failed to update run info for run_id %q", runID), err, ) } From 87c2efef739eeac3117ca21e4c404f14d8cccfa0 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 16 Oct 2024 08:15:24 +0000 Subject: [PATCH 19/24] Revert unnecessary changes Signed-off-by: Juan Escalada --- pkg/tracking/store/sql/runs_internal_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/tracking/store/sql/runs_internal_test.go b/pkg/tracking/store/sql/runs_internal_test.go index d20b437..f42583d 100644 --- a/pkg/tracking/store/sql/runs_internal_test.go +++ b/pkg/tracking/store/sql/runs_internal_test.go @@ -1,3 +1,4 @@ +//nolint:ireturn package sql import ( From a7b73901faba744fb42c2af16499757db33fe118 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 16 Oct 2024 08:20:27 +0000 Subject: [PATCH 20/24] Fix postCreate.sh Signed-off-by: Juan Escalada --- .devcontainer/postCreate.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 .devcontainer/postCreate.sh diff --git a/.devcontainer/postCreate.sh b/.devcontainer/postCreate.sh old mode 100644 new mode 100755 From b27e4f6eb570ab839bb0665b59821194585ce326 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 16 Oct 2024 09:09:05 +0000 Subject: [PATCH 21/24] Fix linter error Signed-off-by: Juan Escalada --- pkg/tracking/store/sql/tags.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index f13bd2c..530bc4e 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -148,8 +148,8 @@ func (s TrackingSQLStore) SetTag( ctx context.Context, runID, key, value string, ) *contract.Error { var run models.Run - err := s.db.Where("run_uuid = ?", runID).First(&run).Error + err := s.db.Where("run_uuid = ?", runID).First(&run).Error if err != nil { return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, @@ -166,7 +166,6 @@ func (s TrackingSQLStore) SetTag( return s.handleTagUpsert(transaction, runID, key, value) }) - if err != nil { var contractError *contract.Error if errors.As(err, &contractError) { @@ -192,8 +191,8 @@ func (s TrackingSQLStore) handleTagUpsert( transaction *gorm.DB, runID, key, value string, ) error { var tag models.Tag - result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) + result := transaction.Where("run_uuid = ? AND key = ?", runID, key).First(&tag) if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, From fe585cc323a7f182ed0d2383e691ee7cb4fa5920 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Wed, 16 Oct 2024 09:57:46 +0000 Subject: [PATCH 22/24] Fix format issue Signed-off-by: Juan Escalada --- mlflow_go/store/tracking.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlflow_go/store/tracking.py b/mlflow_go/store/tracking.py index 09f0151..f1f4f5f 100644 --- a/mlflow_go/store/tracking.py +++ b/mlflow_go/store/tracking.py @@ -205,6 +205,7 @@ def set_tag(self, run_id, tag): request = SetTag(run_id=run_id, key=tag.key, value=tag.value) self.service.call_endpoint(get_lib().TrackingServiceSetTag, request) + def TrackingStore(cls): return type(cls.__name__, (_TrackingStore, cls), {}) From 2f753ad19792975b2f4b925ef76a5dd99eef7ec5 Mon Sep 17 00:00:00 2001 From: Juan Escalada Date: Sat, 19 Oct 2024 05:59:49 +0000 Subject: [PATCH 23/24] Simplify handleRunNameUpdate Signed-off-by: Juan Escalada --- pkg/tracking/store/sql/tags.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/tracking/store/sql/tags.go b/pkg/tracking/store/sql/tags.go index 530bc4e..72c7ab4 100644 --- a/pkg/tracking/store/sql/tags.go +++ b/pkg/tracking/store/sql/tags.go @@ -232,14 +232,12 @@ func (s TrackingSQLStore) handleTagUpsert( func (s TrackingSQLStore) handleRunNameUpdate( ctx context.Context, run models.Run, runID, value string, ) *contract.Error { - runStatus := run.Status.String() - var endTimePtr *int64 if run.EndTime.Valid { endTimePtr = &run.EndTime.Int64 } - if err := s.UpdateRun(ctx, runID, runStatus, endTimePtr, value); err != nil { + if err := s.UpdateRun(ctx, runID, run.Status.String(), endTimePtr, value); err != nil { return contract.NewErrorWith( protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("Failed to update run info for run_id %q", runID), From 069212ca8fe263008248b17c6ca36523872c61c2 Mon Sep 17 00:00:00 2001 From: nojaf Date: Sat, 19 Oct 2024 09:39:58 +0000 Subject: [PATCH 24/24] Refactor PythonSpecific Signed-off-by: nojaf --- CONTRIBUTING.md | 5 ++++- magefiles/tests.go | 22 +++++++--------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a16f9c7..36f18ec 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -67,7 +67,10 @@ mage test:python ```bash # Run specific MLflow Python tests (matches all tests containing the argument) -mage test:pythonSpecific +mage test:pythonSpecific + +#Example +mage test:pythonSpecific ".mlflow.repo/tests/tracking/test_rest_tracking.py::test_rename_experiment" ``` ```bash diff --git a/magefiles/tests.go b/magefiles/tests.go index a66a639..d4d6629 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -28,7 +28,7 @@ func cleanUpMemoryFile() error { return nil } -func runPythonTests(testFiles []string, testName string) error { +func runPythonTests(pytestArgs []string) error { libpath, err := os.MkdirTemp("", "") if err != nil { return err @@ -46,15 +46,9 @@ func runPythonTests(testFiles []string, testName string) error { args := []string{ "--confcutdir=.", + "-k", "not [file", } - args = append(args, testFiles...) - - // Add testName filter if provided - if testName != "" { - args = append(args, "-k", testName, "-vv") - } else { - args = append(args, "-k", "not [file") - } + args = append(args, pytestArgs...) // Run the tests (currently just the server ones) if err := sh.RunWithV(map[string]string{ @@ -75,17 +69,15 @@ func (Test) Python() error { ".mlflow.repo/tests/tracking/test_model_registry.py", ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", - }, "") + }) } // Run specific Python test against the Go backend. func (Test) PythonSpecific(testName string) error { return runPythonTests([]string{ - ".mlflow.repo/tests/tracking/test_rest_tracking.py", - ".mlflow.repo/tests/tracking/test_model_registry.py", - ".mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py", - ".mlflow.repo/tests/store/model_registry/test_sqlalchemy_store.py", - }, testName) + testName, + "-vv", + }) } // Run the Go unit tests.