From 3910426dcd4a17f4503d0e088f726a2383374ce1 Mon Sep 17 00:00:00 2001 From: Anda Zhou <83614683+azhou-determined@users.noreply.github.com> Date: Fri, 25 Oct 2024 18:22:19 -0700 Subject: [PATCH] feat: remove searcher context from harness and master [MD-498] (#10131) Co-authored-by: Ryan Co-authored-by: Guangqing Tang <40620519+gt2345@users.noreply.github.com> Co-authored-by: Michael Kardash --- .circleci/real_config.yml | 1 + docs/.redirects/redirects.json | 11 +- .../get-started/architecture/introduction.rst | 2 - docs/get-started/example-solutions/_index.rst | 15 - docs/get-started/webui-qs.rst | 2 - .../apis-howto/api-core-ug-basic.rst | 50 +- .../api-guides/apis-howto/api-core-ug.rst | 6 +- .../api-guides/apis-howto/api-keras-ug.rst | 200 +- .../api-guides/apis-howto/api-pytorch-ug.rst | 17 +- .../apis-howto/deepspeed/_index.rst | 4 - .../apis-howto/deepspeed/autotuning.rst | 312 -- docs/model-dev-guide/create-experiment.rst | 25 + docs/model-dev-guide/debug-models.rst | 10 +- .../dtrain/config-templates.rst | 4 - .../dtrain/reproducibility.rst | 14 +- .../hyperparameter/search-methods/_index.rst | 8 +- .../search-methods/hp-adaptive-asha.rst | 38 +- .../search-methods/hp-custom.rst | 154 - .../hyperparameter/search-methods/hp-grid.rst | 15 +- .../search-methods/hp-random.rst | 4 +- .../search-methods/hp-single.rst | 5 +- docs/model-dev-guide/profiling.rst | 6 +- docs/reference/_index.rst | 1 - docs/reference/custom-searcher-reference.rst | 89 - .../reference/experiment-config-reference.rst | 260 +- docs/reference/training/_index.rst | 1 + .../reference/training/api-core-reference.rst | 7 - docs/reference/training/api-det-reference.rst | 7 + .../training/api-keras-reference.rst | 23 + .../training/api-transformers-reference.rst | 11 + docs/release-notes.rst | 7 +- docs/release-notes/remove-custom-searcher.rst | 7 + docs/tools/tensorboard.rst | 19 +- .../transition-managed-determined.rst | 2 - docs/tutorials/pachyderm-cat-dog.rst | 2 - docs/tutorials/pytorch-mnist-tutorial.rst | 11 +- docs/tutorials/quickstart-mdldev.rst | 4 +- .../tutorials/viewing-epoch-based-metrics.rst | 24 +- e2e_tests/tests/cluster/test_slurm.py | 4 +- e2e_tests/tests/cluster/test_users.py | 1 - e2e_tests/tests/config.py | 14 +- e2e_tests/tests/experiment/__init__.py | 1 - e2e_tests/tests/experiment/experiment.py | 123 - e2e_tests/tests/experiment/noop.py | 12 +- e2e_tests/tests/experiment/test_core.py | 40 +- .../tests/experiment/test_custom_searcher.py | 498 --- .../test_custom_searcher_asha_2a.py | 91 - e2e_tests/tests/experiment/test_launch.py | 3 +- e2e_tests/tests/experiment/test_metrics.py | 12 +- e2e_tests/tests/experiment/test_noop.py | 1 - .../tests/experiment/test_pending_hpc.py | 7 +- e2e_tests/tests/experiment/test_tf_keras.py | 51 +- .../fixtures/core_api/11_generic_metrics.yaml | 1 - .../core_api/arbitrary_workload_order.yaml | 3 +- .../core_api/pytorch_profiler_sync.yaml | 1 - e2e_tests/tests/fixtures/core_api/sleep.yaml | 2 - e2e_tests/tests/fixtures/core_api/whoami.yaml | 1 - .../core_api_custom_searcher.py | 105 - .../custom_searcher/core_api_model.yaml | 8 - .../core_api_searcher_asha.yaml | 9 - .../core_api_searcher_random.yaml | 9 - .../fixtures/custom_searcher/model_coreapi.py | 84 - .../fixtures/custom_searcher/searchers.py | 645 ---- .../custom_searcher_exp/adaptive.yaml | 22 - .../fixtures/custom_searcher_exp/model_def.py | 228 -- .../fixtures/custom_searcher_exp/single.yaml | 24 - .../tests/fixtures/failures/bad-image.yaml | 4 +- .../fixtures/failures/bad-pbs-option.yaml | 2 - .../fixtures/failures/bad-slurm-option.yaml | 2 - .../failures/docker-login-failure.yaml | 2 - .../slurm-requested-node-not-available.yaml | 4 +- .../failures/unsupported-slurm-option.yaml | 2 - .../tests/fixtures/hpc/embedded-quotes.yaml | 4 +- .../fixtures/hpc/embedded-single-quote.yaml | 4 +- .../mnist_pytorch/adaptive_short.yaml | 6 +- .../mnist_pytorch/const-profiling.yaml | 4 +- .../mnist_pytorch/const-pytorch11.yaml | 4 +- .../distributed-stop-requested.yaml | 6 +- .../fixtures/mnist_pytorch/failable.yaml | 2 - .../mnist_pytorch/failable_model_def.py | 1 + .../tests/fixtures/mnist_pytorch/profiling.py | 16 +- .../tests/fixtures/mnist_pytorch/random.yaml | 8 +- .../mnist_pytorch/stop_requested_model_def.py | 2 +- e2e_tests/tests/fixtures/noop/train.py | 10 +- .../tests/fixtures/ports-proxy/config.yaml | 1 - e2e_tests/tests/nightly/test_distributed.py | 94 +- examples/Makefile | 5 +- .../iris_tf_keras/adaptive.yaml | 11 +- .../computer_vision/iris_tf_keras/const.yaml | 9 +- .../iris_tf_keras/distributed.yaml | 9 +- .../iris_tf_keras/model_def.py | 105 - .../computer_vision/iris_tf_keras/train.py | 143 + examples/deepspeed/dcgan/README.md | 49 + examples/deepspeed/dcgan/data.py | 104 + examples/deepspeed/dcgan/ds_config.json | 15 + examples/deepspeed/dcgan/gan_model.py | 73 + examples/deepspeed/dcgan/mnist.yaml | 33 + examples/deepspeed/dcgan/model.py | 208 ++ examples/deepspeed/dcgan/trainer.py | 38 + examples/deepspeed/gpt_neox/det_utils.py | 2 +- .../deepspeed_autotune/torchvision/README.md | 61 - .../torchvision/core_api/deepspeed.yaml | 19 - .../torchvision/core_api/ds_config.json | 29 - .../torchvision/core_api/script.py | 123 - .../deepspeed_trial/deepspeed.yaml | 22 - .../deepspeed_trial/ds_config.json | 29 - .../torchvision/deepspeed_trial/model_def.py | 89 - .../detsd/pipeline.py | 61 +- .../detsd/trainer.py | 65 +- .../finetune_const.yaml | 10 +- .../finetune_const_advanced.yaml | 10 +- .../generate_grid.yaml | 8 +- .../distributed_inference.yaml | 3 +- examples/features/ports/ray_launcher.yaml | 1 - .../core_api_config.yaml | 1 - .../torch_batch_process_config.yaml | 1 - .../distributed.yaml | 1 - examples/features/unmanaged/1.yaml | 2 - examples/features/unmanaged/2.yaml | 2 - examples/features/unmanaged/3.yaml | 2 - .../features/unmanaged/ray/ray_hp_search.py | 3 +- examples/hf_trainer_api/README.md | 12 - .../hf_image_classification/adaptive.yaml | 4 +- .../hf_image_classification/const.yaml | 2 - .../hf_image_classification/const_epochs.yaml | 2 - .../hf_image_classification/deepspeed.yaml | 2 - .../hf_image_classification/distributed.yaml | 2 - .../image_classification.py | 7 +- .../hf_image_classification/util.py | 140 + .../hf_language_modeling/adaptive.yaml | 4 +- .../hf_language_modeling/const.yaml | 2 - .../hf_language_modeling/const_epochs.yaml | 2 - .../hf_language_modeling/deepspeed.yaml | 2 - .../hf_language_modeling/distributed.yaml | 2 - .../hf_language_modeling/run_clm.py | 7 +- .../hf_language_modeling/util.py | 140 + examples/tutorials/core_api/0_start.yaml | 2 - examples/tutorials/core_api/1_metrics.yaml | 1 - examples/tutorials/core_api/2_checkpoints.py | 4 +- .../tutorials/core_api/2_checkpoints.yaml | 1 - examples/tutorials/core_api/3_hpsearch.py | 62 +- examples/tutorials/core_api/3_hpsearch.yaml | 3 +- examples/tutorials/core_api/4_distributed.py | 73 +- .../tutorials/core_api/4_distributed.yaml | 1 - .../core_api_pytorch_mnist/adaptive.yaml | 4 +- .../core_api_pytorch_mnist/checkpoints.yaml | 3 +- .../core_api_pytorch_mnist/const.yaml | 3 +- .../core_api_pytorch_mnist/distributed.yaml | 2 - .../core_api_pytorch_mnist/metrics.yaml | 3 +- .../model_def_adaptive.py | 68 +- .../model_def_checkpoints.py | 3 +- .../model_def_distributed.py | 56 +- .../model_def_metrics.py | 28 +- examples/tutorials/mnist_pytorch/README.md | 32 +- .../tutorials/mnist_pytorch/adaptive.yaml | 6 +- examples/tutorials/mnist_pytorch/const.yaml | 4 +- .../tutorials/mnist_pytorch/dist_random.yaml | 6 +- .../tutorials/mnist_pytorch/distributed.yaml | 6 +- examples/tutorials/mnist_pytorch/train.py | 32 +- harness/determined/_execution.py | 18 +- harness/determined/cli/cli.py | 86 +- harness/determined/cli/experiment.py | 20 + harness/determined/common/api/bindings.py | 980 +---- harness/determined/common/api/errors.py | 5 +- harness/determined/constants.py | 19 - harness/determined/core/__init__.py | 2 + harness/determined/core/_context.py | 24 +- harness/determined/core/_distributed.py | 63 +- harness/determined/core/_searcher.py | 206 +- harness/determined/core/_train.py | 3 - harness/determined/exec/harness.py | 90 +- .../experimental/core_v2/_core_context_v2.py | 22 +- .../experimental/core_v2/_core_v2.py | 1 - harness/determined/keras/__init__.py | 1 + harness/determined/keras/_callback.py | 465 +++ harness/determined/keras/_load.py | 18 + .../determined/keras/_tensorboard_callback.py | 7 +- harness/determined/keras/_tf_keras_trial.py | 26 + harness/determined/keras/callbacks.py | 64 +- harness/determined/launch/deepspeed.py | 13 +- harness/determined/launch/horovod.py | 14 +- harness/determined/launch/tensorflow.py | 102 + .../determined/launch/torch_distributed.py | 18 +- .../determined/layers/_workload_sequencer.py | 2 +- harness/determined/pytorch/__init__.py | 15 +- harness/determined/pytorch/_pytorch_trial.py | 363 +- harness/determined/pytorch/_trainer.py | 102 +- harness/determined/pytorch/_trainer_utils.py | 145 + .../determined/pytorch/deepspeed/__init__.py | 1 + .../pytorch/deepspeed/_deepspeed_context.py | 207 +- .../pytorch/deepspeed/_deepspeed_trial.py | 1031 ++++-- .../determined/pytorch/deepspeed/_trainer.py | 335 ++ harness/determined/pytorch/dsat/__init__.py | 27 - harness/determined/pytorch/dsat/__main__.py | 48 - .../pytorch/dsat/_dsat_search_method.py | 1432 -------- harness/determined/pytorch/dsat/_run_dsat.py | 91 - harness/determined/pytorch/dsat/_utils.py | 530 --- harness/determined/pytorch/dsat/defaults.py | 76 - harness/determined/searcher/__init__.py | 13 - .../searcher/_remote_search_runner.py | 106 - harness/determined/searcher/_search_method.py | 494 --- harness/determined/searcher/_search_runner.py | 375 -- .../determined/transformers/_hf_callback.py | 381 +- harness/tests/cli/test_cli.py | 115 + harness/tests/cli/util.py | 18 + harness/tests/core/test_searcher.py | 40 +- harness/tests/custom_search_mocks.py | 161 - .../fixtures/deepspeed_linear_model.py | 67 +- .../fixtures/pytorch_amp/apex_amp.yaml | 17 - .../pytorch_amp/apex_amp_distributed.yaml | 25 - .../fixtures/pytorch_amp/auto_amp.yaml | 17 - .../pytorch_amp/auto_amp_distributed.yaml | 25 - .../fixtures/pytorch_amp/manual_amp.yaml | 17 - .../pytorch_amp/manual_amp_distributed.yaml | 25 - .../integrations/test_deepspeed_trial.py | 657 +--- .../tests/experiment/keras}/__init__.py | 0 .../tests/experiment/keras/test_callback.py | 511 +++ .../experiment/keras/test_tf_keras_trial.py | 79 - harness/tests/experiment/keras/train.py | 68 + .../pytorch/test_deepspeed_autotuning.py | 2078 ----------- .../experiment/pytorch/test_pytorch_trial.py | 24 +- harness/tests/experiment/test_utils.py | 47 + .../experiment/transformers}/__init__.py | 0 .../experiment/transformers/test_callback.py | 495 +++ harness/tests/experiment/utils.py | 121 +- harness/tests/launch/test_tensorflow.py | 83 + harness/tests/search_methods.py | 555 --- harness/tests/test_custom_searcher.py | 46 - .../internal/api_config_policies_intg_test.go | 3 + master/internal/api_experiment.go | 130 +- master/internal/api_experiment_intg_test.go | 27 +- master/internal/api_logretention_intg_test.go | 7 +- master/internal/api_runs.go | 9 - master/internal/api_trials.go | 96 +- master/internal/api_trials_intg_test.go | 18 +- master/internal/core.go | 3 - master/internal/core_searcher.go | 49 - master/internal/db/postgres_experiments.go | 38 +- .../db/postgres_experiments_intg_test.go | 6 +- master/internal/db/postgres_snapshots_test.go | 66 - master/internal/db/postgres_test_utils.go | 9 +- master/internal/db/postgres_trial.go | 31 +- master/internal/experiment.go | 340 +- .../internal/experiment/authz_basic_impl.go | 7 - master/internal/experiment/authz_iface.go | 4 - .../internal/experiment/authz_permissive.go | 8 - master/internal/experiment/authz_rbac.go | 7 - .../internal/experiment/experiment_iface.go | 33 +- master/internal/restore.go | 153 +- master/internal/restore_test.go | 73 +- master/internal/telemetry/telemetry_test.go | 7 +- .../internal/templates/service_intg_test.go | 9 +- master/internal/trial.go | 35 +- master/internal/trial_intg_test.go | 51 +- master/internal/trials/postgres_trials.go | 21 +- master/pkg/model/searcher.go | 18 - master/pkg/model/test_utils.go | 7 +- .../pkg/schemas/expconf/experiment_config.go | 10 - master/pkg/schemas/expconf/latest.go | 1 - master/pkg/schemas/expconf/searcher_config.go | 174 +- .../expconf/zgen_adaptive_asha_config_v0.go | 46 +- .../expconf/zgen_adaptive_config_v0.go | 11 +- .../expconf/zgen_adaptive_simple_config_v0.go | 11 +- .../expconf/zgen_async_halving_config_v0.go | 46 +- .../schemas/expconf/zgen_grid_config_v0.go | 11 +- .../schemas/expconf/zgen_random_config_v0.go | 11 +- .../expconf/zgen_searcher_config_v0.go | 6 +- .../schemas/expconf/zgen_single_config_v0.go | 11 +- .../expconf/zgen_sync_halving_config_v0.go | 11 +- master/pkg/schemas/zgen_schemas.go | 52 +- master/pkg/searcher/actions.go | 73 + master/pkg/searcher/adaptive_asha.go | 62 +- master/pkg/searcher/adaptive_asha_test.go | 227 +- master/pkg/searcher/asha.go | 327 -- master/pkg/searcher/asha_stopping.go | 291 +- master/pkg/searcher/asha_stopping_test.go | 618 ++-- master/pkg/searcher/asha_test.go | 231 -- master/pkg/searcher/custom_search.go | 147 - master/pkg/searcher/custom_search_test.go | 198 - .../searcher/custom_searcher_events_queue.go | 158 - master/pkg/searcher/grid.go | 62 +- master/pkg/searcher/grid_test.go | 83 +- master/pkg/searcher/operations.go | 295 -- master/pkg/searcher/random.go | 48 +- master/pkg/searcher/random_test.go | 130 +- master/pkg/searcher/search_method.go | 62 +- master/pkg/searcher/searcher.go | 138 +- master/pkg/searcher/simulate.go | 285 +- master/pkg/searcher/simulate_test.go | 83 + master/pkg/searcher/tournament.go | 72 +- master/pkg/searcher/tournament_test.go | 152 +- master/pkg/searcher/util_test.go | 374 +- performance/daist/daist/metrics/base.py | 6 +- performance/daist/daist/metrics/config.yaml | 9 +- performance/daist/daist/metrics/model_def.py | 12 +- .../daist/daist/metrics/test_latency.py | 2 +- proto/buf.image.bin | Bin 673241 -> 665762 bytes proto/patches/api.json | 14 - proto/pkg/apiv1/api.pb.go | 3227 ++++++++--------- proto/pkg/apiv1/api.pb.gw.go | 424 --- proto/pkg/apiv1/experiment.pb.go | 1599 ++++---- proto/pkg/apiv1/trial.pb.go | 1229 +++---- proto/pkg/experimentv1/searcher.pb.go | 927 ++--- proto/src/determined/api/v1/api.proto | 45 - proto/src/determined/api/v1/experiment.proto | 29 +- proto/src/determined/api/v1/trial.proto | 31 - .../determined/experiment/v1/searcher.proto | 110 +- .../expconf/v0/searcher-adaptive-asha.json | 17 +- .../expconf/v0/searcher-adaptive-simple.json | 1 - schemas/expconf/v0/searcher-adaptive.json | 1 - .../expconf/v0/searcher-async-halving.json | 17 +- schemas/expconf/v0/searcher-custom.json | 1 + schemas/expconf/v0/searcher-grid.json | 1 - schemas/expconf/v0/searcher-random.json | 1 - schemas/expconf/v0/searcher-single.json | 1 - schemas/expconf/v0/searcher-sync-halving.json | 1 - schemas/expconf/v0/searcher.json | 11 +- schemas/test_cases/v0/defaults.yaml | 30 +- schemas/test_cases/v0/experiment.yaml | 10 +- schemas/test_cases/v0/merging.yaml | 4 - .../CompareHyperparameters.test.mock.tsx | 3 - .../components/ComparisonView.test.mock.tsx | 1 - .../ExperimentContinueModal.module.scss | 3 - .../components/ExperimentContinueModal.tsx | 142 +- .../src/components/ExperimentCreateModal.tsx | 59 +- .../components/HyperparameterSearchModal.tsx | 37 - .../non-scalar-metrics-4078.json | 3 - .../responses/experiment-details/set-a.json | 49 +- webui/react/src/pages/Dashboard.test.tsx | 3 - .../ExperimentDetails.test.mock.ts | 4 - .../ExperimentDetailsHeader.tsx | 10 +- .../pages/TrialDetails/TrialInfoBox.test.tsx | 1 - webui/react/src/services/api-ts-sdk/api.ts | 890 +---- webui/react/src/types.ts | 4 - webui/react/src/utils/experiment.test.ts | 8 - webui/react/src/utils/experiment.ts | 15 - .../react/src/utils/tests/generateTestData.ts | 2 - 337 files changed, 11242 insertions(+), 22501 deletions(-) delete mode 100644 docs/model-dev-guide/api-guides/apis-howto/deepspeed/autotuning.rst delete mode 100644 docs/model-dev-guide/hyperparameter/search-methods/hp-custom.rst delete mode 100644 docs/reference/custom-searcher-reference.rst create mode 100644 docs/reference/training/api-transformers-reference.rst create mode 100644 docs/release-notes/remove-custom-searcher.rst delete mode 100644 e2e_tests/tests/experiment/test_custom_searcher.py delete mode 100644 e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py delete mode 100644 e2e_tests/tests/fixtures/custom_searcher/core_api_custom_searcher.py delete mode 100644 e2e_tests/tests/fixtures/custom_searcher/core_api_model.yaml delete mode 100644 e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_asha.yaml delete mode 100644 e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_random.yaml delete mode 100644 e2e_tests/tests/fixtures/custom_searcher/model_coreapi.py delete mode 100644 e2e_tests/tests/fixtures/custom_searcher/searchers.py delete mode 100644 e2e_tests/tests/fixtures/custom_searcher_exp/adaptive.yaml delete mode 100644 e2e_tests/tests/fixtures/custom_searcher_exp/model_def.py delete mode 100644 e2e_tests/tests/fixtures/custom_searcher_exp/single.yaml delete mode 100644 examples/computer_vision/iris_tf_keras/model_def.py create mode 100644 examples/computer_vision/iris_tf_keras/train.py create mode 100644 examples/deepspeed/dcgan/README.md create mode 100644 examples/deepspeed/dcgan/data.py create mode 100644 examples/deepspeed/dcgan/ds_config.json create mode 100644 examples/deepspeed/dcgan/gan_model.py create mode 100644 examples/deepspeed/dcgan/mnist.yaml create mode 100644 examples/deepspeed/dcgan/model.py create mode 100644 examples/deepspeed/dcgan/trainer.py delete mode 100644 examples/deepspeed_autotune/torchvision/README.md delete mode 100644 examples/deepspeed_autotune/torchvision/core_api/deepspeed.yaml delete mode 100644 examples/deepspeed_autotune/torchvision/core_api/ds_config.json delete mode 100644 examples/deepspeed_autotune/torchvision/core_api/script.py delete mode 100644 examples/deepspeed_autotune/torchvision/deepspeed_trial/deepspeed.yaml delete mode 100644 examples/deepspeed_autotune/torchvision/deepspeed_trial/ds_config.json delete mode 100644 examples/deepspeed_autotune/torchvision/deepspeed_trial/model_def.py create mode 100644 examples/hf_trainer_api/hf_image_classification/util.py create mode 100644 examples/hf_trainer_api/hf_language_modeling/util.py create mode 100644 harness/determined/keras/_callback.py create mode 100644 harness/determined/launch/tensorflow.py create mode 100644 harness/determined/pytorch/_trainer_utils.py create mode 100644 harness/determined/pytorch/deepspeed/_trainer.py delete mode 100644 harness/determined/pytorch/dsat/__init__.py delete mode 100644 harness/determined/pytorch/dsat/__main__.py delete mode 100644 harness/determined/pytorch/dsat/_dsat_search_method.py delete mode 100644 harness/determined/pytorch/dsat/_run_dsat.py delete mode 100644 harness/determined/pytorch/dsat/_utils.py delete mode 100644 harness/determined/pytorch/dsat/defaults.py delete mode 100644 harness/determined/searcher/__init__.py delete mode 100644 harness/determined/searcher/_remote_search_runner.py delete mode 100644 harness/determined/searcher/_search_method.py delete mode 100644 harness/determined/searcher/_search_runner.py delete mode 100644 harness/tests/custom_search_mocks.py delete mode 100644 harness/tests/experiment/fixtures/pytorch_amp/apex_amp.yaml delete mode 100644 harness/tests/experiment/fixtures/pytorch_amp/apex_amp_distributed.yaml delete mode 100644 harness/tests/experiment/fixtures/pytorch_amp/auto_amp.yaml delete mode 100644 harness/tests/experiment/fixtures/pytorch_amp/auto_amp_distributed.yaml delete mode 100644 harness/tests/experiment/fixtures/pytorch_amp/manual_amp.yaml delete mode 100644 harness/tests/experiment/fixtures/pytorch_amp/manual_amp_distributed.yaml rename {e2e_tests/tests/fixtures/custom_searcher => harness/tests/experiment/keras}/__init__.py (100%) create mode 100644 harness/tests/experiment/keras/test_callback.py create mode 100644 harness/tests/experiment/keras/train.py delete mode 100644 harness/tests/experiment/pytorch/test_deepspeed_autotuning.py create mode 100644 harness/tests/experiment/test_utils.py rename {e2e_tests/tests/fixtures/custom_searcher_exp => harness/tests/experiment/transformers}/__init__.py (100%) create mode 100644 harness/tests/experiment/transformers/test_callback.py create mode 100644 harness/tests/launch/test_tensorflow.py delete mode 100644 harness/tests/search_methods.py delete mode 100644 harness/tests/test_custom_searcher.py delete mode 100644 master/internal/db/postgres_snapshots_test.go create mode 100644 master/pkg/searcher/actions.go delete mode 100644 master/pkg/searcher/asha.go delete mode 100644 master/pkg/searcher/asha_test.go delete mode 100644 master/pkg/searcher/custom_search.go delete mode 100644 master/pkg/searcher/custom_search_test.go delete mode 100644 master/pkg/searcher/custom_searcher_events_queue.go delete mode 100644 master/pkg/searcher/operations.go create mode 100644 master/pkg/searcher/simulate_test.go delete mode 100644 webui/react/src/components/ExperimentContinueModal.module.scss diff --git a/.circleci/real_config.yml b/.circleci/real_config.yml index ddf809fc97d..a5c92919fee 100644 --- a/.circleci/real_config.yml +++ b/.circleci/real_config.yml @@ -2603,6 +2603,7 @@ jobs: - run: pip install mypy pytest coverage - install-codecov - setup-paths + - run: make -C harness install - run: COVERAGE_FILE=$PWD/test-unit-harness-tf2-pycov make -C harness test-tf2 - run: coverage xml -i --data-file=./test-unit-harness-tf2-pycov - run: codecov -v -t $CODECOV_TOKEN -F harness diff --git a/docs/.redirects/redirects.json b/docs/.redirects/redirects.json index c9e6a0025f6..e6610643735 100644 --- a/docs/.redirects/redirects.json +++ b/docs/.redirects/redirects.json @@ -69,7 +69,6 @@ "reference/deploy/config/helm-config-reference": "../helm-config-reference.html", "reference/deploy/config/common-config-options": "../common-config-options.html", "reference/deploy/config/agent-config-reference": "../agent-config-reference.html", - "reference/searcher/custom-searcher-reference": "../custom-searcher-reference.html", "setup-cluster/security/tls": "../../manage/security/tls.html", "setup-cluster/security/scim": "../../manage/security/scim.html", "setup-cluster/security/saml": "../../manage/security/saml.html", @@ -86,6 +85,7 @@ "model-hub-library/transformers/tutorial": "../../model-dev-guide/api-guides/_index.html", "model-hub-library/mmdetection/overview": "../../model-dev-guide/api-guides/_index.html", "model-dev-guide/hyperparameter/search-methods/index": "_index.html", + "model-dev-guide/hyperparameter/search-methods/hp-custom": "_index.html", "model-dev-guide/api-guides/batch-processing/batch-process-api-ug": "../batch-process-api-ug.html", "model-dev-guide/best-practices/index": "../_index.html", "model-dev-guide/best-practices/_index": "../_index.html", @@ -93,6 +93,7 @@ "model-dev-guide/hyperparameter/index": "_index.html", "model-dev-guide/prepare-container/index": "_index.html", "model-dev-guide/dtrain/index": "_index.html", + "model-dev-guide/api-guides/apis-howto/deepspeed/autotuning": "_index.html", "model-dev-guide/api-guides/apis-howto/deepspeed/index": "_index.html", "model-dev-guide/api-guides/apis-howto/index": "_index.html", "model-dev-guide/api-guides/index": "_index.html", @@ -110,6 +111,8 @@ "integrations/pachyderm/pachyderm": "../data-transformers/pachyderm.html", "architecture/index": "../get-started/architecture/_index.html", "reference/index": "_index.html", + "reference/custom-searcher-reference": "_index.html", + "reference/searcher/custom-searcher-reference": "../_index.html", "model-dev-guide/index": "_index.html", "model-hub-library/index": "../model-dev-guide/api-guides/_index.html", "tutorials/index": "_index.html", @@ -133,9 +136,9 @@ "tutorials/tf-mnist-tutorial": "_index.html", "model-dev-guide/batch-processing/batch-process-api-ug": "../api-guides/batch-process-api-ug.html", "model-dev-guide/apis-howto/deepspeed/advanced": "../../api-guides/apis-howto/deepspeed/advanced.html", + "model-dev-guide/apis-howto/deepspeed/autotuning": "../../api-guides/apis-howto/deepspeed/_index.html", "model-dev-guide/apis-howto/deepspeed/deepspeed": "../../api-guides/apis-howto/deepspeed/deepspeed.html", "model-dev-guide/apis-howto/deepspeed/overview": "../../api-guides/apis-howto/deepspeed/_index.html", - "model-dev-guide/apis-howto/deepspeed/autotuning": "../../api-guides/apis-howto/deepspeed/autotuning.html", "model-dev-guide/apis-howto/deepspeed/pytorch2deepspeed": "../../api-guides/apis-howto/deepspeed/pytorch2deepspeed.html", "model-dev-guide/apis-howto/api-core-ug": "../api-guides/apis-howto/api-core-ug.html", "model-dev-guide/apis-howto/api-pytorch-ug": "../api-guides/apis-howto/api-pytorch-ug.html", @@ -150,11 +153,11 @@ "training/model-management/overview": "../../model-dev-guide/model-management/_index.html", "training/model-management/checkpoints": "../../model-dev-guide/model-management/checkpoints.html", "training/best-practices/overview": "../../model-dev-guide/_index.html", + "training/hyperparameter/search-methods/hp-custom": "../../../model-dev-guide/hyperparameter/search-methods/_index.html", "training/hyperparameter/search-methods/hp-random": "../../../model-dev-guide/hyperparameter/search-methods/hp-random.html", "training/hyperparameter/search-methods/hp-adaptive-asha": "../../../model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.html", "training/hyperparameter/search-methods/hp-grid": "../../../model-dev-guide/hyperparameter/search-methods/hp-grid.html", "training/hyperparameter/search-methods/hp-single": "../../../model-dev-guide/hyperparameter/search-methods/hp-single.html", - "training/hyperparameter/search-methods/hp-custom": "../../../model-dev-guide/hyperparameter/search-methods/hp-custom.html", "training/hyperparameter/search-methods/overview": "../../../model-dev-guide/hyperparameter/search-methods/_index.html", "training/hyperparameter/hp-constraints-det": "../../model-dev-guide/hyperparameter/hp-constraints-det.html", "training/hyperparameter/handle-trial-errors": "../../model-dev-guide/hyperparameter/handle-trial-errors.html", @@ -223,7 +226,6 @@ "cluster-setup-guide/historical-cluster-usage-data": "../manage/historical-cluster-usage-data.html", "cluster-setup-guide/workspaces": "../manage/workspaces.html", "quickstart-mdldev": "tutorials/quickstart-mdldev.html", - "reference/reference-searcher/custom-searcher-reference": "../custom-searcher-reference.html", "reference/reference-model-hub/modelhub/transformers-api": "../../training/_index.html", "reference/reference-model-hub/modelhub/mmdetection-api": "../../training/_index.html", "reference/reference-model-hub/index": "../training/_index.html", @@ -233,6 +235,7 @@ "reference/reference-deploy/config/helm-config-reference": "../../deploy/helm-config-reference.html", "reference/reference-deploy/config/common-config-options": "../../deploy/common-config-options.html", "reference/reference-deploy/index": "../deploy/_index.html", + "reference/reference-searcher/custom-searcher-reference": "../_index.html", "reference/reference-training/training/api-deepspeed-reference": "../../training/api-deepspeed-reference.html", "reference/reference-training/training/api-pytorch-reference": "../../training/api-pytorch-reference.html", "reference/reference-training/training/api-det-reference": "../../training/api-det-reference.html", diff --git a/docs/get-started/architecture/introduction.rst b/docs/get-started/architecture/introduction.rst index 815e4fbd261..46ed20db96f 100644 --- a/docs/get-started/architecture/introduction.rst +++ b/docs/get-started/architecture/introduction.rst @@ -810,8 +810,6 @@ In this example experiment configuration, numbers, strings, maps, and an array a searcher: name: single metric: error - max_length: - batches: 500 smaller_is_better: true environment: environment_variables: diff --git a/docs/get-started/example-solutions/_index.rst b/docs/get-started/example-solutions/_index.rst index 376a856a2e9..d9defbdf7cf 100644 --- a/docs/get-started/example-solutions/_index.rst +++ b/docs/get-started/example-solutions/_index.rst @@ -55,21 +55,6 @@ For an introduction to using the training APIs, please visit :ref:`Training APIs - Enron Email Corpus - :download:`gpt_neox.tgz ` -******************** - DeepSpeed Autotune -******************** - -.. list-table:: - :header-rows: 1 - - - - Framework - - Dataset - - Filename - - - - DeepSpeed (PyTorch) - - ImageNet (Generated) - - :download:`torchvision.tgz ` - - - Hugging Face (DeepSpeed/PyTorch) - Beans (Hugging Face) - :download:`hf_image_classification.tgz ` diff --git a/docs/get-started/webui-qs.rst b/docs/get-started/webui-qs.rst index d26b63c882d..12f44580889 100644 --- a/docs/get-started/webui-qs.rst +++ b/docs/get-started/webui-qs.rst @@ -158,8 +158,6 @@ our multi-trial search. Finally, we'll run a remote distributed training job. name: random metric: validation_loss max_trials: 20 - max_length: - batches: 1000 smaller_is_better: true entrypoint: python3 train.py diff --git a/docs/model-dev-guide/api-guides/apis-howto/api-core-ug-basic.rst b/docs/model-dev-guide/api-guides/apis-howto/api-core-ug-basic.rst index b809a265672..46f43851b53 100644 --- a/docs/model-dev-guide/api-guides/apis-howto/api-core-ug-basic.rst +++ b/docs/model-dev-guide/api-guides/apis-howto/api-core-ug-basic.rst @@ -25,8 +25,8 @@ the the following capabilities: - hyperparameter search - distributing work across multiple GPUs and/or nodes -These are the same features provided by the higher-level PyTorchTrial, DeepSpeedTrial, and -TFKerasTrial APIs: those APIs are implemented using the Core API. +These features are also available in the higher-level PyTorchTrial and DeepSpeedTrial APIs, both of +which are built on top of the Core API. This user guide shows you how to get started using the Core API. @@ -85,7 +85,7 @@ with only a few new lines of code. .. literalinclude:: ../../../../examples/tutorials/core_api/1_metrics.py :language: python :start-after: NEW: import determined - :end-before: def main + :end-at: import determined as det #. Enable ``logging``, using the ``det.LOG_FORMAT`` for logs. This enables useful log messages from the ``determined`` library, and ``det.LOG_FORMAT`` enables filter-by-level in the WebUI. @@ -250,27 +250,6 @@ runs a train-validate-report loop: :dedent: :start-at: hparams = info.trial.hparams -#. Modify ``main()`` to run the train-validate-report loop mentioned above by iterating through - ``core_context.searcher.operations()``. Each :class:`~determined.core.SearcherOperation` from - :meth:`~determined.core.SearcherContext.operations` has a ``length`` attribute that specifies the - absolute length of training to complete. After validating, report the searcher metric value using - ``op.report_completed()``. - - .. literalinclude:: ../../../../examples/tutorials/core_api/3_hpsearch.py - :language: python - :dedent: - :start-at: batch = starting_batch - :end-at: op.report_completed - -#. Because the training length can vary, you might exit the train-validate-report loop before saving - the last of your progress. To handle this, add a conditional save after the loop ends: - - .. literalinclude:: ../../../../examples/tutorials/core_api/3_hpsearch.py - :language: python - :dedent: - :start-at: if last_checkpoint_batch != steps_completed - :end-at: save_state - #. Create a new ``3_hpsearch.yaml`` file and add an ``entrypoint`` that invokes ``3_hpsearch.py``: .. literalinclude:: ../../../../examples/tutorials/core_api/3_hpsearch.yaml @@ -365,9 +344,8 @@ considerations are: :start-after: some logs are easier to read :end-at: logging.info -#. Only the chief worker is permitted to report training metrics, report validation metrics, upload - checkpoints, or report searcher operations completed. This rule applies to the steps you take - periodically during training: +#. Only the chief worker is permitted to report metrics, upload checkpoints, or report progress. + This rule applies to the steps you take periodically during training: .. literalinclude:: ../../../../examples/tutorials/core_api/4_distributed.py :language: python @@ -375,22 +353,6 @@ considerations are: :start-at: if steps_completed % 10 == 0 :end-at: return - The rule also applies to the steps you take after validating: - - .. literalinclude:: ../../../../examples/tutorials/core_api/4_distributed.py - :language: python - :dedent: - :start-after: only the chief may report validation metrics - :end-at: op.report_completed - - The rule also applies to the conditional save after the main loop completes: - - .. literalinclude:: ../../../../examples/tutorials/core_api/4_distributed.py - :language: python - :dedent: - :start-at: again, only the chief may upload checkpoints - :end-at: save_state - #. Create a ``4_distributed.yaml`` file by copying the ``3_distributed.yaml`` file and changing the first couple of lines: @@ -411,7 +373,7 @@ considerations are: .. literalinclude:: ../../../../examples/tutorials/core_api/4_distributed.yaml :language: yaml :start-at: searcher: - :end-at: max_length: + :end-at: metric: #. Run the code using the Determined CLI with the following command: diff --git a/docs/model-dev-guide/api-guides/apis-howto/api-core-ug.rst b/docs/model-dev-guide/api-guides/apis-howto/api-core-ug.rst index ca5ad851f98..5b51199eac9 100644 --- a/docs/model-dev-guide/api-guides/apis-howto/api-core-ug.rst +++ b/docs/model-dev-guide/api-guides/apis-howto/api-core-ug.rst @@ -332,8 +332,10 @@ settings in our experiment configuration file: - ``smaller_is_better``: ``True`` (This is equivalent to minimization vs. maximization of objective.) - ``max_trials``: 500 (This is the maximum number of trials the searcher should run.) -- ``max_length``: 20 epochs (The max length of a trial. For more information, visit Adaptive ASHA - in the :ref:`Experiment Configuration Reference `. +- ``time_metric``: ``epochs`` (This is the name of the "time" metric which we report in validation + metrics). +- ``max_time``: 20 (The max number of epochs a trial will report. For more information, visit + Adaptive ASHA in the :ref:`Experiment Configuration Reference `. In addition, we also need to define the hyperparameters themselves. Adaptive ASHA will pick values between the ``minval`` and ``maxval`` for each hyperparameter for each trial. diff --git a/docs/model-dev-guide/api-guides/apis-howto/api-keras-ug.rst b/docs/model-dev-guide/api-guides/apis-howto/api-keras-ug.rst index f871c9172ec..d59fd5621b3 100644 --- a/docs/model-dev-guide/api-guides/apis-howto/api-keras-ug.rst +++ b/docs/model-dev-guide/api-guides/apis-howto/api-keras-ug.rst @@ -7,7 +7,8 @@ .. meta:: :description: Learn how to use the Keras API to train a Keras model. This user guide walks you through loading your data, defining the model, customizing how the model.fit function is called, checkpointing, and callbacks. -In this guide, you'll learn how to use the Keras API. +In this guide, you'll learn how to use Determined's ``keras.DeterminedCallback`` while training your +Keras model. +---------------------------------------------------------------------+ | Visit the API reference | @@ -15,121 +16,148 @@ In this guide, you'll learn how to use the Keras API. | :ref:`keras-reference` | +---------------------------------------------------------------------+ -This document guides you through training a Keras model in Determined. You need to implement a trial -class that inherits :class:`~determined.keras.TFKerasTrial` and specify it as the entrypoint in the -:ref:`experiment-configuration`. +This document guides you through training a Keras model in Determined. You will need to update your +``model.fit()`` call to include a :class:`~determined.keras.DeterminedCallback` and submit it to a +Determined cluster. -To learn about this API, you can start by reading the trial definitions in the `Iris categorization -example +To learn about this API, you can start by reading the ``train.py`` script in the `Iris +categorization example `__. -*********** - Load Data -*********** +********************** + Configure Entrypoint +********************** -.. note:: +Determined requires you to launch training jobs by submitting them with an +:ref:`experiment-configuration`, which tells the Determined master how to start your container. For +Keras training, you should always wrap your training script in Determined's :ref:`TensorFlow +launcher `: - Before loading data, visit :ref:`load-model-data` to understand how to work with different - sources of data. +.. code:: yaml -Loading data is done by defining :meth:`~determined.keras.TFKerasTrial.build_training_data_loader` -and :meth:`~determined.keras.TFKerasTrial.build_validation_data_loader` methods. Each should return -one of the following data types: + entrypoint: >- + python3 -m determined.launch.tensorflow -- + python3 my_train.py --my-arg... -#. A tuple ``(x, y)`` of NumPy arrays. x must be a NumPy array (or array-like), a list of arrays (in - case the model has multiple inputs), or a dict mapping input names to the corresponding array, if - the model has named inputs. y should be a numpy array. +Determined's TensorFlow launcher will automatically configure your training script with the right +``TF_CONFIG`` environment variable for distributed training when distributed resources are +available, and will safely do nothing when they are not. -#. A tuple ``(x, y, sample_weights)`` of NumPy arrays. +**************************************************************** + Obtain a ``det.core.Context`` and a ``tf.distribute.Strategy`` +**************************************************************** -#. A ``tf.data.dataset`` returning a tuple of either (inputs, targets) or (inputs, targets, - sample_weights). +When using distributed training, TensorFlow requires you to create your ``Strategy`` early in the +process lifetime, before creating your model. -#. A ``keras.utils.Sequence`` returning a tuple of either (inputs, targets) or (inputs, targets, - sample weights). +Since you wrapped your training script in Determined's TensorFlow launcher, you can use Determined's +``core.DistributedContext.from_tf_config()`` helper, which will create both a suitable +``DistributedContext`` and ``Strategy`` for the training environment in your training job. Then you +can feed that ``DistributedContext`` to ``det.core.init()`` to get a ``core.Context``, and feed all +of that to your ``main()`` function (or equivalent) in your training script: -If using ``tf.data.Dataset``, users are required to wrap both their training and validation dataset -using :meth:`self.context.wrap_dataset `. This -wrapper is used to shard the dataset for distributed training. For optimal performance, users should -wrap a dataset immediately after creating it. +.. code:: python -.. include:: ../../../_shared/note-dtrain-learn-more.txt + if __name__ == "__main__": + distributed, strategy = det.core.DistributedContext.from_tf_config() + with det.core.init(distributed=distributed) as core_context: + main(core_context, strategy) -****************** - Define the Model -****************** +***************** + Build the Model +***************** -Users are required wrap their model prior to compiling it using :meth:`self.context.wrap_model -`. This is typically done inside -:meth:`~determined.keras.TFKerasTrial.build_model`. +Building a distributed-capable model is easy in Keras; you just need to wrap your model building and +compiling in the ``strategy.scope()``. See the `TensorFlow documentation +`__ +for more details -****************************************** - Customize Calling Model Fitting Function -****************************************** +.. code:: python -The :class:`~determined.keras.TFKerasTrial` interface allows the user to configure how ``model.fit`` -is called by calling :meth:`self.context.configure_fit() -`. + def main(core_context, strategy): + with strategy.scope(): + model = my_build_model() + model.compile(...) + +*********************************** + Create the ``DeterminedCallback`` +*********************************** + +The :class:`~determined.keras.DeterminedCallback` automatically integrates your training with the +Determined cluster. It reports both train and test metrics, reports progress, saves checkpoints, and +uploads them to checkpoint storage. Additionally, it manages preemption signals from the Determined +master (for example, when you pause your experiment), gracefully halting training and later resuming +from where it left off. + +The ``DeterminedCallback`` has only three required inputs: + - the ``core_context`` you already created + - a ``checkpoint`` UUID to start training from, or ``None`` + - a ``continue_id`` used to decide how to treat the checkpoint + +In training jobs, an easy value for ``checkpoint`` is ``det.get_cluster_info().latest_checkpoint``, +which will automatically be populated with the latest checkpoint saved by this trial, or ``None``. +If, for example, you wanted to start training from a checkpoint and support pausing and resuming, +you could use ``info.latest_checkpoint or my_starting_checkpoint``. + +The ``continue_id`` helps the ``DeterminedCallback`` decide if the provided checkpoint represents +just the starting weights and training should begin at epoch=0, or if the checkpoint represents a +partially complete training that should pick up where it left off (at epoch > 0). The provided +``continue_id`` is saved along with every checkpoint, and when loading the starting checkpoint, if +the ``continue_id`` matches what was in the checkpoint, training state is also loaded from the +checkpoint. In training jobs, an easy value for ``continue_id`` is +``det.get_cluster_info.trial.trial_id``. + +See the reference for :class:`~determined.keras.DeterminedCallback` for details on its optional +parameters. -*************** - Checkpointing -*************** +.. code:: python -A checkpoint includes the model definition (Python source code), experiment configuration file, -network architecture, and the values of the model's parameters (i.e., weights) and hyperparameters. -When using a stateful optimizer during training, checkpoints will also include the state of the -optimizer (i.e., learning rate). You can also embed arbitrary metadata in checkpoints via a -:ref:`Python SDK `. + info = det.get_cluster_info() + assert info and info.task_type == "TRIAL", "this example only runs as a trial on the cluster" -TensorFlow Keras trials are checkpointed to a file named ``determined-keras-model.h5`` using -``tf.keras.models.save_model``. You can learn more from the `TF Keras docs -`__. + det_cb = det.keras.DeterminedCallback( + core_context, + checkpoint=info.latest_checkpoint, + continue_id=info.trial.trial_id, + ) *********** - Callbacks + Load Data *********** -To execute arbitrary Python code during the lifecycle of a :class:`~determined.keras.TFKerasTrial`, -implement the :class:`determined.keras.callbacks.Callback` interface (an extension of the -``tf.keras.callbacks.Callbacks`` interface) and supply them to the -:class:`~determined.keras.TFKerasTrial` by implementing -:meth:`~determined.keras.TFKerasTrial.keras_callbacks`. +Loading data is done as usual, though additional considerations may arise if your existing +data-loading code is not container-ready. For more details, see :ref:`load-model-data`. -.. _keras-profiler: +If you want to take advantage Determined's distributed training, you may need to ensure that your +input data is properly sharded. See `TensorFlow documentation +`__ for details. -*********** - Profiling -*********** - -Determined supports integration with the native TF Keras profiler. Results will automatically be -uploaded to the trial's TensorBoard path and can be viewed in the Determined Web UI. +.. include:: ../../../_shared/note-dtrain-learn-more.txt -The Keras profiler is configured as a callback in the :class:`~determined.keras.TFKerasTrial` class. -The :class:`determined.keras.callbacks.TensorBoard` callback is a thin wrapper around the native -Keras TensorBoard callback, ``tf.keras.callbacks.TensorBoard``. It overrides the ``log_dir`` -argument to set the Determined TensorBoard path, while other arguments are passed directly into -``tf.keras.callbacks.TensorBoard``. For a list of accepted arguments, consult the `official Keras -API documentation `_. +************************* + TensorBoard Integration +************************* -The following code snippet will configure profiling for batches 5 and 10, and will compute weight -histograms every 1 epochs. +Optionally, you can use Determined's :class:`~determined.keras.TensorBoard` callback, which extends +Keras' ``TensorBoard`` callback with the ability to automatically upload metrics to Determined's +checkpoint storage. Determined's ``TensorBoard`` callback is configured identically to Keras' except +it takes an additional ``core_context`` initial argument: .. code:: python - from determined import keras + tb_cb = det.keras.TensorBoard(core_context, ...) - def keras_callbacks(self) -> List[tf.keras.callbacks.Callback]: - return [ - keras.callbacks.TensorBoard( - update_freq="batch", - profile_batch='5, 10', - histogram_freq=1, - ) - ] +Then simply include it in your ``model.fit()`` as normal. -.. note:: +************************* + Calling ``model.fit()`` +************************* + +The only remaining step is to pass your callbacks to your ``model.fit()``: + +.. code:: python - Though specifying batches to profile with ``profile_batch`` is optional, profiling every batch - may cause a large amount of data to be uploaded to Tensorboard. This may result in long rendering - times for Tensorboard and memory issues. For long-running experiments, it is recommended to - configure profiling only on desired batches. + model.fit( + ..., + callbacks=[det_cb, tb_cb], + ) diff --git a/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst b/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst index 39389631b09..da4dd49159a 100644 --- a/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst +++ b/docs/model-dev-guide/api-guides/apis-howto/api-pytorch-ug.rst @@ -745,6 +745,7 @@ arguments, such as checkpointing periods, validation periods, and checkpointing trial = MyTrial(train_context) trainer = det.pytorch.Trainer(trial, train_context) + trainer.fit( + + max_length=pytorch.Epoch(10), + checkpoint_period=pytorch.Batch(100), + validation_period=pytorch.Batch(100), + checkpoint_policy="all" @@ -760,8 +761,7 @@ Run Your Training Script Locally ================================ Run training scripts locally without submitting to a cluster or defining an experiment configuration -file. Be sure to specify ``max_length`` in the ``.fit()`` call, which is used in local training mode -to determine the maximum number of steps to train for. +file. .. code:: python @@ -773,7 +773,7 @@ to determine the maximum number of steps to train for. trial = MyTrial(train_context) trainer = det.pytorch.Trainer(trial, train_context) trainer.fit( - max_length=pytorch.Epoch(1), + max_length=pytorch.Epoch(10), checkpoint_period=pytorch.Batch(100), validation_period=pytorch.Batch(100), checkpoint_policy="all", @@ -786,7 +786,7 @@ to determine the maximum number of steps to train for. main() You can run this Python script directly (``python3 train.py``), or in a Jupyter notebook. This code -will train for one epoch, and checkpoint and validate every 100 batches. +will train for ten epochs, and checkpoint and validate every 100 batches. Local Distributed Training ========================== @@ -808,7 +808,7 @@ code. Both Horovod and PyTorch Distributed backends are supported. trial = MyTrial(train_context) trainer = det.pytorch.Trainer(trial, train_context) trainer.fit( - max_length=pytorch.Epoch(1), + max_length=pytorch.Epoch(10), checkpoint_period=pytorch.Batch(100), validation_period=pytorch.Batch(100), checkpoint_policy="all" @@ -827,7 +827,7 @@ tests around your model code. .. code:: diff trainer.fit( - max_length=pytorch.Epoch(1), + max_length=pytorch.Epoch(10), checkpoint_period=pytorch.Batch(100), validation_period=pytorch.Batch(100), + test_mode=True @@ -864,7 +864,7 @@ Example workflow of frequent iterations between local debugging and cluster depl trial = MNistTrial(train_context) trainer = det.pytorch.Trainer(trial, train_context) trainer.fit( - max_length=pytorch.Epoch(1), + max_length=pytorch.Epoch(11), checkpoint_period=pytorch.Batch(100), validation_period=pytorch.Batch(100), + latest_checkpoint=latest_checkpoint, @@ -879,6 +879,7 @@ To run Trainer API solely on-cluster, the code is much simpler: trial_inst = model.MNistTrial(train_context) trainer = det.pytorch.Trainer(trial_inst, train_context) trainer.fit( + max_length=pytorch.Epoch(11), checkpoint_period=pytorch.Batch(100), validation_period=pytorch.Batch(100), latest_checkpoint=det.get_cluster_info().latest_checkpoint, @@ -896,8 +897,6 @@ Your experiment configuration file must contain searcher configuration and entry searcher: name: single metric: validation_loss - max_length: - epochs: 1 resources: slots_per_trial: 8 entrypoint: python3 -m determined.launch.torch_distributed python3 train.py diff --git a/docs/model-dev-guide/api-guides/apis-howto/deepspeed/_index.rst b/docs/model-dev-guide/api-guides/apis-howto/deepspeed/_index.rst index 649dcc8a7c7..42795a6df92 100644 --- a/docs/model-dev-guide/api-guides/apis-howto/deepspeed/_index.rst +++ b/docs/model-dev-guide/api-guides/apis-howto/deepspeed/_index.rst @@ -33,9 +33,6 @@ Determined DeepSpeed documentation: :class:`~determined.pytorch.PyTorchTrial` to :class:`~determined.pytorch.deepspeed.DeepSpeedTrial`. -- :ref:`DeepSpeed Autotune: User Guide ` demonstrates how to use DeepSpeed - Autotune to take full advantage of your hardware and model. - - :ref:`API Reference ` lays out the classes and methods related to DeepSpeed support including the full API specification for :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and @@ -46,6 +43,5 @@ Determined DeepSpeed documentation: :hidden: API Usage Guide - Autotuning advanced pytorch2deepspeed diff --git a/docs/model-dev-guide/api-guides/apis-howto/deepspeed/autotuning.rst b/docs/model-dev-guide/api-guides/apis-howto/deepspeed/autotuning.rst deleted file mode 100644 index 143895b5d36..00000000000 --- a/docs/model-dev-guide/api-guides/apis-howto/deepspeed/autotuning.rst +++ /dev/null @@ -1,312 +0,0 @@ -.. _deepspeed-autotuning: - -################################ - DeepSpeed Autotune: User Guide -################################ - -.. important:: - - **Deprecation Notice**: DeepSpeed Autotune is deprecated and will be removed in an upcoming - release. For information on supported search methods, please visit :ref:`search-methods`. - -.. meta:: - :description: This user guide demonstrates how to optimize DeepSpeed parameters in order to take full advantage of the user's hardware and model. - -Getting the most out of DeepSpeed (DS) requires aligning the many DS parameters with the specific -properties of your hardware and model. Determined AI's DeepSpeed Autotune (``dsat``) helps to -optimize these settings through an easy-to-use API with very few changes required in user-code, as -we describe in the remainder of this user guide. ``dsat`` can be used with -:class:`~determined.pytorch.deepspeed.DeepSpeedTrial`, :ref:`Core API `, and -`Hugging Face Trainer `__. - -************** - How it Works -************** - -You do not need to create a special configuration file to use ``dsat``. Assuming you have DeepSpeed -code which already functions, autotuning is as easy as inserting one or two helper functions into -your code and modifying the launch command. - -For instance, let's say your directory contains DeepSpeed code and a corresponding ``single`` trial -experiment configuration file ``deepspeed.yaml``. Then, after inserting a line or two of -``dsat``-specific code per the instructions in the following sections, launching the ``dsat`` -experiments is as easy as replacing the usual experiment-launching command: - -.. code:: - - det experiment create deepspeed.yaml . - -with: - -.. code:: - - python3 -m determined.pytorch.dsat asha deepspeed.yaml . - -The above uses Determined AI's DeepSpeed Autotune with the ``asha`` algorithm, one of three -available search methods: - -- ``asha``: Adaptively searches over randomly selected DeepSpeed configurations, allocating more - compute resources to well-performing configurations. See :ref:`this introduction to ASHA - ` for more details. - -- ``binary``: Performs a simple binary search over the batch size for randomly-generated DS - configurations. - -- ``random``: Conducts a search over random DeepSpeed configurations with an aggressive - early-stopping criteria based on domain-knowledge of DeepSpeed and the search history. - -DeepSpeed Autotune is built on top of Custom Searcher (see :ref:`topic-guides_hp-tuning-det_custom`) -which starts up two separate experiments: - -- ``single`` Search Runner Experiment: This experiment coordinates and schedules the trials that - run the model code. -- ``custom`` Experiment: This experiment contains the trials referenced above whose results are - reported back to the search runner. - -Initially, a profiling trial is created to gather information regarding the model and computational -resources. The search runner experiment takes this initial profiling information and creates a -series of trials to search for the DS settings which optimize ``FLOPS_per_gpu``, ``throughput`` -(samples/second), or latency timing information. The results of all such trials can be viewed in the -``custom`` experiment above. The search is informed both by the initial profiling trial and the -results of each subsequent trial, all of whose results are fed back to the search runner. - -.. warning:: - - Determined's DeepSpeed Autotune is not compatible with pipeline or model parallelism. The - to-be-trained model must be a ``DeepSpeedEngine`` instance (not a ``PipelineEngine`` instance). - -******************* - User Code Changes -******************* - -To use ``dsat`` with :class:`~determined.pytorch.deepspeed.DeepSpeedTrial`, Core API, and Hugging -Face Trainer, specific changes must be made to your user code. In the following sections, we will -describe specific use cases and the changes needed for each. - -.. _using_deepspeed_trial: - -DeepSpeedTrial -============== - -To use Determined's DeepSpeed Autotune with ``DeepSpeedTrial``, you must meet the following -requirements. - -First, it is assumed that a base DeepSpeed configuration exists in a file (written following the -`DeepSpeed documentation here `_). We then require that -your Determined ``yaml`` configuration points to the location of that file through a -``deepspeed_config`` key in its ``hyperparameters`` section. For example, if your default DeepSpeed -configuration is stored in ``ds_config.json`` at the top-level of your model directory, your -``hyperparameters`` section should include: - -.. code:: yaml - - hyperparameters: - deepspeed_config: ds_config.json - -Second, your ``DeepSpeedTrial`` code must use our -:func:`~determined.pytorch.dsat.get_ds_config_from_hparams` helper function to get the DeepSpeed -configuration dictionary which is generated by DeepSpeed Autotune for each trial. These dictionaries -are generated by overwriting certain fields in the base DeepSpeed configuration referenced in the -step above. The returned dictionary can then be passed to ``deepspeed.initialize`` as usual: - -.. code:: python - - from determined.pytorch.deepspeed import DeepSpeedTrial, DeepSpeedTrialContext - from determined.pytorch import dsat - - - class MyDeepSpeedTrial(DeepSpeedTrial): - def __init__(self, context: DeepSpeedTrialContext) -> None: - self.hparams = self.context.get_hparams() - config = dsat.get_ds_config_from_hparams(self.hparams) - model = ... - model_parameters= ... - - model_engine, optimizer, train_loader, lr_scheduler = deepspeed.initialize( - model=model, model_parameters=model_parameters, config=config - ) - -Using Determined's DeepSpeed Autotune with a :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` -instance requires no further changes to your code. - -For a complete example of how to use DeepSpeed Autotune with ``DeepSpeedTrial``, visit the -`Determined GitHub Repo -`__ -and navigate to ``examples/deepspeed_autotune/torchvision/deepspeed_trial`` . - -.. note:: - - To find out more about ``DeepSpeedTrial``, visit :ref:`deepspeed-api`. - -Core API -======== - -When using DeepSpeed Autotune with a Core API experiment, there is one additional change to be made -following the steps in the :ref:`using_deepspeed_trial` section above. - -The ``forward``, ``backward``, and ``step`` methods of the ``DeepSpeedEngine`` class need to be -wrapped in the :func:`~determined.pytorch.dsat.dsat_reporting_context` context manager. This -addition ensures that the autotuning metrics from each trial are captured and reported back to the -Determined master. - -Here is an example sketch of ``dsat`` code with Core API: - -.. code:: python - - for op in core_context.searcher.operations(): - for (inputs, labels) in trainloader: - with dsat.dsat_reporting_context(core_context, op): # <-- The new code - outputs = model_engine(inputs) - loss = criterion(outputs, labels) - model_engine.backward(loss) - model_engine.step() - -In this code snippet, ``core_context`` is the :class:`~determined.core.Context` instance which was -initialized with :func:`determined.core.init`. The context manager requires access to both -``core_context`` and the current :class:`~determined.core.SearcherOperation` instance (``op``) to -appropriately report results. Outside of a ``dsat`` context, ``dsat_reporting_context`` is a no-op, -so there is no need to remove the context manager after the ``dsat`` trials have completed. - -For a complete example of how to use DeepSpeed Autotune with Core API, visit the `Determined GitHub -Repo -`__ -and navigate to ``examples/deepspeed_autotune/torchvision/core_api`` . - -Hugging Face Trainer -==================== - -You can also use Determined's DeepSpeed Autotune with the Hugging Face (HF) Trainer and Determined's -:class:`~determined.transformers.DetCallback` callback object to optimize your DeepSpeed parameters. - -Similar to the previous case (Core API), you need to add a ``deepspeed_config`` field to the -``hyperparameters`` section of your experiment configuration file, specifying the relative path to -the DS ``json`` config file. - -Reporting results back to the Determined master requires both the ``dsat.dsat_reporting_context`` -context manager and ``DetCallback``. - -Furthermore, since ``dsat`` performs a search over different batch sizes and Hugging Face expects -parameters to be specified as command-line arguments, an additional helper function, -:func:`~determined.pytorch.dsat.get_hf_args_with_overwrites`, is needed to create consistent Hugging -Face arguments. - -Here is an example code snippet from a Hugging Face Trainer script that contains key pieces of -relevant code: - -.. code:: python - - from determined.transformers import DetCallback - from determined.pytorch import dsat - from transformers import HfArgumentParser,Trainer, TrainingArguments, - - hparams = self.context.get_hparams() - parser = HfArgumentParser(TrainingArguments) - args = sys.argv[1:] - args = dsat.get_hf_args_with_overwrites(args, hparams) - training_args = parser.parse_args_into_dataclasses(args, look_for_args_file=False) - - det_callback = DetCallback(core_context, ...) - trainer = Trainer(args=training_args, ...) - with dsat.dsat_reporting_context(core_context, op=det_callback.current_op): - train_result = trainer.train(resume_from_checkpoint=checkpoint) - -.. important:: - - - The ``dsat_reporting_context`` context manager shares the same initial - :class:`~determined.core.SearcherOperation` as the ``DetCallback`` instance through its - ``op=det_callback.current_op`` argument. - - - The entire ``train`` method of the Hugging Face trainer is wrapped in the - ``dsat_reporting_context`` context manager. - -To find examples that use DeepSpeed Autotune with Hugging Face Trainer, visit the `Determined GitHub -Repo `__ and navigate -to ``examples/hf_trainer_api``. - -****************** - Advanced Options -****************** - -The command-line entrypoint to ``dsat`` has various available options, some of them -search-algorithm-specific. All available options for any given search method can be found through -the command: - -.. code:: - - python3 -m determined.pytorch.dsat asha --help - -and similar for the ``binary`` and ``random`` search methods. - -Flags that are particularly important are detailed below. - -General Options -=============== - -The following options are available for every search method. - -- ``--max-trials``: The maximum number of trials to run. Default: ``64``. - -- ``--max-concurrent-trials``: The maximum number of trials that can run concurrently. Default: - ``16``. - -- ``--max-slots``: The maximum number of slots that can be used concurrently. Defaults to ``None``, - i.e., there is no limit by default. - -- ``--metric``: The metric to be optimized. Defaults to ``FLOPS-per-gpu``. Other available options - are ``throughput``, ``forward``, ``backward``, and ``latency``. - -- ``--run-full-experiment``: If specified, after the ``dsat`` experiment has completed, a - ``single`` experiment will be launched using the specifications in the ``deepspeed.yaml`` - overwritten with the best-found DS configuration parameters. - -- ``--zero-stages``: This flag allows the user to limit the search to a subset of the stages by - providing a space-separated list, as in ``--zero-stages 2 3``. Default: ``1 2 3``. - -.. _asha-options: - -``asha`` Options -================ - -The ``asha`` search algorithm randomly generates various DeepSpeed configurations and attempts to -tune the batch size for each configuration through a binary search. ``asha`` adaptively allocates -resources to explore each configuration (providing more resources to promising lineages) where the -resource is the number of steps taken in each binary search (i.e., the number of trials). - -``asha`` can be configured with the following flags: - -- ``--max-rungs``: The maximum total number of rungs to use in the ASHA algorithm. Larger values - allow for longer binary searches. Default: ``5``. - -- ``--min-binary-search-trials``: The minimum number of trials to use for each binary search. The - ``r`` parameter in `the ASHA paper `_. Default: ``3``. - -- ``--divisor``: Factor controlling the increased computational allotment across rungs, and the - decrease in their population size. The ``eta`` parameter in `the ASHA paper - `_. Default: ``2``. - -- ``--search-range-factor``: The inclusive, initial ``hi`` bound on the binary search is set by an - approximate computation (the ``lo`` bound is always initialized to ``1``). This parameter adjusts - the ``hi`` bound by a factor of ``search-range-factor``. Default: ``1.0``. - -``binary`` Options -================== - -The ``binary`` search algorithm performs a straightforward search over the batch size for a -collection of randomly-drawn DS configurations. A single option is available for this search: -``--search-range-factor``, which plays precisely the same role as in the :ref:`asha-options` section -above. - -``random`` Options -================== - -The ``random`` search algorithm performs a search over randomly drawn DS configurations and uses a -semi-random search over the batch size. - -``random`` can be configured with the following flags: - -- ``--trials-per-random-config``: The maximum batch size configuration which will tested for a - given DS configuration. Default: ``5``. - -- ``--early-stopping``: If provided, the experiment will terminate if a new best-configuration has - not been found in the last ``early-stopping`` trials. Default: ``None``, corresponding to no such - early stopping. diff --git a/docs/model-dev-guide/create-experiment.rst b/docs/model-dev-guide/create-experiment.rst index 7c301850576..bd62ac1b320 100644 --- a/docs/model-dev-guide/create-experiment.rst +++ b/docs/model-dev-guide/create-experiment.rst @@ -172,6 +172,31 @@ Use the ``-h`` option to get the latest usage: python3 -m determined.launch.deepspeed -h +.. _launch-tensorflow: + +TensorFlow Launcher +=================== + +Format: + +``determined.launch.tensorflow [--] SCRIPT...`` + +This launcher configures a ``TF_CONFIG`` environment variable suitable for whichever level of +TensorFlow distributed training is appropriate for the available training resources +(``MultiWorkerMirroredStrategy``, ``MirroredStrategy``, or the default strategy). + +Example: + +.. code:: bash + + python3 -m determined.launch.tensorflow -- python3 ./my_train.py --my-arg=value + +Use the ``-h`` option to get the latest usage: + +.. code:: bash + + python3 -m determined.launch.tensorflow -h + Legacy Launcher =============== diff --git a/docs/model-dev-guide/debug-models.rst b/docs/model-dev-guide/debug-models.rst index 7594cdb08aa..74385d612e9 100644 --- a/docs/model-dev-guide/debug-models.rst +++ b/docs/model-dev-guide/debug-models.rst @@ -70,9 +70,9 @@ for debugging. See :ref:`pytorch_trainer_ug` for usage details. #. Create simple tests to verify each ``Trial`` subclass method. Examples of what these tests might look like for - :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and :class:`~determined.keras.TFKerasTrial` - can be found in the :meth:`determined.TrialContext.from_config` documentation, but only you can - verify what is reasonable for your test. + :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` can be found in the + :meth:`determined.TrialContext.from_config` documentation, but only you can verify what is + reasonable for your test. #. Diagnose failures: @@ -385,8 +385,8 @@ step only applies if you have multiple GPUs and want to use distributed training consume too many resources and prevent the experiment from starting. - Determined is designed to control the details of distributed training for you. If you also try - to control those details, such as by calling ``tf.config.set_visible_devices()`` in a - :class:`~determined.keras.TFKerasTrial`, it is likely to cause issues. + to control those details, such as by calling ``tf.config.set_visible_devices()`` while + training a Keras model, it is likely to cause issues. - Some classes of metrics must be specially calculated during distributed training. Most metrics, such as loss or accuracy, can be calculated piecemeal on each worker in a distributed diff --git a/docs/model-dev-guide/dtrain/config-templates.rst b/docs/model-dev-guide/dtrain/config-templates.rst index 0e42b6bf737..671f9de287e 100644 --- a/docs/model-dev-guide/dtrain/config-templates.rst +++ b/docs/model-dev-guide/dtrain/config-templates.rst @@ -50,8 +50,6 @@ and a simplified configuration. searcher: name: single metric: error - max_length: - batches: 500 smaller_is_better: true You may find that many experiments share the same values for the ``checkpoint_storage`` field, @@ -86,8 +84,6 @@ The experiment configuration for this experiment can then be written using the f searcher: name: single metric: error - max_length: - batches: 500 smaller_is_better: true To launch the experiment with the template: diff --git a/docs/model-dev-guide/dtrain/reproducibility.rst b/docs/model-dev-guide/dtrain/reproducibility.rst index 6aad22e2d4b..cc87e0fbd49 100644 --- a/docs/model-dev-guide/dtrain/reproducibility.rst +++ b/docs/model-dev-guide/dtrain/reproducibility.rst @@ -43,8 +43,8 @@ The experiment seed is used as a source of randomness for any hyperparameter sam The experiment seed is also used to generate a **trial seed** for every trial associated with the experiment. -In the ``Trial`` interface, the trial seed is accessible within the trial class using -``self.ctx.get_trial_seed()``. +When training on-cluster, the trial seed is accessible via +:class:`det.get_cluster_info().trial.trial_seed ` ******************* Coding Guidelines @@ -67,16 +67,12 @@ To achieve reproducible initial conditions in an experiment, please follow these ************************************** When doing CPU-only training with TensorFlow, it is possible to achieve floating-point -reproducibility throughout optimization. If using the :class:`~determined.keras.TFKerasTrial` API, -implement the optional :meth:`~determined.keras.TFKerasTrial.session_config` method to override the -default session configuration: +reproducibility throughout optimization: .. code:: python - def session_config(self) -> tf.ConfigProto: - return tf.ConfigProto( - intra_op_parallelism_threads=1, inter_op_parallelism_threads=1 - ) + tf.config.threading.set_intra_op_parallelism_threads(1) + tf.config.threading.set_inter_op_parallelism_threads(1) .. warning:: diff --git a/docs/model-dev-guide/hyperparameter/search-methods/_index.rst b/docs/model-dev-guide/hyperparameter/search-methods/_index.rst index fdbea0ff0ca..cd6410e5918 100644 --- a/docs/model-dev-guide/hyperparameter/search-methods/_index.rst +++ b/docs/model-dev-guide/hyperparameter/search-methods/_index.rst @@ -10,7 +10,7 @@ values to use in each trial. Every searcher is configured with the name of the v optimize (via the ``metric`` field), in addition to other searcher-specific options. For example, the ``adaptive_asha`` searcher (`arXiv:1810.0593 `_), suitable for larger experiments with many trials, is configured with the maximum number of trials to run, the -maximum training length allowed per trial, and the maximum number of trials that can be worked on +name and maxmimum value of a "time" metric, and the maximum number of trials that can be worked on simultaneously: .. code:: yaml @@ -19,8 +19,8 @@ simultaneously: name: "adaptive_asha" metric: "validation_loss" max_trials: 16 - max_length: - epochs: 1 + time_metric: batches + max_time: 100000 max_concurrent_trials: 8 For details on the supported searchers and their respective configuration options, refer to @@ -59,8 +59,6 @@ Determined also supports other common hyperparameter search algorithms: - :ref:`Random ` evaluates a set of hyperparameter configurations chosen at random and returns the best. -You can also implement your own :ref:`custom search methods `. - .. toctree:: :maxdepth: 1 :hidden: diff --git a/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst b/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst index cf07b997154..1e7846d6295 100644 --- a/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst +++ b/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst @@ -20,12 +20,13 @@ Search mode: Resource budget: -- ``max_length``: The maximum training length (see :ref:`Training Units - `) of any trial that survives to the end of the - experiment. This quantity is domain-specific and should roughly reflect the number of minibatches - the model must be trained on for it to converge on the data set. For users who would like to - determine this number experimentally, train a model with reasonable hyperparameters using the - ``single`` search method. +- ``time_metric``, ``max_time``: The name of the "time" metric and the maximum value it will take + for a trial that survives to the end of the experiment (see :ref:`Training Units + `). Note that the searcher will expect this metric to + appear in validation metrics reported by the model. This quantity is domain-specific and should + roughly reflect the number of minibatches the model must be trained on for it to converge on the + dataset. For users who would like to determine this number experimentally, train a model with + reasonable hyperparameters using the ``single`` search method. - ``max_trials``: This indicates the total number of hyperparameter settings that will be evaluated in the experiment. Set ``max_trials`` to at least 500 to take advantage of speedups from @@ -85,7 +86,8 @@ At the end, the trial with best performance has the hyperparameter setting the S In the example above, we generalize "halving" with a field called divisor, which determines what fraction of trials are kept in successive rungs, as well as the training length in successive rungs. -``max_length`` is 16 epochs, which is the maximum length a trial is trained for. +In this example, ``time_metric`` would probably be "epochs" and ``max_time`` would be 16, since the +maximum time a trial is trained for is 16 epochs. In general, SHA has a fixed ``divisor`` d. In the first rung, it generates an initial set of randomly chosen trials and runs until each trial has trained for the same length. In the next rung, @@ -131,9 +133,9 @@ On one end, ``aggressive`` applies early stopping in a very eager manner; this m corresponds to only making a single call to ASHA. With the default ``divisor`` of 4, 75% of the remaining trials will be eliminated in each rung after only being trained for 25% the length of the next rung. This implies that relatively few trials will be allowed to finish even a small fraction -of the length needed train to convergence (``max_length``). This aggressive early stopping behavior -allows the searcher to start more trials for a wider exploration of hyperparameter configurations, -at the risk of discarding a configuration too soon. +of the length needed train to convergence (``time_metric``, ``max_time``). This aggressive early +stopping behavior allows the searcher to start more trials for a wider exploration of hyperparameter +configurations, at the risk of discarding a configuration too soon. On the other end, ``conservative`` mode is more similar to a ``random`` search, in that it performs significantly less pruning. Extra ASHA subroutines are spawned with fewer rungs and longer training @@ -152,22 +154,22 @@ rung (N in the above ASHA example). **Q: How do I control how long a trial is trained for before it is potentially discarded?** -The training length is guaranteed to be at least ``max_length / 256`` by default, or ``max_length / -divisor ^ max_rungs-1`` in general. It is recommended to configure this in records or epochs if the -``global_batch_size`` hyperparameter is not constant, to ensure each trial trains on the same amount -of data. +The training length is guaranteed to be at least ``max_time / 256`` by default, or ``max_time / +divisor ^ max_rungs-1`` in general. It is recommended to use records or epochs as your +``time_metric`` if your batch size is not constant across all trials, to ensure each trial trains on +the same amount of data. -**Q: How do I make sure ``x`` trials are run the full training length (``max_length``)?** +**Q: How do I make sure ``x`` trials are run the full training length (``max_time``)?** The number of initial trials is determined by a combination of ``mode``, ``max_trials``, -``divisor``, ``max_rungs``, ``max_length`` and ``bracket_rungs``. Here is a rule of thumb for the +``divisor``, ``max_rungs``, ``max_time`` and ``bracket_rungs``. Here is a rule of thumb for the default configuration of ``max_rungs: 5`` and ``divisor: 4``, with ``mode: standard`` and a large enough ``max_trials``: - The initial number of trials is ``max_trials``. -- To ensure that ``x`` trials are run ``max_length``, set ``max_trials`` high enough for the - brackets with their halving rate (the ``divisor``) to allow ``x`` trials to make it to the final +- To ensure that ``x`` trials are run ``max_time``, set ``max_time`` high enough for the brackets + with their halving rate (the ``divisor``) to allow ``x`` trials to make it to the final ``rungs``. This can be viewed by the command describe below. A configuration setting that meets set goals can be found by trial and error. The command diff --git a/docs/model-dev-guide/hyperparameter/search-methods/hp-custom.rst b/docs/model-dev-guide/hyperparameter/search-methods/hp-custom.rst deleted file mode 100644 index bf8e047cd1d..00000000000 --- a/docs/model-dev-guide/hyperparameter/search-methods/hp-custom.rst +++ /dev/null @@ -1,154 +0,0 @@ -.. _topic-guides_hp-tuning-det_custom: - -####################### - Custom Search Methods -####################### - -.. important:: - - **Deprecation Notice**: Support for all custom search methods have been deprecated and will be - removed in a future release. Please see :ref:`search-methods` for details on supported preset - searchers. - -+----------------------------------------------------------------+ -| API reference | -+================================================================+ -| :ref:`custom-searcher-reference` | -+----------------------------------------------------------------+ - -Determined supports defining your own hyperparameter search algorithms and provides search runner -utilities for executing them. - -.. tip:: - - Remember that a Determined experiment is a set of trials, each corresponding to a point in the - hyperparameter space. - -To implement a custom hyperparameter tuning algorithm, subclass -:class:`~determined.searcher.SearchMethod`, overriding its event handler methods. If you want to -achieve fault tolerance and your search method carries any state in addition to the SearcherState -passed into the event handlers, also override -:meth:`~determined.searcher.SearchMethod.save_method_state` and -:meth:`~determined.searcher.SearchMethod.load_method_state`. - -To run the custom hyperparameter tuning algorithm, you can use: - -- :class:`~determined.searcher.LocalSearchRunner` to run on your machine, -- :class:`~determined.searcher.RemoteSearchRunner` to run on a Determined cluster. - -.. note:: - - Using :class:`~determined.searcher.RemoteSearchRunner` will create two experiments, with one - orchestrating the hyperparameter search of the other. - -Both search runners execute the custom hyperparameter tuning algorithm and start a multi-trial -experiment on a Determined cluster. - -The following sections describe the steps needed to implement and use a custom hyperparameter search -algorithm. - -********************************************** - Experiment Configuration for Custom Searcher -********************************************** - -Specify the ``custom`` searcher type in the experiment configuration: - -.. code:: yaml - - searcher: - name: custom - metric: validation_loss - smaller_is_better: true - unit: batches - -*********************************** - Run Hyperparameter Search Locally -*********************************** - -A script performing hyperparameter tuning using :class:`~determined.searcher.LocalSearchRunner` may -look like the following ``run_local_searcher.py``: - -.. code:: python - - import logging - from pathlib import Path - from determined import searcher - - - if __name__ == "__main__": - # The content of the following directory is uploaded to Determined cluster. - # It should include all files necessary to run the experiment (as usual). - model_context_dir = "experiment_files" - - # Path to the .yaml file with the multi-trial experiment configuration. - model_config = "experiment_files/config.yaml" - - # While LocalSearchRunner saves its own state and ensures invoking save() and - # load() methods when necessary, a user is responsible for implementing - # SearchMethod.save_method_state() and SearchMethod.load_method_state() to ensure - # correct resumption of the SearchMethod. - searcher_dir = Path("local_search_runner/searcher_dir") - - # Instantiate your search method, passing the necessary parameters. - search_method = MySearchMethod(...) - - search_runner = searcher.LocalSearchRunner(search_method, searcher_dir=searcher_dir) - - experiment_id = search_runner.run(model_config, model_dir=model_context_dir) - logging.info(f"Experiment {experiment_id} has been completed.") - -To start the custom search method locally, you can use the following CLI command: - -.. code:: bash - - $ python run_local_searcher.py - -**************************************** - Run Hyperparameter Search on a Cluster -**************************************** - -A script to run your custom search method on a Determined cluster may look like the following -``run_remote_searcher.py``: - -.. code:: python - - import determined as det - from pathlib import Path - from determined import searcher - - if __name__ == "__main__": - model_context_dir = "experiment_files" - - model_config = "experiment_files/config.yaml" - - with det.core.init() as core_context: - info = det.get_cluster_info() - assert info is not None - - search_method = MySearchMethod(...) - - search_runner = searcher.RemoteSearchRunner(search_method, context=core_context) - search_runner.run(model_config, model_dir=model_context_dir) - -To start the custom search method on a cluster, you need to submit it to the master as a -single-trial experiment. To this end, you can use the following CLI command: - -.. code:: bash - - $ det e create searcher_config.yaml context_dir - -The custom search method runs on a Determined cluster as a single trial experiment. Configuration -for the search method experiment is specified in the ``searcher_config.yaml`` and may look like -this: - -.. code:: yaml - - name: remote-searcher - entrypoint: python3 run_remote_searcher.py - searcher: - metric: validation_error - smaller_is_better: true - name: single - max_length: - batches: 1000 - max_restarts: 0 diff --git a/docs/model-dev-guide/hyperparameter/search-methods/hp-grid.rst b/docs/model-dev-guide/hyperparameter/search-methods/hp-grid.rst index 9bccc326b37..b0e590c0956 100644 --- a/docs/model-dev-guide/hyperparameter/search-methods/hp-grid.rst +++ b/docs/model-dev-guide/hyperparameter/search-methods/hp-grid.rst @@ -5,14 +5,13 @@ ############# The ``grid`` search method generates trials on a "grid" of hyperparameter configurations and trains -each trial for the number of training units specified by ``max_length``. The user specifies a set of -values for each hyperparameter via the ``hyperparameters`` field in the :ref:`experiment -configuration `. The "grid" of hyperparameter -configurations is generated by taking the `product -`__ of these sets. For example, if the set of -values for three separate hyperparameters ``aparam``, ``bparam``, and ``cparam`` are specified as -``{0, 1, 2}``, ``{10, 20}``, and ``{"c"}`` respectively, then the grid of tuples ``(aparam, bparam, -cparam)`` generated is: +each trial to completion. The user specifies a set of values for each hyperparameter via the +``hyperparameters`` field in the :ref:`experiment configuration +`. The "grid" of hyperparameter configurations is +generated by taking the `product `__ of these sets. +For example, if the set of values for three separate hyperparameters ``aparam``, ``bparam``, and +``cparam`` are specified as ``{0, 1, 2}``, ``{10, 20}``, and ``{"c"}`` respectively, then the grid +of tuples ``(aparam, bparam, cparam)`` generated is: .. code:: diff --git a/docs/model-dev-guide/hyperparameter/search-methods/hp-random.rst b/docs/model-dev-guide/hyperparameter/search-methods/hp-random.rst index 0d453b625a5..c8d20108a34 100644 --- a/docs/model-dev-guide/hyperparameter/search-methods/hp-random.rst +++ b/docs/model-dev-guide/hyperparameter/search-methods/hp-random.rst @@ -5,8 +5,6 @@ ############### The ``random`` search method generates ``max_trials`` trials with hyperparameters chosen uniformly -at random from the configured hyperparameter space. Each trial is trained for the number of units -specified by ``max_length`` (see :ref:`Training Units `) -and then then the trial's validation metrics are computed. +at random from the configured hyperparameter space. Each trial is trained to completion. See :ref:`Experiment Configuration `. diff --git a/docs/model-dev-guide/hyperparameter/search-methods/hp-single.rst b/docs/model-dev-guide/hyperparameter/search-methods/hp-single.rst index f8d6ca7f9a1..7546fbfff92 100644 --- a/docs/model-dev-guide/hyperparameter/search-methods/hp-single.rst +++ b/docs/model-dev-guide/hyperparameter/search-methods/hp-single.rst @@ -5,8 +5,7 @@ ###################### The ``single`` search method does a very minimal "search": it trains a single hyperparameter -configuration for the number of units specified by ``max_length`` (see :ref:`Training Units -`) and then performs validation. This method is useful for -testing or for training a single model configuration until convergence. +configuration to completion. This method is useful for testing or for training a single model +configuration until convergence. See :ref:`Experiment Configuration `. diff --git a/docs/model-dev-guide/profiling.rst b/docs/model-dev-guide/profiling.rst index 277c9adf719..f59c2452c2a 100644 --- a/docs/model-dev-guide/profiling.rst +++ b/docs/model-dev-guide/profiling.rst @@ -82,9 +82,9 @@ training code. Identifying inefficiencies in individual training operations or s fine-grained context than generic system metrics can provide. For this level of profiling, Determined supports integration with training profilers that are native to their frameworks: -- PyTorch Profiler (:ref:`PyTorch API `) -- DeepSpeed Profiler (:ref:`DeepSpeed API `) -- TensorFlow Keras Profiler (:ref:`Keras API `) +- :ref:`PyTorch Profiler ` +- :ref:`DeepSpeed Profiler ` +- :class:`Keras TensorBoard callback ` Please see your framework's profiler documentation and the Determined Training API guide for usage details. diff --git a/docs/reference/_index.rst b/docs/reference/_index.rst index 3475640054d..6d4022ae26c 100644 --- a/docs/reference/_index.rst +++ b/docs/reference/_index.rst @@ -71,4 +71,3 @@ Python SDK REST API Determined CLI Reference - Customer Searcher Reference diff --git a/docs/reference/custom-searcher-reference.rst b/docs/reference/custom-searcher-reference.rst deleted file mode 100644 index 02c111abb19..00000000000 --- a/docs/reference/custom-searcher-reference.rst +++ /dev/null @@ -1,89 +0,0 @@ -.. _custom-searcher-reference: - -########################### - Custom Searcher Reference -########################### - -.. important:: - - **Deprecation Notice**: Custom Searcher is deprecated and will be removed in a future release. - Please see :ref:`search-methods` for details on supported preset searchers. - -******************************************* - ``determined.searcher.LocalSearchRunner`` -******************************************* - -.. autoclass:: determined.searcher.LocalSearchRunner - :members: - :member-order: bysource - -******************************************** - ``determined.searcher.RemoteSearchRunner`` -******************************************** - -.. autoclass:: determined.searcher.RemoteSearchRunner - :members: - :member-order: bysource - -************************************** - ``determined.searcher.SearchMethod`` -************************************** - -.. autoclass:: determined.searcher.SearchMethod - :members: - -*************************************** - ``determined.searcher.SearcherState`` -*************************************** - -.. autoclass:: determined.searcher.SearcherState - :members: - -*********************************** - ``determined.searcher.Operation`` -*********************************** - -.. autoclass:: determined.searcher.Operation - :members: - -******************************* - ``determined.searcher.Close`` -******************************* - -.. autoclass:: determined.searcher.Close - :members: - -********************************** - ``determined.searcher.Progress`` -********************************** - -.. autoclass:: determined.searcher.Progress - :members: - -******************************** - ``determined.searcher.Create`` -******************************** - -.. autoclass:: determined.searcher.Create - :members: - -*************************************** - ``determined.searcher.ValidateAfter`` -*************************************** - -.. autoclass:: determined.searcher.ValidateAfter - :members: - -********************************** - ``determined.searcher.Shutdown`` -********************************** - -.. autoclass:: determined.searcher.Shutdown - :members: - -************************************** - ``determined.searcher.ExitedReason`` -************************************** - -.. autoclass:: determined.searcher.ExitedReason - :members: diff --git a/docs/reference/experiment-config-reference.rst b/docs/reference/experiment-config-reference.rst index 5ba662831e6..65896f57cb4 100644 --- a/docs/reference/experiment-config-reference.rst +++ b/docs/reference/experiment-config-reference.rst @@ -17,84 +17,6 @@ example: det experiment create config-file.yaml model-directory -**************** - Training Units -**************** - -Some configuration settings, such as searcher training lengths and budgets, -``min_validation_period``, and ``min_checkpoint_period``, can be expressed in terms of a few -training units: records, batches, or epochs. - -- ``records``: A *record* is a single labeled example (sometimes called a sample). -- ``batches``: A *batch* is a group of records. The number of records in a batch is configured via - the ``global_batch_size`` hyperparameter. -- ``epoch``: An *epoch* is a single copy of the entire training data set. - -For example, to specify the ``max_length`` for a searcher in terms of batches, the configuration -would read as shown below. - -.. code:: yaml - - max_length: - batches: 900 - -To express it in terms of records or epochs, ``records`` or ``epochs`` would be specified in place -of ``batches``. For :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and -:class:`~determined.keras.TFKerasTrial`, :ref:`records_per_epoch ` must -also be specified if using epochs. Below is an example that configures a ``single`` searcher to -train a model for 64 epochs. - -.. code:: yaml - - records_per_epoch: 50000 - searcher: - name: single - metric: validation_error - max_length: - epochs: 64 - smaller_is_better: true - -The configured :ref:`records_per_epoch ` is only used for interpreting -configuration fields that are expressed in epochs. Actual epoch boundaries are still determined by -the dataset itself (specifically, the end of an epoch occurs when the training data loader runs out -of records). - -.. note:: - - When the amount of training data for a model is specified using records or epochs, and the batch - size does not evenly divide the configured number of inputs, the remaining "partial batch" of - data will be dropped (ignored). For example, if an experiment is configured to train a single - model on 10 records with a batch size of 3, the model will be trained on only 9 records of data. - In the special case where a trial is configured to train on less than a single batch of data, a - single complete batch will be used instead. - -Training Unit Conversion Limitations (Caveats) -============================================== - -In most cases, values expressed in one type of training unit can be converted to another type while -maintaining the same behavior. However, there are some limitations to consider: - -- Since training units must be positive integers, it is not always possible to convert between - different types of units. For example, converting 50 ``records`` into batches is not possible if - the batch size is 64. - -- When performing a hyperparameter search over a range of values for ``global_batch_size``, the - specified ``batches`` cannot be converted to a fixed number of records or epochs and hence cause - different behaviors in different trials of the search. - -- When using :ref:`adaptive_asha `, a single training - unit is treated as atomic (unable to be divided into fractional parts) when dividing - ``max_length`` into the series of rounds (or rungs) by which we early-stop underperforming - trials. This rounding may result in unexpected behavior when configuring ``max_length`` with a - small number of large epochs or batches. - -To verify your search is working as intended before committing to a full run, you can use the CLI's -"preview search" feature: - -.. code:: - - det preview-search - ********** Metadata ********** @@ -200,30 +122,31 @@ field is empty. Arbitrary Script ---------------- -Required. An arbitrary entrypoint script name. +Required. An arbitrary entrypoint script with args. Example: .. code:: yaml - entrypoint: ./hello.sh + entrypoint: ./hello.sh args... Preconfigured Launch Module with Script --------------------------------------- -Required. The name of a preconfigured launch module and script name. +Required. The name of a preconfigured launch module and script with args. Example: .. code:: yaml - entrypoint: python3 -m (LAUNCH_MODULE) train.py + entrypoint: python3 -m (LAUNCH_MODULE) train.py args... ``LAUNCH_MODULE`` options: - Horovod (determined.launch.horovod) - PyTorch (determined.launch.torch_distributed) - Deepspeed (determined.launch.deepspeed) +- TensorFlow (determined.launch.tensorflow) Preconfigured Launch Module with Legacy Trial Definition -------------------------------------------------------- @@ -273,20 +196,6 @@ preemption, in the unit of batches. The number of records in a batch is controll - As a rule of thumb, it should be set to the number of batches that can be trained in roughly 60--180 seconds. -.. _config-records-per-epoch: - -``records_per_epoch`` -===================== - -Optional. The number of records in the training data set. It must be configured if you want to -specify ``min_validation_period``, ``min_checkpoint_period``, and ``searcher.max_length`` in units -of ``epochs``. - -.. note:: - - For :class:`~determined.pytorch.PyTorchTrial`, epoch length is automatically determined using the - chief worker's dataset length, and this value will be ignored. - .. _max-restarts: ``max_restarts`` @@ -426,8 +335,8 @@ Optional. Specifies the minimum frequency at which validation should be run for epochs: 2 - :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and - :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, - :ref:`records_per_epoch ` must be specified. + :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, ``records_per_epoch`` + must be specified. .. _experiment-config-perform-initial-validation: @@ -468,7 +377,7 @@ Optional. Specifies the minimum frequency for running checkpointing for each tri - :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and :class:`~determined.keras.TFKerasTrial`: If the unit is in epochs, you must also specify - :ref:`records_per_epoch `. + ``records_per_epoch``. ``checkpoint_policy`` ===================== @@ -915,8 +824,6 @@ example, to configure a ``random`` hyperparameter search that trains 5 trials fo name: random metric: accuracy max_trials: 5 - max_length: - batches: 1000 For details on using Determined to perform hyperparameter search, refer to :ref:`hyperparameter-tuning`. For more information on the search methods supported by Determined, @@ -940,23 +847,6 @@ configuration. .. _experiment-configuration_single-searcher-max-length: -``max_length`` --------------- - -Required. The length of the trial. - -- This needs to be set in the unit of records, batches, or epochs using a nested dictionary. For - example: - - .. code:: yaml - - max_length: - epochs: 2 - -- :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and - :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, - :ref:`records_per_epoch ` must be specified. - **Optional Fields** ``smaller_is_better`` @@ -999,23 +889,6 @@ configuration. Required. The number of trials, i.e., hyperparameter configurations, to evaluate. -``max_length`` --------------- - -Required. The length of each trial. - -- This needs to be set in the unit of records, batches, or epochs using a nested dictionary. For - example: - - .. code:: yaml - - max_length: - epochs: 2 - -- :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and - :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, - :ref:`records_per_epoch ` must be specified. - **Optional Fields** ``smaller_is_better`` @@ -1056,24 +929,69 @@ specified via the ``hyperparameters`` field. For more details see the Required. The name of the validation metric used to evaluate the performance of a hyperparameter configuration. -``max_length`` --------------- +**Optional Fields** -Required. The length of each trial. +``smaller_is_better`` +--------------------- -- This needs to be set in the unit of records, batches, or epochs using a nested dictionary. For - example: +Optional. Whether to minimize or maximize the metric defined above. The default value is ``true`` +(minimize). - .. code:: yaml +``max_concurrent_trials`` +------------------------- - max_length: - epochs: 2 +Optional. The maximum number of trials that can be worked on simultaneously. The default value is +``16``. When the value is ``0`` we will work on as many trials as possible. -- :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and - :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, - :ref:`records_per_epoch ` must be specified. +``source_trial_id`` +------------------- -**Optional Fields** +Optional. If specified, the weights of this trial will be initialized to the most recent checkpoint +of the given trial ID. This will fail if the source trial's model architecture is inconsistent with +the model architecture of this experiment. + +``source_checkpoint_uuid`` +-------------------------- + +Optional. Like ``source_trial_id``, but specifies an arbitrary checkpoint from which to initialize +weights. At most one of ``source_trial_id`` or ``source_checkpoint_uuid`` should be set. + +.. _experiment-configuration-searcher-asha: + +Asynchronous Halving (ASHA) +=========================== + +The ``async_halving`` search performs a version of the asynchronous successive halving algorithm +(`ASHA `_) that stops trials early if there is enough evidence +to terminate training. Once trials are stopped, they will not be resumed. + +``metric`` +---------- + +Required. The name of the validation metric used to evaluate the performance of a hyperparameter +configuration. + +``time_metric`` +--------------- + +Required. The name of the validation metric used to evaluate the progress of a given trial. + +``max_time`` +------------ + +Required. The maximum value that ``time_metric`` should take when a trial finishes training. Early +stopping is decided based on how far the ``time_metric`` has progressed towards this ``max_time`` +value. + +``max_trials`` +-------------- + +Required. The number of trials, i.e., hyperparameter configurations, to evaluate. + +``num_rungs`` +------------- + +Required. The number of rounds of successive halving to perform. ``smaller_is_better`` --------------------- @@ -1081,18 +999,26 @@ Required. The length of each trial. Optional. Whether to minimize or maximize the metric defined above. The default value is ``true`` (minimize). +``divisor`` +----------- + +Optional. The fraction of trials to keep at each rung, and also determines the training length for +each rung. The default setting is ``4``; only advanced users should consider changing this value. + ``max_concurrent_trials`` ------------------------- Optional. The maximum number of trials that can be worked on simultaneously. The default value is -``16``. When the value is ``0`` we will work on as many trials as possible. +``16``, and we set reasonable values depending on ``max_trials`` and the number of rungs in the +brackets. This is akin to controlling the degree of parallelism of the experiment. If this value is +less than the number of brackets produced by the adaptive algorithm, it will be rounded up. ``source_trial_id`` ------------------- -Optional. If specified, the weights of this trial will be initialized to the most recent checkpoint -of the given trial ID. This will fail if the source trial's model architecture is inconsistent with -the model architecture of this experiment. +Optional. If specified, the weights of *every* trial in the search will be initialized to the most +recent checkpoint of the given trial ID. This will fail if the source trial's model architecture is +inconsistent with the model architecture of any of the trials in this experiment. ``source_checkpoint_uuid`` -------------------------- @@ -1115,25 +1041,17 @@ experiments with hundreds or thousands of trials. Required. The name of the validation metric used to evaluate the performance of a hyperparameter configuration. -``max_length`` --------------- - -Required. The maximum training length of any one trial. The vast majority of trials will be stopped -early, and thus only a small fraction of trials will actually be trained for this long. This -quantity is domain-specific and should roughly reflect the length of training needed for the model -to converge on the data set. - -- This needs to be set in the unit of records, batches, or epochs using a nested dictionary. For - example: +``time_metric`` +--------------- - .. code:: yaml +Required. The name of the validation metric used to evaluate the progress of a given trial. - max_length: - epochs: 2 +``max_time`` +------------ -- :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and - :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, - :ref:`records_per_epoch ` must be specified. +Required. The maximum value that ``time_metric`` should take when a trial finishes training. Early +stopping is decided based on how far the ``time_metric`` has progressed towards this ``max_time`` +value. ``max_trials`` -------------- @@ -1159,14 +1077,6 @@ end of the spectrum, ``conservative`` mode performs significantly less downsampl consequence does not explore as many configurations given the same budget. We recommend using either ``aggressive`` or ``standard`` mode. -``stop_once`` -------------- - -Optional. If ``stop_once`` is set to ``true``, we will use a variant of ASHA that will not resume -trials once stopped. This variant defaults to continuing training and will only stop trials if there -is enough evidence to terminate training. We recommend using this version of ASHA when training a -trial for the max length as fast as possible is important or when fault tolerance is too expensive. - ``divisor`` ----------- diff --git a/docs/reference/training/_index.rst b/docs/reference/training/_index.rst index 84ace82203d..aa221c7b906 100644 --- a/docs/reference/training/_index.rst +++ b/docs/reference/training/_index.rst @@ -15,6 +15,7 @@ - :ref:`det.pytorch.samplers ` - :ref:`det.pytorch.deepspeed ` - :ref:`det.keras ` +- :ref:`det.transformers ` ******************************* Experiment Configuration File diff --git a/docs/reference/training/api-core-reference.rst b/docs/reference/training/api-core-reference.rst index 5bb74671a43..49513066944 100644 --- a/docs/reference/training/api-core-reference.rst +++ b/docs/reference/training/api-core-reference.rst @@ -99,10 +99,3 @@ ************************************* .. autoclass:: determined.core.TensorboardMode - -************************** - ``determined.TrialInfo`` -************************** - -.. autoclass:: determined.TrialInfo - :members: diff --git a/docs/reference/training/api-det-reference.rst b/docs/reference/training/api-det-reference.rst index 857d4bcaf1d..aa5c1e5d9d6 100644 --- a/docs/reference/training/api-det-reference.rst +++ b/docs/reference/training/api-det-reference.rst @@ -13,6 +13,13 @@ .. autoclass:: determined.ClusterInfo :members: +************************** + ``determined.TrialInfo`` +************************** + +.. autoclass:: determined.TrialInfo + :members: + ********************************* ``determined.import_from_path`` ********************************* diff --git a/docs/reference/training/api-keras-reference.rst b/docs/reference/training/api-keras-reference.rst index b24db8644e0..150722d88be 100644 --- a/docs/reference/training/api-keras-reference.rst +++ b/docs/reference/training/api-keras-reference.rst @@ -10,6 +10,29 @@ | :ref:`api-keras-ug` | +-------------------------------------------------+ +***************************************** + ``determined.keras.DeterminedCallback`` +***************************************** + +.. autoclass:: determined.keras.DeterminedCallback + :members: save_model, load_model + :member-order: bysource + :special-members: __init__ + +********************************** + ``determined.keras.TensorBoard`` +********************************** + +.. autoclass:: determined.keras.TensorBoard + +################# + Deprecated APIs +################# + +The following APIs have been deprecated as of Determined 0.38.0 and will be removed in a future +version. Please migrate your ``TFKerasTrial``-based training to use the new +:class:`~determined.keras.DeterminedCallback` instead. + *********************************** ``determined.keras.TFKerasTrial`` *********************************** diff --git a/docs/reference/training/api-transformers-reference.rst b/docs/reference/training/api-transformers-reference.rst new file mode 100644 index 00000000000..8ded8cf68c6 --- /dev/null +++ b/docs/reference/training/api-transformers-reference.rst @@ -0,0 +1,11 @@ +.. _transformers-reference: + +#################################### + ``det.transformers`` API Reference +#################################### + +***************************************** + ``determined.transformers.DetCallback`` +***************************************** + +.. autoclass:: determined.transformers.DetCallback diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 2a741eb64e4..98bf8843186 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -1408,8 +1408,8 @@ Version 0.23.0 **New Features** -- Experiment: :ref:`Custom hyperparameter searchers ` can - include extra directories to pass into the ``client.create_experiment`` context. +- Experiment: Custom hyperparameter searchers can include extra directories to pass into the + ``client.create_experiment`` context. - Checkpoints: Add support for deleting a subset of files from checkpoints. @@ -2063,8 +2063,7 @@ Version 0.19.6 - Custom Searcher: users can now define their own logic to coordinate across multiple trials within an experiment. Examples of use cases are custom hyperparameter searching algorithms, ensembling, - active learning, neural architecture search, reinforcement learning. See - :ref:`topic-guides_hp-tuning-det_custom` for more information. + active learning, neural architecture search, reinforcement learning. - Cluster: The enterprise edition of `HPE Machine Learning Development Environment `_ can now be diff --git a/docs/release-notes/remove-custom-searcher.rst b/docs/release-notes/remove-custom-searcher.rst new file mode 100644 index 00000000000..3e6c1a642e5 --- /dev/null +++ b/docs/release-notes/remove-custom-searcher.rst @@ -0,0 +1,7 @@ +:orphan: + +**Breaking Changes** + +- API: Custom Searcher (including DeepSpeed AutoTune) was deprecated in 0.36.0 and is now removed. + We will maintain first-class support for a variety of preset searchers, which can be easily + configured for any experiment. Visit :ref:`search-methods` for details. diff --git a/docs/tools/tensorboard.rst b/docs/tools/tensorboard.rst index 7263b66f0a4..8f0796e4d58 100644 --- a/docs/tools/tensorboard.rst +++ b/docs/tools/tensorboard.rst @@ -138,20 +138,8 @@ To configure TensorBoard for a specific framework, follow the examples below: TensorFlow Keras ================ -For models using :class:`~determined.keras.TFKerasTrial`, add a -:class:`determined.keras.callabacks.TensorBoard` callback to your trial class: - -.. code:: python - - from determined.keras import TFKerasTrial - from determined.keras.callbacks import TensorBoard - - - class MyModel(TFKerasTrial): - ... - - def keras_callbacks(self): - return [TensorBoard()] +For models using :class:`~determined.keras.DeterminedCallback`, include a +:class:`determined.keras.TensorBoard` callback in your ``model.fit()`` call.: PyTorch ======= @@ -196,10 +184,9 @@ Any additional TFEvent files that are written to the appropriate path during tra to TensorBoard. The appropriate path varies by worker rank and can be obtained by one of the following functions: -- For CoreAPI users: :func:`~determined.core.TrainContext.get_tensorboard_path` +- For CoreAPI and Keras users: :func:`~determined.core.TrainContext.get_tensorboard_path` - For PyTorchTrial users: :func:`~determined.pytorch.PyTorchTrialContext.get_tensorboard_path` - For DeepSpeedTrial users: :func:`~determined.pytorch.deepspeed.DeepSpeedTrialContext.get_tensorboard_path` -- For TFKerasTrial users: :func:`~determined.keras.TFKerasTrialContext.get_tensorboard_path` For more details and examples, refer to the :ref:`TensorBoard How-To Guide `. diff --git a/docs/tutorials/detached-mode/transition-managed-determined.rst b/docs/tutorials/detached-mode/transition-managed-determined.rst index 45ac356a6d8..4c3c6ef56b1 100644 --- a/docs/tutorials/detached-mode/transition-managed-determined.rst +++ b/docs/tutorials/detached-mode/transition-managed-determined.rst @@ -72,8 +72,6 @@ Use the following code to create the experiment configuration file: name: single # metric is required but it shouldn't hurt to ignore it at this point. metric: x - # max_length is ignored if the training script ignores it. - max_length: 1 max_restarts: 0 diff --git a/docs/tutorials/pachyderm-cat-dog.rst b/docs/tutorials/pachyderm-cat-dog.rst index 7ecd65475da..606005961ea 100644 --- a/docs/tutorials/pachyderm-cat-dog.rst +++ b/docs/tutorials/pachyderm-cat-dog.rst @@ -243,8 +243,6 @@ The configuration should resemble the following: searcher: name: single metric: accuracy - max_length: - batches: 100 smaller_is_better: false entrypoint: model_def:CatDogModel scheduling_unit: 10 diff --git a/docs/tutorials/pytorch-mnist-tutorial.rst b/docs/tutorials/pytorch-mnist-tutorial.rst index 8f441e7d980..889f533cbca 100644 --- a/docs/tutorials/pytorch-mnist-tutorial.rst +++ b/docs/tutorials/pytorch-mnist-tutorial.rst @@ -261,8 +261,6 @@ fixed values for the model's hyperparameters: .. code:: yaml name: mnist_pytorch_const - data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz hyperparameters: learning_rate: 1.0 global_batch_size: 64 @@ -273,15 +271,8 @@ fixed values for the model's hyperparameters: searcher: name: single metric: validation_loss - max_length: - epochs: 1 smaller_is_better: true - entrypoint: model_def:MNistTrial - -The ``entrypoint`` specifies the name of the trial class to use. This is useful if the model code -contains more than one trial class. In this case, we use an entrypoint of ``model_def:MNistTrial`` -because our trial class is named ``MNistTrial`` and it is defined in a Python file named -``model_def.py``. + entrypoint: python3 train.py --epochs 1 For more information on experiment configuration, see the :ref:`experiment configuration reference `. diff --git a/docs/tutorials/quickstart-mdldev.rst b/docs/tutorials/quickstart-mdldev.rst index 0bdaaba2da3..25b85ad4541 100644 --- a/docs/tutorials/quickstart-mdldev.rst +++ b/docs/tutorials/quickstart-mdldev.rst @@ -323,8 +323,8 @@ The ``adaptive_asha`` search method and maximum number of trials, max_trials` ar metric: validation_loss smaller_is_better: true max_trials: 16 - max_length: - batches: 937 + time_metric: batch + max_time: 937 This example uses a fixed batch size and searches on dropout size, filters, and learning rate. The ``max_trials`` setting of ``16`` indicates how many model configurations to explore. diff --git a/docs/tutorials/viewing-epoch-based-metrics.rst b/docs/tutorials/viewing-epoch-based-metrics.rst index 9481b8efbba..43a94ba869e 100644 --- a/docs/tutorials/viewing-epoch-based-metrics.rst +++ b/docs/tutorials/viewing-epoch-based-metrics.rst @@ -12,8 +12,8 @@ Sometimes, you want to analyze and visualize your model's training progress and performance over multiple epochs. In this article, we'll show you how to view epoch-based metric data in the WebUI by reporting an -epoch metric to the Determined master via the Core API. To do this, we'll define an epoch metric and -use it as the X-Axis label in the WebUI. +``epochs`` metric to the Determined master via the Core API. To do this, we'll define an ``epochs`` +metric and use it as the X-Axis label in the WebUI. **Recommended** @@ -73,7 +73,7 @@ In the WebUI, we can select our experiment and visit the **Logs** tab. Step 2: Report Epoch-Based Metrics ************************************ -In this section, we'll define our epoch metric. +In this section, we'll define our ``epochs`` metric. - To follow along, use the ``model_def_metrics.py`` script and its accompanying ``metrics.yaml`` experiment configuration file. @@ -96,27 +96,31 @@ training and validation metrics. However, we also want to report epoch-based metrics and to allow Determined to keep track of the specific epoch for which training loss is being reported. -- To do this, we'll modify the train() method to include ``epoch_idx`` as a metric: +- To do this, we'll modify the train() method to include ``epochs`` as a metric. We will calculate + fractional completed epochs based on ``batches_completed``, since this training code reports more + frequently than once per epoch: .. code:: python + partial_epoch = batches_completed / len(training_loader) core_context.train.report_training_metrics( steps_completed=batches_completed + epoch_idx * len(train_loader), - metrics={"train_loss": loss.item(), "epoch": epoch_idx}, + metrics={"train_loss": loss.item(), "epochs": epoch_idx + partial_epoch}, ) -- Similarly, we'll include ``epoch`` as a metric in the reported validation metrics. This allows +- Similarly, we'll include ``epochs`` as a metric in the reported validation metrics. This allows Determined to track the specific epoch for which the validation loss is being reported: .. code:: python + epochs_completed = epoch_idx + 1 core_context.train.report_validation_metrics( steps_completed=steps_completed, - metrics={"test_loss": test_loss, "epoch": epoch}, + metrics={"test_loss": test_loss, "epochs": epochs_completed}, ) -Now that we've reported an epoch value, **Epoch** will be an available option for the X-Axis when we -view our metric data graph in the WebUI. +Now that we've reported an ``epochs`` metric, **Epochs** will be an available option for the X-Axis +when we view our metric data graph in the WebUI. Step 2.2: Run the Experiment & View Epoch-Based Metrics ======================================================= @@ -133,7 +137,7 @@ Our modified script is ready to report epoch-based metrics to the Determined mas Our experiment opens in the **Overview** tab. -- We'll go to the **Metrics** tab, select the **X-Axis** menu and then choose **Epoch**. +- We'll go to the **Metrics** tab, select the **X-Axis** menu and then choose **Epochs**. - If we scroll down, we'll be able to see the epoch-based metrics graph. .. image:: ../assets/images/webui-metrics-epoch-based.png diff --git a/e2e_tests/tests/cluster/test_slurm.py b/e2e_tests/tests/cluster/test_slurm.py index 54b26f0de2f..5a8166e0d56 100644 --- a/e2e_tests/tests/cluster/test_slurm.py +++ b/e2e_tests/tests/cluster/test_slurm.py @@ -177,8 +177,8 @@ def test_docker_login() -> None: def test_mnist_pytorch_distributed() -> None: sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/distributed.yaml")) - config["searcher"]["max_length"] = {"epochs": 1} - config["records_per_epoch"] = 64 + assert "--epochs 1" in config["entrypoint"], "update test to match tutorial" + config["entrypoint"] = config["entrypoint"].replace("--epochs 1", "--batches 64") config["max_restarts"] = 0 exp.run_basic_test_with_temp_config(sess, config, conf.fixtures_path("mnist_pytorch"), 1) diff --git a/e2e_tests/tests/cluster/test_users.py b/e2e_tests/tests/cluster/test_users.py index cef3572c2e5..5609bce09d2 100644 --- a/e2e_tests/tests/cluster/test_users.py +++ b/e2e_tests/tests/cluster/test_users.py @@ -547,7 +547,6 @@ def test_non_root_experiment(tmp_path: pathlib.Path) -> None: "searcher": { "name": "single", "metric": "x", - "max_length": 1, }, } exp_ref = noop.create_experiment(sess, config=config) diff --git a/e2e_tests/tests/config.py b/e2e_tests/tests/config.py index 212ef35d2d9..cfac33d5fd9 100644 --- a/e2e_tests/tests/config.py +++ b/e2e_tests/tests/config.py @@ -1,6 +1,6 @@ import os import pathlib -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from determined.common import api, util @@ -70,10 +70,6 @@ def deepspeed_examples_path(path: str) -> str: return os.path.join(os.path.dirname(__file__), "../../examples/deepspeed", path) -def deepspeed_autotune_examples_path(path: str) -> str: - return os.path.join(os.path.dirname(__file__), "../../examples/deepspeed_autotune", path) - - def hf_trainer_examples_path(path: str) -> str: return os.path.join(os.path.dirname(__file__), "../../examples/hf_trainer_api", path) @@ -101,14 +97,6 @@ def set_slots_per_trial(config: Dict[Any, Any], slots: int) -> Dict[Any, Any]: return config -def set_max_length( - config: Dict[Any, Any], max_length: Union[Dict[str, int], int] -) -> Dict[Any, Any]: - config = config.copy() - config["searcher"]["max_length"] = max_length - return config - - def set_min_validation_period( config: Dict[Any, Any], min_validation_period: Dict[str, int] ) -> Dict[Any, Any]: diff --git a/e2e_tests/tests/experiment/__init__.py b/e2e_tests/tests/experiment/__init__.py index c5133437c0c..ed36ec370af 100644 --- a/e2e_tests/tests/experiment/__init__.py +++ b/e2e_tests/tests/experiment/__init__.py @@ -28,7 +28,6 @@ root_user_home_bind_mount, run_basic_test, run_basic_test_with_temp_config, - run_basic_autotuning_test, run_failure_test, run_failure_test_with_temp_config, s3_checkpoint_config, diff --git a/e2e_tests/tests/experiment/experiment.py b/e2e_tests/tests/experiment/experiment.py index 21ceb462e34..76814de48ba 100644 --- a/e2e_tests/tests/experiment/experiment.py +++ b/e2e_tests/tests/experiment/experiment.py @@ -62,59 +62,6 @@ def create_experiment( return int(m.group(1)) -def maybe_run_autotuning_experiment( - sess: api.Session, - config_file: str, - model_def_file: str, - create_args: Optional[List[str]] = None, - search_method_name: str = "_test", - max_trials: int = 4, -) -> subprocess.CompletedProcess: - command = [ - "python3", - "-m", - "determined.pytorch.dsat", - search_method_name, - config_file, - model_def_file, - "--max-trials", - str(max_trials), - ] - - if create_args is not None: - command += create_args - - env = os.environ.copy() - env["DET_DEBUG"] = "true" - env["DET_MASTER"] = conf.make_master_url() - - return detproc.run( - sess, - command, - universal_newlines=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) - - -def run_autotuning_experiment( - sess: api.Session, - config_file: str, - model_def_file: str, - create_args: Optional[List[str]] = None, - search_method_name: str = "_test", - max_trials: int = 4, -) -> int: - p = maybe_run_autotuning_experiment( - sess, config_file, model_def_file, create_args, search_method_name, max_trials - ) - assert p.returncode == 0, f"\nstdout:\n{p.stdout}\nstderr:\n{p.stderr}" - m = re.search(r"Created experiment (\d+)\n", str(p.stdout)) - assert m is not None - return int(m.group(1)) - - def archive_experiments( sess: api.Session, experiment_ids: List[int], project_id: int, name: Optional[str] = None ) -> None: @@ -805,76 +752,6 @@ def run_basic_test( return experiment_id -def run_basic_autotuning_test( - sess: api.Session, - config_file: str, - model_def_file: str, - expected_trials: Optional[int], - create_args: Optional[List[str]] = None, - max_wait_secs: int = conf.DEFAULT_MAX_WAIT_SECS, - expect_workloads: bool = True, - expect_checkpoints: bool = True, - priority: int = -1, - expect_client_failed: bool = False, - search_method_name: str = "_test", - max_trials: int = 4, -) -> int: - assert os.path.isdir(model_def_file) - orchestrator_exp_id = run_autotuning_experiment( - sess, config_file, model_def_file, create_args, search_method_name, max_trials - ) - if priority != -1: - set_priority(sess, experiment_id=orchestrator_exp_id, priority=priority) - - # Wait for the Autotuning Single Searcher ("Orchestrator") to finish - wait_for_experiment_state( - sess, - orchestrator_exp_id, - bindings.experimentv1State.COMPLETED, - max_wait_secs=max_wait_secs, - ) - assert num_active_trials(sess, orchestrator_exp_id) == 0 - verify_completed_experiment_metadata( - sess, orchestrator_exp_id, expected_trials, expect_workloads, expect_checkpoints - ) - client_exp_id = fetch_autotuning_client_experiment(sess, orchestrator_exp_id) - - # Wait for the Autotuning Custom Searcher Experiment ("Client Experiment") to finish - wait_for_experiment_state( - sess, - client_exp_id, - ( - bindings.experimentv1State.COMPLETED - if not expect_client_failed - else bindings.experimentv1State.ERROR - ), - max_wait_secs=max_wait_secs, - ) - assert num_active_trials(sess, orchestrator_exp_id) == 0 - verify_completed_experiment_metadata( - sess, orchestrator_exp_id, expected_trials, expect_workloads, expect_checkpoints - ) - return client_exp_id - - -def fetch_autotuning_client_experiment(sess: api.Session, exp_id: int) -> int: - command = ["det", "experiment", "logs", str(exp_id)] - env = os.environ.copy() - env["DET_DEBUG"] = "true" - p = detproc.run( - sess, - command, - universal_newlines=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) - assert p.returncode == 0, f"\nstdout:\n{p.stdout} \nstderr:\n{p.stderr}" - m = re.search(r"Created experiment (\d+)\n", str(p.stdout)) - assert m is not None - return int(m.group(1)) - - def set_priority(sess: api.Session, experiment_id: int, priority: int) -> None: command = [ "det", diff --git a/e2e_tests/tests/experiment/noop.py b/e2e_tests/tests/experiment/noop.py index b80dddc3d66..57b28ee1dee 100644 --- a/e2e_tests/tests/experiment/noop.py +++ b/e2e_tests/tests/experiment/noop.py @@ -55,15 +55,7 @@ def to_dict(self) -> Dict[str, Any]: return {"action": "log", "base64": self.base64, "level": self.level} -class CompleteSearcherOperation: - def __init__(self, metric: float) -> None: - self.metric = metric - - def to_dict(self) -> Dict[str, Any]: - return {"action": "complete_searcher_operation", "metric": self.metric} - - -Action = Union[Exit, Sleep, Report, Checkpoint, Log, CompleteSearcherOperation] +Action = Union[Exit, Sleep, Report, Checkpoint, Log] def merge_config(old: Any, new: Any) -> Any: @@ -92,7 +84,6 @@ def generate_config( "searcher": { "name": "single", "metric": "x", - "max_length": 1, }, } @@ -161,7 +152,6 @@ def create_paused_experiment( "searcher": { "name": "single", "metric": "x", - "max_length": 1, }, "entrypoint": "echo yo", } diff --git a/e2e_tests/tests/experiment/test_core.py b/e2e_tests/tests/experiment/test_core.py index 894e2cdf5c2..29505459961 100644 --- a/e2e_tests/tests/experiment/test_core.py +++ b/e2e_tests/tests/experiment/test_core.py @@ -2,7 +2,7 @@ import pytest -from determined.common import api +from determined.common import api, experimental from determined.common.api import bindings from determined.experimental import client from tests import api_utils @@ -139,16 +139,13 @@ def test_end_to_end_adaptive() -> None: d = client.Determined._from_session(sess) exp_ref = d.get_experiment(exp_id) - top_2 = exp_ref.top_n_checkpoints(2) - top_k = exp_ref.top_n_checkpoints( - len(trials), sort_by="validation_loss", smaller_is_better=True + top_k = exp_ref.list_checkpoints( + sort_by=experimental.checkpoint.CheckpointSortBy.SEARCHER_METRIC, + order_by=experimental.OrderBy.ASCENDING, ) - top_2_uuids = [c.uuid for c in top_2] top_k_uuids = [c.uuid for c in top_k] - assert top_2_uuids == top_k_uuids[:2] - # Check that metrics are truly in sorted order. assert all(c.training is not None for c in top_k) metrics = [ @@ -160,11 +157,12 @@ def test_end_to_end_adaptive() -> None: assert metrics == sorted(metrics) # Check that changing smaller is better reverses the checkpoint ordering. - top_k_reversed = exp_ref.top_n_checkpoints( - len(trials), sort_by="validation_loss", smaller_is_better=False + top_k_reversed = exp_ref.list_checkpoints( + sort_by=experimental.checkpoint.CheckpointSortBy.SEARCHER_METRIC, + order_by=experimental.OrderBy.DESCENDING, ) - top_k_reversed_uuids = [c.uuid for c in top_k_reversed] + top_k_reversed_uuids = [c.uuid for c in top_k_reversed] assert top_k_uuids == top_k_reversed_uuids[::-1] checkpoint = top_k[0] @@ -208,28 +206,6 @@ def test_end_to_end_adaptive() -> None: assert checkpoint.metadata == db_check.metadata -@pytest.mark.e2e_cpu -def test_graceful_trial_termination() -> None: - sess = api_utils.user_session() - config = { - "hyperparameters": { - "actions": { - "1": { - "type": "categorical", - "vals": [ - # One trial completes its searcher operation. - noop.CompleteSearcherOperation(1.0).to_dict(), - # The other trial just exits 0. - noop.Exit(0).to_dict(), - ], - } - } - } - } - exp_ref = noop.create_experiment(sess, config=config) - assert exp_ref.wait(interval=0.01) == client.ExperimentState.COMPLETED - - @pytest.mark.e2e_cpu def test_kill_experiment_ignoring_preemption() -> None: sess = api_utils.user_session() diff --git a/e2e_tests/tests/experiment/test_custom_searcher.py b/e2e_tests/tests/experiment/test_custom_searcher.py deleted file mode 100644 index fcd78bfdff4..00000000000 --- a/e2e_tests/tests/experiment/test_custom_searcher.py +++ /dev/null @@ -1,498 +0,0 @@ -import logging -import pathlib -import tempfile -import time -from typing import List, Optional - -import pytest -from urllib3 import connectionpool - -from determined import searcher -from determined.common import api, util -from determined.common.api import bindings -from determined.experimental import client -from tests import api_utils -from tests import config as conf -from tests import detproc -from tests import experiment as exp -from tests.fixtures.custom_searcher import searchers - -TIMESTAMP = int(time.time()) - - -def check_trial_state( - sess: api.Session, trial: bindings.trialv1Trial, expect: bindings.trialv1State -) -> bool: - """If the trial is in an unexpected state, dump logs and return False.""" - if trial.state == expect: - return True - exp.print_trial_logs(sess, trial.id) - return False - - -@pytest.mark.e2e_cpu -def test_run_custom_searcher_experiment(tmp_path: pathlib.Path) -> None: - sess = api_utils.user_session() - client._determined = client.Determined._from_session(sess) - # example searcher script - config = conf.load_config(conf.fixtures_path("custom_searcher_exp/single.yaml")) - config["searcher"] = { - "name": "custom", - "metric": "validation_error", - "smaller_is_better": True, - "unit": "batches", - } - config["name"] = "single" - config["description"] = "custom searcher" - search_method = searchers.SingleSearchMethod(config, 500) - search_runner = searcher.LocalSearchRunner(search_method, tmp_path, session=sess) - experiment_id = search_runner.run(config, model_dir=conf.fixtures_path("custom_searcher_exp")) - - assert client._determined is not None - response = bindings.get_GetExperiment(sess, experimentId=experiment_id) - assert response.experiment.numTrials == 1 - - -@pytest.mark.e2e_cpu_2a -def test_run_random_searcher_exp() -> None: - sess = api_utils.user_session() - client._determined = client.Determined._from_session(sess) - config = conf.load_config(conf.fixtures_path("custom_searcher_exp/single.yaml")) - config["searcher"] = { - "name": "custom", - "metric": "validation_error", - "smaller_is_better": True, - "unit": "batches", - } - config["name"] = "random" - config["description"] = "custom searcher" - - max_trials = 5 - max_concurrent_trials = 2 - max_length = 500 - - with tempfile.TemporaryDirectory() as searcher_dir: - search_method = searchers.RandomSearchMethod( - max_trials, max_concurrent_trials, max_length, test_type="noop" - ) - search_runner = searcher.LocalSearchRunner( - search_method, pathlib.Path(searcher_dir), session=sess - ) - experiment_id = search_runner.run( - config, model_dir=conf.fixtures_path("custom_searcher_exp") - ) - - response = bindings.get_GetExperiment(sess, experimentId=experiment_id) - assert response.experiment.numTrials == 5 - assert search_method.created_trials == 5 - assert search_method.pending_trials == 0 - assert search_method.closed_trials == 5 - assert len(search_runner.state.trials_created) == search_method.created_trials - assert len(search_runner.state.trials_closed) == search_method.closed_trials - - -@pytest.mark.e2e_cpu_2a -@pytest.mark.parametrize( - "config_name,exp_name,exception_points,metric_as_dict", - [ - ("core_api_model.yaml", f"custom-searcher-random-test-{TIMESTAMP}", [], True), - ( - "core_api_model.yaml", - f"custom-searcher-random-test-fail1-{TIMESTAMP}", - ["initial_operations_start", "progress_middle", "on_trial_closed_shutdown"], - False, - ), - ( - "core_api_model.yaml", - f"custom-searcher-random-test-fail2-{TIMESTAMP}", - ["on_validation_completed", "on_trial_closed_end", "on_trial_created_5"], - False, - ), - ( - "core_api_model.yaml", - f"custom-searcher-random-test-fail3-{TIMESTAMP}", - ["on_trial_created", "after_save"], - False, - ), - ( - "core_api_model.yaml", - f"custom-searcher-random-test-fail5-{TIMESTAMP}", - [ - "on_trial_created", - "after_save", - "after_save", - "on_validation_completed", - "after_save", - ], - False, - ), - ], -) -def test_run_random_searcher_exp_core_api( - config_name: str, - exp_name: str, - exception_points: List[str], - metric_as_dict: bool, -) -> None: - sess = api_utils.user_session() - config = conf.load_config(conf.fixtures_path("custom_searcher/core_api_searcher_random.yaml")) - config["entrypoint"] += " --exp-name " + exp_name - config["entrypoint"] += " --config-name " + config_name - if len(exception_points) > 0: - config["entrypoint"] += " --exception-points " + " ".join(exception_points) - if metric_as_dict: - config["entrypoint"] += " --metric-as-dict" - config["max_restarts"] = len(exception_points) - - experiment_id = exp.run_basic_test_with_temp_config( - sess, config, conf.fixtures_path("custom_searcher"), 1 - ) - - # searcher experiment - searcher_exp = bindings.get_GetExperiment(sess, experimentId=experiment_id).experiment - assert searcher_exp.state == bindings.experimentv1State.COMPLETED - - # actual experiment - response = bindings.get_GetExperiments(sess, name=exp_name) - experiments = response.experiments - assert len(experiments) == 1 - - experiment = experiments[0] - assert experiment.numTrials == 5 - - trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment.id).trials - - ok = True - for trial in trials: - ok = ok and check_trial_state(sess, trial, bindings.trialv1State.COMPLETED) - assert ok, "some trials failed" - - for trial in trials: - assert trial.totalBatchesProcessed == 500 - - # check logs to ensure failures actually happened - logs = detproc.check_output(sess, ["det", "experiment", "logs", str(experiment_id)]) - failures = logs.count("Max retries exceeded with url: http://dummyurl (Caused by None)") - assert failures == len(exception_points) - - # check for resubmitting operations - resubmissions = logs.count("determined.searcher: Resubmitting operations for event.id=") - assert resubmissions == sum([x == "after_save" for x in exception_points]) - - -@pytest.mark.e2e_cpu_2a -def test_pause_multi_trial_random_searcher_core_api() -> None: - sess = api_utils.user_session() - config = conf.load_config(conf.fixtures_path("custom_searcher/core_api_searcher_random.yaml")) - exp_name = f"random-pause-{TIMESTAMP}" - config["entrypoint"] += " --exp-name " + exp_name - config["entrypoint"] += " --config-name core_api_model.yaml" - - model_def_path = conf.fixtures_path("custom_searcher") - - with tempfile.NamedTemporaryFile() as tf: - with open(tf.name, "w") as f: - util.yaml_safe_dump(config, f) - - searcher_exp_id = exp.create_experiment(sess, tf.name, model_def_path, None) - exp.wait_for_experiment_state( - sess, - searcher_exp_id, - bindings.experimentv1State.RUNNING, - ) - # make sure both experiments have started by checking - # that multi-trial experiment has at least 1 running trials - multi_trial_exp_id = exp.wait_for_experiment_by_name_is_active(sess, exp_name, 1) - - # pause multi-trial experiment - exp.pause_experiment(sess, multi_trial_exp_id) - exp.wait_for_experiment_state(sess, multi_trial_exp_id, bindings.experimentv1State.PAUSED) - - # activate multi-trial experiment - exp.activate_experiment(sess, multi_trial_exp_id) - - # wait for searcher to complete - exp.wait_for_experiment_state(sess, searcher_exp_id, bindings.experimentv1State.COMPLETED) - - # searcher experiment - searcher_exp = bindings.get_GetExperiment(sess, experimentId=searcher_exp_id).experiment - assert searcher_exp.state == bindings.experimentv1State.COMPLETED - - # actual experiment - experiment = bindings.get_GetExperiment(sess, experimentId=multi_trial_exp_id).experiment - assert experiment.numTrials == 5 - - trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment.id).trials - - ok = True - for trial in trials: - ok = ok and check_trial_state(sess, trial, bindings.trialv1State.COMPLETED) - assert ok, "some trials failed" - - for trial in trials: - assert trial.totalBatchesProcessed == 500 - - -@pytest.mark.e2e_cpu_2a -@pytest.mark.parametrize( - "exceptions", - [ - ["initial_operations_start", "progress_middle", "on_trial_closed_shutdown"], - ["on_validation_completed", "on_trial_closed_end", "on_trial_created_5"], - ["on_trial_created", "save_method_state", "after_save"], - [ - "on_trial_created", - "save_method_state", - "load_method_state", - "after_save", - "after_save", - "on_validation_completed", - "after_save", - "save_method_state", - ], - ], -) -def test_resume_random_searcher_exp(exceptions: List[str]) -> None: - sess = api_utils.user_session() - config = conf.load_config(conf.fixtures_path("custom_searcher_exp/single.yaml")) - config["searcher"] = { - "name": "custom", - "metric": "validation_error", - "smaller_is_better": True, - "unit": "batches", - } - config["description"] = ";".join(exceptions) if exceptions else "custom searcher" - - max_trials = 5 - max_concurrent_trials = 2 - max_length = 500 - failures_expected = len(exceptions) - logging.info(f"expected_failures={failures_expected}") - - # do not use pytest tmp_path to experience LocalSearchRunner in the wild - with tempfile.TemporaryDirectory() as searcher_dir: - failures = 0 - while failures < failures_expected: - try: - exception_point = exceptions.pop(0) - # re-create RandomSearchMethod and LocalSearchRunner after every fail - # to simulate python process crash - search_method = searchers.RandomSearchMethod( - max_trials, - max_concurrent_trials, - max_length, - test_type="noop", - exception_points=[exception_point], - ) - search_runner_mock = FallibleSearchRunner( - exception_point, search_method, pathlib.Path(searcher_dir) - ) - search_runner_mock.run(config, model_dir=conf.fixtures_path("custom_searcher_exp")) - pytest.fail("Expected an exception") - except connectionpool.MaxRetryError: - failures += 1 - - assert failures == failures_expected - - search_method = searchers.RandomSearchMethod( - max_trials, max_concurrent_trials, max_length, test_type="noop" - ) - search_runner = searcher.LocalSearchRunner( - search_method, pathlib.Path(searcher_dir), session=sess - ) - experiment_id = search_runner.run( - config, model_dir=conf.fixtures_path("custom_searcher_exp") - ) - - assert search_runner.state.last_event_id == 41 - assert search_runner.state.experiment_completed is True - response = bindings.get_GetExperiment(sess, experimentId=experiment_id) - assert response.experiment.numTrials == 5 - assert search_method.created_trials == 5 - assert search_method.pending_trials == 0 - assert search_method.closed_trials == 5 - assert len(search_runner.state.trials_created) == search_method.created_trials - assert len(search_runner.state.trials_closed) == search_method.closed_trials - - assert search_method.progress(search_runner.state) == pytest.approx(1.0) - - -@pytest.mark.nightly -def test_run_asha_batches_exp(tmp_path: pathlib.Path) -> None: - sess = api_utils.user_session() - client._determined = client.Determined._from_session(sess) - config = conf.load_config(conf.fixtures_path("custom_searcher_exp/adaptive.yaml")) - config["searcher"] = { - "name": "custom", - "metric": "validation_error", - "smaller_is_better": True, - "unit": "batches", - } - config["name"] = "asha" - config["description"] = "custom searcher" - - max_length = 2000 - max_trials = 16 - num_rungs = 3 - divisor = 4 - - search_method = searchers.ASHASearchMethod( - max_length, max_trials, num_rungs, divisor, test_type="noop" - ) - search_runner = searcher.LocalSearchRunner(search_method, tmp_path) - experiment_id = search_runner.run(config, model_dir=conf.fixtures_path("custom_searcher_exp")) - - assert client._determined is not None - response = bindings.get_GetExperiment(sess, experimentId=experiment_id) - - assert response.experiment.numTrials == 16 - assert search_method.asha_search_state.pending_trials == 0 - assert search_method.asha_search_state.completed_trials == 16 - assert len(search_runner.state.trials_closed) == len( - search_method.asha_search_state.closed_trials - ) - - response_trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment_id).trials - - # 16 trials in rung 1 (#batches = 125) - assert sum(t.totalBatchesProcessed >= 125 for t in response_trials) == 16 - # at least 4 trials in rung 2 (#batches = 500) - assert sum(t.totalBatchesProcessed >= 500 for t in response_trials) >= 4 - # at least 1 trial in rung 3 (#batches = 2000) - assert sum(t.totalBatchesProcessed == 2000 for t in response_trials) >= 1 - - ok = True - for trial in response_trials: - ok = ok and check_trial_state(sess, trial, bindings.trialv1State.COMPLETED) - assert ok, "some trials failed" - - -@pytest.mark.nightly -@pytest.mark.parametrize( - "exceptions", - [ - [ - "initial_operations_start", # fail before sending initial operations - "after_save", # fail on save - should not send initial operations again - "save_method_state", - "save_method_state", - "after_save", - "on_trial_created", - "_get_close_rungs_ops", - ], - [ # searcher state and search method state are restored to last saved state - "on_validation_completed", - "on_validation_completed", - "save_method_state", - "save_method_state", - "after_save", - "after_save", - "load_method_state", - "on_validation_completed", - "shutdown", - ], - ], -) -def test_resume_asha_batches_exp(exceptions: List[str]) -> None: - sess = api_utils.user_session() - client._determined = client.Determined._from_session(sess) - config = conf.load_config(conf.fixtures_path("custom_searcher_exp/adaptive.yaml")) - config["searcher"] = { - "name": "custom", - "metric": "validation_error", - "smaller_is_better": True, - "unit": "batches", - } - config["name"] = "asha" - config["description"] = ";".join(exceptions) if exceptions else "custom searcher" - - max_length = 2000 - max_trials = 16 - num_rungs = 3 - divisor = 4 - failures_expected = len(exceptions) - - with tempfile.TemporaryDirectory() as searcher_dir: - logging.info(f"searcher_dir type = {type(searcher_dir)}") - failures = 0 - while failures < failures_expected: - try: - exception_point = exceptions.pop(0) - search_method = searchers.ASHASearchMethod( - max_length, - max_trials, - num_rungs, - divisor, - test_type="noop", - exception_points=[exception_point], - ) - search_runner_mock = FallibleSearchRunner( - exception_point, search_method, pathlib.Path(searcher_dir) - ) - search_runner_mock.run(config, model_dir=conf.fixtures_path("custom_searcher_exp")) - pytest.fail("Expected an exception") - except connectionpool.MaxRetryError: - failures += 1 - - assert failures == failures_expected - - search_method = searchers.ASHASearchMethod( - max_length, max_trials, num_rungs, divisor, test_type="noop" - ) - search_runner = searcher.LocalSearchRunner(search_method, pathlib.Path(searcher_dir)) - experiment_id = search_runner.run( - config, model_dir=conf.fixtures_path("custom_searcher_exp") - ) - - assert search_runner.state.experiment_completed is True - response = bindings.get_GetExperiment(sess, experimentId=experiment_id) - - assert response.experiment.numTrials == 16 - # asha search method state - assert search_method.asha_search_state.pending_trials == 0 - assert search_method.asha_search_state.completed_trials == 16 - # searcher state - assert len(search_runner.state.trials_created) == 16 - assert len(search_runner.state.trials_closed) == 16 - - assert len(search_runner.state.trials_closed) == len( - search_method.asha_search_state.closed_trials - ) - - response_trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment_id).trials - - # 16 trials in rung 1 (#batches = 125) - assert sum(t.totalBatchesProcessed >= 125 for t in response_trials) == 16 - # at least 4 trials in rung 2 (#batches = 500) - assert sum(t.totalBatchesProcessed >= 500 for t in response_trials) >= 4 - # at least 1 trial in rung 3 (#batches = 2000) - assert sum(t.totalBatchesProcessed == 2000 for t in response_trials) >= 1 - - for trial in response_trials: - assert trial.state == bindings.trialv1State.COMPLETED - - assert search_method.progress(search_runner.state) == pytest.approx(1.0) - - -class FallibleSearchRunner(searcher.LocalSearchRunner): - def __init__( - self, - exception_point: str, - search_method: searcher.SearchMethod, - searcher_dir: Optional[pathlib.Path] = None, - ): - super(FallibleSearchRunner, self).__init__(search_method, searcher_dir) - self.fail_on_save = False - if exception_point == "after_save": - self.fail_on_save = True - - def save_state(self, experiment_id: int, operations: List[searcher.Operation]) -> None: - super(FallibleSearchRunner, self).save_state(experiment_id, operations) - if self.fail_on_save: - logging.info( - "Raising exception in after saving the state and before posting operations" - ) - ex = connectionpool.MaxRetryError( - connectionpool.HTTPConnectionPool(host="dummyhost", port=8080), "http://dummyurl" - ) - raise ex diff --git a/e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py b/e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py deleted file mode 100644 index 44c54a4c574..00000000000 --- a/e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py +++ /dev/null @@ -1,91 +0,0 @@ -import time -from typing import List - -import pytest - -from determined.common.api import bindings -from tests import api_utils -from tests import config as conf -from tests import detproc -from tests import experiment as exp - -TIMESTAMP = int(time.time()) - - -@pytest.mark.e2e_cpu_2a -@pytest.mark.parametrize( - "config_name,exp_name,exception_points", - [ - ("core_api_model.yaml", f"custom-searcher-asha-test-{TIMESTAMP}", []), - ( # test fail on initialization - # test single resubmit of operations - # test resumption on fail before saving - "core_api_model.yaml", - f"custom-searcher-asha-test-fail1-{TIMESTAMP}", - [ - "initial_operations_start", - "after_save", - "on_validation_completed", - ], - ), - ( # test resubmitting operations multiple times - # test fail on shutdown - "core_api_model.yaml", - f"custom-searcher-asha-test-fail2-{TIMESTAMP}", - [ - "on_validation_completed", - "after_save", - "after_save", - "after_save", - "shutdown", - ], - ), - ], -) -def test_run_asha_searcher_exp_core_api( - config_name: str, exp_name: str, exception_points: List[str] -) -> None: - sess = api_utils.user_session() - config = conf.load_config(conf.fixtures_path("custom_searcher/core_api_searcher_asha.yaml")) - config["entrypoint"] += " --exp-name " + exp_name - config["entrypoint"] += " --config-name " + config_name - if len(exception_points) > 0: - config["entrypoint"] += " --exception-points " + " ".join(exception_points) - config["max_restarts"] = len(exception_points) - - experiment_id = exp.run_basic_test_with_temp_config( - sess, config, conf.fixtures_path("custom_searcher"), 1 - ) - - # searcher experiment - searcher_exp = bindings.get_GetExperiment(sess, experimentId=experiment_id).experiment - assert searcher_exp.state == bindings.experimentv1State.COMPLETED - - # actual experiment - response = bindings.get_GetExperiments(sess, name=exp_name) - experiments = response.experiments - assert len(experiments) == 1 - - experiment = experiments[0] - assert experiment.numTrials == 16 - - response_trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment.id).trials - - # 16 trials in rung 1 (#batches = 150) - assert sum(t.totalBatchesProcessed >= 150 for t in response_trials) == 16 - # at least 4 trials in rung 2 (#batches = 600) - assert sum(t.totalBatchesProcessed >= 600 for t in response_trials) >= 4 - # at least 1 trial in rung 3 (#batches = 2400) - assert sum(t.totalBatchesProcessed == 2400 for t in response_trials) >= 1 - - for trial in response_trials: - assert trial.state == bindings.trialv1State.COMPLETED - - # check logs to ensure failures actually happened - logs = detproc.check_output(sess, ["det", "experiment", "logs", str(experiment_id)]) - failures = logs.count("Max retries exceeded with url: http://dummyurl (Caused by None)") - assert failures == len(exception_points) - - # check for resubmitting operations - resubmissions = logs.count("determined.searcher: Resubmitting operations for event.id=") - assert resubmissions == sum([x == "after_save" for x in exception_points]) diff --git a/e2e_tests/tests/experiment/test_launch.py b/e2e_tests/tests/experiment/test_launch.py index fcb413bb868..ef8400f675b 100644 --- a/e2e_tests/tests/experiment/test_launch.py +++ b/e2e_tests/tests/experiment/test_launch.py @@ -12,11 +12,10 @@ def test_launch_layer_mnist() -> None: sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) - config = conf.set_max_length(config, {"batches": 200}) config = conf.set_slots_per_trial(config, 1) config = conf.set_profiling_enabled(config) config = conf.set_entrypoint( - config, "python3 -m determined.launch.horovod --autohorovod python3 train.py" + config, "python3 -m determined.launch.horovod --autohorovod python3 train.py --batches 200" ) experiment_id = exp.run_basic_test_with_temp_config( diff --git a/e2e_tests/tests/experiment/test_metrics.py b/e2e_tests/tests/experiment/test_metrics.py index 32a893a87af..a6db6f1179c 100644 --- a/e2e_tests/tests/experiment/test_metrics.py +++ b/e2e_tests/tests/experiment/test_metrics.py @@ -104,10 +104,10 @@ def request_metric_names(experiment_id): # type: ignore return ("training metric appeared twice", results) accumulated_validation.add(validation) - if accumulated_training != {"loss"}: - return ("unexpected set of training metrics", results) - if accumulated_validation != {"validation_loss", "accuracy"}: - return ("unexpected set of validation metrics", results) + if accumulated_training != {"loss", "batches", "epochs"}: + return (f"unexpected set of training metrics {accumulated_training}", results) + if accumulated_validation != {"validation_loss", "accuracy", "batches", "epochs"}: + return (f"unexpected set of validation metrics {accumulated_validation}", results) return None @@ -158,8 +158,8 @@ def request_valid_metric_batches(experiment_id): # type: ignore if batch in accumulated: return ("batch appears twice", results) accumulated.add(batch) - if accumulated != {200, 400}: - return ("unexpected set of batches", results) + if accumulated != {100, 200, 300, 400}: + return (f"unexpected set of batches: {accumulated}", results) return None diff --git a/e2e_tests/tests/experiment/test_noop.py b/e2e_tests/tests/experiment/test_noop.py index da15f1b7ca1..820c41630ff 100644 --- a/e2e_tests/tests/experiment/test_noop.py +++ b/e2e_tests/tests/experiment/test_noop.py @@ -269,7 +269,6 @@ def test_experiment_config_override() -> None: searcher: name: single metric: x - max_length: 1 entrypoint: echo yo dawg """ ) diff --git a/e2e_tests/tests/experiment/test_pending_hpc.py b/e2e_tests/tests/experiment/test_pending_hpc.py index c5d9d25fe6a..1ca490bb241 100644 --- a/e2e_tests/tests/experiment/test_pending_hpc.py +++ b/e2e_tests/tests/experiment/test_pending_hpc.py @@ -32,12 +32,8 @@ def test_hpc_job_pending_reason() -> None: detobj = client.Determined._from_session(sess) config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) - config = conf.set_max_length(config, {"batches": 200}) config = conf.set_slots_per_trial(config, 1) config = conf.set_profiling_enabled(config) - config = conf.set_entrypoint( - config, "python3 -m determined.launch.torch_distributed python3 train.py" - ) config["max_restarts"] = 0 # The experiment will request 6 CPUs @@ -45,6 +41,9 @@ def test_hpc_job_pending_reason() -> None: config["slurm"]["slots_per_node"] = 6 config.setdefault("pbs", {}) config["pbs"]["slots_per_node"] = 6 + # Wrap entrypoint in torch_distributed for dtrain support. + assert "torch_distributed" not in config["entrypoint"], "update test to match tutorial" + config["entrypoint"] = "python3 -m determined.launch.torch_distributed " + config["entrypoint"] running_exp = detobj.create_experiment(config, conf.fixtures_path("mnist_pytorch")) print(f"Created running experiment {running_exp.id}") diff --git a/e2e_tests/tests/experiment/test_tf_keras.py b/e2e_tests/tests/experiment/test_tf_keras.py index d2e17063fee..9bdec5c1858 100644 --- a/e2e_tests/tests/experiment/test_tf_keras.py +++ b/e2e_tests/tests/experiment/test_tf_keras.py @@ -1,46 +1,17 @@ -import multiprocessing - import pytest -from determined import keras -from determined.common import api -from determined.experimental import client from tests import api_utils from tests import config as conf from tests import experiment as exp -def _export_and_load_model(sess: api.Session, experiment_id: int, master_url: str) -> None: - # Normally verifying that we can load a model would be a good unit test, but making this an e2e - # test ensures that our model saving and loading works with all the versions of tf that we test. - ckpt = client.Determined._from_session(sess).get_experiment(experiment_id).top_checkpoint() - _ = keras.load_model_from_checkpoint_path(ckpt.download()) - - -def export_and_load_model(sess: api.Session, experiment_id: int) -> None: - # We run this in a subprocess to avoid module name collisions - # when performing checkpoint export of different models. - ctx = multiprocessing.get_context("spawn") - p = ctx.Process( - target=_export_and_load_model, - args=( - sess, - experiment_id, - conf.make_master_url(), - ), - ) - p.start() - p.join() - assert p.exitcode == 0, p.exitcode - - @pytest.mark.parallel @pytest.mark.parametrize("aggregation_frequency", [1, 4]) def test_tf_keras_parallel(aggregation_frequency: int) -> None: sess = api_utils.user_session() config = conf.load_config(conf.cv_examples_path("iris_tf_keras/const.yaml")) - config = conf.set_slots_per_trial(config, 8) - config = conf.set_max_length(config, {"batches": 200}) + assert "--epochs" not in config["entrypoint"], "please update test" + config["entrypoint"] += " --epochs 1" config = conf.set_aggregation_frequency(config, aggregation_frequency) config = conf.set_tf2_image(config) config = conf.set_profiling_enabled(config) @@ -50,21 +21,3 @@ def test_tf_keras_parallel(aggregation_frequency: int) -> None: ) trials = exp.experiment_trials(sess, experiment_id) assert len(trials) == 1 - - # Test exporting a checkpoint. - export_and_load_model(sess, experiment_id) - - # Check on record/batch counts we emitted in logs. - validation_size = 30 - num_workers = config.get("resources", {}).get("slots_per_trial", 1) - global_batch_size = config["hyperparameters"]["global_batch_size"] - scheduling_unit = config.get("scheduling_unit", 100) - per_slot_batch_size = global_batch_size // num_workers - exp_val_batches = (validation_size + (per_slot_batch_size - 1)) // per_slot_batch_size - patterns = [ - # Expect two copies of matching training reports. - f"trained: {scheduling_unit * global_batch_size} records.*in {scheduling_unit} batches", - f"trained: {scheduling_unit * global_batch_size} records.*in {scheduling_unit} batches", - f"validated: {validation_size} records.*in {exp_val_batches} batches", - ] - exp.assert_patterns_in_trial_logs(sess, trials[0].trial.id, patterns) diff --git a/e2e_tests/tests/fixtures/core_api/11_generic_metrics.yaml b/e2e_tests/tests/fixtures/core_api/11_generic_metrics.yaml index 1aafd98aa5d..cea9408fdd8 100644 --- a/e2e_tests/tests/fixtures/core_api/11_generic_metrics.yaml +++ b/e2e_tests/tests/fixtures/core_api/11_generic_metrics.yaml @@ -4,6 +4,5 @@ entrypoint: python3 11_generic_metrics.py searcher: name: single metric: x - max_length: 1 max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/core_api/arbitrary_workload_order.yaml b/e2e_tests/tests/fixtures/core_api/arbitrary_workload_order.yaml index 984e34e72a2..ebdd608abca 100644 --- a/e2e_tests/tests/fixtures/core_api/arbitrary_workload_order.yaml +++ b/e2e_tests/tests/fixtures/core_api/arbitrary_workload_order.yaml @@ -3,5 +3,4 @@ entrypoint: python3 arbitrary_workload_order.py searcher: name: single metric: x - max_length: 1 -max_restarts: 0 \ No newline at end of file +max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/core_api/pytorch_profiler_sync.yaml b/e2e_tests/tests/fixtures/core_api/pytorch_profiler_sync.yaml index a86f35c9ae9..1abedb4364f 100644 --- a/e2e_tests/tests/fixtures/core_api/pytorch_profiler_sync.yaml +++ b/e2e_tests/tests/fixtures/core_api/pytorch_profiler_sync.yaml @@ -4,6 +4,5 @@ entrypoint: python3 pytorch_profiler_sync.py searcher: name: single metric: x - max_length: 1 max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/core_api/sleep.yaml b/e2e_tests/tests/fixtures/core_api/sleep.yaml index dcc0f199ac6..59a048227a9 100644 --- a/e2e_tests/tests/fixtures/core_api/sleep.yaml +++ b/e2e_tests/tests/fixtures/core_api/sleep.yaml @@ -2,7 +2,5 @@ name: sleep for ten minutes searcher: name: single metric: loss - max_length: - batches: 1 entrypoint: python3 sleep.py max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/core_api/whoami.yaml b/e2e_tests/tests/fixtures/core_api/whoami.yaml index 484f1e35e1a..4c9f80ffba9 100644 --- a/e2e_tests/tests/fixtures/core_api/whoami.yaml +++ b/e2e_tests/tests/fixtures/core_api/whoami.yaml @@ -4,6 +4,5 @@ entrypoint: python3 whoami.py searcher: name: single metric: x - max_length: 1 max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/custom_searcher/core_api_custom_searcher.py b/e2e_tests/tests/fixtures/custom_searcher/core_api_custom_searcher.py deleted file mode 100644 index a4248469247..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher/core_api_custom_searcher.py +++ /dev/null @@ -1,105 +0,0 @@ -import argparse -import logging -from typing import List, Optional, Tuple - -import searchers -from urllib3 import connectionpool - -import determined as det -from determined import searcher -from determined.common import util - - -def load_config(config_path: str): - with open(config_path) as f: - config = util.safe_load_yaml_with_exceptions(f) - return config - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--searcher", choices=["asha", "random"], required=True) - parser.add_argument("--exp-name", type=str, required=True) - parser.add_argument("--max-length", type=int, required=True) - parser.add_argument("--max-trials", type=int, required=True) - parser.add_argument("--max-concurrent-trials", type=int, default=0) - parser.add_argument("--divisor", type=int, default=3) - parser.add_argument("--num-rungs", type=int, default=16) - parser.add_argument("--exception-points", type=str, nargs="+", default=[]) - parser.add_argument("--config-name", type=str, required=True) - parser.add_argument("--metric-as-dict", action="store_true", default=False) - return parser.parse_args() - - -def create_search_method(args, exception_points: Optional[List[str]] = None): - if args.searcher == "asha": - return searchers.ASHASearchMethod( - max_trials=args.max_trials, - max_length=args.max_length, - divisor=args.divisor, - num_rungs=args.num_rungs, - exception_points=exception_points, - ) - elif args.searcher == "random": - return searchers.RandomSearchMethod( - max_trials=args.max_trials, - max_length=args.max_length, - max_concurrent_trials=args.max_concurrent_trials, - exception_points=exception_points, - metric_as_dict=args.metric_as_dict, - ) - else: - raise ValueError("Unknown searcher type") - - -class FallibleSearchRunner(searcher.RemoteSearchRunner): - def __init__( - self, search_method: searcher.SearchMethod, core_context: det.core.Context - ) -> None: - super(FallibleSearchRunner, self).__init__(search_method, core_context) - self.fail_on_save = False - - def load_state(self, storage_id: str) -> Tuple[int, List[searcher.Operation]]: - result = super(FallibleSearchRunner, self).load_state(storage_id) - - # on every load remove first exception from the list - # since that exception was raised in the previous run; - # this testing approach works as long as the there is - # at least one save between consecutive exceptions - if len(search_method.exception_points) > 0: - self.search_method.exception_points.pop(0) - - if len(self.search_method.exception_points) > 0: - if self.search_method.exception_points[0] == "after_save": - self.fail_on_save = True - - return result - - def save_state(self, experiment_id: int, operations: List[searcher.Operation]) -> None: - super(FallibleSearchRunner, self).save_state(experiment_id, operations) - if self.fail_on_save: - logging.info( - "Raising exception in after saving the state and before posting operations" - ) - ex = connectionpool.MaxRetryError( - connectionpool.HTTPConnectionPool(host="dummyhost", port=8080), "http://dummyurl" - ) - raise ex - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) - - info = det.get_cluster_info() - assert info is not None, "this example only runs on-cluster" - args = parse_args() - - config = load_config(args.config_name) - config["name"] = args.exp_name - if args.metric_as_dict: - config["entrypoint"] += " dict" - - with det.core.init() as core_context: - search_method = create_search_method(args, args.exception_points) - search_runner = FallibleSearchRunner(search_method, core_context) - search_runner.run(config) diff --git a/e2e_tests/tests/fixtures/custom_searcher/core_api_model.yaml b/e2e_tests/tests/fixtures/custom_searcher/core_api_model.yaml deleted file mode 100644 index ef754668ccd..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher/core_api_model.yaml +++ /dev/null @@ -1,8 +0,0 @@ -name: custom-searcher-core-api -entrypoint: python3 model_coreapi.py -searcher: - metric: validation_error - smaller_is_better: true - name: custom - unit: batches -max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_asha.yaml b/e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_asha.yaml deleted file mode 100644 index ef792734917..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_asha.yaml +++ /dev/null @@ -1,9 +0,0 @@ -name: core-api-searcher-asha -entrypoint: python3 core_api_custom_searcher.py --searcher asha --num-rungs 3 --max-trials 16 --divisor 4 --max-length 2400 -searcher: - metric: validation_error - smaller_is_better: true - name: single - max_length: - batches: 100 -max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_random.yaml b/e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_random.yaml deleted file mode 100644 index fddc71c34b1..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher/core_api_searcher_random.yaml +++ /dev/null @@ -1,9 +0,0 @@ -name: core-api-searcher-random -entrypoint: python3 core_api_custom_searcher.py --searcher random --max-trials 5 --max-concurrent-trials 2 --max-length 500 -searcher: - metric: validation_error - smaller_is_better: true - name: single - max_length: - batches: 100 -max_restarts: 0 diff --git a/e2e_tests/tests/fixtures/custom_searcher/model_coreapi.py b/e2e_tests/tests/fixtures/custom_searcher/model_coreapi.py deleted file mode 100644 index aa7d35c36b1..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher/model_coreapi.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import pathlib -import sys - -import determined as det -from determined.common import util - - -def save_state(x, steps_completed, trial_id, checkpoint_directory): - with checkpoint_directory.joinpath("state").open("w") as f: - f.write(f"{x},{steps_completed},{trial_id}") - - -def load_state(trial_id, checkpoint_directory): - checkpoint_directory = pathlib.Path(checkpoint_directory) - with checkpoint_directory.joinpath("state").open("r") as f: - x, steps_completed, ckpt_trial_id = [int(field) for field in f.read().split(",")] - if ckpt_trial_id == trial_id: - return x, steps_completed - else: - return x, 0 - - -def main(core_context, latest_checkpoint, trial_id, increment_by, metric_as_dict): - x = 0 - - starting_batch = 0 - if latest_checkpoint is not None: - with core_context.checkpoint.restore_path(latest_checkpoint) as path: - x, starting_batch = load_state(trial_id, path) - - batch = starting_batch - last_checkpoint_batch = None - for op in core_context.searcher.operations(): - while batch < op.length: - x += increment_by - steps_completed = batch + 1 - if steps_completed % 100 == 0: - core_context.train.report_metrics( - group=util._LEGACY_TRAINING, - steps_completed=steps_completed, - metrics={"validation_error": x}, - ) - - op.report_progress(batch) - - checkpoint_metadata = {"steps_completed": steps_completed} - with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): - save_state(x, steps_completed, trial_id, path) - last_checkpoint_batch = steps_completed - if core_context.preempt.should_preempt(): - return - batch += 1 - - core_context.train.report_metrics( - group="validation", steps_completed=steps_completed, metrics={"validation_error": x} - ) - op.report_completed({"foo": x} if metric_as_dict else x) - - if last_checkpoint_batch != steps_completed: - checkpoint_metadata = {"steps_completed": steps_completed} - with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): - save_state(x, steps_completed, trial_id, path) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) - - info = det.get_cluster_info() - assert info is not None, "this example only runs on-cluster" - latest_checkpoint = info.latest_checkpoint - trial_id = info.trial.trial_id - hparams = info.trial.hparams - - metric_as_dict = len(sys.argv) > 1 and sys.argv[1] == "dict" - - with det.core.init() as core_context: - main( - core_context, - latest_checkpoint, - trial_id, - increment_by=hparams["increment_by"], - metric_as_dict=metric_as_dict, - ) diff --git a/e2e_tests/tests/fixtures/custom_searcher/searchers.py b/e2e_tests/tests/fixtures/custom_searcher/searchers.py deleted file mode 100644 index 7ad1a562c59..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher/searchers.py +++ /dev/null @@ -1,645 +0,0 @@ -import dataclasses -import json -import logging -import pathlib -import pickle -import random -import sys -import uuid -import warnings -from typing import Any, Dict, List, Optional, Set - -from urllib3 import connectionpool - -from determined import searcher - - -class SingleSearchMethod(searcher.SearchMethod): - def __init__(self, experiment_config: dict, max_length: int) -> None: - warnings.warn( - "`SingleSearchMethod` and all custom searchers have been deprecated. " - "This feature will be removed in a future release. Consider configuring a preset " - "searcher instead (see Determined docs for details).", - FutureWarning, - stacklevel=2, - ) - # since this is a single trial the hyperparameter space comprises a single point - self.hyperparameters = experiment_config["hyperparameters"] - self.max_length = max_length - self.trial_closed = False - - def on_trial_created( - self, _: searcher.SearcherState, __: uuid.UUID - ) -> List[searcher.Operation]: - return [] - - def on_validation_completed( - self, _: searcher.SearcherState, request_id: uuid.UUID, metric: Any, train_length: int - ) -> List[searcher.Operation]: - return [] - - def on_trial_closed( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.trial_closed = True - return [searcher.Shutdown()] - - def progress(self, searcher_state: searcher.SearcherState) -> float: - if self.trial_closed: - return 1.0 - (the_trial,) = searcher_state.trials_created - return searcher_state.trial_progress[the_trial] / self.max_length - - def on_trial_exited_early( - self, _: searcher.SearcherState, request_id: uuid.UUID, exited_reason: searcher.ExitedReason - ) -> List[searcher.Operation]: - logging.warning(f"Trial {request_id} exited early: {exited_reason}") - return [searcher.Shutdown()] - - def initial_operations(self, _: searcher.SearcherState) -> List[searcher.Operation]: - logging.info("initial_operations") - - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.hyperparameters, - checkpoint=None, - ) - validate_after = searcher.ValidateAfter( - request_id=create.request_id, length=self.max_length - ) - close = searcher.Close(request_id=create.request_id) - logging.debug(f"Create({create.request_id}, {create.hparams})") - return [create, validate_after, close] - - -class RandomSearchMethod(searcher.SearchMethod): - def __init__( - self, - max_trials: int, - max_concurrent_trials: int, - max_length: int, - test_type: str = "core_api", - exception_points: Optional[List[str]] = None, - metric_as_dict: bool = False, - ) -> None: - warnings.warn( - "`RandomSearchMethod` and all custom searchers have been deprecated. " - "This feature will be removed in a future release. Consider configuring a preset " - "searcher instead (see Determined docs for details).", - FutureWarning, - stacklevel=2, - ) - self.max_trials = max_trials - self.max_concurrent_trials = max_concurrent_trials - self.max_length = max_length - - self.test_type = test_type - self.exception_points = exception_points - self.metric_as_dict = metric_as_dict - - self.created_trials = 0 - self.pending_trials = 0 - self.closed_trials = 0 - self.tried = set() - - def on_trial_created( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.raise_exception("on_trial_created") - if self.created_trials == 5: - self.raise_exception("on_trial_created_5") - self._log_stats() - return [] - - def on_validation_completed( - self, _: searcher.SearcherState, request_id: uuid.UUID, metric: Any, train_length: int - ) -> List[searcher.Operation]: - self.raise_exception("on_validation_completed") - if self.metric_as_dict: - logging.debug(f"metric={metric.items()}") - assert isinstance(metric, dict) - return [] - - def on_trial_closed( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.pending_trials -= 1 - self.closed_trials += 1 - ops: List[searcher.Operation] = [] - if self.created_trials < self.max_trials: - request_id = uuid.uuid4() - ops.append( - searcher.Create( - request_id=request_id, hparams=self.sample_params(), checkpoint=None - ) - ) - ops.append(searcher.ValidateAfter(request_id=request_id, length=self.max_length)) - ops.append(searcher.Close(request_id=request_id)) - self.created_trials += 1 - self.pending_trials += 1 - elif self.pending_trials == 0: - self.raise_exception("on_trial_closed_shutdown") - ops.append(searcher.Shutdown()) - - self._log_stats() - self.raise_exception("on_trial_closed_end") - return ops - - def progress(self, searcher_state: searcher.SearcherState) -> float: - if 0 < self.max_concurrent_trials < self.pending_trials: - logging.error("pending trials is greater than max_concurrent_trial") - units_completed = sum( - ( - ( - self.max_length - if r in searcher_state.trials_closed - else searcher_state.trial_progress[r] - ) - for r in searcher_state.trial_progress - ) - ) - units_expected = self.max_length * self.max_trials - progress = units_completed / units_expected - logging.debug( - f"progress = {progress} = {units_completed} / {units_expected}," - f" {searcher_state.trial_progress}" - ) - - if progress >= 0.5: - self.raise_exception("progress_middle") - - return progress - - def on_trial_exited_early( - self, _: searcher.SearcherState, request_id: uuid.UUID, exited_reason: searcher.ExitedReason - ) -> List[searcher.Operation]: - self.pending_trials -= 1 - - ops: List[searcher.Operation] = [] - if exited_reason == searcher.ExitedReason.INVALID_HP: - request_id = uuid.uuid4() - ops.append( - searcher.Create( - request_id=request_id, hparams=self.sample_params(), checkpoint=None - ) - ) - ops.append(searcher.ValidateAfter(request_id=request_id, length=self.max_length)) - ops.append(searcher.Close(request_id=request_id)) - self.pending_trials += 1 - return ops - - self.closed_trials += 1 - self._log_stats() - return ops - - def initial_operations(self, _: searcher.SearcherState) -> List[searcher.Operation]: - self.raise_exception("initial_operations_start") - initial_trials = self.max_trials - max_concurrent_trials = self.max_concurrent_trials - if max_concurrent_trials > 0: - initial_trials = min(initial_trials, max_concurrent_trials) - - ops: List[searcher.Operation] = [] - - for _ in range(initial_trials): - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append(searcher.ValidateAfter(request_id=create.request_id, length=self.max_length)) - ops.append(searcher.Close(request_id=create.request_id)) - - self.created_trials += 1 - self.pending_trials += 1 - - self._log_stats() - return ops - - def _log_stats(self) -> None: - logging.info(f"created trials={self.created_trials}") - logging.info(f"pending trials={self.pending_trials}") - logging.info(f"closed trials={self.closed_trials}") - - def sample_params(self) -> Dict[str, int]: - if self.test_type == "core_api": - increment = random.randint(1, 20) - while increment in self.tried: - increment = random.randint(1, 20) - hparams = {"increment_by": increment} - else: - hparams = {"global_batch_size": random.randint(10, 100)} - logging.info(f"hparams={hparams}") - return hparams - - def save_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("save_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("w") as f: - state = { - "max_trials": self.max_trials, - "max_concurrent_trials": self.max_concurrent_trials, - "max_length": self.max_length, - "created_trials": self.created_trials, - "pending_trials": self.pending_trials, - "closed_trials": self.closed_trials, - "exception_points": self.exception_points, - } - json.dump(state, f) - - def load_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("load_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("r") as f: - state = json.load(f) - self.max_trials = state["max_trials"] - self.max_concurrent_trials = state["max_concurrent_trials"] - self.max_length = state["max_length"] - self.created_trials = state["created_trials"] - self.pending_trials = state["pending_trials"] - self.closed_trials = state["closed_trials"] - - if self.test_type == "core_api": - # ony restore exception points for core_api searcher tests; - # local searcher is providing new exception point on resumption, - # and it shouldn't be overridden - self.exception_points = state["exception_points"] - - def raise_exception(self, exception_id: str) -> None: - if ( - self.exception_points is not None - and len(self.exception_points) > 0 - and exception_id == self.exception_points[0] - ): - logging.info(f"Raising exception in {exception_id}") - ex = connectionpool.MaxRetryError( - connectionpool.HTTPConnectionPool(host="dummyhost", port=8080), - "http://dummyurl", - ) - raise ex - - -@dataclasses.dataclass -class TrialMetric: - request_id: uuid.UUID - metric: float - promoted: bool = False - - -@dataclasses.dataclass -class Rung: - units_needed: int - idx: int - metrics: List[TrialMetric] = dataclasses.field(default_factory=list) - outstanding_trials: int = 0 - - def promotions_async( - self, request_id: uuid.UUID, metric: float, divisor: int - ) -> List[uuid.UUID]: - logging.info(f"Rung {self.idx}") - logging.info(f"outstanding_trials {self.outstanding_trials}") - - old_num_promote = len(self.metrics) // divisor - num_promote = (len(self.metrics) + 1) // divisor - - index = self._search_metric_index(metric) - promote_now = index < num_promote - trial_metric = TrialMetric(request_id=request_id, metric=metric, promoted=promote_now) - self.metrics.insert(index, trial_metric) - - if promote_now: - return [request_id] - if num_promote != old_num_promote and not self.metrics[old_num_promote].promoted: - self.metrics[old_num_promote].promoted = True - return [self.metrics[old_num_promote].request_id] - - logging.info("No promotion") - return [] - - def _search_metric_index(self, metric: float) -> int: - i: int = 0 - j: int = len(self.metrics) - while i < j: - mid = (i + j) >> 1 - if self.metrics[mid].metric <= metric: - i = mid + 1 - else: - j = mid - return i - - -class ASHASearchMethodState: - def __init__( - self, - max_length: int, - max_trials: int, - num_rungs: int, - divisor: int, - max_concurrent_trials: int = 16, - ) -> None: - # Asha params - self.max_length = max_length - self.max_trials = max_trials - self.num_rungs = num_rungs - self.divisor = divisor - self.max_concurrent_trials = max_concurrent_trials - self.is_smaller_better = True - - # structs - self.rungs: List[Rung] = [] - self.trial_rungs: Dict[uuid.UUID, int] = {} - - # accounting - self.pending_trials: int = 0 - self.completed_trials: int = 0 - self.invalid_trials: int = 0 - self.early_exit_trials: Set[uuid.UUID] = set() - self.closed_trials: Set[uuid.UUID] = set() - - self._init_rungs() - - def _init_rungs(self) -> None: - units_needed = 0 - for idx in range(self.num_rungs): - downsampling_rate = pow(self.divisor, float(self.num_rungs - idx - 1)) - units_needed += max(int(self.max_length / downsampling_rate), 1) - self.rungs.append(Rung(units_needed, idx)) - - -class ASHASearchMethod(searcher.SearchMethod): - def __init__( - self, - max_length: int, - max_trials: int, - num_rungs: int, - divisor: int, - test_type: str = "core_api", - max_concurrent_trials: int = 16, - exception_points: Optional[List[str]] = None, - ) -> None: - warnings.warn( - "`ASHASearchMethod` and all custom searchers have been deprecated. " - "This feature will be removed in a future release. Consider configuring a preset " - "searcher instead (see Determined docs for details).", - FutureWarning, - stacklevel=2, - ) - self.asha_search_state = ASHASearchMethodState( - max_length, max_trials, num_rungs, divisor, max_concurrent_trials - ) - self.test_type = test_type - self.exception_points = exception_points - - def on_trial_closed( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.asha_search_state.completed_trials += 1 - self.asha_search_state.closed_trials.add(request_id) - - if ( - self.asha_search_state.pending_trials == 0 - and self.asha_search_state.completed_trials == self.asha_search_state.max_trials - ): - self.raise_exception("shutdown") - return [searcher.Shutdown()] - - return [] - - def on_trial_created( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.asha_search_state.rungs[0].outstanding_trials += 1 - self.asha_search_state.trial_rungs[request_id] = 0 - self.raise_exception("on_trial_created") - return [] - - def on_validation_completed( - self, _: searcher.SearcherState, request_id: uuid.UUID, metric: Any, train_length: int - ) -> List[searcher.Operation]: - self.asha_search_state.pending_trials -= 1 - if self.asha_search_state.is_smaller_better is False: - metric *= -1 - ops = self.promote_async(request_id, metric) - self.raise_exception("on_validation_completed") - return ops - - def on_trial_exited_early( - self, - _: searcher.SearcherState, - request_id: uuid.UUID, - exited_reason: searcher.ExitedReason, - ) -> List[searcher.Operation]: - self.asha_search_state.pending_trials -= 1 - if exited_reason == searcher.ExitedReason.INVALID_HP: - ops: List[searcher.Operation] = [] - - self.asha_search_state.early_exit_trials.add(request_id) - ops.append(searcher.Close(request_id)) - self.asha_search_state.closed_trials.add(request_id) - self.asha_search_state.invalid_trials += 1 - - highest_rung_index = self.asha_search_state.trial_rungs[request_id] - rung = self.asha_search_state.rungs[highest_rung_index] - rung.outstanding_trials -= 1 - - for rung_idx in range(0, highest_rung_index + 1): - rung = self.asha_search_state.rungs[rung_idx] - rung.metrics = list(filter(lambda x: x.request_id != request_id, rung.metrics)) - - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append( - searcher.ValidateAfter( - request_id=create.request_id, - length=self.asha_search_state.rungs[0].units_needed, - ) - ) - - self.asha_search_state.trial_rungs[create.request_id] = 0 - self.asha_search_state.pending_trials += 1 - - return ops - - self.asha_search_state.early_exit_trials.add(request_id) - self.asha_search_state.closed_trials.add(request_id) - return self.promote_async(request_id, sys.float_info.max) - - def initial_operations(self, _: searcher.SearcherState) -> List[searcher.Operation]: - self.raise_exception("initial_operations_start") - ops: List[searcher.Operation] = [] - - if self.asha_search_state.max_concurrent_trials > 0: - max_concurrent_trials = min( - self.asha_search_state.max_concurrent_trials, self.asha_search_state.max_trials - ) - else: - max_concurrent_trials = max( - 1, - min( - int(pow(self.asha_search_state.divisor, self.asha_search_state.num_rungs - 1)), - self.asha_search_state.max_trials, - ), - ) - - for _ in range(0, max_concurrent_trials): - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append( - searcher.ValidateAfter( - request_id=create.request_id, - length=self.asha_search_state.rungs[0].units_needed, - ) - ) - - self.asha_search_state.trial_rungs[create.request_id] = 0 - self.asha_search_state.pending_trials += 1 - - return ops - - def promote_async(self, request_id: uuid.UUID, metric: float) -> List[searcher.Operation]: - rung_idx = self.asha_search_state.trial_rungs[request_id] - rung = self.asha_search_state.rungs[rung_idx] - rung.outstanding_trials -= 1 - added_train_workload = False - - ops: List[searcher.Operation] = [] - - if rung_idx == self.asha_search_state.num_rungs - 1: - rung.metrics.append(TrialMetric(request_id=request_id, metric=metric)) - - if request_id not in self.asha_search_state.early_exit_trials: - self.raise_exception("promote_async_close_trials") - ops.append(searcher.Close(request_id=request_id)) - logging.info(f"Closing trial {request_id}") - self.asha_search_state.closed_trials.add(request_id) - else: - next_rung = self.asha_search_state.rungs[rung_idx + 1] - self.raise_exception("promote_async") - logging.info(f"Promoting in rung {rung_idx}") - for promoted_request_id in rung.promotions_async( - request_id, metric, self.asha_search_state.divisor - ): - self.asha_search_state.trial_rungs[promoted_request_id] = rung_idx + 1 - next_rung.outstanding_trials += 1 - if promoted_request_id not in self.asha_search_state.early_exit_trials: - logging.info(f"Promoted {promoted_request_id}") - units_needed = max(next_rung.units_needed - rung.units_needed, 1) - ops.append(searcher.ValidateAfter(promoted_request_id, units_needed)) - added_train_workload = True - self.asha_search_state.pending_trials += 1 - else: - return self.promote_async(promoted_request_id, sys.float_info.max) - - all_trials = len(self.asha_search_state.trial_rungs) - self.asha_search_state.invalid_trials - if not added_train_workload and all_trials < self.asha_search_state.max_trials: - logging.info("Creating new trial instead of promoting") - self.asha_search_state.pending_trials += 1 - - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append( - searcher.ValidateAfter( - request_id=create.request_id, - length=self.asha_search_state.rungs[0].units_needed, - ) - ) - self.asha_search_state.trial_rungs[create.request_id] = 0 - - if len(self.asha_search_state.rungs[0].metrics) == self.asha_search_state.max_trials: - ops.extend(self._get_close_rungs_ops()) - - return ops - - def _get_close_rungs_ops(self) -> List[searcher.Operation]: - self.raise_exception("_get_close_rungs_ops") - ops: List[searcher.Operation] = [] - - for rung in self.asha_search_state.rungs: - if rung.outstanding_trials > 0: - break - for trial_metric in rung.metrics: - if ( - not trial_metric.promoted - and trial_metric.request_id not in self.asha_search_state.closed_trials - ): - if trial_metric.request_id not in self.asha_search_state.early_exit_trials: - logging.info(f"Closing trial {trial_metric.request_id}") - ops.append(searcher.Close(trial_metric.request_id)) - self.asha_search_state.closed_trials.add(trial_metric.request_id) - return ops - - def sample_params(self) -> Dict[str, object]: - hparams = { - "metrics_base": 0.05 * (len(self.asha_search_state.trial_rungs) + 1), - "metrics_progression": "constant", - } - if self.test_type == "core_api": - hparams["increment_by"] = 10 - else: - hparams["global_batch_size"] = 10 - logging.info(f"hparams={hparams}") - return hparams - - def progress(self, _: searcher.SearcherState) -> float: - if 0 < self.asha_search_state.max_concurrent_trials < self.asha_search_state.pending_trials: - raise RuntimeError("Pending trial is greater than max concurrent trials") - all_trials = len(self.asha_search_state.rungs[0].metrics) - - progress = all_trials / (1.2 * self.asha_search_state.max_trials) - if all_trials == self.asha_search_state.max_trials: - num_valid_trials = ( - self.asha_search_state.completed_trials - self.asha_search_state.invalid_trials - ) - progress_no_overhead = num_valid_trials / self.asha_search_state.max_trials - progress = max(progress_no_overhead, progress) - - return progress - - def save_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("save_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("wb") as f: - pickle.dump(self.asha_search_state, f) - - exception_path = path.joinpath("exceptions") - with exception_path.open("wb") as f: - pickle.dump(self.exception_points, f) - - def load_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("load_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("rb") as f: - self.asha_search_state = pickle.load(f) - - if self.test_type == "core_api": - # ony restore exception points for core_api searcher tests; - # local searcher is providing new exception point on resumption, - # and it shouldn't be overridden - exception_path = path.joinpath("exceptions") - with exception_path.open("rb") as f: - self.exception_points = pickle.load(f) - - def raise_exception(self, exception_id: str) -> None: - if ( - self.exception_points is not None - and len(self.exception_points) > 0 - and exception_id == self.exception_points[0] - ): - logging.info(f"Raising exception in {exception_id}") - ex = connectionpool.MaxRetryError( - connectionpool.HTTPConnectionPool(host="dummyhost", port=8080), "http://dummyurl" - ) - raise ex diff --git a/e2e_tests/tests/fixtures/custom_searcher_exp/adaptive.yaml b/e2e_tests/tests/fixtures/custom_searcher_exp/adaptive.yaml deleted file mode 100644 index af65b0c2ee1..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher_exp/adaptive.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: noop_adaptive -checkpoint_storage: - type: shared_fs - host_path: /tmp - storage_path: determined-integration-checkpoints -hyperparameters: - global_batch_size: 32 - metrics_progression: decreasing - metrics_base: - type: double - minval: 0.5 - maxval: 0.9 -searcher: - name: adaptive_asha - metric: validation_error - max_trials: 30 - max_length: - batches: 640 -reproducibility: - experiment_seed: 999 -max_restarts: 0 -entrypoint: model_def:NoOpTrial diff --git a/e2e_tests/tests/fixtures/custom_searcher_exp/model_def.py b/e2e_tests/tests/fixtures/custom_searcher_exp/model_def.py deleted file mode 100644 index e42f44d5888..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher_exp/model_def.py +++ /dev/null @@ -1,228 +0,0 @@ -import collections -import json -import logging -import os -import pathlib -import pickle -import random -import sys -import time -from typing import Any, Dict - -import numpy as np - -import determined as det -from determined import layers, tensorboard, util, workload -from determined.common import check - - -class NoOpTrialContext(det.TrialContext): - """ - NoOpTrial needs batch sizes. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._per_slot_batch_size, self._global_batch_size = util.calculate_batch_sizes( - self.get_hparams(), - self.env.experiment_config.slots_per_trial(), - "NoOpTrial", - ) - - def get_per_slot_batch_size(self) -> int: - return self._per_slot_batch_size - - def get_global_batch_size(self) -> int: - return self._global_batch_size - - -class NoOpTrialController(det.TrialController): - """ - A trial class which does nothing (except for maybe sleep) during training - and validation. For testing purposes. - """ - - CHECKPOINT_FILENAME = "no_op_checkpoint" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.metric_writer = self.create_metric_writer() - - check_startup_hook_ran = self.env.hparams.get("check_startup_hook_ran", False) - if check_startup_hook_ran: - check.true(os.path.isfile("startup-hook-ran"), "File should exists.") - - self._batch_size = self.context.get_per_slot_batch_size() - self.nan_probability_validate = self.env.hparams.get("nan_probability_validate", 0) - self.validation_set_size = 256 - self.train_batch_secs = self.env.hparams.get("training_batch_seconds", 0) - self.num_training_metrics = self.env.hparams.get("num_training_metrics", 1) - assert self.num_training_metrics > 0 - self.num_validation_metrics = self.env.hparams.get("num_validation_metrics", 1) - assert self.num_validation_metrics > 0 - self.metrics_progression = self.env.hparams.get("metrics_progression", "decreasing") - assert self.metrics_progression in ("increasing", "decreasing", "constant") - self.metrics_base = self.env.hparams.get("metrics_base", 0.9) - assert 0 < self.metrics_base < 1 - self.write_null = self.env.hparams.get("write_null", False) - - self.request_stop = self.env.hparams.get("request_stop", False) - - self.crash_on_startup = self.env.hparams.get("crash_on_startup", False) - self.non_chief_exit_immediately = self.env.hparams.get("non_chief_exit_immediately", False) - - self.wlsq = None - if self.workloads is None: - self.workloads, self.wlsq = layers.make_compatibility_workloads( - self.context._core, self.env, self.context.get_global_batch_size() - ) - - self.steps_completed = self.env.steps_completed - - if self.env.latest_checkpoint is not None: - with self.context._core.checkpoint.restore_path( - self.env.latest_checkpoint - ) as load_path: - self.load(pathlib.Path(load_path)) - else: - self.trained_steps = collections.Counter() - - @staticmethod - def from_trial(trial_inst: det.LegacyTrial, *args: Any, **kwargs: Any) -> det.TrialController: - return NoOpTrialController(*args, **kwargs) - - @staticmethod - def pre_execute_hook(env: det.EnvContext, distributed_backend: det._DistributedBackend) -> None: - pass - - def create_metric_writer(self) -> tensorboard.BatchMetricWriter: - return tensorboard.get_metric_writer() - - def run(self) -> None: - assert not self.crash_on_startup - if self.non_chief_exit_immediately: - if self.context.distributed.get_rank() != 0: - sys.exit() - else: - time.sleep(1800) - - for w, response_func in self.workloads: - if w.kind == workload.Workload.Kind.RUN_STEP: - response = self.train_for_step(w.step_id, w.num_batches) - elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: - response = self.compute_validation_metrics(w.step_id) - elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: - metadata = {"steps_completed": self.steps_completed} - if self.is_chief: - with self.context._core.checkpoint.store_path(metadata) as ( - path, - storage_id, - ): - self.save(path) - response = {"uuid": storage_id} - else: - response = {} - else: - raise AssertionError("Unexpected workload: {}".format(w.kind)) - - response_func(response) - self.upload_tb_files() - - def steps_trained(self) -> int: - return sum(self.trained_steps.values()) - - def current_metric(self) -> float: - if self.metrics_progression == "constant": - return self.metrics_base - elif self.metrics_progression == "decreasing": - return self.metrics_base ** self.steps_trained() - elif self.metrics_progression == "increasing": - return 1 - (self.metrics_base ** self.steps_trained()) - else: - raise ValueError("Invalid `metrics_progression` {}".format(self.metrics_progression)) - - def train_for_step(self, step_id: int, num_batches: int) -> Dict[str, Any]: - if self.request_stop: - self.context.set_stop_requested(True) - time.sleep(self.train_batch_secs * num_batches) - if self.write_null: - with open("/dev/stdout", "wb") as f: - f.write(b"\x00") - self.trained_steps[step_id] += 1 - metrics = {name: self.current_metric() for name in ["loss", *self.training_metrics()]} - response = { - "metrics": det.util.make_metrics( - self._batch_size * num_batches, [metrics] * num_batches - ), - "stop_requested": self.context.get_stop_requested(), - } - self.steps_completed += num_batches - self.metric_writer.on_train_step_end( - self.steps_completed, - metrics=response["metrics"]["avg_metrics"], - batch_metrics=response["metrics"]["batch_metrics"], - ) - return response - - def compute_validation_metrics(self, step_id: int) -> Dict[str, Any]: - metrics = { - name: ( - np.nan if random.random() < self.nan_probability_validate else self.current_metric() - ) - for name in ["validation_error", *self.validation_metrics()] - } - response = { - "metrics": {"validation_metrics": metrics, "num_inputs": self.validation_set_size}, - "stop_requested": self.context.get_stop_requested(), - } - return response - - def training_metrics(self) -> Dict[str, Any]: - return {"metric_{}".format(i): None for i in range(1, self.num_training_metrics)} - - def validation_metrics(self) -> Dict[str, Any]: - return { - "validation_metric_{}".format(i): None for i in range(1, self.num_validation_metrics) - } - - def batch_size(self) -> int: - return self._batch_size - - def save(self, path: pathlib.Path) -> None: - fpath = path.joinpath(self.CHECKPOINT_FILENAME) - logging.info("Saving checkpoint {}, steps_trained {}".format(fpath, self.steps_trained())) - with fpath.open("w") as f: - json.dump(self.trained_steps, f, sort_keys=True, indent=4) - path.chmod(0o777) - fpath.chmod(0o777) - - wlsq_path = path.joinpath("workload_sequencer.pkl") - if self.wlsq is not None: - with wlsq_path.open("wb") as f: - pickle.dump(self.wlsq.get_state(), f) - - def load(self, path: pathlib.Path) -> None: - fpath = path.joinpath(self.CHECKPOINT_FILENAME) - with fpath.open("r") as f: - jbody = {int(k): v for k, v in json.load(f).items()} - for k, v in jbody.items(): - check.gt_eq(k, 0) - check.is_type(v, int) - check.gt_eq(v, 0) - self.trained_steps = collections.Counter(jbody) - logging.info( - "Loaded checkpoint {}, steps_trained {}".format(fpath, self.steps_trained()) - ) - - wlsq_path = path.joinpath("workload_sequencer.pkl") - if self.wlsq is not None and wlsq_path.exists(): - with wlsq_path.open("rb") as f: - self.wlsq.load_state(pickle.load(f)) - - -class NoOpTrial(det.LegacyTrial): - trial_context_class = NoOpTrialContext - trial_controller_class = NoOpTrialController - - def __init__(self, context: det.TrialContext) -> None: - self.context = context diff --git a/e2e_tests/tests/fixtures/custom_searcher_exp/single.yaml b/e2e_tests/tests/fixtures/custom_searcher_exp/single.yaml deleted file mode 100644 index 2caac8fd2dd..00000000000 --- a/e2e_tests/tests/fixtures/custom_searcher_exp/single.yaml +++ /dev/null @@ -1,24 +0,0 @@ -name: noop_single -checkpoint_storage: - type: shared_fs - host_path: /tmp - storage_path: determined-integration-checkpoints - save_trial_best: 30 -hyperparameters: - global_batch_size: 32 - metrics_progression: decreasing - metrics_base: 0.9 -searcher: - metric: validation_error - smaller_is_better: true - name: single - max_length: - batches: 3000 -reproducibility: - experiment_seed: 999 -min_validation_period: - batches: 100 -min_checkpoint_period: - batches: 100 -max_restarts: 0 -entrypoint: model_def:NoOpTrial diff --git a/e2e_tests/tests/fixtures/failures/bad-image.yaml b/e2e_tests/tests/fixtures/failures/bad-image.yaml index b47628b005e..147154d081d 100644 --- a/e2e_tests/tests/fixtures/failures/bad-image.yaml +++ b/e2e_tests/tests/fixtures/failures/bad-image.yaml @@ -4,7 +4,5 @@ environment: searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 -entrypoint: failures:FailureTrial \ No newline at end of file +entrypoint: failures:FailureTrial diff --git a/e2e_tests/tests/fixtures/failures/bad-pbs-option.yaml b/e2e_tests/tests/fixtures/failures/bad-pbs-option.yaml index 0b4915404d9..cd139b34bcb 100644 --- a/e2e_tests/tests/fixtures/failures/bad-pbs-option.yaml +++ b/e2e_tests/tests/fixtures/failures/bad-pbs-option.yaml @@ -5,7 +5,5 @@ pbs: searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 entrypoint: failures:FailureTrial diff --git a/e2e_tests/tests/fixtures/failures/bad-slurm-option.yaml b/e2e_tests/tests/fixtures/failures/bad-slurm-option.yaml index 2dec4157080..dd86a118ca8 100644 --- a/e2e_tests/tests/fixtures/failures/bad-slurm-option.yaml +++ b/e2e_tests/tests/fixtures/failures/bad-slurm-option.yaml @@ -5,7 +5,5 @@ slurm: searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 entrypoint: failures:FailureTrial diff --git a/e2e_tests/tests/fixtures/failures/docker-login-failure.yaml b/e2e_tests/tests/fixtures/failures/docker-login-failure.yaml index df38384e8d4..e1aa1fa5e6d 100644 --- a/e2e_tests/tests/fixtures/failures/docker-login-failure.yaml +++ b/e2e_tests/tests/fixtures/failures/docker-login-failure.yaml @@ -7,7 +7,5 @@ environment: searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 entrypoint: failures:FailureTrial diff --git a/e2e_tests/tests/fixtures/failures/slurm-requested-node-not-available.yaml b/e2e_tests/tests/fixtures/failures/slurm-requested-node-not-available.yaml index d0abe3ca184..e30a9442de4 100644 --- a/e2e_tests/tests/fixtures/failures/slurm-requested-node-not-available.yaml +++ b/e2e_tests/tests/fixtures/failures/slurm-requested-node-not-available.yaml @@ -4,7 +4,5 @@ resources: searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 -entrypoint: failures:FailureTrial \ No newline at end of file +entrypoint: failures:FailureTrial diff --git a/e2e_tests/tests/fixtures/failures/unsupported-slurm-option.yaml b/e2e_tests/tests/fixtures/failures/unsupported-slurm-option.yaml index 6959e078e6f..5ec5016a0fb 100644 --- a/e2e_tests/tests/fixtures/failures/unsupported-slurm-option.yaml +++ b/e2e_tests/tests/fixtures/failures/unsupported-slurm-option.yaml @@ -6,7 +6,5 @@ slurm: searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 entrypoint: failures:FailureTrial diff --git a/e2e_tests/tests/fixtures/hpc/embedded-quotes.yaml b/e2e_tests/tests/fixtures/hpc/embedded-quotes.yaml index 27f4a2154e7..edc5d84676f 100644 --- a/e2e_tests/tests/fixtures/hpc/embedded-quotes.yaml +++ b/e2e_tests/tests/fixtures/hpc/embedded-quotes.yaml @@ -1,11 +1,9 @@ -description: metric_maker +description: embedded-quotes data: user_defined_key: datakey="datavalue with embedded " searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 entrypoint: python3 data_validator.py diff --git a/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml b/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml index ae5dbdfb7b6..0c2e8efc297 100644 --- a/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml +++ b/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml @@ -1,11 +1,9 @@ -description: metric_maker +description: embedded-single-quote data: user_defined_key: datakey="datavalue with ' embedded " searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 entrypoint: python3 data_validator.py diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/adaptive_short.yaml b/e2e_tests/tests/fixtures/mnist_pytorch/adaptive_short.yaml index 378d8bdb9fc..bcbd69e79d1 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/adaptive_short.yaml +++ b/e2e_tests/tests/fixtures/mnist_pytorch/adaptive_short.yaml @@ -1,5 +1,5 @@ name: mnist_pytorch_adaptive -entrypoint: python3 train.py +entrypoint: python3 train.py --batches 400 hyperparameters: learning_rate: type: double @@ -25,8 +25,8 @@ searcher: name: adaptive_asha metric: validation_loss smaller_is_better: true - max_length: - batches: 400 + time_metric: batches + max_time: 400 max_trials: 5 max_rungs: 2 mode: aggressive diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/const-profiling.yaml b/e2e_tests/tests/fixtures/mnist_pytorch/const-profiling.yaml index 81434d144ab..fe2b42ce998 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/const-profiling.yaml +++ b/e2e_tests/tests/fixtures/mnist_pytorch/const-profiling.yaml @@ -8,7 +8,5 @@ hyperparameters: searcher: name: single metric: validation_loss - max_length: - epochs: 1 smaller_is_better: true -entrypoint: python3 profiling.py +entrypoint: python3 profiling.py --epochs 1 diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/const-pytorch11.yaml b/e2e_tests/tests/fixtures/mnist_pytorch/const-pytorch11.yaml index 7f07423e10f..b6265526c56 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/const-pytorch11.yaml +++ b/e2e_tests/tests/fixtures/mnist_pytorch/const-pytorch11.yaml @@ -1,5 +1,5 @@ name: mnist_pytorch_const -entrypoint: python3 train.py +entrypoint: python3 train.py --batches 200 hyperparameters: learning_rate: 0.001 dropout: 0.5 @@ -11,8 +11,6 @@ hyperparameters: searcher: name: single metric: validation_loss - max_length: - batches: 200 smaller_is_better: true max_restarts: 0 # bind-mounting the /tmp/work_dir directory for the mnist_pytorch experiment diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/distributed-stop-requested.yaml b/e2e_tests/tests/fixtures/mnist_pytorch/distributed-stop-requested.yaml index cc359951d73..eaced3e5c82 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/distributed-stop-requested.yaml +++ b/e2e_tests/tests/fixtures/mnist_pytorch/distributed-stop-requested.yaml @@ -11,7 +11,7 @@ max_restarts: 0 searcher: name: single metric: validation_loss - max_length: - batches: 937 #60,000 training images with batch size 64 smaller_is_better: true -entrypoint: python3 -m determined.launch.torch_distributed python3 stop_requested_model_def.py +entrypoint: >- + python3 -m determined.launch.torch_distributed + python3 stop_requested_model_def.py diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/failable.yaml b/e2e_tests/tests/fixtures/mnist_pytorch/failable.yaml index fc4c60de3b1..9068ebf312f 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/failable.yaml +++ b/e2e_tests/tests/fixtures/mnist_pytorch/failable.yaml @@ -12,7 +12,5 @@ max_restarts: 0 searcher: name: single metric: validation_loss - max_length: - batches: 8 smaller_is_better: true entrypoint: python3 failable_model_def.py diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/failable_model_def.py b/e2e_tests/tests/fixtures/mnist_pytorch/failable_model_def.py index af6bfab602b..647cf7a98f5 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/failable_model_def.py +++ b/e2e_tests/tests/fixtures/mnist_pytorch/failable_model_def.py @@ -27,6 +27,7 @@ def train_batch(self, batch, epoch_idx, batch_idx): trial = MNistFailable(context=train_context, hparams=info.trial.hparams) trainer = pytorch.Trainer(trial, train_context) trainer.fit( + max_length=pytorch.Batch(8), checkpoint_policy="none", checkpoint_period=pytorch.Batch(3), validation_period=pytorch.Batch(1), diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py b/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py index 2444662a10a..461983baff3 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py +++ b/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py @@ -1,3 +1,4 @@ +import argparse import logging import train as mnist_pytorch @@ -6,7 +7,7 @@ from determined import pytorch -def run(): +def run(epochs): """Initializes the trial and runs the training loop with profiling enabled.""" info = det.get_cluster_info() @@ -15,10 +16,19 @@ def run(): with pytorch.init() as train_context: trial = mnist_pytorch.MNistTrial(train_context, hparams=info.trial.hparams) trainer = pytorch.Trainer(trial, train_context) - trainer.fit(latest_checkpoint=info.latest_checkpoint, profiling_enabled=True) + trainer.fit( + max_length=pytorch.Epoch(epochs), + latest_checkpoint=info.latest_checkpoint, + profiling_enabled=True, + ) if __name__ == "__main__": # Configure logging logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) - run() + + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", type=int, default=1) + args = parser.parse_args() + + run(args.epochs) diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/random.yaml b/e2e_tests/tests/fixtures/mnist_pytorch/random.yaml index 6b70ec6b893..110b0db07c8 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/random.yaml +++ b/e2e_tests/tests/fixtures/mnist_pytorch/random.yaml @@ -1,7 +1,5 @@ -name: mnist_pytorch_adaptive_random -data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz -entrypoint: python3 train.py +name: mnist_pytorch_random +entrypoint: python3 train.py --batches 10 hyperparameters: global_batch_size: 64 learning_rate: @@ -28,8 +26,6 @@ searcher: name: random metric: validation_loss smaller_is_better: true - max_length: - batches: 10 max_trials: 50 max_restarts: 0 # bind-mounting the /tmp/work_dir directory for the mnist_pytorch experiment diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/stop_requested_model_def.py b/e2e_tests/tests/fixtures/mnist_pytorch/stop_requested_model_def.py index 6b0e164359b..63862f61b6c 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/stop_requested_model_def.py +++ b/e2e_tests/tests/fixtures/mnist_pytorch/stop_requested_model_def.py @@ -24,4 +24,4 @@ def __init__( with pytorch.init() as train_context: trial = MNistTrialStopRequested(train_context, hparams=info.trial.hparams) trainer = pytorch.Trainer(trial, train_context) - trainer.fit(latest_checkpoint=info.latest_checkpoint) + trainer.fit(max_length=pytorch.Epoch(1), latest_checkpoint=info.latest_checkpoint) diff --git a/e2e_tests/tests/fixtures/noop/train.py b/e2e_tests/tests/fixtures/noop/train.py index 8ac84bc82fb..71684e5c443 100644 --- a/e2e_tests/tests/fixtures/noop/train.py +++ b/e2e_tests/tests/fixtures/noop/train.py @@ -18,7 +18,7 @@ import pathlib import sys import time -from typing import Iterator, Optional, Tuple +from typing import Optional, Tuple import determined as det from determined import core @@ -64,8 +64,6 @@ def main( last_action_id, steps_completed = load_state(trial_id, path) starting_action_id = last_action_id + 1 - operations = None # type: Iterator[core.SearcherOperation] - for action_id, action in enumerate(actions[starting_action_id:], start=starting_action_id): logging.info(f"executing {action}") if action["action"] == "exit": @@ -88,12 +86,6 @@ def main( elif action["action"] == "log": msg = base64.b64decode(action["base64"]).decode("utf8") logging.log(action["level"], msg) - elif action["action"] == "complete_searcher_operation": - # Get operations if we haven't already. - if not operations: - operations = core_context.searcher.operations(core.SearcherMode.ChiefOnly) - op = next(operations) - op.report_completed(action["metric"]) else: raise ValueError(f"unexpected action type: {action}") diff --git a/e2e_tests/tests/fixtures/ports-proxy/config.yaml b/e2e_tests/tests/fixtures/ports-proxy/config.yaml index cc8d2d44734..738df5c36b0 100644 --- a/e2e_tests/tests/fixtures/ports-proxy/config.yaml +++ b/e2e_tests/tests/fixtures/ports-proxy/config.yaml @@ -5,7 +5,6 @@ resources: slots_per_trial: 2 searcher: - max_length: 10000000 name: grid metric: x max_concurrent_trials: 2 diff --git a/e2e_tests/tests/nightly/test_distributed.py b/e2e_tests/tests/nightly/test_distributed.py index 76da98c1433..f3b0fa1c60a 100644 --- a/e2e_tests/tests/nightly/test_distributed.py +++ b/e2e_tests/tests/nightly/test_distributed.py @@ -4,7 +4,6 @@ import pytest -from determined.common import util from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -14,7 +13,8 @@ def test_mnist_pytorch_distributed() -> None: sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/distributed.yaml")) - config = conf.set_max_length(config, {"batches": 200}) + assert "--epochs 1" in config["entrypoint"], "update test to match tutorial" + config["entrypoint"] = config["entrypoint"].replace("--epochs 1", "--batches 64") exp.run_basic_test_with_temp_config(sess, config, conf.fixtures_path("mnist_pytorch"), 1) @@ -39,7 +39,7 @@ def test_hf_trainer_api_integration() -> None: def test_gpt_neox_zero1() -> None: sess = api_utils.user_session() config = conf.load_config(conf.deepspeed_examples_path("gpt_neox/zero1.yaml")) - config = conf.set_max_length(config, {"batches": 100}) + config["searcher"]["max_length"] = {"batches": 100} config = conf.set_min_validation_period(config, {"batches": 100}) # Changing to satisfy cluter size and gpu mem limitations. config = conf.set_slots_per_trial(config, 8) @@ -79,7 +79,7 @@ def test_textual_inversion_stable_diffusion_finetune() -> None: "textual_inversion_stable_diffusion/finetune_const_advanced.yaml" ) ) - config = conf.set_max_length(config, 10) + config["hyperparameters"]["training"]["num_sgd_steps"] = 10 try: config = conf.set_environment_variables( config, [f'HF_AUTH_TOKEN={os.environ["HF_READ_ONLY_TOKEN"]}'] @@ -111,7 +111,7 @@ def test_textual_inversion_stable_diffusion_generate() -> None: conf.diffusion_examples_path("textual_inversion_stable_diffusion/generate_grid.yaml") ) # Shorten the Experiment and reduce to two Trials. - config = conf.set_max_length(config, 2) + config["hyperparameters"]["num_batches"] = 2 prompt_vals = config["hyperparameters"]["call_kwargs"]["prompt"]["vals"] config["hyperparameters"]["call_kwargs"]["guidance_scale"] = 7.5 while len(prompt_vals) > 1: @@ -131,90 +131,6 @@ def test_textual_inversion_stable_diffusion_generate() -> None: raise k -@pytest.mark.distributed -@pytest.mark.gpu_required -def test_hf_trainer_image_classification_deepspeed_autotuning() -> None: - sess = api_utils.user_session() - test_dir = "hf_image_classification" - config_path = conf.hf_trainer_examples_path(f"{test_dir}/deepspeed.yaml") - config = conf.load_config(config_path) - with tempfile.NamedTemporaryFile() as tf: - with open(tf.name, "w") as f: - util.yaml_safe_dump(config, f) - # expected_trials=1 in run_basic_autotuning_test because the search runner only generates - # a single trial (which in turn generates a second, possibly multi-trial experiment). - _ = exp.run_basic_autotuning_test( - sess, - tf.name, - conf.hf_trainer_examples_path(test_dir), - 1, - search_method_name="asha", - ) - - -@pytest.mark.distributed -@pytest.mark.gpu_required -def test_hf_trainer_language_modeling_deepspeed_autotuning() -> None: - sess = api_utils.user_session() - test_dir = "hf_language_modeling" - config_path = conf.hf_trainer_examples_path(f"{test_dir}/deepspeed.yaml") - config = conf.load_config(config_path) - with tempfile.NamedTemporaryFile() as tf: - with open(tf.name, "w") as f: - util.yaml_safe_dump(config, f) - # expected_trials=1 in run_basic_autotuning_test because the search runner only generates - # a single trial (which in turn generates a second, possibly multi-trial experiment). - _ = exp.run_basic_autotuning_test( - sess, - tf.name, - conf.hf_trainer_examples_path(test_dir), - 1, - search_method_name="binary", - ) - - -@pytest.mark.distributed -@pytest.mark.gpu_required -def test_torchvision_core_api_deepspeed_autotuning() -> None: - sess = api_utils.user_session() - test_dir = "torchvision/core_api" - config_path = conf.deepspeed_autotune_examples_path(f"{test_dir}/deepspeed.yaml") - config = conf.load_config(config_path) - with tempfile.NamedTemporaryFile() as tf: - with open(tf.name, "w") as f: - util.yaml_safe_dump(config, f) - # expected_trials=1 in run_basic_autotuning_test because the search runner only generates - # a single trial (which in turn generates a second, possibly multi-trial experiment). - _ = exp.run_basic_autotuning_test( - sess, - tf.name, - conf.deepspeed_autotune_examples_path(test_dir), - 1, - search_method_name="asha", - ) - - -@pytest.mark.distributed -@pytest.mark.gpu_required -def test_torchvision_deepspeed_trial_deepspeed_autotuning() -> None: - sess = api_utils.user_session() - test_dir = "torchvision/deepspeed_trial" - config_path = conf.deepspeed_autotune_examples_path(f"{test_dir}/deepspeed.yaml") - config = conf.load_config(config_path) - with tempfile.NamedTemporaryFile() as tf: - with open(tf.name, "w") as f: - util.yaml_safe_dump(config, f) - # expected_trials=1 in run_basic_autotuning_test because the search runner only generates - # a single trial (which in turn generates a second, possibly multi-trial experiment). - _ = exp.run_basic_autotuning_test( - sess, - tf.name, - conf.deepspeed_autotune_examples_path(test_dir), - 1, - search_method_name="random", - ) - - @pytest.mark.distributed @pytest.mark.gpu_required def test_torch_batch_process_generate_embedding() -> None: diff --git a/examples/Makefile b/examples/Makefile index c650066b63b..2b7d6e60d78 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -7,9 +7,6 @@ CV_EXAMPLES_DIRS := $(patsubst computer_vision/%/., build/%.tgz, $(CV_EXAMPLES)) DEEPSPEED_EXAMPLES := $(wildcard deepspeed/*/.) DEEPSPEED_EXAMPLES_DIRS := $(patsubst deepspeed/%/., build/%.tgz, $(DEEPSPEED_EXAMPLES)) -DEEPSPEED_AUTOTUNE_EXAMPLES := $(wildcard deepspeed_autotune/*/.) -DEEPSPEED_AUTOTUNE_EXAMPLES_DIRS := $(patsubst deepspeed_autotune/%/., build/%.tgz, $(DEEPSPEED_AUTOTUNE_EXAMPLES)) - HF_TRAINER_EXAMPLES := $(wildcard hf_trainer_api/*/.) HF_TRAINER_EXAMPLES_DIRS := $(patsubst hf_trainer_api/%/., build/%.tgz, $(HF_TRAINER_EXAMPLES)) @@ -26,7 +23,7 @@ IGNORE := \( -path ./build -o -path ./tests -o -name __pycache__ -o -name \*.pyc # SRCS is a list of all files that could affect our outputs. SRCS := $(shell find . $(IGNORE) -prune -o -type f -print | sort) -build/stamp: $(TUTORIAL_EXAMPLES_DIRS) $(CV_EXAMPLES_DIRS) $(DEEPSPEED_EXAMPLES_DIRS) $(DEEPSPEED_AUTOTUNE_EXAMPLES_DIRS) $(HF_TRAINER_EXAMPLES_DIRS) $(DIFFUSION_EXAMPLES_DIRS) $(FEATURES_EXAMPLES_DIRS) +build/stamp: $(TUTORIAL_EXAMPLES_DIRS) $(CV_EXAMPLES_DIRS) $(DEEPSPEED_EXAMPLES_DIRS) $(HF_TRAINER_EXAMPLES_DIRS) $(DIFFUSION_EXAMPLES_DIRS) $(FEATURES_EXAMPLES_DIRS) touch $@ .PHONY: build diff --git a/examples/computer_vision/iris_tf_keras/adaptive.yaml b/examples/computer_vision/iris_tf_keras/adaptive.yaml index 5c91087dac6..70863f27e59 100644 --- a/examples/computer_vision/iris_tf_keras/adaptive.yaml +++ b/examples/computer_vision/iris_tf_keras/adaptive.yaml @@ -1,7 +1,4 @@ name: iris_tf_keras_adaptive_search -data: - train_url: http://download.tensorflow.org/data/iris_training.csv - test_url: http://download.tensorflow.org/data/iris_test.csv environment: image: cpu: determinedai/tensorflow-ngc-dev:0736b6d @@ -25,7 +22,9 @@ searcher: name: adaptive_asha metric: val_categorical_accuracy smaller_is_better: false - max_length: - batches: 6400 + time_metric: batches + max_time: 6400 max_trials: 512 -entrypoint: python3 -m determined.launch.horovod --autohorovod --trial model_def:IrisTrial +entrypoint: >- + python3 -m determined.launch.tensorflow -- + python3 train.py diff --git a/examples/computer_vision/iris_tf_keras/const.yaml b/examples/computer_vision/iris_tf_keras/const.yaml index 595a754272c..37fd8de1e66 100644 --- a/examples/computer_vision/iris_tf_keras/const.yaml +++ b/examples/computer_vision/iris_tf_keras/const.yaml @@ -1,7 +1,4 @@ name: iris_tf_keras_const -data: - train_url: http://download.tensorflow.org/data/iris_training.csv - test_url: http://download.tensorflow.org/data/iris_test.csv environment: image: cpu: determinedai/tensorflow-ngc-dev:0736b6d @@ -15,6 +12,6 @@ searcher: name: single metric: val_categorical_accuracy smaller_is_better: false - max_length: - batches: 5000 -entrypoint: python3 -m determined.launch.horovod --autohorovod --trial model_def:IrisTrial +entrypoint: >- + python3 -m determined.launch.tensorflow -- + python3 train.py diff --git a/examples/computer_vision/iris_tf_keras/distributed.yaml b/examples/computer_vision/iris_tf_keras/distributed.yaml index 4dedcdec475..35ee042f776 100644 --- a/examples/computer_vision/iris_tf_keras/distributed.yaml +++ b/examples/computer_vision/iris_tf_keras/distributed.yaml @@ -1,7 +1,4 @@ name: iris_tf_keras_distributed -data: - train_url: http://download.tensorflow.org/data/iris_training.csv - test_url: http://download.tensorflow.org/data/iris_test.csv environment: image: cpu: determinedai/tensorflow-ngc-dev:0736b6d @@ -17,6 +14,6 @@ searcher: name: single metric: val_categorical_accuracy smaller_is_better: false - max_length: - batches: 2500 -entrypoint: python3 -m determined.launch.horovod --autohorovod --trial model_def:IrisTrial +entrypoint: >- + python3 -m determined.launch.tensorflow -- + python3 train.py diff --git a/examples/computer_vision/iris_tf_keras/model_def.py b/examples/computer_vision/iris_tf_keras/model_def.py deleted file mode 100644 index 1624a7705de..00000000000 --- a/examples/computer_vision/iris_tf_keras/model_def.py +++ /dev/null @@ -1,105 +0,0 @@ -""" -This example shows how you could use Keras `Sequence`s and multiprocessing/multithreading for Keras -models in Determined. - -Useful References: - http://docs.determined.ai/latest/keras.html - https://keras.io/utils/ - -Based off of: https://medium.com/@nickbortolotti/iris-species-categorization-using-tf-keras-tf-data- - and-differences-between-eager-mode-on-and-off-9b4693e0b22 -""" -from typing import List - -import pandas as pd -import tensorflow as tf -from tensorflow.keras.layers import Dense, Input -from tensorflow.keras.losses import categorical_crossentropy -from tensorflow.keras.metrics import categorical_accuracy -from tensorflow.keras.models import Model -from tensorflow.keras.optimizers.legacy import RMSprop -from tensorflow.keras.utils import to_categorical - -from determined import keras - -# Constants about the data set. -NUM_CLASSES = 3 - -# The first row of each data set is not a typical CSV header with column labels, but rather a -# dataset descriptor of the following format: -# -# ,,,, -# -# The remaining rows then contain observations, with the four features followed by label. The -# label values in the observation rows take on the values 0, 1, or 2 which correspond to the -# three species in the header. Define the columns explicitly here so that we can more easily -# separate features and labels below. -LABEL_HEADER = "Species" -DS_COLUMNS = [ - "SepalLength", - "SepalWidth", - "PetalLength", - "PetalWidth", - LABEL_HEADER, -] - - -class IrisTrial(keras.TFKerasTrial): - def __init__(self, context: keras.TFKerasTrialContext) -> None: - self.context = context - - def build_model(self) -> Model: - """ - Define model for iris classification. - - This is a simple model with one hidden layer to predict iris species (setosa, versicolor, or - virginica) based on four input features (length and width of sepals and petals). - """ - inputs = Input(shape=(4,)) - dense1 = Dense(self.context.get_hparam("layer1_dense_size"))(inputs) - dense2 = Dense(NUM_CLASSES, activation="softmax")(dense1) - - # Wrap the model. - model = self.context.wrap_model(Model(inputs=inputs, outputs=dense2)) - - # Create and wrap the optimizer. - optimizer = RMSprop( - lr=self.context.get_hparam("learning_rate"), - decay=self.context.get_hparam("learning_rate_decay"), - ) - optimizer = self.context.wrap_optimizer(optimizer) - - model.compile( - optimizer, - categorical_crossentropy, - [categorical_accuracy], - ) - - return model - - def keras_callbacks(self) -> List[tf.keras.callbacks.Callback]: - return [keras.callbacks.TensorBoard(update_freq="batch", profile_batch=0, histogram_freq=1)] - - def build_training_data_loader(self) -> keras.InputData: - # Ignore header line and read the training CSV observations into a pandas DataFrame. - train = pd.read_csv(self.context.get_data_config()["train_url"], names=DS_COLUMNS, header=0) - train_features, train_labels = train, train.pop(LABEL_HEADER) - - # Since we're building a classifier, convert the labels in the raw - # dataset (0, 1, or 2) to one-hot vector encodings that we'll to - # construct the Sequence data loaders that Determined expects. - train_labels_categorical = to_categorical(train_labels, num_classes=3) - - return train_features.values, train_labels_categorical - - def build_validation_data_loader(self) -> keras.InputData: - # Ignore header line and read the test CSV observations into a pandas DataFrame. - test = pd.read_csv(self.context.get_data_config()["test_url"], names=DS_COLUMNS, header=0) - test_features, test_labels = test, test.pop(LABEL_HEADER) - - # Since we're building a classifier, convert the labels in the raw - # dataset (0, 1, or 2) to one-hot vector encodings that we'll to - # construct the Sequence data loaders that Determined expects. - test_labels_categorical = to_categorical(test_labels, num_classes=3) - - return test_features.values, test_labels_categorical diff --git a/examples/computer_vision/iris_tf_keras/train.py b/examples/computer_vision/iris_tf_keras/train.py new file mode 100644 index 00000000000..2e5fccc252c --- /dev/null +++ b/examples/computer_vision/iris_tf_keras/train.py @@ -0,0 +1,143 @@ +""" +This example shows you how to train a model with Determined's keras callback. + +Useful References: + https://docs.determined.ai/latest/reference/training/api-keras-reference.html + https://keras.io/api/ + +Based off of: https://medium.com/@nickbortolotti/iris-species-categorization-using-tf-keras-tf-data- + and-differences-between-eager-mode-on-and-off-9b4693e0b22 +""" +import argparse +import logging +from typing import List + +import pandas as pd +from tensorflow.keras import layers, losses, metrics, models, utils +from tensorflow.keras.optimizers import legacy + +import determined as det +import determined.keras + +# Where to download data from. +TRAIN_DATA = "http://download.tensorflow.org/data/iris_training.csv" +TEST_DATA = "http://download.tensorflow.org/data/iris_test.csv" + +# Constants about the data set. +NUM_CLASSES = 3 + +# The first row of each data set is not a typical CSV header with column labels, but rather a +# dataset descriptor of the following format: +# +# ,,,, +# +# The remaining rows then contain observations, with the four features followed by label. The +# label values in the observation rows take on the values 0, 1, or 2 which correspond to the +# three species in the header. Define the columns explicitly here so that we can more easily +# separate features and labels below. +LABEL_HEADER = "Species" +DS_COLUMNS = [ + "SepalLength", + "SepalWidth", + "PetalLength", + "PetalWidth", + LABEL_HEADER, +] + + +def get_train_data(): + # Ignore header line and read the training CSV observations into a pandas DataFrame. + train = pd.read_csv(TRAIN_DATA, names=DS_COLUMNS, header=0) + train_features, train_labels = train, train.pop(LABEL_HEADER) + + # Since we're building a classifier, convert the labels in the raw + # dataset (0, 1, or 2) to one-hot vector encodings that we'll to + # construct the Sequence data loaders that Determined expects. + train_labels_categorical = utils.to_categorical(train_labels, num_classes=3) + + return train_features.values, train_labels_categorical + + +def get_test_data(): + test = pd.read_csv(TEST_DATA, names=DS_COLUMNS, header=0) + test_features, test_labels = test, test.pop(LABEL_HEADER) + test_labels_categorical = utils.to_categorical(test_labels, num_classes=3) + return test_features.values, test_labels_categorical + + +def main(core_context, strategy, checkpoint, continue_id, hparams, epochs): + # Download train and test data. + train_x, train_y = get_train_data() + validation_data = get_test_data() + + # Create and compile the model within a strategy's scope. + with strategy.scope(): + inputs = layers.Input(shape=(4,)) + dense1 = layers.Dense(hparams["layer1_dense_size"])(inputs) + dense2 = layers.Dense(NUM_CLASSES, activation="softmax")(dense1) + model = models.Model(inputs=inputs, outputs=dense2) + + optimizer = legacy.RMSprop( + lr=hparams["learning_rate"], + decay=hparams["learning_rate_decay"], + ) + + model.compile( + optimizer, + losses.categorical_crossentropy, + [metrics.categorical_accuracy], + ) + + # Create the main DeterminedCallback that connects training to the Determined cluster. + det_cb = det.keras.DeterminedCallback( + core_context, + checkpoint=checkpoint, + continue_id=continue_id, + # Iris epochs are very short, so we don't even bother to save checkpoints until we finish. + checkpoint_epochs=0, + ) + + # Also include a Determined-aware version of the Keras' TensorBoard callback. + tb_cb = det.keras.TensorBoard( + core_context, update_freq="batch", profile_batch=0, histogram_freq=1 + ) + + # Call model.fit() with our callbacks. + model.fit( + x=train_x, + y=train_y, + batch_size=hparams["global_batch_size"], + validation_data=validation_data, + epochs=epochs, + callbacks=[det_cb, tb_cb], + ) + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", type=int, default=100, help="how long to train for") + args = parser.parse_args() + + info = det.get_cluster_info() + if info and info.task_type == "TRIAL": + # We are a training a trial on-cluster. + continue_id = info.trial.trial_id + checkpoint = info.latest_checkpoint + # Use the hparams selected by the searcher for this trial. + hparams = info.trial.hparams + else: + # We are either in a notebook on-cluster or off-cluster entirely. + continue_id = "local-train-task" + checkpoint = None + # Pick some hparams for ourselves. + hparams = { + "learning_rate": 1.0e-4, + "learning_rate_decay": 1.0e-6, + "layer1_dense_size": 16, + "global_batch_size": 16, + } + + distributed, strategy = det.core.DistributedContext.from_tf_config() + with det.core.init(distributed=distributed) as core_context: + main(core_context, strategy, checkpoint, continue_id, hparams, args.epochs) diff --git a/examples/deepspeed/dcgan/README.md b/examples/deepspeed/dcgan/README.md new file mode 100644 index 00000000000..f0b9811b9c9 --- /dev/null +++ b/examples/deepspeed/dcgan/README.md @@ -0,0 +1,49 @@ +# DeepSpeed CIFAR Example +This example is adapted from the +[DCGAN example in the DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/training/gan) +repository. It is intended to demonstrate a simple usecase of DeepSpeed with Determined. + +## Files +* **model.py**: The DCGANTrial definition. +* **gan_model.py**: Network definitions for generator and discriminator. +* **data.py**: Dataset loading/downloading code. + +### Configuration Files +* **ds_config.json**: The DeepSpeed config file. +* **mnist.yaml**: Determined config to train the model on mnist on a cluster. + +## Data +This repo supports the same datasets as the original example: `["imagenet", "lfw", "lsun", "cifar10", "mnist", "fake", "celeba"]`. The `cifar10` and `mnist` datasets will be downloaded as needed, whereas the rest must be mounted on the agent. For `lsun`, the `data_config.classes` setting must be set. The `folder` dataset can be used to load an arbitrary torchvision `ImageFolder` that is mounted on the agent. + +## To Run Locally + +It is recommended to run this from within one of our agent docker images, found at +https://hub.docker.com/r/determinedai/pytorch-ngc/tags + +After installing docker and pulling an image, users can launch a container via +`docker run --gpus=all -v ~path/to/repo:/src/proj -it ` + +Install necessary dependencies via `pip install determined mpi4py` + +Then, run the following command: +``` +python trainer.py +``` + +Any additional configs can be specified in `mnist.yaml` and `ds_config.json` accordingly. + +## To Run on Cluster +If you have not yet installed Determined, installation instructions can be found +under `docs/install-admin.html` or at https://docs.determined.ai/latest/index.html + +Run the following command: +``` +det experiment create mnist.yaml . +``` +The other configurations can be run by specifying the appropriate configuration file in place +of `mnist.yaml`. + +## Results +Training `mnist` should yield reasonable looking fake digit images on the images tab in TensorBoard after ~5k steps. + +Training `cifar10` does not converge as convincingly, but should look image-like after ~10k steps. diff --git a/examples/deepspeed/dcgan/data.py b/examples/deepspeed/dcgan/data.py new file mode 100644 index 00000000000..c950df584d1 --- /dev/null +++ b/examples/deepspeed/dcgan/data.py @@ -0,0 +1,104 @@ +import contextlib +import os +from typing import cast + +import filelock +import torch +import torchvision.datasets as dset +import torchvision.transforms as transforms + +CHANNELS_BY_DATASET = { + "imagenet": 3, + "folder": 3, + "lfw": 3, + "lsun": 3, + "cifar10": 3, + "mnist": 1, + "fake": 3, + "celeba": 3, +} + + +def get_dataset(data_config: dict) -> torch.utils.data.Dataset: + if data_config.get("dataroot", None) is None: + if str(data_config.get("dataset"),"").lower() != "fake": + raise ValueError('`dataroot` parameter is required for dataset "%s"' + % data_config.get("dataset", "")) + else: + context = contextlib.nullcontext() + else: + # Ensure that only one local process attempts to download/validate datasets at once. + context = filelock.FileLock(os.path.join(data_config["dataroot"], ".lock")) + with context: + if data_config["dataset"] in ["imagenet", "folder", "lfw"]: + # folder dataset + dataset = dset.ImageFolder( + root=data_config["dataroot"], + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.CenterCrop(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + elif data_config["dataset"] == "lsun": + classes = [c + "_train" for c in data_config["classes"].split(",")] + dataset = dset.LSUN( + root=data_config["dataroot"], + classes=classes, + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.CenterCrop(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + elif data_config["dataset"] == "cifar10": + dataset = dset.CIFAR10( + root=data_config["dataroot"], + download=True, + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + elif data_config["dataset"] == "mnist": + dataset = dset.MNIST( + root=data_config["dataroot"], + download=True, + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ] + ), + ) + elif data_config["dataset"] == "fake": + dataset = dset.FakeData( + image_size=(3, data_config["image_size"], data_config["image_size"]), + transform=transforms.ToTensor(), + ) + elif data_config["dataset"] == "celeba": + dataset = dset.ImageFolder( + root=data_config["dataroot"], + transform=transforms.Compose( + [ + transforms.Resize(data_config["image_size"]), + transforms.CenterCrop(data_config["image_size"]), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + else: + unknown_dataset_name = data_config["dataset"] + raise Exception(f"Unknown dataset {unknown_dataset_name}") + return cast(torch.utils.data.Dataset, dataset) diff --git a/examples/deepspeed/dcgan/ds_config.json b/examples/deepspeed/dcgan/ds_config.json new file mode 100644 index 00000000000..708952b50b2 --- /dev/null +++ b/examples/deepspeed/dcgan/ds_config.json @@ -0,0 +1,15 @@ +{ + "train_batch_size": 64, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [ + 0.5, + 0.999 + ], + "eps": 1e-8 + } + }, + "steps_per_print": 10 +} diff --git a/examples/deepspeed/dcgan/gan_model.py b/examples/deepspeed/dcgan/gan_model.py new file mode 100644 index 00000000000..97ed726f45b --- /dev/null +++ b/examples/deepspeed/dcgan/gan_model.py @@ -0,0 +1,73 @@ +from typing import cast + +import torch +import torch.nn as nn + + +def weights_init(m: nn.Module) -> None: + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(cast(torch.Tensor, m.weight.data), 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(cast(torch.Tensor, m.weight.data), 1.0, 0.02) + nn.init.constant_(cast(torch.Tensor, m.bias.data), 0) + + +class Generator(nn.Module): + def __init__(self, ngf: int, nc: int, nz: int) -> None: + super(Generator, self).__init__() # type: ignore + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), + nn.BatchNorm2d(ngf * 8), # type: ignore + nn.ReLU(True), + # state size. (ngf*8) x 4 x 4 + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 4), # type: ignore + nn.ReLU(True), + # state size. (ngf*4) x 8 x 8 + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), # type: ignore + nn.ReLU(True), + # state size. (ngf*2) x 16 x 16 + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), # type: ignore + nn.ReLU(True), + # state size. (ngf) x 32 x 32 + nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), + nn.Tanh() # type: ignore + # state size. (nc) x 64 x 64 + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.main(input) + return cast(torch.Tensor, output) + + +class Discriminator(nn.Module): + def __init__(self, ndf: int, nc: int) -> None: + super(Discriminator, self).__init__() # type: ignore + self.main = nn.Sequential( + # input is (nc) x 64 x 64 + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf) x 32 x 32 + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 2), # type: ignore + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*2) x 16 x 16 + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 4), # type: ignore + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*4) x 8 x 8 + nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 8), # type: ignore + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*8) x 4 x 4 + nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), + nn.Sigmoid(), # type: ignore + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.main(input) + return cast(torch.Tensor, output.view(-1, 1).squeeze(1)) diff --git a/examples/deepspeed/dcgan/mnist.yaml b/examples/deepspeed/dcgan/mnist.yaml new file mode 100644 index 00000000000..fb996c55532 --- /dev/null +++ b/examples/deepspeed/dcgan/mnist.yaml @@ -0,0 +1,33 @@ +name: dcgan_deepspeed_mnist +data: + dataroot: /data + dataset: mnist + image_size: 64 +hyperparameters: + deepspeed_config: ds_config.json + noise_length: 100 + generator_width_base: 64 + discriminator_width_base: 64 + data_workers: 16 +environment: + environment_variables: + - NCCL_DEBUG=INFO + - NCCL_SOCKET_IFNAME=ens,eth,ib + image: determinedai/pytorch-ngc-dev:0736b6d +bind_mounts: + - host_path: /tmp + container_path: /data +resources: + slots_per_trial: 2 +searcher: + name: single + metric: no_validation_metric +min_validation_period: + batches: 0 +entrypoint: + - python3 + - -m + - determined.launch.deepspeed + - python3 + - trainer.py +max_restarts: 0 diff --git a/examples/deepspeed/dcgan/model.py b/examples/deepspeed/dcgan/model.py new file mode 100644 index 00000000000..8ceab93dc6a --- /dev/null +++ b/examples/deepspeed/dcgan/model.py @@ -0,0 +1,208 @@ +import logging +from typing import Any, Dict, Iterator, Optional, Tuple, Union, cast + +import data +import deepspeed +import torch +import torch.nn as nn +import torch.utils.data +import torchvision +from gan_model import Discriminator, Generator, weights_init + +from determined.pytorch import DataLoader, TorchData +from determined.pytorch import deepspeed as det_ds + +REAL_LABEL = 1 +FAKE_LABEL = 0 + + +class DCGANTrial(det_ds.DeepSpeedTrial): + def __init__(self, context: det_ds.DeepSpeedTrialContext, + hparams: dict, data_config: dict) -> None: + self.context = context + self.hparams = hparams + self.data_config = data_config + self.logger = self.context.get_tensorboard_writer() + num_channels = data.CHANNELS_BY_DATASET[self.data_config["dataset"]] + gen_net = Generator( + self.hparams["generator_width_base"], num_channels, self.hparams["noise_length"] + ) + gen_net.apply(weights_init) + disc_net = Discriminator(self.hparams["discriminator_width_base"], num_channels) + disc_net.apply(weights_init) + gen_parameters = filter(lambda p: p.requires_grad, gen_net.parameters()) + disc_parameters = filter(lambda p: p.requires_grad, disc_net.parameters()) + ds_config = det_ds.overwrite_deepspeed_config( + self.hparams["deepspeed_config"], self.hparams.get("overwrite_deepspeed_args", {}) + ) + generator, _, _, _ = deepspeed.initialize( + model=gen_net, model_parameters=gen_parameters, config=ds_config + ) + discriminator, _, _, _ = deepspeed.initialize( + model=disc_net, model_parameters=disc_parameters, config=ds_config + ) + + self.generator = self.context.wrap_model_engine(generator) + self.discriminator = self.context.wrap_model_engine(discriminator) + self.fixed_noise = self.context.to_device( + torch.randn( + self.context.train_micro_batch_size_per_gpu, self.hparams["noise_length"], 1, 1 + ) + ) + self.criterion = nn.BCELoss() + self.fp16 = generator.fp16_enabled() + self.gradient_accumulation_steps = generator.gradient_accumulation_steps() + # Manually perform gradient accumulation. + if self.gradient_accumulation_steps > 1: + logging.info("Disabling automatic gradient accumulation.") + self.context.disable_auto_grad_accumulation() + + def _get_noise(self, dtype: torch.dtype) -> torch.Tensor: + return cast( + torch.Tensor, + self.context.to_device( + torch.randn( + self.context.train_micro_batch_size_per_gpu, + self.hparams["noise_length"], + 1, + 1, + dtype=dtype, + ) + ), + ) + + def _get_label_constants( + self, batch_size: int, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + real_label = cast( + torch.Tensor, + self.context.to_device(torch.full((batch_size,), REAL_LABEL, dtype=dtype)), + ) + fake_label = cast( + torch.Tensor, + self.context.to_device(torch.full((batch_size,), FAKE_LABEL, dtype=dtype)), + ) + return real_label, fake_label + + def train_batch( + self, iter_dataloader: Optional[Iterator[TorchData]], epoch_idx: int, batch_idx: int + ) -> Union[torch.Tensor, Dict[str, Any]]: + assert iter_dataloader is not None + if self.fp16: + dtype = torch.float16 + else: + dtype = torch.float32 + real_label, fake_label = self._get_label_constants( + self.context.train_micro_batch_size_per_gpu, dtype + ) + ############################ + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) + ########################### + self.discriminator.zero_grad() + + real_sample_count = 0 + errD_real_sum = 0.0 + errD_fake_sum = 0.0 + D_x = 0.0 + D_G_z1 = 0.0 + fake_sample_count = ( + self.context.train_micro_batch_size_per_gpu * self.gradient_accumulation_steps + ) + + for i in range(self.gradient_accumulation_steps): + # Note: at end of epoch, may receive a batch of size smaller than train_micro_batch_size_per_gpu. + # In that case, we end up training on more fake examples than real examples. + # train with real + real, _ = self.context.to_device(next(iter_dataloader)) + real = cast(torch.Tensor, real) + actual_batch_size = real.shape[0] + real_sample_count += actual_batch_size + if self.fp16: + real = real.half() + output = self.discriminator(real) + # For edge-case small batches, must cut real_label size to match. + errD_real = self.criterion(output, real_label[:actual_batch_size]) + self.discriminator.backward(errD_real) + # Undo averaging so we can re-average at end when reporting metrics. + errD_real_sum += errD_real * actual_batch_size + D_x += output.sum().item() + # train with fake + noise = self._get_noise(dtype) + fake = self.generator(noise) + output = self.discriminator(fake.detach()) + errD_fake = self.criterion(output, fake_label) + self.discriminator.backward(errD_fake) + errD_fake_sum += errD_fake * self.context.train_micro_batch_size_per_gpu + D_G_z1 += output.sum().item() + # update + self.discriminator.step() + D_x /= real_sample_count + D_G_z1 /= fake_sample_count + errD = (errD_real_sum / real_sample_count) + (errD_fake_sum / fake_sample_count) + ############################ + # (2) Update G network: maximize log(D(G(z))) + ########################### + self.generator.zero_grad() + D_G_z2_sum = 0.0 + errG_sum = 0.0 + for i in range(self.gradient_accumulation_steps): + if i > 0: + # Must repeat forward pass of generator for accumulation steps beyond the first. + noise = self._get_noise(dtype) + fake = self.generator(noise) + output = self.discriminator(fake) + errG = self.criterion(output, real_label) # fake labels are real for generator cost + self.generator.backward(errG) + errG_sum += errG * self.context._train_micro_batch_size_per_gpu + D_G_z2_sum += output.sum().item() + self.generator.step() + + if batch_idx % 100 == 0: + fake = self.generator(self.fixed_noise) + denormalized_real = (real + 1) / 2 + denormalized_fake = (fake + 1) / 2 + self.logger.add_image( + "real_images", torchvision.utils.make_grid(denormalized_real), batch_idx + ) + self.logger.add_image( + "fake_images", torchvision.utils.make_grid(denormalized_fake), batch_idx + ) + + return { + "errD": errD, + "errG": errG_sum / fake_sample_count, + "D_x": D_x, + "D_G_z1": D_G_z1, + "D_G_z2": D_G_z2_sum / fake_sample_count, + } + + def evaluate_batch( + self, dataloader_iter: Optional[Iterator[TorchData]], batch_idx: int + ) -> Dict[str, Any]: + # TODO: We could add an evaluation metric like FID here. + assert dataloader_iter is not None + next(dataloader_iter) + return {"no_validation_metric": 0.0} + + def build_training_data_loader(self) -> Any: + dataset = data.get_dataset(self.data_config) + return DataLoader( + dataset, + batch_size=self.context.train_micro_batch_size_per_gpu, + shuffle=True, + num_workers=int(self.hparams["data_workers"]), + ) + + def build_validation_data_loader(self) -> Any: + dataset = data.get_dataset(self.data_config) + # Since we're not doing validation, limit to single batch. + dataset = torch.utils.data.Subset( + dataset, + list( + range( + self.context.train_micro_batch_size_per_gpu + * self.context.distributed.get_size() + ) + ), + ) + return DataLoader(dataset, batch_size=self.context.train_micro_batch_size_per_gpu) diff --git a/examples/deepspeed/dcgan/trainer.py b/examples/deepspeed/dcgan/trainer.py new file mode 100644 index 00000000000..1d114430d6f --- /dev/null +++ b/examples/deepspeed/dcgan/trainer.py @@ -0,0 +1,38 @@ +import logging + +import model +import yaml + +import determined as det +from determined import pytorch +from determined.pytorch import deepspeed as det_ds + + +def main(config_file: str, local: bool = True): + info = det.get_cluster_info() + + if local: + # For convenience, use hparams from const.yaml for local mode. + with open(config_file, "r") as f: + experiment_config = yaml.load(f, Loader=yaml.SafeLoader) + hparams = experiment_config["hyperparameters"] + data_config = experiment_config["data"] + latest_checkpoint = None + else: + hparams = info.trial.hparams + data_config = info.trial._config["data"] + latest_checkpoint = ( + info.latest_checkpoint + ) # (Optional) Configure checkpoint for pause/resume functionality. + + with det_ds.init() as train_context: + trial = model.DCGANTrial(train_context, hparams, data_config) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(200), latest_checkpoint=latest_checkpoint) + + +if __name__ == "__main__": + local = det.get_cluster_info() is None + # Configure logging + logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + main(config_file="mnist.yaml", local=local) diff --git a/examples/deepspeed/gpt_neox/det_utils.py b/examples/deepspeed/gpt_neox/det_utils.py index 608d30c7cfd..3a6eac44f1c 100644 --- a/examples/deepspeed/gpt_neox/det_utils.py +++ b/examples/deepspeed/gpt_neox/det_utils.py @@ -30,7 +30,7 @@ def get_neox_args(context): "checkpoint_factor": exp_config["min_validation_period"]["batches"], "eval_interval": exp_config["min_validation_period"]["batches"], "hostfile": os.environ.get("DET_DEEPSPEED_HOSTFILE_PATH"), - "seed": context.env.trial_seed, + "seed": context.get_trial_seed(), } ) for k, v in overwrite_values.items(): diff --git a/examples/deepspeed_autotune/torchvision/README.md b/examples/deepspeed_autotune/torchvision/README.md deleted file mode 100644 index 08e0a74431e..00000000000 --- a/examples/deepspeed_autotune/torchvision/README.md +++ /dev/null @@ -1,61 +0,0 @@ -# DeepSpeed Autotuning - -This example demonstrates how to use the DeepSpeed Autotune (`dsat`) feature with two parallel examples, one -written as a [`DeepSpeedTrial`](https://docs.determined.ai/latest/training/apis-howto/deepspeed/overview.html) -class and the other written using [Core API](https://docs.determined.ai/latest/training/apis-howto/api-core-ug.html#core-api). -The relevant code can be found under `deepspeed_trial/` and `core_api/`, respectively. Each example -trains a [`torchvision`](https://pytorch.org/vision/stable/models.html) model on randomly generated -ImageNet-like data (for speed and simplicity). - -## Files - -The two subdirectories closely mirror each other. - -Both contain identical `ds_config.json` files which -use a simple zero-1 DeepSpeed (DS) configuration. They also contain nearly identical`deepspeed.yaml` Determined -configuration files: `core_api/deepspeed.yaml` only differs from `deepspeed_trial/deepspeed.yaml` -in its entrypoint and the inclusion of parameters which control the metric-reporting and checkpointing -frequencies. - -Model code can be found in the following files: - -- `deepspeed_trial/model_def.py` contains the `DeepSpeedTrial` subclass. The only `dsat`-specific code - in this file comes is the `dsat.get_ds_config_from_hparams` helper function. -- `core_api/script.py` contains a bare-bones training loop written with Core API. The script handles - preemption, metric-reporting, and checkpointing. In addition to the `dsat.get_ds_config_from_hparams` - helper function, the forward and backward steps are wrapped in the `dsat.dsat_reporting_context` - context manager. - -The `deepspeed.yaml` files define standard single-Trial experiments which can be run in the usual way -by calling - -```bash -python3 -m determined.pytorch.dsat binary deepspeed.yaml . -``` - -after `cd`-ing into the relevant directory. The code path which utilizes `dsat` is described in the -following section. - -## Basic Usage - -There are three available search methods for DeepSpeed Autotune: - -- `asha`: uses the [ASHA](https://docs.determined.ai/latest/training/hyperparameter/search-methods/hp-adaptive-asha.html#id1) - algorithm to adaptively search over randomly selected DeepSpeed configurations -- `binary`: tunes the optimal batch size for a handful of randomly generated DeepSpeed configurations - via binary search. -- `random`: performs a search over randomly generated DeepSpeed configurations which implements - aggressive early-stopping criteria based on domain-knowledge of DeepSpeed and the search history. - - After `cd`-ing into either of the two subdirectories above, a `asha dsat` experiment can be launched - by entering the following command, for instance: - -```bash -python3 -m determined.pytorch.dsat asha deepspeed.yaml . -``` - -Similar commands are available for `binary` and `random`. The full options for each `dsat` search -method can be found as in `python3 -m determined.pytorch.dsat asha --help` and similar for the other -search methods. - -See [the documentation](https://docs.determined.ai/latest/model-dev-guide/apis-howto/deepspeed/autotuning.html) for more on the available DeepSpeed Autotuning options. diff --git a/examples/deepspeed_autotune/torchvision/core_api/deepspeed.yaml b/examples/deepspeed_autotune/torchvision/core_api/deepspeed.yaml deleted file mode 100644 index 8fc1d9a6a65..00000000000 --- a/examples/deepspeed_autotune/torchvision/core_api/deepspeed.yaml +++ /dev/null @@ -1,19 +0,0 @@ -name: torchvision dsat core_api -max_restarts: 0 -environment: - image: - gpu: determinedai/pytorch-ngc-dev:0736b6d -resources: - slots_per_trial: 2 - shm_size: 4294967296 # 4 GiB. -searcher: - name: single - metric: loss - max_length: 100 -hyperparameters: - model_name: resnet152 - # NOTE: dsat expects the yaml config to reference the DS json config path as in the below. - deepspeed_config: ds_config.json - checkpoint_rate: 50 - metric_reporting_rate: 10 -entrypoint: python3 -m determined.launch.deepspeed python3 script.py diff --git a/examples/deepspeed_autotune/torchvision/core_api/ds_config.json b/examples/deepspeed_autotune/torchvision/core_api/ds_config.json deleted file mode 100644 index f695474e56d..00000000000 --- a/examples/deepspeed_autotune/torchvision/core_api/ds_config.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "train_batch_size": 256, - "steps_per_print": 2000, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [0.8, 0.999], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "gradient_clipping": 1.0, - "prescale_gradients": false, - "fp16": { - "enabled": true - }, - "wall_clock_breakdown": false, - "zero_optimization": { - "stage": 1, - "allgather_partitions": true, - "reduce_scatter": true, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": true, - "contiguous_gradients": true, - "cpu_offload": false - } -} diff --git a/examples/deepspeed_autotune/torchvision/core_api/script.py b/examples/deepspeed_autotune/torchvision/core_api/script.py deleted file mode 100644 index f1d2d270313..00000000000 --- a/examples/deepspeed_autotune/torchvision/core_api/script.py +++ /dev/null @@ -1,123 +0,0 @@ -import logging -import uuid -from typing import Any, Optional, Tuple - -import attrdict -import deepspeed -import numpy as np -import torch -import torch.nn as nn -from torch.utils.data import Dataset -from torchvision import models - -import determined as det -from determined.pytorch import dsat - - -class RandImageNetDataset(Dataset): - """ - A fake, ImageNet-like dataset which only actually contains `num_actual_datapoints` independent - datapoints, but pretends to have the number reported in `__len__`. Used for speed and - simplicity. Replace with your own ImageNet-like dataset as desired. - """ - - def __init__(self, num_actual_datapoints: int = 128) -> None: - self.num_actual_datapoints = num_actual_datapoints - self.imgs = torch.randn(self.num_actual_datapoints, 3, 224, 224) - self.labels = torch.randint(1000, size=(self.num_actual_datapoints,)) - - def __len__(self) -> int: - return 10**6 - - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: - img = self.imgs[idx % self.num_actual_datapoints] - label = self.labels[idx % self.num_actual_datapoints] - return img, label - - -def main( - core_context: det.core.Context, - hparams: attrdict.AttrDict, - latest_checkpoint: Optional[uuid.UUID], -) -> None: - is_chief = core_context.distributed.rank == 0 - ds_config = dsat.get_ds_config_from_hparams(hparams) - deepspeed.init_distributed() - - trainset = RandImageNetDataset() - model = getattr(models, hparams.model_name)() - parameters = filter(lambda p: p.requires_grad, model.parameters()) - - model_engine, _, trainloader, _ = deepspeed.initialize( - model=model, - model_parameters=parameters, - training_data=trainset, - config=ds_config, - ) - # Restore from latest checkpoint, if any. - if latest_checkpoint is not None: - with core_context.checkpoint.restore_path(storage_id=latest_checkpoint) as path: - model_engine.load_checkpoint(path) - - fp16 = model_engine.fp16_enabled() - criterion = nn.CrossEntropyLoss() - - steps_completed = 0 - local_loss_bucket = [] - for op in core_context.searcher.operations(): - while steps_completed < op.length: - for data in trainloader: - with dsat.dsat_reporting_context(core_context, op): - inputs, labels = data - inputs, labels = inputs.to(model_engine.local_rank), labels.to( - model_engine.local_rank - ) - if fp16: - inputs = inputs.half() - outputs = model_engine(inputs) - loss = criterion(outputs, labels) - local_loss_bucket.append(loss.item()) - model_engine.backward(loss) - model_engine.step() - - # Only increment `steps_completed` when an actual optimizer step is taken, - # accounting for the gradient accumulation rate. - if model_engine.is_gradient_accumulation_boundary(): - steps_completed += 1 - # Metrics reporting. - if not steps_completed % hparams.metric_reporting_rate: - mean_local_loss = np.array(local_loss_bucket).mean() - local_loss_bucket = [] - gathered_losses = core_context.distributed.gather(mean_local_loss) - if is_chief: - mean_global_loss = np.array(gathered_losses).mean() - metrics_dict = {"loss": mean_global_loss} - core_context.train.report_training_metrics( - steps_completed=steps_completed, metrics=metrics_dict - ) - # Checkpointing. - if not steps_completed % hparams.checkpoint_rate: - metadata = {"steps_completed": steps_completed} - with core_context.checkpoint.store_path(metadata=metadata, shard=True) as ( - path, - _, - ): - model_engine.save_checkpoint(path) - # Preemption after checkpointing. - if core_context.preempt.should_preempt(): - return - # Completion. - if steps_completed == op.length: - if is_chief: - op.report_completed(mean_global_loss) - return - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) - info = det.get_cluster_info() - latest_checkpoint = info.latest_checkpoint - hparams = attrdict.AttrDict(info.trial.hparams) - distributed = det.core.DistributedContext.from_deepspeed() - with det.core.init(distributed=distributed) as core_context: - main(core_context, hparams, latest_checkpoint) diff --git a/examples/deepspeed_autotune/torchvision/deepspeed_trial/deepspeed.yaml b/examples/deepspeed_autotune/torchvision/deepspeed_trial/deepspeed.yaml deleted file mode 100644 index b892d8cbb90..00000000000 --- a/examples/deepspeed_autotune/torchvision/deepspeed_trial/deepspeed.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: torchvision dsat deepspeed_trial -max_restarts: 0 -environment: - image: - gpu: determinedai/pytorch-ngc-dev:0736b6d -resources: - slots_per_trial: 2 - shm_size: 4294967296 # 4 GiB. -searcher: - name: single - metric: val_loss - max_length: 100 -hyperparameters: - model_name: resnet152 - # NOTE: dsat expects the yaml config to reference the DS json config path as in the below. - deepspeed_config: ds_config.json -entrypoint: - - python3 - - -m - - determined.launch.deepspeed - - --trial - - model_def:TorchvisionTrial diff --git a/examples/deepspeed_autotune/torchvision/deepspeed_trial/ds_config.json b/examples/deepspeed_autotune/torchvision/deepspeed_trial/ds_config.json deleted file mode 100644 index f695474e56d..00000000000 --- a/examples/deepspeed_autotune/torchvision/deepspeed_trial/ds_config.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "train_batch_size": 256, - "steps_per_print": 2000, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [0.8, 0.999], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "gradient_clipping": 1.0, - "prescale_gradients": false, - "fp16": { - "enabled": true - }, - "wall_clock_breakdown": false, - "zero_optimization": { - "stage": 1, - "allgather_partitions": true, - "reduce_scatter": true, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": true, - "contiguous_gradients": true, - "cpu_offload": false - } -} diff --git a/examples/deepspeed_autotune/torchvision/deepspeed_trial/model_def.py b/examples/deepspeed_autotune/torchvision/deepspeed_trial/model_def.py deleted file mode 100644 index ccb0f7bd26a..00000000000 --- a/examples/deepspeed_autotune/torchvision/deepspeed_trial/model_def.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Any, Dict - -import deepspeed -import torch -import torch.nn as nn -from attrdict import AttrDict -from torch.utils.data import Dataset -from torchvision import models - -from determined.pytorch import DataLoader, dsat -from determined.pytorch.deepspeed import DeepSpeedTrial, DeepSpeedTrialContext - - -class RandImageNetDataset(Dataset): - """ - A fake, ImageNet-like dataset which only actually contains `num_actual_datapoints` independent - datapoints, but pretends to have the number reported in `__len__`. Used for speed and - simplicity. Replace with your own ImageNet-like dataset as desired. - """ - - def __init__(self, num_actual_datapoints: int = 128) -> None: - self.num_actual_datapoints = num_actual_datapoints - self.imgs = torch.randn(self.num_actual_datapoints, 3, 224, 224) - self.labels = torch.randint(1000, size=(self.num_actual_datapoints,)) - - def __len__(self) -> int: - return 10**6 - - def __getitem__(self, idx: int) -> torch.Tensor: - img = self.imgs[idx % self.num_actual_datapoints] - label = self.labels[idx % self.num_actual_datapoints] - return img, label - - -class TorchvisionTrial(DeepSpeedTrial): - def __init__(self, context: DeepSpeedTrialContext) -> None: - self.context = context - self.hparams = AttrDict(self.context.get_hparams()) - - model = getattr(models, self.hparams.model_name)() - parameters = filter(lambda p: p.requires_grad, model.parameters()) - - ds_config = dsat.get_ds_config_from_hparams(self.hparams) - model_engine, _, _, _ = deepspeed.initialize( - model=model, model_parameters=parameters, config=ds_config - ) - - self.fp16 = model_engine.fp16_enabled() - self.model_engine = self.context.wrap_model_engine(model_engine) - - self.criterion = nn.CrossEntropyLoss().to(self.context.device) - - def train_batch(self, iter_dataloader, epoch_idx, batch_idx) -> Dict[str, torch.Tensor]: - inputs, labels = self.context.to_device(next(iter_dataloader)) - if self.fp16: - inputs = inputs.half() - outputs = self.model_engine(inputs) - loss = self.criterion(outputs, labels) - - self.model_engine.backward(loss) - self.model_engine.step() - return {"train_loss": loss.item()} - - def evaluate_batch(self, iter_dataloader, batch_idx) -> Dict[str, Any]: - inputs, labels = self.context.to_device(next(iter_dataloader)) - if self.fp16: - inputs = inputs.half() - outputs = self.model_engine(inputs) - loss = self.criterion(outputs, labels) - return {"val_loss": loss.item()} - - def build_training_data_loader(self) -> Any: - trainset = RandImageNetDataset() - train_loader = DataLoader( - trainset, - batch_size=self.context.train_micro_batch_size_per_gpu, - shuffle=True, - num_workers=2, - ) - return train_loader - - def build_validation_data_loader(self) -> Any: - testset = RandImageNetDataset() - return DataLoader( - testset, - batch_size=self.context.train_micro_batch_size_per_gpu, - shuffle=False, - num_workers=2, - ) diff --git a/examples/diffusion/textual_inversion_stable_diffusion/detsd/pipeline.py b/examples/diffusion/textual_inversion_stable_diffusion/detsd/pipeline.py index 92fa0df9a7e..c5ae30141d2 100644 --- a/examples/diffusion/textual_inversion_stable_diffusion/detsd/pipeline.py +++ b/examples/diffusion/textual_inversion_stable_diffusion/detsd/pipeline.py @@ -102,6 +102,7 @@ def generate_on_cluster(cls) -> None: # Extract relevant groups from hparams. batch_size = hparams["batch_size"] + num_batches = hparams["num_batches"] main_process_generator_seed = hparams["main_process_generator_seed"] save_freq = hparams["save_freq"] pipeline_init_kwargs = hparams["pipeline"] @@ -185,42 +186,36 @@ def generate_on_cluster(cls) -> None: if is_main_process: logger.info("--------------- Generating Images ---------------") - # There will be a single op of len max_length, as defined in the searcher config. - for op in core_context.searcher.operations(): - while pipeline.steps_completed < op.length: - pipeline.image_history.extend(pipeline(**call_kwargs).images) - pipeline.steps_completed += 1 - - # Write to tensorboard and checkpoint at the specified frequency. - if ( - pipeline.steps_completed % save_freq == 0 - or pipeline.steps_completed == op.length - ): - pipeline._write_tb_imgs( - core_context=core_context, tb_writer=tb_writer, tb_tag=tb_tag - ) + while pipeline.steps_completed < num_batches: + pipeline.image_history.extend(pipeline(**call_kwargs).images) + pipeline.steps_completed += 1 + + # Write to tensorboard and checkpoint at the specified frequency. + if ( + pipeline.steps_completed % save_freq == 0 + or pipeline.steps_completed == num_batches + ): + pipeline._write_tb_imgs( + core_context=core_context, tb_writer=tb_writer, tb_tag=tb_tag + ) - # Checkpointing. - devices_and_generators = core_context.distributed.gather( - (device, generator.get_state()) + # Checkpointing. + devices_and_generators = core_context.distributed.gather( + (device, generator.get_state()) + ) + if is_main_process: + logger.info(f"Saving at step {pipeline.steps_completed}") + # Save the state of the generators as the checkpoint. + pipeline._save( + core_context=core_context, + devices_and_generators=devices_and_generators, + trial_id=trial_id, ) - if is_main_process: - logger.info(f"Saving at step {pipeline.steps_completed}") - # Save the state of the generators as the checkpoint. - pipeline._save( - core_context=core_context, - devices_and_generators=devices_and_generators, - trial_id=trial_id, - ) - op.report_progress(pipeline.steps_completed) - - # Only preempt after a checkpoint has been saved. - if core_context.preempt.should_preempt(): - return + core_context.train.report_progress(pipeline.steps_completed / MAX_LENGTH) - if is_main_process: - # Report zero upon completion. - op.report_completed(0) + # Only preempt after a checkpoint has been saved. + if core_context.preempt.should_preempt(): + return def load_from_checkpoint_dir( self, diff --git a/examples/diffusion/textual_inversion_stable_diffusion/detsd/trainer.py b/examples/diffusion/textual_inversion_stable_diffusion/detsd/trainer.py index d9e481db487..f8204250060 100644 --- a/examples/diffusion/textual_inversion_stable_diffusion/detsd/trainer.py +++ b/examples/diffusion/textual_inversion_stable_diffusion/detsd/trainer.py @@ -30,6 +30,7 @@ def __init__( initializer_strs: Union[str, Sequence[str]], learnable_properties: Sequence[Literal["object", "style"]], pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4", + num_sgd_steps: int = 100, train_batch_size: int = 1, gradient_accumulation_steps: int = 4, optimizer_name: Literal["adam", "adamw", "sgd"] = "adam", @@ -74,6 +75,7 @@ def __init__( self.logger = accelerate.logging.get_logger(__name__) self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.num_sgd_steps = num_sgd_steps if isinstance(learnable_properties, str): learnable_properties = [learnable_properties] @@ -233,39 +235,36 @@ def train_on_cluster(cls) -> None: trial_id=trial_id, ) - # There will be a single op of len max_length, as defined in the searcher config. - for op in core_context.searcher.operations(): - while trainer.steps_completed < op.length: - for batch in trainer.train_dataloader: - # Use the accumulate method for efficient gradient accumulation. - with trainer.accelerator.accumulate(trainer.text_encoder): - trainer._train_one_batch(batch) - took_sgd_step = trainer.accelerator.sync_gradients - if took_sgd_step: - trainer.steps_completed += 1 - trainer.logger.info(f"Step {trainer.steps_completed} completed.") - - is_end_of_training = trainer.steps_completed == op.length - time_to_report = ( - trainer.steps_completed % trainer.metric_report_freq == 0 - ) - time_to_ckpt = trainer.steps_completed % trainer.checkpoint_freq == 0 - - # Report metrics, checkpoint, and preempt as appropriate. - if is_end_of_training or time_to_report or time_to_ckpt: - trainer._report_train_metrics(core_context) - # report_progress for Web UI progress-bar rendering. - if trainer.accelerator.is_main_process: - op.report_progress(trainer.steps_completed) - if is_end_of_training or time_to_ckpt: - trainer._save(core_context, trial_id) - if core_context.preempt.should_preempt(): - return - if is_end_of_training: - break - if trainer.accelerator.is_main_process: - # Report the final mean loss. - op.report_completed(trainer.last_mean_loss) + while trainer.steps_completed < trainer.num_sgd_steps: + for batch in trainer.train_dataloader: + # Use the accumulate method for efficient gradient accumulation. + with trainer.accelerator.accumulate(trainer.text_encoder): + trainer._train_one_batch(batch) + took_sgd_step = trainer.accelerator.sync_gradients + if took_sgd_step: + trainer.steps_completed += 1 + trainer.logger.info(f"Step {trainer.steps_completed} completed.") + + is_end_of_training = trainer.steps_completed == trainer.num_sgd_steps + time_to_report = ( + trainer.steps_completed % trainer.metric_report_freq == 0 + ) + time_to_ckpt = trainer.steps_completed % trainer.checkpoint_freq == 0 + + # Report metrics, checkpoint, and preempt as appropriate. + if is_end_of_training or time_to_report or time_to_ckpt: + trainer._report_train_metrics(core_context) + # report_progress for Web UI progress-bar rendering. + if trainer.accelerator.is_main_process: + core_context.train.report_progress( + trainer.steps_completed / trainer.num_sgd_steps + ) + if is_end_of_training or time_to_ckpt: + trainer._save(core_context, trial_id) + if core_context.preempt.should_preempt(): + return + if is_end_of_training: + break def _train_one_batch(self, batch: TorchData) -> None: """Train on a single batch and update internal metrics.""" diff --git a/examples/diffusion/textual_inversion_stable_diffusion/finetune_const.yaml b/examples/diffusion/textual_inversion_stable_diffusion/finetune_const.yaml index afe578e2a63..daa16ae1608 100644 --- a/examples/diffusion/textual_inversion_stable_diffusion/finetune_const.yaml +++ b/examples/diffusion/textual_inversion_stable_diffusion/finetune_const.yaml @@ -3,12 +3,11 @@ entrypoint: python3 -m determined.launch.torch_distributed python3 finetune.py searcher: name: single metric: loss - max_length: 500 # Number of SGD steps. resources: slots_per_trial: 2 max_restarts: 1 environment: - environment_variables: + environment_variables: - HF_AUTH_TOKEN=YOUR_HF_AUTH_TOKEN_HERE checkpoint_storage: save_trial_latest: 5 @@ -16,15 +15,16 @@ hyperparameters: model: pretrained_model_name_or_path: CompVis/stable-diffusion-v1-4 concepts: - learnable_properties: # One of 'object' or 'style'. + learnable_properties: # One of 'object' or 'style'. - object concept_strs: # Individual strings representing new concepts. Must not exist in tokenizer. - det-logo - initializer_strs: # Strings which describe the added concepts. + initializer_strs: # Strings which describe the added concepts. - brain logo, sharp lines, connected circles, concept art img_dirs: - det_logos training: + num_sgd_steps: 500 train_batch_size: 1 gradient_accumulation_steps: 4 optimizer_name: adam @@ -38,5 +38,3 @@ hyperparameters: - a watercolor painting on textured paper of a det-logo using soft strokes, pastel colors, incredible composition, masterpiece - a Van Gogh painting of a det-logo with vibrant colors, thick strokes, masterpiece, incredible composition - Beautiful tarot illustration of a det-logo, in the style of james jean and victo ngai, mystical colors, trending on artstation - - diff --git a/examples/diffusion/textual_inversion_stable_diffusion/finetune_const_advanced.yaml b/examples/diffusion/textual_inversion_stable_diffusion/finetune_const_advanced.yaml index 3f3ec15ccdf..a3e6ede03ba 100644 --- a/examples/diffusion/textual_inversion_stable_diffusion/finetune_const_advanced.yaml +++ b/examples/diffusion/textual_inversion_stable_diffusion/finetune_const_advanced.yaml @@ -3,7 +3,6 @@ entrypoint: python3 -m determined.launch.torch_distributed python3 finetune.py searcher: name: single metric: loss - max_length: 1000 # Number of SGD steps. resources: slots_per_trial: 4 max_restarts: 1 @@ -16,11 +15,11 @@ hyperparameters: model: pretrained_model_name_or_path: CompVis/stable-diffusion-v1-4 concepts: - learnable_properties: # One of 'object' or 'style'. + learnable_properties: # One of 'object' or 'style'. - object - concept_strs: # Individual strings representing new concepts. Must not exist in tokenizer. + concept_strs: # Individual strings representing new concepts. Must not exist in tokenizer. - det-logo - initializer_strs: # Strings which describe the added concepts. + initializer_strs: # Strings which describe the added concepts. - brain logo, sharp lines, connected circles, concept art img_dirs: - det_logos @@ -33,6 +32,7 @@ hyperparameters: num_blank_prompts: 10 num_a_prompts: 10 training: + num_sgd_steps: 1000 train_batch_size: 1 gradient_accumulation_steps: 4 optimizer_name: adam @@ -61,5 +61,3 @@ hyperparameters: main_process_generator_seed: 2147483647 other_inference_scheduler_kwargs: skip_prk_steps: True - - diff --git a/examples/diffusion/textual_inversion_stable_diffusion/generate_grid.yaml b/examples/diffusion/textual_inversion_stable_diffusion/generate_grid.yaml index 0522a770e8c..491cb433e3b 100644 --- a/examples/diffusion/textual_inversion_stable_diffusion/generate_grid.yaml +++ b/examples/diffusion/textual_inversion_stable_diffusion/generate_grid.yaml @@ -1,8 +1,5 @@ # A grid search over multiple prompts and pipeline settings. -# Due to tensorboard limitations, one should keep batch_size * max_length <= 10 in order for all -# generated images to be easily viewable in tensorboard. Scanning over generator seeds is the -# easiest way to get more samples while also diversifying the results. name: detsd_generate entrypoint: python3 -m determined.launch.torch_distributed python3 generate.py searcher: @@ -17,7 +14,11 @@ environment: environment_variables: - HF_AUTH_TOKEN=YOUR_HF_AUTH_TOKEN_HERE hyperparameters: + # Keep batch_size * num_batches <= 10 in order for all generated images to + # be easily viewable in tensorboard. Scanning over generator seeds is the + # easiest way to get more samples while also diversifying the results. batch_size: 2 + num_batches: 5 # Number of times the generation pipeline is called, per-worker. main_process_generator_seed: type: int minval: 14748367 @@ -51,4 +52,3 @@ hyperparameters: minval: 2 maxval: 6 count: 3 - diff --git a/examples/features/inference_mnist_pytorch/distributed_inference.yaml b/examples/features/inference_mnist_pytorch/distributed_inference.yaml index cbb4132c4a1..65eda6e4728 100644 --- a/examples/features/inference_mnist_pytorch/distributed_inference.yaml +++ b/examples/features/inference_mnist_pytorch/distributed_inference.yaml @@ -9,7 +9,6 @@ resources: searcher: name: grid metric: x - max_length: 100 hyperparameters: # Change this to your model name. model_name: mnist_models @@ -27,4 +26,4 @@ max_restarts: 0 bind_mounts: - host_path: /tmp container_path: /tmp - read_only: false \ No newline at end of file + read_only: false diff --git a/examples/features/ports/ray_launcher.yaml b/examples/features/ports/ray_launcher.yaml index 4997e153f74..83d92a5438d 100644 --- a/examples/features/ports/ray_launcher.yaml +++ b/examples/features/ports/ray_launcher.yaml @@ -7,7 +7,6 @@ resources: searcher: name: single metric: x - max_length: 10000000 max_restarts: 0 diff --git a/examples/features/torch_batch_process_core_api_comparison/core_api_config.yaml b/examples/features/torch_batch_process_core_api_comparison/core_api_config.yaml index 82e7606cf40..5f4a5c8978f 100644 --- a/examples/features/torch_batch_process_core_api_comparison/core_api_config.yaml +++ b/examples/features/torch_batch_process_core_api_comparison/core_api_config.yaml @@ -9,7 +9,6 @@ resources: searcher: name: single metric: x - max_length: 100 max_restarts: 2 bind_mounts: - host_path: /tmp diff --git a/examples/features/torch_batch_process_core_api_comparison/torch_batch_process_config.yaml b/examples/features/torch_batch_process_core_api_comparison/torch_batch_process_config.yaml index f484e4fa0fc..f7c26ff1d47 100644 --- a/examples/features/torch_batch_process_core_api_comparison/torch_batch_process_config.yaml +++ b/examples/features/torch_batch_process_core_api_comparison/torch_batch_process_config.yaml @@ -9,7 +9,6 @@ resources: searcher: name: single metric: x - max_length: 100 max_restarts: 2 bind_mounts: diff --git a/examples/features/torch_batch_process_embeddings/distributed.yaml b/examples/features/torch_batch_process_embeddings/distributed.yaml index 6f6dc82e742..da96c285b02 100644 --- a/examples/features/torch_batch_process_embeddings/distributed.yaml +++ b/examples/features/torch_batch_process_embeddings/distributed.yaml @@ -9,7 +9,6 @@ resources: searcher: name: single metric: x - max_length: 100 max_restarts: 0 bind_mounts: diff --git a/examples/features/unmanaged/1.yaml b/examples/features/unmanaged/1.yaml index 24325b23f86..becd96e2a15 100644 --- a/examples/features/unmanaged/1.yaml +++ b/examples/features/unmanaged/1.yaml @@ -6,7 +6,5 @@ searcher: name: single # metric is required but it shouldn't hurt to ignore it at this point. metric: x - # max_length is ignored if the training script ignores it. - max_length: 1 max_restarts: 0 diff --git a/examples/features/unmanaged/2.yaml b/examples/features/unmanaged/2.yaml index 09c4f84740a..f3b6ce030cd 100644 --- a/examples/features/unmanaged/2.yaml +++ b/examples/features/unmanaged/2.yaml @@ -6,8 +6,6 @@ searcher: name: single # metric is required but it shouldn't hurt to ignore it at this point. metric: x - # max_length is ignored if the training script ignores it. - max_length: 1 max_restarts: 0 diff --git a/examples/features/unmanaged/3.yaml b/examples/features/unmanaged/3.yaml index 041c83ed08b..d1a6706f337 100644 --- a/examples/features/unmanaged/3.yaml +++ b/examples/features/unmanaged/3.yaml @@ -11,7 +11,5 @@ searcher: name: single # metric is required but it shouldn't hurt to ignore it at this point. metric: x - # max_length is ignored if the training script ignores it. - max_length: 1 max_restarts: 0 diff --git a/examples/features/unmanaged/ray/ray_hp_search.py b/examples/features/unmanaged/ray/ray_hp_search.py index 79347d6a906..f17a0a822cd 100644 --- a/examples/features/unmanaged/ray/ray_hp_search.py +++ b/examples/features/unmanaged/ray/ray_hp_search.py @@ -25,9 +25,10 @@ def objective(config): # We need to pass a non-single searcher config to have the WebUI display our experiment # as HP search. searcher={ - "name": "custom", + "name": "random", "metric": "loss", "smaller_is_better": True, + "max_trials": 2, }, external_experiment_id=experiment_name, external_trial_id=trial_name, diff --git a/examples/hf_trainer_api/README.md b/examples/hf_trainer_api/README.md index a412219375b..85269146958 100644 --- a/examples/hf_trainer_api/README.md +++ b/examples/hf_trainer_api/README.md @@ -75,15 +75,3 @@ det experiment create deepspeed.yaml . The deepspeed configuration can be changed by altering the `hyperparameters.deepspeed_config` entry of the `deepspeed.yaml` config, as well as the corresponding line in the `entrypoint`. The default configuration is `ds_configs/ds_config_stage_1.json`. - -## DeepSpeed Autotune - -One can also use Determined's DeepSpeed Autotune functionality to automatically optimize the -DeepSpeed settings. From either subdirectory, DeepSpeed parameters can be tuned to maximize the -model FLOPs via the ASHA algorithm by running the following script, for instance: - -``` -python3 -m determined.pytorch.dsat asha deepspeed.yaml . -``` - -See [the documentation](https://docs.determined.ai/latest/model-dev-guide/apis-howto/deepspeed/autotuning.html) for more on the available DeepSpeed Autotuning options. diff --git a/examples/hf_trainer_api/hf_image_classification/adaptive.yaml b/examples/hf_trainer_api/hf_image_classification/adaptive.yaml index 4407ebf100d..59850ae5150 100644 --- a/examples/hf_trainer_api/hf_image_classification/adaptive.yaml +++ b/examples/hf_trainer_api/hf_image_classification/adaptive.yaml @@ -9,8 +9,8 @@ resources: slots_per_trial: 2 searcher: name: adaptive_asha - max_length: - batches: 100 + time_metric: batches + max_time: 100 max_trials: 64 max_rungs: 4 divisor: 4 diff --git a/examples/hf_trainer_api/hf_image_classification/const.yaml b/examples/hf_trainer_api/hf_image_classification/const.yaml index 2a3b95535fa..fcf6a9d3844 100644 --- a/examples/hf_trainer_api/hf_image_classification/const.yaml +++ b/examples/hf_trainer_api/hf_image_classification/const.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 1 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml b/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml index 49425d854af..2fb6a69f5a4 100644 --- a/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml +++ b/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml @@ -10,8 +10,6 @@ resources: records_per_epoch: 1000 searcher: name: single - max_length: - epochs: 5 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml b/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml index f60ec4d218d..698d68f8bba 100644 --- a/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml +++ b/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml @@ -11,8 +11,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: deepspeed_config: ds_configs/ds_config_stage_1.json diff --git a/examples/hf_trainer_api/hf_image_classification/distributed.yaml b/examples/hf_trainer_api/hf_image_classification/distributed.yaml index a9ea4ca154b..fe74f1ec1b7 100644 --- a/examples/hf_trainer_api/hf_image_classification/distributed.yaml +++ b/examples/hf_trainer_api/hf_image_classification/distributed.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_image_classification/image_classification.py b/examples/hf_trainer_api/hf_image_classification/image_classification.py index 2c31f748877..fdc9c104e61 100644 --- a/examples/hf_trainer_api/hf_image_classification/image_classification.py +++ b/examples/hf_trainer_api/hf_image_classification/image_classification.py @@ -26,6 +26,7 @@ import numpy as np import torch import transformers +import util from PIL import Image from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import ( @@ -51,7 +52,6 @@ from transformers.utils.versions import require_version import determined as det -from determined.pytorch import dsat from determined.transformers import DetCallback """ Fine-tuning a 🤗 Transformers model for image classification""" @@ -223,7 +223,7 @@ def parse_input_arguments( args = sys.argv[1:] args.extend(dict2args(training_arguments)) if any("--deepspeed" == arg.strip() for arg in args): - args = dsat.get_hf_args_with_overwrites(args, hparams) + args = util.get_hf_args_with_overwrites(args, hparams) model_args, data_args, training_args = parser.parse_args_into_dataclasses( args, look_for_args_file=False ) @@ -430,8 +430,7 @@ def val_transforms(example_batch): elif last_checkpoint is not None: checkpoint = last_checkpoint - with dsat.dsat_reporting_context(core_context, op=det_callback.current_op): - train_result = trainer.train(resume_from_checkpoint=checkpoint) + train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) diff --git a/examples/hf_trainer_api/hf_image_classification/util.py b/examples/hf_trainer_api/hf_image_classification/util.py new file mode 100644 index 00000000000..85cdda4f4e7 --- /dev/null +++ b/examples/hf_trainer_api/hf_image_classification/util.py @@ -0,0 +1,140 @@ +import copy +import json +import logging +import pathlib +from typing import Any, Dict, List, Optional, Union + +import filelock + +from determined import util as det_util + +CURR_DIR = pathlib.Path(".") +CONFIG_KEY = "deepspeed_config" +OVERWRITE_KEY = "overwrite_deepspeed_args" + + +def get_ds_config_from_hparams( + hparams: Dict[str, Any], + base_dir: Union[pathlib.Path, str] = CURR_DIR, +) -> Dict[str, Any]: + """Gets the DS config dictionary after merging with overwrite values. + + Follows the rules as described here: + https://docs.determined.ai/latest/training/apis-howto/deepspeed/deepspeed.html#configuration + Args: + hparams (Dict): + Hyperparameters dictionary + base_dir (pathlib.Path): + Base directory relattive to which hparams.deepspeed_config is defined + Returns: + The Deepspeed Configuration for this experiment following the overwriting rules + """ + assert CONFIG_KEY in hparams, ( + f"Expected to find {CONFIG_KEY} in the Hyperparameters section. " f"Instead found {hparams}" + ) + ds_config_relative_path = hparams[CONFIG_KEY] + base_dir = pathlib.Path(base_dir) + full_path = base_dir.joinpath(ds_config_relative_path) + with open(full_path, "r") as f: + base_ds_config: Dict[str, Any] = json.load(f) + overwrite_ds_config = hparams.get(OVERWRITE_KEY, {}) + final_ds_config = det_util.merge_dicts(base_ds_config, overwrite_ds_config) + return final_ds_config + + +def get_hf_ds_config_path_from_args(args: List[str]) -> Optional[str]: + for idx in range(len(args)): + if args[idx] == "--deepspeed": + ds_config_idx = idx + 1 + ds_config_path = args[ds_config_idx] + return ds_config_path + return None + + +def update_hf_args(args: List[str], ds_config_dict: Dict[str, Any]) -> List[str]: + """ + Updates batch-size-related HF CLI args to be consistent with the values specified in the + provided DeepSpeed config dictionary. + + Args: + args: list of CLI arguments passed to the HF entrypoint + ds_config_dict: the DeepSpeed configuration as a dictionary + """ + hf_flag_to_ds_key = { + "--per_device_train_batch_size": "train_micro_batch_size_per_gpu", + "--gradient_accumulation_steps": "gradient_accumulation_steps", + } + # Overwrite CLI args + args = copy.deepcopy(args) + for idx in range(len(args)): + if args[idx] in hf_flag_to_ds_key: + ds_key = hf_flag_to_ds_key[args[idx]] + overwrite_value = ds_config_dict[ds_key] + # Need to avoid copying possible "auto" value from json config to HF CLI. + is_auto = isinstance(overwrite_value, str) and overwrite_value.strip() == "auto" + if not is_auto: + overwrite_value_str = str(overwrite_value) + if args[idx + 1] != overwrite_value_str: + logging.warning( + f"Changing {args[idx]} from {args[idx +1]} to {overwrite_value_str}" + " to match the deespspeed config values." + ) + args[idx + 1] = overwrite_value_str + del hf_flag_to_ds_key[args[idx]] + + # Any remaining keys in hf_flag_to_ds_key were not provided as args to the HF CLI entrypoint, + # but they must be added in explicitly, to avoid falling back to HF defaults. + for hf_flag, ds_key in hf_flag_to_ds_key.items(): + hf_flag_value = ds_config_dict[ds_key] + is_auto = isinstance(hf_flag_value, str) and hf_flag_value.strip() == "auto" + if not is_auto: + hf_flag_value_str = str(hf_flag_value) + args.extend([hf_flag, hf_flag_value_str]) + logging.warning( + f"Adding {hf_flag} {hf_flag_value_str} to HF CLI args to reflect overwrite values." + ) + return args + + +def get_hf_args_with_overwrites(args: List[str], hparams: Dict[str, Any]) -> List[str]: + """Updates the submitted HF CLI Args to account for overwrite values. + + Primarily intended as a helper function for Determined AI DeepSpeed (DS) which provides + overwrite values through the `hparams["overwrite_deepspeed_args"]` which possibly include DS + batch-size related arguments (`train_batch_size`, `train_micro_batch_size_per_gpu`, and + `gradient_accumulation_steps`) which are in conflict with the corresponding HF CLI batch-size + related arguments(`--per_device_train_batch_size` and `--gradient_accumulation_steps`). This + function updates the HF CLI args to relect any such overwrite values. This process also requires + overwriting the corresponding DS json file on-cluster. + + Args: + args: the original HF CLI arguments + hparams: hyperparameter dictionary generated through Determined AI + + Returns: + args: updated HF CLI arguments + """ + if OVERWRITE_KEY not in hparams: + logging.info( + f"{OVERWRITE_KEY} key not found in hparams, `get_hf_args_with_overwrites` " "is a no-op" + ) + return args + + ds_config_path = get_hf_ds_config_path_from_args(args) + assert ds_config_path is not None, "--deepspeed flag not found in HuggingFace args!" + + # A file lock is required during both the writing and reading. + with filelock.FileLock(ds_config_path + ".lock"): + with open(ds_config_path, "r") as f: + ds_config_dict = json.load(f) + + # Then merge all overwrites into the ds_config + overwritten_ds_config_dict = det_util.merge_dicts(ds_config_dict, hparams[OVERWRITE_KEY]) + + # We need to actually overwrite the ds json config file, due to how HF processes args. + with open(ds_config_path, "w") as f: + json.dump(overwritten_ds_config_dict, f) + # Finally overwrite the CLI args + args = update_hf_args(args, overwritten_ds_config_dict) + + return args diff --git a/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml b/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml index 7799582879a..946aa15f9f1 100644 --- a/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml @@ -9,8 +9,8 @@ resources: slots_per_trial: 2 searcher: name: adaptive_asha - max_length: - batches: 100 + time_metric: batches + max_time: 100 max_trials: 64 max_rungs: 4 divisor: 4 diff --git a/examples/hf_trainer_api/hf_language_modeling/const.yaml b/examples/hf_trainer_api/hf_language_modeling/const.yaml index 294504aed07..e340834457c 100644 --- a/examples/hf_trainer_api/hf_language_modeling/const.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/const.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 1 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml b/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml index 75b44dc5f97..b544db63620 100644 --- a/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml @@ -10,8 +10,6 @@ resources: records_per_epoch: 1000 searcher: name: single - max_length: - epochs: 5 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml b/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml index 66ff58889fc..8facb3c47ac 100644 --- a/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml @@ -11,8 +11,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: deepspeed_config: ds_configs/ds_config_stage_1.json diff --git a/examples/hf_trainer_api/hf_language_modeling/distributed.yaml b/examples/hf_trainer_api/hf_language_modeling/distributed.yaml index c305d98f490..08b62b79788 100644 --- a/examples/hf_trainer_api/hf_language_modeling/distributed.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/distributed.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_language_modeling/run_clm.py b/examples/hf_trainer_api/hf_language_modeling/run_clm.py index f2210f62b1e..cb5daf88537 100644 --- a/examples/hf_trainer_api/hf_language_modeling/run_clm.py +++ b/examples/hf_trainer_api/hf_language_modeling/run_clm.py @@ -33,6 +33,7 @@ import evaluate import torch import transformers +import util from datasets import load_dataset from torch.utils.tensorboard import SummaryWriter from transformers import ( @@ -55,7 +56,6 @@ from transformers.utils.versions import require_version import determined as det -from determined.pytorch import dsat from determined.transformers import DetCallback # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -287,7 +287,7 @@ def parse_input_arguments( args = sys.argv[1:] args.extend(dict2args(training_arguments)) if any("--deepspeed" == arg.strip() for arg in args): - args = dsat.get_hf_args_with_overwrites(args, hparams) + args = util.get_hf_args_with_overwrites(args, hparams) model_args, data_args, training_args = parser.parse_args_into_dataclasses( args, look_for_args_file=False ) @@ -656,8 +656,7 @@ def compute_metrics(eval_preds): checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint - with dsat.dsat_reporting_context(core_context, op=det_callback.current_op): - train_result = trainer.train(resume_from_checkpoint=checkpoint) + train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.save_model() # Saves the tokenizer too for easy upload metrics = train_result.metrics diff --git a/examples/hf_trainer_api/hf_language_modeling/util.py b/examples/hf_trainer_api/hf_language_modeling/util.py new file mode 100644 index 00000000000..85cdda4f4e7 --- /dev/null +++ b/examples/hf_trainer_api/hf_language_modeling/util.py @@ -0,0 +1,140 @@ +import copy +import json +import logging +import pathlib +from typing import Any, Dict, List, Optional, Union + +import filelock + +from determined import util as det_util + +CURR_DIR = pathlib.Path(".") +CONFIG_KEY = "deepspeed_config" +OVERWRITE_KEY = "overwrite_deepspeed_args" + + +def get_ds_config_from_hparams( + hparams: Dict[str, Any], + base_dir: Union[pathlib.Path, str] = CURR_DIR, +) -> Dict[str, Any]: + """Gets the DS config dictionary after merging with overwrite values. + + Follows the rules as described here: + https://docs.determined.ai/latest/training/apis-howto/deepspeed/deepspeed.html#configuration + Args: + hparams (Dict): + Hyperparameters dictionary + base_dir (pathlib.Path): + Base directory relattive to which hparams.deepspeed_config is defined + Returns: + The Deepspeed Configuration for this experiment following the overwriting rules + """ + assert CONFIG_KEY in hparams, ( + f"Expected to find {CONFIG_KEY} in the Hyperparameters section. " f"Instead found {hparams}" + ) + ds_config_relative_path = hparams[CONFIG_KEY] + base_dir = pathlib.Path(base_dir) + full_path = base_dir.joinpath(ds_config_relative_path) + with open(full_path, "r") as f: + base_ds_config: Dict[str, Any] = json.load(f) + overwrite_ds_config = hparams.get(OVERWRITE_KEY, {}) + final_ds_config = det_util.merge_dicts(base_ds_config, overwrite_ds_config) + return final_ds_config + + +def get_hf_ds_config_path_from_args(args: List[str]) -> Optional[str]: + for idx in range(len(args)): + if args[idx] == "--deepspeed": + ds_config_idx = idx + 1 + ds_config_path = args[ds_config_idx] + return ds_config_path + return None + + +def update_hf_args(args: List[str], ds_config_dict: Dict[str, Any]) -> List[str]: + """ + Updates batch-size-related HF CLI args to be consistent with the values specified in the + provided DeepSpeed config dictionary. + + Args: + args: list of CLI arguments passed to the HF entrypoint + ds_config_dict: the DeepSpeed configuration as a dictionary + """ + hf_flag_to_ds_key = { + "--per_device_train_batch_size": "train_micro_batch_size_per_gpu", + "--gradient_accumulation_steps": "gradient_accumulation_steps", + } + # Overwrite CLI args + args = copy.deepcopy(args) + for idx in range(len(args)): + if args[idx] in hf_flag_to_ds_key: + ds_key = hf_flag_to_ds_key[args[idx]] + overwrite_value = ds_config_dict[ds_key] + # Need to avoid copying possible "auto" value from json config to HF CLI. + is_auto = isinstance(overwrite_value, str) and overwrite_value.strip() == "auto" + if not is_auto: + overwrite_value_str = str(overwrite_value) + if args[idx + 1] != overwrite_value_str: + logging.warning( + f"Changing {args[idx]} from {args[idx +1]} to {overwrite_value_str}" + " to match the deespspeed config values." + ) + args[idx + 1] = overwrite_value_str + del hf_flag_to_ds_key[args[idx]] + + # Any remaining keys in hf_flag_to_ds_key were not provided as args to the HF CLI entrypoint, + # but they must be added in explicitly, to avoid falling back to HF defaults. + for hf_flag, ds_key in hf_flag_to_ds_key.items(): + hf_flag_value = ds_config_dict[ds_key] + is_auto = isinstance(hf_flag_value, str) and hf_flag_value.strip() == "auto" + if not is_auto: + hf_flag_value_str = str(hf_flag_value) + args.extend([hf_flag, hf_flag_value_str]) + logging.warning( + f"Adding {hf_flag} {hf_flag_value_str} to HF CLI args to reflect overwrite values." + ) + return args + + +def get_hf_args_with_overwrites(args: List[str], hparams: Dict[str, Any]) -> List[str]: + """Updates the submitted HF CLI Args to account for overwrite values. + + Primarily intended as a helper function for Determined AI DeepSpeed (DS) which provides + overwrite values through the `hparams["overwrite_deepspeed_args"]` which possibly include DS + batch-size related arguments (`train_batch_size`, `train_micro_batch_size_per_gpu`, and + `gradient_accumulation_steps`) which are in conflict with the corresponding HF CLI batch-size + related arguments(`--per_device_train_batch_size` and `--gradient_accumulation_steps`). This + function updates the HF CLI args to relect any such overwrite values. This process also requires + overwriting the corresponding DS json file on-cluster. + + Args: + args: the original HF CLI arguments + hparams: hyperparameter dictionary generated through Determined AI + + Returns: + args: updated HF CLI arguments + """ + if OVERWRITE_KEY not in hparams: + logging.info( + f"{OVERWRITE_KEY} key not found in hparams, `get_hf_args_with_overwrites` " "is a no-op" + ) + return args + + ds_config_path = get_hf_ds_config_path_from_args(args) + assert ds_config_path is not None, "--deepspeed flag not found in HuggingFace args!" + + # A file lock is required during both the writing and reading. + with filelock.FileLock(ds_config_path + ".lock"): + with open(ds_config_path, "r") as f: + ds_config_dict = json.load(f) + + # Then merge all overwrites into the ds_config + overwritten_ds_config_dict = det_util.merge_dicts(ds_config_dict, hparams[OVERWRITE_KEY]) + + # We need to actually overwrite the ds json config file, due to how HF processes args. + with open(ds_config_path, "w") as f: + json.dump(overwritten_ds_config_dict, f) + # Finally overwrite the CLI args + args = update_hf_args(args, overwritten_ds_config_dict) + + return args diff --git a/examples/tutorials/core_api/0_start.yaml b/examples/tutorials/core_api/0_start.yaml index 49f5aa0eeb0..e533cfb6230 100644 --- a/examples/tutorials/core_api/0_start.yaml +++ b/examples/tutorials/core_api/0_start.yaml @@ -6,7 +6,5 @@ searcher: name: single # metric is required but it shouldn't hurt to ignore it at this point. metric: x - # max_length is ignored if the training script ignores it. - max_length: 1 max_restarts: 0 diff --git a/examples/tutorials/core_api/1_metrics.yaml b/examples/tutorials/core_api/1_metrics.yaml index 6484cdb6d1b..2a4c3ee2e81 100644 --- a/examples/tutorials/core_api/1_metrics.yaml +++ b/examples/tutorials/core_api/1_metrics.yaml @@ -6,6 +6,5 @@ entrypoint: python3 1_metrics.py searcher: name: single metric: x - max_length: 1 max_restarts: 0 diff --git a/examples/tutorials/core_api/2_checkpoints.py b/examples/tutorials/core_api/2_checkpoints.py index 42fbbcbd606..44167bc2cdb 100644 --- a/examples/tutorials/core_api/2_checkpoints.py +++ b/examples/tutorials/core_api/2_checkpoints.py @@ -66,8 +66,8 @@ def main(core_context, latest_checkpoint, trial_id, increment_by): save_state(x, steps_completed, trial_id, path) # NEW: check for a preemption signal. This could originate from a - # higher-priority task bumping us off the cluster, or for a user pausing - # the experiment via the WebUI or CLI. + # higher-priority task bumping us off the cluster, from the hpsearch + # algorithm, or from a user pausing in the WebUI or CLI. if core_context.preempt.should_preempt(): # At this point, a checkpoint was just saved, so training can exit # immediately and resume when the trial is reactivated. diff --git a/examples/tutorials/core_api/2_checkpoints.yaml b/examples/tutorials/core_api/2_checkpoints.yaml index fddde2c179a..025afedf201 100644 --- a/examples/tutorials/core_api/2_checkpoints.yaml +++ b/examples/tutorials/core_api/2_checkpoints.yaml @@ -6,6 +6,5 @@ entrypoint: python3 2_checkpoints.py searcher: name: single metric: x - max_length: 1 max_restarts: 0 diff --git a/examples/tutorials/core_api/3_hpsearch.py b/examples/tutorials/core_api/3_hpsearch.py index 4afe788ffc7..ef9a49ba0f9 100644 --- a/examples/tutorials/core_api/3_hpsearch.py +++ b/examples/tutorials/core_api/3_hpsearch.py @@ -27,49 +27,37 @@ def load_state(trial_id, checkpoint_directory): def main(core_context, latest_checkpoint, trial_id, increment_by): x = 0 + max_length = 100 starting_batch = 0 if latest_checkpoint is not None: with core_context.checkpoint.restore_path(latest_checkpoint) as path: x, starting_batch = load_state(trial_id, path) - # NEW: Iterate through the core_context.searcher.operations() to decide how long to train for. - batch = starting_batch - last_checkpoint_batch = None - for op in core_context.searcher.operations(): - # NEW: Use a while loop for easier accounting of absolute lengths. - while batch < op.length: - x += increment_by - steps_completed = batch + 1 - time.sleep(0.1) - logging.info(f"x is now {x}") - if steps_completed % 10 == 0: - core_context.train.report_training_metrics( - steps_completed=steps_completed, metrics={"x": x} - ) - - # NEW: report progress once in a while. - op.report_progress(batch) - - checkpoint_metadata = {"steps_completed": steps_completed} - with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): - save_state(x, steps_completed, trial_id, path) - last_checkpoint_batch = steps_completed - if core_context.preempt.should_preempt(): - return - batch += 1 - # NEW: After training for each op, you typically validate and report the - # searcher metric to the master. - core_context.train.report_validation_metrics( - steps_completed=steps_completed, metrics={"x": x} - ) - op.report_completed(x) - - # NEW: after searching, save a checkpoint if our last one is not up-to-date. - if last_checkpoint_batch != steps_completed: - checkpoint_metadata = {"steps_completed": steps_completed} - with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): - save_state(x, steps_completed, trial_id, path) + for batch in range(starting_batch, max_length): + x += increment_by + steps_completed = batch + 1 + time.sleep(0.1) + logging.info(f"x is now {x}") + if steps_completed % 10 == 0: + core_context.train.report_training_metrics( + steps_completed=steps_completed, metrics={"x": x} + ) + core_context.train.report_progress(steps_completed / float(max_length)) + + # NEW: periodically report validation metrics, which the searcher + # may monitor for the purpose of early-stopping. Note that for the + # ASHA searcher you need to provide the "time" metric you + # configured in the experiment config, in this case `batch`. + core_context.train.report_validation_metrics( + steps_completed=steps_completed, metrics={"x": x, "batches": steps_completed} + ) + + checkpoint_metadata = {"steps_completed": steps_completed} + with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): + save_state(x, steps_completed, trial_id, path) + if core_context.preempt.should_preempt(): + return if __name__ == "__main__": diff --git a/examples/tutorials/core_api/3_hpsearch.yaml b/examples/tutorials/core_api/3_hpsearch.yaml index 9471ef9307e..b3d70f6325c 100644 --- a/examples/tutorials/core_api/3_hpsearch.yaml +++ b/examples/tutorials/core_api/3_hpsearch.yaml @@ -12,7 +12,8 @@ hyperparameters: searcher: name: adaptive_asha metric: x - max_length: 100 max_trials: 10 + time_metric: batches + max_time: 100 max_restarts: 0 diff --git a/examples/tutorials/core_api/4_distributed.py b/examples/tutorials/core_api/4_distributed.py index 80c84af7c1c..a2ebd048013 100644 --- a/examples/tutorials/core_api/4_distributed.py +++ b/examples/tutorials/core_api/4_distributed.py @@ -38,57 +38,42 @@ def load_state(trial_id, checkpoint_directory): def main(core_context, latest_checkpoint, trial_id, increment_by): x = 0 + max_length = 100 starting_batch = 0 if latest_checkpoint is not None: with core_context.checkpoint.restore_path(latest_checkpoint) as path: x, starting_batch = load_state(trial_id, path) - batch = starting_batch - last_checkpoint_batch = None - for op in core_context.searcher.operations(): - while batch < op.length: - # NEW: Increment by the sum of every worker's increment_by value. - # In reality, it is just increment_by*num_workers, but the point is - # to show how to use the communication primitives. - all_increment_bys = core_context.distributed.allgather(increment_by) - x += sum(all_increment_bys) - steps_completed = batch + 1 - time.sleep(0.1) - # NEW: some logs are easier to read if you only log from the chief. - if core_context.distributed.rank == 0: - logging.info(f"x is now {x}") - if steps_completed % 10 == 0: - # NEW: only the chief may report training metrics and progress, - # or upload checkpoints. - if core_context.distributed.rank == 0: - core_context.train.report_training_metrics( - steps_completed=steps_completed, metrics={"x": x} - ) - op.report_progress(steps_completed) - checkpoint_metadata = {"steps_completed": steps_completed} - with core_context.checkpoint.store_path(checkpoint_metadata) as ( - checkpoint_directory, - uuid, - ): - save_state(x, steps_completed, trial_id, checkpoint_directory) - last_checkpoint_batch = steps_completed - if core_context.preempt.should_preempt(): - return - batch += 1 - - # NEW: only the chief may report validation metrics and completed operations. + for batch in range(starting_batch, max_length): + # NEW: Increment by the sum of every worker's increment_by value. + # In reality, it is just increment_by*num_workers, but the point is + # to show how to use the communication primitives. + all_increment_bys = core_context.distributed.allgather(increment_by) + x += sum(all_increment_bys) + steps_completed = batch + 1 + time.sleep(0.1) + # NEW: some logs are easier to read if you only log from the chief. if core_context.distributed.rank == 0: - core_context.train.report_validation_metrics( - steps_completed=steps_completed, metrics={"x": x} - ) - op.report_completed(x) - - # NEW: again, only the chief may upload checkpoints. - if core_context.distributed.rank == 0 and last_checkpoint_batch != steps_completed: - checkpoint_metadata = {"steps_completed": steps_completed} - with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): - save_state(x, steps_completed, trial_id, path) + logging.info(f"x is now {x}") + if steps_completed % 10 == 0: + # NEW: only the chief may report metrics and progress, + # or upload checkpoints. + if core_context.distributed.rank == 0: + core_context.train.report_training_metrics( + steps_completed=steps_completed, metrics={"x": x} + ) + core_context.train.report_progress(steps_completed / float(max_length)) + + core_context.train.report_validation_metrics( + steps_completed=steps_completed, metrics={"x": x} + ) + + checkpoint_metadata = {"steps_completed": steps_completed} + with core_context.checkpoint.store_path(checkpoint_metadata) as (path, uuid): + save_state(x, steps_completed, trial_id, path) + if core_context.preempt.should_preempt(): + return # NEW: Launch one process per slot. In many distributed training frameworks, like horovod, diff --git a/examples/tutorials/core_api/4_distributed.yaml b/examples/tutorials/core_api/4_distributed.yaml index 559a8479c16..a1a2878bc26 100644 --- a/examples/tutorials/core_api/4_distributed.yaml +++ b/examples/tutorials/core_api/4_distributed.yaml @@ -14,6 +14,5 @@ hyperparameters: searcher: name: single metric: x - max_length: 100 max_restarts: 0 diff --git a/examples/tutorials/core_api_pytorch_mnist/adaptive.yaml b/examples/tutorials/core_api_pytorch_mnist/adaptive.yaml index 58728ef7bf7..b87936cc26d 100644 --- a/examples/tutorials/core_api_pytorch_mnist/adaptive.yaml +++ b/examples/tutorials/core_api_pytorch_mnist/adaptive.yaml @@ -29,6 +29,6 @@ searcher: metric: test_loss smaller_is_better: true max_trials: 500 - max_length: - epochs: 20 + time_metric: epochs + max_time: 20 entrypoint: python3 model_def_adaptive.py diff --git a/examples/tutorials/core_api_pytorch_mnist/checkpoints.yaml b/examples/tutorials/core_api_pytorch_mnist/checkpoints.yaml index 6dfd1e0997b..807a50ff003 100644 --- a/examples/tutorials/core_api_pytorch_mnist/checkpoints.yaml +++ b/examples/tutorials/core_api_pytorch_mnist/checkpoints.yaml @@ -1,8 +1,7 @@ name: coreapi_mnist_tutorial_checkpoints description: Save and load checkpoints as well as pause and resume experiments in the WebUI. -entrypoint: python3 model_def_checkpoints.py +entrypoint: python3 model_def_checkpoints.py max_restarts: 0 searcher: name: single - max_length: 1 metric: val_loss diff --git a/examples/tutorials/core_api_pytorch_mnist/const.yaml b/examples/tutorials/core_api_pytorch_mnist/const.yaml index 3cc0f81c1fe..b8dc76a0b76 100644 --- a/examples/tutorials/core_api_pytorch_mnist/const.yaml +++ b/examples/tutorials/core_api_pytorch_mnist/const.yaml @@ -1,8 +1,7 @@ name: coreapi_mnist_tutorial description: A bare-bones experiment configuration file to run the model_def.py script on a Determined cluster. -entrypoint: python3 model_def.py +entrypoint: python3 model_def.py max_restarts: 0 searcher: name: single - max_length: 1 metric: val_loss diff --git a/examples/tutorials/core_api_pytorch_mnist/distributed.yaml b/examples/tutorials/core_api_pytorch_mnist/distributed.yaml index 6d2b8c39116..7c4511ca114 100644 --- a/examples/tutorials/core_api_pytorch_mnist/distributed.yaml +++ b/examples/tutorials/core_api_pytorch_mnist/distributed.yaml @@ -32,7 +32,5 @@ searcher: metric: test_loss smaller_is_better: true max_trials: 500 - max_length: - epochs: 20 resources: slots_per_trial: 4 diff --git a/examples/tutorials/core_api_pytorch_mnist/metrics.yaml b/examples/tutorials/core_api_pytorch_mnist/metrics.yaml index c54a681519f..e9820440a3a 100644 --- a/examples/tutorials/core_api_pytorch_mnist/metrics.yaml +++ b/examples/tutorials/core_api_pytorch_mnist/metrics.yaml @@ -1,8 +1,7 @@ name: coreapi_mnist_tutorial_metrics description: Report training and testing metrics to the master. -entrypoint: python3 model_def_metrics.py +entrypoint: python3 model_def_metrics.py max_restarts: 0 searcher: name: single - max_length: 1 metric: val_loss diff --git a/examples/tutorials/core_api_pytorch_mnist/model_def_adaptive.py b/examples/tutorials/core_api_pytorch_mnist/model_def_adaptive.py index 36b03e41dc7..b599673d2d7 100644 --- a/examples/tutorials/core_api_pytorch_mnist/model_def_adaptive.py +++ b/examples/tutorials/core_api_pytorch_mnist/model_def_adaptive.py @@ -49,9 +49,7 @@ def forward(self, x): return output -# NEW: Modify function header to include op for reporting training -# progress to master. -def train(args, model, device, train_loader, optimizer, core_context, epoch_idx, op): +def train(args, model, device, train_loader, optimizer, epoch_idx, core_context): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -79,13 +77,9 @@ def train(args, model, device, train_loader, optimizer, core_context, epoch_idx, if args.dry_run: break - # NEW: Report progress once in a while. - op.report_progress(epoch_idx) - -# NEW: Modify function header to include op for reporting training -# progress to master and return test loss. -def test(args, model, device, test_loader, core_context, steps_completed, op) -> int: +# NEW: pass epochs_completed through to metrics, for the ASHA searcher to use. +def test(args, model, device, test_loader, core_context, steps_completed, epochs_completed): model.eval() test_loss = 0 correct = 0 @@ -107,12 +101,9 @@ def test(args, model, device, test_loader, core_context, steps_completed, op) -> core_context.train.report_validation_metrics( steps_completed=steps_completed, - metrics={"test_loss": test_loss}, + metrics={"test_loss": test_loss, "epochs": epochs_completed}, ) - # NEW: Return test_loss. - return test_loss - def load_state(checkpoint_directory, trial_id): checkpoint_directory = pathlib.Path(checkpoint_directory) @@ -232,41 +223,22 @@ def main(core_context): scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - # NEW: Iterate through the core_context.searcher.operations() to - # decide how long to train for. - # Start with the number of epochs completed, in case of pausing - # and resuming the experiment. - epoch_idx = epochs_completed - last_checkpoint_batch = None - - for op in core_context.searcher.operations(): - # NEW: Use a while loop for easier accounting of absolute lengths. - while epoch_idx < op.length: - # NEW: Pass op into train() and test(). - train(args, model, device, train_loader, optimizer, core_context, epoch_idx, op) - epochs_completed = epoch_idx + 1 - steps_completed = epochs_completed * len(train_loader) - test_loss = test(args, model, device, test_loader, core_context, steps_completed, op) - - scheduler.step() - - checkpoint_metadata_dict = { - "steps_completed": steps_completed, - } - - epoch_idx += 1 - - with core_context.checkpoint.store_path(checkpoint_metadata_dict) as (path, storage_id): - torch.save(model.state_dict(), path / "checkpoint.pt") - with path.joinpath("state").open("w") as f: - f.write(f"{epochs_completed},{info.trial.trial_id}") - - if core_context.preempt.should_preempt(): - return - - # NEW: After training for each op, validate and report the - # searcher metric to the master. - op.report_completed(test_loss) + for epoch_idx in range(epochs_completed, args.epochs): + train(args, model, device, train_loader, optimizer, epoch_idx, core_context) + epochs_completed = epoch_idx + 1 + steps_completed = epochs_completed * len(train_loader) + test(args, model, device, test_loader, core_context, steps_completed, epochs_completed) + + scheduler.step() + + checkpoint_metadata_dict = {"steps_completed": steps_completed} + with core_context.checkpoint.store_path(checkpoint_metadata_dict) as (path, storage_id): + torch.save(model.state_dict(), path / "checkpoint.pt") + with path.joinpath("state").open("w") as f: + f.write(f"{epochs_completed},{info.trial.trial_id}") + + if core_context.preempt.should_preempt(): + return if __name__ == "__main__": diff --git a/examples/tutorials/core_api_pytorch_mnist/model_def_checkpoints.py b/examples/tutorials/core_api_pytorch_mnist/model_def_checkpoints.py index 325e397bcf5..60d190f5179 100644 --- a/examples/tutorials/core_api_pytorch_mnist/model_def_checkpoints.py +++ b/examples/tutorials/core_api_pytorch_mnist/model_def_checkpoints.py @@ -72,7 +72,7 @@ def train(args, model, device, train_loader, optimizer, epoch_idx, core_context) break -def test(args, model, device, test_loader, epoch, core_context, steps_completed): +def test(args, model, device, test_loader, core_context, steps_completed): model.eval() test_loss = 0 correct = 0 @@ -229,7 +229,6 @@ def main(core_context): model, device, test_loader, - epoch_idx, core_context, steps_completed=steps_completed, ) diff --git a/examples/tutorials/core_api_pytorch_mnist/model_def_distributed.py b/examples/tutorials/core_api_pytorch_mnist/model_def_distributed.py index f9b1cc3d08b..6c309758e8a 100644 --- a/examples/tutorials/core_api_pytorch_mnist/model_def_distributed.py +++ b/examples/tutorials/core_api_pytorch_mnist/model_def_distributed.py @@ -54,7 +54,7 @@ def forward(self, x): return output -def train(args, model, device, train_loader, optimizer, core_context, epoch_idx, op): +def train(args, model, device, train_loader, optimizer, epoch_idx, core_context): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -86,12 +86,8 @@ def train(args, model, device, train_loader, optimizer, core_context, epoch_idx, if args.dry_run: break - # NEW: Report progress only on rank 0. - if core_context.distributed.rank == 0: - op.report_progress(epoch_idx) - -def test(args, model, device, test_loader, core_context, steps_completed, op) -> int: +def test(args, model, device, test_loader, core_context, steps_completed): model.eval() test_loss = 0 correct = 0 @@ -117,8 +113,6 @@ def test(args, model, device, test_loader, core_context, steps_completed, op) -> steps_completed=steps_completed, metrics={"test_loss": test_loss} ) - return test_loss - def load_state(checkpoint_directory, trial_id): checkpoint_directory = pathlib.Path(checkpoint_directory) @@ -264,40 +258,24 @@ def main(core_context): optimizer = optim.Adadelta(model.parameters(), lr=hparams["learning_rate"]) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - epoch_idx = epochs_completed - last_checkpoint_batch = None - - for op in core_context.searcher.operations(): - while epoch_idx < op.length: - train(args, model, device, train_loader, optimizer, core_context, epoch_idx, op) - epochs_completed = epoch_idx + 1 - steps_completed = epochs_completed * len(train_loader) - test_loss = test(args, model, device, test_loader, core_context, steps_completed, op) + for epoch_idx in range(epochs_completed, args.epochs): + train(args, model, device, train_loader, optimizer, epoch_idx, core_context) + epochs_completed = epoch_idx + 1 + steps_completed = epochs_completed * len(train_loader) + test(args, model, device, test_loader, core_context, steps_completed) - scheduler.step() + scheduler.step() - checkpoint_metadata_dict = { - "steps_completed": steps_completed, - } - - epoch_idx += 1 - - # Store checkpoints only on rank 0. - if core_context.distributed.rank == 0: - with core_context.checkpoint.store_path(checkpoint_metadata_dict) as ( - path, - storage_id, - ): - torch.save(model.state_dict(), path / "checkpoint.pt") - with path.joinpath("state").open("w") as f: - f.write(f"{epochs_completed},{info.trial.trial_id}") - - if core_context.preempt.should_preempt(): - return - - # Report completed only on rank 0. + # NEW: Store checkpoints only on rank 0. if core_context.distributed.rank == 0: - op.report_completed(test_loss) + checkpoint_metadata_dict = {"steps_completed": steps_completed} + with core_context.checkpoint.store_path(checkpoint_metadata_dict) as (path, storage_id): + torch.save(model.state_dict(), path / "checkpoint.pt") + with path.joinpath("state").open("w") as f: + f.write(f"{epochs_completed},{info.trial.trial_id}") + + if core_context.preempt.should_preempt(): + return # Docs snippet start: initialize process group diff --git a/examples/tutorials/core_api_pytorch_mnist/model_def_metrics.py b/examples/tutorials/core_api_pytorch_mnist/model_def_metrics.py index 8769e20b0fb..ddcedb6d2d7 100644 --- a/examples/tutorials/core_api_pytorch_mnist/model_def_metrics.py +++ b/examples/tutorials/core_api_pytorch_mnist/model_def_metrics.py @@ -69,11 +69,9 @@ def train(args, model, device, train_loader, optimizer, epoch_idx, core_context) ) ) # Docs snippet start: report training metrics - # NEW: Report training metrics to Determined - # master via core_context. - # Index by (batch_idx + 1) * (epoch-1) * len(train_loader) - # to continuously plot loss on one graph for consecutive - # epochs. + # NEW: Report training metrics to Determined master via core_context. + # Index by batches_completed + epoch_idx * len(train_loader) + # to continuously plot loss on one graph for consecutive epochs. core_context.train.report_training_metrics( steps_completed=batches_completed + epoch_idx * len(train_loader), metrics={"train_loss": loss.item()}, @@ -108,8 +106,7 @@ def test(args, model, device, test_loader, epoch, core_context, steps_completed) ) # Docs snippet end: include args # Docs snippet start: report validation metrics - # NEW: Report validation metrics to Determined master - # via core_context. + # NEW: Report validation metrics to Determined master via core_context. core_context.train.report_validation_metrics( steps_completed=steps_completed, metrics={"test_loss": test_loss}, @@ -213,17 +210,17 @@ def main(core_context): scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch_idx in range(0, args.epochs): - # Docs snippet start: calculate steps completed - # NEW: Calculate steps_completed for plotting test metrics. - steps_completed = epoch_idx * len(train_loader) - # Docs snippet end: calculate steps completed - # Docs snippet start: pass core context # NEW: Pass core_context into train() and test(). train(args, model, device, train_loader, optimizer, epoch_idx, core_context) - # NEW: Pass args, test_loader, epoch, and steps_completed into - # test(). + # Docs snippet start: calculate steps completed + # NEW: Calculate steps_completed for plotting test metrics. + epochs_completed = epoch_idx + 1 + steps_completed = epochs_completed * len(train_loader) + # Docs snippet end: calculate steps completed + + # NEW: Pass args, test_loader, epoch, and steps_completed into test(). test( args, model, @@ -236,8 +233,7 @@ def main(core_context): scheduler.step() # Docs snippet end: pass core context - # NEW: Remove model saving logic, checkpointing shown in next - # stage. + # NEW: Remove model saving logic, checkpointing shown in next stage. # Docs snippet start: modify main loop core context diff --git a/examples/tutorials/mnist_pytorch/README.md b/examples/tutorials/mnist_pytorch/README.md index 5263bc5cdf3..828a06efe13 100644 --- a/examples/tutorials/mnist_pytorch/README.md +++ b/examples/tutorials/mnist_pytorch/README.md @@ -15,47 +15,45 @@ tutorial](https://github.com/pytorch/examples/tree/master/mnist). * **adaptive.yaml**: Perform a hyperparameter search using Determined's state-of-the-art adaptive hyperparameter tuning algorithm. ## Data -This examples uses the MNIST dataset from the `torchvision` datasets subpackage. See -[torchvision docs](https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST) +This examples uses the MNIST dataset from the `torchvision` datasets subpackage. See +[torchvision docs](https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST) for details. ## To Run If you have not yet installed Determined, installation instructions can be found under `docs/install-admin.html` or at https://docs.determined.ai/latest/index.html -The training loop is invoked through the -[`Trainer.fit()`](https://docs.determined.ai/latest/reference/training/api-pytorch-reference.html#determined.pytorch.Trainer.fit) -method, which accepts various arguments for configuring training behavior. The example shown in `train.py` +The training loop is invoked through the +[`Trainer.fit()`](https://docs.determined.ai/latest/reference/training/api-pytorch-reference.html#determined.pytorch.Trainer.fit) +method, which accepts various arguments for configuring training behavior. The example shown in `train.py` illustrates a single implementation that can run in two modes (local training and on-cluster) without any code changes. ### Local Training -The training code in `train.py` can be invoked locally as a regular Python script. Configure the appropriate training -lengths, checkpoint/validation periods, or other desired local training functionality in the `Trainer.fit()` call, -then run `python3 train.py` from your local environment. +The training code in `train.py` can be invoked locally as a regular Python script. Customize the training functionality +in the `Trainer.fit()` call, then run `python3 train.py --batches 1000` from your local environment. ### On-cluster -To run training on-cluster, configure the desired training arguments (checkpoint or validation periods, -checkpoint to start from, etc.) in the `Trainer.fit()` call. An experiment configuration file is also required -for on-cluster experiments (several examples are included in the directory). +To run training on-cluster, customize the training functionality in the `Trainer.fit()` call. An experiment +configuration file is also required for on-cluster experiments (several examples are included in the directory). Then the code can be submitted to Determined for on-cluster training by running this command from the current directory: -`det experiment create const.yaml .`. The other configurations can be run by specifying the desired +`det experiment create const.yaml .`. The other configurations can be run by specifying the desired configuration file in place of `const.yaml`. #### Distributed Training -To train on-cluster across multiple nodes, `slots_per_trial` and `entrypoint`must be configured in the experiment configuration. -`entrypoint` should wrap `train.py` with a Determined launch layer module, which will launch the training script across -the slots specified. The launch layer module can be used in single-slot trials as well, to avoid configuration changes +To train on-cluster across multiple nodes, `slots_per_trial` and `entrypoint`must be configured in the experiment configuration. +`entrypoint` should wrap `train.py` with a Determined launch layer module, which will launch the training script across +the slots specified. The launch layer module can be used in single-slot trials as well, to avoid configuration changes between iterations. ```yaml ... resources: slots_per_trial: 2 -entrypoint: python3 -m determined.launch.torch_distributed python3 train.py +entrypoint: python3 -m determined.launch.torch_distributed python3 train.py --epochs 1 ``` ## Results Training the model with the hyperparameter settings in `const.yaml` should yield -a validation accuracy of ~97%. +a validation accuracy of ~97%. diff --git a/examples/tutorials/mnist_pytorch/adaptive.yaml b/examples/tutorials/mnist_pytorch/adaptive.yaml index 0dc9015a3e1..5953cdad5d6 100644 --- a/examples/tutorials/mnist_pytorch/adaptive.yaml +++ b/examples/tutorials/mnist_pytorch/adaptive.yaml @@ -25,6 +25,6 @@ searcher: metric: validation_loss smaller_is_better: true max_trials: 16 - max_length: - batches: 937 #60,000 training images with batch size 64 -entrypoint: python3 train.py + time_metric: batches + max_time: 937 # 60,000 training images with batch size 64 +entrypoint: python3 train.py --epochs 1 diff --git a/examples/tutorials/mnist_pytorch/const.yaml b/examples/tutorials/mnist_pytorch/const.yaml index ccc802d6a59..c81e5e62f92 100644 --- a/examples/tutorials/mnist_pytorch/const.yaml +++ b/examples/tutorials/mnist_pytorch/const.yaml @@ -8,7 +8,5 @@ hyperparameters: searcher: name: single metric: validation_loss - max_length: - batches: 1000 # approximately 1 epoch smaller_is_better: true -entrypoint: python3 train.py +entrypoint: python3 train.py --epochs 1 diff --git a/examples/tutorials/mnist_pytorch/dist_random.yaml b/examples/tutorials/mnist_pytorch/dist_random.yaml index 4f7cc2d66b0..b2d74389db6 100644 --- a/examples/tutorials/mnist_pytorch/dist_random.yaml +++ b/examples/tutorials/mnist_pytorch/dist_random.yaml @@ -25,8 +25,8 @@ searcher: metric: accuracy smaller_is_better: true max_trials: 2 - max_length: - epochs: 1 resources: slots_per_trial: 1 -entrypoint: python3 -m determined.launch.torch_distributed python3 train.py +entrypoint: >- + python3 -m determined.launch.torch_distributed + python3 train.py --epochs 1 diff --git a/examples/tutorials/mnist_pytorch/distributed.yaml b/examples/tutorials/mnist_pytorch/distributed.yaml index a11654dd6db..904a3694bb2 100644 --- a/examples/tutorials/mnist_pytorch/distributed.yaml +++ b/examples/tutorials/mnist_pytorch/distributed.yaml @@ -8,9 +8,9 @@ hyperparameters: searcher: name: single metric: validation_loss - max_length: - epochs: 1 smaller_is_better: true resources: slots_per_trial: 8 -entrypoint: python3 -m determined.launch.torch_distributed python3 train.py +entrypoint: >- + python3 -m determined.launch.torch_distributed + python3 train.py --epochs 1 diff --git a/examples/tutorials/mnist_pytorch/train.py b/examples/tutorials/mnist_pytorch/train.py index 1fa1fee3821..78d7eadd53c 100644 --- a/examples/tutorials/mnist_pytorch/train.py +++ b/examples/tutorials/mnist_pytorch/train.py @@ -12,6 +12,7 @@ The model can be trained either locally or on-cluster with the same training code. """ +import argparse import logging import pathlib from typing import Any, Dict @@ -76,10 +77,13 @@ def evaluate_batch(self, batch: pytorch.TorchData, batch_idx: int) -> Dict[str, pred = output.argmax(dim=1, keepdim=True) accuracy = pred.eq(labels.view_as(pred)).sum().item() / len(batch_data) - return {"validation_loss": validation_loss, "accuracy": accuracy} + return { + "validation_loss": validation_loss, + "accuracy": accuracy, + } -def run(local: bool = False): +def run(max_length, local: bool = False): """Initializes the trial and runs the training loop. This method configures the appropriate training parameters for both local and on-cluster @@ -100,11 +104,9 @@ def run(local: bool = False): yml = yaml.YAML(typ="safe", pure=True) conf = yml.load(pathlib.Path("./const.yaml").read_text()) hparams = conf["hyperparameters"] - max_length = pytorch.Batch(100) # Train for 100 batches. latest_checkpoint = None else: hparams = info.trial.hparams # Get instance of hparam values from Determined cluster info. - max_length = None # On-cluster training trains for the searcher's configured length. latest_checkpoint = ( info.latest_checkpoint ) # (Optional) Configure checkpoint for pause/resume functionality. @@ -112,12 +114,30 @@ def run(local: bool = False): with pytorch.init() as train_context: trial = MNistTrial(train_context, hparams=hparams) trainer = pytorch.Trainer(trial, train_context) - trainer.fit(max_length=max_length, latest_checkpoint=latest_checkpoint) + trainer.fit( + max_length=max_length, + latest_checkpoint=latest_checkpoint, + validation_period=pytorch.Batch(100), + ) if __name__ == "__main__": # Configure logging logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + # Parse command line options + parser = argparse.ArgumentParser() + group = parser.add_mutually_exclusive_group() + group.add_argument("--batches", type=int) + group.add_argument("--epochs", type=int) + args = parser.parse_args() + if args.batches: + max_length = pytorch.Batch(args.batches) + elif args.epochs: + max_length = pytorch.Epoch(args.epochs) + else: + # default training length + max_length = pytorch.Batch(100) + local_training = det.get_cluster_info() is None - run(local=local_training) + run(max_length, local=local_training) diff --git a/harness/determined/_execution.py b/harness/determined/_execution.py index d5f55cc5456..277ac9c21cd 100644 --- a/harness/determined/_execution.py +++ b/harness/determined/_execution.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Type import determined as det -from determined import constants, core, gpu, load +from determined import core, gpu, load logger = logging.getLogger("determined") @@ -116,6 +116,20 @@ def _make_test_experiment_config(config: Dict[str, Any]) -> Dict[str, Any]: return config_test +DEFAULT_TEST_CONFIG = { + "searcher": {"name": "single", "max_length": {"batches": 100}}, + "scheduling_unit": 100, + "resources": {"slots_per_trial": 1}, + "optimizations": { + "aggregation_frequency": 1, + "average_aggregated_gradients": True, + "average_training_metrics": True, + "gradient_compression": False, + "mixed_precision": "O0", + }, +} + + def _make_local_execution_exp_config( input_config: Optional[Dict[str, Any]], checkpoint_dir: str, @@ -157,7 +171,7 @@ def _make_local_execution_exp_config( "host_path": os.path.abspath(checkpoint_dir), } - return {"checkpoint_storage": checkpoint_storage, **constants.DEFAULT_EXP_CFG, **input_config} + return {"checkpoint_storage": checkpoint_storage, **DEFAULT_TEST_CONFIG, **input_config} def _make_local_execution_env( diff --git a/harness/determined/cli/cli.py b/harness/determined/cli/cli.py index 50044554d3b..17b63a997f6 100644 --- a/harness/determined/cli/cli.py +++ b/harness/determined/cli/cli.py @@ -47,66 +47,52 @@ version, workspace, ) -from determined.common import api, util, yaml +from determined.common import api, util from determined.common.api import bindings, certs +def _render_search_summary(resp: bindings.v1PreviewHPSearchResponse) -> str: + output = [ + termcolor.colored("Using search configuration:", "green"), + ] + + # For mypy + assert resp.summary and resp.summary.config and resp.summary.trials + # Exclude empty configs from rendering. + searcher_config = {k: v for k, v in resp.summary.config.items() if v is not None} + + config_str = render.format_object_as_yaml(searcher_config) + output.append(config_str) + headers = ["Trials", "Training Time"] + trial_summaries = [] + for trial_summary in resp.summary.trials: + num_trials = trial_summary.count + trial_unit = trial_summary.unit + if trial_unit.maxLength: + summary = "train to completion" + else: + summary = f"train for {trial_unit.value} {trial_unit.name}" + trial_summaries.append([num_trials, summary]) + + output.append(tabulate.tabulate(trial_summaries, headers, tablefmt="presto")) + return "\n".join(output) + + def preview_search(args: argparse.Namespace) -> None: sess = cli.setup_session(args) experiment_config = util.safe_load_yaml_with_exceptions(args.config_file) args.config_file.close() if "searcher" not in experiment_config: - print("Experiment configuration must have 'searcher' section") - sys.exit(1) - r = sess.post("searcher/preview", json=experiment_config) - j = r.json() + raise errors.CliError("Missing 'searcher' config section in experiment config.") - def to_full_name(kind: str) -> str: - try: - # The unitless searcher case, for masters newer than 0.17.6. - length = int(kind) - return f"train for {length}" - except ValueError: - pass - if kind[-1] == "R": - return "train {} records".format(kind[:-1]) - if kind[-1] == "B": - return "train {} batch(es)".format(kind[:-1]) - if kind[-1] == "E": - return "train {} epoch(s)".format(kind[:-1]) - if kind == "V": - return "validation" - raise ValueError("unexpected kind: {}".format(kind)) - - def render_sequence(sequence: List[str]) -> str: - if not sequence: - return "N/A" - instructions = [] - current = sequence[0] - count = 0 - for k in sequence: - if k != current: - instructions.append("{} x {}".format(count, to_full_name(current))) - current = k - count = 1 - else: - count += 1 - instructions.append("{} x {}".format(count, to_full_name(current))) - return ", ".join(instructions) - - headers = ["Trials", "Breakdown"] - values = [ - (count, render_sequence(operations.split())) for operations, count in j["results"].items() - ] - - print(termcolor.colored("Using search configuration:", "green")) - yml = yaml.YAML() - yml.indent(mapping=2, sequence=4, offset=2) - yml.dump(experiment_config["searcher"], sys.stdout) - print() - print("This search will create a total of {} trial(s).".format(sum(j["results"].values()))) - print(tabulate.tabulate(values, headers, tablefmt="presto"), flush=False) + resp = bindings.post_PreviewHPSearch( + session=sess, + body=bindings.v1PreviewHPSearchRequest( + config=experiment_config, + ), + ) + print(_render_search_summary(resp=resp)) args_description = [ diff --git a/harness/determined/cli/experiment.py b/harness/determined/cli/experiment.py index 52ed843a67b..8ec1b208cf2 100644 --- a/harness/determined/cli/experiment.py +++ b/harness/determined/cli/experiment.py @@ -6,6 +6,7 @@ import pprint import sys import time +import warnings from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union import tabulate @@ -188,6 +189,14 @@ def submit_experiment(args: argparse.Namespace) -> None: ) if args.test_mode: + warnings.warn( + "The --test flag to det experiment create has been deprecated in Determined 0.38.0 and " + "will be removed in a future version. The searcher.max_length setting of the " + "experiment config has also been deprecated, and since --test mode relies on that " + "setting, --test mode will cease to work as soon as you remove max_length.", + FutureWarning, + stacklevel=2, + ) print(termcolor.colored("Validating experiment configuration...", "yellow"), end="\r") bindings.post_CreateExperiment(sess, body=req) print(termcolor.colored("Experiment configuration validation succeeded! 🎉", "green")) @@ -266,6 +275,17 @@ def local_experiment(args: argparse.Namespace) -> None: "directly?" ) + warnings.warn( + "The --local and --test flags to det experiment create have both been deprecated in " + "Determined 0.38.0 and will be removed in a future version. The searcher.max_length " + "setting of the experiment config has also been deprecated, and since --test mode relies " + "on that setting, --test mode will cease to work as soon as you remove max_length. " + "Additionally, --local mode should no longer be necessary, as you should be able to just " + "invoke your script directly.", + FutureWarning, + stacklevel=2, + ) + common.set_logger(bool(experiment_config.get("debug", False))) with det._local_execution_manager(args.model_def.resolve()): diff --git a/harness/determined/common/api/bindings.py b/harness/determined/common/api/bindings.py index 662ce773fc2..b6384d66a20 100644 --- a/harness/determined/common/api/bindings.py +++ b/harness/determined/common/api/bindings.py @@ -2464,33 +2464,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: } return out -class v1CloseTrialOperation(Printable): - """Close a trial with given ID.""" - requestId: "typing.Optional[str]" = None - - def __init__( - self, - *, - requestId: "typing.Union[str, None, Unset]" = _unset, - ): - if not isinstance(requestId, Unset): - self.requestId = requestId - - @classmethod - def from_json(cls, obj: Json) -> "v1CloseTrialOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "requestId" in obj: - kwargs["requestId"] = obj["requestId"] - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "requestId" in vars(self): - out["requestId"] = self.requestId - return out - class v1ClusterMessage(Printable): """Active notice from the server admin.""" createdTime: "typing.Optional[str]" = None @@ -2681,41 +2654,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: } return out -class v1CompleteValidateAfterOperation(Printable): - """Used to complete a ValidateAfterOperation.""" - op: "typing.Optional[v1ValidateAfterOperation]" = None - searcherMetric: "typing.Optional[typing.Any]" = None - - def __init__( - self, - *, - op: "typing.Union[v1ValidateAfterOperation, None, Unset]" = _unset, - searcherMetric: "typing.Union[typing.Any, None, Unset]" = _unset, - ): - if not isinstance(op, Unset): - self.op = op - if not isinstance(searcherMetric, Unset): - self.searcherMetric = searcherMetric - - @classmethod - def from_json(cls, obj: Json) -> "v1CompleteValidateAfterOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "op" in obj: - kwargs["op"] = v1ValidateAfterOperation.from_json(obj["op"]) if obj["op"] is not None else None - if "searcherMetric" in obj: - kwargs["searcherMetric"] = obj["searcherMetric"] - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "op" in vars(self): - out["op"] = None if self.op is None else self.op.to_json(omit_unset) - if not omit_unset or "searcherMetric" in vars(self): - out["searcherMetric"] = self.searcherMetric - return out - class v1Config(Printable): """The config to be patched into Master Config.""" log: "typing.Optional[v1LogConfig]" = None @@ -3132,41 +3070,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: } return out -class v1CreateTrialOperation(Printable): - """Create a trial with given hyperparameters.""" - hyperparams: "typing.Optional[str]" = None - requestId: "typing.Optional[str]" = None - - def __init__( - self, - *, - hyperparams: "typing.Union[str, None, Unset]" = _unset, - requestId: "typing.Union[str, None, Unset]" = _unset, - ): - if not isinstance(hyperparams, Unset): - self.hyperparams = hyperparams - if not isinstance(requestId, Unset): - self.requestId = requestId - - @classmethod - def from_json(cls, obj: Json) -> "v1CreateTrialOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "hyperparams" in obj: - kwargs["hyperparams"] = obj["hyperparams"] - if "requestId" in obj: - kwargs["requestId"] = obj["requestId"] - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "hyperparams" in vars(self): - out["hyperparams"] = self.hyperparams - if not omit_unset or "requestId" in vars(self): - out["requestId"] = self.requestId - return out - class v1CreateTrialRequest(Printable): """Create a trial.""" experimentId: "typing.Optional[int]" = None @@ -4340,76 +4243,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: } return out -class v1ExperimentInactive(Printable): - """ExperimentInactive is a searcher event triggered when an experiment - is no longer active. - """ - - def __init__( - self, - *, - experimentState: "experimentv1State", - ): - self.experimentState = experimentState - - @classmethod - def from_json(cls, obj: Json) -> "v1ExperimentInactive": - kwargs: "typing.Dict[str, typing.Any]" = { - "experimentState": experimentv1State(obj["experimentState"]), - } - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - "experimentState": self.experimentState.value, - } - return out - -class v1ExperimentSimulation(Printable): - """ExperimentSimulation holds the configuration and results of simulated run of - a searcher. - """ - config: "typing.Optional[typing.Dict[str, typing.Any]]" = None - seed: "typing.Optional[int]" = None - trials: "typing.Optional[typing.Sequence[v1TrialSimulation]]" = None - - def __init__( - self, - *, - config: "typing.Union[typing.Dict[str, typing.Any], None, Unset]" = _unset, - seed: "typing.Union[int, None, Unset]" = _unset, - trials: "typing.Union[typing.Sequence[v1TrialSimulation], None, Unset]" = _unset, - ): - if not isinstance(config, Unset): - self.config = config - if not isinstance(seed, Unset): - self.seed = seed - if not isinstance(trials, Unset): - self.trials = trials - - @classmethod - def from_json(cls, obj: Json) -> "v1ExperimentSimulation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "config" in obj: - kwargs["config"] = obj["config"] - if "seed" in obj: - kwargs["seed"] = obj["seed"] - if "trials" in obj: - kwargs["trials"] = [v1TrialSimulation.from_json(x) for x in obj["trials"]] if obj["trials"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "config" in vars(self): - out["config"] = self.config - if not omit_unset or "seed" in vars(self): - out["seed"] = self.seed - if not omit_unset or "trials" in vars(self): - out["trials"] = None if self.trials is None else [x.to_json(omit_unset) for x in self.trials] - return out - class v1FailureType(DetEnum): """The failure type of a resource. - FAILURE_TYPE_UNSPECIFIED: UNSPECIFIED denotes an error that is not defined below. @@ -5162,40 +4995,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["pagination"] = None if self.pagination is None else self.pagination.to_json(omit_unset) return out -class v1GetCurrentTrialSearcherOperationResponse(Printable): - completed: "typing.Optional[bool]" = None - op: "typing.Optional[v1TrialOperation]" = None - - def __init__( - self, - *, - completed: "typing.Union[bool, None, Unset]" = _unset, - op: "typing.Union[v1TrialOperation, None, Unset]" = _unset, - ): - if not isinstance(completed, Unset): - self.completed = completed - if not isinstance(op, Unset): - self.op = op - - @classmethod - def from_json(cls, obj: Json) -> "v1GetCurrentTrialSearcherOperationResponse": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "completed" in obj: - kwargs["completed"] = obj["completed"] - if "op" in obj: - kwargs["op"] = v1TrialOperation.from_json(obj["op"]) if obj["op"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "completed" in vars(self): - out["completed"] = self.completed - if not omit_unset or "op" in vars(self): - out["op"] = None if self.op is None else self.op.to_json(omit_unset) - return out - class v1GetExperimentCheckpointsResponse(Printable): """Response to GetExperimentCheckpointsRequest.""" @@ -6620,33 +6419,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["metadata"] = self.metadata return out -class v1GetSearcherEventsResponse(Printable): - """Response to GetSearcherEventsRequest.""" - searcherEvents: "typing.Optional[typing.Sequence[v1SearcherEvent]]" = None - - def __init__( - self, - *, - searcherEvents: "typing.Union[typing.Sequence[v1SearcherEvent], None, Unset]" = _unset, - ): - if not isinstance(searcherEvents, Unset): - self.searcherEvents = searcherEvents - - @classmethod - def from_json(cls, obj: Json) -> "v1GetSearcherEventsResponse": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "searcherEvents" in obj: - kwargs["searcherEvents"] = [v1SearcherEvent.from_json(x) for x in obj["searcherEvents"]] if obj["searcherEvents"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "searcherEvents" in vars(self): - out["searcherEvents"] = None if self.searcherEvents is None else [x.to_json(omit_unset) for x in self.searcherEvents] - return out - class v1GetShellResponse(Printable): """Response to GetShellRequest.""" @@ -7769,35 +7541,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["notebookId"] = self.notebookId return out -class v1InitialOperations(Printable): - """InitialOperations is a searcher event signaling the creation of an - experiment. - """ - placeholder: "typing.Optional[int]" = None - - def __init__( - self, - *, - placeholder: "typing.Union[int, None, Unset]" = _unset, - ): - if not isinstance(placeholder, Unset): - self.placeholder = placeholder - - @classmethod - def from_json(cls, obj: Json) -> "v1InitialOperations": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "placeholder" in obj: - kwargs["placeholder"] = obj["placeholder"] - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "placeholder" in vars(self): - out["placeholder"] = self.placeholder - return out - class v1Int32FieldFilter(Printable): """Int32 filters.""" gt: "typing.Optional[int]" = None @@ -12144,49 +11887,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["metadata"] = self.metadata return out -class v1PostSearcherOperationsRequest(Printable): - """Request for sending operations from a custom search method.""" - experimentId: "typing.Optional[int]" = None - searcherOperations: "typing.Optional[typing.Sequence[v1SearcherOperation]]" = None - triggeredByEvent: "typing.Optional[v1SearcherEvent]" = None - - def __init__( - self, - *, - experimentId: "typing.Union[int, None, Unset]" = _unset, - searcherOperations: "typing.Union[typing.Sequence[v1SearcherOperation], None, Unset]" = _unset, - triggeredByEvent: "typing.Union[v1SearcherEvent, None, Unset]" = _unset, - ): - if not isinstance(experimentId, Unset): - self.experimentId = experimentId - if not isinstance(searcherOperations, Unset): - self.searcherOperations = searcherOperations - if not isinstance(triggeredByEvent, Unset): - self.triggeredByEvent = triggeredByEvent - - @classmethod - def from_json(cls, obj: Json) -> "v1PostSearcherOperationsRequest": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "experimentId" in obj: - kwargs["experimentId"] = obj["experimentId"] - if "searcherOperations" in obj: - kwargs["searcherOperations"] = [v1SearcherOperation.from_json(x) for x in obj["searcherOperations"]] if obj["searcherOperations"] is not None else None - if "triggeredByEvent" in obj: - kwargs["triggeredByEvent"] = v1SearcherEvent.from_json(obj["triggeredByEvent"]) if obj["triggeredByEvent"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "experimentId" in vars(self): - out["experimentId"] = self.experimentId - if not omit_unset or "searcherOperations" in vars(self): - out["searcherOperations"] = None if self.searcherOperations is None else [x.to_json(omit_unset) for x in self.searcherOperations] - if not omit_unset or "triggeredByEvent" in vars(self): - out["triggeredByEvent"] = None if self.triggeredByEvent is None else self.triggeredByEvent.to_json(omit_unset) - return out - class v1PostTaskLogsRequest(Printable): """Request to PostTaskLogs.""" @@ -12581,29 +12281,29 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: class v1PreviewHPSearchResponse(Printable): """Response to PreviewSearchRequest.""" - simulation: "typing.Optional[v1ExperimentSimulation]" = None + summary: "typing.Optional[v1SearchSummary]" = None def __init__( self, *, - simulation: "typing.Union[v1ExperimentSimulation, None, Unset]" = _unset, + summary: "typing.Union[v1SearchSummary, None, Unset]" = _unset, ): - if not isinstance(simulation, Unset): - self.simulation = simulation + if not isinstance(summary, Unset): + self.summary = summary @classmethod def from_json(cls, obj: Json) -> "v1PreviewHPSearchResponse": kwargs: "typing.Dict[str, typing.Any]" = { } - if "simulation" in obj: - kwargs["simulation"] = v1ExperimentSimulation.from_json(obj["simulation"]) if obj["simulation"] is not None else None + if "summary" in obj: + kwargs["summary"] = v1SearchSummary.from_json(obj["summary"]) if obj["summary"] is not None else None return cls(**kwargs) def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out: "typing.Dict[str, typing.Any]" = { } - if not omit_unset or "simulation" in vars(self): - out["simulation"] = None if self.simulation is None else self.simulation.to_json(omit_unset) + if not omit_unset or "summary" in vars(self): + out["summary"] = None if self.summary is None else self.summary.to_json(omit_unset) return out class v1Project(Printable): @@ -14798,54 +14498,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["storageId"] = self.storageId return out -class v1RunnableOperation(Printable): - """RunnableOperation represents a single runnable operation emitted by a - searcher. - """ - length: "typing.Optional[str]" = None - type: "typing.Optional[v1RunnableType]" = None - - def __init__( - self, - *, - length: "typing.Union[str, None, Unset]" = _unset, - type: "typing.Union[v1RunnableType, None, Unset]" = _unset, - ): - if not isinstance(length, Unset): - self.length = length - if not isinstance(type, Unset): - self.type = type - - @classmethod - def from_json(cls, obj: Json) -> "v1RunnableOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "length" in obj: - kwargs["length"] = obj["length"] - if "type" in obj: - kwargs["type"] = v1RunnableType(obj["type"]) if obj["type"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "length" in vars(self): - out["length"] = self.length - if not omit_unset or "type" in vars(self): - out["type"] = None if self.type is None else self.type.value - return out - -class v1RunnableType(DetEnum): - """RunnableType defines the type of operation that should be executed by trial - runners. - - RUNNABLE_TYPE_UNSPECIFIED: Denotes an unknown runnable type. - - RUNNABLE_TYPE_TRAIN: Signals to a trial runner that it should run a train. - - RUNNABLE_TYPE_VALIDATE: Signals to a trial runner it should compute validation metrics. - """ - UNSPECIFIED = "RUNNABLE_TYPE_UNSPECIFIED" - TRAIN = "RUNNABLE_TYPE_TRAIN" - VALIDATE = "RUNNABLE_TYPE_VALIDATE" - class v1SSOProvider(Printable): """Describe one SSO provider.""" alwaysRedirect: "typing.Optional[bool]" = None @@ -15239,144 +14891,76 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: } return out -class v1SearcherEvent(Printable): - """SearcherEvent is a message from master to a client-driven custom searcher - informing it of relevant changes in the state of an experiment. +class v1SearchSummary(Printable): + """SearchSummary contains the estimated trials and training lengths that a + search plans to execute. """ - experimentInactive: "typing.Optional[v1ExperimentInactive]" = None - initialOperations: "typing.Optional[v1InitialOperations]" = None - trialClosed: "typing.Optional[v1TrialClosed]" = None - trialCreated: "typing.Optional[v1TrialCreated]" = None - trialExitedEarly: "typing.Optional[v1TrialExitedEarly]" = None - trialProgress: "typing.Optional[v1TrialProgress]" = None - validationCompleted: "typing.Optional[v1ValidationCompleted]" = None + trials: "typing.Optional[typing.Sequence[v1TrialSummary]]" = None def __init__( self, *, - id: int, - experimentInactive: "typing.Union[v1ExperimentInactive, None, Unset]" = _unset, - initialOperations: "typing.Union[v1InitialOperations, None, Unset]" = _unset, - trialClosed: "typing.Union[v1TrialClosed, None, Unset]" = _unset, - trialCreated: "typing.Union[v1TrialCreated, None, Unset]" = _unset, - trialExitedEarly: "typing.Union[v1TrialExitedEarly, None, Unset]" = _unset, - trialProgress: "typing.Union[v1TrialProgress, None, Unset]" = _unset, - validationCompleted: "typing.Union[v1ValidationCompleted, None, Unset]" = _unset, + config: "typing.Dict[str, typing.Any]", + trials: "typing.Union[typing.Sequence[v1TrialSummary], None, Unset]" = _unset, ): - self.id = id - if not isinstance(experimentInactive, Unset): - self.experimentInactive = experimentInactive - if not isinstance(initialOperations, Unset): - self.initialOperations = initialOperations - if not isinstance(trialClosed, Unset): - self.trialClosed = trialClosed - if not isinstance(trialCreated, Unset): - self.trialCreated = trialCreated - if not isinstance(trialExitedEarly, Unset): - self.trialExitedEarly = trialExitedEarly - if not isinstance(trialProgress, Unset): - self.trialProgress = trialProgress - if not isinstance(validationCompleted, Unset): - self.validationCompleted = validationCompleted - - @classmethod - def from_json(cls, obj: Json) -> "v1SearcherEvent": + self.config = config + if not isinstance(trials, Unset): + self.trials = trials + + @classmethod + def from_json(cls, obj: Json) -> "v1SearchSummary": kwargs: "typing.Dict[str, typing.Any]" = { - "id": obj["id"], + "config": obj["config"], } - if "experimentInactive" in obj: - kwargs["experimentInactive"] = v1ExperimentInactive.from_json(obj["experimentInactive"]) if obj["experimentInactive"] is not None else None - if "initialOperations" in obj: - kwargs["initialOperations"] = v1InitialOperations.from_json(obj["initialOperations"]) if obj["initialOperations"] is not None else None - if "trialClosed" in obj: - kwargs["trialClosed"] = v1TrialClosed.from_json(obj["trialClosed"]) if obj["trialClosed"] is not None else None - if "trialCreated" in obj: - kwargs["trialCreated"] = v1TrialCreated.from_json(obj["trialCreated"]) if obj["trialCreated"] is not None else None - if "trialExitedEarly" in obj: - kwargs["trialExitedEarly"] = v1TrialExitedEarly.from_json(obj["trialExitedEarly"]) if obj["trialExitedEarly"] is not None else None - if "trialProgress" in obj: - kwargs["trialProgress"] = v1TrialProgress.from_json(obj["trialProgress"]) if obj["trialProgress"] is not None else None - if "validationCompleted" in obj: - kwargs["validationCompleted"] = v1ValidationCompleted.from_json(obj["validationCompleted"]) if obj["validationCompleted"] is not None else None + if "trials" in obj: + kwargs["trials"] = [v1TrialSummary.from_json(x) for x in obj["trials"]] if obj["trials"] is not None else None return cls(**kwargs) def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out: "typing.Dict[str, typing.Any]" = { - "id": self.id, + "config": self.config, } - if not omit_unset or "experimentInactive" in vars(self): - out["experimentInactive"] = None if self.experimentInactive is None else self.experimentInactive.to_json(omit_unset) - if not omit_unset or "initialOperations" in vars(self): - out["initialOperations"] = None if self.initialOperations is None else self.initialOperations.to_json(omit_unset) - if not omit_unset or "trialClosed" in vars(self): - out["trialClosed"] = None if self.trialClosed is None else self.trialClosed.to_json(omit_unset) - if not omit_unset or "trialCreated" in vars(self): - out["trialCreated"] = None if self.trialCreated is None else self.trialCreated.to_json(omit_unset) - if not omit_unset or "trialExitedEarly" in vars(self): - out["trialExitedEarly"] = None if self.trialExitedEarly is None else self.trialExitedEarly.to_json(omit_unset) - if not omit_unset or "trialProgress" in vars(self): - out["trialProgress"] = None if self.trialProgress is None else self.trialProgress.to_json(omit_unset) - if not omit_unset or "validationCompleted" in vars(self): - out["validationCompleted"] = None if self.validationCompleted is None else self.validationCompleted.to_json(omit_unset) + if not omit_unset or "trials" in vars(self): + out["trials"] = None if self.trials is None else [x.to_json(omit_unset) for x in self.trials] return out -class v1SearcherOperation(Printable): - """SearcherOperation is an operation issued by the custom searcher.""" - closeTrial: "typing.Optional[v1CloseTrialOperation]" = None - createTrial: "typing.Optional[v1CreateTrialOperation]" = None - setSearcherProgress: "typing.Optional[v1SetSearcherProgressOperation]" = None - shutDown: "typing.Optional[v1ShutDownOperation]" = None - trialOperation: "typing.Optional[v1TrialOperation]" = None +class v1SearchUnit(Printable): + """SearchUnit describes a length unit used by some searchers to manage training.""" + name: "typing.Optional[str]" = None + value: "typing.Optional[int]" = None def __init__( self, *, - closeTrial: "typing.Union[v1CloseTrialOperation, None, Unset]" = _unset, - createTrial: "typing.Union[v1CreateTrialOperation, None, Unset]" = _unset, - setSearcherProgress: "typing.Union[v1SetSearcherProgressOperation, None, Unset]" = _unset, - shutDown: "typing.Union[v1ShutDownOperation, None, Unset]" = _unset, - trialOperation: "typing.Union[v1TrialOperation, None, Unset]" = _unset, + maxLength: bool, + name: "typing.Union[str, None, Unset]" = _unset, + value: "typing.Union[int, None, Unset]" = _unset, ): - if not isinstance(closeTrial, Unset): - self.closeTrial = closeTrial - if not isinstance(createTrial, Unset): - self.createTrial = createTrial - if not isinstance(setSearcherProgress, Unset): - self.setSearcherProgress = setSearcherProgress - if not isinstance(shutDown, Unset): - self.shutDown = shutDown - if not isinstance(trialOperation, Unset): - self.trialOperation = trialOperation + self.maxLength = maxLength + if not isinstance(name, Unset): + self.name = name + if not isinstance(value, Unset): + self.value = value @classmethod - def from_json(cls, obj: Json) -> "v1SearcherOperation": + def from_json(cls, obj: Json) -> "v1SearchUnit": kwargs: "typing.Dict[str, typing.Any]" = { + "maxLength": obj["maxLength"], } - if "closeTrial" in obj: - kwargs["closeTrial"] = v1CloseTrialOperation.from_json(obj["closeTrial"]) if obj["closeTrial"] is not None else None - if "createTrial" in obj: - kwargs["createTrial"] = v1CreateTrialOperation.from_json(obj["createTrial"]) if obj["createTrial"] is not None else None - if "setSearcherProgress" in obj: - kwargs["setSearcherProgress"] = v1SetSearcherProgressOperation.from_json(obj["setSearcherProgress"]) if obj["setSearcherProgress"] is not None else None - if "shutDown" in obj: - kwargs["shutDown"] = v1ShutDownOperation.from_json(obj["shutDown"]) if obj["shutDown"] is not None else None - if "trialOperation" in obj: - kwargs["trialOperation"] = v1TrialOperation.from_json(obj["trialOperation"]) if obj["trialOperation"] is not None else None + if "name" in obj: + kwargs["name"] = obj["name"] + if "value" in obj: + kwargs["value"] = obj["value"] return cls(**kwargs) def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out: "typing.Dict[str, typing.Any]" = { + "maxLength": self.maxLength, } - if not omit_unset or "closeTrial" in vars(self): - out["closeTrial"] = None if self.closeTrial is None else self.closeTrial.to_json(omit_unset) - if not omit_unset or "createTrial" in vars(self): - out["createTrial"] = None if self.createTrial is None else self.createTrial.to_json(omit_unset) - if not omit_unset or "setSearcherProgress" in vars(self): - out["setSearcherProgress"] = None if self.setSearcherProgress is None else self.setSearcherProgress.to_json(omit_unset) - if not omit_unset or "shutDown" in vars(self): - out["shutDown"] = None if self.shutDown is None else self.shutDown.to_json(omit_unset) - if not omit_unset or "trialOperation" in vars(self): - out["trialOperation"] = None if self.trialOperation is None else self.trialOperation.to_json(omit_unset) + if not omit_unset or "name" in vars(self): + out["name"] = self.name + if not omit_unset or "value" in vars(self): + out["value"] = self.value return out class v1SetClusterMessageRequest(Printable): @@ -15577,50 +15161,21 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["clusterQuotaPairs"] = self.clusterQuotaPairs return out -class v1SetSearcherProgressOperation(Printable): - """SetSearcherProgressOperation informs the master of the progress of the custom - searcher. - """ - progress: "typing.Optional[float]" = None +class v1SetShellPriorityRequest(Printable): + """Set the priority of the requested shell.""" + priority: "typing.Optional[int]" = None + shellId: "typing.Optional[str]" = None def __init__( self, *, - progress: "typing.Union[float, None, Unset]" = _unset, + priority: "typing.Union[int, None, Unset]" = _unset, + shellId: "typing.Union[str, None, Unset]" = _unset, ): - if not isinstance(progress, Unset): - self.progress = progress - - @classmethod - def from_json(cls, obj: Json) -> "v1SetSearcherProgressOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "progress" in obj: - kwargs["progress"] = float(obj["progress"]) if obj["progress"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "progress" in vars(self): - out["progress"] = None if self.progress is None else dump_float(self.progress) - return out - -class v1SetShellPriorityRequest(Printable): - """Set the priority of the requested shell.""" - priority: "typing.Optional[int]" = None - shellId: "typing.Optional[str]" = None - - def __init__( - self, - *, - priority: "typing.Union[int, None, Unset]" = _unset, - shellId: "typing.Union[str, None, Unset]" = _unset, - ): - if not isinstance(priority, Unset): - self.priority = priority - if not isinstance(shellId, Unset): - self.shellId = shellId + if not isinstance(priority, Unset): + self.priority = priority + if not isinstance(shellId, Unset): + self.shellId = shellId @classmethod def from_json(cls, obj: Json) -> "v1SetShellPriorityRequest": @@ -15926,41 +15481,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["userId"] = self.userId return out -class v1ShutDownOperation(Printable): - """Shut down custom searcher method.""" - cancel: "typing.Optional[bool]" = None - failure: "typing.Optional[bool]" = None - - def __init__( - self, - *, - cancel: "typing.Union[bool, None, Unset]" = _unset, - failure: "typing.Union[bool, None, Unset]" = _unset, - ): - if not isinstance(cancel, Unset): - self.cancel = cancel - if not isinstance(failure, Unset): - self.failure = failure - - @classmethod - def from_json(cls, obj: Json) -> "v1ShutDownOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "cancel" in obj: - kwargs["cancel"] = obj["cancel"] - if "failure" in obj: - kwargs["failure"] = obj["failure"] - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "cancel" in vars(self): - out["cancel"] = self.cancel - if not omit_unset or "failure" in vars(self): - out["failure"] = self.failure - return out - class v1Slot(Printable): """Slot wraps a single device on the agent.""" container: "typing.Optional[v1Container]" = None @@ -16773,54 +16293,6 @@ class v1TokenType(DetEnum): USER_SESSION = "TOKEN_TYPE_USER_SESSION" ACCESS_TOKEN = "TOKEN_TYPE_ACCESS_TOKEN" -class v1TrialClosed(Printable): - """TrialClosed is a searcher event triggered when a trial has successfully - finished. - """ - - def __init__( - self, - *, - requestId: str, - ): - self.requestId = requestId - - @classmethod - def from_json(cls, obj: Json) -> "v1TrialClosed": - kwargs: "typing.Dict[str, typing.Any]" = { - "requestId": obj["requestId"], - } - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - "requestId": self.requestId, - } - return out - -class v1TrialCreated(Printable): - """TrialCreated is a searcher event signaling the creation of a trial.""" - - def __init__( - self, - *, - requestId: str, - ): - self.requestId = requestId - - @classmethod - def from_json(cls, obj: Json) -> "v1TrialCreated": - kwargs: "typing.Dict[str, typing.Any]" = { - "requestId": obj["requestId"], - } - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - "requestId": self.requestId, - } - return out - class v1TrialEarlyExit(Printable): """Signals to the experiment the trial early exited.""" @@ -16855,48 +16327,6 @@ class v1TrialEarlyExitExitedReason(DetEnum): INVALID_HP = "EXITED_REASON_INVALID_HP" INIT_INVALID_HP = "EXITED_REASON_INIT_INVALID_HP" -class v1TrialExitedEarly(Printable): - """TrialExitedEarly is a searcher event triggered when a trial exited - prematurely. - """ - - def __init__( - self, - *, - exitedReason: "v1TrialExitedEarlyExitedReason", - requestId: str, - ): - self.exitedReason = exitedReason - self.requestId = requestId - - @classmethod - def from_json(cls, obj: Json) -> "v1TrialExitedEarly": - kwargs: "typing.Dict[str, typing.Any]" = { - "exitedReason": v1TrialExitedEarlyExitedReason(obj["exitedReason"]), - "requestId": obj["requestId"], - } - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - "exitedReason": self.exitedReason.value, - "requestId": self.requestId, - } - return out - -class v1TrialExitedEarlyExitedReason(DetEnum): - """The reason for an early exit. - - EXITED_REASON_UNSPECIFIED: Zero-value (not allowed). - - EXITED_REASON_INVALID_HP: Indicates the trial exited due to an invalid hyperparameter. - - EXITED_REASON_USER_REQUESTED_STOP: Indicates the trial exited due to a user requested stop, from code. - - EXITED_REASON_USER_CANCELED: Indicates the trial exited due to a user requested stop, from the CLI or - UI. - """ - UNSPECIFIED = "EXITED_REASON_UNSPECIFIED" - INVALID_HP = "EXITED_REASON_INVALID_HP" - USER_REQUESTED_STOP = "EXITED_REASON_USER_REQUESTED_STOP" - USER_CANCELED = "EXITED_REASON_USER_CANCELED" - class v1TrialLogsFieldsResponse(Printable): """Response to TrialLogFieldsRequest.""" agentIds: "typing.Optional[typing.Sequence[str]]" = None @@ -17090,33 +16520,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["stepsCompleted"] = self.stepsCompleted return out -class v1TrialOperation(Printable): - """TrialOperation is any operation that a trial can perform while it is active.""" - validateAfter: "typing.Optional[v1ValidateAfterOperation]" = None - - def __init__( - self, - *, - validateAfter: "typing.Union[v1ValidateAfterOperation, None, Unset]" = _unset, - ): - if not isinstance(validateAfter, Unset): - self.validateAfter = validateAfter - - @classmethod - def from_json(cls, obj: Json) -> "v1TrialOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "validateAfter" in obj: - kwargs["validateAfter"] = v1ValidateAfterOperation.from_json(obj["validateAfter"]) if obj["validateAfter"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "validateAfter" in vars(self): - out["validateAfter"] = None if self.validateAfter is None else self.validateAfter.to_json(omit_unset) - return out - class v1TrialProfilerMetricLabels(Printable): agentId: "typing.Optional[str]" = None gpuUuid: "typing.Optional[str]" = None @@ -17206,35 +16609,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: } return out -class v1TrialProgress(Printable): - """TrialProgress is a searcher event that tells you the number of batches - completed in the trial. - """ - - def __init__( - self, - *, - partialUnits: float, - requestId: str, - ): - self.partialUnits = partialUnits - self.requestId = requestId - - @classmethod - def from_json(cls, obj: Json) -> "v1TrialProgress": - kwargs: "typing.Dict[str, typing.Any]" = { - "partialUnits": float(obj["partialUnits"]), - "requestId": obj["requestId"], - } - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - "partialUnits": dump_float(self.partialUnits), - "requestId": self.requestId, - } - return out - class v1TrialRunnerMetadata(Printable): """The metadata pertaining to the current running task for a trial.""" @@ -17258,43 +16632,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: } return out -class v1TrialSimulation(Printable): - """TrialSimulation is a specific sequence of workloads that were run before the - trial was completed. - """ - occurrences: "typing.Optional[int]" = None - operations: "typing.Optional[typing.Sequence[v1RunnableOperation]]" = None - - def __init__( - self, - *, - occurrences: "typing.Union[int, None, Unset]" = _unset, - operations: "typing.Union[typing.Sequence[v1RunnableOperation], None, Unset]" = _unset, - ): - if not isinstance(occurrences, Unset): - self.occurrences = occurrences - if not isinstance(operations, Unset): - self.operations = operations - - @classmethod - def from_json(cls, obj: Json) -> "v1TrialSimulation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "occurrences" in obj: - kwargs["occurrences"] = obj["occurrences"] - if "operations" in obj: - kwargs["operations"] = [v1RunnableOperation.from_json(x) for x in obj["operations"]] if obj["operations"] is not None else None - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "occurrences" in vars(self): - out["occurrences"] = self.occurrences - if not omit_unset or "operations" in vars(self): - out["operations"] = None if self.operations is None else [x.to_json(omit_unset) for x in self.operations] - return out - class v1TrialSourceInfo(Printable): modelId: "typing.Optional[int]" = None modelVersion: "typing.Optional[int]" = None @@ -17352,6 +16689,35 @@ class v1TrialSourceInfoType(DetEnum): INFERENCE = "TRIAL_SOURCE_INFO_TYPE_INFERENCE" FINE_TUNING = "TRIAL_SOURCE_INFO_TYPE_FINE_TUNING" +class v1TrialSummary(Printable): + """TrialSummary describes the runs that are estimated to train for a certain + length. + """ + + def __init__( + self, + *, + count: int, + unit: "v1SearchUnit", + ): + self.count = count + self.unit = unit + + @classmethod + def from_json(cls, obj: Json) -> "v1TrialSummary": + kwargs: "typing.Dict[str, typing.Any]" = { + "count": obj["count"], + "unit": v1SearchUnit.from_json(obj["unit"]), + } + return cls(**kwargs) + + def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: + out: "typing.Dict[str, typing.Any]" = { + "count": self.count, + "unit": self.unit.to_json(omit_unset), + } + return out + class v1TrialsSampleResponse(Printable): def __init__( @@ -18059,76 +17425,6 @@ def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: out["value"] = self.value return out -class v1ValidateAfterOperation(Printable): - """ValidateAfterOperation means the trial should train and validate after - training the given length. - """ - length: "typing.Optional[str]" = None - requestId: "typing.Optional[str]" = None - - def __init__( - self, - *, - length: "typing.Union[str, None, Unset]" = _unset, - requestId: "typing.Union[str, None, Unset]" = _unset, - ): - if not isinstance(length, Unset): - self.length = length - if not isinstance(requestId, Unset): - self.requestId = requestId - - @classmethod - def from_json(cls, obj: Json) -> "v1ValidateAfterOperation": - kwargs: "typing.Dict[str, typing.Any]" = { - } - if "length" in obj: - kwargs["length"] = obj["length"] - if "requestId" in obj: - kwargs["requestId"] = obj["requestId"] - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - } - if not omit_unset or "length" in vars(self): - out["length"] = self.length - if not omit_unset or "requestId" in vars(self): - out["requestId"] = self.requestId - return out - -class v1ValidationCompleted(Printable): - """ValidationCompleted is a searcher event triggered when a validation has been - completed. - """ - - def __init__( - self, - *, - metric: typing.Any, - requestId: str, - validateAfterLength: str, - ): - self.metric = metric - self.requestId = requestId - self.validateAfterLength = validateAfterLength - - @classmethod - def from_json(cls, obj: Json) -> "v1ValidationCompleted": - kwargs: "typing.Dict[str, typing.Any]" = { - "metric": obj["metric"], - "requestId": obj["requestId"], - "validateAfterLength": obj["validateAfterLength"], - } - return cls(**kwargs) - - def to_json(self, omit_unset: bool = False) -> typing.Dict[str, typing.Any]: - out: "typing.Dict[str, typing.Any]" = { - "metric": self.metric, - "requestId": self.requestId, - "validateAfterLength": self.validateAfterLength, - } - return out - class v1ValidationHistoryEntry(Printable): """ValidationHistoryEntry is a single entry for a validation history for an experiment. @@ -19268,33 +18564,6 @@ def get_CompareTrials( return v1CompareTrialsResponse.from_json(_resp.json()) raise APIHttpError("get_CompareTrials", _resp) -def post_CompleteTrialSearcherValidation( - session: "api.BaseSession", - *, - body: "v1CompleteValidateAfterOperation", - trialId: int, -) -> None: - """Reports to the searcher that the trial has completed the given searcher - operation. - - - body: The completed operation. - - trialId: The id of the trial. - """ - _params = None - _resp = session._do_request( - method="POST", - path=f"/api/v1/trials/{trialId}/searcher/completed_operation", - params=_params, - json=body.to_json(True), - data=None, - headers=None, - timeout=None, - stream=False, - ) - if _resp.status_code == 200: - return - raise APIHttpError("post_CompleteTrialSearcherValidation", _resp) - def post_ContinueExperiment( session: "api.BaseSession", *, @@ -20347,30 +19616,6 @@ def get_GetCommands( return v1GetCommandsResponse.from_json(_resp.json()) raise APIHttpError("get_GetCommands", _resp) -def get_GetCurrentTrialSearcherOperation( - session: "api.BaseSession", - *, - trialId: int, -) -> "v1GetCurrentTrialSearcherOperationResponse": - """Get the current searcher operation. - - - trialId: The id of the trial. - """ - _params = None - _resp = session._do_request( - method="GET", - path=f"/api/v1/trials/{trialId}/searcher/operation", - params=_params, - json=None, - data=None, - headers=None, - timeout=None, - stream=False, - ) - if _resp.status_code == 200: - return v1GetCurrentTrialSearcherOperationResponse.from_json(_resp.json()) - raise APIHttpError("get_GetCurrentTrialSearcherOperation", _resp) - def get_GetExperiment( session: "api.BaseSession", *, @@ -21780,30 +21025,6 @@ def get_GetRunMetadata( return v1GetRunMetadataResponse.from_json(_resp.json()) raise APIHttpError("get_GetRunMetadata", _resp) -def get_GetSearcherEvents( - session: "api.BaseSession", - *, - experimentId: int, -) -> "v1GetSearcherEventsResponse": - """Get the list of custom searcher events with long polling. - - - experimentId: The ID of the experiment. - """ - _params = None - _resp = session._do_request( - method="GET", - path=f"/api/v1/experiments/{experimentId}/searcher_events", - params=_params, - json=None, - data=None, - headers=None, - timeout=None, - stream=False, - ) - if _resp.status_code == 200: - return v1GetSearcherEventsResponse.from_json(_resp.json()) - raise APIHttpError("get_GetSearcherEvents", _resp) - def get_GetShell( session: "api.BaseSession", *, @@ -24582,31 +23803,6 @@ def post_PostRunMetadata( return v1PostRunMetadataResponse.from_json(_resp.json()) raise APIHttpError("post_PostRunMetadata", _resp) -def post_PostSearcherOperations( - session: "api.BaseSession", - *, - body: "v1PostSearcherOperationsRequest", - experimentId: int, -) -> None: - """Submit operations to a custom searcher. - - - experimentId: The experiment ID. - """ - _params = None - _resp = session._do_request( - method="POST", - path=f"/api/v1/experiments/{experimentId}/searcher_operations", - params=_params, - json=body.to_json(True), - data=None, - headers=None, - timeout=None, - stream=False, - ) - if _resp.status_code == 200: - return - raise APIHttpError("post_PostSearcherOperations", _resp) - def post_PostTaskLogs( session: "api.BaseSession", *, diff --git a/harness/determined/common/api/errors.py b/harness/determined/common/api/errors.py index 0d465dd9b37..99d49a88324 100644 --- a/harness/determined/common/api/errors.py +++ b/harness/determined/common/api/errors.py @@ -54,9 +54,10 @@ def __init__(self, response: requests.Response, *args: Any) -> None: try: self.response_error = response.json()["error"] m = self.response_error["error"] - except (ValueError, KeyError): + except (AttributeError, ValueError, KeyError): self.response_error = None - m = response.text + # Requests that don't go through the GRPC gateway will return error plain messages. + m = response.text if hasattr(response, "text") else response super().__init__(m, response, *args) self.status_code = response.status_code diff --git a/harness/determined/constants.py b/harness/determined/constants.py index 58e630474c4..62fe376aa14 100644 --- a/harness/determined/constants.py +++ b/harness/determined/constants.py @@ -1,25 +1,6 @@ import os MAX_SLOTS_PER_AGENT = 16 -# The default configs to use in when running test experiments. -# -# TODO: Unify the defaults used here with the defaults used in master. -DEFAULT_SEARCHER_CFG = {"name": "single", "max_length": {"batches": 100}} -DEFAULT_RESOURCES_CFG = {"slots_per_trial": 1, "native_parallel": False} -DEFAULT_SCHEDULING_UNIT = 100 -DEFAULT_OPTIMIZATIONS = { - "aggregation_frequency": 1, - "average_aggregated_gradients": True, - "average_training_metrics": True, - "gradient_compression": False, - "mixed_precision": "O0", -} -DEFAULT_EXP_CFG = { - "searcher": DEFAULT_SEARCHER_CFG, - "scheduling_unit": DEFAULT_SCHEDULING_UNIT, - "resources": DEFAULT_RESOURCES_CFG, - "optimizations": DEFAULT_OPTIMIZATIONS, -} # Until we implement a more automatic solution, expose a temporary workaround of # allowing ports to be changed using envionment variables for the rare case that diff --git a/harness/determined/core/__init__.py b/harness/determined/core/__init__.py index 90aface2937..3b4037e32d6 100644 --- a/harness/determined/core/__init__.py +++ b/harness/determined/core/__init__.py @@ -25,8 +25,10 @@ DummySearcherOperation, SearcherMode, SearcherContext, + SearcherContextMissing, SearcherOperation, Unit, + _parse_searcher_max_length, _parse_searcher_units, ) from determined.core._preempt import ( diff --git a/harness/determined/core/_context.py b/harness/determined/core/_context.py index 2398ba3c39b..8ea69270d15 100644 --- a/harness/determined/core/_context.py +++ b/harness/determined/core/_context.py @@ -58,7 +58,7 @@ def __init__( self.preempt = preempt or core.DummyPreemptContext(self.distributed) self.train = train or core.DummyTrainContext() self._metrics = _metrics or core._DummyMetricsContext() - self.searcher = searcher or core.DummySearcherContext(self.distributed) + self.searcher = searcher or core.SearcherContextMissing() self.info = info self.experimental = experimental or core.DummyExperimentalCoreContext() self.profiler = profiler or core.DummyProfilerContext() @@ -327,15 +327,19 @@ def init( tbd_writer, ) - units = core._parse_searcher_units(info.trial._config) - searcher = core.SearcherContext( - session, - distributed, - info.trial.trial_id, - info.trial._trial_run_id, - info.allocation_id, - units, - ) + # only provide a .searcher if max_length appears in the experiment config + max_length = core._parse_searcher_max_length(info.trial._config) + if not max_length: + searcher = None + else: + units = core._parse_searcher_units(info.trial._config) + searcher = core.SearcherContext( + session, + distributed, + info.trial.trial_id, + max_length, + units, + ) if storage_manager is None: storage_manager = storage.build( diff --git a/harness/determined/core/_distributed.py b/harness/determined/core/_distributed.py index de5ee92920b..d63a5d93114 100644 --- a/harness/determined/core/_distributed.py +++ b/harness/determined/core/_distributed.py @@ -1,13 +1,17 @@ +import json import logging import os import socket import tempfile -from typing import Any, Callable, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar from determined import constants, ipc, util logger = logging.getLogger("determined.core") +if TYPE_CHECKING: + import tensorflow + class DistributedContext: """ @@ -232,6 +236,63 @@ def from_torch_distributed(cls, chief_ip: Optional[str] = None) -> "DistributedC chief_ip=chief_ip or os.environ.get("DET_CHIEF_IP"), ) + @classmethod + def from_tf_config(cls) -> Tuple["DistributedContext", "tensorflow.distribute.Strategy"]: + """ + Create a DistributedContext and a tf.distribute.Strategy based on the TF_CONFIG environment + variable. + + Note that the ``determined.launch.tensorflow`` launcher will automatically create a + TF_CONFIG environment variable for on-cluster training, so you may not need to configure it + yourself. + + Presently, the only supported configurations are: + - MultiMirroredWorkerStrategy, when there are multiple nodes participating in training. + - Mirrored strategy, when there is one node but multiple GPUs for training. + - The default strategy otherwise. + """ + import tensorflow as tf + + if "TF_CONFIG" in os.environ: + tf_config = json.loads(os.environ["TF_CONFIG"]) + # We only support worker tasks + task_type = tf_config["task"]["type"] + if task_type != "worker": + raise RuntimeError( + "DistributedContext.from_tf_config() only supports the default strategy, " + "MirroredStrategy or MultiWorkerMirroredStrategy, but found unexpected " + f'task_type="{task_type}" in TF_CONFIG ({os.environ["TF_CONFIG"]})' + ) + assert tf_config["task"]["type"] == "worker", tf_config["task"]["type"] + num_workers = len(tf_config["cluster"]["worker"]) + if num_workers > 1: + # Multiple workers means MultiWorkerMirroredStrategy. + index = tf_config["task"]["index"] + chief_ip = tf_config["cluster"]["worker"][0].split(":")[0] + dist = cls( + rank=index, + size=num_workers, + local_rank=0, + local_size=1, + cross_rank=index, + cross_size=num_workers, + chief_ip=chief_ip, + ) + return dist, tf.distribute.MultiWorkerMirroredStrategy() + # Either no TF_CONFIG or there's only one worker. + dist = cls( + rank=0, + size=1, + local_rank=0, + local_size=1, + cross_rank=0, + cross_size=1, + ) + # Use a MirroredStrategy if we have multiple GPUs. + ngpus = len(tf.config.list_physical_devices("GPU")) + strategy = tf.distribute.MirroredStrategy() if ngpus > 1 else tf.distribute.get_strategy() + return dist, strategy + def close(self) -> None: # if statements in close() mirror the if statements of _init_ipc(). if self._closed or self.size < 2: diff --git a/harness/determined/core/_searcher.py b/harness/determined/core/_searcher.py index 7b6faaf1b58..5b8efe12d11 100644 --- a/harness/determined/core/_searcher.py +++ b/harness/determined/core/_searcher.py @@ -1,5 +1,6 @@ import enum import logging +import warnings from typing import Any, Iterator, Optional import determined as det @@ -15,6 +16,26 @@ class Unit(enum.Enum): BATCHES = "BATCHES" +def _parse_searcher_max_length(experiment_config: dict) -> Optional[int]: + searcher = experiment_config.get("searcher", {}) + + max_length = searcher.get("max_length") + if max_length is None: + return None + + if isinstance(max_length, int): + return max_length + + # assume something like {"epochs": 10} + assert isinstance(max_length, dict), max_length + values = max_length.values() + if not values: + return None + out = next(iter(values)) + assert isinstance(out, int), max_length + return out + + def _parse_searcher_units(experiment_config: dict) -> Optional[Unit]: searcher = experiment_config.get("searcher", {}) @@ -34,6 +55,9 @@ def convert_key(key: Any) -> Optional[Unit]: class SearcherOperation: """ + .. warning:: + SearcherOperation is deprecated in 0.38.0, and will be removed in a future version. + A ``SearcherOperation`` is a request from the hyperparameter-search logic for the training script to execute one train-validate-report cycle. @@ -64,6 +88,11 @@ def __init__( @property def length(self) -> int: """ + .. warning:: + SearcherOperation.length is deprecated in 0.38.0, and will be removed in a future + version. Instead, you should directly specify your training length in your training + code. + ``length`` represents the total amount of training which should be reached by the train step before the validate-report steps. """ @@ -71,6 +100,11 @@ def length(self) -> int: def report_progress(self, length: float) -> None: """ + .. warning:: + SearcherOperation.report_progress is deprecated in 0.38.0, and will be removed in a + future version. Instead, report progess with + :meth:`~determined.core.TrainContext.report_progress`. + ``report_progress()`` reports the training progress to the Determined master so the WebUI can show accurate progress to users. @@ -84,13 +118,19 @@ def report_progress(self, length: float) -> None: if self._completed and length != self._length: raise RuntimeError("you must not call op.report_progress() after op.report_completed()") logger.debug(f"op.report_progress({length})") + # get the floating point progress + progress = min(1.0, max(0.0, length / self._length)) self._session.post( f"/api/v1/trials/{self._trial_id}/progress", - data=det.util.json_encode({"progress": length}), + data=det.util.json_encode({"progress": progress, "is_raw": True}), ) def report_completed(self, searcher_metric: Any) -> None: """ + .. warning:: + SearcherOperation.report_completed is deprecated in 0.38.0, and will be removed in a + future version. Instead, just exit 0 when your training is complete. + ``report_completed()`` is the final step of a train-validate-report cycle. ``report_completed()`` requires the value of the metric you are searching over. This value @@ -103,16 +143,13 @@ def report_completed(self, searcher_metric: Any) -> None: if self._completed: raise RuntimeError("you may only call op.report_completed() once") self._completed = True - body = {"op": {"length": self._length}, "searcherMetric": searcher_metric} - logger.debug(f"op.report_completed({searcher_metric})") - self._session.post( - f"/api/v1/trials/{self._trial_id}/searcher/completed_operation", - data=det.util.json_encode(body), - ) class SearcherMode(enum.Enum): """ + .. warning:: + SearcherMode is deprecated in 0.38.0, and will be removed in a future version. + ``SearcherMode`` defines the calling behavior of the ``SearcherContext.operations()`` call. When mode is ``WorkersAskChief`` (the default), all workers must call @@ -130,6 +167,11 @@ class SearcherMode(enum.Enum): class SearcherContext: """ + .. warning:: + SearcherContext is deprecated in 0.38.0, and will be removed in a future version. Instead + of using ``SearcherContext.operations()`` to decide how long to train for, you should set + your training length directly in your training code. + ``SearcherContext`` gives direct access to operations emitted by the search algorithm in the master. Each ``SearcherOperation`` emitted has a (unitless) length that you should train for, then you complete the op by reporting the validation metric you are searching over. @@ -183,94 +225,95 @@ def __init__( session: api.Session, dist: core.DistributedContext, trial_id: int, - run_id: int, - allocation_id: str, + max_length: int, units: Optional[Unit] = None, ) -> None: self._session = session self._dist = dist self._trial_id = trial_id - self._run_id = run_id - self._allocation_id = allocation_id + self._length = max_length self._units = units - def _get_searcher_op(self) -> Optional[SearcherOperation]: - logger.debug("_get_searcher_op()") - r = self._session.get(f"/api/v1/trials/{self._trial_id}/searcher/operation") - body = r.json() - if body["completed"]: - return None - - # grpc-gateway encodes uint64 as a string, since it is bigger than a JavaScript `number`. - length = int(body["op"]["validateAfter"]["length"]) - is_chief = self._dist.rank == 0 - return SearcherOperation(self._session, self._trial_id, length=length, is_chief=is_chief) - def operations( self, searcher_mode: SearcherMode = SearcherMode.WorkersAskChief, auto_ack: bool = True, ) -> Iterator[SearcherOperation]: """ - Iterate through all the operations this searcher has to offer. + .. warning:: + SearcherContext.operations is deprecated in 0.38.0, and will be removed in a future + version. Instead, you should set your training length directly in your training code. - See :class:`~determined.core.SearcherMode` for details about calling requirements in - distributed training scenarios. + This method no longer talks to the Determined master; it just yields a single + ``SearcherOperation`` objects based on the ``searcher.max_length`` in the experiment config + (which is also deprecated). + """ + + warnings.warn( + "SearcherContext.operations() was deprecated in Determined 0.38.0 and will be removed " + "in a future version. Instead, you should set your training length directly in your " + "training code.", + FutureWarning, + stacklevel=2, + ) + + yield from self._operations(searcher_mode) - After training to the point specified by each ``SearcherOperation``, the chief, and only the - chief, must call ``op.report_completed(``) on each operation. This is true regardless of - the ``searcher_mode`` setting because the Determined master needs a clear, unambiguous - report of when an operation is completed. + def _operations( + self, + searcher_mode: SearcherMode = SearcherMode.WorkersAskChief, + ) -> Iterator[SearcherOperation]: """ - searcher_mode = SearcherMode(searcher_mode) + The internal-only version of .operations which doesn't show a warning. + This is meant to be called by other, deprecated things which internally depend on + .operations() and have their own deprecation warning. That way the user gets the + deprecation warning for what they actually used. + """ + + searcher_mode = SearcherMode(searcher_mode) + # Force the same synchronization behavior we used to have before fabricating operations. if self._dist.rank == 0: - # Chief gets operations from master. - while True: - op = self._get_searcher_op() - if searcher_mode == SearcherMode.WorkersAskChief: - # Broadcast op.length (or None) to workers. We broadcast just the length - # because SearcherOperation is not serializable, and the is_chief attribute - # obviously must be set on a per-worker basis. - _ = self._dist.broadcast(op and op.length) - if op is None: - if auto_ack: - self.acknowledge_out_of_ops() - break - yield op - if not op._completed: - raise RuntimeError("you must call op.report_completed() on each operation") + # Chief fabricates an op. + op = SearcherOperation(self._session, self._trial_id, self._length, True) + if searcher_mode == SearcherMode.WorkersAskChief: + # Broadcast op to workers. + _ = self._dist.broadcast(op and op.length) + yield op + if not op._completed: + raise RuntimeError("you must call op.report_completed() on each operation") + if searcher_mode == SearcherMode.WorkersAskChief: + _ = self._dist.broadcast(None) else: if searcher_mode != SearcherMode.WorkersAskChief: raise RuntimeError( - "you cannot call searcher.operations(searcher_mode=ChiefOnly) from a non-chief " - "worker." + "you cannot call searcher.operations(searcher_mode=ChiefOnly) " + "from a non-chief worker." ) # Worker gets operations from chief. while True: op_length = self._dist.broadcast(None) if op_length is None: break - yield SearcherOperation( - self._session, self._trial_id, length=op_length, is_chief=False - ) + yield SearcherOperation(self._session, self._trial_id, op_length, False) def acknowledge_out_of_ops(self) -> None: """ - acknowledge_out_of_ops() tells the Determined master that you are shutting down because - you have recognized the searcher has no more operations for you to complete at this time. - - This is important for the Determined master to know that it is safe to restart this process - should new operations be assigned to this trial. - - acknowledge_out_of_ops() is normally called automatically just before operations() raises a - StopIteration, unless operations() is called with auto_ack=False. + .. warning:: + SearcherContext.acknowledge_out_of_ops() is deprecated in 0.38.0, and will be removed in + a future version. Current calls to this function are ignored, and there should not need + to be a replacement. """ - logger.debug(f"acknowledge_out_of_ops(allocation_id:{self._allocation_id})") - self._session.post(f"/api/v1/allocations/{self._allocation_id}/signals/ack_preemption") + pass def get_configured_units(self) -> Optional[Unit]: """ + .. warning:: + SearcherContext.get_configured_units() is deprecated in 0.38.0, and will be removed in + a future version. Note that the ``searcher.max_length`` filed of the experiment config + is also deprecated and will be removed as well. Instead, you should directly specify + your training length in your training code. + get_configured_units() reports what units were used in the searcher field of the experiment config. If no units were configured, None is returned. @@ -330,6 +373,14 @@ def operations( searcher_mode: SearcherMode = SearcherMode.WorkersAskChief, auto_ack: bool = True, ) -> Iterator[SearcherOperation]: + warnings.warn( + "SearcherContext.operations() was deprecated in Determined 0.38.0 and will be removed " + "in a future version. Instead, you should set your training length directly in your " + "training code.", + FutureWarning, + stacklevel=2, + ) + searcher_mode = SearcherMode(searcher_mode) # Force the same synchronization behavior in the DummySearcherContext as the real one. if self._dist.rank == 0: @@ -361,3 +412,36 @@ def acknowledge_out_of_ops(self) -> None: def get_configured_units(self) -> Optional[Unit]: return Unit.EPOCHS + + +class SearcherContextMissing(SearcherContext): + def __init__(self) -> None: + pass + + def operations( + self, + searcher_mode: SearcherMode = SearcherMode.WorkersAskChief, + auto_ack: bool = True, + ) -> Iterator[SearcherOperation]: + raise ValueError( + "SearcherContext was not created because your experiment config does not have the " + "searcher.max_length set. Both the searcher.max_length and the SearcherContext are " + "deprecated. Instead, you should specify your training length directly in your " + "training code and avoid all calls to core_context.searcher." + ) + + def acknowledge_out_of_ops(self) -> None: + raise ValueError( + "SearcherContext was not created because your experiment config does not have the " + "searcher.max_length set. Both the searcher.max_length and the SearcherContext are " + "deprecated. Instead, you should specify your training length directly in your " + "training code and avoid all calls to core_context.searcher." + ) + + def get_configured_units(self) -> Optional[Unit]: + raise ValueError( + "SearcherContext was not created because your experiment config does not have the " + "searcher.max_length set. Both the searcher.max_length and the SearcherContext are " + "deprecated. Instead, you should specify your training length directly in your " + "training code and avoid all calls to core_context.searcher." + ) diff --git a/harness/determined/core/_train.py b/harness/determined/core/_train.py index f030db6f3af..7a8d7912c7c 100644 --- a/harness/determined/core/_train.py +++ b/harness/determined/core/_train.py @@ -270,9 +270,6 @@ def report_progress(self, progress: float) -> None: should represent the current iteration step as a fraction of maximum training steps (i.e.: `report_progress(step_num / max_steps)`). - Note that for hyperparameter search, progress should be reported through - ``SearcherOperation.report_progress()`` in the Searcher API instead. - Arguments: progress (float): completion progress in the range [0, 1.0]. """ diff --git a/harness/determined/exec/harness.py b/harness/determined/exec/harness.py index ae950891e13..65b9932c746 100644 --- a/harness/determined/exec/harness.py +++ b/harness/determined/exec/harness.py @@ -3,6 +3,7 @@ import faulthandler import logging import sys +import warnings from typing import Iterator, Optional, Type import determined as det @@ -38,8 +39,13 @@ def main(train_entrypoint: str) -> int: # We can't import pytorch directly because if running TfKerasTrials with an image that contains # both torch and keras, keras will throw exceptions due to unexpected CUDNN library versions. - if hasattr(det, "pytorch") and issubclass(trial_class, det.pytorch.PyTorchTrial): - return _run_pytorch_trial(trial_class, info) + if hasattr(det, "pytorch"): + if hasattr(det.pytorch, "deepspeed") and issubclass( + trial_class, det.pytorch.deepspeed.DeepSpeedTrial + ): + return _run_deepspeed_trial(trial_class, info, train_entrypoint) + elif issubclass(trial_class, det.pytorch.PyTorchTrial): + return _run_pytorch_trial(trial_class, info, train_entrypoint) # TODO: Don't include EnvContext object in the future high-level APIs for PyTorch or Keras. # It was natural to create this big-blob-of-config object, but it was a mistake to pass it into @@ -133,11 +139,25 @@ def main(train_entrypoint: str) -> int: def _run_pytorch_trial( trial_class: "Type[det.pytorch.PyTorchTrial]", info: det.ClusterInfo, + train_entrypoint: str, ) -> int: from determined import pytorch det.common.set_logger(info.trial._debug) + # Only warn here if the user set a legacy entrypoint, not if we arrived here after user passed + # a --trial argument to a launcher. + if train_entrypoint == info.trial._config["entrypoint"]: + warnings.warn( + f"Support for legacy entrypoint format ({train_entrypoint}) has been deprecated in " + "Determined 0.38.0 and will be removed in a future version. You can keep your " + "PyTorchTrial, but please replace your model_def:TrialClass-style entrypoint " + "with a script-style entrypoint, and use the det.pytorch.Trainer() to train your " + "PyTorchTrial.", + FutureWarning, + stacklevel=2, + ) + logger.debug("Starting harness.") with maybe_periodic_stacktraces(info.trial._debug): @@ -194,6 +214,72 @@ def _run_pytorch_trial( return 0 +def _run_deepspeed_trial( + trial_class: "Type[det.pytorch.deepspeed.DeepSpeedTrial]", + info: det.ClusterInfo, + train_entrypoint: str, +) -> int: + from determined import pytorch + from determined.pytorch import deepspeed as det_ds + + det.common.set_logger(info.trial._debug) + + # Only warn here if the user set a legacy entrypoint, not if we arrived here after user passed + # a --trial argument to a launcher. + if train_entrypoint == info.trial._config["entrypoint"]: + warnings.warn( + f"Support for legacy entrypoint format ({train_entrypoint}) has been deprecated in " + "Determined 0.38.0 and will be removed in a future version. You can keep your " + "DeepSpeedTrial, but please replace your model_def:TrialClass-style entrypoint " + "with a script-style entrypoint, and use the new det.pytorch.deepspeed.Trainer() " + "to train your DeepSpeedTrial.", + FutureWarning, + stacklevel=2, + ) + + logger.debug("Starting harness.") + + with det_ds.init( + hparams=info.trial.hparams, + exp_conf=info.trial._config, + ) as train_context: + trial_inst = trial_class(train_context) + + if train_context.distributed.size > 1 and not train_context.distributed.rank == 0: + log_level = logging.DEBUG if info.trial._debug else logging.WARNING + logging.getLogger().setLevel(log_level) + + logger.info( + f"Creating {det_ds.DeepSpeedTrialController.__name__} with {trial_class.__name__}." + ) + + trainer = det_ds.Trainer(trial_inst, train_context) + + if "global_batch_size" in info.trial.hparams: + global_batch_size = int(info.trial.hparams["global_batch_size"]) # type: Optional[int] + else: + global_batch_size = None + + trainer.fit( + checkpoint_period=pytorch.TrainUnit._from_values( + **info.trial._config["min_checkpoint_period"], + global_batch_size=global_batch_size, + ), + validation_period=pytorch.TrainUnit._from_values( + **info.trial._config["min_validation_period"], + global_batch_size=global_batch_size, + ), + reporting_period=pytorch.Batch(info.trial._config["scheduling_unit"]), + checkpoint_policy=info.trial._config["checkpoint_policy"], + latest_checkpoint=info.latest_checkpoint, + step_zero_validation=info.trial._config["perform_initial_validation"], + test_mode=False, + profiling_enabled=bool(info.trial._config["profiling"]["enabled"]), + ) + + return 0 + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("train_entrypoint") diff --git a/harness/determined/experimental/core_v2/_core_context_v2.py b/harness/determined/experimental/core_v2/_core_context_v2.py index 3a0c4c624a2..9cc4eb06d43 100644 --- a/harness/determined/experimental/core_v2/_core_context_v2.py +++ b/harness/determined/experimental/core_v2/_core_context_v2.py @@ -120,15 +120,19 @@ def _make_v2_context( tensorboard_manager, tbd_writer, ) - units = core._parse_searcher_units(info.trial._config) - searcher = core.SearcherContext( - session, - distributed, - info.trial.trial_id, - info.trial._trial_run_id, - info.allocation_id, - units, - ) + # only provide a .searcher if max_length appears in the experiment config + max_length = core._parse_searcher_max_length(info.trial._config) + if not max_length: + searcher = None + else: + units = core._parse_searcher_units(info.trial._config) + searcher = core.SearcherContext( + session, + distributed, + info.trial.trial_id, + max_length, + units, + ) if storage_manager is None: if has_storage: diff --git a/harness/determined/experimental/core_v2/_core_v2.py b/harness/determined/experimental/core_v2/_core_v2.py index 6ef5f7141a1..1d565add316 100644 --- a/harness/determined/experimental/core_v2/_core_v2.py +++ b/harness/determined/experimental/core_v2/_core_v2.py @@ -189,7 +189,6 @@ def _init_context( or { "name": "single", "metric": "unmanaged", - "max_length": 100000000, }, "workspace": config.workspace, "project": config.project, diff --git a/harness/determined/keras/__init__.py b/harness/determined/keras/__init__.py index e47936d3924..57e755cce0a 100644 --- a/harness/determined/keras/__init__.py +++ b/harness/determined/keras/__init__.py @@ -20,3 +20,4 @@ from determined.keras._tf_keras_trial import TFKerasTrial, TFKerasTrialController from determined.keras._load import load_model_from_checkpoint_path from determined.keras._tf_rng import get_rng_state, set_rng_state +from determined.keras._callback import DeterminedCallback, TensorBoard diff --git a/harness/determined/keras/_callback.py b/harness/determined/keras/_callback.py new file mode 100644 index 00000000000..843d15f766e --- /dev/null +++ b/harness/determined/keras/_callback.py @@ -0,0 +1,465 @@ +import contextlib +import logging +import os +import pickle +import shutil +import tempfile +from typing import Any, Dict, Optional, Tuple, Union + +from tensorflow.keras import callbacks, models + +import determined as det +from determined import core + +logger = logging.getLogger("determined.keras") + + +class DeterminedCallback(callbacks.ProgbarLogger): # type: ignore + """ + DeterminedCallback adds Determined tracking, checkpointing, pausing, and restoring to a Keras + ``model.fit()`` call. Just include it as one of your callbacks. + + DeterminedCallback must not be used with a BackupAndRestore callback or a ModelCheckpoint + callback, which have conflicting behaviors. + + When using DeterminedCallback: + - The ``initial_epoch`` parameter to ``model.fit()`` will be overridden. Rely on the + ``checkpoint`` and ``continue_id`` parameters to DeterminedCallback instead. + - Checkpoints are saved and uploaded to Determined's checkpoint storage every epoch by + default, but can be saved less frequently based on the ``checkpoint_epochs`` parameter. + Checkpoints are always saved when training finishes or is preempted. + - Training will check for preemption every epoch. This means, for instance, if you click the + "pause" button in the UI, training will continue until the next epoch boundary. + - The normal verbose=1 TQDM progress bars are replaced with a more log-friendly output. + - By default, checkpoints are saved with ``model.save_weights()`` and restored with + ``model.load_weights()``. This is configurable by subclassing DeterminedCallback and + implementing custom ``save_model`` and ``load_model`` methods. + - By default, weights are saved to the path ``model_checkpoint`` inside the checkpoint + directory, which you can pass to ``model.load_weights()`` to load a trained model from a + downloaded checkpoint after training is complete. + + Arguments: + core_context: the result of a ``det.core.init()`` call + checkpoint: Either None, or a checkpoint uuid to start from. When you are training + on-cluster, this is likely the output of ``det.get_cluster_info().latest_checkpoint``. + continue_id: A unique identifier that is saved with the checkpoint. When you are training + on-cluster, this is likely the output of ``det.get_cluster_info().trial.trial_id``. + When loading an existing checkpoint, if the provided continue_id matches what was in the + checkpoint, training will continue from the epoch where it left off (a pause-and-unpause + scenario). If the provided continue_id does not match the checkpoint, the model weights + will be loaded but training will begin from epoch=0 (a warm-start scenario). + train_metrics_report_period: Either the string ``"epoch"`` or a number of batches to wait + between reporting training metrics to Determined master. Default: ``"epoch"``. + checkpoint_epochs: Save every N epochs. Checkpoints are always saved when training is + preempted, or at the end of training. A value of `0` means to only save at those times. + Default: 1. + + See also: + - :meth:`DistributedContext.from_tf_config + ` + """ + + _chief_worker_only = False + _supports_tf_logs = False + + def __init__( + self, + core_context: core.Context, + checkpoint: Optional[str], + continue_id: Union[int, str], + *, + train_metrics_report_period: Union[int, str] = "epoch", + checkpoint_epochs: int = 1, + ) -> None: + # We subclass ProgbarLogger to disable standard verbose=1 behavior, but really we don't + # want any of its actual behavior. So __init__ the supersuper class directly. + callbacks.Callback.__init__(self) + self._core = core_context + self._checkpoint = checkpoint + self._continue_id = continue_id + self._report_period = train_metrics_report_period + self._checkpoint_epochs = checkpoint_epochs + + self._is_chief = core_context.distributed.rank == 0 + self._is_verbose = False # Configured by .set_params(). + + # This is an undocumented workaround in case off-cluster user-code saving goes awry. + self._save_user_code = True + + self._steps_completed = 0 + + # We only track the last value from on_epoch_begin() in order to handle off-epoch reporting + # of training metrics. + self._epoch = -1 + + # Track on_epoch_end() calls, and the last on_epoch_end() where we saved a checkpoint, in + # order to be able to decide if we have any uncheckpointed work when we hit on_train_end(). + self._last_train_epoch = -1 + self._last_ckpt_epoch = -1 + + # progress + self._training_length: Optional[int] = None + self._validation_length: Optional[int] = None + self._training_batches = 0 + self._validation_batches = 0 + self._percent_reported = -1 + + # We download the checkpoint, then have to keep it for a while until we can delete it. + self._ckpt_context: Optional[contextlib.ExitStack] = None + + # Mask the inherited ProgbarLogger behavior that we don't actually want. + def set_params(self, params: Dict[str, Any]) -> None: + callbacks.Callback.set_params(self, params) + self._is_verbose = bool(params.get("verbose", 0) != 0) + + # Mask the inherited ProgbarLogger behavior that we don't actually want. + def on_predict_begin(self, logs: Optional[Dict[str, Any]]) -> None: + pass + + # Mask the inherited ProgbarLogger behavior that we don't actually want. + def on_predict_batch_end(self, batch: int, logs: Optional[Dict[str, Any]]) -> None: + pass + + # Mask the inherited ProgbarLogger behavior that we don't actually want. + def on_predict_end(self, logs: Optional[Dict[str, Any]]) -> None: + pass + + def _implements_train_batch_hooks(self) -> bool: + return True + + def _implements_test_batch_hooks(self) -> bool: + # Tell keras that we don't need on_test_batch_end unless we are verbose. + return self._is_verbose + + def _implements_predict_batch_hooks(self) -> bool: + # Tell keras that we don't actually want any on_predict_batch_end calls. + return False + + def _print_progress( + self, logs: Optional[Dict[str, Any]], training: bool, batches: int, total: Optional[int] + ) -> None: + # Only report progress if we have the target total. + if total is None: + return + + # Don't report more often than 10% increments. + percent_10 = int((batches / total) * 10) * 10 + if percent_10 <= self._percent_reported: + return + + # When you do report, report to 1% accuracy. + percent = int((batches / total) * 100) + self._percent_reported = percent + + if training: + report = ( + f"total batches trained: {self._steps_completed}, " + f"epoch {percent}% complete ({batches}/{total})" + ) + else: + report = ( + f"validation after batch: {self._steps_completed}, " + f"validation {percent}% complete ({batches}/{total})" + ) + if logs is not None: + metrics = {k: v for k, v in logs.items() if k not in ("batch", "size")} + report += f": {metrics}" + + print(report) + + def on_train_begin(self, logs: Optional[Dict[str, Any]]) -> None: + # Load initial state. Note that we might set model._training_state, which is how we + # override the initial_epoch provided to model.fit(), but this callback occurs just before + # model.fit() trys to read model._training_state. + self._ckpt_context = self._load(self._checkpoint) + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict[str, Any]]) -> None: + self._epoch = epoch + if not self._is_chief: + return + + # Set status. + self._core.train.set_status("training") + + # Print progress. + if self._is_verbose: + self._training_batches = 0 + self._percent_reported = -1 + self._print_progress(logs=None, training=True, batches=0, total=self._training_length) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict[str, Any]]) -> None: + self._steps_completed += 1 + + # Delete the initial checkpoint files, if we haven't already. + if self._ckpt_context: + self._ckpt_context.close() + self._ckpt_context = None + + if not self._is_chief: + return + + assert logs + + # Report metrics. + if ( + isinstance(self._report_period, int) + and self._steps_completed % self._report_period == 0 + ): + # Skip non-metrics data from logs. + metrics = {k: v for k, v in logs.items() if k not in ("batch", "size")} + # Add epochs and batches. + metrics["epochs"] = metrics.get("epochs", self._epoch + 1) + metrics["batches"] = metrics.get("batches", self._steps_completed) + self._core.train.report_metrics("training", self._steps_completed, metrics) + + # Print progress. + if self._is_verbose: + self._training_batches += 1 + self._print_progress( + logs=logs, + training=True, + batches=self._training_batches, + total=self._training_length, + ) + + def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, Any]]) -> None: + # Report metrics. + if self._is_chief and self._report_period == "epoch": + assert logs + # Filter out the validation logs. + metrics = {k: v for k, v in logs.items() if not k.startswith("val_")} + metrics["epochs"] = metrics.get("epochs", epoch + 1) + metrics["batches"] = metrics.get("batches", self._steps_completed) + self._core.train.report_metrics("training", self._steps_completed, metrics) + + # Report progress. + if self._is_chief and self.params["epochs"]: + self._core.train.report_progress((epoch + 1) / self.params["epochs"]) + + # Save a checkpoint. + self._last_train_epoch = epoch + if self._checkpoint_epochs > 0 and (epoch + 1) % self._checkpoint_epochs == 0: + self._save(epoch) + self._last_ckpt_epoch = epoch + + # Check for preemption. Checkpointing time can be non-negligible, so we check for + # preemption here after possibly saving a checkpoint. If we didn't save a checkpoint but we + # did get preempted, we'll catch that in the checkpoint fallback in on_train_end(). + if self._core.preempt.should_preempt(): + self.model.stop_training = True + + # Remember how many batches we trained, for next time. + if self._is_chief: + self._training_length = self._training_batches + + def on_test_begin(self, logs: Optional[Dict[str, Any]]) -> None: + if not self._is_chief: + return + + # Set status. + self._core.train.set_status("validating") + + # Print progress. + if self._is_verbose: + self._validation_batches = 0 + self._percent_reported = -1 + self._print_progress( + logs=None, training=False, batches=0, total=self._validation_length + ) + + def on_test_batch_end(self, batch: int, logs: Optional[Dict[str, Any]]) -> None: + # Print progress. + if self._is_chief and self._is_verbose: + self._validation_batches += 1 + self._print_progress( + logs=logs, + training=False, + batches=self._validation_batches, + total=self._validation_length, + ) + + def on_test_end(self, logs: Optional[Dict[str, Any]]) -> None: + if not self._is_chief: + return + + assert logs + metrics = {**logs} + metrics["epochs"] = metrics.get("epochs", self._epoch + 1) + metrics["batches"] = metrics.get("batches", self._steps_completed) + self._core.train.report_metrics("validation", self._steps_completed, metrics) + + # Remember how many batches we trained, for next time. + self._validation_length = self._validation_batches + + def on_train_end(self, logs: Optional[Dict[str, Any]]) -> None: + # Are we exiting with some amount of uncheckpointed training? + if self._last_train_epoch > self._last_ckpt_epoch: + self._save(self._last_train_epoch) + + if self._is_chief: + self._core.train.set_status("finishing") + + def _save(self, epoch: int) -> None: + if self._is_chief: + self._core.train.set_status("checkpointing") + + metadata = {"steps_completed": self._steps_completed} + # Use shard=True because keras wants every worker to write a checkpoint, even though every + # worker except the chief will end up deleting it. + with self._core.checkpoint.store_path(metadata, shard=True) as (path, storage_id): + # Save the model. + self.save_model(self.model, str(path / "model_checkpoint"), self._core.distributed) + # Only the chief saves the callback state and user code + if self._is_chief: + with (path / "callback_state").open("wb") as f: + state = { + "epoch": epoch, + "steps_completed": self._steps_completed, + "continue_id": self._continue_id, + "training_length": self._training_length, + "validation_length": self._validation_length, + } + pickle.dump(state, f) + # Save user code. + if self._save_user_code: + det.util.write_user_code(path, on_cluster=det.get_cluster_info() is not None) + + def _load(self, checkpoint: Optional[str]) -> Optional[contextlib.ExitStack]: + if checkpoint is None: + return None + + if self._is_chief: + self._core.train.set_status("restoring") + + with contextlib.ExitStack() as exit_stack: + path = exit_stack.enter_context(self._core.checkpoint.restore_path(checkpoint)) + + # Load model. + self.load_model(self.model, str(path / "model_checkpoint"), self._core.distributed) + + # Load training state also. + state_path = path / "callback_state" + if not state_path.exists(): + return None + with state_path.open("rb") as f: + state = pickle.load(f) + if state["continue_id"] != self._continue_id: + return None + # Continue training where we left off. + self._steps_completed = state["steps_completed"] + self._training_length = state["training_length"] + self._validation_length = state["validation_length"] + initial_epoch: int = state["epoch"] + 1 + + # HACK: Trick the training loop into starting on a different epoch. Internally, this is + # how keras.callbacks.BackupAndRestore() sets the initial_epoch. + class WorkerTrainingState: + # For tf.keras. + def maybe_load_initial_epoch_from_ckpt(*_: Any, **__: Any) -> int: + return initial_epoch + + # For plain keras. + def maybe_load_initial_counters_from_ckpt(*_: Any, **__: Any) -> Tuple[int, int]: + # We only save on epoch boundaries. + initial_batch = 0 + return initial_epoch, initial_batch + + self.model._training_state = WorkerTrainingState() + + # Success! Don't delete the checkpoint until after the first batch runs though, because + # the checkpoint isn't actually read until then. + return exit_stack.pop_all() + + # mypy thinks it's possible to arrive here, but it isn't. + raise RuntimeError("impossible codepath") + + def save_model( + self, model: models.Model, path: str, distributed: core.DistributedContext + ) -> None: + """ + Users can subclass this if they need to customize how they save their model. + + This method is responsible for meeting the requirements of checkpointing according to the + needs of the active Strategy. + + See the `TensorFlow docs`_ for more details. + + Arguments: + model: the model to save + path: the destination to save to + distributed: the value of core_context.distributed, which can be used for detecting + the current process's rank, or inter-worker coordination, as needed. + + .. _TensorFlow docs: + https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras + #model_saving_and_loading + """ + + # MultiWorkerMirroredStrategy requires everyone to save the model (to access shared + # variables or something) but you have to delete the non-chief copies. Brilliant. + if distributed.rank == 0: + model.save_weights(path) + else: + tempdir = tempfile.mkdtemp("save-worker-model") + try: + model.save_weights(os.path.join(tempdir, "model_checkpoint")) + finally: + shutil.rmtree(tempdir) + + def load_model( + self, model: models.Model, path: str, distributed: core.DistributedContext + ) -> None: + """ + Users can subclass this if they need to customize how they load their model. + + Arguments: + model: the model to load + path: the destination to load from + distributed: the value of core_context.distributed, which can be used for detecting + the current process's rank, or inter-worker coordination, as needed. + """ + + # Users can subclass this if they just need to change how they save their model. + model.load_weights(path) + + +class TensorBoard(callbacks.TensorBoard): # type: ignore + """ + This is a thin wrapper over the TensorBoard callback that ships with ``tf.keras``. For more + information, see the :ref:`TensorBoard Guide ` or the upstream docs for + `tf.keras.callbacks.TensorBoard + `__. + + Note that if a ``log_dir`` argument is passed to the constructor, it will be ignored if the + ``core_context`` is configured for tensorboard (which is the default when on-cluster). + """ + + def __init__(self, core_context: core.Context, *args: Any, **kwargs: Any): + det_tb_path = core_context.train.get_tensorboard_path() + if det_tb_path: + if "log_dir" in kwargs: + user_log_dir = kwargs.pop("log_dir") + logger.warning( + f"arg log_dir={user_log_dir} to det.keras.TensorBoard will be ignored" + ) + elif args: + user_log_dir, args = args[0], args[1:] + logger.warning( + f"arg log_dir={user_log_dir} to det.keras.TensorBoard will be ignored" + ) + args = [det_tb_path, *args] # type: ignore + super().__init__(*args, **kwargs) + + def _write_logs(self, *args: Any) -> None: + """ + _write_logs calls the original _write_logs() function from the Keras + TensorBoard callback. After the logs are flushed to disk, we close and + reopen the tf event writer so that it serializes the next set of logs + to a new file. This allows the tensorboard manager to treat the + written files as immutable and upload them to persistent storage + without later having to append to them. This behavior is useful for + tensorboard backed by S3. + """ + super()._write_logs(*args) + self.writer.close() + self.writer.reopen() diff --git a/harness/determined/keras/_load.py b/harness/determined/keras/_load.py index bcb6a147c47..8aa68d8b212 100644 --- a/harness/determined/keras/_load.py +++ b/harness/determined/keras/_load.py @@ -1,6 +1,7 @@ import json import logging import pathlib +import warnings from typing import List, Optional, cast import tensorflow as tf @@ -34,8 +35,25 @@ def load_model_from_checkpoint_path( tags (list string, optional): Specifies which tags are loaded from the TensorFlow SavedModel. See documentation for `tf.compat.v1.saved_model.load_v2 `_. + + .. warning:: + + load_model_from_checkpoint_path has been deprecated in Determined 0.38.0 and will be removed + in a future version. This function is designed to work with TFKerasTrial, which is also + deprecated. Please use the new :class:`~determined.keras.DeterminedCallback` for + training instead, which allows you to use ``model.load_weights()`` to restore from + checkpoints. """ + warnings.warn( + "load_model_from_checkpoint_path has been deprecated in Determined 0.38.0 and will be " + "removedin a future version. This function is designed to work with TFKerasTrial, which " + "is alsodeprecated. Please use the new det.keras.DeterminedCallback fortraining instead, " + "which allows you to use ``model.load_weights()`` to restore fromcheckpoints.", + FutureWarning, + stacklevel=2, + ) + ckpt_dir = pathlib.Path(path) load_data_path = ckpt_dir.joinpath("load_data.json") metadata_path = ckpt_dir.joinpath("metadata.json") diff --git a/harness/determined/keras/_tensorboard_callback.py b/harness/determined/keras/_tensorboard_callback.py index bcc6038e9c2..e7090b427b5 100644 --- a/harness/determined/keras/_tensorboard_callback.py +++ b/harness/determined/keras/_tensorboard_callback.py @@ -1,4 +1,5 @@ import logging +import warnings from typing import Any from determined.keras import callbacks @@ -8,9 +9,11 @@ class TFKerasTensorBoard(callbacks.TensorBoard): def __init__(self, *args: Any, **kwargs: Any): - logger.warning( + warnings.warn( "det.keras.TFKerasTensorBoard is a deprecated name for " - "det.keras.callbacks.TensorBoard, please update your code." + "det.keras.callbacks.TensorBoard, please update your code.", + FutureWarning, + stacklevel=2, ) # Avoid using super() due to a diamond inheritance pattern. callbacks.TensorBoard.__init__(self, *args, **kwargs) diff --git a/harness/determined/keras/_tf_keras_trial.py b/harness/determined/keras/_tf_keras_trial.py index e8b7c01c2dc..5e832c02fa3 100644 --- a/harness/determined/keras/_tf_keras_trial.py +++ b/harness/determined/keras/_tf_keras_trial.py @@ -7,6 +7,7 @@ import random import sys import time +import warnings from typing import Any, Dict, List, Optional, Tuple, Type, cast import h5py @@ -179,6 +180,14 @@ def pre_execute_hook( env: det.EnvContext, distributed_backend: det._DistributedBackend, ) -> None: + # TFKerasTrial's __init__ method may not be called by user-defined subclasses, so we fire + # the warning here. Also, this will show up before some of the worst tf log vomit, I hope. + warnings.warn( + "TFKerasTrial has been deprecated in Determined 0.38.0 and will be removed in a future " + "version. Please use the new det.keras.DeterminedCallback for training.", + FutureWarning, + stacklevel=2, + ) # Initialize the correct horovod. if distributed_backend.use_horovod(): hvd = horovod.hvd @@ -881,6 +890,13 @@ def _post_train_batch_end(self, num_inputs: int, logs: Dict) -> None: if self.env.experiment_config.average_training_metrics_enabled(): final_metrics = self._allreduce_logs(final_metrics) + # Inject batches and epochs into avg metrics. + # (this is after batches and possibly epochs have been updated) + final_metrics["batches"] = final_metrics.get( + "batches", self.multiplexer.state.total_batches + ) + final_metrics["epochs"] = final_metrics.get("epochs", self.multiplexer.state.epoch) + self.multiplexer._train_workload_end(final_metrics) self._stop_training_check() @@ -953,6 +969,11 @@ def _compute_validation_metrics(self) -> workload.Response: step_duration = time.time() - validation_start_time logger.info(det.util.make_timing_log("validated", step_duration, num_inputs, num_batches)) + # Inject batches and epochs into validation metrics. + # (this is after batches and possibly epochs have been updated) + metrics["batches"] = metrics.get("batches", self.multiplexer.state.total_batches) + metrics["epochs"] = metrics.get("epochs", self.multiplexer.state.epoch) + self.metric_writer.on_validation_step_end(self.steps_completed, metrics) self.upload_tb_files() return {"num_inputs": num_inputs, "validation_metrics": metrics} @@ -992,6 +1013,11 @@ class TFKerasTrial(det.LegacyTrial): ``tf.compat.v1.disable_eager_execution`` after your import statements. If you are using TensorFlow 1.x in eager mode, please add ``experimental_run_tf_function=False`` to your model compile function. + + .. warning:: + + TFKerasTrial has been deprecated in Determined 0.38.0 and will be removed in a future + version. Please use the new :class:`~determined.keras.DeterminedCallback` for training. """ trial_controller_class = TFKerasTrialController diff --git a/harness/determined/keras/callbacks.py b/harness/determined/keras/callbacks.py index cf404a85c17..f0f03823adc 100644 --- a/harness/determined/keras/callbacks.py +++ b/harness/determined/keras/callbacks.py @@ -1,6 +1,7 @@ import logging import pathlib import time +import warnings from typing import Any, Dict, List, Optional import tensorflow as tf @@ -35,12 +36,19 @@ class Callback(tf.keras.callbacks.Callback): # type: ignore * The tf.keras version of ``EarlyStopping`` will not work right in Determined. You should use you should use :class:`determined.keras.callbacks.EarlyStopping` instead. * The tf.keras version of ``ReduceLROnPlateau`` will not work right in Determined. You - should use you should use :class:`determined.keras.callbacks.ReduceLRScheduler` + should use you should use :class:`determined.keras.callbacks.ReduceLROnPlateau` instead. The Determined versions are based around ``on_test_end`` rather than ``on_epoch_end``, which can be influenced by setting ``min_validation_period`` in the experiment configuration. + + .. warning:: + + det.keras.callbacks.Callback has been deprecated in Determined 0.38.0 and will be removed + in a future version. This Callback class is designed to work with TFKerasTrial, which is + also deprecated. Please use the new :class:`~determined.keras.DeterminedCallback` for + training, and use normal keras Callabacks with it. """ def on_train_workload_begin( @@ -539,6 +547,14 @@ class EarlyStopping(tf.keras.callbacks.EarlyStopping, Callback): # type: ignore In Determined, ``on_test_end`` may be called slightly more often than ``min_validation_period`` during some types of hyperparameter searches, but it is unlikely for that to occur often enough have a meaningful impact on this callback's operation. + + .. warning:: + + EarlyStopping has been deprecated in Determined 0.38.0 and will be removed in a future + version. Determined's EarlyStopping is a customization of keras' EarlyStopping callback + that is specific to TFKerasTrial, which is also deprecated. Please use the new + :class:`~determined.keras.DeterminedCallback` for training, and use keras' EarlyStopping + with it. """ _savable_attributes = { @@ -576,6 +592,16 @@ def __init__(self, *arg: Any, **kwarg: Any) -> None: tf.keras.callbacks.EarlyStopping.__init__(self, *arg, **kwarg) self.test_end_count = 0 + warnings.warn( + "EarlyStopping has been deprecated in Determined 0.38.0 and will be removed in a " + "future version. Determined's EarlyStopping is a customization of keras' " + "EarlyStopping callback that is specific to TFKerasTrial, which is also deprecated. " + "Please use the new det.keras.DeterminedCallback for training, and use keras' " + "EarlyStopping with it.", + FutureWarning, + stacklevel=2, + ) + def on_epoch_end(self, epoch: int, logs: Optional[Dict]) -> None: # Ignore on_epoch_end calls, which never contain metrics in Determined. pass @@ -605,6 +631,14 @@ class ReduceLROnPlateau(tf.keras.callbacks.ReduceLROnPlateau, Callback): # type In Determined, ``on_test_end`` may be called slightly more often than ``min_validation_period`` during some types of hyperparameter searches, but it is unlikely for that to occur often enough have a meaningful impact on this callback's operation. + + .. warning:: + + ReduceLROnPlateau has been deprecated in Determined 0.38.0 and will be removed in a future + version. Determined's ReduceLROnPlateau is a customization of keras' ReduceLROnPlateau + callback that is specific to TFKerasTrial, which is also deprecated. Please use the new + :class:`~determined.keras.DeterminedCallback` for training, and use keras' + ReduceLROnPlateau with it. """ _savable_attributes = { @@ -643,6 +677,16 @@ def __init__(self, *arg: Any, **kwarg: Any) -> None: tf.keras.callbacks.ReduceLROnPlateau.__init__(self, *arg, **kwarg) self.test_end_count = 0 + warnings.warn( + "ReduceLROnPlateau has been deprecated in Determined 0.38.0 and will be removed in a " + "future version. Determined's ReduceLROnPlateau is a customization of keras' " + "ReduceLROnPlateau callback that is specific to TFKerasTrial, which is also " + "deprecated. Please use the new det.keras.DeterminedCallback for training, and use " + "keras' ReduceLROnPlateau with it.", + FutureWarning, + stacklevel=2, + ) + def on_epoch_end(self, epoch: int, logs: Optional[Dict]) -> None: # Ignore on_epoch_end calls, which never contain metrics in Determined. pass @@ -667,6 +711,14 @@ class TensorBoard(tf.keras.callbacks.TensorBoard, Callback): # type: ignore `__. Note that if a ``log_dir`` argument is passed to the constructor, it will be ignored. + + .. warning:: + + det.keras.callbacks.TensorBoard has been deprecated in Determined 0.38.0 and will be removed + in a future version. This version of keras' TensorBoard callback is designed to work with + TFKerasTrial, which is also deprecated. Please use the new + :class:`~determined.keras.DeterminedCallback` for training, and use the new + :class:`det.keras.TensorBoard ` with it. """ # TensorBoard uses on_epoch_end but we manually take care of that. @@ -682,6 +734,16 @@ def __init__(self, *args: Any, **kwargs: Any): log_dir = str(tensorboard.get_base_path({}).resolve()) tf.keras.callbacks.TensorBoard.__init__(self, log_dir=log_dir, *args, **kwargs) + warnings.warn( + "det.keras.callbacks.TensorBoard has been deprecated in Determined 0.38.0 and will be " + "removed in a future version. This version of keras' TensorBoard callback is designed " + "to work with TFKerasTrial, which is also deprecated. Please use the new " + "det.keras.DeterminedCallback for training, and use the new det.keras.TensorBoard with " + "it.", + FutureWarning, + stacklevel=2, + ) + def _write_logs(self, *args: Any) -> None: """ _write_logs calls the original _write_logs() function from the Keras diff --git a/harness/determined/launch/deepspeed.py b/harness/determined/launch/deepspeed.py index 6cff0c91d91..648c0244c00 100644 --- a/harness/determined/launch/deepspeed.py +++ b/harness/determined/launch/deepspeed.py @@ -12,6 +12,7 @@ import subprocess import sys import time +import warnings from typing import Dict, List, Mapping, Optional import deepspeed @@ -375,8 +376,8 @@ def parse_args(args: List[str]) -> List[str]: parser.add_argument( "--trial", help=( - "use a Trial class as the entrypoint to training. When --trial is used, the SCRIPT " - "positional argument must be omitted." + "(deprecated) use a Trial class as the entrypoint to training. When --trial is used, " + "the SCRIPT positional argument must be omitted." ), ) # For training scripts. @@ -396,6 +397,14 @@ def parse_args(args: List[str]) -> List[str]: parser.print_usage() print("error: extra arguments to --trial:", script, file=sys.stderr) sys.exit(1) + warnings.warn( + "Support for --trial argument to determined.launch.deepspeed has been deprecated " + "in Determined 0.38.0 and will be removed in a future version. You can keep your " + "DeepSpeedTrial, but please replace your --trial argument with a script that uses the " + "new det.pytorch.deepspeed.Trainer() to train your DeepSpeedTrial.", + FutureWarning, + stacklevel=2, + ) script = det.util.legacy_trial_entrypoint_to_script(parsed.trial) elif not script: # There needs to be at least one script argument. diff --git a/harness/determined/launch/horovod.py b/harness/determined/launch/horovod.py index 899a290c364..a5da95fb73a 100644 --- a/harness/determined/launch/horovod.py +++ b/harness/determined/launch/horovod.py @@ -11,6 +11,7 @@ import subprocess import sys import time +import warnings from typing import List, Tuple import determined as det @@ -223,8 +224,8 @@ def parse_args(args: List[str]) -> Tuple[List[str], List[str], bool]: parser.add_argument( "--trial", help=( - "use a Trial class as the entrypoint to training. When --trial is used, the SCRIPT " - "positional argument must be omitted." + "(deprecated) use a Trial class as the entrypoint to training. When --trial is used, " + "the SCRIPT positional argument must be omitted." ), ) # For training scripts. @@ -270,5 +271,14 @@ def parse_args(args: List[str]) -> Tuple[List[str], List[str], bool]: if __name__ == "__main__": + warnings.warn( + "determined.launch.horovod has been deprecated in Determined 0.38.0 and will be " + "removed in a future version. For PyTorchTrial users, please switch to " + "determined.launch.torch_distributed instead. For TFKerasTrial users, please migrate " + "to the new det.keras.DeterminedCallback for training and use the new " + "determined.launch.tensorflow launcher with it.", + FutureWarning, + stacklevel=2, + ) hvd_args, script, autohorovod = parse_args(sys.argv[1:]) sys.exit(main(hvd_args, script, autohorovod)) diff --git a/harness/determined/launch/tensorflow.py b/harness/determined/launch/tensorflow.py new file mode 100644 index 00000000000..ebd7a9f2646 --- /dev/null +++ b/harness/determined/launch/tensorflow.py @@ -0,0 +1,102 @@ +import argparse +import json +import logging +import os +import subprocess +import sys +from typing import List, Tuple + +import determined as det + +# We use the same port configure as our torch_distributed launcher, to make network communications +# a little easier for the cluster admin. +C10D_PORT = int(str(os.getenv("C10D_PORT", "29400"))) + +logger = logging.getLogger("determined.launch.tensorflow") + + +def create_log_wrapper(rank: int) -> List[str]: + return [ + "python3", + "-m", + "determined.launch.wrap_rank", + str(rank), + "--", + ] + + +def main(port: int, script: List[str]) -> int: + info = det.get_cluster_info() + assert info is not None, "must be run on-cluster" + + chief_ip = info.container_addrs[0] + env = {**os.environ, "DET_CHIEF_IP": chief_ip} + + if len(info.container_addrs) > 1: + # Multi-node training means MultiWorkerMirroredStrategy. + tf_config = { + "cluster": {"worker": [f"{addr}:{port}" for addr in info.container_addrs]}, + "task": {"type": "worker", "index": info.container_rank}, + } + env["TF_CONFIG"] = json.dumps(tf_config) + log_wrapper = create_log_wrapper(info.container_rank) + else: + # Single-node training means MirroredStrategy or just the default strategy. + # (no point in prefixing every log line with "rank=0") + log_wrapper = [] + + launch_cmd = log_wrapper + script + + logger.debug(f"Tensorflow launching with: {launch_cmd}") + + p = subprocess.Popen(launch_cmd, env=env) + with det.util.forward_signals(p): + return p.wait() + + +def parse_args(args: List[str]) -> Tuple[int, List[str]]: + parser = argparse.ArgumentParser( + usage="%(prog)s [--port PORT] [--] SCRIPT...", + description="Launch a script for tensorflow training on a Determined cluster.", + epilog=( + "This launcher automatically injects a TF_CONFIG environment variable suitable for " + "MirroredStrategy or MultiWorkerMirroredStrategy when multiple nodes and or GPUs are " + "available." + ), + ) + parser.add_argument( + "--port", + type=int, + help="the port that TensorFlow should use for distributed training communication", + default=C10D_PORT, + ) + parser.add_argument( + "script", + metavar="SCRIPT...", + nargs=argparse.REMAINDER, + help="script to launch for training", + ) + + # Manually process the -- because argparse doesn't quite handle it right. + if "--" in args: + split = args.index("--") + args, extra_script = args[:split], args[split + 1 :] + else: + extra_script = [] + + parsed = parser.parse_args(args) + + full_script = parsed.script + extra_script + + if not full_script: + # There needs to be at least one script argument. + parser.print_usage() + print("error: empty script is not allowed", file=sys.stderr) + sys.exit(1) + + return parsed.port, full_script + + +if __name__ == "__main__": + port, script = parse_args(sys.argv[1:]) + sys.exit(main(port, script)) diff --git a/harness/determined/launch/torch_distributed.py b/harness/determined/launch/torch_distributed.py index 5795c401f15..b22946b6e9d 100644 --- a/harness/determined/launch/torch_distributed.py +++ b/harness/determined/launch/torch_distributed.py @@ -3,6 +3,7 @@ import os import subprocess import sys +import warnings from typing import List, Tuple import determined as det @@ -127,8 +128,8 @@ def parse_args(args: List[str]) -> Tuple[List[str], List[str]]: parser.add_argument( "--trial", help=( - "use a Trial class as the entrypoint to training. When --trial is used, the SCRIPT " - "positional argument must be omitted." + "(deprecated) use a Trial class as the entrypoint to training. When --trial is used, " + "the SCRIPT positional argument must be omitted." ), ) # For training scripts. @@ -148,6 +149,19 @@ def parse_args(args: List[str]) -> Tuple[List[str], List[str]]: parser.print_usage() print("error: extra arguments to --trial:", script, file=sys.stderr) sys.exit(1) + # We may have arrived here indirectly if the user specified a legacy entrypoint, in which + # case we have a warning for them in exec/harness.py. Only issue the warning if they + # explicitly configured the --trial argument. + info = det.get_cluster_info() + if info and "--trial" in info.trial._config["entrypoint"]: + warnings.warn( + "Support for --trial argument to determined.launch.torch_distributed has been " + "deprecated in Determined 0.38.0 and will be removed in a future version. You can " + "keep your PyTorchTrial, but please replace your --trial argument with a script " + "that uses the det.pytorch.Trainer() to train your PyTorchTrial.", + FutureWarning, + stacklevel=2, + ) script = det.util.legacy_trial_entrypoint_to_script(parsed.trial) elif not script: # There needs to be at least one script argument. diff --git a/harness/determined/layers/_workload_sequencer.py b/harness/determined/layers/_workload_sequencer.py index c15901e5e04..4d01e64fa9e 100644 --- a/harness/determined/layers/_workload_sequencer.py +++ b/harness/determined/layers/_workload_sequencer.py @@ -388,7 +388,7 @@ def __iter__(self) -> workload.Stream: ): yield from self.validate(None) - for op in self.core_context.searcher.operations(core.SearcherMode.ChiefOnly): + for op in self.core_context.searcher._operations(core.SearcherMode.ChiefOnly): while self.batches_until_op_complete(op) > 0: # Do some training. yield from self.train( diff --git a/harness/determined/pytorch/__init__.py b/harness/determined/pytorch/__init__.py index dbf60ce0316..4c055768abf 100644 --- a/harness/determined/pytorch/__init__.py +++ b/harness/determined/pytorch/__init__.py @@ -24,17 +24,20 @@ _convert_metrics_to_numpy, _log_tb_metrics, ) +from determined.pytorch._trainer_utils import ( + Batch, + Epoch, + _ShouldExit, + _TrainBoundary, + _TrainBoundaryType, + TrainUnit, + _TrialState, +) from determined.pytorch._experimental import PyTorchExperimentalContext from determined.pytorch._pytorch_context import PyTorchTrialContext from determined.pytorch._pytorch_trial import ( PyTorchTrial, _PyTorchTrialController, - TrainUnit, - _TrainBoundary, - _TrainBoundaryType, - _TrialState, - Batch, - Epoch, ) from determined.pytorch._load import CheckpointLoadContext, load_trial_from_checkpoint_path from determined.pytorch._trainer import init, Trainer diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index 1363f321483..b431ed4d20c 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -1,6 +1,5 @@ import abc import contextlib -import enum import inspect import json import logging @@ -10,7 +9,6 @@ import sys import time import warnings -from collections import abc as col_abc from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union import numpy as np @@ -40,153 +38,14 @@ def dataloader_next(dataloader_iter: Iterator) -> Iterator: yield batch -class TrainUnit: - """ - TrainUnit is the base class for the supported training units (Batch, Epoch) containing - the value of unit, where the value can be an int or an implementable collections.abc.Container. - - TrainUnits are used to define periodic training behavior such as checkpointing and validating. - - int values are treated as periods, e.g. Batch(100) will checkpoint/validate every 100 batches. - collections.abc.Container values are treated as schedules, e.g. Batch(1,5,10) will - checkpoint/validate on batches 1, 5, and 10. - """ - - def __init__(self, value: Union[int, col_abc.Container]): - self.value = value - - @staticmethod - def _from_searcher_unit( - length: int, unit: Optional[core.Unit], global_batch_size: Optional[int] = None - ) -> "TrainUnit": - if unit == core.Unit.EPOCHS: - return Epoch(length) - elif unit == core.Unit.RECORDS: - if global_batch_size is None: - raise ValueError("global_batch_size required for searcher unit Records.") - return Batch._from_records(length, global_batch_size) - elif unit == core.Unit.BATCHES: - return Batch(length) - else: - raise ValueError(f"unrecognized searcher unit {unit}") - - def _to_searcher_unit(self) -> core.Unit: - if isinstance(self, Batch): - return core.Unit.BATCHES - return core.Unit.EPOCHS - - @staticmethod - def _from_values( - batches: Optional[int] = None, - records: Optional[int] = None, - epochs: Optional[int] = None, - global_batch_size: Optional[int] = None, - ) -> "TrainUnit": - if sum((batches is not None, records is not None, epochs is not None)) != 1: - raise ValueError(f"invalid config: batches={batches} records={records} epochs={epochs}") - if batches is not None: - if batches < 1: - batches = sys.maxsize - return Batch(batches) - if records is not None: - assert global_batch_size, "global_batch_size is required for RECORD units." - if records < 1: - records = sys.maxsize - return Batch._from_records(records, global_batch_size) - if epochs is not None: - if epochs < 1: - epochs = sys.maxsize - return Epoch(epochs) - - # Make mypy happy - raise ValueError("invalid values") - - def should_stop(self, step_num: int) -> bool: - if isinstance(self.value, int): - return self._divides(step_num) - assert isinstance(self.value, col_abc.Container) - return step_num in self.value - - def _divides(self, steps: int) -> bool: - assert isinstance(steps, int) and isinstance( - self.value, int - ), "_divides can only be called on int types." - # Treat <= 0 values as always step - if self.value < 1: - return True - if steps == 0: - return False - return steps % self.value == 0 - - -class Epoch(TrainUnit): - """ - Epoch step type (e.g. Epoch(1) defines 1 epoch) - """ - - pass - - -class Batch(TrainUnit): - """ - Batch step type (e.g. Batch(1) defines 1 batch) - """ - - @staticmethod - def _from_records(records: int, global_batch_size: int) -> "Batch": - return Batch(max(records // global_batch_size, 1)) - - -class _TrainBoundaryType(enum.Enum): - CHECKPOINT = "CHECKPOINT" - REPORT = "REPORT" - VALIDATE = "VALIDATE" - TRAIN = "TRAIN" - - -class _TrainBoundary: - def __init__(self, step_type: _TrainBoundaryType, unit: TrainUnit): - self.step_type = step_type - self.unit = unit - self.limit_reached = False - - -class ShouldExit(Exception): - """ - ShouldExit breaks out of the top-level train loop from inside function calls. - """ - - def __init__(self, skip_exit_checkpoint: bool = False): - self.skip_exit_checkpoint = skip_exit_checkpoint - - -class _TrialState: - def __init__( - self, - trial_id: int = 0, - last_ckpt: int = 0, - step_id: int = 0, - last_val: int = 0, - batches_trained: int = 0, - epochs_trained: int = 0, - ) -> None: - # Store TrialID to distinguish between e.g. pause/restart and continue training. - self.trial_id = trial_id - self.last_ckpt = last_ckpt - self.step_id = step_id - self.last_val = last_val - self.batches_trained = batches_trained - self.epochs_trained = epochs_trained - - class _PyTorchTrialController: def __init__( self, trial_inst: det.LegacyTrial, context: pytorch.PyTorchTrialContext, - checkpoint_period: TrainUnit, - validation_period: TrainUnit, - reporting_period: TrainUnit, + checkpoint_period: pytorch.TrainUnit, + validation_period: pytorch.TrainUnit, + reporting_period: pytorch.TrainUnit, smaller_is_better: bool, steps_completed: int, latest_checkpoint: Optional[str], @@ -195,7 +54,7 @@ def __init__( searcher_metric_name: Optional[str], checkpoint_policy: str, step_zero_validation: bool, - max_length: Optional[TrainUnit], + max_length: pytorch.TrainUnit, global_batch_size: Optional[int], profiling_enabled: Optional[bool], ) -> None: @@ -219,21 +78,10 @@ def __init__( self.reporting_period = reporting_period # Training loop state - if local_training: - self.trial_id = 0 - assert self.max_length, "max_length must be specified for local-training mode." - self.searcher_unit = self.max_length._to_searcher_unit() - else: - self.trial_id = self.core_context.train._trial_id - configured_units = self.core_context.searcher.get_configured_units() - if configured_units is None: - raise ValueError( - "Searcher units must be configured for training with PyTorchTrial." - ) - self.searcher_unit = configured_units + self.trial_id = 0 if local_training else self.core_context.train._trial_id # Don't initialize the state here because it will be invalid until we load a checkpoint. - self.state = None # type: Optional[_TrialState] + self.state = None # type: Optional[pytorch._TrialState] self.start_from_batch = steps_completed self.val_from_previous_run = self.core_context.train._get_last_validation() self.step_zero_validation = step_zero_validation @@ -247,10 +95,6 @@ def __init__( self.global_batch_size = global_batch_size self.profiling_enabled = profiling_enabled - if self.searcher_unit == core.Unit.RECORDS: - if self.global_batch_size is None: - raise ValueError("global_batch_size required for searcher unit RECORDS.") - self.callbacks = self.trial.build_callbacks() for callback in self.callbacks.values(): if util.is_overridden(callback.on_checkpoint_end, pytorch.PyTorchCallback): @@ -352,6 +196,11 @@ def _aggregate_training_metrics(self, training_metrics: List[Dict]) -> Dict: batch_metrics, ) + # We report "batch" and "epoch" only if these keys are not already reported in user + # metrics. + avg_metrics["batches"] = avg_metrics.get("batches", self.state.batches_trained) + avg_metrics["epochs"] = avg_metrics.get("epochs", self.state.epochs_trained) + self.core_context.train.report_training_metrics( steps_completed=self.state.batches_trained, metrics=avg_metrics, @@ -409,7 +258,7 @@ def _checkpoint(self, already_exiting: bool) -> None: except det.InvalidHP: if not already_exiting: self.core_context.train.report_early_exit(core.EarlyExitReason.INVALID_HP) - raise ShouldExit(skip_exit_checkpoint=True) + raise pytorch._ShouldExit(skip_exit_checkpoint=True) raise def _check_evaluate_implementation(self) -> None: @@ -509,21 +358,22 @@ def _step_batch(self) -> None: def _stop_requested(self) -> None: if self.core_context.preempt.should_preempt(): - raise ShouldExit() + raise pytorch._ShouldExit() if self.context.get_stop_requested(): - raise ShouldExit() + raise pytorch._ShouldExit() - def _report_searcher_progress( - self, op: core.SearcherOperation, unit: Optional[core.Unit] - ) -> None: + def _report_training_progress(self) -> None: assert self.state - if unit == core.Unit.BATCHES: - op.report_progress(self.state.batches_trained) - elif unit == core.Unit.RECORDS: - assert self.global_batch_size, "global_batch_size must be specified for RECORDS" - op.report_progress(self.global_batch_size * self.state.batches_trained) - elif unit == core.Unit.EPOCHS: - op.report_progress(self.state.epochs_trained) + assert isinstance(self.max_length.value, int) + + if isinstance(self.max_length, pytorch.Batch): + progress = self.state.batches_trained / self.max_length.value + elif isinstance(self.max_length, pytorch.Epoch): + progress = self.state.epochs_trained / self.max_length.value + else: + raise ValueError(f"unexpected train unit type {type(self.max_length)}") + + self.core_context.train.report_progress(progress=progress) def _checkpoint_is_current(self) -> bool: assert self.state @@ -535,12 +385,12 @@ def _validation_is_current(self) -> bool: # State persists validation step in batches return self.state.last_val == self.state.batches_trained - def _steps_until_complete(self, train_unit: TrainUnit) -> int: + def _steps_until_complete(self, train_unit: pytorch.TrainUnit) -> int: assert isinstance(train_unit.value, int), "invalid length type" assert self.state - if isinstance(train_unit, Batch): + if isinstance(train_unit, pytorch.Batch): return train_unit.value - self.state.batches_trained - elif isinstance(train_unit, Epoch): + elif isinstance(train_unit, pytorch.Epoch): return train_unit.value - self.state.epochs_trained else: raise ValueError(f"Unrecognized train unit {train_unit}") @@ -597,7 +447,7 @@ def cleanup_iterator() -> None: self._load(load_path) else: # If we are not loading, initialize a fresh state. - self.state = _TrialState(trial_id=self.trial_id) + self.state = pytorch._TrialState(trial_id=self.trial_id) if self.context.distributed.size > 1 and self.use_horovod: hvd = horovod.hvd @@ -615,7 +465,6 @@ def cleanup_iterator() -> None: self._run() def _run(self) -> None: - ops: Iterator[det.core.SearcherOperation] assert self.state try: @@ -626,48 +475,27 @@ def _run(self) -> None: ): self._validate() - if self.local_training: - assert self.max_length and isinstance(self.max_length.value, int) - ops = iter( - [ - det.core.DummySearcherOperation( - length=self.max_length.value, is_chief=self.is_chief - ) - ] - ) - else: - ops = self.core_context.searcher.operations() - - for op in ops: - if self.local_training: - train_unit = self.max_length - else: - train_unit = TrainUnit._from_searcher_unit( - op.length, self.searcher_unit, self.global_batch_size - ) - assert train_unit - - self._train_for_op( - op=op, - train_boundaries=[ - _TrainBoundary( - step_type=_TrainBoundaryType.TRAIN, - unit=train_unit, - ), - _TrainBoundary( - step_type=_TrainBoundaryType.VALIDATE, unit=self.validation_period - ), - _TrainBoundary( - step_type=_TrainBoundaryType.CHECKPOINT, - unit=self.checkpoint_period, - ), - # Scheduling unit is always configured in batches - _TrainBoundary( - step_type=_TrainBoundaryType.REPORT, unit=self.reporting_period - ), - ], - ) - except ShouldExit as e: + self._train( + length=pytorch.Batch(1) if self.test_mode else self.max_length, + train_boundaries=[ + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.TRAIN, + unit=self.max_length, + ), + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.VALIDATE, unit=self.validation_period + ), + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.CHECKPOINT, + unit=self.checkpoint_period, + ), + # Scheduling unit is always configured in batches + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.REPORT, unit=self.reporting_period + ), + ], + ) + except pytorch._ShouldExit as e: # Checkpoint unsaved work and exit. if not e.skip_exit_checkpoint and not self._checkpoint_is_current(): self._checkpoint(already_exiting=True) @@ -679,8 +507,8 @@ def _run(self) -> None: return def _train_with_boundaries( - self, training_enumerator: Iterator, train_boundaries: List[_TrainBoundary] - ) -> Tuple[List[_TrainBoundary], List]: + self, training_enumerator: Iterator, train_boundaries: List[pytorch._TrainBoundary] + ) -> Tuple[List[pytorch._TrainBoundary], List]: training_metrics = [] # Start of train step: tell core API and set model mode @@ -711,19 +539,19 @@ def _train_with_boundaries( # Batch complete: check if any training periods have been reached and exit if any for step in train_boundaries: - if isinstance(step.unit, Batch): + if isinstance(step.unit, pytorch.Batch): if step.unit.should_stop(batch_idx + 1): step.limit_reached = True # True epoch based training not supported, detect last batch of epoch to calculate # fully-trained epochs - if isinstance(step.unit, Epoch): + if isinstance(step.unit, pytorch.Epoch): if step.unit.should_stop(epoch_idx + 1): if batch_in_epoch_idx == epoch_len - 1: step.limit_reached = True # Break early after one batch for test mode - if step.step_type == _TrainBoundaryType.TRAIN and self.test_mode: + if step.step_type == pytorch._TrainBoundaryType.TRAIN and self.test_mode: step.limit_reached = True # Exit if any train step limits have been reached @@ -733,20 +561,10 @@ def _train_with_boundaries( # True epoch end return train_boundaries, training_metrics - def _train_for_op( - self, op: core.SearcherOperation, train_boundaries: List[_TrainBoundary] + def _train( + self, length: pytorch.TrainUnit, train_boundaries: List[pytorch._TrainBoundary] ) -> None: - if self.test_mode: - train_length = Batch(1) - elif self.local_training: - train_length = self.max_length # type: ignore - else: - train_length = TrainUnit._from_searcher_unit( - op.length, self.searcher_unit, self.global_batch_size - ) # type: ignore - assert train_length - - while self._steps_until_complete(train_length) > 0: + while self._steps_until_complete(length) > 0: train_boundaries, training_metrics = self._train_with_boundaries( self.training_enumerator, train_boundaries ) @@ -766,18 +584,18 @@ def _train_for_op( continue # Train step limits reached, proceed accordingly. - if train_boundary.step_type == _TrainBoundaryType.TRAIN: - if not op._completed and self.is_chief and not step_reported: - self._report_searcher_progress(op, self.searcher_unit) + if train_boundary.step_type == pytorch._TrainBoundaryType.TRAIN: + if self.is_chief and not step_reported: + self._report_training_progress() step_reported = True - elif train_boundary.step_type == _TrainBoundaryType.REPORT: - if not op._completed and self.is_chief and not step_reported: - self._report_searcher_progress(op, self.searcher_unit) + elif train_boundary.step_type == pytorch._TrainBoundaryType.REPORT: + if self.is_chief and not step_reported: + self._report_training_progress() step_reported = True - elif train_boundary.step_type == _TrainBoundaryType.VALIDATE: + elif train_boundary.step_type == pytorch._TrainBoundaryType.VALIDATE: if not self._validation_is_current(): - self._validate(op) - elif train_boundary.step_type == _TrainBoundaryType.CHECKPOINT: + self._validate() + elif train_boundary.step_type == pytorch._TrainBoundaryType.CHECKPOINT: if not self._checkpoint_is_current(): self._checkpoint(already_exiting=False) @@ -789,20 +607,12 @@ def _train_for_op( self._upload_tb_files() self._stop_requested() - # Finished training for op. Perform final checkpoint/validation if necessary. + # Finished training. Perform final checkpoint/validation if necessary. if not self._validation_is_current(): - self._validate(op) + self._validate() if not self._checkpoint_is_current(): self._checkpoint(already_exiting=False) - # Test mode will break after one batch despite not completing op. - if self.is_chief and not self.test_mode: - # The only case where op isn't reported as completed is if we restarted but - # op.length was already trained for and validated on; in that case just raise - # ShouldExit; we have nothing to do. - if not op._completed: - raise ShouldExit(skip_exit_checkpoint=True) - def _check_searcher_metric(self, val_metrics: Dict) -> Any: if self.searcher_metric_name not in val_metrics: raise RuntimeError( @@ -913,7 +723,7 @@ def _train_batch( return training_metrics @torch.no_grad() - def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dict[str, Any]: + def _validate(self) -> Dict[str, Any]: # Report a validation step is starting. if self.is_chief: self.core_context.train.set_status("validating") @@ -1049,24 +859,14 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic # Get best validation before reporting metrics. best_validation_before = self.core_context.train.get_experiment_best_validation() - self.core_context.train.report_validation_metrics(self.state.batches_trained, metrics) - - searcher_metric = None + # We report "batch" and "epoch" only if these keys are not already reported in user + # metrics. + metrics["batches"] = metrics.get("batches", self.state.batches_trained) + metrics["epochs"] = metrics.get("epochs", self.state.epochs_trained) - # Report searcher status. - if self.is_chief and searcher_op: - if self.local_training: - searcher_length = self.max_length - else: - searcher_length = TrainUnit._from_searcher_unit( - searcher_op.length, self.searcher_unit, self.global_batch_size - ) - if self.searcher_metric_name: - searcher_metric = self._check_searcher_metric(metrics) - - assert searcher_length - if self._steps_until_complete(searcher_length) < 1 and not searcher_op._completed: - searcher_op.report_completed(searcher_metric) + self.core_context.train.report_validation_metrics( + steps_completed=self.state.batches_trained, metrics=metrics + ) should_checkpoint = False @@ -1079,6 +879,7 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic assert ( self.searcher_metric_name ), "checkpoint policy 'best' but searcher metric name not defined" + searcher_metric = self._check_searcher_metric(metrics) assert searcher_metric is not None if self._is_best_validation(now=searcher_metric, before=best_validation_before): @@ -1250,10 +1051,10 @@ def _load_state(self, state: Any) -> None: # If the trial_id doesn't match our current trial id, we're continuing training a previous # trial and should start from a fresh state. if state.get("trial_id") != self.trial_id: - self.state = _TrialState(trial_id=self.trial_id) + self.state = pytorch._TrialState(trial_id=self.trial_id) return - self.state = _TrialState(**state) + self.state = pytorch._TrialState(**state) assert self.state # Detect the case where the final validation we made was against this exact checkpoint. In @@ -1266,10 +1067,10 @@ def _load_state(self, state: Any) -> None: def _load_wlsq_state(self, state: Any) -> None: if state.get("trial_id") != self.trial_id: - self.state = _TrialState(trial_id=self.trial_id) + self.state = pytorch._TrialState(trial_id=self.trial_id) return - self.state = _TrialState( + self.state = pytorch._TrialState( trial_id=state.get("trial_id"), last_ckpt=state.get("last_ckpt"), last_val=state.get("last_val"), diff --git a/harness/determined/pytorch/_trainer.py b/harness/determined/pytorch/_trainer.py index 180b4ca214a..bdb51208e70 100644 --- a/harness/determined/pytorch/_trainer.py +++ b/harness/determined/pytorch/_trainer.py @@ -2,6 +2,7 @@ import logging import random import sys +import warnings from typing import Any, Dict, Iterator, Optional import numpy as np @@ -94,11 +95,16 @@ def fit( of ``collections.abc.Container`` (list, tuple, etc.). For example, ``Batch(100)`` would validate every 100 batches, while ``Batch([5, 30, 45])`` would validate after every 5th, 30th, and 45th batch. - max_length: The maximum number of steps to train for. This value is required and - only applicable in local training mode. For on-cluster training, this value will - be ignored; the searcher’s ``max_length`` must be configured from the experiment - configuration. This is a ``TrainUnit`` type (``Batch`` or ``Epoch``) which takes an - ``int``. For example, ``Epoch(1)`` would train for a maximum length of one epoch. + max_length: The maximum number of steps to train for. This is a ``TrainUnit`` type + (``Batch`` or ``Epoch``) which takes an ``int``. For example, ``Epoch(1)`` would + train for a maximum length of one epoch. + + .. note:: + + If using an ASHA searcher, this value should match the searcher config values in + the experiment config (i.e. ``Epoch(1)`` = `max_time: 1` and `time_metric: + "epochs"`). + reporting_period: The number of steps to train for before reporting metrics and searcher progress. For local training mode, metrics are printed to stdout. This is a ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or @@ -146,8 +152,13 @@ def fit( if max_length is None: raise ValueError("max_length must be defined in local training mode.") - if not isinstance(max_length.value, int): - raise TypeError("max_length must be configured in TrainUnit(int) types.") + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type" + ) if profiling_enabled: logger.warning("Profiling is not supported in local training mode.") @@ -160,12 +171,6 @@ def fit( if test_mode: raise ValueError("test_mode is only supported in local training mode.") - if max_length is not None: - logger.warning( - "max_length is ignored when training on-cluster. Please configure the " - "searcher instead." - ) - assert self._info, "Unable to detect cluster info." if latest_checkpoint is None and self._info.latest_checkpoint is not None: logger.warning( @@ -176,11 +181,45 @@ def fit( smaller_is_better = bool(self._info.trial._config["searcher"]["smaller_is_better"]) searcher_metric_name = self._info.trial._config["searcher"]["metric"] + steps_completed = int(self._info.trial._steps_completed) global_batch_size = self._info.trial.hparams.get("global_batch_size", None) if global_batch_size: global_batch_size = int(global_batch_size) + # Backwards compatibility: try to parse legacy `searcher.max_length` if `max_length` + # isn't passed in. + if max_length is None: + max_length_val = core._parse_searcher_max_length(self._info.trial._config) + if max_length_val: + warnings.warn( + "Configuring `max_length` from the `searcher.max_length` experiment " + "config, which was deprecated in 0.38.0 and will be removed in a future " + "release. Please set `fit(max_length=X)` with your desired training length " + "directly.", + FutureWarning, + stacklevel=2, + ) + max_length_unit = core._parse_searcher_units(self._info.trial._config) + max_length = pytorch.TrainUnit._from_searcher_unit( + max_length_val, max_length_unit, global_batch_size + ) + + # If we couldn't parse the legacy `searcher.max_length`, raise an error. + if not max_length: + raise ValueError( + "`fit(max_length=X)` must be set with your desired training length." + ) + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type." + ) + + _check_searcher_length(exp_conf=self._info.trial._config, max_length=max_length) + trial_controller = pytorch._PyTorchTrialController( trial_inst=self._trial, context=self._context, @@ -203,6 +242,43 @@ def fit( trial_controller.run() +def _check_searcher_length( + exp_conf: Dict[str, Any], + max_length: pytorch.TrainUnit, +) -> None: + """ + Certain searchers (ASHA and Adaptive ASHA) require configuring the maximum training length in + the experiment config. We check that the `max_length` passed to `fit()` matches the experiment + config and log warnings if it doesn't. + """ + time_metric = exp_conf["searcher"].get("time_metric") + if time_metric is not None: + max_time = exp_conf["searcher"].get("max_time") + assert max_time, "`searcher.max_time` not configured" + if time_metric == "batches": + if not isinstance(max_length, pytorch.Batch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Batch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + elif time_metric == "epochs": + if not isinstance(max_length, pytorch.Epoch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Epoch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + else: + logger.warning( + "`searcher.time_metric` must be either 'batches' or 'epochs' " + f"for training with PyTorchTrials, but got {time_metric}. " + f"Training will proceed with {max_length} but may result in unexpected behavior." + ) + + def _initialize_distributed_backend() -> Optional[core.DistributedContext]: info = det.get_cluster_info() diff --git a/harness/determined/pytorch/_trainer_utils.py b/harness/determined/pytorch/_trainer_utils.py new file mode 100644 index 00000000000..254fad6e150 --- /dev/null +++ b/harness/determined/pytorch/_trainer_utils.py @@ -0,0 +1,145 @@ +import enum +import sys +from collections import abc +from typing import Optional, Union + +from determined import core + + +class TrainUnit: + """ + TrainUnit is the base class for the supported training units (Batch, Epoch) containing + the value of unit, where the value can be an int or an implementable collections.abc.Container. + + TrainUnits are used to define periodic training behavior such as checkpointing and validating. + + int values are treated as periods, e.g. Batch(100) will checkpoint/validate every 100 batches. + collections.abc.Container values are treated as schedules, e.g. Batch(1,5,10) will + checkpoint/validate on batches 1, 5, and 10. + """ + + def __init__(self, value: Union[int, abc.Container]): + self.value = value + + @staticmethod + def _from_searcher_unit( + length: int, unit: Optional[core.Unit], global_batch_size: Optional[int] = None + ) -> "TrainUnit": + if unit == core.Unit.EPOCHS: + return Epoch(length) + elif unit == core.Unit.RECORDS: + if global_batch_size is None: + raise ValueError("global_batch_size required for searcher unit Records.") + return Batch._from_records(length, global_batch_size) + elif unit == core.Unit.BATCHES: + return Batch(length) + else: + raise ValueError(f"unrecognized searcher unit {unit}") + + def _to_searcher_unit(self) -> core.Unit: + if isinstance(self, Batch): + return core.Unit.BATCHES + return core.Unit.EPOCHS + + @staticmethod + def _from_values( + batches: Optional[int] = None, + records: Optional[int] = None, + epochs: Optional[int] = None, + global_batch_size: Optional[int] = None, + ) -> "TrainUnit": + if sum((batches is not None, records is not None, epochs is not None)) != 1: + raise ValueError(f"invalid config: batches={batches} records={records} epochs={epochs}") + if batches is not None: + if batches < 1: + batches = sys.maxsize + return Batch(batches) + if records is not None: + assert global_batch_size, "global_batch_size is required for RECORD units." + if records < 1: + records = sys.maxsize + return Batch._from_records(records, global_batch_size) + if epochs is not None: + if epochs < 1: + epochs = sys.maxsize + return Epoch(epochs) + + # Make mypy happy + raise ValueError("invalid values") + + def should_stop(self, step_num: int) -> bool: + if isinstance(self.value, int): + return self._divides(step_num) + assert isinstance(self.value, abc.Container) + return step_num in self.value + + def _divides(self, steps: int) -> bool: + assert isinstance(steps, int) and isinstance( + self.value, int + ), "_divides can only be called on int types." + # Treat <= 0 values as always step + if self.value < 1: + return True + if steps == 0: + return False + return steps % self.value == 0 + + +class Epoch(TrainUnit): + """ + Epoch step type (e.g. Epoch(1) defines 1 epoch) + """ + + pass + + +class Batch(TrainUnit): + """ + Batch step type (e.g. Batch(1) defines 1 batch) + """ + + @staticmethod + def _from_records(records: int, global_batch_size: int) -> "Batch": + return Batch(max(records // global_batch_size, 1)) + + +class _ShouldExit(Exception): + """ + ShouldExit breaks out of the top-level train loop from inside function calls. + """ + + def __init__(self, skip_exit_checkpoint: bool = False): + self.skip_exit_checkpoint = skip_exit_checkpoint + + +class _TrialState: + def __init__( + self, + trial_id: int = 0, + last_ckpt: int = 0, + step_id: int = 0, + last_val: int = 0, + batches_trained: int = 0, + epochs_trained: int = 0, + ) -> None: + # Store TrialID to distinguish between e.g. pause/restart and continue training. + self.trial_id = trial_id + self.last_ckpt = last_ckpt + self.step_id = step_id + self.last_val = last_val + self.batches_trained = batches_trained + self.epochs_trained = epochs_trained + + +class _TrainBoundaryType(enum.Enum): + CHECKPOINT = "CHECKPOINT" + REPORT = "REPORT" + VALIDATE = "VALIDATE" + TRAIN = "TRAIN" + + +class _TrainBoundary: + def __init__(self, step_type: _TrainBoundaryType, unit: TrainUnit): + self.step_type = step_type + self.unit = unit + self.limit_reached = False diff --git a/harness/determined/pytorch/deepspeed/__init__.py b/harness/determined/pytorch/deepspeed/__init__.py index 46b40dc66f7..62cb79dfaaf 100644 --- a/harness/determined/pytorch/deepspeed/__init__.py +++ b/harness/determined/pytorch/deepspeed/__init__.py @@ -8,3 +8,4 @@ overwrite_deepspeed_config, ) from determined.pytorch.deepspeed._deepspeed_trial import DeepSpeedTrial, DeepSpeedTrialController +from determined.pytorch.deepspeed._trainer import init, Trainer diff --git a/harness/determined/pytorch/deepspeed/_deepspeed_context.py b/harness/determined/pytorch/deepspeed/_deepspeed_context.py index dbb80c7f651..b71f44e31da 100644 --- a/harness/determined/pytorch/deepspeed/_deepspeed_context.py +++ b/harness/determined/pytorch/deepspeed/_deepspeed_context.py @@ -1,5 +1,6 @@ import json import logging +import pathlib import time from importlib import util as importutil from typing import Any, Dict, List, Optional, Set, Type, Union, cast @@ -42,7 +43,7 @@ def overwrite_deepspeed_config( return util.merge_dicts(cast(Dict[str, Any], base_ds_config), source_ds_dict) -class DeepSpeedTrialContext(det.TrialContext, pytorch._PyTorchReducerContext): +class DeepSpeedTrialContext(pytorch._PyTorchReducerContext): """Contains runtime information for any Determined workflow that uses the ``DeepSpeedTrial`` API. @@ -65,10 +66,38 @@ class DeepSpeedTrialContext(det.TrialContext, pytorch._PyTorchReducerContext): 5. Disable automatic gradient aggregation for non-pipeline-parallel training. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - det.TrialContext.__init__(self, *args, **kwargs) + def __init__( + self, + core_context: det.core.Context, + trial_seed: Optional[int], + hparams: Optional[Dict], + slots_per_trial: int, + num_gpus: int, + exp_conf: Optional[Dict[str, Any]], + steps_completed: int, + enable_tensorboard_logging: bool = True, + ) -> None: + self._core = core_context + self.distributed = self._core.distributed + pytorch._PyTorchReducerContext.__init__(self, self.distributed.allgather) + self._per_slot_batch_size, self._global_batch_size = ( + util.calculate_batch_sizes( + hparams=hparams, + slots_per_trial=slots_per_trial, + trialname="DeepSpeedTrial", + ) + if hparams and hparams.get("global_batch_size", None) + else (None, None) + ) + self._hparams = hparams + self._num_gpus = num_gpus + self._exp_conf = exp_conf + + self._trial_seed = trial_seed + self._steps_completed = steps_completed + self._init_device() # Track which types we have issued warnings for in to_device(). @@ -85,14 +114,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # The following attributes are initialized during the lifetime of # a DeepSpeedTrialContext. self.models = [] # type: List[deepspeed.DeepSpeedEngine] + self.profiler = None # type: Any self._epoch_len = None # type: Optional[int] self._loss_ids = {} # type: Dict[torch.Tensor, int] self._last_backward_batch_idx = None # type: Optional[int] self._current_batch_idx = None # type: Optional[int] - self.profiler = None # type: Any - self._mpu = det_ds.make_data_parallel_mpu( self.distributed ) # type: det_ds.ModelParallelUnit @@ -103,48 +131,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._data_repro_checks_disabled = False self._manual_grad_accumulation = False - self._check_experiment_config_optimizations() + self._stop_requested = False self._tbd_writer = None # type: Optional[Any] - self._enable_tensorboard_logging = True + self._enable_tensorboard_logging = enable_tensorboard_logging # Timestamp for batching TensorBoard uploads self._last_tb_reset_ts: Optional[float] = None - def _check_experiment_config_optimizations(self) -> None: - """ - Check if the user specified options in optimizations are incompatible with - DeepSpeedTrial. - """ - optimizations_config = self.env.experiment_config.get_optimizations_config() - self._average_training_metrics = optimizations_config.get("average_training_metrics", False) - - mixed_precision_val = optimizations_config.get("mixed_precision", "O0") - if mixed_precision_val != "O0": - raise det.errors.InvalidExperimentException( - "Mixed precision is specified through the deepspeed config instead of the " - "Determined experiment config.", - ) - aggregation_frequency = optimizations_config.get("aggregation_frequency", 1) - if aggregation_frequency > 1: - raise det.errors.InvalidExperimentException( - "Gradient aggregation is specified through the deepspeed config instead of the " - "Determined experiment config.", - ) - other_optimizations_default_values = { - "average_aggregated_gradients": True, - "gradient_compression": False, - "tensor_fusion_threshold": 64, - "tensor_fusion_cycle_time": 5, - "autotune_tensor_fusion": False, - } - for opt_field, default_value in other_optimizations_default_values.items(): - opt_value = optimizations_config.get(opt_field, default_value) - if opt_value != default_value: - logger.warning( - f"{opt_field}={opt_value} ignored since the setting does not apply " - "to DeepSpeedTrial." - ) - def set_mpu(self, mpu: det_ds.ModelParallelUnit) -> None: """Use a custom model parallel configuration. @@ -166,12 +159,6 @@ def set_mpu(self, mpu: det_ds.ModelParallelUnit) -> None: "Only one MPU can be passed to DeepSpeedTrialContext. " "Please make sure wrap_mpu is only called once in the trial definition." ) - if self.distributed.rank == 0: - if not self._mpu.should_report_metrics and not self._average_training_metrics: - raise det.errors.InvalidExperimentException( - "Please set optimizations.average_training_metrics in the experiment config " - "to true so that metrics will exist on the chief for report to the master." - ) self._called_set_mpu = True self._mpu = mpu @@ -245,16 +232,14 @@ def disable_dataset_reproducibility_checks(self) -> None: def use_pipeline_parallel(self) -> bool: return self._use_pipeline_parallel - @property - def train_micro_batch_size_per_gpu(self) -> int: + def get_train_micro_batch_size_per_gpu(self) -> int: if self._train_micro_batch_size_per_gpu is None: raise det.errors.InvalidExperimentException( "Please call wrap_model_engine before accessing train_micro_batch_size." ) return self._train_micro_batch_size_per_gpu - @property - def num_micro_batches_per_slot(self) -> int: + def get_num_micro_batches_per_slot(self) -> int: if self._num_micro_batches_per_slot is None: raise det.errors.InvalidExperimentException( "Please call wrap_model_engine before accessing num_micro_batches_per_slot." @@ -262,8 +247,7 @@ def num_micro_batches_per_slot(self) -> int: return self._num_micro_batches_per_slot def _init_device(self) -> None: - self.n_gpus = len(self.env.container_gpus) - if not self.n_gpus: + if not self._num_gpus: raise det.errors.InvalidExperimentException("GPUs required for DeepSpeedTrial.") if self.distributed.size > 1: self.device = torch.device("cuda", self.distributed.get_local_rank()) @@ -359,6 +343,38 @@ def set_profiler(self, *args: List[str], **kwargs: Any) -> None: **kwargs, ) + def get_initial_batch(self) -> int: + return self._steps_completed + + def get_data_config(self) -> Dict[str, Any]: + """ + Return the data configuration. + """ + return self.get_experiment_config().get("data", {}) + + def get_experiment_id(self) -> int: + """ + Return the experiment ID of the current trial. + """ + return int(self._core.train._exp_id) + + def get_trial_id(self) -> int: + """ + Return the trial ID of the current trial. + """ + return int(self._core.train._trial_id) + + def get_trial_seed(self) -> int: + if self._trial_seed is None: + raise det.errors.InternalException("Trial seed not set.") + return self._trial_seed + + def get_tensorboard_path(self) -> pathlib.Path: + """ + Get the path where files for consumption by TensorBoard should be written + """ + return self._core.train.get_tensorboard_path() + def get_tensorboard_writer(self) -> Any: """ This function returns an instance of ``torch.utils.tensorboard.SummaryWriter`` @@ -442,3 +458,86 @@ def get_enable_tensorboard_logging(self) -> bool: Return whether automatic tensorboard logging is enabled """ return self._enable_tensorboard_logging + + def get_global_batch_size(self) -> int: + """ + Return the global batch size. + """ + if self._global_batch_size is None: + raise ValueError( + "global_batch_size is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + return self._global_batch_size + + def get_per_slot_batch_size(self) -> int: + """ + Return the per-slot batch size. When a model is trained with a single GPU, this is equal to + the global batch size. When multi-GPU training is used, this is equal to the global batch + size divided by the number of GPUs used to train the model. + """ + if self._per_slot_batch_size is None: + raise ValueError( + "per_slot_batch_size is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + + return self._per_slot_batch_size + + def get_experiment_config(self) -> Dict[str, Any]: + if self._exp_conf is None: + raise ValueError( + "exp_conf is undefined in this Trial. Please check the init() call to Trainer API." + ) + return self._exp_conf + + def get_hparam(self, name: str) -> Any: + """ + Return the current value of the hyperparameter with the given name. + """ + if self._hparams is None: + raise ValueError( + "hparams is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + if name not in self.get_hparams(): + raise ValueError( + "Could not find name '{}' in experiment " + "hyperparameters. Please check your experiment " + "configuration 'hyperparameters' section.".format(name) + ) + if name == "global_batch_size": + logger.warning( + "Please use `context.get_per_slot_batch_size()` and " + "`context.get_global_batch_size()` instead of accessing " + "`global_batch_size` directly." + ) + return self.get_hparams()[name] + + def get_hparams(self) -> Dict[str, Any]: + if self._hparams is None: + raise ValueError( + "hparams is undefined in this Trial because hparams was not " + "configured. Please check the init() call to Trainer API." + ) + return self._hparams + + def get_stop_requested(self) -> bool: + """ + Return whether a trial stoppage has been requested. + """ + return self._stop_requested + + def set_stop_requested(self, stop_requested: bool) -> None: + """ + Set a flag to request a trial stoppage. When this flag is set to True, + we finish the step, checkpoint, then exit. + """ + if not isinstance(stop_requested, bool): + raise AssertionError("stop_requested must be a boolean") + + logger.info( + "A trial stoppage has requested. The trial will be stopped " + "at the end of the current step." + ) + self._stop_requested = stop_requested diff --git a/harness/determined/pytorch/deepspeed/_deepspeed_trial.py b/harness/determined/pytorch/deepspeed/_deepspeed_trial.py index f59152e0ea0..8c8d3f5d599 100644 --- a/harness/determined/pytorch/deepspeed/_deepspeed_trial.py +++ b/harness/determined/pytorch/deepspeed/_deepspeed_trial.py @@ -1,5 +1,7 @@ import abc import contextlib +import inspect +import json import logging import os import pathlib @@ -7,7 +9,7 @@ import random import time import warnings -from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, cast +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import deepspeed import numpy as np @@ -15,9 +17,8 @@ from deepspeed.runtime import dataloader as ds_loader import determined as det -from determined import layers, pytorch, util, workload +from determined import core, pytorch, tensorboard, util from determined.pytorch import deepspeed as det_ds -from determined.pytorch import dsat logger = logging.getLogger("determined.pytorch") @@ -31,24 +32,48 @@ def get_length(self: ds_loader.RepeatingLoader) -> int: return len(self.loader) -ds_loader.RepeatingLoader.__len__ = get_length +def dataloader_next(dataloader_iter: Optional[Iterator]) -> Iterator: + if dataloader_iter is None: + return None + while True: + try: + batch = next(dataloader_iter) + except StopIteration: + return + yield batch + +ds_loader.RepeatingLoader.__len__ = get_length -class DeepSpeedTrialController(det.TrialController): - def __init__(self, trial_inst: det.LegacyTrial, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) +class DeepSpeedTrialController: + def __init__( + self, + trial_inst: det.LegacyTrial, + context: det_ds.DeepSpeedTrialContext, + checkpoint_period: pytorch.TrainUnit, + validation_period: pytorch.TrainUnit, + reporting_period: pytorch.TrainUnit, + smaller_is_better: bool, + steps_completed: int, + latest_checkpoint: Optional[str], + local_training: bool, + test_mode: bool, + searcher_metric_name: Optional[str], + checkpoint_policy: str, + step_zero_validation: bool, + max_length: pytorch.TrainUnit, + global_batch_size: Optional[int], + profiling_enabled: Optional[bool], + ) -> None: assert isinstance( trial_inst, DeepSpeedTrial ), "DeepSpeedTrialController needs a DeepSpeedTrial" self.trial = trial_inst - self.context = cast(det_ds.DeepSpeedTrialContext, self.context) - self._dsat_mode = self.context.get_hparams().get(dsat.defaults.USE_DSAT_MODE_KEY, False) - if self._dsat_mode: - searcher_name = self.context.get_experiment_config()["searcher"]["name"] - assert ( - searcher_name == "custom" - ), "`_dsat_mode` can only be set to true for Custom Searcher trials." + self.context = context + self.core_context = self.context._core + + self.is_chief = self.context.distributed.rank == 0 self.callbacks = self.trial.build_callbacks() for callback in self.callbacks.values(): @@ -66,18 +91,35 @@ def __init__(self, trial_inst: det.LegacyTrial, *args: Any, **kwargs: Any) -> No "This might be caused by not wrapping your model with wrap_model_engine()" ) - self.wlsq = None # type: Optional[layers.WorkloadSequencer] - if self.workloads is None: - self.workloads, self.wlsq = layers.make_compatibility_workloads( - self.context._core, self.env, self.context.models[0].train_batch_size() - ) - - self.steps_completed = self.env.steps_completed + # Don't initialize the state here because it will be invalid until we load a checkpoint. + self.state = None # type: Optional[pytorch._TrialState] + self.start_from_batch = steps_completed + self.val_from_previous_run = self.core_context.train._get_last_validation() + self.step_zero_validation = step_zero_validation + + # Training configs + self.latest_checkpoint = latest_checkpoint + self.test_mode = test_mode + self.searcher_metric_name = searcher_metric_name + self.checkpoint_policy = checkpoint_policy + self.smaller_is_better = smaller_is_better + self.global_batch_size = global_batch_size + self.profiling_enabled = profiling_enabled + + # Training loop variables + self.max_length = max_length + self.checkpoint_period = checkpoint_period + self.validation_period = validation_period + self.reporting_period = reporting_period + + # Training loop state + self.local_training = local_training + self.trial_id = 0 if local_training else self.core_context.train._trial_id @classmethod def pre_execute_hook( cls: Type["DeepSpeedTrialController"], - env: det.EnvContext, + trial_seed: int, distributed_backend: det._DistributedBackend, ) -> None: # We use an environment variable to allow users to enable custom initialization routine for @@ -94,18 +136,19 @@ def pre_execute_hook( # training batch. # TODO (Liam): seed data loading workers so that we can configure different seeds for # data augmentations per slot per worker. - random.seed(env.trial_seed) - np.random.seed(env.trial_seed) - torch.random.manual_seed(env.trial_seed) - - @classmethod - def from_trial( - cls: Type["DeepSpeedTrialController"], *args: Any, **kwargs: Any - ) -> det.TrialController: - return cls(*args, **kwargs) + random.seed(trial_seed) + np.random.seed(trial_seed) + torch.random.manual_seed(trial_seed) + + def _upload_tb_files(self) -> None: + self.context._maybe_reset_tbd_writer() + self.core_context.train.upload_tensorboard_files( + (lambda _: True) if self.is_chief else (lambda p: not p.match("*tfevents*")), + tensorboard.util.get_rank_aware_path, + ) def _set_data_loaders(self) -> None: - skip_batches = self.env.steps_completed + skip_batches = self.start_from_batch # Training and validation data loaders are not built for every slot when model parallelism # is used. @@ -151,14 +194,14 @@ def _set_data_loaders(self) -> None: ) if self.context.use_pipeline_parallel: - if len(self.validation_loader) < self.context.num_micro_batches_per_slot: + if len(self.validation_loader) < self.context.get_num_micro_batches_per_slot(): raise det.errors.InvalidExperimentException( "Number of train micro batches in validation data loader should not be " "less than the number of gradient accumulation steps when using " "pipeline parallelism." ) excluded_micro_batches = ( - len(validation_data) % self.context.num_micro_batches_per_slot + len(validation_data) % self.context.get_num_micro_batches_per_slot() ) if excluded_micro_batches: logger.warning( @@ -189,9 +232,9 @@ def _set_data_loaders(self) -> None: if self.context.use_pipeline_parallel: self.num_validation_batches = ( - self.num_validation_batches // self.context.num_micro_batches_per_slot + self.num_validation_batches // self.context.get_num_micro_batches_per_slot() ) - self.validation_batch_size *= self.context.num_micro_batches_per_slot + self.validation_batch_size *= self.context.get_num_micro_batches_per_slot() # We will do a gather on to get train and val loader lengths and broadcast to all slots. self.context._epoch_len = ( @@ -199,28 +242,34 @@ def _set_data_loaders(self) -> None: ) all_epoch_lens = self.context.distributed.gather(self.context._epoch_len) if self.is_chief: - all_epoch_lens = [le for le in all_epoch_lens if le is not None] + all_epoch_lens = [le for le in all_epoch_lens if le is not None] # type: ignore if min(all_epoch_lens) < max(all_epoch_lens): logger.warning( "Training data loader length inconsistent across ranks. " "Using the minimum for epoch length." ) - self.context._epoch_len = min(all_epoch_lens) // self.context.num_micro_batches_per_slot + self.context._epoch_len = ( + min(all_epoch_lens) // self.context.get_num_micro_batches_per_slot() + ) self.context._epoch_len = self.context.distributed.broadcast(self.context._epoch_len) all_tuples = self.context.distributed.gather( (self.num_validation_batches, self.validation_batch_size) ) if self.is_chief: - all_num_validation_batches, all_validation_batch_size = zip(*all_tuples) - all_num_validation_batches = [le for le in all_num_validation_batches if le is not None] + all_num_validation_batches, all_validation_batch_size = zip(*all_tuples) # type: ignore + all_num_validation_batches = [ + le for le in all_num_validation_batches if le is not None + ] # type: ignore if min(all_num_validation_batches) < max(all_num_validation_batches): logger.warning( "Validation data loader length inconsistent across ranks. " "Using the minimum for validation length." ) self.num_validation_batches = min(all_num_validation_batches) - all_validation_batch_size = [le for le in all_validation_batch_size if le is not None] + all_validation_batch_size = [ + le for le in all_validation_batch_size if le is not None + ] # type: ignore if min(all_validation_batch_size) < max(all_validation_batch_size): logger.warning( "Validation batch size inconsistent across ranks. " @@ -251,7 +300,7 @@ def on_shutdown(callback_name: str, on_trial_shutdown: Callable) -> None: with contextlib.ExitStack() as exit_stack: for callback in self.callbacks.values(): - callback.on_trial_startup(self.steps_completed, self.env.latest_checkpoint) + callback.on_trial_startup(self.start_from_batch, self.latest_checkpoint) exit_stack.enter_context( defer(on_shutdown, callback.__class__.__name__, callback.on_trial_shutdown) ) @@ -271,19 +320,22 @@ def on_shutdown(callback_name: str, on_trial_shutdown: Callable) -> None: ) def cleanup_iterator() -> None: - # Explicitly trigger the training iterator's shutdown (which happens in __del__). + # Explicitly trigger the iterator's shutdown (which happens in __del__). # See the rather long note in pytorch/torch/utils/data/dataloader.py. del self.training_iterator exit_stack.enter_context(defer(cleanup_iterator)) # If a load path is provided load weights and restore the data location. - if self.env.latest_checkpoint is not None: - logger.info(f"Restoring trial from checkpoint {self.env.latest_checkpoint}") + if self.latest_checkpoint is not None: + logger.info(f"Restoring trial from checkpoint {self.latest_checkpoint}") with self.context._core.checkpoint.restore_path( - self.env.latest_checkpoint + self.latest_checkpoint ) as load_path: self._load(load_path) + else: + # If we are not loading, initialize a fresh state. + self.state = pytorch._TrialState(trial_id=self.trial_id) for callback in self.callbacks.values(): callback.on_training_start() @@ -295,184 +347,247 @@ def cleanup_iterator() -> None: self._run() def _run(self) -> None: - # Special code path only used for DeepSpeed Autotuning. - if self._dsat_mode: - ops = self.context._core.searcher.operations() - op = next(ops) - for _ in range(op.length): - with dsat.dsat_reporting_context(core_context=self.context._core, op=op): - _ = self._train_for_step( - step_id=self.steps_completed + 1, - num_batches=1, - total_batches_processed=self.steps_completed, - ) + assert self.state + + try: + if ( + self.step_zero_validation + and self.val_from_previous_run is None + and self.state.batches_trained == 0 + ): + self._validate() + + self._train( + length=pytorch.Batch(1) if self.test_mode else self.max_length, + train_boundaries=[ + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.TRAIN, + unit=self.max_length, + ), + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.VALIDATE, unit=self.validation_period + ), + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.CHECKPOINT, + unit=self.checkpoint_period, + ), + # Scheduling unit is always configured in batches + pytorch._TrainBoundary( + step_type=pytorch._TrainBoundaryType.REPORT, unit=self.reporting_period + ), + ], + ) + except pytorch._ShouldExit as e: + # Checkpoint unsaved work and exit. + if not e.skip_exit_checkpoint and not self._checkpoint_is_current(): + self._checkpoint(already_exiting=True) - assert self.workloads is not None - for w, response_func in self.workloads: - try: - if w.kind == workload.Workload.Kind.RUN_STEP: - action = "training" - metrics = self._train_for_step( - w.step_id, - w.num_batches, - w.total_batches_processed, - ) - response = { - "metrics": metrics, - "stop_requested": self.context.get_stop_requested(), - } # type: workload.Response - metrics = self.context.distributed.broadcast(metrics) - for callback in self.callbacks.values(): - callback.on_training_workload_end( - avg_metrics=metrics["avg_metrics"], - batch_metrics=metrics["batch_metrics"], - ) - elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: - action = "validation" - response = { - "metrics": self._compute_validation_metrics(), - "stop_requested": self.context.get_stop_requested(), - } - elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: - action = "checkpointing" - metadata = { - "steps_completed": self.steps_completed, - "framework": f"torch-{torch.__version__}", - "format": "pickle", - } - with self.context._core.checkpoint.store_path(metadata, shard=True) as ( - path, - storage_id, - ): - self._save(path) - response = {"uuid": storage_id} - for callback in self.callbacks.values(): - callback.on_checkpoint_upload_end(uuid=storage_id) - else: - raise AssertionError("Unexpected workload: {}".format(w.kind)) + except det.InvalidHP as e: + # Catch InvalidHP to checkpoint before exiting and re-raise for cleanup by core.init() + if not self._checkpoint_is_current(): + self._checkpoint(already_exiting=True) + raise e - except det.InvalidHP as e: - logger.info(f"Invalid hyperparameter exception during {action}: {e}") - response = workload.InvalidHP() - response_func(response) - self.context._maybe_reset_tbd_writer() - self.upload_tb_files() + return - def get_epoch_idx(self, batch_id: int) -> int: + def _get_epoch_idx(self, batch_id: int) -> int: return batch_id // cast(int, self.context._epoch_len) - def _train_for_step( - self, step_id: int, num_batches: int, total_batches_processed: int - ) -> workload.Metrics: - """ - DeepSpeed allows specifying train_batch_size, train_micro_batch_size_per_gpu, and - gradient_accumulation_steps. The three are related as follows: - train_batch_size = train_micro_batch_size * gradient_accumulation_steps. - Hence, if two are specified, the third can be inferred. - - For pipeline parallel training, DeepSpeed will automatically interleave - gradient_accumulation_steps worth of micro batches in one train_batch/eval_batch call. - - With the default DeepSpeed model engine (no pipeline parallel training), the backward - and optimizer step calls track micro batches and will automatically update model weights - and lr scheduler if micro batches % gradient_accumulation_steps == 0. - - Comparing training with and without pipeline parallel is a common goal. Since DeepSpeed's - PipelineEngine trains on a number of micro batches equal to gradient accumulation steps, - we automatically perform gradient accumulation by default when pipeline parallelism is not - enabled. This makes it fair to compare training with and without pipeline parallelism - at a given batch idx. This can be turned off by setting - context.disable_auto_grad_accumulation. - """ - assert step_id > 0, "step_id should be greater than 0" - step_start_time = time.time() - self.context.reset_reducers() + def _train( + self, length: pytorch.TrainUnit, train_boundaries: List[pytorch._TrainBoundary] + ) -> None: + while self._steps_until_complete(length) > 0: + train_boundaries, training_metrics = self._train_with_boundaries(train_boundaries) + + metrics = self._aggregate_training_metrics(training_metrics) + metrics = self.context.distributed.broadcast(metrics) + for callback in self.callbacks.values(): + callback.on_training_workload_end( + avg_metrics=metrics["avg_metrics"], + batch_metrics=metrics["batch_metrics"], + ) + + step_reported = False + + for train_boundary in train_boundaries: + if not train_boundary.limit_reached: + continue + + # Train step limits reached, proceed accordingly. + if train_boundary.step_type == pytorch._TrainBoundaryType.TRAIN: + if self.is_chief and not step_reported: + self._report_training_progress() + elif train_boundary.step_type == pytorch._TrainBoundaryType.REPORT: + if self.is_chief and not step_reported: + self._report_training_progress() + elif train_boundary.step_type == pytorch._TrainBoundaryType.VALIDATE: + if not self._validation_is_current(): + self._validate() + elif train_boundary.step_type == pytorch._TrainBoundaryType.CHECKPOINT: + if not self._checkpoint_is_current(): + self._checkpoint(already_exiting=False) + + # Reset train step limit + train_boundary.limit_reached = False + + # After checkpoint/validation steps, check preemption and upload to tensorboard + if self.context.get_enable_tensorboard_logging(): + self._upload_tb_files() + self._stop_requested() + + # Finished training. Perform final checkpoint/validation if necessary. + if not self._validation_is_current(): + self._validate() + if not self._checkpoint_is_current(): + self._checkpoint(already_exiting=False) + + def _train_with_boundaries( + self, train_boundaries: List[pytorch._TrainBoundary] + ) -> Tuple[List[pytorch._TrainBoundary], List]: + training_metrics = [] + + # Start of train step: tell core API and set model mode + if self.is_chief: + self.core_context.train.set_status("training") - # Set the behavior of certain layers (e.g., dropout) that are different - # between training and inference. for model in self.context.models: model.train() - start = total_batches_processed - end = start + num_batches + self.context.reset_reducers() - per_batch_metrics = [] # type: List[Dict] - num_inputs = 0 + epoch_len = self.context._epoch_len + assert epoch_len, "Training dataloader uninitialized." - for batch_idx in range(start, end): - self.steps_completed += 1 - batch_start_time = time.time() + for batch_idx in range(epoch_len): + epoch_idx, batch_in_epoch_idx = divmod(batch_idx, epoch_len) + + # Set the batch index on the trial context used by step_optimizer. self.context._current_batch_idx = batch_idx - if self.context.is_epoch_start(): - for callback in self.callbacks.values(): - callback.on_training_epoch_start(self.get_epoch_idx(batch_idx)) - # This can be inaccurate if the user's data loader does not return batches with - # the micro batch size. It is also slightly inaccurate if the data loader can return - # partial batches. The same sort of assumptions is made in the DeepSpeed - # model engine's accounting and profiling computations. - batch_inputs = ( - self.context.train_micro_batch_size_per_gpu - * self.context.num_micro_batches_per_slot - ) - num_inputs += batch_inputs - num_train_batch_calls = self.context.num_micro_batches_per_slot - if self.context.use_pipeline_parallel or self.context._manual_grad_accumulation: - num_train_batch_calls = 1 - self.context._loss_ids = {} - for _ in range(num_train_batch_calls): - with contextlib.ExitStack() as exit_stack: - if self.context.profiler: - exit_stack.enter_context(self.context.profiler) - - tr_metrics = self.trial.train_batch( - self.training_iterator, - self.get_epoch_idx(batch_idx), - batch_idx, - ) - if self.context.profiler: - self.context.profiler.step() + # Call epoch start callbacks before training first batch in epoch. + if batch_in_epoch_idx == 0: + self._on_epoch_start(epoch_idx) - if self.context._mpu.should_report_metrics: - if isinstance(tr_metrics, torch.Tensor): - tr_metrics = {"loss": tr_metrics} - if not isinstance(tr_metrics, dict): - raise det.errors.InvalidExperimentException( - "train_batch must return a dictionary " - f"mapping string names to Tensor metrics, got {type(tr_metrics)}", - ) + batch_metrics = self._train_batch(batch_idx=batch_idx, epoch_idx=epoch_idx) + training_metrics.extend(batch_metrics) + self._step_batch() - for name, metric in tr_metrics.items(): - # Convert PyTorch metric values to NumPy, so that - # `det.util.encode_json` handles them properly without - # needing a dependency on PyTorch. - if isinstance(metric, torch.Tensor): - metric = metric.cpu().detach().numpy() - tr_metrics[name] = metric - per_batch_metrics.append(tr_metrics) - # We do a check here to make sure that we do indeed process `num_micro_batches_per_slot` - # micro batches when training a batch for models that do not use pipeline parallelism. - model0 = self.context.models[0] - if not isinstance(model0, deepspeed.PipelineEngine): - assert ( - model0.micro_steps % self.context.num_micro_batches_per_slot == 0 - ), "did not train for gradient accumulation steps" - - batch_dur = time.time() - batch_start_time - samples_per_second = batch_inputs / batch_dur - samples_per_second *= self.context._mpu.data_parallel_world_size - - if self.context.is_epoch_end(): - for callback in self.callbacks.values(): - callback.on_training_epoch_end(self.get_epoch_idx(batch_idx)) + # Batch complete: check if any training periods have been reached and exit if any + for step in train_boundaries: + if isinstance(step.unit, pytorch.Batch): + if step.unit.should_stop(batch_idx + 1): + step.limit_reached = True + + # True epoch based training not supported, detect last batch of epoch to calculate + # fully-trained epochs + if isinstance(step.unit, pytorch.Epoch): + if step.unit.should_stop(epoch_idx + 1): + if batch_in_epoch_idx == epoch_len - 1: + step.limit_reached = True + + # Break early after one batch for test mode + if step.step_type == pytorch._TrainBoundaryType.TRAIN and self.test_mode: + step.limit_reached = True + + # Exit if any train step limits have been reached + if any(step.limit_reached for step in train_boundaries): + return train_boundaries, training_metrics + + # True epoch end + return train_boundaries, training_metrics + + def _train_batch(self, epoch_idx: int, batch_idx: int) -> List[dict]: + num_micro_batches = self.context.get_num_micro_batches_per_slot() + if self.context.use_pipeline_parallel or self.context._manual_grad_accumulation: + num_micro_batches = 1 + + # Reset loss IDs for AMP + self.context._loss_ids = {} + + batch_start_time = time.time() + per_batch_metrics = [] # type: List[Dict] + + for _ in range(num_micro_batches): + with contextlib.ExitStack() as exit_stack: + if self.context.profiler: + exit_stack.enter_context(self.context.profiler) + + training_metrics = self.trial.train_batch( + self.training_iterator, + epoch_idx, + batch_idx, + ) + + if self.context.profiler: + self.context.profiler.step() + + if self.context._mpu.should_report_metrics: + if isinstance(training_metrics, torch.Tensor): + training_metrics = {"loss": training_metrics} + if not isinstance(training_metrics, dict): + raise det.errors.InvalidExperimentException( + "train_batch must return a dictionary " + f"mapping string names to Tensor metrics, got {type(training_metrics)}", + ) + + for name, metric in training_metrics.items(): + # Convert PyTorch metric values to NumPy, so that + # `det.util.encode_json` handles them properly without + # needing a dependency on PyTorch. + if isinstance(metric, torch.Tensor): + metric = metric.cpu().detach().numpy() + training_metrics[name] = metric + per_batch_metrics.append(training_metrics) + # We do a check here to make sure that we do indeed process `num_micro_batches_per_slot` + # micro batches when training a batch for models that do not use pipeline parallelism. + model0 = self.context.models[0] + if not isinstance(model0, deepspeed.PipelineEngine): + assert ( + model0.micro_steps % self.context.get_num_micro_batches_per_slot() == 0 + ), "did not train for gradient accumulation steps" + + batch_dur = time.time() - batch_start_time + batch_inputs = ( + self.context.get_train_micro_batch_size_per_gpu() + * self.context.get_num_micro_batches_per_slot() + ) + samples_per_second = batch_inputs / batch_dur + samples_per_second *= self.context.distributed.size # Aggregate and reduce training metrics from all the training processes. - if self.context.distributed.size > 1 and self.context._average_training_metrics: - per_batch_metrics = pytorch._combine_and_average_training_metrics( + if self.context.distributed.size > 1: + metrics = pytorch._combine_and_average_training_metrics( self.context.distributed, per_batch_metrics ) - num_inputs *= self.context._mpu.data_parallel_world_size - metrics = det.util.make_metrics(num_inputs, per_batch_metrics) + else: + metrics = per_batch_metrics + + return metrics + + def _step_batch(self) -> None: + assert self.state + self.state.batches_trained += 1 + + epoch_len = self.context._epoch_len + assert epoch_len, "Training dataloader not initialized." + + # True epoch-based training is not supported. Epoch end is calculated with batch. + epoch_idx, batch_in_epoch_idx = divmod(self.state.batches_trained - 1, epoch_len) + + if batch_in_epoch_idx == epoch_len - 1: + self._on_epoch_end(epoch_idx) + self.state.epochs_trained += 1 + + def _aggregate_training_metrics(self, training_metrics: List[Dict]) -> Dict: + # Aggregate and reduce training metrics from all the training processes. + if self.context.distributed.size > 1: + batch_metrics = pytorch._combine_and_average_training_metrics( + self.context.distributed, training_metrics + ) + else: + batch_metrics = training_metrics + + metrics = det.util.make_metrics(None, batch_metrics) # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch # metrics are even logical for a custom reducer. @@ -480,27 +595,127 @@ def _train_for_step( pytorch._convert_metrics_to_numpy(self.context.reduce_metrics(for_training=True)) ) - if self.is_chief: - step_duration = time.time() - step_start_time - logger.info(det.util.make_timing_log("trained", step_duration, num_inputs, num_batches)) - - if self.context.get_enable_tensorboard_logging(): - det.pytorch._log_tb_metrics( - self.context.get_tensorboard_writer(), - "train", - self.steps_completed, - metrics["avg_metrics"], - metrics["batch_metrics"], - ) - if not self.is_chief: return {} + # Only report on the chief worker + avg_metrics = metrics.get("avg_metrics", {}) + batch_metrics = metrics.get("batch_metrics", []) + + assert self.state + if self.context.get_enable_tensorboard_logging(): + pytorch._log_tb_metrics( + self.context.get_tensorboard_writer(), + "train", + self.state.batches_trained, + avg_metrics, + batch_metrics, + ) + + self.core_context.train.report_training_metrics( + steps_completed=self.state.batches_trained, + metrics=avg_metrics, + batch_metrics=batch_metrics, + ) return metrics + def _is_best_validation(self, now: float, before: Optional[float]) -> bool: + if before is None: + return True + + return (now < before) if self.smaller_is_better else (now > before) + + def _on_epoch_start(self, epoch_idx: int) -> None: + for callback in self.callbacks.values(): + sig = inspect.signature(callback.on_training_epoch_start) + if sig.parameters: + callback.on_training_epoch_start(epoch_idx) + else: + logger.warning( + "on_training_epoch_start() without parameters is deprecated" + " since 0.17.8. Please add epoch_idx parameter." + ) + callback.on_training_epoch_start() # type: ignore[call-arg] + + def _on_epoch_end(self, epoch_idx: int) -> None: + for callback in self.callbacks.values(): + callback.on_training_epoch_end(epoch_idx) + + def _checkpoint(self, already_exiting: bool) -> None: + if self.is_chief: + self.core_context.train.set_status("checkpointing") + + assert self.state + self.state.last_ckpt = self.state.batches_trained + try: + uuid = "" + metadata = { + "determined_version": det.__version__, + "steps_completed": self.state.batches_trained, + "framework": f"torch-{torch.__version__}", + "format": "pickle", + } + with self.context._core.checkpoint.store_path(metadata, shard=True) as ( + path, + storage_id, + ): + self._save(path) + uuid = storage_id + for callback in self.callbacks.values(): + callback.on_checkpoint_upload_end(uuid=uuid) + except det.InvalidHP: + if not already_exiting: + self.core_context.train.report_early_exit(core.EarlyExitReason.INVALID_HP) + raise pytorch._ShouldExit(skip_exit_checkpoint=True) + raise + + def _stop_requested(self) -> None: + if self.core_context.preempt.should_preempt(): + raise pytorch._ShouldExit() + if self.context.get_stop_requested(): + raise pytorch._ShouldExit() + + def _report_training_progress(self) -> None: + assert self.state + assert isinstance(self.max_length.value, int) + + if isinstance(self.max_length, pytorch.Batch): + progress = self.state.batches_trained / self.max_length.value + elif isinstance(self.max_length, pytorch.Epoch): + progress = self.state.epochs_trained / self.max_length.value + else: + raise ValueError(f"unexpected train unit type {type(self.max_length)}") + + self.core_context.train.report_progress(progress=progress) + + def _checkpoint_is_current(self) -> bool: + assert self.state + # State always persists checkpoint step in batches + return self.state.last_ckpt == self.state.batches_trained + + def _validation_is_current(self) -> bool: + assert self.state + # State persists validation step in batches + return self.state.last_val == self.state.batches_trained + + def _steps_until_complete(self, train_unit: pytorch.TrainUnit) -> int: + assert isinstance(train_unit.value, int), "invalid length type" + assert self.state + if isinstance(train_unit, pytorch.Batch): + return train_unit.value - self.state.batches_trained + elif isinstance(train_unit, pytorch.Epoch): + return train_unit.value - self.state.epochs_trained + else: + raise ValueError(f"Unrecognized train unit {train_unit}") + @torch.no_grad() - def _compute_validation_metrics(self) -> workload.Response: + def _validate(self) -> Dict[str, Any]: + # Report a validation step is starting. + if self.is_chief: + self.core_context.train.set_status("validating") + self.context.reset_reducers() + # Set the behavior of certain layers (e.g., dropout) that are # different between training and inference. for model in self.context.models: @@ -512,57 +727,83 @@ def _compute_validation_metrics(self) -> workload.Response: callback.on_validation_start() num_inputs = 0 - keys = None - batch_metrics = [] + metrics = {} # type: Dict[str, Any] - for callback in self.callbacks.values(): - callback.on_validation_epoch_start() - - validation_iterator = iter(self.validation_loader) if self.validation_loader else None - for idx in range(cast(int, self.num_validation_batches)): - num_inputs += cast(int, self.validation_batch_size) - # Note that when using pipeline parallelism, each call to evaluate_batch will request - # self.context.num_micro_batches_per_slot batches from the validation iterator. - # This is why we set self.num_validation_batches differently for pipeline parallel - # and no pipeline parallel when building the data loaders. - vld_metrics = self.trial.evaluate_batch(validation_iterator, idx) - if self.context._mpu.should_report_metrics: - if not isinstance(vld_metrics, dict): - raise det.errors.InvalidExperimentException( - "evaluate_batch must return a dictionary of string names " - "to Tensor metrics", - ) - # Verify validation metric names are the same across batches. - if keys is None: - keys = vld_metrics.keys() + batches_evaluated = -1 + + if self._evaluate_batch_defined(): + keys = None + batch_metrics = [] + + for callback in self.callbacks.values(): + callback.on_validation_epoch_start() + + validation_iterator = iter(self.validation_loader) if self.validation_loader else None + for idx in range(cast(int, self.num_validation_batches)): + batches_evaluated += 1 + num_inputs += cast(int, self.validation_batch_size) + # Note that when using pipeline parallelism, each call to evaluate_batch will + # request self.context.num_micro_batches_per_slot batches from the validation + # iterator. This is why we set self.num_validation_batches differently for + # pipeline parallel and no pipeline parallel when building the data loaders. + if util.has_param(self.trial.evaluate_batch, "batch_idx", 2): + vld_metrics = self.trial.evaluate_batch(validation_iterator, idx) else: - if keys != vld_metrics.keys(): + vld_metrics = self.trial.evaluate_batch(validation_iterator) # type: ignore + if self.context._mpu.should_report_metrics: + if not isinstance(vld_metrics, dict): raise det.errors.InvalidExperimentException( - "Validation metric names must match across all batches of data.", + "evaluate_batch must return a dictionary " + f"mapping string names to Tensor metrics, got {type(vld_metrics)}", ) - # TODO: For performance perform -> cpu() only at the end of validation. - batch_metrics.append(pytorch._convert_metrics_to_numpy(vld_metrics)) - if self.env.test_mode: - break + for name, metric in vld_metrics.items(): + # Convert PyTorch metric values to NumPy, so that + # `det.util.encode_json` handles them properly without + # needing a dependency on PyTorch. + if isinstance(metric, torch.Tensor): + metric = metric.cpu().detach().numpy() + vld_metrics[name] = metric + # Verify validation metric names are the same across batches. + if keys is None: + keys = vld_metrics.keys() + else: + if keys != vld_metrics.keys(): + raise ValueError( + "Validation metric names must match across all batches of data: " + f"{keys} != {vld_metrics.keys()}.", + ) + batch_metrics.append(pytorch._convert_metrics_to_numpy(vld_metrics)) + if self.test_mode: + break - # keys and list(keys) does not satisfy all cases because it will return dict_keys type if - # keys is an empty dict. this will then break when passed to zmq_broadcast since it does - # not know how to serialize dict_keys type. - all_keys = self.context.distributed.gather(keys if keys is None else list(keys)) - if self.is_chief: - all_keys = [k for k in all_keys if k is not None] - keys = all_keys[0] - keys = self.context.distributed.broadcast(keys) + for callback in self.callbacks.values(): + callback.on_validation_epoch_end(batch_metrics) + + metrics = pytorch._reduce_metrics( + self.context.distributed, + batch_metrics=batch_metrics, + keys=keys, + metrics_reducers=pytorch._prepare_metrics_reducers( + self.trial.evaluation_reducer(), keys=keys + ), + ) - for callback in self.callbacks.values(): - callback.on_validation_epoch_end(batch_metrics) + # Gather a list of per-worker (num_inputs, num_batches) tuples. + input_counts = self.context.distributed.gather((num_inputs, batches_evaluated + 1)) + + else: + assert self._evaluate_full_dataset_defined(), "evaluate_full_dataset not defined." + if self.is_chief: + assert self.validation_loader is not None + metrics = self.trial.evaluate_full_dataset(data_loader=self.validation_loader) + + if not isinstance(metrics, dict): + raise TypeError( + f"eval() must return a dictionary, got {type(metrics).__name__}." + ) + + metrics = pytorch._convert_metrics_to_numpy(metrics) - metrics = pytorch._reduce_metrics( - self.context.distributed, - batch_metrics=batch_metrics, - keys=keys, - metrics_reducers=pytorch._prepare_metrics_reducers(pytorch.Reducer.AVG, keys=keys), - ) metrics.update( pytorch._convert_metrics_to_numpy(self.context.reduce_metrics(for_training=False)) ) @@ -573,51 +814,119 @@ def _compute_validation_metrics(self) -> workload.Response: ): logger.debug( "Broadcasting metrics to all worker processes to execute a " - "validation step end callback" + "validation step end callback." ) metrics = self.context.distributed.broadcast(metrics) for callback in self.callbacks.values(): callback.on_validation_end(metrics) + assert self.state + self.state.last_val = self.state.batches_trained + + # Report metrics. if self.is_chief: - num_inputs *= self.context._mpu.data_parallel_world_size - step_duration = time.time() - step_start_time - logger.info( - det.util.make_timing_log( - "validated", step_duration, num_inputs, cast(int, self.num_validation_batches) + # Skip reporting timings if evaluate_full_dataset() was defined. This is far less + # common than evaluate_batch() and we can't know how the user processed their + # validation data. + if self._evaluate_batch_defined(): + # Reshape and sum. + # TODO: remove the type directive once we upgrade to mypy >= 1.7.0 + inputs_total, batches_total = [sum(n) for n in zip(*input_counts)] # type: ignore + step_duration = time.time() - step_start_time + logger.info( + det.util.make_timing_log( + "validated", step_duration, inputs_total, batches_total + ) ) - ) - if self.context.get_enable_tensorboard_logging(): - det.pytorch._log_tb_metrics( - self.context.get_tensorboard_writer(), "val", self.steps_completed, metrics + pytorch._log_tb_metrics( + self.context.get_tensorboard_writer(), + "val", + self.state.batches_trained, + metrics, ) - if not self.is_chief: - return {} + # Get best validation before reporting metrics. + best_validation_before = self.core_context.train.get_experiment_best_validation() - return {"num_inputs": num_inputs, "validation_metrics": metrics} + # We report "batch" and "epoch" only if these keys are not already reported in user + # metrics. + metrics["batches"] = metrics.get("batches", self.state.batches_trained) + metrics["epochs"] = metrics.get("epochs", self.state.epochs_trained) - def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: - if self.context.get_enable_tensorboard_logging(): - det.pytorch._log_tb_metrics( - self.context.get_tensorboard_writer(), "val", self.steps_completed, metrics + self.core_context.train.report_validation_metrics( + steps_completed=self.state.batches_trained, metrics=metrics + ) + should_checkpoint = False + + # Checkpoint according to policy. + if self.is_chief: + if not self._checkpoint_is_current(): + if self.checkpoint_policy == "all": + should_checkpoint = True + elif self.checkpoint_policy == "best": + assert ( + self.searcher_metric_name + ), "checkpoint policy 'best' but searcher metric name not defined" + searcher_metric = self._check_searcher_metric(metrics) + assert searcher_metric is not None + + if self._is_best_validation(now=searcher_metric, before=best_validation_before): + should_checkpoint = True + should_checkpoint = self.context.distributed.broadcast(should_checkpoint) + if should_checkpoint: + self._checkpoint(already_exiting=False) + return metrics + + def _check_searcher_metric(self, val_metrics: Dict) -> Any: + if self.searcher_metric_name not in val_metrics: + raise RuntimeError( + f"Search method is configured to use metric '{self.searcher_metric_name}' but " + f"model definition returned validation metrics {list(val_metrics.keys())}. The " + f"metric used by the search method must be one of the validation " + "metrics returned by the model definition." ) + # Check that the searcher metric has a scalar value so that it can be compared for + # search purposes. Other metrics don't have to be scalars. + searcher_metric = val_metrics[self.searcher_metric_name] + if not util.is_numerical_scalar(searcher_metric): + raise RuntimeError( + f"Searcher validation metric '{self.searcher_metric_name}' returned " + f"a non-scalar value: {searcher_metric}." + ) + return searcher_metric + + def _evaluate_batch_defined(self) -> bool: + return util.is_overridden(self.trial.evaluate_batch, DeepSpeedTrial) + + def _evaluate_full_dataset_defined(self) -> bool: + return util.is_overridden(self.trial.evaluate_full_dataset, DeepSpeedTrial) + def _load(self, load_path: pathlib.Path) -> None: # Right now we will load all checkpoint shards on each node regardless of which # checkpoints are needed. # TODO (Liam): revisit later to optimize sharded checkpoint loading. + potential_paths = [ + ["state_dict.pth"], + ["determined", "state_dict.pth"], + ["pedl", "state_dict.pth"], + ["checkpoint.pt"], + [f"det_state_dict_rank{self.context.distributed.rank}.pth"], + ] # Load stateful things tracked by Determined on all slots. - ckpt_path = f"det_state_dict_rank{self.context.distributed.rank}.pth" - maybe_ckpt = load_path.joinpath(ckpt_path) + checkpoint: Optional[Dict[str, Any]] = None + for ckpt_path in potential_paths: + maybe_ckpt = load_path.joinpath(*ckpt_path) + if maybe_ckpt.exists(): + checkpoint = torch.load(str(maybe_ckpt), map_location="cpu") + break - if not maybe_ckpt.exists(): + if checkpoint is None or not isinstance(checkpoint, dict): return - checkpoint = torch.load(str(maybe_ckpt), map_location="cpu") if not isinstance(checkpoint, dict): raise det.errors.InvalidExperimentException( f"Expected checkpoint at {maybe_ckpt} to be a dict " @@ -665,27 +974,68 @@ def _load(self, load_path: pathlib.Path) -> None: "callback will be initialized from scratch" ) - # Load workload sequencer state. - wlsq_path = load_path.joinpath("workload_sequencer.pkl") - if self.wlsq is not None and wlsq_path.exists(): - with wlsq_path.open("rb") as f: - self.wlsq.load_state(pickle.load(f)) + save_path = load_path.joinpath("trial_state.pkl") + + if save_path.exists(): + with save_path.open("rb") as f: + self._load_state(pickle.load(f)) + else: + # Support legacy save states. + wlsq_path = load_path.joinpath("workload_sequencer.pkl") + if wlsq_path.exists(): + with wlsq_path.open("rb") as f: + self._load_wlsq_state(pickle.load(f)) + + def _load_state(self, state: Any) -> None: + # Load our state from the checkpoint if we are continuing training after a pause or restart. + # If the trial_id doesn't match our current trial id, we're continuing training a previous + # trial and should start from a fresh state. + if state.get("trial_id") != self.trial_id: + self.state = pytorch._TrialState(trial_id=self.trial_id) + return + + self.state = pytorch._TrialState(**state) + assert self.state + + # Detect the case where the final validation we made was against this exact checkpoint. In + # that case, the master will know about the validation, but it would not appear in the + # checkpoint state. If the validation was before the last checkpoint, the checkpoint state + # is already correct, while any validations after the last checkpoint aren't valid anymore + # and can be safely ignored. + if self.state.batches_trained == self.val_from_previous_run: + self.state.last_val = self.state.batches_trained + + def _load_wlsq_state(self, state: Any) -> None: + if state.get("trial_id") != self.trial_id: + self.state = pytorch._TrialState(trial_id=self.trial_id) + return + + self.state = pytorch._TrialState( + trial_id=state.get("trial_id"), + last_ckpt=state.get("last_ckpt"), + last_val=state.get("last_val"), + step_id=state.get("step_id"), + # steps_completed is a legacy field kept to support loading from older checkpoints. + # checkpoints should only persist batches_trained and epochs_trained + batches_trained=state.get("steps_completed"), + epochs_trained=self._get_epoch_idx(state.get("steps_completed")), + ) + + assert self.state + if self.state.batches_trained == self.val_from_previous_run: + self.state.last_val = self.state.batches_trained def _save(self, path: pathlib.Path) -> None: - if self.context.distributed.local_rank == 0: - path.mkdir(parents=True, exist_ok=True) - _ = self.context.distributed.gather_local(None) # sync + path.mkdir(parents=True, exist_ok=True) if self.is_chief: # We assume these stateful objects should be the same across slots and only have # the chief save them. - util.write_user_code(path, self.env.on_cluster) + util.write_user_code(path, not self.local_training) + assert self.state + with path.joinpath("trial_state.pkl").open("wb") as f: + pickle.dump(vars(self.state), f) - if self.wlsq is not None: - with path.joinpath("workload_sequencer.pkl").open("wb") as f: - pickle.dump(self.wlsq.get_state(), f) - - # Save per rank Determined checkpoint. rng_state = { "cpu_rng_state": torch.random.get_rng_state(), "np_rng_state": np.random.get_state(), @@ -694,22 +1044,21 @@ def _save(self, path: pathlib.Path) -> None: if torch.cuda.device_count(): rng_state["gpu_rng_state"] = torch.cuda.get_rng_state( - self.context.distributed.get_local_rank() + self.context.distributed.local_rank ) - checkpoint = {"rng_state": rng_state} # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. - checkpoint["callbacks"] = { - name: callback.state_dict() for name, callback in self.callbacks.items() + checkpoint = { + "callbacks": {name: callback.state_dict() for name, callback in self.callbacks.items()}, + "rng_state": rng_state, } for callback in self.callbacks.values(): callback.on_checkpoint_save_start(checkpoint) - ckpt_name = f"det_state_dict_rank{self.context.distributed.rank}.pth" torch.save(checkpoint, str(path.joinpath(ckpt_name))) @@ -717,6 +1066,22 @@ def _save(self, path: pathlib.Path) -> None: # the save method provided by DeepSpeed. self.trial.save(self.context, path) + with open(path.joinpath("load_data.json"), "w") as f2: + try: + exp_conf = self.context.get_experiment_config() # type: Optional[Dict[str, Any]] + hparams = self.context.get_hparams() # type: Optional[Dict[str, Any]] + except ValueError: + exp_conf = None + hparams = None + + load_data = { + "trial_type": "DeepSpeedTrial", + "experiment_config": exp_conf, + "hparams": hparams, + } + + json.dump(load_data, f2) + for callback in self.callbacks.values(): # TODO(DET-7912): remove on_checkpoint_end once it has been deprecated long enough. callback.on_checkpoint_end(str(path)) @@ -749,8 +1114,8 @@ class DeepSpeedTrial(det.LegacyTrial): """ - trial_controller_class = DeepSpeedTrialController - trial_context_class = det_ds.DeepSpeedTrialContext + trial_controller_class = DeepSpeedTrialController # type: ignore + trial_context_class = det_ds.DeepSpeedTrialContext # type: ignore @abc.abstractmethod def __init__(self, context: det_ds.DeepSpeedTrialContext) -> None: @@ -905,6 +1270,32 @@ def evaluate_batch( """ pass + def evaluate_full_dataset(self, data_loader: torch.utils.data.DataLoader) -> Dict[str, Any]: + """ + Calculate validation metrics on the entire validation dataset and + return them as a dictionary mapping metric names to reduced metric + values (i.e., each returned metric is the average or sum of that metric + across the entire validation set). + + This validation cannot be distributed and is performed on a single + device, even when multiple devices (slots) are used for training. Only + one of :meth:`evaluate_full_dataset` and :meth:`evaluate_batch` should + be overridden by a trial. + + The metrics returned from this function must be JSON-serializable. + + Arguments: + data_loader (torch.utils.data.DataLoader): data loader for evaluating. + """ + pass + + def evaluation_reducer(self) -> Union[pytorch.Reducer, Dict[str, pytorch.Reducer]]: + """ + Return a reducer for all evaluation metrics, or a dict mapping metric + names to individual reducers. Defaults to :obj:`determined.pytorch.Reducer.AVG`. + """ + return pytorch.Reducer.AVG + def save(self, context: det_ds.DeepSpeedTrialContext, path: pathlib.Path) -> None: """ Save is called on every GPU to make sure all checkpoint shards are saved. @@ -943,3 +1334,33 @@ def load( # DeepSpeed does not provide an error message with many assertion errors in the # checkpoint load module. raise AssertionError("Failed to load deepspeed checkpoint.") + + def get_batch_length(self, batch: Any) -> int: + """Count the number of records in a given batch. + + Override this method when you are using custom batch types, as produced + when iterating over the class:`determined.pytorch.DataLoader`. + For example, when using ``pytorch_geometric``: + + .. code-block:: python + + # Extra imports: + from determined.pytorch import DataLoader + from torch_geometric.data.dataloader import Collater + + # Trial methods: + def build_training_data_loader(self): + return DataLoader( + self.train_subset, + batch_size=self.context.get_per_slot_batch_size(), + collate_fn=Collater([], []), + ) + + def get_batch_length(self, batch): + # `batch` is `torch_geometric.data.batch.Batch`. + return batch.num_graphs + + Arguments: + batch (Any): input training or validation data batch object. + """ + return pytorch.data_length(batch) diff --git a/harness/determined/pytorch/deepspeed/_trainer.py b/harness/determined/pytorch/deepspeed/_trainer.py new file mode 100644 index 00000000000..8e36f345235 --- /dev/null +++ b/harness/determined/pytorch/deepspeed/_trainer.py @@ -0,0 +1,335 @@ +import contextlib +import logging +import os +import random +import sys +import warnings +from typing import Any, Dict, Iterator, Optional + +import deepspeed +import numpy as np +import torch + +import determined as det +from determined import core, gpu, pytorch +from determined.pytorch import deepspeed as det_ds + +logger = logging.getLogger("determined.pytorch.deepspeed") + + +class Trainer: + """ + ``pytorch.deepspeed.Trainer`` is an abstraction on top of a DeepSpeed training loop + that handles many training details under-the-hood, and exposes APIs for configuring + training-related features such as automatic checkpointing, validation, profiling, + metrics reporting, etc. + + ``Trainer`` must be initialized and called from within a + ``pytorch.deepspeed.DeepSpeedTrialContext``. + """ + + def __init__(self, trial: det_ds.DeepSpeedTrial, context: det_ds.DeepSpeedTrialContext): + self._trial = trial + self._context = context + self._core = self._context._core + self._info = det.get_cluster_info() + self._local_training = self._info is None or self._info.task_type != "TRIAL" + + def fit( + self, + checkpoint_period: Optional[pytorch.TrainUnit] = None, + validation_period: Optional[pytorch.TrainUnit] = None, + max_length: Optional[pytorch.TrainUnit] = None, + reporting_period: pytorch.TrainUnit = pytorch.Batch(100), # noqa: B008 + checkpoint_policy: str = "best", + latest_checkpoint: Optional[str] = None, + step_zero_validation: bool = False, + test_mode: bool = False, + profiling_enabled: bool = False, + ) -> None: + """ + ``fit()`` trains a ``DeepSpeedTrial`` configured from the ``Trainer`` and handles + checkpointing and validation steps, and metrics reporting. + + Arguments: + checkpoint_period: The number of steps to train for before checkpointing. This is + a ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or + instance of ``collections.abc.Container`` (list, tuple, etc.). For example, + ``Batch(100)`` would checkpoint every 100 batches, while ``Batch([5, 30, 45])`` + would checkpoint after every 5th, 30th, and 45th batch. + validation_period: The number of steps to train for before validating. This is a + ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or instance + of ``collections.abc.Container`` (list, tuple, etc.). For example, ``Batch(100)`` + would validate every 100 batches, while ``Batch([5, 30, 45])`` would validate + after every 5th, 30th, and 45th batch. + max_length: The maximum number of steps to train for. This is a ``TrainUnit`` type + (``Batch`` or ``Epoch``) which takes an ``int``. For example, ``Epoch(1)`` would + train for a maximum length of one epoch. + .. note:: + If using an ASHA searcher, this value should match the searcher config values in + the experiment config (i.e. ``Epoch(1)`` = `max_time: 1` and `time_metric: + "epochs"`). + + reporting_period: The number of steps to train for before reporting metrics and + searcher progress. For local training mode, metrics are printed to stdout. This + is a ``TrainUnit`` type (``Batch`` or ``Epoch``) which can take an ``int`` or + instance of ``collections.abc.Container`` (list, tuple, etc.). For example, + ``Batch(100)`` would report every 100 batches, while ``Batch([5, 30, 45])`` would + report after every 5th, 30th, and 45th batch. + checkpoint_policy: Controls how Determined performs checkpoints after validation + operations, if at all. Should be set to one of the following values: + + best (default): A checkpoint will be taken after every validation operation + that performs better than all previous validations for this experiment. + Validation metrics are compared according to the ``metric`` and + ``smaller_is_better`` fields in the searcher configuration. This option + is only supported for on-cluster training. + all: A checkpoint will be taken after every validation, no matter the + validation performance. + none: A checkpoint will never be taken due to a validation. However, + even with this policy selected, checkpoints are still expected to be taken + after the trial is finished training, due to cluster scheduling decisions, + before search method decisions, or due to ``min_checkpoint_period``. + latest_checkpoint: Configures the checkpoint used to start or continue training. + This value should be set to ``det.get_cluster_info().latest_checkpoint`` for + standard continue training functionality. + step_zero_validation: Configures whether to perform an initial validation before + training. Defaults to false. + test_mode: Runs a minimal loop of training for testing and debugging purposes. Will + train and validate one batch. Defaults to false. + profiling_enabled: Enables system metric profiling functionality for on-cluster + training. Defaults to false. + """ + # Set defaults. + if checkpoint_period is None: + checkpoint_period = pytorch.Batch(sys.maxsize) + + if validation_period is None: + validation_period = pytorch.Batch(sys.maxsize) + + if self._local_training: + if checkpoint_policy == "best": + logger.warning( + "checkpoint_policy='best' is not supported in local training mode. " + "Falling back to 'all'." + ) + checkpoint_policy = "all" + if max_length is None: + raise ValueError("max_length must be defined in local training mode.") + + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type" + ) + + if profiling_enabled: + logger.warning("Profiling is not supported in local training mode.") + + smaller_is_better = True + searcher_metric_name = None + steps_completed = 0 + global_batch_size = None + else: + if test_mode: + raise ValueError("test_mode is only supported in local training mode.") + + assert self._info, "Unable to detect cluster info." + if latest_checkpoint is None and self._info.latest_checkpoint is not None: + logger.warning( + "latest_checkpoint has not been configured. Pause/resume training will not " + "be able to continue from latest checkpoint. Did you mean to set " + "`fit(latest_checkpoint=info.latest_checkpoint)'?" + ) + + smaller_is_better = bool(self._info.trial._config["searcher"]["smaller_is_better"]) + searcher_metric_name = self._info.trial._config["searcher"]["metric"] + steps_completed = int(self._info.trial._steps_completed) + global_batch_size = self._info.trial.hparams.get("global_batch_size", None) + if global_batch_size: + global_batch_size = int(global_batch_size) + + # Backwards compatibility: try to parse legacy `searcher.max_length` if `max_length` + # isn't passed in. + if max_length is None: + max_length_val = core._parse_searcher_max_length(self._info.trial._config) + if max_length_val: + warnings.warn( + "Configuring `max_length` from the `searcher.max_length` experiment " + "config, which was deprecated in 0.38.0 and will be removed in a future " + "release. Please set `fit(max_length=X)` with your desired training length " + "directly.", + FutureWarning, + stacklevel=2, + ) + max_length_unit = core._parse_searcher_units(self._info.trial._config) + max_length = pytorch.TrainUnit._from_searcher_unit( + max_length_val, max_length_unit, global_batch_size + ) + + # If we couldn't parse the legacy `searcher.max_length`, raise an error. + if not max_length: + raise ValueError( + "`fit(max_length=X)` must be set with your desired training length." + ) + if not isinstance(max_length, (pytorch.Batch, pytorch.Epoch)) or not isinstance( + max_length.value, int + ): + raise TypeError( + "max_length must either be a det.pytorch.Batch(int) or det.pytorch.Epoch(int) " + "type." + ) + + _check_searcher_length(exp_conf=self._info.trial._config, max_length=max_length) + + trial_controller = det_ds.DeepSpeedTrialController( + trial_inst=self._trial, + context=self._context, + checkpoint_period=checkpoint_period, + validation_period=validation_period, + smaller_is_better=smaller_is_better, + steps_completed=steps_completed, + latest_checkpoint=latest_checkpoint, + local_training=self._local_training, + test_mode=test_mode, + reporting_period=reporting_period, + searcher_metric_name=searcher_metric_name, + checkpoint_policy=checkpoint_policy, + step_zero_validation=step_zero_validation, + max_length=max_length, + global_batch_size=global_batch_size, + profiling_enabled=profiling_enabled, + ) + + trial_controller.run() + + +def _check_searcher_length( + exp_conf: Dict[str, Any], + max_length: pytorch.TrainUnit, +) -> None: + """ + Certain searchers (ASHA and Adaptive ASHA) require configuring the maximum training length in + the experiment config. We check that the `max_length` passed to `fit()` matches the experiment + config and log warnings if it doesn't. + """ + time_metric = exp_conf["searcher"].get("time_metric") + if time_metric is not None: + max_time = exp_conf["searcher"].get("max_time") + assert max_time, "`searcher.max_time` not configured" + if time_metric == "batches": + if not isinstance(max_length, pytorch.Batch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Batch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + elif time_metric == "epochs": + if not isinstance(max_length, pytorch.Epoch) or max_length.value != max_time: + logger.warning( + f"`max_length` passed into `fit()` method ({max_length}) does not match " + f"`searcher.max_time` and `searcher.time_metric` from the experiment config " + f"(Epoch(value={max_time})). This may result in unexpected hyperparameter " + f"search behavior." + ) + else: + logger.warning( + "`searcher.time_metric` must be either 'batches' or 'epochs' " + f"for training with PyTorchTrials, but got {time_metric}. " + f"Training will proceed with {max_length} but may result in unexpected behavior." + ) + + +def _initialize_distributed_backend() -> Optional[core.DistributedContext]: + info = det.get_cluster_info() + distributed_backend = det._DistributedBackend() + + if distributed_backend.use_deepspeed(): + # We use an environment variable to allow users to enable custom initialization routine for + # distributed training since the pre_execute_hook runs before trial initialization. + manual_dist_init = os.environ.get("DET_MANUAL_INIT_DISTRIBUTED") + if not manual_dist_init: + deepspeed.init_distributed(auto_mpi_discovery=False) + return core.DistributedContext.from_deepspeed() + elif info and (len(info.container_addrs) > 1 or len(info.slot_ids) > 1): + raise ValueError( + "In multi-slot managed cluster training, you must wrap your training script with a " + "distributed launch layer such as determined.launch.deepspeed." + ) + return None + + +def _set_random_seeds(seed: int) -> None: + # Set identical random seeds on all training processes. + # When doing distributed training, each worker will start at a unique + # offset in the dataset, ensuring that it is processing a unique + # training batch. + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + +@contextlib.contextmanager +def init( + *, + hparams: Optional[Dict] = None, + exp_conf: Optional[Dict[str, Any]] = None, + distributed: Optional[core.DistributedContext] = None, + enable_tensorboard_logging: bool = True, +) -> Iterator[det_ds.DeepSpeedTrialContext]: + """ + Creates a DeepSpeedTrialContext for use with a DeepSpeedTrial. All trainer.* calls + must be within the scope of this context because there are resources started in + __enter__ that must be cleaned up in __exit__. + + Arguments: + hparams: (Optional) instance of hyperparameters for the trial + exp_conf: (Optional) for local-training mode. If unset, calling + context.get_experiment_config() will fail. + distributed: (Optional) custom distributed training configuration + enable_tensorboard_logging: Configures if upload to tensorboard is enabled + """ + cluster_info = det.get_cluster_info() + local_training = cluster_info is None or cluster_info.task_type != "TRIAL" + + # Pre-execute steps: initialize distributed backend and random seeds. + distributed_context = distributed + + if not local_training: + distributed_context = _initialize_distributed_backend() + + # Initialize default values. + if local_training: + trial_seed = None + steps_completed = 0 + num_gpus = len(gpu.get_gpu_uuids()) + else: + assert cluster_info, "Unable to detect cluster info" + + trial_seed = cluster_info.trial.trial_seed + exp_conf = cluster_info.trial._config + steps_completed = cluster_info.trial._steps_completed + num_gpus = len(cluster_info.gpu_uuids) + + _set_random_seeds(trial_seed) + + with core.init( + distributed=distributed_context, + preempt_mode=core.PreemptMode.WorkersAskChief, + tensorboard_mode=core.TensorboardMode.MANUAL, + ) as core_context: + context = det_ds.DeepSpeedTrialContext( + core_context=core_context, + trial_seed=trial_seed, + hparams=hparams, + slots_per_trial=core_context.distributed.get_size(), + num_gpus=num_gpus, + exp_conf=exp_conf, + steps_completed=steps_completed, + enable_tensorboard_logging=enable_tensorboard_logging, + ) + + yield context diff --git a/harness/determined/pytorch/dsat/__init__.py b/harness/determined/pytorch/dsat/__init__.py deleted file mode 100644 index a6919e6b44b..00000000000 --- a/harness/determined/pytorch/dsat/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from determined.pytorch.dsat._utils import ( - dsat_reporting_context, - get_full_parser, - get_batch_config_from_mbs_gas_and_slots, - get_ds_config_from_hparams, - get_dict_from_yaml_or_json_path, - get_hf_args_with_overwrites, - get_random_zero_optim_config, - get_search_runner_config_from_args, - smaller_is_better, -) -from determined.pytorch.dsat._dsat_search_method import ( - BaseDSATSearchMethod, - DSATTrial, - DSATTrialTracker, - DSATModelProfileInfoTrial, - ASHADSATSearchData, - DSATSearchData, - RandomDSATSearchMethod, - BinarySearchDSATSearchMethod, - ASHADSATSearchMethod, - TestDSATSearchMethod, -) -from determined.pytorch.dsat._run_dsat import ( - get_custom_dsat_exp_conf_from_args, - get_search_method_class, -) diff --git a/harness/determined/pytorch/dsat/__main__.py b/harness/determined/pytorch/dsat/__main__.py deleted file mode 100644 index bfa441ae2c8..00000000000 --- a/harness/determined/pytorch/dsat/__main__.py +++ /dev/null @@ -1,48 +0,0 @@ -import argparse -import os -import pathlib -import pickle -import tempfile - -from determined.experimental import client -from determined.pytorch import dsat -from determined.pytorch.dsat import defaults - - -def parse_args() -> argparse.Namespace: - parser = dsat.get_full_parser() - args = parser.parse_args() - assert args.max_trials > 1, "--max-trials must be larger than 1" - - # Convert the paths to absolute paths - args.config_path = os.path.abspath(args.config_path) - args.model_dir = os.path.abspath(args.model_dir) - args.include = [os.path.abspath(p) for p in args.include] if args.include is not None else [] - - return args - - -def run_autotuning(args: argparse.Namespace) -> None: - # Build the default SearchRunner's config from the submitted config. The original - # config yaml file is added as an include and is reimported by the SearchRunner later. - - config = dsat.get_search_runner_config_from_args(args) - - # Create empty tempdir as the model_dir and upload everything else as an includes in order to - # preserve the top-level model_dir structure inside the SearchRunner's container. - - with tempfile.TemporaryDirectory() as temp_dir: - # Upload the args, which will be used by the search runner on-cluster. - args_path = pathlib.Path(temp_dir).joinpath(defaults.ARGS_PKL_PATH) - with args_path.open("wb") as f: - pickle.dump(args, f) - includes = [args.model_dir, args.config_path] + args.include - exp = client.create_experiment(config=config, model_dir=temp_dir, includes=includes) - # Note: Simulating the same print functionality as our CLI when making an experiment. - # This line is needed for the e2e tests - print(f"Created experiment {exp.id}") - - -if __name__ == "__main__": - args = parse_args() - run_autotuning(args) diff --git a/harness/determined/pytorch/dsat/_dsat_search_method.py b/harness/determined/pytorch/dsat/_dsat_search_method.py deleted file mode 100644 index 668f03c5942..00000000000 --- a/harness/determined/pytorch/dsat/_dsat_search_method.py +++ /dev/null @@ -1,1432 +0,0 @@ -import abc -import argparse -import collections -import copy -import dataclasses -import json -import logging -import pathlib -import pickle -import random -import uuid -from typing import Any, Deque, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, cast - -import numpy as np - -from determined import searcher, util -from determined.experimental import client -from determined.pytorch import dsat -from determined.pytorch.dsat import defaults - -logger = logging.getLogger("determined.pytorch") - - -class DSATTrial: - """Encapsulation of DeepSpeed Autotune Trials. - - Simple objects for handling all pertinent information and results for every created Trial. - Contains basic lineage tracking in which each `DSATTrial` instance holds direct references to - its immediate parent and children, along with various helper properties. - """ - - def __init__( - self, - hparams: Dict[str, Any], - model_dir: str, - slots_per_trial: int, - length: int, - request_id: Optional[uuid.UUID] = None, - parent: Optional["DSATTrial"] = None, - search_data: Optional["DSATSearchData"] = None, - searcher_metric_name: Optional[str] = None, - ) -> None: - self.hparams = hparams - self.model_dir = model_dir - self.slots_per_trial = slots_per_trial - self.length = length - self.request_id = request_id or uuid.uuid4() - self.parent = parent - # Arbitrary attribute for search-specific data tracking. - self.search_data: Optional["DSATSearchData"] = search_data - self.searcher_metric_name = searcher_metric_name - - # Other attrs which are updated during training: - - self.metric: Union[float, Dict[str, Any]] = {} - self.error = False - self.running = False - self.children: Set["DSATTrial"] = set() - - # If a parent was specified, register the current Trial as the parent's child. - if self.parent is not None: - self.parent.children.add(self) - - self.lineage_root: DSATTrial = self if self.parent is None else self.parent.lineage_root - - # The DS config json file may either be in the specified model directory or in the base of - # the workdir, if it was added as an `--include` arg. - try: - self.ds_config = dsat.get_ds_config_from_hparams(self.hparams, self.model_dir) - except FileNotFoundError: - self.ds_config = dsat.get_ds_config_from_hparams(self.hparams) - - self._error_in_direct_history = False - - @property - def completed(self) -> bool: - return bool(self.error or self.metric) - - @property - def lineage_set(self) -> Set["DSATTrial"]: - """Computes set of trials in lineage tree.""" - root = self.lineage_root - trials_set = {root} - children = set(root.children) - while children: - random_child = children.pop() - trials_set.add(random_child) - children |= random_child.children - return trials_set - - @property - def num_completed_trials_in_lineage(self) -> int: - """Computes total number of trials in lineage tree.""" - num_trials = sum(trial.completed for trial in self.lineage_set) - return num_trials - - @property - def error_in_direct_history(self) -> bool: - if self._error_in_direct_history: - return self._error_in_direct_history - trial: Optional["DSATTrial"] = self - while trial is not None: - if trial.error: - return True - trial = trial.parent - return False - - @property - def mbs_in_lineage(self) -> Set[int]: - """ - Returns the set of all `train_micro_batch_size_per_gpu` (mbs) used in the Trial's lineage. - """ - mbs_in_lineage = {t.mbs for t in self.lineage_set} - return mbs_in_lineage - - @property - def stage(self) -> int: - return int(self.ds_config.get("zero_optimization", {}).get("stage", 0)) - - @property - def fp16(self) -> bool: - return bool(self.ds_config.get("fp16", {}).get("enabled")) or False - - @property - def mbs(self) -> int: - assert "train_micro_batch_size_per_gpu" in self.ds_config, ( - "The DSATTrial must be provided with a `ds_config` that contains the" - " key `train_micro_batch_size_per_gpu`" - ) - assert isinstance( - self.ds_config["train_micro_batch_size_per_gpu"], int - ), "The DSATTrial must be provided an `int` value for `train_micro_batch_size_per_gpu`" - return self.ds_config["train_micro_batch_size_per_gpu"] - - @property - def create_and_val_ops(self) -> List[searcher.Operation]: - """ - Returns a list with the searcher.Create and searcher.ValidateAfter operations - needed to initiate and run the specified Trial. - """ - create_op = searcher.Create( - request_id=self.request_id, - hparams=self.hparams, - checkpoint=None, - ) - validate_after_op = searcher.ValidateAfter(request_id=self.request_id, length=self.length) - ops_list = [create_op, validate_after_op] - - return ops_list - - @property - def searcher_metric_val(self) -> Optional[float]: - if self.searcher_metric_name is None: - return None - if isinstance(self.metric, float): - return self.metric - val = self.metric.get(self.searcher_metric_name) - if val is not None: - return float(val) - return val - - -class DSATModelProfileInfoTrial(DSATTrial): - """ - Super class for differentiating the model profiling info run. - """ - - -class DSATTrialTracker: - """Primary stateful object for tracking DeepSpeed Autotune Experiments. - - Holds references to all genereated `DSATTrial` instances, as well as the - `DSATModelProfileInfoTrial` and handles queueing through its `queue` attribute. - Class for organizing DSATTrial instances and retrieving pertinent info. Provides helper - functions for generating the appropriate `DSATModelProfileInfoTrial` and `DSATTrial` instances - with consistent batch sizes and configurations in line with CLI arguments. - """ - - def __init__( - self, - args: argparse.Namespace, - exp_config: Dict[str, Any], - ) -> None: - self.exp_config = exp_config - self.max_trials: int = args.max_trials - self.max_concurrent_trials = args.max_concurrent_trials - self.max_slots: int = args.max_slots - self.model_dir = args.model_dir - self.searcher_metric = args.metric - self.start_profile_step = args.start_profile_step - self.end_profile_step = args.end_profile_step - self.zero_stages = set(args.zero_stages) - - # Derived attributes - self.slots_per_trial: int = self.exp_config["resources"]["slots_per_trial"] - self.hparams: Dict[str, Any] = self.exp_config["hyperparameters"] - - self.smaller_is_better = dsat.smaller_is_better(self.searcher_metric) - - self.model_profile_info_trial: Optional["DSATTrial"] = None - self.num_trials_since_best_result: int = 0 - self.successful_stages: Set[int] = set() - self._all_trials_dict: Dict[uuid.UUID, "DSATTrial"] = {} - self.queue: Deque["DSATTrial"] = collections.deque() - - self._mem_per_gpu_per_stage: Optional[Dict[int, int]] = None - self._approx_max_mbs_per_stage: Optional[Dict[int, int]] = None - - def __len__(self) -> int: - return len(self._all_trials_dict) - - def __getitem__(self, request_id: uuid.UUID) -> DSATTrial: - return self._all_trials_dict[request_id] - - def __iter__(self) -> Iterator[Tuple[uuid.UUID, "DSATTrial"]]: - return iter(self._all_trials_dict.items()) - - def __contains__(self, item: Union[uuid.UUID, DSATTrial]) -> bool: - if isinstance(item, uuid.UUID): - return item in self._all_trials_dict - elif isinstance(item, DSATTrial): - return item in self._all_trials_dict.values() - else: - raise ValueError( - f"Expected a `uuid.UUID` or `DSATTrial` instance, instead received an object of" - f" type {type(item)}" - ) - - def create_trial( - self, - hparams: Dict[str, Any], - search_data: Optional[Any] = None, - parent_trial: Optional[DSATTrial] = None, - ) -> DSATTrial: - """ - Helper function which creates a new `DSATTrial` object of the appropriate length, given the - config, while also enforcing a consistent DS batch size configuration. - """ - # Create a consistent batch size configuration which obeys the DS constraints. - self.enforce_consistent_batch_config(hparams) - - # For some reason, DS (0.8.3) exits in the DeepSpeedEngine.step call when - # DeepSpeedEngine.global_step (initiated at zero) equals end_profile_step + 1, - # with global_step updated *before* this check happens. So, we need to run for - # a length of end_profile_step + 1 to trigger the exit. Presumably an off-by-one error - # on their end. - trial = DSATTrial( - hparams=hparams, - model_dir=self.model_dir, - slots_per_trial=self.slots_per_trial, - length=self.end_profile_step + 1, - parent=parent_trial, - search_data=search_data, - searcher_metric_name=self.searcher_metric, - ) - return trial - - def create_model_profile_info_trial( - self, - ) -> DSATModelProfileInfoTrial: - # Create the special hp dictionary used for the model profile info run. - model_profile_info_hps = copy.deepcopy(self.hparams) - model_profile_info_hps[defaults.OVERWRITE_KEY] = util.merge_dicts( - model_profile_info_hps.get(defaults.OVERWRITE_KEY, {}), - defaults.MODEL_INFO_PROFILE_DS_CONFIG, - ) - self.enforce_consistent_batch_config(model_profile_info_hps) - - model_profile_info_trial = DSATModelProfileInfoTrial( - hparams=model_profile_info_hps, - model_dir=self.model_dir, - slots_per_trial=self.slots_per_trial, - length=1, # Only need a single step. - ) - self.model_profile_info_trial = model_profile_info_trial - return model_profile_info_trial - - def queue_and_register_trial(self, trial: DSATTrial) -> None: - """ - Helper function which both adds the `trial` to the queue and the internal dictionary - tracking all trials. - """ - # Verify that the given trial was not previously completed. - for other_trial in self.completed_trials: - if trial.hparams == other_trial.hparams: - logger.warning( - f"Skipping attempt to queue Trial identical to {other_trial.request_id}" - ) - self._all_trials_dict[trial.request_id] = trial - self.queue.append(trial) - - def enforce_consistent_batch_config(self, hparams: Dict[str, Any]) -> None: - """Enforces a consistent batch size configuration by altering `hparams` in-place.""" - try: - ds_config = dsat.get_ds_config_from_hparams(hparams, self.model_dir) - except FileNotFoundError: - # In case the DS json config was added as an `--include` arg. - ds_config = dsat.get_ds_config_from_hparams(hparams) - batch_size_config = dsat.get_batch_config_from_mbs_gas_and_slots( - ds_config, slots=self.slots_per_trial - ) - hparams[defaults.OVERWRITE_KEY] = util.merge_dicts( - hparams[defaults.OVERWRITE_KEY], batch_size_config - ) - - def update_trial_metric( - self, - trial: DSATTrial, - metric: Union[float, Dict[str, Any]], - ) -> None: - """ - Updates the Trial Tracker after metrics have been reported, attaching the reported metrics - to the `DSATTrial` instnace and updating early-stopping bookkeeping. - """ - trial.metric = metric - - # The model info profiling run's metric will not contain the searcher metric key and should - # not be counted against the early stopping criteria. - if not isinstance(trial, DSATModelProfileInfoTrial): - self.successful_stages.add(trial.stage) - trial_is_best = self.best_trial == trial - if trial_is_best: - self.num_trials_since_best_result = 0 - else: - self.num_trials_since_best_result += 1 - trial.running = False - - def report_trial_early_exit(self, trial: DSATTrial) -> None: - # `self.num_trials_since_best_result` is only incremented after a best trial has been - # established. - if self.best_trial is not None: - self.num_trials_since_best_result += 1 - - trial.error = True - trial.running = False - - def _fetch_model_profile_info_data(self, param_name: str) -> int: - assert ( - self.model_profile_info_trial is not None - ), f"The `DSATModelProfileInfoTrial` must be run before requesting its `{param_name}`" - assert isinstance( - self.model_profile_info_trial.metric, dict - ), "The `DSATModelProfileInfoTrial` must be provided with a metric dictionary" - assert param_name in self.model_profile_info_trial.metric, ( - "The `DSATModelProfileInfoTrial` must be provided with a metric dict that contains the" - f" key `{param_name}`" - ) - assert isinstance( - self.model_profile_info_trial.metric[param_name], int - ), f"The `DSATModelProfileInfoTrial` must be provided an `int` value for `{param_name}`" - return int(self.model_profile_info_trial.metric[param_name]) - - @property - def gpu_mem(self) -> int: - """ - Returns the available GPU memory in bytes according to the `DSATModelProfileInfoTrial` - """ - return self._fetch_model_profile_info_data("gpu_mem") - - @property - def num_params(self) -> int: - """ - Returns the number of params according to the `DSATModelProfileInfoTrial` - """ - return self._fetch_model_profile_info_data("num_params") - - @property - def trainable_num_params(self) -> int: - """ - Returns the number of trainable params according to the `DSATModelProfileInfoTrial` - """ - return self._fetch_model_profile_info_data("trainable_num_params") - - @property - def activation_mem_per_gpu(self) -> int: - """ - Returns the amount of activation memory per gpu in bytes according to - the `DSATModelProfileInfoTrial` - """ - return self._fetch_model_profile_info_data("activation_mem_per_gpu") - - @property - def mem_per_gpu_per_stage(self) -> Dict[int, int]: - """ - Returns the required gpu memory in bytes, per stage, according to whether fp16 training was - used (other low-precision cases not handled). - """ - assert ( - self.model_profile_info_trial is not None - ), "The model profile info Trial must be run before calling this method." - fp16 = self.model_profile_info_trial.fp16 - if self._mem_per_gpu_per_stage is None: - params_mem = self.num_params * (2 if fp16 else 4) - # Gradients must be converted to fp32 to update master weights, so they eventually - # require the same memory regardless of whether mixed-precision is used. - gradients_mem = self.trainable_num_params * 4 - # optimizer_mem assumes Adam, following DS. TODO: don't assume this (MLG-584). - master_params_mem = 4 if fp16 else 0 - momentum_mem = variance_mem = 4 - optimizer_mem = self.trainable_num_params * ( - master_params_mem + momentum_mem + variance_mem - ) - - non_activation_mem_per_gpu_per_stage = { - 0: params_mem + gradients_mem + optimizer_mem, - 1: params_mem + gradients_mem + optimizer_mem // self.slots_per_trial, - 2: params_mem + (gradients_mem + optimizer_mem) // self.slots_per_trial, - 3: (params_mem + gradients_mem + optimizer_mem) // self.slots_per_trial, - } - # In DS there is an mp_size int which can be used for model parallelism and also enters - # the memory computation, but we will not support that feature at the moment. - - mem_per_gpu_per_stage = { - stage: mem + self.activation_mem_per_gpu - for stage, mem in non_activation_mem_per_gpu_per_stage.items() - } - self._mem_per_gpu_per_stage = mem_per_gpu_per_stage - return self._mem_per_gpu_per_stage - - @property - def approx_max_mbs_per_stage(self) -> Dict[int, int]: - """ - Returns the approximate max train_micro_batch_size_per_gpu (mbs) per stage. - """ - if self._approx_max_mbs_per_stage is None: - self._approx_max_mbs_per_stage = { - stage: max((self.gpu_mem - mem) // self.activation_mem_per_gpu, 1) - for stage, mem in self.mem_per_gpu_per_stage.items() - } - return self._approx_max_mbs_per_stage - - def _best_trial_fn(self, trials: Iterable["DSATTrial"]) -> Optional["DSATTrial"]: - trials_with_searcher_metric = [ - trial - for trial in trials - if not isinstance(trial, DSATModelProfileInfoTrial) - and isinstance(trial.metric, dict) - and self.searcher_metric in trial.metric - ] - if not trials_with_searcher_metric: - return None - - min_or_max = min if self.smaller_is_better else max - best_trial = min_or_max( - trials_with_searcher_metric, - key=lambda trial: trial.metric - if isinstance(trial.metric, float) - else float(trial.metric[self.searcher_metric]), - ) - return best_trial - - @property - def best_trials_by_stage(self) -> Dict[int, Optional["DSATTrial"]]: - _best_trials_by_stage: Dict[int, Optional["DSATTrial"]] = {} - for stage in range(4): - trials_to_check = [trial for _, trial in self if trial.stage == stage] - best_trial = self._best_trial_fn(trials_to_check) - _best_trials_by_stage[stage] = best_trial - return _best_trials_by_stage - - @property - def best_trial(self) -> Optional["DSATTrial"]: - best_trial = self._best_trial_fn( - trial for trial in self.best_trials_by_stage.values() if trial is not None - ) - return best_trial - - @property - def running_trials(self) -> List[DSATTrial]: - return [trial for _, trial in self if trial.running] - - @property - def completed_trials(self) -> List[DSATTrial]: - return [trial for _, trial in self if trial.completed] - - @property - def num_running_trials(self) -> int: - return len(self.running_trials) - - @property - def num_completed_trials(self) -> int: - return len(self.completed_trials) - - @property - def max_trials_queued(self) -> bool: - return len(self.queue) >= self.max_trials - - @property - def max_trials_are_running_or_closed(self) -> bool: - return self.num_running_trials + self.num_completed_trials >= self.max_trials - - @property - def should_be_failure(self) -> bool: - model_profile_info_trial_failed = ( - self.model_profile_info_trial is not None and self.model_profile_info_trial.error - ) - every_autotuning_trial_failed = all( - trial.error - for _, trial in self - if trial.completed and not isinstance(trial, DSATModelProfileInfoTrial) - ) - return model_profile_info_trial_failed or every_autotuning_trial_failed - - @property - def can_run_more_trials(self) -> int: - if not self.queue: - return False - if self.max_trials_are_running_or_closed: - return False - if self.num_running_trials >= self.max_concurrent_trials: - return False - if self.max_slots is not None: - occupied_slots = self.num_running_trials * self.slots_per_trial - remaining_slots = self.max_slots - occupied_slots - trials_available_with_remaining_slots = remaining_slots // self.slots_per_trial - return trials_available_with_remaining_slots > 0 - return True - - -class BaseDSATSearchMethod(searcher.SearchMethod): - """Base class for all Determined AI DeepSpeed Autotune searchers. - - Contains two abstract methods: `get_trials_after_validation_completed` and - `get_trials_after_early_exit` which return iterables of `DSATTrial` after their respective - events occur. The `early_stopping_triggered` and `choose_next_trial_from_queue` methods are also - provided with the intention of overwriting for further fine-grained control. The base class - ensures that global constraints such as `max_trials`, `max_concurrent_trials`, and `max_slots` - are respected by all subclasses. The `trial_tracker` attribute (a `DSATTrialTracker` instance) - is the stateful object which tracks results and the queued Trials. - """ - - def __init__(self, args: argparse.Namespace, exp_config: Dict[str, Any]) -> None: - # Storing args so that additional args can be inherited by child classes - self.args = args - self.exp_config = exp_config - self.trial_tracker = DSATTrialTracker(args=args, exp_config=exp_config) - self.rng = np.random.default_rng(seed=args.random_seed) - random.seed(args.random_seed) - - self._tracker_ckpt_path = "trial_tracker.pkl" - self._py_rand_ckpt_path = "py_random_state.pkl" - self._np_rand_ckpt_path = "np_rng.pkl" - - @abc.abstractmethod - def get_trials_after_validation_completed( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - metric: Union[float, Dict[str, Any]], - ) -> Iterable[DSATTrial]: - """ - All returned `DSATTrial`s will be `append`-ed to `self.trial_tracker.queue` in the order - they are provided. - """ - pass - - @abc.abstractmethod - def get_trials_after_early_exit( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - exited_reason: searcher.ExitedReason, - ) -> Iterable[DSATTrial]: - """ - All returned `DSATTrial`s will be `append`-ed to `self.trial_tracker.queue` in the order - they are provided. - """ - pass - - def choose_next_trial_from_queue(self) -> DSATTrial: - """ - Called whenever resources exist to run an additional Trial. Overwrite if more complex - logic is needed. - """ - - next_trial = self.trial_tracker.queue.popleft() - return next_trial - - def initial_operations( - self, searcher_state: searcher.SearcherState - ) -> List[searcher.Operation]: - """ - Submits the model info profiling run in order to collect model and resources info to - inform the search. - """ - - model_profile_info_trial = self.trial_tracker.create_model_profile_info_trial() - self.trial_tracker.queue_and_register_trial(model_profile_info_trial) - self.trial_tracker.queue.popleft() # Needed for bookkeeping. - ops = model_profile_info_trial.create_and_val_ops - return ops - - def on_trial_created( - self, searcher_state: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - return [] - - def on_validation_completed( - self, - searcher_state: searcher.SearcherState, - request_id: uuid.UUID, - metric: Union[float, Dict[str, Any]], - train_length: int, - ) -> List[searcher.Operation]: - last_trial = self.trial_tracker[request_id] - self.trial_tracker.update_trial_metric(trial=last_trial, metric=metric) - - if isinstance(last_trial, DSATModelProfileInfoTrial): - logger.info(f"Approx. max mbs per stage: {self.trial_tracker.approx_max_mbs_per_stage}") - logger.info(f"Approx. GPU memory per stage: {self.trial_tracker.mem_per_gpu_per_stage}") - logger.info(f"Total GPU memory: {self.trial_tracker.gpu_mem}") - - if not self.trial_tracker.max_trials_queued and not self.should_shutdown(): - new_trials = self.get_trials_after_validation_completed( - searcher_state=searcher_state, - last_trial=last_trial, - metric=metric, - ) - for trial in new_trials: - self.trial_tracker.queue_and_register_trial(trial) - - # All DS AT Trials should be closed after validation. - return [searcher.Close(request_id)] - - def on_trial_exited_early( - self, - searcher_state: searcher.SearcherState, - request_id: uuid.UUID, - exited_reason: searcher.ExitedReason, - ) -> List["searcher.Operation"]: - last_trial = self.trial_tracker[request_id] - self.trial_tracker.report_trial_early_exit(last_trial) - - new_ops_list: List["searcher.Operation"] = [] - if exited_reason != searcher.ExitedReason.ERRORED: - # In case of INVALID_HP or USER_CANCELED, shut down the searcher. - logger.info( - f"Shutting down: unexpected early exit due to {exited_reason}" - f"\nLast trial: {last_trial}, request_id: {request_id}" - ) - new_ops_list.append(searcher.Shutdown(failure=self.trial_tracker.should_be_failure)) - if not self.trial_tracker.max_trials_queued and not self.should_shutdown(): - # ERRORED Trials generally corresponds to OOMs, after which we may want to submit - # follow-on Trials. - new_trials = self.get_trials_after_early_exit( - searcher_state=searcher_state, - last_trial=last_trial, - exited_reason=exited_reason, - ) - for trial in new_trials: - self.trial_tracker.queue_and_register_trial(trial) - self.trial_tracker.queue - - return new_ops_list - - def on_trial_closed( - self, searcher_state: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - new_ops_list: List[searcher.Operation] = [] - if self.should_shutdown(): - if self.trial_tracker.best_trial is not None and self.args.run_full_experiment: - submitted_config = dsat.get_dict_from_yaml_or_json_path(self.args.config_path) - optimal_config = util.merge_dicts( - submitted_config, {"hyperparameters": self.trial_tracker.best_trial.hparams} - ) - # Delete the keys which enforce autotuning code paths - del optimal_config["hyperparameters"][defaults.OVERWRITE_KEY]["autotuning"] - del optimal_config["hyperparameters"][defaults.USE_DSAT_MODE_KEY] - client.create_experiment(optimal_config, self.args.model_dir, self.args.include) - - new_ops_list.append(searcher.Shutdown(failure=self.trial_tracker.should_be_failure)) - else: - while self.trial_tracker.can_run_more_trials: - next_trial = self.choose_next_trial_from_queue() - next_trial.running = True - new_ops_list.extend(next_trial.create_and_val_ops) - - return new_ops_list - - def progress(self, searcher_state: searcher.SearcherState) -> float: - progress = len(searcher_state.trials_closed) / self.trial_tracker.max_trials - return progress - - def save_method_state(self, path: pathlib.Path) -> None: - with path.joinpath(self._tracker_ckpt_path).open("wb") as f: - pickle.dump(self.trial_tracker, f) - with path.joinpath(self._py_rand_ckpt_path).open("wb") as f: - pickle.dump(random.getstate(), f) - with path.joinpath(self._np_rand_ckpt_path).open("wb") as f: - pickle.dump(self.rng, f) - if self.trial_tracker.best_trial is not None: - with path.joinpath("best_ds_config.json").open("w") as ds_config_f: - best_ds_metrics = copy.deepcopy(self.trial_tracker.best_trial.ds_config) - del best_ds_metrics["autotuning"] - json.dump(best_ds_metrics, ds_config_f) - with path.joinpath("best_ds_metrics.json").open("w") as ds_metrics_f: - json.dump(self.trial_tracker.best_trial.metric, ds_metrics_f) - - def load_method_state(self, path: pathlib.Path) -> None: - logger.info("Restoring searcher state from checkpoint.") - with path.joinpath(self._tracker_ckpt_path).open("rb") as f: - self.trial_tracker = cast(DSATTrialTracker, pickle.load(f)) - with path.joinpath(self._py_rand_ckpt_path).open("rb") as f: - py_random_state = pickle.load(f) - random.setstate(py_random_state) - with path.joinpath(self._np_rand_ckpt_path).open("rb") as f: - self.rng = pickle.load(f) - - def should_shutdown(self) -> bool: - """ - Conditions on which to shutdown the search. - """ - if ( - self.trial_tracker.model_profile_info_trial is not None - and self.trial_tracker.model_profile_info_trial.error - ): - logger.info( - "Shutting down: error in model profile info Trial." - " You may need to specify a configuration which can successfully run with" - " `train_micro_batch_size_per_gpu = 1`." - ) - return True - if self.early_stopping_triggered(): - logger.info("Shutting down: early stopping criteria met.") - return True - if self.trial_tracker.num_completed_trials >= self.trial_tracker.max_trials: - logger.info("Shutting down: all Trials completed.") - return True - return False - - def early_stopping_triggered(self) -> bool: - """ - Overwrite to implement search-method-specific early-stopping logic. - """ - return False - - -@dataclasses.dataclass -class DSATSearchData: - """Basic binary-search type data used to guide DS AT.""" - - lo: int - hi: int - - -class RandomDSATSearchMethod(BaseDSATSearchMethod): - """ - Implements a random search through DeepSpeed configuration space with an approximate binary - search on batch sizes. Utilizes aggressive early stopping based on the results of other Trials - and heuristics based on domain knowledge of DeepSpeed. Uses two search-specific arguments: - - Args: - trials_per_random_config: - the maximum number of Trials which will be used to optimize each randomly-generated - configuration - early_stopping: - the maximum number of Trials to run without improving results after a best-found - configuration has been established - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.trials_per_random_config = self.args.trials_per_random_config - self.early_stopping: int = self.args.early_stopping - - def get_trials_after_validation_completed( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - metric: Optional[Union[float, Dict[str, Any]]] = None, - ) -> List[DSATTrial]: - new_trials = [] - if isinstance(last_trial, DSATModelProfileInfoTrial): - new_trials = self.get_trial_list_after_model_profile_info_run() - else: - new_trials = self.get_trial_list_after_successful_run(last_trial) - - return new_trials - - def get_trials_after_early_exit( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - exited_reason: searcher.ExitedReason, - ) -> List[DSATTrial]: - new_trials = [] - - if self.should_stop_lineage(last_trial): - logger.info(f"Killing trial {last_trial.request_id}") - new_trials.append(self.get_random_trial()) - else: - if last_trial.search_data is None: - return new_trials - new_search_data = copy.deepcopy(last_trial.search_data) - new_search_data.hi = last_trial.mbs - 1 - - mbs = self.get_random_mbs_from_search_data(new_search_data) - new_hparams = copy.deepcopy(last_trial.hparams) - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - - trial = self.trial_tracker.create_trial( - hparams=new_hparams, - search_data=new_search_data, - parent_trial=last_trial, - ) - new_trials.append(trial) - return new_trials - - def choose_next_trial_from_queue(self) -> DSATTrial: - """ - Continually removes Trials whose lineages should be stopped from the front of the queue - while adding their corresponding replacements, finally returning the next Trial which should - be run. - """ - - next_trial = self.trial_tracker.queue.popleft() - while self.should_stop_lineage(next_trial): - self.trial_tracker.queue_and_register_trial(self.get_random_trial()) - next_trial = self.trial_tracker.queue.popleft() - - return next_trial - - def get_trial_list_after_model_profile_info_run(self) -> List[DSATTrial]: - new_trials = [] - concurrent_trials = self.args.max_concurrent_trials - if self.args.max_slots is not None: - concurrent_trials_from_slots = self.args.max_slots // self.trial_tracker.slots_per_trial - concurrent_trials = min(concurrent_trials, concurrent_trials_from_slots) - for _ in range(concurrent_trials): - trial = self.get_random_trial() - new_trials.append(trial) - return new_trials - - def get_trial_list_after_successful_run( - self, - last_trial: DSATTrial, - ) -> List[DSATTrial]: - if self.should_stop_lineage(trial=last_trial) or last_trial.search_data is None: - return [self.get_random_trial()] - - new_search_data = copy.deepcopy(last_trial.search_data) - new_search_data.lo = last_trial.mbs + 1 - # It is possible lo > hi in the case where initial soft ceiling computation was innaccurate - # in which case we double hi. - if new_search_data.lo > new_search_data.hi: - new_search_data.hi *= 2 - - mbs = self.get_random_mbs_from_search_data(new_search_data) - - new_hparams = copy.deepcopy(last_trial.hparams) - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - trial = self.trial_tracker.create_trial( - hparams=new_hparams, - search_data=new_search_data, - parent_trial=last_trial, - ) - return [trial] - - def should_stop_lineage(self, trial: DSATTrial) -> bool: - # General conditions - assert trial.search_data is not None - failed_on_min_mbs = trial.error and trial.search_data and trial.mbs <= trial.search_data.lo - - exceeded_trials_per_random_config_limit = ( - trial.num_completed_trials_in_lineage >= self.trials_per_random_config - ) - - # DS domain knowledge: if stages 1 or 2 run successfully, there is no need to use stage 3. - stage_one_or_two_successful = {1, 2} & self.trial_tracker.successful_stages - should_stop_this_stage_3_trial = trial.stage == 3 and stage_one_or_two_successful - - # Check if other same-stage trials have successfully run with larger batch sizes than this - # lineage can possibly run. - - other_configs_run_larger_batch_sizes = ( - trial.error_in_direct_history - and trial.search_data - and any( - other_trial.mbs >= trial.search_data.hi - for _, other_trial in self.trial_tracker - if other_trial.stage == trial.stage and other_trial.searcher_metric_val is not None - ) - ) - - if ( - failed_on_min_mbs - or exceeded_trials_per_random_config_limit - or should_stop_this_stage_3_trial - or other_configs_run_larger_batch_sizes - ): - return True - - return False - - def get_random_mbs_from_search_data(self, search_data: DSATSearchData) -> int: - """ - Randomly choose a mbs given the `search_data` bounds. Random choice covers a larger search - volume than simply choosing the midpoint. Draws from a binomial distribution, to keep the - results still somewhat focused near the midpoint. - """ - mbs: int = search_data.lo + self.rng.binomial(search_data.hi - search_data.lo, 0.5) - return mbs - - def get_random_hparams_and_search_data( - self, zero_stage: int - ) -> Tuple[Dict[str, Any], DSATSearchData]: - zero_optim_config = dsat.get_random_zero_optim_config(zero_stage) - new_hparams = copy.deepcopy(self.trial_tracker.hparams) - new_hparams[defaults.OVERWRITE_KEY] = util.merge_dicts( - new_hparams.get(defaults.OVERWRITE_KEY, {}), - {"zero_optimization": zero_optim_config}, - ) - - # If a best trial has been established for the given stage, use its search data bounds to - # choose a better starting point. - best_trial_for_stage = self.trial_tracker.best_trials_by_stage[zero_stage] - - if best_trial_for_stage is not None and best_trial_for_stage.search_data is not None: - new_search_data = copy.deepcopy(best_trial_for_stage.search_data) - # Update the floor to one greater than the mbs used and raise the ceiling to be - # the maximum between the largest mbs trial of this stage which was successful, the - # best trial's ceiling, and twice as large as the floor. - new_search_data.lo = best_trial_for_stage.mbs + 1 - largest_successful_batch_size_for_stage = max( - t.mbs - for t in self.trial_tracker.completed_trials - if t.stage == best_trial_for_stage.stage - and isinstance(t.metric, dict) - and t.metric.get(self.trial_tracker.searcher_metric) is not None - ) - new_search_data.hi = max( - largest_successful_batch_size_for_stage, new_search_data.hi, 2 * new_search_data.lo - ) - # Otherwise choose the corresponding search data based on approximate computations - else: - random_zero_stage_max_mbs = self.trial_tracker.approx_max_mbs_per_stage[zero_stage] - new_search_data = DSATSearchData(lo=1, hi=random_zero_stage_max_mbs) - - # Randomly choose the actual batch size. - mbs = self.get_random_mbs_from_search_data(new_search_data) - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - return new_hparams, new_search_data - - def get_random_trial(self) -> DSATTrial: - # Choose the stage randomly from user provided stages, after some performance filtering. - # If stage one or two was successful, don't continue with stage 3. - stage_one_and_two = {1, 2} - successful_one_or_two_stages = stage_one_and_two & self.trial_tracker.successful_stages - filtered_zero_stages = successful_one_or_two_stages & self.trial_tracker.zero_stages - if filtered_zero_stages: - zero_stage = random.choice(list(filtered_zero_stages)) - else: - zero_stage = random.choice(list(self.trial_tracker.zero_stages)) - - hparams, search_data = self.get_random_hparams_and_search_data(zero_stage) - random_trial = self.trial_tracker.create_trial(hparams=hparams, search_data=search_data) - return random_trial - - def early_stopping_triggered(self) -> bool: - if self.early_stopping is None: - return False - return self.trial_tracker.num_trials_since_best_result >= self.early_stopping - - -class BinarySearchDSATSearchMethod(BaseDSATSearchMethod): - """Basic binary search over randomly generated configurations. - - Randomly generates as many DeepSpeed configurations as can be concurrently tested, per the - CLI arguments, and performs a binary search over batch size. Each such lineage runs to - completion or until the `max_trials` limit is hit. Lineages whose binary search ends before - `max_trials` is hit are replaced with newly generated random configurations. One search-specific - argument: - - Args: - search_range_factor: - adjusts the initial binary search range by raising the ceiling by a factor of - `search_range_factor` - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.search_range_factor = self.args.search_range_factor - - def get_trials_after_validation_completed( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - metric: Optional[Union[float, Dict[str, Any]]] = None, - ) -> List[DSATTrial]: - new_trials = [] - if isinstance(last_trial, DSATModelProfileInfoTrial): - new_trials = self.get_trial_list_after_model_profile_info_run() - else: - new_trials = self.get_trial_list_after_successful_run(last_trial) - - return new_trials - - def get_trials_after_early_exit( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - exited_reason: searcher.ExitedReason, - ) -> List[DSATTrial]: - new_trials = [] - if last_trial.search_data is None: - return [self.get_random_trial()] - new_search_data = copy.deepcopy(last_trial.search_data) - new_search_data.hi = last_trial.mbs - 1 - if new_search_data.lo > new_search_data.hi: - return [self.get_random_trial()] - - mbs = (new_search_data.hi + new_search_data.lo) // 2 - new_hparams = copy.deepcopy(last_trial.hparams) - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - - trial = self.trial_tracker.create_trial( - hparams=new_hparams, - search_data=new_search_data, - parent_trial=last_trial, - ) - new_trials.append(trial) - return new_trials - - def get_trial_list_after_model_profile_info_run(self) -> List[DSATTrial]: - new_trials = [] - concurrent_trials = self.args.max_concurrent_trials - if self.args.max_slots is not None: - concurrent_trials_from_slots = self.args.max_slots // self.trial_tracker.slots_per_trial - concurrent_trials = min(concurrent_trials, concurrent_trials_from_slots) - for _ in range(concurrent_trials): - trial = self.get_random_trial() - new_trials.append(trial) - return new_trials - - def get_trial_list_after_successful_run( - self, - last_trial: DSATTrial, - ) -> List[DSATTrial]: - if last_trial.search_data is None: - return [self.get_random_trial()] - new_search_data = copy.deepcopy(last_trial.search_data) - new_search_data.lo = last_trial.mbs + 1 - if new_search_data.lo > new_search_data.hi: - return [self.get_random_trial()] - - mbs = (new_search_data.hi + new_search_data.lo) // 2 - new_hparams = copy.deepcopy(last_trial.hparams) - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - trial = self.trial_tracker.create_trial( - hparams=new_hparams, - search_data=new_search_data, - parent_trial=last_trial, - ) - return [trial] - - def get_random_hparams_and_search_data( - self, zero_stage: int - ) -> Tuple[Dict[str, Any], DSATSearchData]: - zero_optim_config = dsat.get_random_zero_optim_config(zero_stage) - new_hparams = copy.deepcopy(self.trial_tracker.hparams) - new_hparams[defaults.OVERWRITE_KEY] = util.merge_dicts( - new_hparams.get(defaults.OVERWRITE_KEY, {}), - {"zero_optimization": zero_optim_config}, - ) - - random_zero_stage_max_mbs = self.trial_tracker.approx_max_mbs_per_stage[zero_stage] - - # The default `search_range_factor = 1.` value makes the ceiling coincide with - # the predicted max mbs, but we give the user a handle to alter this range as needed. - lo = 1 - hi = int(self.search_range_factor * random_zero_stage_max_mbs) - hi = max(hi, lo) - new_search_data = DSATSearchData(lo=1, hi=hi) - - mbs = (new_search_data.hi + new_search_data.lo) // 2 - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - return new_hparams, new_search_data - - def get_random_trial(self) -> DSATTrial: - # Choose the stage randomly from user provided stages, after some performance filtering. - # If stage one or two was successful, don't continue with stage 3. - zero_stage = random.choice(list(self.trial_tracker.zero_stages)) - hparams, search_data = self.get_random_hparams_and_search_data(zero_stage) - random_trial = self.trial_tracker.create_trial(hparams=hparams, search_data=search_data) - return random_trial - - -class ASHADSATSearchData(DSATSearchData): - def __init__(self, curr_rung: int, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.curr_rung = curr_rung - - -class ASHADSATSearchMethod(BaseDSATSearchMethod): - """Asynchronous Successive Halving Algorithm (ASHA) - - Adaptive search through randomly-generated DeepSpeed configurations which tunes the batch size - through a binary search and uses the number of Trials in this search as the finite-resource of - ASHA. Search-specific arguments: - - Args: - asha_early_stopping: - ASHA early stopping parameter (`s` in arxiv:1810.05934) - max_rungs: - Maximum number of rungs - min_binary_search_trials: - Minimum number of binary search Trials to run per random configuration - divisor: - ASHA divisor parameter (`eta` in arxiv:1810.05934), controlling the growth in - resources and population thinning across rungs - search_range_factor: - adjusts the initial binary search range by raising the ceiling by a factor of - `search_range_factor` - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.divisor: int = self.args.divisor - self.max_rungs: int = self.args.max_rungs - self.min_binary_search_trials: int = self.args.min_binary_search_trials - self.asha_early_stopping: int = self.args.asha_early_stopping - self.search_range_factor: float = self.args.search_range_factor - - def get_trials_after_validation_completed( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - metric: Optional[Union[float, Dict[str, Any]]] = None, - ) -> List[DSATTrial]: - if isinstance(last_trial, DSATModelProfileInfoTrial): - new_trials = self.get_trial_list_after_model_profile_info_run() - else: - new_trials = [self.get_next_trial(last_trial)] - return new_trials - - def get_trials_after_early_exit( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - exited_reason: searcher.ExitedReason, - ) -> List[DSATTrial]: - new_trial = [self.get_next_trial(last_trial)] - return new_trial - - def choose_next_trial_from_queue(self) -> DSATTrial: - """ - Schedule the trial with the largest `search_data.curr_rung` value. - """ - - def curr_rung_key(trial: DSATTrial) -> int: - assert trial.search_data - assert isinstance(trial.search_data, ASHADSATSearchData) - return trial.search_data.curr_rung - - highest_rung_trial = max(self.trial_tracker.queue, key=curr_rung_key) - # If there are multiple such trials, choose the one with the longest lineage so that - # trials are promoted more quickly. - assert highest_rung_trial.search_data - assert isinstance(highest_rung_trial.search_data, ASHADSATSearchData) - highest_curr_rung = highest_rung_trial.search_data.curr_rung - all_highest_curr_rung_trials_in_queue = [ - t - for t in self.trial_tracker.queue - if t.search_data - and isinstance(t.search_data, ASHADSATSearchData) - and t.search_data.curr_rung == highest_curr_rung - ] - - next_trial = max( - all_highest_curr_rung_trials_in_queue, key=lambda t: t.num_completed_trials_in_lineage - ) - self.trial_tracker.queue.remove(next_trial) - - return next_trial - - def get_next_trial(self, last_trial: DSATTrial) -> DSATTrial: - next_trial = None - assert last_trial.search_data is not None and isinstance( - last_trial.search_data, ASHADSATSearchData - ) - if not self.lineage_completed_rung(last_trial, last_trial.search_data.curr_rung): - next_trial = self.get_next_trial_in_lineage(last_trial) - if next_trial is None: - next_lineage = self.get_next_promotable_lineage() - if next_lineage is not None: - next_trial = self.get_next_trial_in_lineage(next_lineage) - if next_trial is not None: - assert next_trial.search_data - assert isinstance(next_trial.search_data, ASHADSATSearchData) - # Promote to next rung - next_trial.search_data.curr_rung += 1 - if next_trial is None: - next_trial = self.get_random_trial() - return next_trial - - def get_trial_list_after_model_profile_info_run( - self, - ) -> List[DSATTrial]: - new_trials = [] - max_num_trials = min( - self.trial_tracker.max_concurrent_trials, self.trial_tracker.max_trials - ) - if self.trial_tracker.max_slots: - max_trials_by_slot = self.trial_tracker.max_slots // self.trial_tracker.slots_per_trial - max_num_trials = min(max_num_trials, max_trials_by_slot) - for _ in range(max_num_trials): - new_trials.append(self.get_random_trial()) - return new_trials - - @property - def rungs(self) -> Dict[int, List[DSATTrial]]: - """ - A dictionary of lists of the latest trials in each lineage which have completed the - specified rung. - """ - rungs = collections.defaultdict(list) - for root in self.get_all_latest_trials_in_lineages(): - assert isinstance(root.search_data, ASHADSATSearchData) - rung_idx = 0 - while self.lineage_completed_rung(root, rung_idx): - rungs[rung_idx].append(root) - rung_idx += 1 - return rungs - - def get_all_latest_trials_in_lineages(self) -> List[DSATTrial]: - """ - Returns a list of the latest trials in each lineage. - """ - lineage_root_set = [ - trial - for _, trial in self.trial_tracker - if not isinstance(trial, DSATModelProfileInfoTrial) - and self.get_latest_trial_in_lineage(trial) == trial - and trial.search_data is not None - and isinstance(trial.search_data, ASHADSATSearchData) - ] - return lineage_root_set - - def lineage_completed_rung(self, trial: DSATTrial, rung_idx: int) -> bool: - assert trial.search_data - assert isinstance(trial.search_data, ASHADSATSearchData) - latest_trial = self.get_latest_trial_in_lineage(trial) - assert latest_trial.search_data - assert isinstance(latest_trial.search_data, ASHADSATSearchData) - if latest_trial.search_data.curr_rung > rung_idx: - return True - if trial.num_completed_trials_in_lineage >= self.max_trials_for_rung_idx(rung_idx): - return True - # Also need to cover the cases where a binary search stopped before using all available - # resources (trials) in its current rung. Only need to check for curr_rung = rung_idx. - if latest_trial.search_data.curr_rung == rung_idx: - failed_on_min_mbs = ( - latest_trial.error and latest_trial.mbs == latest_trial.search_data.lo - ) - trivial_search_data = latest_trial.search_data.hi == latest_trial.search_data.lo - if trivial_search_data or failed_on_min_mbs: - return True - return False - - def get_next_promotable_lineage(self) -> Optional[DSATTrial]: - # Cannot promote from the top rung (rung_idx == self.max_rung - 1) - for rung_idx in reversed(range(self.max_rungs - 1)): - next_promotable_trial = self.get_next_promotable_lineage_in_rung(rung_idx) - if next_promotable_trial is not None: - return next_promotable_trial - return None - - def get_next_promotable_lineage_in_rung(self, rung_idx: int) -> Optional[DSATTrial]: - """ - Returns the latest trial in the next promotable lineage in the given rung. - """ - top_trials = self.get_top_lineages_in_rung(rung_idx) - for trial in top_trials: - latest_trial = self.get_latest_trial_in_lineage(trial) - assert latest_trial.search_data - assert isinstance(latest_trial.search_data, ASHADSATSearchData) - already_promoted = latest_trial.search_data.curr_rung > rung_idx - if not already_promoted: - return self.get_latest_trial_in_lineage(trial) - return None - - def get_top_lineages_in_rung(self, rung_idx: int) -> List[DSATTrial]: - """ - Returns the best trial in each of the top 1 / divisor fraction of lineages from the given - rung, per the ASHA paper. - """ - completed_lineages_in_rung = self.rungs[rung_idx] - k = len(completed_lineages_in_rung) // self.divisor - if not k: - return [] - best_trials: List[DSATTrial] = [] - for lin in completed_lineages_in_rung: - best_trial = self.get_best_trial_in_lineage(lin, max_rung_idx=rung_idx) - if best_trial is not None: - best_trials.append(best_trial) - reverse = not self.trial_tracker.smaller_is_better - best_trials.sort( - key=lambda t: t.searcher_metric_val is not None and t.searcher_metric_val, - reverse=reverse, - ) - return best_trials[:k] - - def get_best_trial_in_lineage( - self, trial: DSATTrial, max_rung_idx: Optional[int] = None - ) -> Optional[DSATTrial]: - trials_with_metrics = [t for t in trial.lineage_set if t.searcher_metric_val is not None] - if max_rung_idx is not None: - filtered_trials_with_metrics: List[DSATTrial] = [] - for t in trials_with_metrics: - assert t.search_data - assert isinstance(t.search_data, ASHADSATSearchData) - if t.search_data.curr_rung <= max_rung_idx: - filtered_trials_with_metrics.append(t) - trials_with_metrics = filtered_trials_with_metrics - if not trials_with_metrics: - return None - min_or_max = min if self.trial_tracker.smaller_is_better else max - return min_or_max( - trials_with_metrics, - key=lambda t: t.searcher_metric_val is not None and t.searcher_metric_val, - ) - - def get_latest_trial_in_lineage(self, trial: DSATTrial) -> DSATTrial: - while trial.children: - assert len(trial.children) <= 1 # Sanity check - trial = next(iter(trial.children)) - return trial - - def get_next_trial_in_lineage(self, trial: DSATTrial) -> Optional[DSATTrial]: - latest_trial = self.get_latest_trial_in_lineage(trial) - assert latest_trial.search_data is not None - new_search_data = copy.deepcopy(latest_trial.search_data) - if latest_trial.searcher_metric_val is not None: - new_search_data.lo = latest_trial.mbs + 1 - else: - new_search_data.hi = latest_trial.mbs - 1 - - if new_search_data.hi < new_search_data.lo: - return None - - mbs = (new_search_data.hi + new_search_data.lo) // 2 - - new_hparams = copy.deepcopy(latest_trial.hparams) - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - next_trial = self.trial_tracker.create_trial( - hparams=new_hparams, - search_data=new_search_data, - parent_trial=latest_trial, - ) - return next_trial - - def max_trials_for_rung_idx(self, rung_idx: int) -> int: - max_trials: int = self.min_binary_search_trials * self.divisor ** ( - self.asha_early_stopping + rung_idx - ) - return max_trials - - def get_random_hparams_and_search_data( - self, zero_stage: int - ) -> Tuple[Dict[str, Any], ASHADSATSearchData]: - zero_optim_config = dsat.get_random_zero_optim_config(zero_stage) - new_hparams = copy.deepcopy(self.trial_tracker.hparams) - new_hparams[defaults.OVERWRITE_KEY] = util.merge_dicts( - new_hparams.get(defaults.OVERWRITE_KEY, {}), - {"zero_optimization": zero_optim_config}, - ) - - random_zero_stage_max_mbs = self.trial_tracker.approx_max_mbs_per_stage[zero_stage] - lo = 1 - hi = int(random_zero_stage_max_mbs * self.search_range_factor) - hi = max(hi, lo) - new_search_data = ASHADSATSearchData(lo=1, hi=hi, curr_rung=0) - - # Randomly choose the actual batch size. - mbs = (new_search_data.hi + new_search_data.lo) // 2 - new_hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = mbs - return new_hparams, new_search_data - - def get_random_trial(self) -> DSATTrial: - zero_stage = random.choice(list(self.trial_tracker.zero_stages)) - hparams, search_data = self.get_random_hparams_and_search_data(zero_stage) - random_trial = self.trial_tracker.create_trial(hparams=hparams, search_data=search_data) - return random_trial - - -class TestDSATSearchMethod(BaseDSATSearchMethod): - """Searcher for basic testing purposes. - - Submits Trials with linearly increasing batch sizes, from 2 up to max_trials - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - - def get_trials_after_validation_completed( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - metric: Optional[Union[float, Dict[str, Any]]] = None, - ) -> List[DSATTrial]: - new_trials = [] - if isinstance(last_trial, DSATModelProfileInfoTrial): - # Delete special DS keys which force a model profiling info run. - hparams_without_profile_info_keys = last_trial.hparams - del hparams_without_profile_info_keys[defaults.OVERWRITE_KEY]["autotuning"][ - "model_info" - ] - del hparams_without_profile_info_keys[defaults.OVERWRITE_KEY]["autotuning"][ - "model_info_path" - ] - for tmbs in range(2, self.trial_tracker.max_trials + 1): - hparams = copy.deepcopy(hparams_without_profile_info_keys) - hparams[defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = tmbs - # Choose a random zero stage: - hparams[defaults.OVERWRITE_KEY]["zero_optimization"] = { - "stage": random.choice(list(self.args.zero_stages)) - } - trial = self.trial_tracker.create_trial( - hparams=hparams, - search_data=None, - parent_trial=None, - ) - new_trials.append(trial) - return new_trials - - def get_trials_after_early_exit( - self, - searcher_state: searcher.SearcherState, - last_trial: DSATTrial, - exited_reason: searcher.ExitedReason, - ) -> List[DSATTrial]: - return [] diff --git a/harness/determined/pytorch/dsat/_run_dsat.py b/harness/determined/pytorch/dsat/_run_dsat.py deleted file mode 100644 index 93ab84b55b1..00000000000 --- a/harness/determined/pytorch/dsat/_run_dsat.py +++ /dev/null @@ -1,91 +0,0 @@ -import argparse -import logging -import os -import pathlib -import pickle -from typing import Any, Dict, Type - -import determined as det -from determined import searcher, util -from determined.pytorch import dsat -from determined.pytorch.dsat import defaults - - -def get_search_method_class(method_string: str) -> Type[dsat.BaseDSATSearchMethod]: - string_to_class_map = { - "binary": dsat.BinarySearchDSATSearchMethod, - "random": dsat.RandomDSATSearchMethod, - "asha": dsat.ASHADSATSearchMethod, - "_test": dsat.TestDSATSearchMethod, - } - if method_string not in string_to_class_map: - raise ValueError( - f"`method_string` must be one of {list(string_to_class_map)}, not {method_string}" - ) - return string_to_class_map[method_string] - - -def get_custom_dsat_exp_conf_from_args( - args: argparse.Namespace, -) -> Dict[str, Any]: - """ - Helper function which alters the user-submitted configuration and args into a configuration - for the DS AT custom searchers. - """ - exp_config = dsat.get_dict_from_yaml_or_json_path( - args.config_path - ) # add the search runner's experiment id to the description of the corresonding Trial - additional_description = f"(#{args.experiment_id}) generated" - existing_description = exp_config.get("description") - if existing_description is not None: - exp_config["description"] = f"{additional_description} - {exp_config['description']}" - else: - exp_config["description"] = additional_description - - # Overwrite the searcher section. - exp_config["searcher"] = { - "name": "custom", - "metric": args.metric, - "smaller_is_better": dsat.smaller_is_better(args.metric), - } - # Add all necessary autotuning keys from dsat.defaults and user-supplied args. - autotuning_config = defaults.AUTOTUNING_DICT - autotuning_config["autotuning"]["start_profile_step"] = args.start_profile_step - autotuning_config["autotuning"]["end_profile_step"] = args.end_profile_step - - exp_config["hyperparameters"] = util.merge_dicts( - exp_config["hyperparameters"], {defaults.OVERWRITE_KEY: autotuning_config} - ) - # Add an internal key to the HP dict which enables the DSAT code path for Trial classes. - exp_config["hyperparameters"][defaults.USE_DSAT_MODE_KEY] = True - - return exp_config - - -def main(core_context: det.core.Context) -> None: - with pathlib.Path(defaults.ARGS_PKL_PATH).open("rb") as f: - args = pickle.load(f) - # On-cluster, the relative paths to the below files just come from the base names. - args.config_path = os.path.basename(args.config_path) - args.model_dir = os.path.basename(args.model_dir) - args.include = [os.path.basename(p) for p in args.include] if args.include is not None else [] - cluster_info = det.get_cluster_info() - assert ( - cluster_info and cluster_info._trial_info - ), "Could not find `cluster_info`, the DSAT module must be run on a Determined Cluster" - args.experiment_id = cluster_info._trial_info.experiment_id - - exp_config = get_custom_dsat_exp_conf_from_args(args) - - search_method_class = get_search_method_class(args.search_method) - search_method = search_method_class(args=args, exp_config=exp_config) - - search_runner = searcher.RemoteSearchRunner(search_method, context=core_context) - - search_runner.run(exp_config=exp_config, model_dir=args.model_dir, includes=args.include) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) - with det.core.init() as core_context: - main(core_context) diff --git a/harness/determined/pytorch/dsat/_utils.py b/harness/determined/pytorch/dsat/_utils.py deleted file mode 100644 index 1fad1046ce9..00000000000 --- a/harness/determined/pytorch/dsat/_utils.py +++ /dev/null @@ -1,530 +0,0 @@ -import argparse -import contextlib -import copy -import json -import logging -import pathlib -import random -from typing import Any, Dict, Generator, List, Optional, Union - -import filelock -import torch - -import determined as det -from determined import util as det_util -from determined.common import util -from determined.pytorch.dsat import defaults - -logger = logging.getLogger("determined.pytorch") - -CURR_DIR = pathlib.Path(".") - - -def get_base_parser() -> argparse.ArgumentParser: - base_parser = argparse.ArgumentParser(add_help=False) - base_parser.add_argument("config_path", help="experiment config file (.yaml)") - base_parser.add_argument("model_dir", help="file or directory containing model definition") - base_parser.add_argument( - "-i", - "--include", - type=str, - nargs="+", - help="additional files to copy into the task container", - ) - - base_parser.add_argument( - "-mt", - "--max-trials", - type=int, - default=defaults.AUTOTUNING_ARG_DEFAULTS["max-trials"], - help="Maximum number of trials to run, including the model profile info trial", - ) - base_parser.add_argument( - "-ms", - "--max-slots", - type=int, - help="Maximum number of slots to use concurrently", - ) - base_parser.add_argument( - "-mct", - "--max-concurrent-trials", - type=int, - default=defaults.AUTOTUNING_ARG_DEFAULTS["max-concurrent-trials"], - help="Maximum number of trials to run concurrently", - ) - base_parser.add_argument( - "-m", - "--metric", - type=str, - default=defaults.AUTOTUNING_ARG_DEFAULTS["metric"], - choices=defaults.SMALLER_IS_BETTER_METRICS + defaults.LARGER_IS_BETTER_METRICS, - ) - base_parser.add_argument( - "--run-full-experiment", - action="store_true", - help="Run full-length experiment using best-found configuration after dsat completes", - ) - base_parser.add_argument( - "-z", - "--zero-stages", - type=int, - nargs="+", - default=defaults.AUTOTUNING_ARG_DEFAULTS["zero-stages"], - choices=list(range(4)), - help="Space-separated list of zero stages to search over", - ) - base_parser.add_argument( - "--start-profile-step", - type=int, - default=defaults.AUTOTUNING_ARG_DEFAULTS["start-profile-step"], - help="Step on which to start profiling", - ) - base_parser.add_argument( - "--end-profile-step", - type=int, - default=defaults.AUTOTUNING_ARG_DEFAULTS["end-profile-step"], - help="Step on which to stop profiling", - ) - base_parser.add_argument( - "-r", - "--random-seed", - type=int, - default=defaults.AUTOTUNING_ARG_DEFAULTS["random-seed"], - help="Sets all random seeds", - ) - base_parser.add_argument( - "--search-runner-config", - type=str, - help="Path to an alternative search runner configuration file. For advanced use cases", - ) - base_parser.add_argument( - "--max-search-runner-restarts", type=int, default=5, help="Maximum search runner restarts" - ) - - return base_parser - - -def get_full_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - prog="Determined AI DeepSpeed Autotune", - ) - subparsers = parser.add_subparsers(required=True, dest="search_method") - base_parser = get_base_parser() - - subparsers.add_parser( - "_test", - parents=[base_parser], - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - random_subparser = subparsers.add_parser( - "random", - parents=[base_parser], - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - random_subparser.add_argument( - "--trials-per-random-config", - type=int, - default=defaults.AUTOTUNING_ARG_DEFAULTS["trials-per-random-config"], - help="Maximum number of trials to run per random config", - ) - random_subparser.add_argument( - "--early-stopping", - type=int, - help="Terminates the search if a new best config not found in last `early-stopping` trials", - ) - - binary_subparser = subparsers.add_parser( - "binary", - parents=[base_parser], - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - search_range_factor_help = ( - "Expands the initial search range by a factor of `search-range-factor`" - ) - binary_subparser.add_argument( - "--search-range-factor", - type=float, - default=defaults.AUTOTUNING_ARG_DEFAULTS["search-range-factor"], - help=search_range_factor_help, - ) - - asha_subparser = subparsers.add_parser( - "asha", - parents=[base_parser], - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - asha_subparser.add_argument( - "--max-rungs", - default=defaults.AUTOTUNING_ARG_DEFAULTS["max-rungs"], - help="Maximum rungs to use in the ASHA algorithm", - ) - asha_subparser.add_argument( - "--min-binary-search-trials", - default=defaults.AUTOTUNING_ARG_DEFAULTS["min-binary-search-trials"], - help="Minimum number of binary search trials to run per random configuration", - ) - asha_subparser.add_argument( - "--asha-early-stopping", - default=defaults.AUTOTUNING_ARG_DEFAULTS["asha-early-stopping"], - help="ASHA early stopping parameter (`s` in arxiv:1810.05934)", - ) - asha_subparser.add_argument( - "--divisor", - default=defaults.AUTOTUNING_ARG_DEFAULTS["divisor"], - help="ASHA divisor parameter (`eta` in arxiv:1810.05934)", - ) - asha_subparser.add_argument( - "--search-range-factor", - type=float, - default=defaults.AUTOTUNING_ARG_DEFAULTS["search-range-factor"], - help=search_range_factor_help, - ) - - return parser - - -def smaller_is_better(metric: str) -> bool: - if metric in defaults.SMALLER_IS_BETTER_METRICS: - return True - elif metric in defaults.LARGER_IS_BETTER_METRICS: - return False - else: - valid_metrics = defaults.SMALLER_IS_BETTER_METRICS + defaults.LARGER_IS_BETTER_METRICS - raise ValueError(f"metric must be one of {valid_metrics}, not {metric}") - - -def get_split_entrypoint(submitted_entrypoint: Union[List[str], str]) -> List[str]: - # The entrypoint may be a string or list of strings. Strip all white space from each entry and - # convert to a list, in either case. - if isinstance(submitted_entrypoint, str): - split_entrypoint = submitted_entrypoint.split(" ") - elif isinstance(submitted_entrypoint, list): - # Join and re-split to remove any possile white space. - # submitted_entrypoint: List[str] - str_entrypoint: str = " ".join(submitted_entrypoint) - split_entrypoint = str_entrypoint.split(" ") - else: - raise ValueError( - f"Expected a string or list for an entrypoint, but received " - f"{type(submitted_entrypoint)}" - ) - return [s.strip() for s in split_entrypoint if s.strip()] - - -def get_search_runner_config_from_args(args: argparse.Namespace) -> Dict[str, Any]: - if args.search_runner_config is not None: - submitted_search_runner_config = get_dict_from_yaml_or_json_path(args.search_runner_config) - return submitted_search_runner_config - - submitted_exp_config_dict = get_dict_from_yaml_or_json_path(args.config_path) - assert "deepspeed_config" in submitted_exp_config_dict["hyperparameters"], ( - "DS AT requires a `hyperparameters.deepspeed_config` key which points " - "to the deepspeed config json file" - ) - - # Also sanity check that if a --deepspeed_config (or in the case of HF - # --deepspeed) arg is passed in, both configs match. Probably some gotchas here because - # --deepspeed is also a boolean arg for vanilla deepspeed. - possible_config_flags = ("--deepspeed", "--deepspeed_config") - - submitted_entrypoint: Union[List[str], str] = submitted_exp_config_dict["entrypoint"] - split_entrypoint = get_split_entrypoint(submitted_entrypoint) - - for idx in range(len(split_entrypoint) - 1): - curr_arg, next_arg = split_entrypoint[idx : idx + 2] - next_arg_is_not_a_flag = next_arg != "-" - if curr_arg in possible_config_flags and next_arg_is_not_a_flag: - entrypoint_deepspeed_config = next_arg - hp_deepspeed_config = submitted_exp_config_dict["hyperparameters"]["deepspeed_config"] - if entrypoint_deepspeed_config != hp_deepspeed_config: - raise ValueError( - f"The deepspeed config path in the `hyperparameters` section, " - f"{hp_deepspeed_config}, does not match the path in the entrypoint, " - f"{entrypoint_deepspeed_config}." - ) - - default_search_runner_config = defaults.DEFAULT_SEARCH_RUNNER_CONFIG - if args.max_search_runner_restarts is not None: - default_search_runner_config["max_restarts"] = args.max_search_runner_restarts - # Merge with the submitted experiment config so that the search runner shares the project, - # workspace, etc. - search_runner_config = det_util.merge_dicts( - submitted_exp_config_dict, default_search_runner_config - ) - search_runner_config["name"] = f"(DSAT) {search_runner_config['name']}" - search_runner_config["hyperparameters"] = vars(args) - - return search_runner_config - - -def get_dict_from_yaml_or_json_path( - path: str, convert_json_keys_to_int: bool = True -) -> Dict[Any, Any]: - """ - Load a json or yaml file as a dict. Optionally convert all json dict keys to - ints, where possible. - """ - p = pathlib.Path(path) - if p.suffix == ".json": - try: - with open(p, "r") as f: - json_dict: Dict[Any, Any] = json.load(f) - if convert_json_keys_to_int: - - def try_str_to_int(s: str) -> Union[str, int]: - try: - return int(s) - except ValueError: - return s - - json_dict = {try_str_to_int(k): v for k, v in json_dict.items()} - return json_dict - except Exception as e: - logger.info(f"Exception {e} raised when loading {path} with json. Attempting yaml.") - return {} - else: - with open(p, "r") as f: - yaml_dict: Dict[Any, Any] = util.yaml_safe_load(f) - return yaml_dict - - -@contextlib.contextmanager -def dsat_reporting_context( - core_context: det.core._context.Context, - op: det.core._searcher.SearcherOperation, - steps_completed: Optional[int] = None, -) -> Generator[None, None, None]: - """ - Context manager required for using Determined AI DeepSpeed Autotune with Core API. - - The `forward` and `step` methods of the DeepSpeed model engine must be called inside of this - context manager. - - Args: - core_context: a `Context` instance created with `determined.core.init` - op: the first `SearcherOperation` instance generated by `core_context.searcher.operations` - - """ - if steps_completed is None: - steps_completed = op.length - try: - yield - except SystemExit as se: - model_profiling_path = pathlib.Path(defaults.MODEL_INFO_PROFILING_PATH) - autotuning_results_path = pathlib.Path(defaults.AUTOTUNING_RESULTS_PATH) - possible_paths = [model_profiling_path, autotuning_results_path] - existing_paths = [p for p in possible_paths if p.exists()] - # Exactly one of these files should be generated for each properly exited DS AT Trial. - if len(existing_paths) == 1: - path = existing_paths[0] - add_gpu_info = path == model_profiling_path - report_json_results( - core_context=core_context, - op=op, - steps_completed=steps_completed, - add_gpu_info=add_gpu_info, - path=path, - ) - raise se - - -def report_json_results( - core_context: det.core._context.Context, - op: det.core._searcher.SearcherOperation, - steps_completed: int, - add_gpu_info: bool, - path: Union[str, pathlib.Path], -) -> None: - is_chief = core_context.distributed.rank == 0 - if is_chief: - with open(path, "r") as f: - results_dict = json.load(f) - if add_gpu_info: - gpu_mem = torch.cuda.get_device_properties(0).total_memory - results_dict["gpu_mem"] = gpu_mem - core_context.train.report_validation_metrics( - steps_completed=steps_completed, metrics=results_dict - ) - op.report_completed(results_dict) - # Ensure the operations generator is empty to complete sanity checks. - try: - next(core_context.searcher.operations()) - except StopIteration: - pass - else: - raise AssertionError("Unexpected additional operations found!") - - -def get_zero_stage_search_space( - zero_stage: int, -) -> Dict[str, List[Union[bool, float]]]: - default_settings: Dict[ - int, Dict[str, List[Union[bool, float]]] - ] = defaults.DEFAULT_ZERO_SEARCH_SPACE - assert ( - zero_stage in default_settings - ), f"Invalid zero_stage, must be one of {list(default_settings)}" - search_space = default_settings[1] - for stage in range(2, zero_stage + 1): - search_space = det_util.merge_dicts(search_space, default_settings[stage]) - return search_space - - -def get_random_zero_optim_config(zero_stage: int) -> Dict[str, Union[bool, float]]: - search_space = get_zero_stage_search_space(zero_stage) - zero_optim_dict = {k: random.choice(v) for k, v in search_space.items()} - zero_optim_dict["stage"] = zero_stage - return zero_optim_dict - - -def get_batch_config_from_mbs_gas_and_slots( - ds_config: Dict[str, Any], slots: int -) -> Dict[str, int]: - """ - Returns a consistent batch size configuration by adjusting `train_batch_size` according to the - number of `slots`, `train_micro_batch_size_per_gpu`, and `gradient_accumulation_steps` (or its - default value, if not specified). - """ - mbs = ds_config["train_micro_batch_size_per_gpu"] - gas = ds_config.get("gradient_accumulation_steps", defaults.GAS_DEFAULT) - if gas == "auto": - # Needed for HuggingFace. - gas = 1 - tbs = mbs * gas * slots - return { - "train_batch_size": tbs, - "train_micro_batch_size_per_gpu": mbs, - "gradient_accumulation_steps": gas, - } - - -def get_ds_config_from_hparams( - hparams: Dict[str, Any], - base_dir: Union[pathlib.Path, str] = CURR_DIR, -) -> Dict[str, Any]: - """Gets the DS config dictionary after merging with overwrite values. - - Follows the rules as described here: - https://docs.determined.ai/latest/training/apis-howto/deepspeed/deepspeed.html#configuration - Args: - hparams (Dict): - Hyperparameters dictionary - base_dir (pathlib.Path): - Base directory reltative to which hparams.deepspeed_config is defined - Returns: - The Deepspeed Configuration for this experiment following the overwriting rules - """ - assert defaults.CONFIG_KEY in hparams, ( - f"Expected to find {defaults.CONFIG_KEY} in the Hyperparameters section. " - f"Instead found {hparams}" - ) - ds_config_relative_path = hparams[defaults.CONFIG_KEY] - base_dir = pathlib.Path(base_dir) - full_path = base_dir.joinpath(ds_config_relative_path) - with open(full_path, "r") as f: - base_ds_config: Dict[str, Any] = json.load(f) - overwrite_ds_config = hparams.get(defaults.OVERWRITE_KEY, {}) - final_ds_config = det_util.merge_dicts(base_ds_config, overwrite_ds_config) - return final_ds_config - - -def get_hf_ds_config_path_from_args(args: List[str]) -> Optional[str]: - for idx in range(len(args)): - if args[idx] == "--deepspeed": - ds_config_idx = idx + 1 - ds_config_path = args[ds_config_idx] - return ds_config_path - return None - - -def update_hf_args(args: List[str], ds_config_dict: Dict[str, Any]) -> List[str]: - """ - Updates batch-size-related HF CLI args to be consistent with the values specified in the - provided DeepSpeed config dictionary. - - Args: - args: list of CLI arguments passed to the HF entrypoint - ds_config_dict: the DeepSpeed configuration as a dictionary - """ - hf_flag_to_ds_key = { - "--per_device_train_batch_size": "train_micro_batch_size_per_gpu", - "--gradient_accumulation_steps": "gradient_accumulation_steps", - } - # Overwrite CLI args - args = copy.deepcopy(args) - for idx in range(len(args)): - if args[idx] in hf_flag_to_ds_key: - ds_key = hf_flag_to_ds_key[args[idx]] - overwrite_value = ds_config_dict[ds_key] - # Need to avoid copying possible "auto" value from json config to HF CLI. - is_auto = isinstance(overwrite_value, str) and overwrite_value.strip() == "auto" - if not is_auto: - overwrite_value_str = str(overwrite_value) - if args[idx + 1] != overwrite_value_str: - logger.warning( - f"Changing {args[idx]} from {args[idx +1]} to {overwrite_value_str}" - " to match the deespspeed config values." - ) - args[idx + 1] = overwrite_value_str - del hf_flag_to_ds_key[args[idx]] - - # Any remaining keys in hf_flag_to_ds_key were not provided as args to the HF CLI entrypoint, - # but they must be added in explicitly, to avoid falling back to HF defaults. - for hf_flag, ds_key in hf_flag_to_ds_key.items(): - hf_flag_value = ds_config_dict[ds_key] - is_auto = isinstance(hf_flag_value, str) and hf_flag_value.strip() == "auto" - if not is_auto: - hf_flag_value_str = str(hf_flag_value) - args.extend([hf_flag, hf_flag_value_str]) - logger.warning( - f"Adding {hf_flag} {hf_flag_value_str} to HF CLI args to reflect overwrite values." - ) - return args - - -def get_hf_args_with_overwrites(args: List[str], hparams: Dict[str, Any]) -> List[str]: - """Updates the submitted HF CLI Args to account for overwrite values. - - Primarily intended as a helper function for Determined AI DeepSpeed (DS) Autotune which provides - overwrite values through the `hparams["overwrite_deepspeed_args"]` which possibly include DS - batch-size related arguments (`train_batch_size`, `train_micro_batch_size_per_gpu`, and - `gradient_accumulation_steps`) which are in conflict with the corresponding HF CLI batch-size - related arguments(`--per_device_train_batch_size` and `--gradient_accumulation_steps`). This - function updates the HF CLI args to relect any such overwrite values. This process also requires - overwriting the corresponding DS json file on-cluster. - - Args: - args: the original HF CLI arguments - hparams: hyperparameter dictionary generated through Determined AI - - Returns: - args: updated HF CLI arguments - """ - if defaults.OVERWRITE_KEY not in hparams: - logger.info( - f"{defaults.OVERWRITE_KEY} key not found in hparams, `get_hf_args_with_overwrites` " - "is a no-op" - ) - return args - - ds_config_path = get_hf_ds_config_path_from_args(args) - assert ds_config_path is not None, "--deepspeed flag not found in HuggingFace args!" - - # A file lock is required during both the writing and reading. - with filelock.FileLock(ds_config_path + ".lock"): - with open(ds_config_path, "r") as f: - ds_config_dict = json.load(f) - - # Then merge all overwrites into the ds_config - overwritten_ds_config_dict = det_util.merge_dicts( - ds_config_dict, hparams[defaults.OVERWRITE_KEY] - ) - - # We need to actually overwrite the ds json config file, due to how HF processes args. - with open(ds_config_path, "w") as f: - json.dump(overwritten_ds_config_dict, f) - # Finally overwrite the CLI args - args = update_hf_args(args, overwritten_ds_config_dict) - - return args diff --git a/harness/determined/pytorch/dsat/defaults.py b/harness/determined/pytorch/dsat/defaults.py deleted file mode 100644 index 672c8e34026..00000000000 --- a/harness/determined/pytorch/dsat/defaults.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Dict, List, Union - -ALL_SEARCH_METHOD_NAMES = ["binary", "_test", "asha", "random"] - -MODEL_INFO_PROFILING_PATH = "model_info.json" -AUTOTUNING_RESULTS_PATH = "autotuning_metric.json" -SMALLER_IS_BETTER = True -USE_DSAT_MODE_KEY = "_use_dsat_mode" -GAS_DEFAULT = 1 -CONFIG_KEY = "deepspeed_config" -OVERWRITE_KEY = "overwrite_deepspeed_args" -ARGS_PKL_PATH = "args.pkl" - -SMALLER_IS_BETTER_METRICS = ["forward", "backward", "latency"] -LARGER_IS_BETTER_METRICS = ["throughput", "FLOPS_per_gpu"] - -# Native DS AT uses the below settings for the model info profile run, but also with the the stage -# set to 3, presumably since that gives a general model the best chance to run without OOMing. -# However, since some model cannot run with stage 3, we do not enforce that choice here and the -# zero configuration in the submitted deepspeed config will be used. -MODEL_INFO_PROFILE_DS_CONFIG = { - "train_micro_batch_size_per_gpu": 1, - "autotuning": { - "enabled": True, - # The two fields below essentially use DS internals! Maybe fragile. - "model_info_path": MODEL_INFO_PROFILING_PATH, - "model_info": {"profile": True}, - }, -} - - -# Using similar. Written as a diff between successive stages for brevity. -reduce_bucket_size_defaults = [n * 10**m for n in (1, 5) for m in range(6, 10)] -allgather_bucket_size_defaults = [n * 10**m for n in (1, 5) for m in range(6, 10)] - -DEFAULT_ZERO_SEARCH_SPACE: Dict[int, Dict[str, List[Union[bool, float]]]] = { - 0: {}, - 1: { - "reduce_bucket_size": reduce_bucket_size_defaults, - "allgather_bucket_size": allgather_bucket_size_defaults, - }, - 2: { - "overlap_comm": [True, False], - "reduce_scatter": [True, False], - "contiguous_gradients": [True, False], - }, - 3: { - "allgather_partitions": [True, False], - }, -} - -AUTOTUNING_DICT = {"autotuning": {"enabled": True}} - -AUTOTUNING_ARG_DEFAULTS = { - "max-trials": 64, - "max-concurrent-trials": 16, - "zero-stages": [1, 2, 3], - "trials-per-random-config": 5, - "start-profile-step": 3, - "end-profile-step": 5, - "metric": "FLOPS_per_gpu", - "random-seed": 42, - "run-full-experiment": False, - "search-range-factor": 1.0, - "divisor": 2, - "min-binary-search-trials": 3, - "max-rungs": 5, - "asha-early-stopping": 0, -} - -DEFAULT_SEARCH_RUNNER_CONFIG = { - "searcher": {"name": "single", "max_length": 0}, - "max_restarts": 5, - "resources": {"slots_per_trial": 0}, - "entrypoint": "python3 -m determined.pytorch.dsat._run_dsat", -} diff --git a/harness/determined/searcher/__init__.py b/harness/determined/searcher/__init__.py deleted file mode 100644 index bba5853f095..00000000000 --- a/harness/determined/searcher/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from determined.searcher._search_method import ( - SearchMethod, - SearcherState, - Close, - Create, - ExitedReason, - Operation, - Progress, - Shutdown, - ValidateAfter, -) -from determined.searcher._search_runner import SearchRunner, LocalSearchRunner -from determined.searcher._remote_search_runner import RemoteSearchRunner diff --git a/harness/determined/searcher/_remote_search_runner.py b/harness/determined/searcher/_remote_search_runner.py deleted file mode 100644 index c4d300089c5..00000000000 --- a/harness/determined/searcher/_remote_search_runner.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging -import os -import pathlib -import pickle -import warnings -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import determined as det -from determined import searcher -from determined.experimental import client - -logger = logging.getLogger("determined.searcher") - - -class RemoteSearchRunner(searcher.SearchRunner): - """ - ``RemoteSearchRunner`` performs a search for optimal hyperparameter values, - applying the provided ``SearchMethod`` (you will subclass ``SearchMethod`` and provide - an instance of the derived class). - ``RemoteSearchRunner`` executes on-cluster: it runs a meta-experiment - using ``Core API``. - """ - - def __init__(self, search_method: searcher.SearchMethod, context: det.core.Context) -> None: - warnings.warn( - "`RemoteSearchRunner` and all custom searchers have been deprecated. " - "This feature will be removed in a future release. Consider configuring a preset " - "searcher instead (see Determined docs for details).", - FutureWarning, - stacklevel=2, - ) - super().__init__(search_method) - self.context = context - info = det.get_cluster_info() - assert info is not None, "RemoteSearchRunner only runs on-cluster" - self.info = info - - self.latest_checkpoint = self.info.latest_checkpoint - - def run( - self, - exp_config: Union[Dict[str, Any], str], - model_dir: Optional[str] = None, - includes: Optional[Iterable[Union[str, pathlib.Path]]] = None, - ) -> int: - """ - Run custom search as a Core API experiment (on-cluster). - - Args: - exp_config (dictionary, string): experiment config filename (.yaml) or a dict. - model_dir (string): directory containing model definition. - includes (Iterable[Union[str, pathlib.Path]], optional): Additional files - or directories to include in the model definition. (default: ``None``) - """ - logger.info("RemoteSearchRunner.run") - - operations: Optional[List[searcher.Operation]] = None - - if model_dir is None: - model_dir = os.getcwd() - - if self.latest_checkpoint is not None: - experiment_id, operations = self.load_state(self.latest_checkpoint) - logger.info(f"Resuming HP searcher for experiment {experiment_id}") - else: - logger.info("No latest checkpoint. Starting new experiment.") - exp = client.create_experiment(exp_config, model_dir, includes) - self.state.experiment_id = exp.id - self.state.last_event_id = 0 - self.save_state(exp.id, []) - experiment_id = exp.id - # Note: Simulating the same print functionality as our CLI when making an experiment. - # This line is needed for the e2e tests - logger.info(f"Created experiment {exp.id}") - - # make sure client is initialized - # TODO: remove typing suppression when mypy #14473 is resolved - client._require_singleton(lambda: None)() # type: ignore - assert client._determined is not None - session = client._determined._session - self.run_experiment(experiment_id, session, operations) - - return experiment_id - - def load_state(self, storage_id: str) -> Tuple[int, List[searcher.Operation]]: - with self.context.checkpoint.restore_path(storage_id) as path: - self.state, experiment_id = self.search_method.load(path) - with path.joinpath("ops").open("rb") as ops_file: - operations = pickle.load(ops_file) - return experiment_id, operations - - def save_state(self, experiment_id: int, operations: List[searcher.Operation]) -> None: - steps_completed = self.state.last_event_id - metadata = {"steps_completed": steps_completed} - with self.context.checkpoint.store_path(metadata) as (path, storage_id): - self.search_method.save(self.state, path, experiment_id=experiment_id) - with path.joinpath("ops").open("wb") as ops_file: - pickle.dump(operations, ops_file) - - def _show_experiment_paused_msg(self) -> None: - logger.warning( - f"Experiment {self.state.experiment_id} " - "has been paused. If you leave searcher experiment running, " - "your search method will automatically resume when the experiment " - "becomes active again." - ) diff --git a/harness/determined/searcher/_search_method.py b/harness/determined/searcher/_search_method.py deleted file mode 100644 index 8190f461f35..00000000000 --- a/harness/determined/searcher/_search_method.py +++ /dev/null @@ -1,494 +0,0 @@ -import abc -import dataclasses -import enum -import json -import pathlib -import uuid -import warnings -from typing import Any, Dict, List, Optional, Set, Tuple - -from determined import experimental -from determined.common.api import bindings - -STATE_FILE = "state" - - -@dataclasses.dataclass -class SearcherState: - """ - Custom Searcher State. - - Search runners maintain this state that can be used by a ``SearchMethod`` - to inform event handling. In other words, this state can be taken into account - when deciding which operations to return from your event handler. Do not - modify ``SearcherState`` in your ``SearchMethod``. If your hyperparameter - tuning algorithm needs additional state variables, add those variable to your - ``SearchMethod`` implementation. - - Attributes: - failures: number of failed trials - trial_progress: progress of each trial as a number between 0.0 and 1.0 - trials_closed: set of completed trials - trials_created: set of created trials - """ - - failures: Set[uuid.UUID] - trial_progress: Dict[uuid.UUID, float] - trials_closed: Set[uuid.UUID] - trials_created: Set[uuid.UUID] - last_event_id: int = 0 - experiment_completed: bool = False - experiment_failed: bool = False - - def __init__(self) -> None: - self.failures = set() - self.trial_progress = {} - self.trials_closed = set() - self.trials_created = set() - - def to_dict(self) -> Dict[str, Any]: - return { - "failures": [str(f) for f in self.failures], - "trialProgress": {str(k): v for k, v in self.trial_progress.items()}, - "trialsClosed": [str(t) for t in self.trials_closed], - "trialsCreated": [str(t) for t in self.trials_created], - "lastEventId": self.last_event_id, - "experimentId": self.experiment_id, - "experimentCompleted": self.experiment_completed, - "experimentFailed": self.experiment_failed, - } - - def from_dict(self, d: Dict[str, Any]) -> None: - self.failures = {uuid.UUID(f) for f in d.get("failures", [])} - self.trial_progress = {uuid.UUID(k): v for k, v in d.get("trialProgress", {}).items()} - self.trials_closed = {uuid.UUID(t) for t in d.get("trialsClosed", [])} - self.trials_created = {uuid.UUID(t) for t in d.get("trialsCreated", [])} - self.last_event_id = d.get("lastEventId", 0) - self.experiment_id = d.get("experimentId") - self.experiment_completed = d.get("experimentCompleted", False) - self.experiment_failed = d.get("experimentFailed", False) - - -class ExitedReason(enum.Enum): - """ - The reason why a trial exited early - - The following reasons are supported: - - - `ERRORED`: The Trial encountered an exception - - `USER_CANCELLED`: The Trial was manually closed by the user - - `INVALID_HP`: The hyperparameters the trial was created with were invalid - """ - - ERRORED = "ERRORED" - USER_CANCELED = "USER_CANCELED" - INVALID_HP = "INVALID_HP" - - @classmethod - def _from_bindings( - cls, bindings_exited_reason: bindings.v1TrialExitedEarlyExitedReason - ) -> "ExitedReason": - if bindings_exited_reason == bindings.v1TrialExitedEarlyExitedReason.INVALID_HP: - return cls.INVALID_HP - if bindings_exited_reason == bindings.v1TrialExitedEarlyExitedReason.USER_REQUESTED_STOP: - return cls.USER_CANCELED - if bindings_exited_reason == bindings.v1TrialExitedEarlyExitedReason.UNSPECIFIED: - return cls.ERRORED - raise RuntimeError(f"Invalid exited reason: {bindings_exited_reason}") - - -class Operation(metaclass=abc.ABCMeta): - """ - Abstract base class for all Operations - """ - - @abc.abstractmethod - def _to_searcher_operation(self) -> bindings.v1SearcherOperation: - pass - - -class ValidateAfter(Operation): - """ - Operation signaling the trial to train until its total units trained - equals the specified length, where the units (batches, epochs, etc.) - are specified in the searcher section of the experiment configuration - """ - - def __init__(self, request_id: uuid.UUID, length: int) -> None: - super().__init__() - self.request_id = request_id - self.length = length - - def _to_searcher_operation(self) -> bindings.v1SearcherOperation: - return bindings.v1SearcherOperation( - trialOperation=bindings.v1TrialOperation( - validateAfter=bindings.v1ValidateAfterOperation( - requestId=str(self.request_id), length=str(self.length) - ), - ) - ) - - -class Close(Operation): - """ - Operation for closing the specified trial - """ - - def __init__(self, request_id: uuid.UUID): - super().__init__() - self.request_id = request_id - - def _to_searcher_operation(self) -> bindings.v1SearcherOperation: - return bindings.v1SearcherOperation( - closeTrial=bindings.v1CloseTrialOperation(requestId=str(self.request_id)) - ) - - -class Progress(Operation): - """ - Operation for signalling the relative progress of the hyperparameter search between 0 and 1 - """ - - def __init__(self, progress: float): - super().__init__() - self.progress = progress - - def _to_searcher_operation(self) -> bindings.v1SearcherOperation: - return bindings.v1SearcherOperation( - setSearcherProgress=bindings.v1SetSearcherProgressOperation(progress=self.progress) - ) - - -class Shutdown(Operation): - """ - Operation for shutting the experiment down - """ - - def __init__(self, cancel: bool = False, failure: bool = False) -> None: - super().__init__() - self.cancel = cancel - self.failure = failure - - def _to_searcher_operation(self) -> bindings.v1SearcherOperation: - return bindings.v1SearcherOperation( - shutDown=bindings.v1ShutDownOperation(cancel=self.cancel, failure=self.failure) - ) - - -class Create(Operation): - """ - Operation for creating a trial with a specified combination of hyperparameter values - """ - - def __init__( - self, - request_id: uuid.UUID, - hparams: Dict[str, Any], - checkpoint: Optional[experimental.Checkpoint], - ) -> None: - super().__init__() - self.request_id = request_id - self.hparams = json.dumps(hparams) - self.checkpoint = checkpoint - - def _to_searcher_operation(self) -> bindings.v1SearcherOperation: - return bindings.v1SearcherOperation( - createTrial=bindings.v1CreateTrialOperation( - hyperparams=self.hparams, requestId=str(self.request_id) - ) - ) - - -class SearchMethod: - """ - The implementation of a custom hyperparameter tuning algorithm. - - To implement your specific hyperparameter tuning approach, subclass ``SearchMethod`` - overriding the event handler methods. - - Each event handler, except :meth:`progress() ` - returns a list of operations (``List[Operation]``) that will be submitted to master for - processing. - - Currently, we support the following :class:`~Operation`: - - - :class:`~Create` - starts a new trial with a unique trial id and a set of hyperparameter - values. - - :class:`~ValidateAfter` - sets number of steps (i.e., batches or epochs) after which a - validation is run for a trial with a given id. - - :class:`~Progress` - updates the progress of the multi-trial experiment to the master. - - :class:`~Close` - closes a trial with a given id. - - :class:`~Shutdown` - closes the experiment. - - .. note:: - - Do not modify ``searcher_state`` passed into event handlers. - """ - - def __init__(self) -> None: - warnings.warn( - "`SearchMethod` and all custom searchers have been deprecated. " - "This feature will be removed in a future release. Consider configuring a preset " - "searcher instead (see Determined docs for details).", - FutureWarning, - stacklevel=2, - ) - - @abc.abstractmethod - def initial_operations(self, searcher_state: SearcherState) -> List[Operation]: - """ - Returns a list of initial operations that the custom hyperparameter search should - perform. This is called by the Custom Searcher :class:`~SearchRunner` - to initialize the trials - - Example: - - .. code:: python - - def initial_operations(self, _: searcher.SearcherState) -> List[searcher.Operation]: - ops: List[searcher.Operation] = [] - N = 100 - hparams = { - # ... - } - for _ in range(0, N): - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=hparams, - checkpoint=None, - ) - ops.append(create) - return ops - - Args: - searcher_state(:class:`~SearcherState`): Read-only current searcher state - - Returns: - List[Operation]: Initial list of :class:`~Operation` to start the Hyperparameter - search - """ - pass - - @abc.abstractmethod - def on_trial_created( - self, searcher_state: SearcherState, request_id: uuid.UUID - ) -> List[Operation]: - """ - Informs the searcher that a trial has been created - as a result of Create operation. - - Example: - - .. code:: python - - def on_trial_created( - self, _: SearcherState, request_id: uuid.UUID - ) -> List[Operation]: - return [ - searcher.ValidateAfter( - request_id=request_id, - length=1, # Run for one unit of time (epoch, etc.) - ) - ] - - In this example, we are choosing to deterministically train for one unit of time - - Args: - searcher_state(:class:`~SearcherState`): Read-only current searcher state - request_id (uuid.UUID): Request UUID of the Trial that was created - - Returns: - List[Operation]: List of :class:`~Operation` to run upon creation of the given - trial - """ - pass - - @abc.abstractmethod - def on_validation_completed( - self, searcher_state: SearcherState, request_id: uuid.UUID, metric: Any, train_length: int - ) -> List[Operation]: - """ - Informs the searcher that the validation workload has completed after training for - ``train_length`` units. It returns any new operations as a result of this workload - completing - - Example: - - .. code:: python - - def on_validation_completed( - self, - searcher_state: SearcherState, - request_id: uuid.UUID, - metric: Any, - train_length: int - ) -> List[Operation]: - if train_length < self.max_train_length: - return [ - searcher.ValidateAfter( - request_id=request_id, - length=train_length + 1, # Run an additional unit of time - ) - ] - return [searcher.Close(request_id=request_id)] - - Args: - searcher_state (SearcherState): Read-only current searcher state - request_id (uuid.UUID): Request UUID of the Trial that was trained - metric (Any): Metric data returned by the trial - train_length (int): The cumulative units of time that that trial has finished - training for (epochs, etc.) - - Returns: - List[Operation]: List of :class:`~Operation` to run upon completion of training for - the given trial - """ - pass - - @abc.abstractmethod - def on_trial_closed( - self, searcher_state: SearcherState, request_id: uuid.UUID - ) -> List[Operation]: - """ - Informs the searcher that a trial has been closed as a result of a :class:`~Close` - - Example: - - .. code:: python - - def on_trial_closed( - self, searcher_state: SearcherState, request_id: uuid.UUID - ) -> List[Operation]: - if searcher_state.trials_created < self.max_num_trials: - hparams = { - # ... - } - return [ - searcher.Create( - request_id=uuid.uuid4(), - hparams=hparams, - checkpoint=None, - ) - ] - if searcher_state.trials_closed >= self.max_num_trials: - return [searcher.Shutdown()] - return [] - - Args: - searcher_state (SearcherState): Read-only current searcher state - request_id (uuid.UUID): Request UUID of the Trial that was closed - - Returns: - List[Operation]: List of :class:`~Operation` to run after closing the given - trial - """ - pass - - @abc.abstractmethod - def progress(self, searcher_state: SearcherState) -> float: - """ - Returns experiment progress as a float between 0 and 1. - - Example: - - .. code:: python - - def progress(self, searcher_state: SearcherState) -> float: - return searcher_state.trials_closed / float(self.max_num_trials) - - Args: - searcher_state (SearcherState): Read-only current searcher state - - Returns: - float: Experiment progress as a float between 0 and 1. - """ - pass - - @abc.abstractmethod - def on_trial_exited_early( - self, - searcher_state: SearcherState, - request_id: uuid.UUID, - exited_reason: ExitedReason, - ) -> List[Operation]: - """ - Informs the searcher that a trial has exited earlier than expected. - - Example: - - .. code:: python - - def on_trial_exited_early( - self, - searcher_state: SearcherState, - request_id: uuid.UUID, - exited_reason: ExitedReason, - ) -> List[Operation]: - if exited_reason == searcher.ExitedReason.USER_CANCELED: - return [searcher.Shutdown(cancel=True)] - if exited_reason == searcher.ExitedReason.INVALID_HP: - return [searcher.Shutdown(failure=True)] - if searcher_state.failures >= self.max_failures: - return [searcher.Shutdown(failure=True)] - return [] - - .. note:: - - The trial has already been internally closed when this callback is run. - You do not need to explicitly issue a :class:`~Close` operation - - Args: - searcher_state (SearcherState): Read-only current searcher state - request_id (uuid.UUID): Request UUID of the Trial that exited early - exited_reason (ExitedReason): The reason that the trial exited early - - Returns: - List[Operation]: List of :class:`~Operation` to run in response to the given - trial exiting early - """ - pass - - def save( - self, searcher_state: SearcherState, path: pathlib.Path, *, experiment_id: int - ) -> None: - """ - Saves the searcher state and the search method state. - It will be called by the ``SearchRunner`` after receiving operations - from the ``SearchMethod`` - """ - searcher_state_file = path.joinpath(STATE_FILE) - d = searcher_state.to_dict() - d["experimentId"] = experiment_id - with searcher_state_file.open("w") as f: - json.dump(d, f) - - self.save_method_state(path) - - def save_method_state(self, path: pathlib.Path) -> None: - """ - Saves method-specific state - """ - pass - - def load(self, path: pathlib.Path) -> Tuple[SearcherState, int]: - """ - Loads searcher state and method-specific state. - """ - searcher_state_file = path.joinpath(STATE_FILE) - with searcher_state_file.open("r") as f: - state_dict = json.load(f) - searcher_state = SearcherState() - searcher_state.from_dict(state_dict) - experiment_id = state_dict["experimentId"] # type: int - - self.load_method_state(path) - return searcher_state, experiment_id - - def load_method_state( - self, - path: pathlib.Path, - ) -> None: - """ - Loads method-specific search state. - """ - pass diff --git a/harness/determined/searcher/_search_runner.py b/harness/determined/searcher/_search_runner.py deleted file mode 100644 index 2081f7145e8..00000000000 --- a/harness/determined/searcher/_search_runner.py +++ /dev/null @@ -1,375 +0,0 @@ -import json -import logging -import os -import pathlib -import pickle -import time -import uuid -import warnings -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union - -from determined import searcher -from determined.common import api -from determined.common.api import bindings, errors -from determined.experimental import client - -EXPERIMENT_ID_FILE = "experiment_id.txt" -logger = logging.getLogger("determined.searcher") - - -class _ExperimentInactiveException(Exception): - def __init__(self, exp_state: bindings.experimentv1State): - self.exp_state = exp_state - - -class SearchRunner: - def __init__( - self, - search_method: searcher.SearchMethod, - ) -> None: - self.search_method = search_method - self.state = searcher.SearcherState() - - def _get_operations(self, event: bindings.v1SearcherEvent) -> List[searcher.Operation]: - if event.initialOperations: - logger.info("initial operations") - operations = self.search_method.initial_operations(self.state) - elif event.trialCreated: - logger.info(f"trialCreated({event.trialCreated.requestId})") - request_id = uuid.UUID(event.trialCreated.requestId) - self.state.trials_created.add(request_id) - self.state.trial_progress[request_id] = 0.0 - operations = self.search_method.on_trial_created(self.state, request_id) - elif event.trialClosed: - logger.info(f"trialClosed({event.trialClosed.requestId})") - request_id = uuid.UUID(event.trialClosed.requestId) - self.state.trials_closed.add(request_id) - operations = self.search_method.on_trial_closed(self.state, request_id) - - # add progress operation - progress = self.search_method.progress(self.state) - operations.append(searcher.Progress(progress)) - elif event.trialExitedEarly: - # duplicate exit accounting already performed by master - searcher_exit_reason = searcher.ExitedReason._from_bindings( - event.trialExitedEarly.exitedReason - ) - logger.info( - f"trialExitedEarly({event.trialExitedEarly.requestId}, {searcher_exit_reason})" - ) - if event.trialExitedEarly.exitedReason is None: - raise RuntimeError("trialExitedEarly event is invalid without exitedReason") - request_id = uuid.UUID(event.trialExitedEarly.requestId) - if ( - event.trialExitedEarly.exitedReason - == bindings.v1TrialExitedEarlyExitedReason.INVALID_HP - ): - self.state.trial_progress.pop(request_id, None) - elif ( - event.trialExitedEarly.exitedReason - == bindings.v1TrialExitedEarlyExitedReason.UNSPECIFIED - ): - self.state.failures.add(request_id) - operations = self.search_method.on_trial_exited_early( - self.state, - request_id, - exited_reason=searcher_exit_reason, - ) - # add progress operation - progress = self.search_method.progress(self.state) - operations.append(searcher.Progress(progress)) - elif event.validationCompleted: - # duplicate completion accounting already performed by master - logger.info( - f"validationCompleted({event.validationCompleted.requestId}," - f" {event.validationCompleted.metric})" - ) - request_id = uuid.UUID(event.validationCompleted.requestId) - if event.validationCompleted.metric is None: - raise RuntimeError("validationCompleted event is invalid without a metric") - - operations = self.search_method.on_validation_completed( - self.state, - request_id, - event.validationCompleted.metric, - int(event.validationCompleted.validateAfterLength), - ) - # add progress operation - progress = self.search_method.progress(self.state) - operations.append(searcher.Progress(progress)) - elif event.experimentInactive: - logger.info( - f"experiment {self.state.experiment_id} is " - f"inactive; state={event.experimentInactive.experimentState}" - ) - - raise _ExperimentInactiveException(event.experimentInactive.experimentState) - elif event.trialProgress: - logger.debug( - f"trialProgress({event.trialProgress.requestId}, " - f"{event.trialProgress.partialUnits})" - ) - request_id = uuid.UUID(event.trialProgress.requestId) - self.state.trial_progress[request_id] = float(event.trialProgress.partialUnits) - progress = self.search_method.progress(self.state) - operations = [searcher.Progress(progress)] - else: - raise RuntimeError(f"Unsupported event {event}") - return operations - - def run_experiment( - self, - experiment_id: int, - session: api.Session, - prior_operations: Optional[List[searcher.Operation]], - sleep_time: float = 1.0, - ) -> None: - experiment_is_active = True - try: - while experiment_is_active: - time.sleep( - sleep_time - ) # we don't want to call long polling API more often than every second. - events = self.get_events(session, experiment_id) - if not events: - continue - logger.info(json.dumps([SearchRunner._searcher_event_as_dict(e) for e in events])) - # the first event is an event we have already processed and told master about it - # however, we may not have saved the state after that event if we crashed - # after POSTing operations but before saving state - last_event_id = self.state.last_event_id - first_event = True - for event in events: - if ( - first_event - and last_event_id != 0 - and last_event_id >= event.id >= 0 - and prior_operations is not None - ): - logger.info(f"Resubmitting operations for event.id={event.id}") - operations = prior_operations - else: - if event.experimentInactive: - logger.info( - f"experiment {self.state.experiment_id} is " - f"inactive; state={event.experimentInactive.experimentState}" - ) - if ( - event.experimentInactive.experimentState - == bindings.experimentv1State.COMPLETED - ): - self.state.experiment_completed = True - elif ( - event.experimentInactive.experimentState - == bindings.experimentv1State.ERROR - ): - self.state.experiment_failed = True - - if ( - event.experimentInactive.experimentState - == bindings.experimentv1State.PAUSED - ): - self._show_experiment_paused_msg() - else: - experiment_is_active = False - break - - operations = self._get_operations(event) - - # save state - self.state.last_event_id = event.id - self.save_state(experiment_id, operations) - - first_event = False - self.post_operations(session, experiment_id, event, operations) - - except KeyboardInterrupt: - print("Runner interrupted") - - def post_operations( - self, - session: api.Session, - experiment_id: int, - event: bindings.v1SearcherEvent, - operations: List[searcher.Operation], - ) -> None: - body = bindings.v1PostSearcherOperationsRequest( - experimentId=self.state.experiment_id, - searcherOperations=[op._to_searcher_operation() for op in operations], - triggeredByEvent=event, - ) - - # This try/except is intended to catch a specific error which occurs for DeepSpeed Autotune. - # DeepSpeed makes an explicit `exit()` call internally when autotuning flags are enabled in - # the DS config. When we also post a `Close` operation, there is a resulting race condition - # and intermittently the process and its corresponding agent die before the `Close` - # operation reaches the agent, resulting in a `APIException` with a `failed to post - # operations: rpc error: code = NotFound desc = actor /experiments/xxx could not be found` - # message. This try/except allows the experiment to continue uninterrupted in such cases. - try: - bindings.post_PostSearcherOperations( - session, - body=body, - experimentId=experiment_id, - ) - except errors.APIException as e: - logger.warning(f"Catching errors.APIException: {str(e)}") - close_op_in_operations = any((isinstance(o, searcher.Close) for o in operations)) - logger.warning(f"operations: {operations}") - if close_op_in_operations and "could not be found" in str(e): - pass - else: - raise e - - def get_events( - self, - session: api.Session, - experiment_id: int, - ) -> Optional[Sequence[bindings.v1SearcherEvent]]: - # API is implemented with long polling. - events = bindings.get_GetSearcherEvents(session, experimentId=experiment_id) - return events.searcherEvents - - def save_state(self, experiment_id: int, operations: List[searcher.Operation]) -> None: - pass - - def _show_experiment_paused_msg(self) -> None: - pass - - @staticmethod - def _searcher_event_as_dict(event: bindings.v1SearcherEvent) -> dict: - return {k: v for k, v in event.to_json().items() if v is not None} - - -class LocalSearchRunner(SearchRunner): - """ - ``LocalSearchRunner`` performs a search for optimal hyperparameter values, - applying the provided ``SearchMethod``. It is executed locally and interacts - with a Determined cluster where it starts a multi-trial experiment. It then - reacts to event notifications coming from the running experiments by forwarding - them to event handler methods in your ``SearchMethod`` implementation and sending - the returned operations back to the experiment. - """ - - def __init__( - self, - search_method: searcher.SearchMethod, - searcher_dir: Optional[pathlib.Path] = None, - session: Optional[api.Session] = None, - ): - warnings.warn( - "`LocalSearchRunner` and all custom searchers have been deprecated. " - "This feature will be removed in a future release. Consider configuring a preset " - "searcher instead (see Determined docs for details).", - FutureWarning, - stacklevel=2, - ) - super().__init__(search_method) - self.state_path = None - self.session = session - - self.searcher_dir = searcher_dir or pathlib.Path.cwd() - if not self.searcher_dir.exists(): - self.searcher_dir.mkdir(parents=True) - elif not self.searcher_dir.is_dir(): - raise FileExistsError( - f"searcher_dir={self.searcher_dir} already exists and is not a directory" - ) - - def run( - self, - exp_config: Union[Dict[str, Any], str], - model_dir: Optional[str] = None, - includes: Optional[Iterable[Union[str, pathlib.Path]]] = None, - ) -> int: - """ - Run custom search. - - Args: - exp_config (dictionary, string): experiment config filename (.yaml) or a dict. - model_dir (string): directory containing model definition. - includes (Iterable[Union[str, pathlib.Path]], optional): Additional files - or directories to include in the model definition. (default: ``None``) - """ - logger.info("LocalSearchRunner.run") - - if model_dir is None: - model_dir = os.getcwd() - experiment_id_file = self.searcher_dir.joinpath(EXPERIMENT_ID_FILE) - operations: Optional[List[searcher.Operation]] = None - if experiment_id_file.exists(): - with experiment_id_file.open("r") as f: - experiment_id = int(f.read()) - logger.info(f"Resuming HP searcher for experiment {experiment_id}") - # load searcher state and search method state - _, operations = self.load_state(experiment_id) - else: - exp = client.create_experiment(exp_config, model_dir, includes) - with experiment_id_file.open("w") as f: - f.write(str(exp.id)) - state_path = self._get_state_path(exp.id) - state_path.mkdir(parents=True) - self.state.experiment_id = exp.id - self.state.last_event_id = 0 - self.save_state(exp.id, []) - experiment_id = exp.id - # Note: Simulating the same print functionality as our CLI when making an experiment. - logger.info(f"Created experiment {experiment_id}") - - # make sure client is initialized - # TODO: remove typing suppression when mypy #14473 is resolved - client._require_singleton(lambda: None)() # type: ignore - assert client._determined is not None - if self.session: - session = self.session - else: - session = client._determined._session - self.run_experiment(experiment_id, session, operations) - return experiment_id - - def load_state(self, experiment_id: int) -> Tuple[int, List[searcher.Operation]]: - experiment_searcher_dir = self._get_state_path(experiment_id) - with experiment_searcher_dir.joinpath("event_id").open("r") as event_id_file: - last_event_id = int(event_id_file.read()) - state_path = experiment_searcher_dir.joinpath(f"event_{last_event_id}") - self.state, loaded_experiment_id = self.search_method.load(state_path) - assert experiment_id == loaded_experiment_id, ( - f"Experiment id mismatch. Expected {experiment_id}." f" Found {loaded_experiment_id}" - ) - with state_path.joinpath("ops").open("rb") as f: - operations = pickle.load(f) - return loaded_experiment_id, operations - - def save_state(self, experiment_id: int, operations: List[searcher.Operation]) -> None: - experiment_searcher_dir = self._get_state_path(experiment_id) - state_path = experiment_searcher_dir.joinpath(f"event_{self.state.last_event_id}") - - if not state_path.exists(): - state_path.mkdir(parents=True) - - self.search_method.save( - self.state, - state_path, - experiment_id=experiment_id, - ) - with state_path.joinpath("ops").open("wb") as ops_file: - pickle.dump(operations, ops_file) - - # commit - event_id_path = experiment_searcher_dir.joinpath("event_id") - event_id_new_path = experiment_searcher_dir.joinpath("event_id_new") - with event_id_new_path.open("w") as f: - f.write(str(self.state.last_event_id)) - os.replace(event_id_new_path, event_id_path) - - def _get_state_path(self, experiment_id: int) -> pathlib.Path: - return self.searcher_dir.joinpath(f"exp_{experiment_id}") - - def _show_experiment_paused_msg(self) -> None: - logger.warning( - f"Experiment {self.state.experiment_id} " - f"has been paused. If you leave searcher process running, your search method" - f" will automatically resume when the experiment becomes active again. " - f"Otherwise, you can terminate this process and restart it " - f"manually to continue the search." - ) diff --git a/harness/determined/transformers/_hf_callback.py b/harness/determined/transformers/_hf_callback.py index b9840db5f55..e4ccc3a819b 100644 --- a/harness/determined/transformers/_hf_callback.py +++ b/harness/determined/transformers/_hf_callback.py @@ -8,10 +8,28 @@ import determined as det -logger = logging.getLogger("determined.transformers") +logger = logging.getLogger("det.transformers") class DetCallback(transformers.TrainerCallback): # type: ignore + """ + ``DetCallback`` integrates a training loop built around ``transformers.Trainer`` with the + Determined cluster. It reports metrics, uploads checkpoints, and handles preemption signals. + It also automatically restores training from the latest checkpoint after pauses or crashes. + + Simply include ``DetCallback`` as in the list of ``callbacks`` that you pass to your + ``Trainer``. + + Args: + core_context: the result of a ``det.core.init()`` call. + args: ``TrainingArgs`` from a ``transformers.HfArgumentParser``, the same ``args`` to be + passed to the ``Trainer``. + filter_metrics: a list of metric names to report to Determined. Default: ``None`` (all + metrics are reported). + user_data: an optional dict of metadata to be stored in every checkpoint. + Default: ``None``. + """ + def __init__( self, core_context: det.core.Context, @@ -20,32 +38,125 @@ def __init__( user_data: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() - self.core_context = core_context - self.filter_metrics = filter_metrics self.user_data = user_data + + self.last_train_metrics = -1 + self.last_eval_metrics = -1 + self.last_save = -1 + self.last_progress = 0 + + info = det.get_cluster_info() + if not info: + raise RuntimeError("det.transformers.DetCallback must be run on a Determined cluster") + self.info = info + self.load_last_checkpoint(args) - self.last_metrics: Dict[str, float] = {"train_step": -1, "eval_step": -1} - self.searcher_ops = self.core_context.searcher.operations() - self.current_op = next(self.searcher_ops) - self.updating_searcher = False - - cluster_info = det.get_cluster_info() - assert ( - cluster_info - ), "Could not find `cluster_info`, the HF Callback must be run on a Determined Cluster" - searcher_config = cluster_info.trial._config["searcher"] - self.searcher_metric = searcher_config["metric"] - # Custom searchers have a different config structure which need to be handled differently - if searcher_config["name"] == "custom": - self.searcher_unit = "batches" - self.searcher_max_length = self.current_op.length + self.searcher_metric = None + self.time_metric = None + if self.info.task_type == "TRIAL": + searcher_config = self.info.trial._config["searcher"] + self._check_searcher_config(searcher_config, args) + self.searcher_metric = searcher_config["metric"] + self.time_metric = searcher_config.get("time_metric") + # Don't allow filtering of the searcher or time_metric metrics. + if self.filter_metrics: + self.filter_metrics.append(self.searcher_metric) + if self.time_metric: + self.filter_metrics.append(self.time_metric) + + # Undocumented workarounds in case forcing the checkpoint and validations at the end of + # non-preempted training is a bad idea somehow. + self._force_final_save = True + self._force_final_evaluate = True + + def load_last_checkpoint(self, args: transformers.TrainingArguments) -> None: + latest_checkpoint = self.info.latest_checkpoint + if latest_checkpoint is None: + return + if args.overwrite_output_dir is True: + logger.info( + "Skipping downloading last checkpoint from Determined due " + "to overwrite_output_dir=True." + ) + return + + # To resume DeepSpeed, each node requires ALL sharded model/optimizer states, + # so we can skip using selector and just download all files. + self.core_context.checkpoint.download(latest_checkpoint, args.output_dir) + + checkpoint_path = trainer_utils.get_last_checkpoint(args.output_dir) + args.resume_from_checkpoint = checkpoint_path + + logger.info(f"Latest checkpoint downloaded to {checkpoint_path}.") + + def _check_searcher_config( + self, cfg: Dict[str, Any], args: transformers.TrainingArguments + ) -> None: + if args.max_steps > -1: + args_unit = "batches" + args_len = args.max_steps + len_arg = "--max_steps" else: - self.searcher_unit = list(searcher_config["max_length"].keys())[0] - self.searcher_max_length = list(searcher_config["max_length"].values())[0] - self._check_searcher_compatibility(args) + args_unit = "epochs" + args_len = args.num_train_epochs + len_arg = "--num_train_epochs" + + if isinstance(cfg.get("max_length"), int): + # Legacy searcher config (unitless). Has never been supported, actually. + raise ValueError( + "HF trainer no longer respects the deprecated searcher.max_length " + "field. searcher.max_length is deprecated; please remove it and rely " + f"on {len_arg} instead to avoid ambiguous training specifications." + ) + elif isinstance(cfg.get("max_length"), dict): + # Legacy searcher config; max_length must match provided args. + search_unit, search_len = next(iter(cfg["max_length"].items())) + if (search_unit, search_len) != (args_unit, args_len): + raise ValueError( + "HF trainer units does not match configured searcher.max_length " + f"({args_unit}={args_len} != {search_unit}={search_len}). The " + "searcher.max_length field is deprecated; please remove it and avoid " + "ambiguous training specifications." + ) + elif cfg["name"] in ["adaptive_asha", "async_halving"]: + # ASHA search: check time_metric and max_time are sane. + self.required_metrics.append(cfg["time_metric"]) + search_unit = cfg["time_metric"] + search_len = cfg["max_time"] + if search_unit not in ("batches", "epochs"): + self.required_metrics.append(search_unit) + elif (search_unit, search_len) != (args_unit, args_len): + name = cfg["name"] + raise ValueError( + "HF trainer units does not match configured the max_time configured for " + f"{name} searcher ({args_unit}={args_len} != {search_unit}={search_len}. " + f"Please update one of the searcher.max_time config field or the {len_arg} " + "to match the other." + ) + + def _check_eval_metrics(self, metrics: Dict[str, Any]) -> None: + search_ok = self.searcher_metric is None or self.searcher_metric in metrics + time_ok = self.time_metric is None or self.time_metric in metrics + if not search_ok and not time_ok: + raise ValueError( + f"Searcher metric '{self.searcher_metric}' set by searcher.metric config field " + f"and time metric '{self.time_metric}' from searcher.time_metric config field are " + "both missing; you must emit those metrics for the hyperparameter search to work." + ) + if not search_ok: + raise ValueError( + f"Searcher metric '{self.searcher_metric}' set by searcher.metric config field " + "is missing; you must emit that metric for features like hyperparameter search, " + "checkpoint garbage collection, and selecting the best checkpoint to work." + ) + if not time_ok: + raise ValueError( + f"Time metric '{self.time_metric}' set by searcher.time_metric config field is " + "missing; you must emit that metric for the hyperparameter search to work." + ) def on_log( self, @@ -60,52 +171,56 @@ def on_log( return metrics, metric_type = self._get_metrics(logs) logger.debug(f"on_log metrics, global_step {state.global_step}", metrics) + metrics["batches"] = metrics.get("batches", state.global_step) + metrics["epochs"] = metrics.get("epochs", state.epoch) if metric_type == TRAIN: # Prevents reporting metrics for the same step twice. This happens after # training is completed and average training metrics are reported with # the same step as the in-progress training metrics. - if self.last_metrics["train_step"] != state.global_step: + if self.last_train_metrics != state.global_step: + self.last_train_metrics = state.global_step if state.is_world_process_zero: - self.core_context.train.report_training_metrics( - steps_completed=state.global_step, metrics=metrics + # Note: state.global_step represents steps_completed, not step index + self.core_context.train.report_metrics( + group="training", steps_completed=state.global_step, metrics=metrics ) - metrics["train_step"] = state.global_step elif metric_type == EVAL: # Prevents reporting metrics for the same step twice. This happens when # after-training evaluation is completed, and it is reported with the same # step as the last during-training evaluation. - if self.last_metrics["eval_step"] != state.global_step: + if self.last_eval_metrics != state.global_step: + self.last_eval_metrics = state.global_step if state.is_world_process_zero: - self.core_context.train.report_validation_metrics( - steps_completed=state.global_step, metrics=metrics + self._check_eval_metrics(metrics) + # Note: state.global_step represents steps_completed, not step index + self.core_context.train.report_metrics( + group="validation", steps_completed=state.global_step, metrics=metrics ) - metrics["eval_step"] = state.global_step else: logger.warning(f"Metrics not reported: metric type = {metric_type}.") - self.last_metrics.update(metrics) - - # Update searcher state after collecting the metrics. - if self.updating_searcher is True: - self._update_searcher(state, control) - - # If searcher is NOT being updated and preemption signal is received - # (e.g., by pausing experiment in the WebUI), notify Trainer (via TrainerControl) - # to save the checkpoint and stop training. - if self.updating_searcher is False and self.core_context.preempt.should_preempt(): + # If we've been preempted, save a checkpoint and shut down training. + if self.core_context.preempt.should_preempt(): control.should_training_stop = True - control.should_save = True + # Don't set control.should_save now, or it can trigger multiple saves, if we trigger + # in a training on_log and arrive here again in an evaluate on_log. We would not cause + # that to happen, but other callbacks could, such as if it were just naturally time for + # an evaluation. So just let the save-at-end logic handle it. def _get_metrics(self, logs: Dict[str, Any]) -> Tuple[Dict[str, Any], str]: - metrics = logs metric_type = get_metric_type(logs) - if self.filter_metrics: - metrics = {} - for k, v in logs.items(): - if any(m in k for m in self.filter_metrics) is True: - metrics[k] = v - + if not self.filter_metrics: + metrics = logs + else: + metrics = {k: v for k, v in logs.items() if any(m in k for m in self.filter_metrics)} + # Remove the default rounded 'epoch' metric. + metrics.pop("epoch", None) + # Also remove speed metrics. + speed_suffixes = ["_runtime", "_per_second", "_compilation_time"] + speed_metrics = [m for m in metrics if any(m.endswith(s) for s in speed_suffixes)] + for m in speed_metrics: + metrics.pop(m, None) return metrics, metric_type def on_save( @@ -115,25 +230,24 @@ def on_save( control: transformers.TrainerControl, **kwargs: Any, ) -> None: - info = det.get_cluster_info() - assert info - + self.last_save = state.global_step # local_path is where HF Trainer saves model and tokenizer in a given step. local_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") if state.is_world_process_zero: if self.user_data is not None: self._on_save_user_data(local_path) - det_checkpoint_metadata = { + metadata = { "steps_completed": state.global_step, - "trial_id": info.trial.trial_id, } + if self.info.task_type == "TRIAL": + metadata["trial_id"] = self.info.trial.trial_id def selector(x: str) -> bool: return x.startswith((f"checkpoint-{state.global_step}/", "runs/")) self.core_context.checkpoint.upload( - args.output_dir, metadata=det_checkpoint_metadata, shard=True, selector=selector + args.output_dir, metadata=metadata, shard=True, selector=selector ) def _on_save_user_data(self, save_path: str) -> None: @@ -145,28 +259,6 @@ def _on_save_user_data(self, save_path: str) -> None: with open(os.path.join(save_path, "my_data.json"), "w") as f: json.dump(self.user_data, f) - def load_last_checkpoint(self, args: transformers.TrainingArguments) -> None: - info = det.get_cluster_info() - assert info - - latest_checkpoint = info.latest_checkpoint - if latest_checkpoint is not None: - if args.overwrite_output_dir is True: - logger.info( - "Skip downloading last checkpoint from Determined due " - "to overwrite_output_dir=True." - ) - return - - # To resume DeepSpeed, each node requires ALL sharded model/optimizer states, - # so we can skip using selector and just download all files. - self.core_context.checkpoint.download(latest_checkpoint, args.output_dir) - - checkpoint_path = trainer_utils.get_last_checkpoint(args.output_dir) - args.resume_from_checkpoint = checkpoint_path - - logger.info(f"Latest checkpoint downloaded to {checkpoint_path}.") - def on_step_end( self, args: transformers.TrainingArguments, @@ -174,17 +266,14 @@ def on_step_end( control: transformers.TrainerControl, **kwargs: Any, ) -> None: - # state.epoch is not None only during training. - if state.epoch and self.searcher_unit == "batches": - if state.is_world_process_zero: - self.current_op.report_progress(state.global_step) - - if state.global_step >= self.current_op.length: - logger.info( - f"Max length of {self.current_op.length} steps reached for current " - f"searcher operation. Updating searcher." - ) - self._update_searcher(state, control) + if state.is_world_process_zero and args.max_steps > -1: + # There needs to be at least 1% increase in progress to report progress (maximum 100 + # report_progress API calls in per trial). + progress = state.global_step / args.max_steps + percent = int(progress * 100) + if percent > self.last_progress: + self.last_progress = percent + self.core_context.train.report_progress(progress) def on_epoch_end( self, @@ -193,95 +282,30 @@ def on_epoch_end( control: transformers.TrainerControl, **kwargs: Any, ) -> None: - # state.epoch is not None only during training. - if state.epoch and self.searcher_unit == "epochs": - if state.is_world_process_zero: - self.current_op.report_progress(state.epoch) - - if state.epoch >= self.current_op.length: - logger.info( - f"Max length of {state.epoch} epochs reached for current " - f"searcher operation. Updating searcher." - ) - self._update_searcher(state, control) - - def _update_searcher( - self, state: transformers.TrainerState, control: transformers.TrainerControl - ) -> None: - if self._metrics_reported(state.global_step) is False: - self._wait_for_metrics(control) - return - - if state.is_world_process_zero: - if self.last_metrics is None: - logger.warning( - "No training or evaluation metrics has been recorded. Please " - "check your settings for training metrics " - "(--logging_strategy and --logging_steps) or " - "evaluation metrics (--evaluation_strategy and --eval_steps). " - "Reporting trainer_state.best_metric to the searcher." - ) - searcher_metric = state.best_metric - elif self.searcher_metric not in self.last_metrics: - logger.warning( - f"Searcher metric {self.searcher_metric} from the yaml config file does " - "not match any of the recorded metrics " - f"in {self.last_metrics}. " - "Reporting trainer_state.best_metric to the searcher." - ) - searcher_metric = state.best_metric - else: - searcher_metric = self.last_metrics[self.searcher_metric] - - logger.info(f"Metric reported to searcher: {searcher_metric}") - self.current_op.report_completed(searcher_metric) - - self.updating_searcher = False + # Decide if we're about to shut down training. + is_end = False + if control.should_training_stop: + is_end = True + elif args.max_steps > -1: + is_end = state.global_step >= args.max_steps + else: + is_end = state.epoch >= args.num_train_epochs - try: - self.current_op = next(self.searcher_ops) - except StopIteration: - control.should_training_stop = True + # If training is ending, this is our last chance to ask for a eval and/or save. + if is_end: + # Avoid stale evaluate-at-end. + if state.global_step > self.last_eval_metrics: + # Also avoid evaluate-at-end if we have been preempted. + if self._force_final_evaluate and not self.core_context.preempt.should_preempt(): + control.should_evaluate = True + # Avoid stale save-at-end. + if state.global_step > self.last_save: + # You can't disable save-after-preemption. + if self._force_final_save or self.core_context.preempt.should_preempt(): + control.should_save = True - def _metrics_reported(self, step: int) -> bool: - return self.last_metrics["eval_step"] == step and self.last_metrics["train_step"] == step - - def _wait_for_metrics(self, control: transformers.TrainerControl) -> None: - # Notify Trainer (via transformers.TrainerControl) to: - # (1) log current training metrics, - # (2) evaluate the model and log evaluation metrics, - # (3) save the checkpoint. - # updating_searcher is as an internal flag that indicates we are - # in the process of updating the searcher with the current metrics. - control.should_log = True - control.should_evaluate = True - control.should_save = True - self.updating_searcher = True - - def _check_searcher_compatibility(self, args: transformers.TrainingArguments) -> None: - if self.searcher_unit == "batches": - if args.max_steps == -1: - self._raise_config_mismatch("epochs", args.num_train_epochs) - elif args.max_steps != self.searcher_max_length: - self._raise_config_mismatch("batches", args.max_steps) - elif self.searcher_unit == "epochs": - if args.max_steps != -1: - self._raise_config_mismatch("batches", args.max_steps) - elif args.num_train_epochs != self.searcher_max_length: - self._raise_config_mismatch("epochs", args.num_train_epochs) - - def _raise_config_mismatch( - self, - trainer_units: str, - trainer_len: float, - ) -> None: - raise ValueError( - f"HF trainer units {trainer_units}={trainer_len} MUST match searcher config " - f"{self.searcher_unit}={self.searcher_max_length}. " - f"Modify either --num_train_epochs for the training script or " - f"searcher.max_length.epochs in the experiment config so they are the same value " - f"(--max_steps and searcher.max_length.batches if using batches)." - ) + if state.is_world_process_zero and args.max_steps == -1: + self.core_context.train.report_progress(state.epoch / args.num_train_epochs) EVAL = "eval_" @@ -290,13 +314,10 @@ def _raise_config_mismatch( def get_metric_type(d: Dict[str, Any]) -> str: - for k, _ in d.items(): - if k.startswith(EVAL): - return EVAL - elif k.startswith(TEST): - return TEST - else: - return TRAIN + if any(k.startswith(EVAL) for k in d): + return EVAL + if any(k.startswith(TEST) for k in d): + return TEST return TRAIN diff --git a/harness/tests/cli/test_cli.py b/harness/tests/cli/test_cli.py index 9a5b0bf9b40..00ddfaccc59 100644 --- a/harness/tests/cli/test_cli.py +++ b/harness/tests/cli/test_cli.py @@ -1,6 +1,7 @@ import collections import inspect import io +import json import os import pathlib import sys @@ -12,13 +13,16 @@ import pytest import requests import requests_mock +from responses import matchers from determined.cli import cli, ntsc, render from determined.common import constants, context from determined.common.api import bindings from tests import filetree +from tests.cli import util MINIMAL_CONFIG = '{"description": "test"}' +MASTER_HOST = "http://localhost:8080" def test_parse_config() -> None: @@ -566,3 +570,114 @@ def test_dev_bindings_call_arg_unmarshal(case: Tuple[List[str], Dict[str, Any]]) _, params = dev.bindings_sig(bindings.get_ExpMetricNames) kwargs = dev.parse_args_to_kwargs(args, params) assert kwargs == expected, kwargs + + +def test_preview_search(tmp_path: pathlib.Path) -> None: + # Random + max_trials = 10 + searcher_config = { + "hyperparameters": { + "x": 12, + }, + "name": "test preview search (random)", + "searcher": { + "name": "random", + "metric": "loss", + "max_trials": max_trials, + }, + } + conf_path = tmp_path / "config.yaml" + with conf_path.open("w") as tmp_file: + tmp_file.write(json.dumps(searcher_config)) + + mock_resp = bindings.v1PreviewHPSearchResponse( + summary=bindings.v1SearchSummary( + config=searcher_config, + trials=[ + bindings.v1TrialSummary( + count=max_trials, + unit=bindings.v1SearchUnit(maxLength=True), + ) + ], + ) + ) + with util.standard_cli_rsps() as rsps: + rsps.post( + f"{MASTER_HOST}/api/v1/preview-hp-search", + status=200, + match=[ + matchers.json_params_matcher( + params={ + "config": searcher_config, + } + ) + ], + json=mock_resp.to_json(), + ) + expected_output = f"""Using search configuration: +{render.format_object_as_yaml(searcher_config)} + Trials | Training Time +----------+--------------------- + 10 | train to completion +""" + util.check_cli_output(["preview-search", str(conf_path)], expected_output) + + # ASHA + searcher_config = { + "hyperparameters": { + "x": 12, + }, + "name": "test preview search (asha)", + "searcher": { + "bracket_rungs": [], + "divisor": 5, + "max_concurrent_trials": 5, + "max_rungs": 5, + "max_time": 1000, + "max_trials": 10, + "metric": "loss", + "mode": "standard", + "name": "adaptive_asha", + "time_metric": "batch", + }, + } + conf_path = tmp_path / "config.yaml" + with conf_path.open("w") as tmp_file: + tmp_file.write(json.dumps(searcher_config)) + + mock_resp = bindings.v1PreviewHPSearchResponse( + summary=bindings.v1SearchSummary( + config=searcher_config, + trials=[ + bindings.v1TrialSummary( + count=7, + unit=bindings.v1SearchUnit(name="batch", value=200, maxLength=False), + ), + bindings.v1TrialSummary( + count=3, + unit=bindings.v1SearchUnit(name="batch", value=1000, maxLength=False), + ), + ], + ) + ) + with util.standard_cli_rsps() as rsps: + rsps.post( + f"{MASTER_HOST}/api/v1/preview-hp-search", + status=200, + match=[ + matchers.json_params_matcher( + params={ + "config": searcher_config, + } + ) + ], + json=mock_resp.to_json(), + ) + expected_output = f"""Using search configuration: +{render.format_object_as_yaml(searcher_config)} + Trials | Training Time +----------+---------------------- + 7 | train for 200 batch + 3 | train for 1000 batch +""" + util.check_cli_output(["preview-search", str(conf_path)], expected_output) diff --git a/harness/tests/cli/util.py b/harness/tests/cli/util.py index cbe674c68fe..b1d414558cb 100644 --- a/harness/tests/cli/util.py +++ b/harness/tests/cli/util.py @@ -1,4 +1,6 @@ import contextlib +import difflib +import io import os from typing import Any, Iterator, List, Optional, cast @@ -6,6 +8,7 @@ from responses import registries import determined as det +from determined.cli import cli from determined.common.api import authentication @@ -151,3 +154,18 @@ def expect_get_info( rsps.get(f"{master_url}/info", status=200, json={"version": det.__version__}) else: responses.get(f"{master_url}/info", status=200, json={"version": det.__version__}) + + +def check_cli_output(args: List[str], expected: str) -> None: + """ + Helper method to test CLI methods that checks redirected STDOUT from the executed command + matches expected output. + """ + with contextlib.redirect_stdout(io.StringIO()) as f: + cli.main(args=args) + actual = f.getvalue() + exp_lines = expected.splitlines(keepends=True) + act_lines = actual.splitlines(keepends=True) + diff_lines = difflib.ndiff(act_lines, exp_lines) + diff = "".join(diff_lines) + assert actual == expected, f"CLI output for {args} actual(-) != expected(+):\n {diff}" diff --git a/harness/tests/core/test_searcher.py b/harness/tests/core/test_searcher.py index a42b5f0d8f9..785ab4b84c1 100644 --- a/harness/tests/core/test_searcher.py +++ b/harness/tests/core/test_searcher.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from unittest import mock import pytest @@ -7,37 +7,13 @@ from tests import parallel -def make_test_searcher(ops: List[int], dist: core.DistributedContext) -> core.SearcherContext: - # Mock the session.get to return a few searcher ops - final_op = ops[-1] - ops = list(ops) - - def session_get(_: Any) -> Any: - assert ( - dist.rank == 0 - ), "worker SearcherContexts must not GET new ops, but ask the chief instead" - resp = mock.MagicMock() - if ops: - resp.json.return_value = { - "op": {"validateAfter": {"length": str(ops.pop(0))}}, - "completed": False, - } - else: - resp.json.return_value = { - "op": {"validateAfter": {"length": str(final_op)}}, - "completed": True, - } - return resp - +def make_test_searcher(max_length: int, dist: core.DistributedContext) -> core.SearcherContext: session = mock.MagicMock() - session.get.side_effect = session_get - searcher = core.SearcherContext( session=session, dist=dist, trial_id=1, - run_id=2, - allocation_id="3", + max_length=max_length, ) return searcher @@ -49,7 +25,7 @@ def test_searcher_workers_ask_chief(dummy: bool) -> None: @pex.run def searchers() -> core.SearcherContext: if not dummy: - searcher = make_test_searcher([5, 10, 15], pex.distributed) + searcher = make_test_searcher(5, pex.distributed) else: searcher = core.DummySearcherContext(dist=pex.distributed) epochs_trained = 0 @@ -72,10 +48,10 @@ def searchers() -> core.SearcherContext: return searcher if not dummy: - # Expect calls from chief: 15x progress, 4x completions + # Expect calls from chief: 5x progress chief = searchers[0] post_mock: Any = chief._session.post - assert post_mock.call_count == 19, post_mock.call_args_list + assert post_mock.call_count == 5, post_mock.call_args_list # The workers must not make any REST API calls at all. worker = searchers[1] @@ -88,7 +64,7 @@ def test_completion_check() -> None: @pex.run def do_test() -> None: - searcher = make_test_searcher([5], pex.distributed) + searcher = make_test_searcher(5, pex.distributed) ops = iter(searcher.operations()) next(ops) @@ -109,7 +85,7 @@ def test_searcher_chief_only(dummy: bool) -> None: @pex.run def do_test() -> None: if not dummy: - searcher = make_test_searcher([5, 10, 15], pex.distributed) + searcher = make_test_searcher(1, pex.distributed) else: searcher = core.DummySearcherContext(dist=pex.distributed) diff --git a/harness/tests/custom_search_mocks.py b/harness/tests/custom_search_mocks.py deleted file mode 100644 index 95549935b7f..00000000000 --- a/harness/tests/custom_search_mocks.py +++ /dev/null @@ -1,161 +0,0 @@ -import abc -import logging -import pathlib -from typing import Any, Dict, Iterable, List, Optional, Sequence, Union -from unittest import mock - -from determined import searcher -from determined.common import api -from determined.common.api import bindings - - -class MockMaster(metaclass=abc.ABCMeta): - @abc.abstractmethod - def handle_post_operations( - self, event: bindings.v1SearcherEvent, operations: List[searcher.Operation] - ) -> None: - pass - - @abc.abstractmethod - def handle_get_events(self) -> Optional[Sequence[bindings.v1SearcherEvent]]: - return [] - - @abc.abstractmethod - def add_event(self, event_obj: bindings.v1SearcherEvent) -> None: - pass - - -class SimulateMaster(MockMaster): - def __init__(self, metric: Union[float, Dict[str, Any]]) -> None: - self.events_queue: List[bindings.v1SearcherEvent] = [] # store event and - self.events_count = 0 - self.metric = metric - self.overall_progress = 0.0 - - def handle_post_operations( - self, event: bindings.v1SearcherEvent, operations: List[searcher.Operation] - ) -> None: - self._remove_upto(event) - self._process_operations(operations) - - def _remove_upto(self, event: bindings.v1SearcherEvent) -> None: - for i, e in enumerate(self.events_queue): - if e.id == event.id: - self.events_queue = self.events_queue[i + 1 :] - return - - raise RuntimeError(f"event not found in events queue: {event}") - - def _process_operations(self, operations: List[searcher.Operation]) -> None: - for op in operations: - self._append_events_for_op(op) # validate_after returns two events. - - def add_event(self, event_obj: bindings.v1SearcherEvent) -> None: - self.events_queue.append(event_obj) - - def handle_get_events(self) -> Optional[Sequence[bindings.v1SearcherEvent]]: - return self.events_queue - - def _append_events_for_op(self, op: searcher.Operation) -> None: - if type(op) == searcher.ValidateAfter: - validation_completed = bindings.v1ValidationCompleted( - requestId=str(op.request_id), - metric=self.metric, - validateAfterLength=str(op.length), - ) - self.events_count += 1 - event = bindings.v1SearcherEvent( - id=self.events_count, validationCompleted=validation_completed - ) - self.events_queue.append(event) - - trial_progress = bindings.v1TrialProgress( - requestId=str(op.request_id), partialUnits=float(op.length) - ) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, trialProgress=trial_progress) - self.events_queue.append(event) - - elif type(op) == searcher.Create: - trial_created = bindings.v1TrialCreated(requestId=str(op.request_id)) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, trialCreated=trial_created) - self.events_queue.append(event) - - elif type(op) == searcher.Progress: # no events - self.overall_progress - - elif type(op) == searcher.Close: - trial_closed = bindings.v1TrialClosed(requestId=str(op.request_id)) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, trialClosed=trial_closed) - self.events_queue.append(event) - - elif type(op) == searcher.Shutdown: - exp_state = bindings.experimentv1State.COMPLETED - exp_inactive = bindings.v1ExperimentInactive(experimentState=exp_state) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, experimentInactive=exp_inactive) - self.events_queue.append(event) - else: - pass - - -class MockMasterSearchRunner(searcher.LocalSearchRunner): - def __init__( - self, - search_method: searcher.SearchMethod, - mock_master_object: MockMaster, - searcher_dir: Optional[pathlib.Path] = None, - ): - super(MockMasterSearchRunner, self).__init__(search_method, searcher_dir) - self.mock_master_obj = mock_master_object - initial_ops = bindings.v1InitialOperations() - event_obj = bindings.v1SearcherEvent(id=1, initialOperations=initial_ops) - self.mock_master_obj.add_event(event_obj) - - def post_operations( - self, - session: api.Session, - experiment_id: int, - event: bindings.v1SearcherEvent, - operations: List[searcher.Operation], - ) -> None: - logging.info("MockMasterSearchRunner.post_operations") - self.mock_master_obj.handle_post_operations(event, operations) - - def get_events( - self, - session: api.Session, - experiment_id: int, - ) -> Optional[Sequence[bindings.v1SearcherEvent]]: - logging.info("MockMasterSearchRunner.get_events") - return self.mock_master_obj.handle_get_events() - - def run( - self, - exp_config: Union[Dict[str, Any], str], - context_dir: Optional[str] = None, - includes: Optional[Iterable[Union[str, pathlib.Path]]] = None, - ) -> int: - logging.info("MockMasterSearchRunner.run") - experiment_id_file = self.searcher_dir.joinpath("experiment_id") - exp_id = 4 # dummy exp - with experiment_id_file.open("w") as f: - f.write(str(exp_id)) - state_path = self._get_state_path(exp_id) - state_path.mkdir(parents=True) - logging.info(f"Starting HP searcher for mock experiment {exp_id}") - self.state.experiment_id = exp_id - self.state.last_event_id = 0 - super(MockMasterSearchRunner, self).save_state(exp_id, []) - experiment_id = exp_id - operations: Optional[List[searcher.Operation]] = None - session: api.Session = mock.Mock() - super(MockMasterSearchRunner, self).run_experiment( - experiment_id, session, operations, sleep_time=0.0 - ) - return exp_id - - def _get_state_path(self, experiment_id: int) -> pathlib.Path: - return self.searcher_dir.joinpath(f"exp_{experiment_id}") diff --git a/harness/tests/experiment/fixtures/deepspeed_linear_model.py b/harness/tests/experiment/fixtures/deepspeed_linear_model.py index 900236c6cb1..3fd06f08cdf 100644 --- a/harness/tests/experiment/fixtures/deepspeed_linear_model.py +++ b/harness/tests/experiment/fixtures/deepspeed_linear_model.py @@ -12,6 +12,43 @@ from determined import pytorch +class MetricsCallbacks(pytorch.PyTorchCallback): + def __init__(self, trial) -> None: + self.trial = trial + super().__init__() + + def on_validation_end(self, metrics: Dict) -> None: + assert "loss" in metrics.keys() + + def on_checkpoint_upload_end(self, uuid: str) -> None: + self.trial.checkpoint_uuid = uuid + + def on_checkpoint_load_start(self, checkpoint: Optional[Dict]): + self.trial.checkpoint_found = checkpoint is not None + + +class ReproducibilityCallbacks(pytorch.PyTorchCallback): + def __init__(self, trial) -> None: + self.trial = trial + super().__init__() + + def on_validation_end(self, metrics: Dict) -> None: + self.trial.val_metrics.append(metrics) + + def on_training_workload_end(self, avg_metrics, batch_metrics): + self.trial.avg_metrics.append(avg_metrics) + self.trial.batch_metrics.append(batch_metrics) + + +class TwoEngineMetricsCallbacks(pytorch.PyTorchCallback): + def __init__(self) -> None: + super().__init__() + + def on_validation_end(self, metrics: Dict) -> None: + assert "loss1" in metrics.keys() + assert "loss2" in metrics.keys() + + class LinearDataset(torch.utils.data.Dataset): def __init__(self, a: int, b: int, num_samples: int): self.a = a @@ -31,9 +68,11 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: class LinearDeepSpeedTrial(det_ds.DeepSpeedTrial): _searcher_metric = "loss" - def __init__(self, context: det_ds.DeepSpeedTrialContext): + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): self.context = context - self.hparams = attrdict.AttrDict(context.get_hparams()) + self.hparams = attrdict.AttrDict(hparams) + self.checkpoint_uuid = None + self.checkpoint_found = None if ( self.hparams.test_manual_init_distributed or self.hparams.test_fail_manual_init_distributed @@ -64,6 +103,9 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext): if self.hparams.test_custom_reducer: self.reducer = self.context.wrap_reducer(lambda x: np.mean(x) * 2, name="loss_2x") + def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: + return {"my_callbacks": MetricsCallbacks(trial=self)} + def build_training_data_loader(self) -> Union[pytorch.DataLoader, torch.utils.data.DataLoader]: dataset = LinearDataset(1, 1, self.ds_config.train_batch_size * 2) dataloader = pytorch.DataLoader( @@ -158,8 +200,8 @@ def evaluate_batch( class LinearCallbackTrial(LinearDeepSpeedTrial): - def __init__(self, context: det_ds.DeepSpeedTrialContext): - super().__init__(context) + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): + super().__init__(context, hparams) self.counter = counter.Counter() def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: @@ -167,9 +209,9 @@ def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: class LinearTwoEngineTrial(LinearDeepSpeedTrial): - def __init__(self, context: det_ds.DeepSpeedTrialContext): + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): self.context = context - self.hparams = attrdict.AttrDict(context.get_hparams()) + self.hparams = attrdict.AttrDict(hparams) self.ds_config = attrdict.AttrDict(self.hparams.deepspeed_config) model1 = torch.nn.Linear(1, 1) model2 = torch.nn.Linear(1, 1) @@ -183,6 +225,9 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext): self.model1 = self.context.wrap_model_engine(self.model1) self.model2 = self.context.wrap_model_engine(self.model2) + def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: + return {"my_callbacks": TwoEngineMetricsCallbacks()} + def train_batch( self, dataloader_iter: Optional[Iterator[pytorch.TorchData]], @@ -214,10 +259,13 @@ def take_step(model): class LinearPipelineEngineTrial(LinearDeepSpeedTrial): - def __init__(self, context: det_ds.DeepSpeedTrialContext): + def __init__(self, context: det_ds.DeepSpeedTrialContext, hparams: Dict): self.context = context - self.hparams = attrdict.AttrDict(context.get_hparams()) + self.hparams = attrdict.AttrDict(hparams) self.ds_config = attrdict.AttrDict(self.hparams.deepspeed_config) + self.avg_metrics = [] + self.batch_metrics = [] + self.val_metrics = [] model = torch.nn.Linear(1, 1) model = deepspeed.PipelineModule( layers=[model], @@ -232,6 +280,9 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext): self.model = self.context.wrap_model_engine(self.model) self.context.set_mpu(det_ds.make_deepspeed_mpu(self.model.mpu)) + def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]: + return {"my_callbacks": ReproducibilityCallbacks(trial=self)} + def train_batch( self, dataloader_iter: Optional[Iterator[pytorch.TorchData]], diff --git a/harness/tests/experiment/fixtures/pytorch_amp/apex_amp.yaml b/harness/tests/experiment/fixtures/pytorch_amp/apex_amp.yaml deleted file mode 100644 index dd457242478..00000000000 --- a/harness/tests/experiment/fixtures/pytorch_amp/apex_amp.yaml +++ /dev/null @@ -1,17 +0,0 @@ -description: mnist_pytorch with PyTorch APEX support configured -data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz -hyperparameters: - learning_rate: 1.0 - global_batch_size: 64 - n_filters1: 32 - n_filters2: 64 - dropout1: 0.25 - dropout2: 0.5 -searcher: - name: single - metric: validation_loss - max_length: - batches: 937 #60,000 training images with batch size 64 - smaller_is_better: true -entrypoint: python3 apex_amp_model_def.py diff --git a/harness/tests/experiment/fixtures/pytorch_amp/apex_amp_distributed.yaml b/harness/tests/experiment/fixtures/pytorch_amp/apex_amp_distributed.yaml deleted file mode 100644 index 16e46933f7f..00000000000 --- a/harness/tests/experiment/fixtures/pytorch_amp/apex_amp_distributed.yaml +++ /dev/null @@ -1,25 +0,0 @@ -description: mnist_pytorch with PyTorch APEX support configured -data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz -environment: - environment_variables: - #This is specified due to an error in NCCL/veth interface - #https://github.com/pytorch/pytorch/issues/68893 - - NCCL_SOCKET_IFNAME=ens,eth -hyperparameters: - learning_rate: 1.0 - global_batch_size: 64 - n_filters1: 32 - n_filters2: 64 - dropout1: 0.25 - dropout2: 0.5 -resources: - slots_per_trial: 8 -searcher: - name: single - metric: validation_loss - max_length: - batches: 937 #60,000 training images with batch size 64 - smaller_is_better: true -max_restarts: 0 -entrypoint: python3 -m determined.launch.torch_distributed python3 apex_amp_model_def.py diff --git a/harness/tests/experiment/fixtures/pytorch_amp/auto_amp.yaml b/harness/tests/experiment/fixtures/pytorch_amp/auto_amp.yaml deleted file mode 100644 index 2e2e5412774..00000000000 --- a/harness/tests/experiment/fixtures/pytorch_amp/auto_amp.yaml +++ /dev/null @@ -1,17 +0,0 @@ -description: mnist_pytorch with PyTorch AMP API automatically enabled -data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz -hyperparameters: - learning_rate: 1.0 - global_batch_size: 64 - n_filters1: 32 - n_filters2: 64 - dropout1: 0.25 - dropout2: 0.5 -searcher: - name: single - metric: validation_loss - max_length: - batches: 937 #60,000 training images with batch size 64 - smaller_is_better: true -entrypoint: python3 auto_amp_model_def.py diff --git a/harness/tests/experiment/fixtures/pytorch_amp/auto_amp_distributed.yaml b/harness/tests/experiment/fixtures/pytorch_amp/auto_amp_distributed.yaml deleted file mode 100644 index 1f00fea0c79..00000000000 --- a/harness/tests/experiment/fixtures/pytorch_amp/auto_amp_distributed.yaml +++ /dev/null @@ -1,25 +0,0 @@ -description: mnist_pytorch with PyTorch AMP API automatically enabled -data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz -environment: - environment_variables: - #This is specified due to an error in NCCL/veth interface - #https://github.com/pytorch/pytorch/issues/68893 - - NCCL_SOCKET_IFNAME=ens,eth -hyperparameters: - learning_rate: 1.0 - global_batch_size: 64 - n_filters1: 32 - n_filters2: 64 - dropout1: 0.25 - dropout2: 0.5 -resources: - slots_per_trial: 8 -searcher: - name: single - metric: validation_loss - max_length: - batches: 937 #60,000 training images with batch size 64 - smaller_is_better: true -max_restarts: 0 -entrypoint: python3 -m determined.launch.torch_distributed python3 auto_amp_model_def.py diff --git a/harness/tests/experiment/fixtures/pytorch_amp/manual_amp.yaml b/harness/tests/experiment/fixtures/pytorch_amp/manual_amp.yaml deleted file mode 100644 index 95920f58571..00000000000 --- a/harness/tests/experiment/fixtures/pytorch_amp/manual_amp.yaml +++ /dev/null @@ -1,17 +0,0 @@ -description: mnist_pytorch_const with PyTorch AMP API manually applied -data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz -hyperparameters: - learning_rate: 1.0 - global_batch_size: 64 - n_filters1: 32 - n_filters2: 64 - dropout1: 0.25 - dropout2: 0.5 -searcher: - name: single - metric: validation_loss - max_length: - batches: 937 #60,000 training images with batch size 64 - smaller_is_better: true -entrypoint: python3 manual_amp_model_def.py diff --git a/harness/tests/experiment/fixtures/pytorch_amp/manual_amp_distributed.yaml b/harness/tests/experiment/fixtures/pytorch_amp/manual_amp_distributed.yaml deleted file mode 100644 index c9b6eb6d294..00000000000 --- a/harness/tests/experiment/fixtures/pytorch_amp/manual_amp_distributed.yaml +++ /dev/null @@ -1,25 +0,0 @@ -description: mnist_pytorch_const with PyTorch AMP API manually applied -data: - url: https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz -environment: - environment_variables: - #This is specified due to an error in NCCL/veth interface - #https://github.com/pytorch/pytorch/issues/68893 - - NCCL_SOCKET_IFNAME=ens,eth -hyperparameters: - learning_rate: 1.0 - global_batch_size: 64 - n_filters1: 32 - n_filters2: 64 - dropout1: 0.25 - dropout2: 0.5 -resources: - slots_per_trial: 8 -searcher: - name: single - metric: validation_loss - max_length: - batches: 937 #60,000 training images with batch size 64 - smaller_is_better: true -max_restarts: 0 -entrypoint: python3 -m determined.launch.torch_distributed python3 manual_amp_model_def.py diff --git a/harness/tests/experiment/integrations/test_deepspeed_trial.py b/harness/tests/experiment/integrations/test_deepspeed_trial.py index 980196b265f..365034ffe14 100644 --- a/harness/tests/experiment/integrations/test_deepspeed_trial.py +++ b/harness/tests/experiment/integrations/test_deepspeed_trial.py @@ -4,17 +4,18 @@ import os import pathlib import shutil -from typing import Any, Dict, Iterator, Optional +from typing import Iterator +import appdirs import pytest import torch from deepspeed.runtime import config_utils import determined -import determined.pytorch.deepspeed as det_deepspeed -from determined import workload -from tests.experiment import utils # noqa: I100 -from tests.experiment.fixtures import deepspeed_linear_model +import determined.pytorch.deepspeed as det_ds +from determined import pytorch # noqa: I2041 +from determined.pytorch.deepspeed import _trainer # noqa: I2041 +from tests.experiment.fixtures import deepspeed_linear_model # noqa: I2041 ds_config_path = str( pathlib.Path(__file__).resolve().parent.parent.joinpath("fixtures/ds_config.json") @@ -36,6 +37,12 @@ def manual_init_distributed() -> Iterator[None]: del os.environ["DET_MANUAL_INIT_DISTRIBUTED"] +# Checks shm size and skips certain tests if it can't be determined or isn't enough. +# TODO: Remove these skips after CI is updated (INFENG-659) +def check_shm_size() -> bool: + return pathlib.Path("/dev/shm").exists() and shutil.disk_usage("/dev/shm")[0] < 10**8 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no gpu available") @pytest.mark.deepspeed @pytest.mark.gpu @@ -76,522 +83,229 @@ def test_fail_manual_init_distributed(self, manual_init_distributed: None): updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_fail_manual_init_distributed"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - with pytest.raises(AssertionError, match=r"Distributed backend is not initialized. .*"): - _ = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_manual_init_distributed(self, manual_init_distributed: None): updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_manual_init_distributed"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - _ = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) assert torch.distributed.is_initialized() def test_linear_model(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_manual_grad_acc_metrics(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_manual_grad_acc"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send(steps=10, validation_freq=10, train_batch_calls=1) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_fail_manual_grad_acc_metrics(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_fail_manual_grad_acc"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send(steps=10, validation_freq=10, train_batch_calls=1) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - with pytest.raises(AssertionError, match="did not train for gradient accumulation steps"): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_custom_dataloader(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_manual_dataloader"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_fail_dataset_repro_check(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_fail_dataset_repro_check"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - with pytest.raises(RuntimeError, match=r".* reproducibility .* disable this check .*"): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(16)) def test_invalid_valid_dataset(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( determined.errors.InvalidExperimentException, match=r".* train micro batches .* should not be less than .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.InvalidValidDatasetTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.InvalidValidDatasetTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_invalid_train_metric(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( determined.errors.InvalidExperimentException, match=r"train_batch() must return a dictionary .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.InvalidTrainMetricTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.InvalidTrainMetricTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_invalid_valid_metric(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( determined.errors.InvalidExperimentException, match=r"evaluate_batch must return a dictionary .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.InvalidValidMetricTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.InvalidValidMetricTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_differing_valid_metric_keys(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( - determined.errors.InvalidExperimentException, - match=r".* metric names must match across all batches .*", + ValueError, + match=r"Validation metric names must match across all batches of data: .*", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.DifferingValidMetricKeyTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.DifferingValidMetricKeyTrial( + train_context, self.hparams + ) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_fail_multiple_set_mpu(self): - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - with pytest.raises( - determined.errors.InvalidExperimentException, match=r"Only one MPU can be passed .*" + determined.errors.InvalidExperimentException, + match=r"Only one MPU can be passed to DeepSpeedTrialContext.", ): - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.context.set_mpu( - det_deepspeed.make_data_parallel_mpu(controller.context.distributed) - ) - controller.context.set_mpu( - det_deepspeed.make_data_parallel_mpu(controller.context.distributed) - ) + with det_ds.init() as train_context: + _ = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + train_context.set_mpu(det_ds.make_data_parallel_mpu(train_context.distributed)) + train_context.set_mpu(det_ds.make_data_parallel_mpu(train_context.distributed)) def test_custom_reducer(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["test_custom_reducer"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_linear_non_scalar_metrics(self) -> None: updated_hparams = copy.deepcopy(self.hparams) updated_hparams["return_non_scalar_metrics"] = True - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=10, - validation_freq=10, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=updated_hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, updated_hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_linear_pipeline_model(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send(steps=1, validation_freq=1, train_batch_calls=1) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearPipelineEngineTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearPipelineEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) def test_two_model_engines(self) -> None: - def make_workloads() -> workload.Stream: - trainer = utils.TrainAndValidate() - - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - training_metrics, validation_metrics = trainer.result() - - for metrics in validation_metrics: - assert "loss1" in metrics - assert "loss2" in metrics - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearTwoEngineTrial, - hparams=self.hparams, - workloads=make_workloads(), - trial_seed=self.trial_seed, - expose_gpus=True, - ) - controller.run() - - @pytest.mark.skipif(shutil.disk_usage("/dev/shm")[0] < 10**8, reason="insufficient shm size") - def test_checkpointing_and_restoring(self, tmp_path: pathlib.Path) -> None: - def make_trial_controller_fn( - workloads: workload.Stream, - checkpoint_dir: Optional[str] = None, - latest_checkpoint: Optional[Dict[str, Any]] = None, - steps_completed: int = 0, - ) -> determined.TrialController: - return utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearPipelineEngineTrial, - hparams=self.hparams, - workloads=workloads, - trial_seed=self.trial_seed, - checkpoint_dir=checkpoint_dir, - latest_checkpoint=latest_checkpoint, - steps_completed=steps_completed, - expose_gpus=True, - ) - - utils.checkpointing_and_restoring_test(make_trial_controller_fn, tmp_path) - - def test_restore_invalid_checkpoint(self, tmp_path: pathlib.Path) -> None: - # Build, train, and save a checkpoint with the normal hyperparameters. - checkpoint_dir = str(tmp_path.joinpath("checkpoint")) - latest_checkpoint = None - steps_completed = 0 - - def make_workloads_1() -> workload.Stream: - trainer = utils.TrainAndValidate() - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, - ) - interceptor = workload.WorkloadResponseInterceptor() - yield from interceptor.send(workload.checkpoint_workload()) - nonlocal latest_checkpoint, steps_completed - latest_checkpoint = interceptor.metrics_result()["uuid"] - steps_completed = trainer.get_steps_completed() - - controller1 = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearDeepSpeedTrial, - hparams=self.hparams, - workloads=make_workloads_1(), - trial_seed=self.trial_seed, - checkpoint_dir=checkpoint_dir, - expose_gpus=True, - ) - controller1.run() - - # Verify that an invalid architecture fails to load from the checkpoint. - def make_workloads_2() -> workload.Stream: - trainer = utils.TrainAndValidate() - yield from trainer.send( - steps=1, - validation_freq=1, - train_batch_calls=self.data_parallel_only_auto_train_batch_calls, + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearTwoEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) + + def test_checkpointing_and_restoring(self) -> None: + with det_ds.init() as train_context: + trial1 = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial1, train_context) + assert trial1.checkpoint_uuid is None + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) + with det_ds.init() as train_context: + trial2 = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial2, train_context) + assert trial1.checkpoint_uuid is not None + trainer.fit( + validation_period=pytorch.Batch(16), + max_length=pytorch.Batch(16), + latest_checkpoint=os.path.join( + appdirs.user_data_dir("determined"), trial1.checkpoint_uuid + ), ) - with pytest.raises(AssertionError, match="Failed to load deepspeed checkpoint."): - controller2 = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearTwoEngineTrial, - hparams=self.hparams, - workloads=make_workloads_2(), - trial_seed=self.trial_seed, - checkpoint_dir=checkpoint_dir, - latest_checkpoint=latest_checkpoint, - steps_completed=steps_completed, - expose_gpus=True, - ) - controller2.run() - - # TODO: Remove these skips after CI is updated (INFENG-659) - @pytest.mark.skipif(shutil.disk_usage("/dev/shm")[0] < 10**8, reason="insufficient shm size") + def test_restore_invalid_checkpoint(self) -> None: + with det_ds.init() as train_context: + trial1 = deepspeed_linear_model.LinearDeepSpeedTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial1, train_context) + assert trial1.checkpoint_uuid is None + trainer.fit(validation_period=pytorch.Batch(16), max_length=pytorch.Batch(16)) + + with det_ds.init() as train_context: + trial2 = deepspeed_linear_model.LinearTwoEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial2, train_context) + assert trial1.checkpoint_uuid is not None + with pytest.raises(AssertionError, match="Failed to load deepspeed checkpoint."): + trainer.fit( + validation_period=pytorch.Batch(16), + max_length=pytorch.Batch(16), + latest_checkpoint=os.path.join( + appdirs.user_data_dir("determined"), trial1.checkpoint_uuid + ), + ) + + # TODO: Remove this particular skip after CI is updated (INFENG-659) + @pytest.mark.skipif(check_shm_size(), reason="insufficient shm size") def test_reproducibility(self) -> None: - def controller_fn(workloads: workload.Stream) -> determined.TrialController: - return utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearPipelineEngineTrial, - hparams=self.hparams, - workloads=workloads, - trial_seed=self.trial_seed, - expose_gpus=True, - ) - - utils.reproducibility_test(controller_fn, steps=1000, validation_freq=100) - - @pytest.mark.skipif(shutil.disk_usage("/dev/shm")[0] < 10**8, reason="insufficient shm size") - def test_callbacks(self, tmp_path: pathlib.Path) -> None: - checkpoint_dir = tmp_path.joinpath("checkpoint") - latest_checkpoint = None - steps_completed = 0 - - controller = None - - def make_workloads1() -> workload.Stream: - nonlocal controller - assert controller.trial.counter.trial_startups == 1 - - yield workload.train_workload(1, 1, 0, 4), workload.ignore_workload_response - assert controller is not None, "controller was never set!" - assert controller.trial.counter.__dict__ == { - "trial_startups": 1, - "validation_steps_started": 0, - "validation_steps_ended": 0, - "checkpoints_written": 0, - "checkpoints_uploaded": 0, - "training_started_times": 1, - "training_epochs_started": 2, - "training_epochs_ended": 2, - "training_workloads_ended": 1, - "trial_shutdowns": 0, - } - - yield workload.validation_workload(), workload.ignore_workload_response - assert controller.trial.counter.__dict__ == { - "trial_startups": 1, - "validation_steps_started": 1, - "validation_steps_ended": 1, - "checkpoints_written": 0, - "checkpoints_uploaded": 0, - "training_started_times": 1, - "training_epochs_started": 2, - "training_epochs_ended": 2, - "training_workloads_ended": 1, - "trial_shutdowns": 0, - } - - interceptor = workload.WorkloadResponseInterceptor() - yield from interceptor.send(workload.checkpoint_workload()) - nonlocal latest_checkpoint, steps_completed - latest_checkpoint = interceptor.metrics_result()["uuid"] - steps_completed = 1 - assert controller.trial.counter.__dict__ == { + with det_ds.init() as train_context: + _trainer._set_random_seeds(self.trial_seed) + train_context._trial_seed = self.trial_seed + trial1 = deepspeed_linear_model.LinearPipelineEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial1, train_context) + trainer.fit(validation_period=pytorch.Batch(100), max_length=pytorch.Batch(1000)) + + with det_ds.init() as train_context: + _trainer._set_random_seeds(self.trial_seed) + train_context._trial_seed = self.trial_seed + trial2 = deepspeed_linear_model.LinearPipelineEngineTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial2, train_context) + trainer.fit(validation_period=pytorch.Batch(100), max_length=pytorch.Batch(1000)) + + assert len(trial1.avg_metrics) == len(trial2.avg_metrics) + for A, B in zip(trial1.avg_metrics, trial2.avg_metrics): + assert A.keys() == B.keys() + for key in A.keys(): + assert abs(A[key] - B[key]) < 10e-7 + + assert len(trial1.batch_metrics) == len(trial2.batch_metrics) + for batch_idx in range(len(trial1.batch_metrics)): + for A, B in zip(trial1.batch_metrics[batch_idx], trial2.batch_metrics[batch_idx]): + assert A.keys() == B.keys() + for key in A.keys(): + assert abs(A[key] - B[key]) < 10e-7 + + assert len(trial1.val_metrics) == len(trial2.val_metrics) + for A, B in zip(trial1.val_metrics, trial2.val_metrics): + assert A.keys() == B.keys() + for key in A.keys(): + assert abs(A[key] - B[key]) < 10e-7 + + def test_callbacks(self) -> None: + with det_ds.init() as train_context: + trial = deepspeed_linear_model.LinearCallbackTrial(train_context, self.hparams) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Epoch(2)) + assert trial.counter.__dict__ == { "trial_startups": 1, "validation_steps_started": 1, "validation_steps_ended": 1, @@ -600,51 +314,10 @@ def make_workloads1() -> workload.Stream: "training_started_times": 1, "training_epochs_started": 2, "training_epochs_ended": 2, - "training_workloads_ended": 1, - "trial_shutdowns": 0, + "training_workloads_ended": 2, + "trial_shutdowns": 1, } - hparams1 = dict(self.hparams) - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearCallbackTrial, - hparams=hparams1, - workloads=make_workloads1(), - checkpoint_dir=str(checkpoint_dir), - expose_gpus=True, - ) - controller.run() - assert controller.trial.counter.trial_shutdowns == 1 - - # Verify the checkpoint loading callback works. - def make_workloads2() -> workload.Stream: - yield workload.train_workload(1, 1, 0, 2), workload.ignore_workload_response - - controller = utils.make_trial_controller_from_trial_implementation( - trial_class=deepspeed_linear_model.LinearCallbackTrial, - hparams=self.hparams, - workloads=make_workloads2(), - checkpoint_dir=str(checkpoint_dir), - latest_checkpoint=latest_checkpoint, - steps_completed=steps_completed, - expose_gpus=True, - ) - controller.run() - assert controller.trial.counter.__dict__ == { - # Note: trial_startups will get reset by the loading logic. - "trial_startups": 1, - "validation_steps_started": 1, - "validation_steps_ended": 1, - # Note: checkpoints_written, checkpoints_uploaded, and trial_shutdowns, cannot be - # persisted, as they are all updated after checkpointing. - "checkpoints_written": 0, - "checkpoints_uploaded": 0, - "training_started_times": 2, - "training_epochs_started": 3, - "training_epochs_ended": 3, - "training_workloads_ended": 2, - "trial_shutdowns": 1, - } - @pytest.mark.deepspeed def test_overwrite_deepspeed_config() -> None: @@ -656,16 +329,16 @@ def test_overwrite_deepspeed_config() -> None: expected_config = copy.deepcopy(deepspeed_config) expected_config["train_micro_batch_size_per_gpu"] = 2 expected_config["optimizer"]["params"]["lr"] = 0.001 - result = det_deepspeed.overwrite_deepspeed_config(base_ds_config, source_ds_config) + result = det_ds.overwrite_deepspeed_config(base_ds_config, source_ds_config) assert result == expected_config # Test load base deepspeed config from json file. base_ds_config = str( pathlib.Path(__file__).resolve().parent.parent.joinpath("fixtures/ds_config.json") ) - result = det_deepspeed.overwrite_deepspeed_config(base_ds_config, source_ds_config) + result = det_ds.overwrite_deepspeed_config(base_ds_config, source_ds_config) assert result == expected_config # Test fail invalid base_ds_config argument. with pytest.raises(TypeError, match="Expected string or dict for base_ds_config argument."): - _ = det_deepspeed.overwrite_deepspeed_config([1, 2], source_ds_config) + _ = det_ds.overwrite_deepspeed_config([1, 2], source_ds_config) diff --git a/e2e_tests/tests/fixtures/custom_searcher/__init__.py b/harness/tests/experiment/keras/__init__.py similarity index 100% rename from e2e_tests/tests/fixtures/custom_searcher/__init__.py rename to harness/tests/experiment/keras/__init__.py diff --git a/harness/tests/experiment/keras/test_callback.py b/harness/tests/experiment/keras/test_callback.py new file mode 100644 index 00000000000..ff32e11a729 --- /dev/null +++ b/harness/tests/experiment/keras/test_callback.py @@ -0,0 +1,511 @@ +import json +import os +import pathlib +import re +import subprocess +import sys +from typing import Any, Callable, Dict, Optional, Tuple, Union +from unittest import mock + +import keras +import numpy as np +import pytest +import tensorflow as tf + +import determined as det +import determined.keras +from determined import core +from determined.common import storage +from tests.experiment import utils + + +def mock_core_context( + path: str, events: utils.Events, distributed: Optional[core.DistributedContext] = None +) -> Tuple[core.Context, Callable[[], None]]: + """ + Returns a core_context and a set_preempt() callable. + + The core_context is partially mocked to support triggering preemption from test code and to log + all reports to the provided Events object. + """ + # Set up a functional DistributedContext. + distributed = distributed or core.DummyDistributedContext() + # Set up a functional CheckpointContext. + storage_manager = storage.SharedFSStorageManager(path) + checkpoint = core.DummyCheckpointContext(distributed, storage_manager) + + # Mock everything else, logging report-like calls to events. + + def report_metrics(group: str, steps_completed: int, metrics: Any) -> None: + events.append((f"report_metrics:{group}:{steps_completed}", metrics)) + + def report_progress(progress: float) -> None: + fourdigits = "%.4f" % progress + events.append((f"report_progress:{fourdigits}", progress)) + + def set_status(status: str) -> None: + events.append((f"set_status:{status}", None)) + + preempted = False + + def should_preempt() -> bool: + nonlocal preempted + return preempted + + core_context = mock.Mock() + core_context.distributed = distributed + core_context.preempt.should_preempt.side_effect = should_preempt + core_context.checkpoint = checkpoint + core_context.train.report_metrics.side_effect = report_metrics + core_context.train.report_progress.side_effect = report_progress + core_context.train.set_status.side_effect = set_status + + def set_preempt() -> None: + nonlocal preempted + preempted = True + + return core_context, set_preempt + + +class DeterminedCallbackForTesting(det.keras.DeterminedCallback): + """ + For testing purposes, log events that happen during training for evaluation after training. + """ + + def __init__(self, events: utils.Events, *args: Any, **kwargs: Any) -> None: + self.events = events + super().__init__(*args, **kwargs) + + def on_train_begin(self, logs: Any) -> None: + super().on_train_begin(logs) + weight = self.model.layers[0].get_weights()[0][0] + fourdigits = "%.4f" % weight + self.events.append((f"after_train_begin:{fourdigits}", weight)) + + def on_epoch_end(self, epoch: int, logs: Any) -> None: + self.events.append((f"before_epoch_end:{epoch}", logs)) + super().on_epoch_end(epoch, logs) + self.events.append((f"after_epoch_end:{epoch}", logs)) + + def on_train_end(self, logs: Any) -> None: + self.events.append(("before_train_end", None)) + super().on_train_end(logs) + + def save_model( + self, model: keras.models.Model, path: str, distributed: core.DistributedContext + ) -> None: + super().save_model(model, path, distributed) + ckpt_uuid = os.path.basename(os.path.dirname(path)) + weight = self.model.layers[0].get_weights()[0][0] + self.events.append(("save_model", (ckpt_uuid, weight))) + + def load_model(self, *args: Any, **kwargs: Any) -> None: + super().load_model(*args, **kwargs) + self.events.append(("load_model", None)) + + +def build_model(eager: bool = False) -> keras.models.Model: + layer = keras.layers.Dense( + 1, activation=None, use_bias=False, kernel_initializer="zeros", input_shape=(8,) + ) + model = keras.models.Sequential([layer]) + model.compile( + loss=keras.losses.MeanSquaredError(), + optimizer=keras.optimizers.SGD(), + run_eagerly=eager, + ) + return model + + +def do_fit( + # Basic test configuration. + path: Union[str, pathlib.Path], + model: Optional[keras.models.Model] = None, + distributed: Optional[core.DistributedContext] = None, + # DeterminedCallback settings. + checkpoint: Optional[str] = None, + continue_id: int = 1, + checkpoint_epochs: int = 1, + train_metrics_report_period: Union[str, int] = "epoch", + # Model.compile settings. + eager: bool = False, + # Model.fit settings. + epochs: int = 2, + verbose: int = 0, + set_preempt_on_event: Optional[str] = None, +) -> utils.Events: + x = np.ones((64, 8)) + y = np.ones((64, 8)) + validation_data = (np.ones((64, 8)), np.ones((64, 8))) + + model = model or build_model(eager=eager) + events = utils.Events() + core_context, set_preempt = mock_core_context(str(path), events, distributed) + + if set_preempt_on_event: + # Configure a hook for our Events that calls set_preempt() when a matching event arrives. + p = re.compile(set_preempt_on_event) + + def hook(summary: str, data: Any) -> None: + if p.search(summary): + set_preempt() + + events.hook = hook + + det_cb = DeterminedCallbackForTesting( + events, + core_context, + checkpoint=checkpoint, + continue_id=continue_id, + train_metrics_report_period=train_metrics_report_period, + checkpoint_epochs=checkpoint_epochs, + ) + + model.fit( + x=x, + y=y, + validation_data=validation_data, + batch_size=8, + epochs=epochs, + callbacks=[det_cb], + verbose=verbose, + ) + return events + + +def check_keras_metrics(metrics: Dict[str, Any]) -> None: + # Make sure we are filtering out size and batch, which are pointless to our UI. + assert "size" not in metrics and "batch" not in metrics, metrics + # Make sure we are always injecting epochs and batches. + assert "batches" in metrics and "epochs" in metrics, metrics + # Never allow 'val_' prefix in log names: + # - Validation metrics come in on_test_end, and don't include 'val_' prefix. + # - Training metrics from on_epoch_end have val_* values, which we filter out. + # - Training metrics from on_test_batch_end do not have val_* metrics. + # Training metrics must not contain validation metrics. + assert not any(m.startswith("val_") for m in metrics), metrics + + +@pytest.mark.tensorflow +def test_basic_logic(tmp_path: pathlib.Path) -> None: + # make sure verbose=1 doesn't puke (though we don't really check the output) + events = do_fit(tmp_path, verbose=1) + + # Checks that: + # - set_status() gets called + # - report_metrics() gets called + # - report_progress() gets called + data = utils.assert_events_match( + events, + "!load_model", + "after_train_begin", + "set_status:training", + "set_status:validating", + ("report_metrics:validation", "validation_metrics_sample"), + "before_epoch_end:0", + ("report_metrics:training", "training_metrics_sample"), + "report_progress:0.5000", + "set_status:checkpointing", + "save_model", + "after_epoch_end:0", + "before_epoch_end:1", + "report_progress:1.000", + "save_model", + "after_epoch_end:1", + "before_train_end", + "!save_model", # No final checkpoint. + "set_status:finishing", + ) + # Check examples of training and validation metrics. + check_keras_metrics(data["training_metrics_sample"]) + check_keras_metrics(data["validation_metrics_sample"]) + + +# Pick this test to run eagerly because it both saves and loads checkpoints, which feel like it +# could matter if run_eagerly was set or not. +@pytest.mark.parametrize("eager", [False, True]) +@pytest.mark.tensorflow +def test_save_restore_and_warm_start(tmp_path: pathlib.Path, eager: bool) -> None: + # Train-from-scratch, then check that: + # - initial weight is 0 (no checkpoint was loaded) + # - initial epoch is 0 (no training state was loaded) + # - checkpoint gets saved + events = do_fit(tmp_path, eager=eager, checkpoint=None, continue_id=1) + data = utils.assert_events_match( + events, + "!load_model", + "after_train_begin:0.0000", + "before_epoch_end:0", + ("save_model", "ckpt"), + "after_epoch_end:0", + "before_epoch_end:1", + "save_model", + "after_epoch_end:1", + ) + + # Grab the checkpoint uuid and the weight from the "save_model" match. + ckpt, weight = data["ckpt"] + + # Continue training (continue_id does match), then check that: + # - initial weight is nonzero (checkpoint was loaded) + # - initial epoch is nonzero (training state was loaded) + # - steps_completed was properly restored + events = do_fit(tmp_path, eager=eager, checkpoint=ckpt, continue_id=1) + utils.assert_events_match( + events, + "set_status:restoring", + "load_model", + "after_train_begin:%.4f" % weight, + "!after_epoch_end:0", + "before_epoch_end:1", + "report_metrics:training:16", + "after_epoch_end:1", + "!after_epoch_end", # Don't do two epochs if we started with one already from one. + ) + + # Warm-start training (continue_id does not match), then check that: + # - initial weight is nonzero (no checkpoint was loaded) + # - initial epoch is zero (no training state was loaded) + # - steps_completed was properly reset + events = do_fit(tmp_path, eager=eager, checkpoint=ckpt, continue_id=2) + utils.assert_events_match( + events, + "set_status:restoring", + "load_model", + "after_train_begin:%.4f" % weight, + "report_metrics:training:8", + "after_epoch_end:0", + "after_epoch_end:1", + "!after_epoch_end", + ) + + +@pytest.mark.tensorflow +def test_checkpoint_epochs(tmp_path: pathlib.Path) -> None: + # Never checkpoint, except on preemption or completion + events = do_fit(tmp_path, checkpoint_epochs=0, epochs=4) + utils.assert_events_match( + events, + # The only save is after the final on_epoch_end + "!save_model", + "after_epoch_end:3", + "!after_epoch_end", + "before_train_end", + "save_model", + ) + + # Same thing, but trigger a checkpoint mid-training. + events = do_fit(tmp_path, checkpoint_epochs=0, set_preempt_on_event="report_progress:0.5000") + utils.assert_events_match( + events, + "!save_model", # The preemption-caused checkpoint is in on_train_end, not on_epoch_end. + "after_epoch_end:0", + "!after_epoch_end", + "before_train_end", + "save_model", + ) + + # Checkpoint every other epoch, exiting on a natural checkpoint. + events = do_fit(tmp_path, checkpoint_epochs=2, epochs=4) + utils.assert_events_match( + events, + "!save_model", + "before_epoch_end:1", + "save_model", + "after_epoch_end:1", + "!save_model", + "before_epoch_end:3", + "save_model", + "after_epoch_end:3", + # There is nothing to save in the on_train_end hook. + "!after_epoch_end", + "!save_model", + ) + + # Checkpoint every other epoch, and also at the end, if there is uncheckpointed work. + events = do_fit(tmp_path, checkpoint_epochs=2, epochs=3) + utils.assert_events_match( + events, + "!save_model", + "before_epoch_end:1", + "save_model", + "after_epoch_end:1", + "!save_model", + "after_epoch_end:2", + "!save_model", + "!after_epoch_end", + # Expect an on_train_end checkpoint. + "before_train_end", + "save_model", + ) + + # Checkpoint every other epoch, preempting after a natural checkpoint. + events = do_fit( + tmp_path, checkpoint_epochs=2, epochs=4, set_preempt_on_event="report_progress:0.5000" + ) + utils.assert_events_match( + events, + "!save_model", + "before_epoch_end:1", + "save_model", + "after_epoch_end:1", + # No on_train_end checkpoint. + "!after_epoch_end", + "!save_model", + ) + + # Checkpoint every other epoch, preempting when there wasn't a checkpoint. + events = do_fit( + tmp_path, checkpoint_epochs=2, epochs=4, set_preempt_on_event="report_progress:0.2500" + ) + utils.assert_events_match( + events, + "!save_model", + "after_epoch_end:0", + "!after_epcoh_end", + "!save_model", + # Expect an on_train_end checkpoint. + "before_train_end", + "save_model", + ) + + +@pytest.mark.tensorflow +def test_report_period(tmp_path: pathlib.Path) -> None: + events = do_fit(tmp_path, train_metrics_report_period=3) + # There are 8 batches per epoch. + data = utils.assert_events_match( + events, + "!report_metrics:training:1", + "!report_metrics:training:2", + ("report_metrics:training:3", "training_metrics_sample"), + "!report_metrics:training:4", + "!report_metrics:training:5", + "report_metrics:training:6", + "!report_metrics:training:7", + "!report_metrics:training:8", + "report_metrics:validation:8", + "report_metrics:training:9", + "!report_metrics:training:10", + "!report_metrics:training:11", + "report_metrics:training:12", + "!report_metrics:training:13", + "!report_metrics:training:14", + "report_metrics:training:15", + "!report_metrics:training:16", + "report_metrics:validation:16", + "!report_metrics:training", + ) + # Check training metrics from the non-epoch reporting codepath. + check_keras_metrics(data["training_metrics_sample"]) + + +# Pick this test to run eagerly because multi-gpu training, feel like it might be eager-senstive. +@pytest.mark.parametrize("eager", [False, True]) +@pytest.mark.parametrize("multinode", [False, True]) +@pytest.mark.skipif(len(tf.config.list_physical_devices("GPU")) < 2, reason="not enough gpus") +@pytest.mark.gpu_parallel +def test_multi_gpu(tmp_path: pathlib.Path, eager: bool, multinode: bool) -> None: + """ + Getting an environment where this test can actually pass can be a real pain. + + If you are running on bare metal or a vm with multiple gpus, you can run the test directly, + but you must have your nvidia driver and cuda library installations all squared away. That is + surprisingly difficult to achieve, at least I (rb) couldn't make it work. + + The tedious alternative, but which I find more reliable, is to run it in a docker container + that is compatible with your GPU and driver. I selected an NGC tensorflow image from the NGC + image support matrix: + + https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html + + Then I built a little dockerfile like this: + + FROM ncr.io/nvidia/tensorflow:$YOUR_NGC_IMAGE + RUN pip install determined pytest && pip uninstall --yes determined + ENV PYTHONUNBUFFERED=1 + + Then I configured /etc/docker/daemon.json with some settings commonly used to make dtrain happy: + + { + "runtimes": { + "nvidia": { + "args": [], + "path": "nvidia-container-runtime" + } + }, + "default-shm-size": "4G", + "default-ulimits": { + "memlock": { + "Name": "memlock", + "Hard": -1, + "Soft": -1 + }, + "stack": { + "Name": "stack", + "Hard": 67108864, + "Soft": 67108864 + } + } + } + + Restarted docker: + + sudo systemctl restart docker + + Then I mounted the entire determined project into a container with that new image: + + cd /path/to/determined + docker run -it --rm -v $PWD:$PWD --gpus=all $YOUR_CUSTOM_IMAGE + + And finally, inside the container, I navigate to the harness directory and install determined + with the editable setting: + + cd /path/to/determined/harness + pip install -e . + + And voila, I can finally run the tests: + + pytest -v -s --tb=native tests/experiment/keras/test_callback.py -k test_multi_gpu + + I can also edit the tests from outside the container and rerun them immediately within the + container because I mounted the whole project into the container and used `-e` with the pip + install. + """ + + script = os.path.join(os.path.dirname(__file__), "train.py") + cmd = [sys.executable, script, str(tmp_path)] + if eager: + cmd += ["--eager"] + # NCCL can hit failures in this test, so make it easy to debug. + env = {**os.environ, "NCCL_DEBUG": "info"} + if multinode: + tf_config = { + "cluster": {"worker": ["localhost:12345", "localhost:12346"]}, + "task": {"type": "worker", "index": 0}, + } + # Start worker 0. + env["TF_CONFIG"] = json.dumps(tf_config) + env["CUDA_VISIBLE_DEVICES"] = "0" + p1 = subprocess.Popen(cmd, env=env) + # Start worker 1. + tf_config["task"]["index"] = 1 # type: ignore + env["TF_CONFIG"] = json.dumps(tf_config) + env["CUDA_VISIBLE_DEVICES"] = "1" + p2 = subprocess.Popen(cmd, env=env) + ret1 = p1.wait() + ret2 = p2.wait() + assert ret1 == ret2 == 0, (ret1, ret2) + else: + env.pop("TF_CONFIG", None) + env.pop("CUDA_VISIBLE_DEVICES", None) + subprocess.run(cmd, check=True) + + +@pytest.mark.tensorflow +def test_iris() -> None: + """ + Make sure the DeterminedCallback-based iris example works. + """ + cmd = [sys.executable, utils.cv_examples_path("iris_tf_keras/train.py"), "--epochs", "1"] + subprocess.run(cmd, check=True) diff --git a/harness/tests/experiment/keras/test_tf_keras_trial.py b/harness/tests/experiment/keras/test_tf_keras_trial.py index 3823763d19c..ad4ffe3b09f 100644 --- a/harness/tests/experiment/keras/test_tf_keras_trial.py +++ b/harness/tests/experiment/keras/test_tf_keras_trial.py @@ -368,85 +368,6 @@ def make_workloads() -> workload.Stream: controller.run() -@pytest.mark.tensorflow -def test_iris(tmp_path: pathlib.Path) -> None: - """ - Make sure each example: - - trains - - validates - - checkpoints - - can load from checkpoint - """ - checkpoint_dir = str(tmp_path.joinpath("checkpoint")) - latest_checkpoint = None - steps_completed = 0 - - def make_workloads_1() -> workload.Stream: - """ - Train one batch, validate one batch, checkpoint. - """ - trainer = utils.TrainAndValidate() - yield from trainer.send(steps=2, validation_freq=2, scheduling_unit=1) - - interceptor = workload.WorkloadResponseInterceptor() - yield from interceptor.send(workload.checkpoint_workload()) - nonlocal latest_checkpoint, steps_completed - latest_checkpoint = interceptor.metrics_result()["uuid"] - steps_completed = trainer.get_steps_completed() - - example_path = utils.cv_examples_path("iris_tf_keras/model_def.py") - trial_cls = utils.import_class_from_module("IrisTrial", example_path) - - hparams = { - "learning_rate": 1.0e-4, - "learning_rate_decay": 1.0e-6, - "layer1_dense_size": 16, - "global_batch_size": 30, - } - data = { - "train_url": "http://download.tensorflow.org/data/iris_training.csv", - "test_url": "http://download.tensorflow.org/data/iris_test.csv", - } - - exp_config = utils.make_default_exp_config( - hparams, - scheduling_unit=1, - searcher_metric="random", - checkpoint_dir=checkpoint_dir, - data=data, - ) - - controller = utils.make_trial_controller_from_trial_implementation( - trial_cls, - hparams, - make_workloads_1(), - trial_seed=777, - exp_config=exp_config, - checkpoint_dir=checkpoint_dir, - expose_gpus=True, - ) - # Verify that train/validate/ckpt doesn't puke. - controller.run() - - # Verify that load/train/validate doesn't puke. - def make_workloads_2() -> workload.Stream: - trainer = utils.TrainAndValidate() - yield from trainer.send(steps=1, validation_freq=1, scheduling_unit=1) - - controller = utils.make_trial_controller_from_trial_implementation( - trial_cls, - hparams, - make_workloads_2(), - trial_seed=777, - exp_config=exp_config, - checkpoint_dir=checkpoint_dir, - latest_checkpoint=latest_checkpoint, - steps_completed=steps_completed, - expose_gpus=True, - ) - controller.run() - - @pytest.mark.tensorflow @pytest.mark.gpu def test_tf2_no_op(tmp_path: pathlib.Path) -> None: diff --git a/harness/tests/experiment/keras/train.py b/harness/tests/experiment/keras/train.py new file mode 100644 index 00000000000..84b3d2c7b16 --- /dev/null +++ b/harness/tests/experiment/keras/train.py @@ -0,0 +1,68 @@ +""" +This is a script for testing tf-native dtrain with the DeterminedCallback. + +Tf-native dtrain depends on environment variables and singletons and such, which makes it hard to +test other than in a totally separate script, executed as a sub-process. + +See ./test_callback.py::test_multi_gpu() for additional details. +""" + +import argparse + +import test_callback + +from determined import core + + +def main(path: str, eager: bool) -> None: + distributed, strategy = core.DistributedContext.from_tf_config() + + with strategy.scope(): + model = test_callback.build_model(eager=eager) + + events = test_callback.do_fit(path, model=model, distributed=distributed) + if distributed.rank == 0: + test_callback.assert_events_match( + events, + "!load_model", + "after_train_begin", + "set_status:training", + "set_status:validating", + "report_metrics:validation", + "before_epoch_end:0", + "report_metrics:training", + "report_progress:0.5000", + "set_status:checkpointing", + "save_model", + "after_epoch_end:0", + "before_epoch_end:1", + "report_progress:1.000", + "save_model", + "after_epoch_end:1", + "before_train_end", + "!save_model", # No final checkpoint. + "set_status:finishing", + ) + else: + test_callback.assert_events_match( + events, + "!load_model", + "after_train_begin", + "before_epoch_end:0", + "save_model", + "after_epoch_end:0", + "before_epoch_end:1", + "save_model", + "after_epoch_end:1", + "before_train_end", + "!save_model", # No final checkpoint. + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("path") + parser.add_argument("--eager", action="store_true") + args = parser.parse_args() + + main(args.path, args.eager) diff --git a/harness/tests/experiment/pytorch/test_deepspeed_autotuning.py b/harness/tests/experiment/pytorch/test_deepspeed_autotuning.py deleted file mode 100644 index 539fe525b02..00000000000 --- a/harness/tests/experiment/pytorch/test_deepspeed_autotuning.py +++ /dev/null @@ -1,2078 +0,0 @@ -import argparse -import collections -import copy -import json -import math -import pathlib -import shutil -import tempfile -from typing import Any, Deque, Dict, Generator, List, Mapping, Optional, Sequence, Tuple, cast - -import pytest - -from determined import searcher -from determined.common.api import bindings -from determined.pytorch import dsat -from tests import custom_search_mocks - -ERROR_METRIC_NAME = "error" - -BASE_EXPERIMENT_FIXTURE_PATH = ( - pathlib.Path(__file__).resolve().parent.parent.joinpath("fixtures/deepspeed_autotune") -) -MODEL_DIR = BASE_EXPERIMENT_FIXTURE_PATH.joinpath("example_experiment") -DS_CONFIG_PATH = MODEL_DIR.joinpath("ds_config.json") -CONFIG_PATH = MODEL_DIR.joinpath("deepspeed.yaml") -DEFAULT_ARGS_DICT = { - search_method_name: dsat.get_full_parser().parse_args( - [search_method_name, str(CONFIG_PATH), str(MODEL_DIR)] - ) - for search_method_name in dsat.defaults.ALL_SEARCH_METHOD_NAMES -} -for default_args in DEFAULT_ARGS_DICT.values(): - default_args.experiment_id = 0 - -DEFAULT_SEARCH_RUNNER_CONFIG_DICT = { - search_method_name: dsat.get_search_runner_config_from_args(default_args) - for search_method_name, default_args in DEFAULT_ARGS_DICT.items() -} -DEFAULT_CUSTOM_DSAT_EXP_CONFIG_DICT = { - search_method_name: dsat.get_custom_dsat_exp_conf_from_args(default_args) - for search_method_name, default_args in DEFAULT_ARGS_DICT.items() -} - -MODEL_INFO_PROFILE_METRIC_FIXTURE: Dict[str, Any] = { - "num_params": 60192808, - "trainable_num_params": 60192808, - "activation_mem_per_gpu": 1698283521, - "rank": 0, - "gpu_mem": 15843721216, -} - -DSATTRIAL_ARGS: Mapping[str, Any] = { - "hparams": {"deepspeed_config": "ds_config.json"}, - "model_dir": BASE_EXPERIMENT_FIXTURE_PATH.joinpath("example_experiment"), - "slots_per_trial": 2, - "length": 5, -} - -HPARAMS_FIXTURE: Dict[str, Any] = { - "deepspeed_config": "ds_config.json", - dsat.defaults.OVERWRITE_KEY: { - "train_batch_size": 1, - "gradient_accumulation_steps": 1, - "train_micro_batch_size_per_gpu": 1, - }, -} - -HF_DS_CONFIG_PATH = BASE_EXPERIMENT_FIXTURE_PATH.joinpath("hf_integration_experiment").joinpath( - "ds_config.json" -) -# HF args without any training batch size args and no deepspeed flag. -RAW_DEFAULT_HF_ARGS_WITHOUT_DEEPSPEED = """" ---model_name_or_path gpt2 ---dataset_name wikitext ---dataset_config_name wikitext-2-raw-v1 ---do_train ---do_eval ---max_steps 100 ---logging_strategy steps ---logging_steps 10 ---output_dir /tmp/test-clm ---eval_steps 10 ---evaluation_strategy steps ---save_total_limit 3 ---seed 1337 ---save_strategy steps ---save_steps 20 ---per_device_eval_batch_size 8 -""" -DEFAULT_HF_ARGS_WITHOUT_DEEPSPEED = RAW_DEFAULT_HF_ARGS_WITHOUT_DEEPSPEED.split() - - -def _run_searcher( - search_method_name: str, all_metrics: List[Dict[str, Any]] -) -> custom_search_mocks.MockMasterSearchRunner: - """ - Run a mocked version of the Determined master with a deterministic series of - returned metrics for a given Deepspeed Autotune Custom Search Method - """ - search_method_class = dsat.get_search_method_class(search_method_name) - default_args = DEFAULT_ARGS_DICT[search_method_name] - default_exp_config = DEFAULT_CUSTOM_DSAT_EXP_CONFIG_DICT[search_method_name] - with tempfile.TemporaryDirectory() as searcher_dir: - searcher_path = pathlib.Path(searcher_dir) - search_method = search_method_class(args=default_args, exp_config=default_exp_config) - mock_master_obj = DSATMockMaster(all_metrics=all_metrics) - search_runner = custom_search_mocks.MockMasterSearchRunner( - search_method, mock_master_obj, searcher_path - ) - search_runner.run(exp_config={}, context_dir="", includes=None) - return search_runner - - -@pytest.mark.timeout(10) -def test_deepspeed_autotune_happy_path() -> None: - """ - Simulate the Deepspeed Autotune Search Methods end to end and make sure - nothing falls over - """ - for search_method_name in dsat.defaults.ALL_SEARCH_METHOD_NAMES: - # All of our search methods currently run all of the specified `max-trials` in the - # happy path - exp_num_trials = cast(int, dsat.defaults.AUTOTUNING_ARG_DEFAULTS["max-trials"]) - model_info_profile_trial_metrics: List[Dict[str, Any]] = [MODEL_INFO_PROFILE_METRIC_FIXTURE] - default_metric_name = str(dsat.defaults.AUTOTUNING_ARG_DEFAULTS["metric"]) - successful_trial_metrics: List[Dict[str, Any]] = [ - {default_metric_name: 0.0} for _ in range(exp_num_trials - 1) - ] - all_metrics = model_info_profile_trial_metrics + successful_trial_metrics - search_runner = _run_searcher(search_method_name, all_metrics) - assert len(search_runner.state.trials_created) == exp_num_trials - assert len(search_runner.state.trials_closed) == exp_num_trials - assert len(search_runner.state.trial_progress) == exp_num_trials - for trial_uuid in search_runner.state.trial_progress: - assert search_runner.state.trial_progress[trial_uuid] == 1.0 - assert not search_runner.state.experiment_failed - assert search_runner.state.experiment_completed - - -@pytest.mark.timeout(10) -def test_continuous_failures() -> None: - """ - Make sure that DSAT Search Methods can handle continuous failures. The experiment should be - marked as failed. - """ - for search_method_name in dsat.defaults.ALL_SEARCH_METHOD_NAMES: - exp_num_trials = cast(int, dsat.defaults.AUTOTUNING_ARG_DEFAULTS["max-trials"]) - model_info_profile_trial_metrics = [MODEL_INFO_PROFILE_METRIC_FIXTURE] - failed_trial_metrics = [{ERROR_METRIC_NAME: True} for _ in range(exp_num_trials - 1)] - all_metrics = model_info_profile_trial_metrics + failed_trial_metrics - search_runner = _run_searcher(search_method_name, all_metrics) - - assert len(search_runner.state.trials_created) == exp_num_trials - assert len(search_runner.state.failures) == exp_num_trials - 1 - assert len(search_runner.state.trials_closed) == exp_num_trials - assert len(search_runner.state.trial_progress) == exp_num_trials - assert search_runner.state.experiment_failed - assert not search_runner.state.experiment_completed - - -@pytest.mark.timeout(10) -def test_one_off_failure() -> None: - """Make sure that DSAT Search Methods can properly handle a single failure""" - for search_method_name in dsat.defaults.ALL_SEARCH_METHOD_NAMES: - exp_num_trials = cast(int, dsat.defaults.AUTOTUNING_ARG_DEFAULTS["max-trials"]) - model_info_profile_trial_metrics = [MODEL_INFO_PROFILE_METRIC_FIXTURE] - one_failed_trial_metrics: List[Dict[str, Any]] = [{ERROR_METRIC_NAME: True}] - default_metric_name: str = str(dsat.defaults.AUTOTUNING_ARG_DEFAULTS["metric"]) - successful_trial_metrics = [{default_metric_name: 0.0} for _ in range(exp_num_trials - 2)] - all_metrics = ( - model_info_profile_trial_metrics + one_failed_trial_metrics + successful_trial_metrics - ) - search_runner = _run_searcher(search_method_name, all_metrics) - - assert len(search_runner.state.trials_created) == exp_num_trials - assert len(search_runner.state.failures) == 1 - assert len(search_runner.state.trials_closed) == exp_num_trials - assert len(search_runner.state.trial_progress) == exp_num_trials - assert not search_runner.state.experiment_failed - assert search_runner.state.experiment_completed - - -@pytest.mark.timeout(5) -def test_model_profile_info_run_failure() -> None: - """Test DSAT with a failed model profile info run.""" - for search_method_name in dsat.defaults.ALL_SEARCH_METHOD_NAMES: - failed_model_profile_info_trial_metrics = [ - {ERROR_METRIC_NAME: True}, - ] - search_runner = _run_searcher( - search_method_name, - failed_model_profile_info_trial_metrics, - ) - assert len(search_runner.state.trials_created) == 1 - assert len(search_runner.state.failures) == 1 - assert len(search_runner.state.trials_closed) == 1 - assert len(search_runner.state.trial_progress) == 1 - assert search_runner.state.experiment_failed - assert not search_runner.state.experiment_completed - - -class TestDSATTrial: - @pytest.mark.timeout(5) - def setup_class(self) -> None: - self.first_trial = dsat.DSATTrial(**DSATTRIAL_ARGS) - - @pytest.mark.timeout(5) - def test_lineage_methods(self) -> None: - """ - Testing expected behavior of lineage properties. - """ - trials = [self.first_trial] - for _ in range(10): - trials.append(dsat.DSATTrial(parent=trials[-1], **DSATTRIAL_ARGS)) - - last_trial = None - for idx, trial in enumerate(trials): - if idx == 0: - assert trial.parent is None - else: - assert trial.parent == trials[idx - 1] - if idx != len(trials) - 1: - assert trial.children == {trials[idx + 1]} - else: - assert trial.children == set() - assert trial.lineage_root == self.first_trial - assert trial.lineage_set == set(trials) - assert trial.num_completed_trials_in_lineage == idx - metric_name = ( - "test" if trial.searcher_metric_name is None else trial.searcher_metric_name - ) - trial.metric = {metric_name: 0.0} - last_trial = trial - if last_trial is not None: - assert last_trial.num_completed_trials_in_lineage == len(trials) - - @pytest.mark.timeout(5) - def test_error_history(self) -> None: - """ - Testing error history. - """ - initial_successful_trials = [self.first_trial] - for _ in range(10): - initial_successful_trials.append( - dsat.DSATTrial(parent=initial_successful_trials[-1], **DSATTRIAL_ARGS) - ) - - errored_trial = dsat.DSATTrial(parent=initial_successful_trials[-1], **DSATTRIAL_ARGS) - errored_trial.error = True - alternating_errored_trials = [errored_trial] - for _ in range(10): - last_trial = alternating_errored_trials[-1] - next_trial = dsat.DSATTrial(parent=last_trial, **DSATTRIAL_ARGS) - if not last_trial.error: - next_trial.error - alternating_errored_trials.append(next_trial) - - all_trials = initial_successful_trials + alternating_errored_trials - - seen_errored = False - for trial in all_trials: - if trial.error: - seen_errored = True - if not seen_errored: - assert not trial.error_in_direct_history - else: - assert trial.error_in_direct_history - - -def queue_and_trial_tracker_builder( - args: argparse.Namespace, -) -> Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker]: - """Completes the model profile into trial and load up a queue of max_trials Trials.""" - exp_config = dsat.get_custom_dsat_exp_conf_from_args(args) - trial_tracker = dsat.DSATTrialTracker(args=args, exp_config=exp_config) - model_profile_info_trial = trial_tracker.create_model_profile_info_trial() - trial_tracker.queue_and_register_trial(model_profile_info_trial) - trial_tracker.update_trial_metric( - trial_tracker.queue.popleft(), MODEL_INFO_PROFILE_METRIC_FIXTURE - ) - - queued_trials = [] - for idx in range(trial_tracker.max_trials - 1): - overwrites = {dsat.defaults.OVERWRITE_KEY: {"zero_optimization": {"stage": 1 + (idx % 3)}}} - hparams = {**HPARAMS_FIXTURE, **overwrites} - # Add an arbitrary hp to avoid the non-duplicate hparams check in `queue_and_register_trial` - hparams["_arbitrary"] = idx - trial = trial_tracker.create_trial(hparams) - queued_trials.append(trial) - trial_tracker.queue_and_register_trial(trial) - return queued_trials, trial_tracker - - -@pytest.fixture -def basic_queue_and_trial_tracker() -> ( - Generator[Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker], Any, None] -): - yield queue_and_trial_tracker_builder(DEFAULT_ARGS_DICT["_test"]) - - -@pytest.fixture -def max_concurrent_trials_queue_and_tracker() -> ( - Generator[Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker], Any, None] -): - args = copy.deepcopy(DEFAULT_ARGS_DICT["_test"]) - args.max_concurrent_trials = 2 - yield queue_and_trial_tracker_builder(args) - - -@pytest.fixture -def max_slots_queue_and_trial_tracker() -> ( - Generator[Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker], Any, None] -): - args = copy.deepcopy(DEFAULT_ARGS_DICT["_test"]) - args.max_slots = 4 - yield queue_and_trial_tracker_builder(args) - - -@pytest.fixture -def failed_model_profile_info_queue_and_trial_tracker() -> ( - Generator[dsat.DSATTrialTracker, Any, None] -): - exp_config = DEFAULT_CUSTOM_DSAT_EXP_CONFIG_DICT["_test"] - trial_tracker = dsat.DSATTrialTracker(args=DEFAULT_ARGS_DICT["_test"], exp_config=exp_config) - model_profile_info_trial = trial_tracker.create_model_profile_info_trial() - trial_tracker.queue_and_register_trial(model_profile_info_trial) - assert trial_tracker.model_profile_info_trial - trial_tracker.report_trial_early_exit(trial_tracker.model_profile_info_trial) - yield trial_tracker - - -@pytest.fixture -def early_stopping_queue_and_trial_tracker() -> dsat.DSATTrialTracker: - """ - Returns a trial tracker whose early_stopping criteria should be triggered. - """ - args = copy.deepcopy(DEFAULT_ARGS_DICT["_test"]) - args.early_stopping = 3 - _, trial_tracker = queue_and_trial_tracker_builder(args) - # One successful initial trial. - trial = trial_tracker.queue.popleft() - assert trial.searcher_metric_name - trial_tracker.update_trial_metric(trial, {trial.searcher_metric_name: 0.0}) - for _ in range(args.early_stopping): - trial = trial_tracker.queue.popleft() - trial_tracker.report_trial_early_exit(trial) - return trial_tracker - - -class TestDSATTrialTracker: - @pytest.mark.timeout(5) - def test_trial_registration( - self, basic_queue_and_trial_tracker: Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker] - ) -> None: - queued_trials, trial_tracker = basic_queue_and_trial_tracker - for trial in queued_trials: - assert trial.request_id in trial_tracker - - @pytest.mark.timeout(5) - def test_trial_queue_and_state_all_successes( - self, basic_queue_and_trial_tracker: Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker] - ) -> None: - """ - Verify the expected trial tracker states are accurate when all trials succeed. - """ - queued_trials, trial_tracker = basic_queue_and_trial_tracker - for idx, trial in enumerate(queued_trials): - num_trials_in_queue = len(queued_trials) - idx - assert len(trial_tracker.queue) == num_trials_in_queue - assert trial_tracker.num_completed_trials == 1 + idx - assert not trial.running - assert trial_tracker.can_run_more_trials - - popped_trial = trial_tracker.queue.popleft() - popped_trial.running = True - - assert popped_trial == trial - assert len(trial_tracker.queue) == num_trials_in_queue - 1 - assert trial_tracker.num_completed_trials == 1 + idx - assert trial_tracker.num_running_trials == 1 - assert popped_trial.searcher_metric_name - - trial_tracker.update_trial_metric( - popped_trial, {popped_trial.searcher_metric_name: 0.0} - ) - assert trial_tracker.num_completed_trials == 2 + idx - assert trial_tracker.num_running_trials == 0 - - assert not trial_tracker.can_run_more_trials - assert len(trial_tracker.queue) == 0 - assert trial_tracker.max_trials_are_running_or_closed - assert not trial_tracker.should_be_failure - - @pytest.mark.timeout(5) - def test_trial_queue_and_state_all_errors( - self, basic_queue_and_trial_tracker: Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker] - ) -> None: - """ - Verify the expected trial tracker states are accurate when all trials fail. - """ - queued_trials, trial_tracker = basic_queue_and_trial_tracker - for idx, trial in enumerate(queued_trials): - num_trials_in_queue = len(queued_trials) - idx - assert len(trial_tracker.queue) == num_trials_in_queue - assert trial_tracker.num_completed_trials == 1 + idx - assert not trial.running - assert trial_tracker.can_run_more_trials - - popped_trial = trial_tracker.queue.popleft() - popped_trial.running = True - - assert popped_trial == trial - assert len(trial_tracker.queue) == num_trials_in_queue - 1 - assert trial_tracker.num_completed_trials == 1 + idx - assert trial_tracker.num_running_trials == 1 - - trial_tracker.report_trial_early_exit(popped_trial) - assert trial_tracker.num_completed_trials == 2 + idx - assert trial_tracker.num_running_trials == 0 - - assert not trial_tracker.can_run_more_trials - assert len(trial_tracker.queue) == 0 - assert trial_tracker.max_trials_are_running_or_closed - assert trial_tracker.should_be_failure - - @pytest.mark.timeout(5) - def test_max_concurrent_trials( - self, - max_concurrent_trials_queue_and_tracker: Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker], - ) -> None: - """ - Verify that `max_concurrent_trials` is respected. - """ - _, trial_tracker = max_concurrent_trials_queue_and_tracker - while trial_tracker.can_run_more_trials: - popped_trial = trial_tracker.queue.popleft() - assert popped_trial.searcher_metric_name - trial_tracker.update_trial_metric( - popped_trial, {popped_trial.searcher_metric_name: 0.0} - ) - assert trial_tracker.num_running_trials <= trial_tracker.max_concurrent_trials - - @pytest.mark.timeout(5) - def test_max_slots( - self, max_slots_queue_and_trial_tracker: Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker] - ) -> None: - """ - Verify that `max_slots` is respected. - """ - _, trial_tracker = max_slots_queue_and_trial_tracker - while trial_tracker.can_run_more_trials: - popped_trial = trial_tracker.queue.popleft() - assert popped_trial.searcher_metric_name - trial_tracker.update_trial_metric( - popped_trial, {popped_trial.searcher_metric_name: 0.0} - ) - assert ( - trial_tracker.num_running_trials * popped_trial.slots_per_trial - <= trial_tracker.max_slots - ) - - @pytest.mark.timeout(5) - def test_best_metric_tracking( - self, basic_queue_and_trial_tracker: Tuple[List[dsat.DSATTrial], dsat.DSATTrialTracker] - ) -> None: - """ - Uses a series of successful trials where each trial is better than the previous one. - """ - _, trial_tracker = basic_queue_and_trial_tracker - metrics = list(range(len(trial_tracker) - 1)) - if not trial_tracker.smaller_is_better: - metrics = list(reversed(metrics)) - while trial_tracker.can_run_more_trials: - popped_trial = trial_tracker.queue.popleft() - assert popped_trial.searcher_metric_name - trial_tracker.update_trial_metric( - popped_trial, {popped_trial.searcher_metric_name: metrics.pop()} - ) - assert trial_tracker.best_trial == popped_trial - assert trial_tracker.best_trials_by_stage[popped_trial.stage] == popped_trial - - -def search_state_and_method_builder( - args: argparse.Namespace, -) -> Tuple[searcher.SearcherState, dsat.BaseDSATSearchMethod]: - """ - Creates the appropriate `BaseDSATSearchMethod` superclass instance with a completed model - profile info run and a populated queue. - """ - exp_config = dsat.get_custom_dsat_exp_conf_from_args(args) - search_method = dsat.get_search_method_class(args.search_method)( - args=args, - exp_config=exp_config, - ) - searcher_state = searcher.SearcherState() - search_method.initial_operations(searcher_state) - assert search_method.trial_tracker.model_profile_info_trial - search_method.on_validation_completed( - searcher_state, - search_method.trial_tracker.model_profile_info_trial.request_id, - MODEL_INFO_PROFILE_METRIC_FIXTURE, - search_method.trial_tracker.model_profile_info_trial.length, - ) - return searcher_state, search_method - - -@pytest.fixture -def default_random_state_and_search_method() -> ( - Generator[Tuple[searcher.SearcherState, dsat.BaseDSATSearchMethod], Any, None] -): - searcher_state, search_method = search_state_and_method_builder(DEFAULT_ARGS_DICT["random"]) - yield searcher_state, search_method - - -@pytest.fixture -def long_random_state_and_search_method() -> ( - Generator[Tuple[searcher.SearcherState, dsat.BaseDSATSearchMethod], Any, None] -): - """For long-running tests which need a longer max_trials.""" - args = copy.deepcopy(DEFAULT_ARGS_DICT["random"]) - args.max_trials = 10**3 - args.trials_per_random_config = args.max_trials - searcher_state, search_method = search_state_and_method_builder(args) - yield searcher_state, search_method - - -class TestRandomDSATSearchMethodTrialCreation: - """ - Testing the various `RandomDSATSearchMethod` methods related to trial creation. - """ - - @pytest.mark.timeout(5) - def test_random_hparams_and_search_data( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - _, search_method = default_random_state_and_search_method - for _ in range(100): - for stage in range(4): - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - mbs = hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] - assert hparams[dsat.defaults.OVERWRITE_KEY]["zero_optimization"]["stage"] == stage - assert search_data.lo <= mbs <= search_data.hi - - @pytest.mark.timeout(5) - def test_random_hparams_and_search_data_after_best( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - for _ in range(100): - _, search_method = default_random_state_and_search_method - for stage in range(4): - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(trial) - search_method.trial_tracker.queue.popleft() - assert trial.searcher_metric_name - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: 0.0} - ) - _, new_search_data = search_method.get_random_hparams_and_search_data(stage) - assert new_search_data.lo <= new_search_data.hi - - @pytest.mark.timeout(5) - def test_lineage_continuation_after_failures( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Verifying that a lineage will be attempted for `trials_per_random_config` total attempts - even when each trial fails. - """ - searcher_state, search_method = default_random_state_and_search_method - # Take and fail the next trial - first_trial = next_trial = search_method.choose_next_trial_from_queue() - # Remove everything else, so that we only have this lineage to handle. - search_method.trial_tracker.queue.clear() - # The next search_method.trials_per_random_config - 1 trials should have the - # first trial as their parent. - for _ in range(search_method.trials_per_random_config - 1): - search_method.on_trial_exited_early( - searcher_state, next_trial.request_id, searcher.ExitedReason.ERRORED - ) - next_trial = search_method.choose_next_trial_from_queue() - # Force the search data to be non-trivial, so that we avoid exiting due to a trivial - # search range. - assert next_trial.search_data - next_trial.search_data.lo = 1 - next_trial.search_data.hi = 10 - next_trial.ds_config["train_micro_batch_size_per_gpu"] = 5 - assert next_trial.lineage_root == first_trial - # And the next trial should be from a new lineage. - search_method.on_trial_exited_early( - searcher_state, next_trial.request_id, searcher.ExitedReason.ERRORED - ) - next_trial = search_method.choose_next_trial_from_queue() - assert next_trial.lineage_root != first_trial - - @pytest.mark.timeout(5) - def test_lineage_continuation_after_successes( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Verifying that a lineage will be attempted for `trials_per_random_config` total attempts - even when each trial succeeds, each improving on the last. - """ - searcher_state, search_method = default_random_state_and_search_method - # Take and fail the next trial - first_trial = next_trial = search_method.choose_next_trial_from_queue() - metrics = list(range(search_method.trials_per_random_config)) - if search_method.trial_tracker.smaller_is_better: - metrics = metrics[::-1] - # Remove everything else, so that we only have this lineage to handle. - search_method.trial_tracker.queue.clear() - # The next search_method.trials_per_random_config - 1 trials should have the - # first trial as their parent. - for idx in range(search_method.trials_per_random_config - 1): - assert next_trial.searcher_metric_name - search_method.on_validation_completed( - searcher_state, - next_trial.request_id, - {next_trial.searcher_metric_name: metrics[idx]}, - train_length=1, - ) - next_trial = search_method.choose_next_trial_from_queue() - assert next_trial.lineage_root == first_trial - # And the next trial should be from a new lineage. - assert next_trial.searcher_metric_name - idx = search_method.trials_per_random_config - 1 - search_method.on_validation_completed( - searcher_state, - next_trial.request_id, - {next_trial.searcher_metric_name: metrics[idx]}, - train_length=1, - ) - next_trial = search_method.choose_next_trial_from_queue() - assert next_trial.lineage_root != first_trial - - -class TestRandomDSATSearchMethodSearch: - @pytest.mark.timeout(5) - def test_search_happy_path( - self, - long_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Ensure that when the actual `train_micro_batch_size_per_gpu` lies between the - search bounds, this optimal value will be found. - """ - searcher_state, search_method = long_random_state_and_search_method - search_method.trial_tracker.queue.clear() - # Test for that all stages successfully find all possible values in their search range. - # Reverse the stage range so that early stopping of stage-3 trials is not triggered. - for stage in reversed(range(4)): - _, search_data = search_method.get_random_hparams_and_search_data(stage) - num_possible_mbs = search_data.hi - search_data.lo + 1 - for target_mbs in range(search_data.lo, search_data.hi + 1): - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - first_trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(first_trial) - curr_trial = search_method.trial_tracker.queue.popleft() - for _ in range(num_possible_mbs): - assert curr_trial.search_data and curr_trial.search_data - assert curr_trial.search_data.lo <= curr_trial.mbs <= curr_trial.search_data.hi - if curr_trial.mbs > target_mbs: - search_method.on_trial_exited_early( - searcher_state, curr_trial.request_id, searcher.ExitedReason.ERRORED - ) - assert search_method.trial_tracker.queue - else: - assert curr_trial.searcher_metric_name - search_method.on_validation_completed( - searcher_state, - curr_trial.request_id, - {curr_trial.searcher_metric_name: 0.0}, - curr_trial.length, - ) - assert search_method.trial_tracker.queue - if curr_trial.mbs == target_mbs: - break - curr_trial = search_method.trial_tracker.queue.popleft() - # queue should now be empty - assert not search_method.trial_tracker.queue - # Every trial should belong to the same lineage. - assert curr_trial.lineage_root == first_trial - assert curr_trial.mbs == target_mbs - - @pytest.mark.timeout(5) - def test_full_experiment_happy_path( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Simulate running a full experiment with all successful trials, each improving on the last, - and verify the expected end state. - """ - searcher_state, search_method = default_random_state_and_search_method - num_trials = 0 - while search_method.trial_tracker.can_run_more_trials: - trial = search_method.choose_next_trial_from_queue() - assert trial.searcher_metric_name is not None - num_trials += 1 - metric_val = ( - -1 * num_trials if search_method.trial_tracker.smaller_is_better else num_trials - ) - search_method.on_validation_completed( - searcher_state=searcher_state, - request_id=trial.request_id, - metric={trial.searcher_metric_name: metric_val}, - train_length=trial.length, - ) - # Verify that all max_trials were run. - assert ( - search_method.trial_tracker.num_completed_trials - == search_method.trial_tracker.max_trials - ) - # Verify that the best-found trial has the expected metric value - assert search_method.trial_tracker.best_trial is not None - assert search_method.trial_tracker.best_trial.metric == { - trial.searcher_metric_name: metric_val - } - - -class TestRandomDSATSearchMethodShouldStopLineage: - """ - Testing the various conditions which should trigger RandomDSATSearchMethod.should_stop_lineage - """ - - @pytest.mark.timeout(5) - def test_trials_per_random_config_stopping( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Test that we respect the trials_per_random_config bound. - """ - assert True - _, search_method = default_random_state_and_search_method - trial = None - for stage in range(4): - for _ in range(search_method.trials_per_random_config): - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - trial = search_method.trial_tracker.create_trial( - HPARAMS_FIXTURE, search_data, parent_trial=trial - ) - search_method.trial_tracker.queue_and_register_trial(trial) - search_method.trial_tracker.report_trial_early_exit(trial) - assert trial - assert search_method.should_stop_lineage(trial) - - @pytest.mark.timeout(5) - def test_stop_stage_3( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Verify that we stop a stage 3 lineage when a successful stage-1 or 2 trial has been found. - """ - _, search_method = default_random_state_and_search_method - trial_dict_by_stage: Dict[int, dsat.DSATTrial] = {} - for stage in (1, 2, 3): - overwrites = {dsat.defaults.OVERWRITE_KEY: {"zero_optimization": {"stage": stage}}} - hparams = {**HPARAMS_FIXTURE, **overwrites} - trial_dict_by_stage[stage] = search_method.trial_tracker.create_trial( - hparams, search_data=dsat.DSATSearchData(lo=1, hi=1) - ) - assert trial_dict_by_stage[3].searcher_metric_name - search_method.trial_tracker.update_trial_metric( - trial_dict_by_stage[3], {trial_dict_by_stage[3].searcher_metric_name: 0} - ) - assert not search_method.should_stop_lineage(trial_dict_by_stage[3]) - - search_method.trial_tracker.report_trial_early_exit(trial_dict_by_stage[1]) - assert not search_method.should_stop_lineage(trial_dict_by_stage[3]) - - search_method.trial_tracker.update_trial_metric( - trial_dict_by_stage[2], {trial_dict_by_stage[3].searcher_metric_name: 0} - ) - assert search_method.should_stop_lineage(trial_dict_by_stage[3]) - - @pytest.mark.timeout(5) - def test_stop_after_fail_on_min_mbs( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Verify that we stop a lineage after a trial erors out when attempting its minimum batch - size. - """ - searcher_state, search_method = default_random_state_and_search_method - for stage in range(4): - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = search_data.lo - trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(trial) - search_method.trial_tracker.queue.popleft() - search_method.trial_tracker.report_trial_early_exit(trial) - assert search_method.should_stop_lineage(trial) - - @pytest.mark.timeout(5) - def test_stop_after_max_possible_mbs_run( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Verify that we stop a lineage after a trial has attempted its largest possible batch size - once a hard ceiling has been established. - """ - searcher_state, search_method = default_random_state_and_search_method - # Go through stages in reversed order, in order to avoid early stage-3 exiting triggers. - for stage in reversed(range(4)): - # Lineage should be abandoned regardless of whether the follow-on Trial errors. - for should_error_next_trial in (True, False): - # First fail on batch size of two, establishing a hard ceiling. - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = 2 - errored_trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(errored_trial) - search_method.trial_tracker.queue.popleft() - search_method.trial_tracker.report_trial_early_exit(errored_trial) - - # Then update the ceiling and run a follow-on trial which attempts to run at the - # established hard ceilng (which should be `train_micro_batch_size_per_gpu = 1`) - next_trial = search_method.get_trials_after_early_exit( - searcher_state, errored_trial, searcher.ExitedReason.ERRORED - )[0] - assert next_trial.mbs == 1 - search_method.trial_tracker.queue_and_register_trial(next_trial) - search_method.trial_tracker.queue.popleft() - if should_error_next_trial: - search_method.trial_tracker.report_trial_early_exit(next_trial) - else: - assert next_trial.searcher_metric_name - search_method.trial_tracker.update_trial_metric( - next_trial, {next_trial.searcher_metric_name: 0.0} - ) - - assert search_method.should_stop_lineage(next_trial) - - @pytest.mark.timeout(5) - def test_stop_when_other_configs_run_larger_batches( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Verify that we stop a lineage which cannot possibly run batches as large as other same-stage - configs can run. - """ - searcher_state, search_method = default_random_state_and_search_method - for stage in range(4): - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - good_hparams = copy.deepcopy(hparams) - good_hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = 2 - good_trial = search_method.trial_tracker.create_trial(good_hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(good_trial) - search_method.trial_tracker.queue.popleft() - assert good_trial.searcher_metric_name - search_method.trial_tracker.update_trial_metric( - good_trial, {good_trial.searcher_metric_name: 0.0} - ) - - bad_hparams = copy.deepcopy(hparams) - bad_hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = 1 - bad_trial = search_method.trial_tracker.create_trial(bad_hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(bad_trial) - search_method.trial_tracker.queue.popleft() - search_method.trial_tracker.report_trial_early_exit(bad_trial) - assert search_method.should_stop_lineage(bad_trial) - - -class TestRandomDSATSearchMethodChooseNextTrial: - """ - Testing the various conditions which should non-trivially trigger - RandomDSATSearchMethod.choose_next_trial_from_queue - """ - - @pytest.mark.timeout(5) - def test_pruning_stage_3_trials( - self, - default_random_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.RandomDSATSearchMethod - ], - ) -> None: - """ - Test the pruning of stage 3 trials. - """ - _, search_method = default_random_state_and_search_method - # Run a successful stage-1 trial. - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - successful_trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(successful_trial) - search_method.trial_tracker.queue.popleft() - assert successful_trial.searcher_metric_name - search_method.trial_tracker.update_trial_metric( - successful_trial, {successful_trial.searcher_metric_name: 0.0} - ) - - # Queue up a number of stage-3 trials and verify that choose_next_trial_from_queue - # returns a non-stage-3 trial and that no other stage-3 trials remain in the queue. - stage_three_trials = [] - for _ in range(10): - hparams, search_data = search_method.get_random_hparams_and_search_data(3) - trial = search_method.trial_tracker.create_trial(hparams, search_data) - stage_three_trials.append(trial) - search_method.trial_tracker.queue_and_register_trial(trial) - - # Then empty the queue and verify that all the trials which actually run are not - # stage 3, but rather their replacements. - while search_method.trial_tracker.queue: - next_trial = search_method.choose_next_trial_from_queue() - assert next_trial.stage != 3 - - -@pytest.fixture -def long_binary_state_and_search_method() -> ( - Generator[Tuple[searcher.SearcherState, dsat.BaseDSATSearchMethod], Any, None] -): - """For long-running tests which need a longer max_trials.""" - args = copy.deepcopy(DEFAULT_ARGS_DICT["binary"]) - args.max_trials = 10**3 - searcher_state, search_method = search_state_and_method_builder(args) - yield searcher_state, search_method - - -class TestBinaryDSATSearchMethod: - @pytest.mark.timeout(5) - def test_binary_happy_path( - self, - long_binary_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Ensure that when the actual `train_micro_batch_size_per_gpu` lies between the - search bounds, this optimal value will be found. - """ - searcher_state, search_method = long_binary_state_and_search_method - search_method.trial_tracker.queue.clear() - # Test for that all stages successfully find all possible values in their search range: - for stage in range(4): - _, search_data = search_method.get_random_hparams_and_search_data(stage) - num_possible_mbs = search_data.hi - search_data.lo + 1 - for target_mbs in range(search_data.lo, search_data.hi + 1): - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - first_trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(first_trial) - curr_trial = search_method.trial_tracker.queue.popleft() - for num_halvings in range(1, num_possible_mbs + 1): - assert curr_trial.search_data - assert curr_trial.search_data.lo <= curr_trial.mbs <= curr_trial.search_data.hi - if curr_trial.mbs > target_mbs: - search_method.on_trial_exited_early( - searcher_state, curr_trial.request_id, searcher.ExitedReason.ERRORED - ) - assert search_method.trial_tracker.queue - else: - assert curr_trial.searcher_metric_name - search_method.on_validation_completed( - searcher_state, - curr_trial.request_id, - {curr_trial.searcher_metric_name: 0.0}, - curr_trial.length, - ) - assert search_method.trial_tracker.queue - if curr_trial.mbs == target_mbs: - # Affirm that the solution was found as quickly as expected. - assert num_halvings <= int(math.log(num_possible_mbs, 2)) + 1 - break - curr_trial = search_method.trial_tracker.queue.popleft() - # queue should now be empty - assert not search_method.trial_tracker.queue - # Every trial should belong to the same lineage. - assert curr_trial.lineage_root == first_trial - assert curr_trial.mbs == target_mbs - - @pytest.mark.timeout(5) - def test_full_experiment_happy_path( - self, - long_binary_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Simulate running a full experiment with all successful trials, each improving on the last, - and verify the expected end state. - """ - searcher_state, search_method = long_binary_state_and_search_method - num_trials = 0 - while search_method.trial_tracker.can_run_more_trials: - trial = search_method.choose_next_trial_from_queue() - assert trial.searcher_metric_name is not None - num_trials += 1 - metric_val = ( - -1 * num_trials if search_method.trial_tracker.smaller_is_better else num_trials - ) - search_method.on_validation_completed( - searcher_state=searcher_state, - request_id=trial.request_id, - metric={trial.searcher_metric_name: metric_val}, - train_length=trial.length, - ) - # Verify that all max_trials were run. - assert ( - search_method.trial_tracker.num_completed_trials - == search_method.trial_tracker.max_trials - ) - # Verify that the best-found trial has the expected metric value - assert search_method.trial_tracker.best_trial is not None - assert search_method.trial_tracker.best_trial.metric == { - trial.searcher_metric_name: metric_val - } - - @pytest.mark.timeout(5) - def test_binary_no_trials_can_run( - self, - long_binary_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Verify expected behavior if every trial fails to even run batch size one. - """ - searcher_state, search_method = long_binary_state_and_search_method - search_method.trial_tracker.queue.clear() - # Test for that all stages successfully find all possible values in their search range: - for stage in range(4): - _, search_data = search_method.get_random_hparams_and_search_data(stage) - num_possible_mbs = search_data.hi - search_data.lo + 1 - target_mbs = 0 - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - first_trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(first_trial) - curr_trial = search_method.trial_tracker.queue.popleft() - for num_halvings in range(1, num_possible_mbs + 1): - assert curr_trial.search_data - assert curr_trial.search_data.lo <= curr_trial.mbs <= curr_trial.search_data.hi - assert curr_trial.mbs > target_mbs - search_method.on_trial_exited_early( - searcher_state, - curr_trial.request_id, - searcher.ExitedReason.ERRORED, - ) - assert search_method.trial_tracker.queue - if curr_trial.mbs == curr_trial.search_data.lo: - # Next trial should start a new lineage in this case. - next_lineage_trial = search_method.trial_tracker.queue.popleft() - assert not search_method.trial_tracker.queue - assert next_lineage_trial.lineage_root != first_trial - assert num_halvings <= int(math.log(num_possible_mbs, 2)) + 1 - break - else: - curr_trial = search_method.trial_tracker.queue.popleft() - assert not search_method.trial_tracker.queue - assert curr_trial.lineage_root == first_trial - - @pytest.mark.timeout(5) - def test_binary_range_too_small( - self, - long_binary_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Ensure that if the actual optimal batch size is larger than the initial range (which - hopefully never happens, but is possible), then the largest batch size in the range is - returned. - """ - searcher_state, search_method = long_binary_state_and_search_method - search_method.trial_tracker.queue.clear() - # test for that all stages successfully find all possible values in their search range: - for stage in range(4): - _, search_data = search_method.get_random_hparams_and_search_data(stage) - num_possible_mbs = search_data.hi - search_data.lo + 1 - target_mbs = search_data.hi + 1 - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - first_trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(first_trial) - curr_trial = search_method.trial_tracker.queue.popleft() - for num_halvings in range(1, num_possible_mbs + 1): - assert curr_trial.search_data - assert curr_trial.search_data.lo <= curr_trial.mbs <= curr_trial.search_data.hi - assert curr_trial.mbs < target_mbs - assert curr_trial.searcher_metric_name - search_method.on_validation_completed( - searcher_state, - curr_trial.request_id, - {curr_trial.searcher_metric_name: 0.0}, - curr_trial.length, - ) - assert search_method.trial_tracker.queue - if curr_trial.mbs == search_data.hi: - # Next trial should start a new lineage in this case. - next_lineage_trial = search_method.trial_tracker.queue.popleft() - assert not search_method.trial_tracker.queue - assert next_lineage_trial.lineage_root != first_trial - assert num_halvings <= int(math.log(num_possible_mbs, 2)) + 1 - break - curr_trial = search_method.trial_tracker.queue.popleft() - assert not search_method.trial_tracker.queue - assert curr_trial.lineage_root == first_trial - - -@pytest.fixture -def default_asha_state_and_search_method() -> ( - Generator[Tuple[searcher.SearcherState, dsat.BaseDSATSearchMethod], Any, None] -): - searcher_state, search_method = search_state_and_method_builder(DEFAULT_ARGS_DICT["asha"]) - yield searcher_state, search_method - - -@pytest.fixture -def long_asha_state_and_search_method() -> ( - Generator[Tuple[searcher.SearcherState, dsat.BaseDSATSearchMethod], Any, None] -): - args = copy.deepcopy(DEFAULT_ARGS_DICT["asha"]) - args.max_trials = 500 - args.max_rungs = 8 - searcher_state, search_method = search_state_and_method_builder(args) - yield searcher_state, search_method - - -@pytest.fixture -def long_large_min_resource_asha_state_and_search_method() -> ( - Generator[Tuple[searcher.SearcherState, dsat.BaseDSATSearchMethod], Any, None] -): - """ - For long-running tests which need a longer max_trials and resources. - """ - args = copy.deepcopy(DEFAULT_ARGS_DICT["asha"]) - args.max_trials = 10**3 - args.min_binary_search_trials = 10**3 - searcher_state, search_method = search_state_and_method_builder(args) - yield searcher_state, search_method - - -class TestASHADSATSearchMethod: - @pytest.mark.timeout(5) - def test_binary_happy_path( - self, - long_large_min_resource_asha_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Ensure that when the actual `train_micro_batch_size_per_gpu` lies between the - search bounds, this optimal value will be found. - """ - searcher_state, search_method = long_large_min_resource_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - # Test for that all stages successfully find all possible values in their search range: - stage = 1 - _, search_data = search_method.get_random_hparams_and_search_data(stage) - num_possible_mbs = search_data.hi - search_data.lo + 1 - for target_mbs in range(search_data.lo, search_data.hi + 1): - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(stage) - first_trial = search_method.trial_tracker.create_trial(hparams, search_data) - search_method.trial_tracker.queue_and_register_trial(first_trial) - curr_trial = search_method.trial_tracker.queue.popleft() - for num_halvings in range(1, num_possible_mbs + 1): - assert curr_trial.search_data is not None - assert curr_trial.search_data.lo <= curr_trial.mbs <= curr_trial.search_data.hi - if curr_trial.mbs > target_mbs: - search_method.on_trial_exited_early( - searcher_state, curr_trial.request_id, searcher.ExitedReason.ERRORED - ) - assert search_method.trial_tracker.queue - else: - assert curr_trial.searcher_metric_name is not None - search_method.on_validation_completed( - searcher_state, - curr_trial.request_id, - {curr_trial.searcher_metric_name: 0.0}, - curr_trial.length, - ) - assert search_method.trial_tracker.queue - if curr_trial.mbs == target_mbs: - # Affirm that the solution was found as quickly as expected. - assert num_halvings <= int(math.log(num_possible_mbs, 2)) + 1 - break - curr_trial = search_method.trial_tracker.queue.popleft() - # queue should now be empty - assert not search_method.trial_tracker.queue - # Every trial should belong to the same lineage. - assert curr_trial.lineage_root == first_trial - assert curr_trial.mbs == target_mbs - - @pytest.mark.timeout(5) - def test_full_experiment_happy_path( - self, - default_asha_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Simulate running a full experiment with all successful trials, each improving on the last - and verify the expected end state. - """ - searcher_state, search_method = default_asha_state_and_search_method - num_trials = 0 - while search_method.trial_tracker.can_run_more_trials: - trial = search_method.choose_next_trial_from_queue() - assert trial.searcher_metric_name is not None - num_trials += 1 - metric_val = ( - -1 * num_trials if search_method.trial_tracker.smaller_is_better else num_trials - ) - search_method.on_validation_completed( - searcher_state=searcher_state, - request_id=trial.request_id, - metric={trial.searcher_metric_name: metric_val}, - train_length=trial.length, - ) - # Verify that all max_trials were run. - assert ( - search_method.trial_tracker.num_completed_trials - == search_method.trial_tracker.max_trials - ) - # Verify that the best-found trial has the expected metric value - assert search_method.trial_tracker.best_trial is not None - assert search_method.trial_tracker.best_trial.metric == { - trial.searcher_metric_name: metric_val - } - - @pytest.mark.timeout(10) - def test_full_experiment_reverse_ordered_results( - self, - long_asha_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Simulate running a full experiment with all successful trials, each worse than the last, - and verify the expected end state, which is that trials in the higher rungs should have - better metrics than those which were never promoted out of the rungs. - """ - searcher_state, search_method = long_asha_state_and_search_method - assert isinstance(search_method, dsat.ASHADSATSearchMethod) - metrics = list(range(search_method.trial_tracker.max_trials - 1)) - if not search_method.trial_tracker.smaller_is_better: - metrics = metrics[::-1] - for metric in metrics: - trial = search_method.choose_next_trial_from_queue() - assert trial.searcher_metric_name is not None - search_method.on_validation_completed( - searcher_state=searcher_state, - request_id=trial.request_id, - metric={trial.searcher_metric_name: metric}, - train_length=trial.length, - ) - # Verify that the higher rungs contain lineages which performed better than lower rungs. - for rung_idx in range(search_method.max_rungs - 1): - lower_rung_trials = search_method.rungs[rung_idx] - higher_rung_trials = search_method.rungs[rung_idx + 1] - if higher_rung_trials: - non_promoted_lower_rung_trials = [ - lo - for lo in lower_rung_trials - if not any(lo in hi.lineage_set for hi in higher_rung_trials) - ] - # Every best-metric result in this set should be worse than every best-metric result - # in higher_rung_trials - for lo in non_promoted_lower_rung_trials: - best_lo_trial = search_method.get_best_trial_in_lineage(lo, rung_idx) - assert best_lo_trial - assert best_lo_trial.metric - assert isinstance(best_lo_trial.metric, dict) - assert best_lo_trial.searcher_metric_name - best_lo_metric = best_lo_trial.metric[best_lo_trial.searcher_metric_name] - for hi in higher_rung_trials: - best_hi_trial = search_method.get_best_trial_in_lineage(hi, rung_idx + 1) - assert best_hi_trial - assert best_hi_trial.metric - assert isinstance(best_hi_trial.metric, dict) - assert best_hi_trial.searcher_metric_name - best_hi_metric = best_hi_trial.metric[best_lo_trial.searcher_metric_name] - assert best_lo_metric < best_hi_metric - - @pytest.mark.timeout(5) - def test_promotion_respects_rung_idx( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Test that promotion from a given rung_idx only accounts for the results of each lineage with - curr_rung <= rung_idx. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - # Create three lineages which complete the first rung with three different metrics - metrics = list(range(search_method.divisor)) - # Order so that the worst lineage is last: - if not search_method.trial_tracker.smaller_is_better: - metrics = metrics[::-1] - for metric in metrics: - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - trial = None - for idx in range(search_method.max_trials_for_rung_idx(0)): - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=trial - ) - assert isinstance(trial.search_data, dsat.ASHADSATSearchData) - assert trial.search_data.curr_rung == 0 - assert trial.searcher_metric_name is not None - search_method.trial_tracker.queue_and_register_trial(trial) - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: metric} - ) - assert len(trial.lineage_set) == idx + 1 - assert trial is not None - assert search_method.lineage_completed_rung(trial, 0) - assert not search_method.lineage_completed_rung(trial, 1) - assert search_method.get_next_promotable_lineage() - - # Take the worst lineage in rung zero, promote it, and complete its next rung with better - # metrics than any seen in rung zero. - assert trial - best_metric = ( - min(metrics) if search_method.trial_tracker.smaller_is_better else max(metrics) - ) - next_metric = ( - best_metric - 1 if search_method.trial_tracker.smaller_is_better else best_metric + 1 - ) - while trial.num_completed_trials_in_lineage < search_method.max_trials_for_rung_idx(1): - search_data = copy.deepcopy(search_data) - search_data.curr_rung = 1 - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=search_data, parent_trial=trial - ) - assert isinstance(trial.search_data, dsat.ASHADSATSearchData) - assert trial.searcher_metric_name is not None - search_method.trial_tracker.queue_and_register_trial(trial) - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: next_metric} - ) - # Next promotable trial should be from the lowest rung. - assert ( - search_method.get_next_promotable_lineage() - == search_method.get_next_promotable_lineage_in_rung(0) - ) - - # And the promoted trial should not take the improved performance of the previously-worst - # rung_idx = 0 lineage into account. - next_promoted_trial = search_method.get_next_promotable_lineage_in_rung(0) - assert next_promoted_trial - assert next_promoted_trial.metric - assert isinstance(next_promoted_trial.metric, dict) - assert next_promoted_trial.searcher_metric_name - assert next_promoted_trial.metric[next_promoted_trial.searcher_metric_name] != next_metric - assert next_promoted_trial.metric[next_promoted_trial.searcher_metric_name] == best_metric - - @pytest.mark.timeout(5) - def test_choose_next_trial_from_queue( - self, - default_asha_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.ASHADSATSearchMethod - ], - ) -> None: - """ - Verify that the `choose_next_trial_from_queue` method both chooses a trial with the largest - curr_rung value and from all such choices choose the trial with the longest lineage - """ - searcher_state, search_method = default_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - # Create an arbitrary counter to differentiate hparams and avoid the duplicate check in - # `queue_and_register_trial`. - arbitrary = 0 - - # Create a curr_rung = 0 lineage - trial = None - hparams = copy.deepcopy(hparams) - hparams["_arbitrary"] = arbitrary - arbitrary += 1 - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=trial - ) - search_method.trial_tracker.queue_and_register_trial(trial) - assert trial.searcher_metric_name - search_method.trial_tracker.update_trial_metric(trial, {trial.searcher_metric_name: 0.0}) - - # Create several curr_rung = 1 lineages of varying lengths - for num_in_lineage in range(1, 3): - trial = None - for _ in range(num_in_lineage): - hparams = copy.deepcopy(hparams) - hparams["_arbitrary"] = arbitrary - arbitrary += 1 - search_data = copy.deepcopy(search_data) - search_data.curr_rung = 1 - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=search_data, parent_trial=trial - ) - search_method.trial_tracker.queue_and_register_trial(trial) - assert trial.searcher_metric_name - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: 0.0} - ) - - # Get the next trial: - next_trial = search_method.choose_next_trial_from_queue() - assert next_trial.search_data - assert isinstance(next_trial.search_data, dsat.ASHADSATSearchData) - assert next_trial.search_data.curr_rung == 1 - assert next_trial.num_completed_trials_in_lineage == num_in_lineage - - @pytest.mark.timeout(5) - def test_get_best_trial_in_lineage( - self, - default_asha_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.ASHADSATSearchMethod - ], - ) -> None: - """ - Test the `get_best_trial_in_lineage` method and verify that it respects the `max_rung_idx` - arg appropriately. - """ - searcher_state, search_method = default_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - trial = None - # Let the metric improve with each rung. - for rung_idx in range(search_method.max_rungs): - while ( - not trial - or trial.num_completed_trials_in_lineage - < search_method.max_trials_for_rung_idx(rung_idx) - ): - metric = ( - -1 * rung_idx if search_method.trial_tracker.smaller_is_better else rung_idx - ) - search_data = copy.deepcopy(search_data) - search_data.curr_rung = rung_idx - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=search_data, parent_trial=trial - ) - assert trial.searcher_metric_name - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: metric} - ) - for rung_idx in range(search_method.max_rungs): - assert trial - best_trial = search_method.get_best_trial_in_lineage(trial, max_rung_idx=rung_idx) - assert best_trial - assert best_trial.metric - assert isinstance(best_trial.metric, dict) - assert best_trial.searcher_metric_name - best_trial_metric = best_trial.metric[best_trial.searcher_metric_name] - expected_metric = ( - -1 * rung_idx if search_method.trial_tracker.smaller_is_better else rung_idx - ) - assert best_trial_metric == expected_metric - - @pytest.mark.timeout(5) - def test_get_top_lineages_in_rung( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Populate the lowest rung with trials with increasing metric values across lineages. - Verify that the reported best lineages are the expected ones. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - metrics = list(range(10 * search_method.divisor)) - for metric in metrics: - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - trial = None - for idx in range(search_method.min_binary_search_trials): - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=trial - ) - assert isinstance(trial.search_data, dsat.ASHADSATSearchData) - assert trial.search_data.curr_rung == 0 - assert trial.searcher_metric_name is not None - search_method.trial_tracker.queue_and_register_trial(trial) - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: metric} - ) - assert len(trial.lineage_set) == idx + 1 - assert trial is not None - assert search_method.lineage_completed_rung(trial, 0) - assert not search_method.lineage_completed_rung(trial, 1) - - top_trials = search_method.get_top_lineages_in_rung(0) - assert len(top_trials) == len(search_method.rungs[0]) // search_method.divisor - # Verify that the metrics of the top trials take on their expected values. - if search_method.trial_tracker.smaller_is_better: - expected_metrics = metrics[: len(top_trials)] - else: - expected_metrics = list(reversed(metrics[len(top_trials) :])) - assert trial is not None - assert trial.searcher_metric_name is not None - actual_metrics = [] - for t in top_trials: - best_trial_in_lineage = search_method.get_best_trial_in_lineage(t) - assert best_trial_in_lineage is not None - assert isinstance(best_trial_in_lineage.metric, dict) - actual_metrics.append(best_trial_in_lineage.metric[trial.searcher_metric_name]) - assert expected_metrics == actual_metrics - - @pytest.mark.timeout(5) - def test_basic_promotion( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Populate the rungs such that there is a promotable lineage and test that the promoted - lineage has the expected properties. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - # Complete enough trials so that some can be promoted. - - for _ in range(search_method.max_trials_for_rung_idx(1)): - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - trial = None - for trial_num in range(search_method.min_binary_search_trials): - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=trial - ) - assert trial.search_data is not None - assert isinstance(trial.search_data, dsat.ASHADSATSearchData) - assert trial.search_data.curr_rung == 0 - search_method.trial_tracker.queue_and_register_trial(trial) - assert trial.searcher_metric_name is not None - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: 0.0} - ) - assert len(trial.lineage_set) == trial_num + 1 - assert trial is not None - assert search_method.lineage_completed_rung(trial, 0) - assert not search_method.lineage_completed_rung(trial, 1) - - next_promotable_lineage = search_method.get_next_promotable_lineage() - assert next_promotable_lineage is not None - next_trial = search_method.get_next_trial_in_lineage(next_promotable_lineage) - assert next_trial is not None - assert next_trial.search_data is not None - assert isinstance(next_trial.search_data, dsat.ASHADSATSearchData) - next_trial.search_data.curr_rung += 1 - assert next_trial.search_data.curr_rung == 1 - assert len(next_trial.lineage_set) == search_method.min_binary_search_trials + 1 - - @pytest.mark.timeout(5) - def test_lineage_continutation( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Verify that we continue trials which have not yet completed their rung. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - first_trial = curr_trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=None - ) - search_method.trial_tracker.queue_and_register_trial(first_trial) - _ = search_method.trial_tracker.queue.popleft() - for trial_num in range(search_method.min_binary_search_trials): - assert isinstance(curr_trial.search_data, dsat.ASHADSATSearchData) - assert curr_trial.search_data.curr_rung == 0 - assert curr_trial.lineage_root == first_trial - assert curr_trial.num_completed_trials_in_lineage == trial_num - assert not search_method.lineage_completed_rung(curr_trial, 0) - assert curr_trial.searcher_metric_name is not None - search_method.on_validation_completed( - searcher_state=searcher_state, - request_id=curr_trial.request_id, - metric={curr_trial.searcher_metric_name: 0.0}, - train_length=curr_trial.length, - ) - assert curr_trial.completed - curr_trial = search_method.trial_tracker.queue.popleft() - # Force the search data to be non-trivial, so that we avoid exiting due to a trivial - # search range. - assert curr_trial.search_data - curr_trial.search_data.lo = 1 - curr_trial.search_data.hi = 10 - curr_trial.ds_config["train_micro_batch_size_per_gpu"] = 5 - - assert search_method.lineage_completed_rung(first_trial, 0) - assert curr_trial.lineage_root != first_trial - - @pytest.mark.timeout(5) - def test_top_promotion( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Verify that if multiple lineages can be promoted, we promote from the higest-rung lineage - available. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - good_metric, bad_metric = ( - (0.0, 1.0) if search_method.trial_tracker.smaller_is_better else (1.0, 0.0) - ) - # Fill two rungs with trials search_method.divisor trials, so that there are enough to - # promote from the top rung. - - # Create several lineages which complete rung_idx = 1: - max_trials_for_rung_one = search_method.max_trials_for_rung_idx(1) - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - # Create an arbitrary counter to differentiate hparams and avoid the duplicate check in - # `queue_and_register_trial`. - arbitrary = 0 - # Add `divisor` such lineages, so that one can be promoted. - for lineage_number in range(1, 1 + search_method.divisor): - trial = None - while ( - search_method.trial_tracker.num_completed_trials - < lineage_number * max_trials_for_rung_one + 1 - ): - hparams = copy.deepcopy(hparams) - hparams["_arbitrary"] = arbitrary - arbitrary += 1 - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=trial - ) - search_method.trial_tracker.queue_and_register_trial(trial) - _ = search_method.trial_tracker.queue.popleft() - assert trial.searcher_metric_name is not None - search_method.trial_tracker.update_trial_metric( - trial=trial, - metric={trial.searcher_metric_name: bad_metric}, - ) - # Promote as appropriate - assert trial - assert trial.search_data - assert isinstance(trial.search_data, dsat.ASHADSATSearchData) - if trial.num_completed_trials_in_lineage > search_method.max_trials_for_rung_idx(0): - trial.search_data.curr_rung = 1 - - # Check that we have populated the rungs as expected: - assert all(search_method.rungs[idx] for idx in range(2)) - assert not any(search_method.rungs[idx] for idx in range(2, search_method.max_rungs - 1)) - assert search_method.get_next_promotable_lineage() - assert search_method.get_next_promotable_lineage_in_rung(1) - assert not search_method.get_next_promotable_lineage_in_rung(0) - - # Submit another lineage which completes the lowest rung with better metrics than the - # lineage above, so that it is promotable. - trial = None - for _ in range(search_method.max_trials_for_rung_idx(0)): - hparams = copy.deepcopy(hparams) - hparams["_arbitrary"] = arbitrary - arbitrary += 1 - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=trial - ) - search_method.trial_tracker.queue_and_register_trial(trial) - _ = search_method.trial_tracker.queue.popleft() - assert trial.searcher_metric_name is not None - search_method.trial_tracker.update_trial_metric( - trial=trial, - metric={trial.searcher_metric_name: good_metric}, - ) - # Verify the counting above and that the next promoted trial will come from the topmost - # possible rung. - assert len(search_method.rungs[0]) == search_method.divisor + 1 - assert len(search_method.rungs[1]) == search_method.divisor - - next_lineage_rung_0 = search_method.get_next_promotable_lineage_in_rung(0) - assert next_lineage_rung_0 - assert next_lineage_rung_0.search_data - assert isinstance(next_lineage_rung_0.search_data, dsat.ASHADSATSearchData) - - next_lineage_rung_1 = search_method.get_next_promotable_lineage_in_rung(1) - assert next_lineage_rung_1 - assert next_lineage_rung_1.search_data - assert isinstance(next_lineage_rung_1.search_data, dsat.ASHADSATSearchData) - - assert next_lineage_rung_0.search_data.curr_rung == 0 - assert next_lineage_rung_1.search_data.curr_rung == 1 - - next_promotable_lineage = search_method.get_next_promotable_lineage() - assert next_promotable_lineage - assert next_promotable_lineage.search_data - assert isinstance(next_promotable_lineage.search_data, dsat.ASHADSATSearchData) - - assert next_promotable_lineage.search_data.curr_rung == 1 - - @pytest.mark.timeout(5) - def test_max_resource_respected( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Verify that we respect the maximum resource per lineage. - """ - # Create a lineage with the maximum resource per lineage - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=None - ) - search_method.trial_tracker.queue_and_register_trial(trial) - _ = search_method.trial_tracker.queue.popleft() - max_binary_search_trials = ( - search_method.min_binary_search_trials - * search_method.divisor ** (search_method.max_rungs - 1) - ) - for _ in range(max_binary_search_trials): - assert trial.searcher_metric_name is not None - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: 0.0} - ) - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=copy.deepcopy(search_data), parent_trial=trial - ) - search_method.trial_tracker.queue_and_register_trial(trial) - _ = search_method.trial_tracker.queue.popleft() - assert search_method.lineage_completed_rung(trial, search_method.max_rungs - 1) - assert search_method.get_next_promotable_lineage() is None - - @pytest.mark.timeout(5) - def test_no_continuation_for_completed_lineages( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Verify that lineages which have completed their binary search are not continued. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, _ = search_method.get_random_hparams_and_search_data(1) - search_data = dsat.ASHADSATSearchData(lo=1, hi=1, curr_rung=0) - hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = ( - search_data.hi + search_data.lo - ) // 2 - trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=search_data, parent_trial=None - ) - search_method.trial_tracker.queue_and_register_trial(trial) - _ = search_method.trial_tracker.queue.popleft() - assert trial.searcher_metric_name is not None - search_method.trial_tracker.update_trial_metric(trial, {trial.searcher_metric_name: 0.0}) - assert search_method.get_next_trial_in_lineage(trial) is None - - @pytest.mark.timeout(5) - def test_completed_binary_search_lineages_are_counted_complete( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Verify that if a lineage successfully completes its binary search mid-rung, that lineage - is counted as having completed the rung. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, _ = search_method.get_random_hparams_and_search_data(1) - search_data = dsat.ASHADSATSearchData(lo=1, hi=1, curr_rung=0) - hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = ( - search_data.hi + search_data.lo - ) // 2 - successful_trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=search_data, parent_trial=None - ) - search_method.trial_tracker.queue_and_register_trial(successful_trial) - _ = search_method.trial_tracker.queue.popleft() - assert successful_trial.searcher_metric_name is not None - assert successful_trial.search_data is not None - assert isinstance(successful_trial.search_data, dsat.ASHADSATSearchData) - search_method.trial_tracker.update_trial_metric( - successful_trial, {successful_trial.searcher_metric_name: 0.0} - ) - assert search_method.lineage_completed_rung( - successful_trial, successful_trial.search_data.curr_rung - ) - assert search_method.get_next_trial_in_lineage(successful_trial) is None - - @pytest.mark.timeout(5) - def test_failed_binary_search_lineages_are_counted_complete( - self, - long_asha_state_and_search_method: Tuple[searcher.SearcherState, dsat.ASHADSATSearchMethod], - ) -> None: - """ - Verify that if a lineage fails its binary search mid-rung by failing on the minimum - possible batch size, that lineage is counted as having completed the rung. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, _ = search_method.get_random_hparams_and_search_data(1) - search_data = dsat.ASHADSATSearchData(lo=1, hi=2, curr_rung=0) - hparams[dsat.defaults.OVERWRITE_KEY]["train_micro_batch_size_per_gpu"] = ( - search_data.hi + search_data.lo - ) // 2 - failed_trial = search_method.trial_tracker.create_trial( - hparams=hparams, search_data=search_data, parent_trial=None - ) - search_method.trial_tracker.queue_and_register_trial(failed_trial) - _ = search_method.trial_tracker.queue.popleft() - search_method.trial_tracker.report_trial_early_exit(failed_trial) - assert isinstance(failed_trial.search_data, dsat.ASHADSATSearchData) - assert search_method.lineage_completed_rung( - failed_trial, failed_trial.search_data.curr_rung - ) - assert search_method.get_next_trial_in_lineage(failed_trial) is None - - @pytest.mark.timeout(5) - def test_lineage_completed_rung( - self, - long_asha_state_and_search_method: Tuple[ - searcher.SearcherState, dsat.BinarySearchDSATSearchMethod - ], - ) -> None: - """ - Testing the `lineage_completed_rung` method by creating a very long lineage and verifying - that this method gives the expected results. - """ - searcher_state, search_method = long_asha_state_and_search_method - search_method.trial_tracker.queue.clear() - hparams, search_data = search_method.get_random_hparams_and_search_data(1) - trial = None - assert isinstance(search_method, dsat.ASHADSATSearchMethod) - num_trials_to_fill_all_rungs = search_method.max_trials_for_rung_idx( - search_method.max_rungs - 1 - ) - for num_trials in range(1, 1 + num_trials_to_fill_all_rungs): - hparams = copy.deepcopy(hparams) - # Add arbitrary hp to avoid non-duplicate hparams check in `queue_and_register_trial` - hparams["_arbitrary"] = num_trials - search_data = copy.deepcopy(search_data) - trial = search_method.trial_tracker.create_trial( - hparams, search_data, parent_trial=trial - ) - assert trial.searcher_metric_name is not None - assert trial.search_data is not None - assert isinstance(trial.search_data, dsat.ASHADSATSearchData) - search_method.trial_tracker.queue_and_register_trial(trial) - _ = search_method.trial_tracker.queue.popleft() - search_method.trial_tracker.update_trial_metric( - trial, {trial.searcher_metric_name: 0.0} - ) - if num_trials < search_method.max_trials_for_rung_idx(trial.search_data.curr_rung): - assert not search_method.lineage_completed_rung(trial, trial.search_data.curr_rung) - else: - old_rung = trial.search_data.curr_rung - assert trial.search_data - assert isinstance(trial.search_data, dsat.ASHADSATSearchData) - trial.search_data.curr_rung += 1 - for t in trial.lineage_set: - assert t.search_data is not None - assert isinstance(t.search_data, dsat.ASHADSATSearchData) - for rung_idx in range(0, old_rung + 1): - assert search_method.lineage_completed_rung(trial, rung_idx) - assert not search_method.lineage_completed_rung(trial, old_rung + 1) - - -class TestHFConfigOverwriting: - @pytest.mark.timeout(5) - def test_overwritten_args(self) -> None: - """ - Verify that `get_hf_args_with_overwrites` returns the expected args. - """ - optional_arg_possibilities: List[List[str]] = [ - [], - ["--per_device_train_batch_size", "8"], - ["--gradient_accumulation_steps", "4"], - ["--per_device_train_batch_size", "8", "--gradient_accumulation_steps", "4"], - ] - for optional_args in optional_arg_possibilities: - with tempfile.TemporaryDirectory() as d: - ds_config_path = pathlib.Path(d).joinpath("ds_config.json") - shutil.copyfile(HF_DS_CONFIG_PATH, ds_config_path) - args = ( - DEFAULT_HF_ARGS_WITHOUT_DEEPSPEED - + ["--deepspeed", str(ds_config_path)] - + optional_args - ) - args = dsat.get_hf_args_with_overwrites(args=args, hparams=HPARAMS_FIXTURE) - hf_flag_to_ds_key = { - "--per_device_train_batch_size": "train_micro_batch_size_per_gpu", - "--gradient_accumulation_steps": "gradient_accumulation_steps", - } - for idx in range(len(args)): - if args[idx] in hf_flag_to_ds_key: - hf_flag = args[idx] - ds_key = hf_flag_to_ds_key[hf_flag] - expected_hf_value = HPARAMS_FIXTURE[dsat.defaults.OVERWRITE_KEY][ds_key] - actual_hf_value = int(args[idx + 1]) - assert actual_hf_value == expected_hf_value - - @pytest.mark.timeout(5) - def test_overwritten_config_file(self) -> None: - """ - Verify that `get_hf_args_with_overwrites` overwrite the ds config file. - """ - with tempfile.TemporaryDirectory() as d: - overwrite_dict = HPARAMS_FIXTURE[dsat.defaults.OVERWRITE_KEY] - ds_config_path = pathlib.Path(d).joinpath("ds_config.json") - shutil.copyfile(HF_DS_CONFIG_PATH, ds_config_path) - - # Verify that the original config values are different from those we are overwriting. - with open(ds_config_path, "r") as f: - original_ds_config = json.load(f) - for k, v in overwrite_dict.items(): - assert original_ds_config.get(k) != v - args = DEFAULT_HF_ARGS_WITHOUT_DEEPSPEED + ["--deepspeed", str(ds_config_path)] - _ = dsat.get_hf_args_with_overwrites(args=args, hparams=HPARAMS_FIXTURE) - with open(ds_config_path, "r") as f: - overwritten_ds_config = json.load(f) - for k, v in overwrite_dict.items(): - assert overwritten_ds_config.get(k) == v - - @pytest.mark.timeout(5) - def test_no_auto_in_cli_args(self) -> None: - """ - Verify that if the user has an "overwrite_deepspeed_args" key in their hparam dict, but the - ds config json still has "auto" for batch size arguments, these "auto" settings are not - propagated as CLI args. Needed for cases where the user wants to overwrite some json - fields via the yaml config, but still wants to configure the batch size through HF CLI - entrypoint flags. - """ - optional_arg_possibilities: List[List[str]] = [ - [], - ["--per_device_train_batch_size", "8"], - ["--gradient_accumulation_steps", "4"], - ["--per_device_train_batch_size", "8", "--gradient_accumulation_steps", "4"], - ] - for optional_args in optional_arg_possibilities: - with tempfile.TemporaryDirectory() as d: - ds_config_path = pathlib.Path(d).joinpath("ds_config.json") - shutil.copyfile(HF_DS_CONFIG_PATH, ds_config_path) - args = ( - DEFAULT_HF_ARGS_WITHOUT_DEEPSPEED - + ["--deepspeed", str(ds_config_path)] - + optional_args - ) - hparams = copy.deepcopy(HPARAMS_FIXTURE) - # Make the overwrite section non-trivial, but also independent of batch-size args. - hparams[dsat.defaults.OVERWRITE_KEY] = {"arbitrary": 0} - args = dsat.get_hf_args_with_overwrites(args=args, hparams=hparams) - hf_flag_to_ds_key = { - "--per_device_train_batch_size": "train_micro_batch_size_per_gpu", - "--gradient_accumulation_steps": "gradient_accumulation_steps", - } - for idx in range(len(args)): - if args[idx] in hf_flag_to_ds_key: - actual_hf_value = args[idx + 1] - assert actual_hf_value != "auto" - - -class DSATMockMaster(custom_search_mocks.MockMaster): - """ - Sends v1 metrics back to the Search Runner in the manner defined with the - `all_metrics` list of dictionaries. - - The metrics are sent as a `v1ValidationCompleted` metric event. When the key for - the metric is instead `ERROR_METRIC_NAME`, this signals to the `MockMaster` to - instead send a `v1TrialExitedEarly` event to the Search Runner. - """ - - def __init__(self, all_metrics: List[Dict[str, Any]]) -> None: - self.events_queue: Deque[bindings.v1SearcherEvent] = collections.deque([]) - self.events_count = 0 - self.all_metrics = all_metrics - self.metric_index = 0 - - def handle_post_operations( - self, event: bindings.v1SearcherEvent, operations: List[searcher.Operation] - ) -> None: - self._remove_upto(event) - self._process_operations(operations) - - def _remove_upto(self, event: bindings.v1SearcherEvent) -> None: - while len(self.events_queue) > 0: - e = self.events_queue.popleft() - if e.id == event.id: - return - - raise RuntimeError(f"event not found in events queue: {event}") - - def _process_operations(self, operations: List[searcher.Operation]) -> None: - for op in operations: - self._append_events_for_op(op) # validate_after returns two events. - - def add_event(self, event_obj: bindings.v1SearcherEvent) -> None: - self.events_queue.append(event_obj) - - def handle_get_events(self) -> Optional[Sequence[bindings.v1SearcherEvent]]: - return list(self.events_queue) - - def _append_events_for_op(self, op: searcher.Operation) -> None: - if isinstance(op, searcher.ValidateAfter): - metric = self.all_metrics[self.metric_index] - self.metric_index += 1 - if isinstance(metric, dict) and ERROR_METRIC_NAME in metric: - trial_exited_early = bindings.v1TrialExitedEarly( - requestId=str(op.request_id), - exitedReason=bindings.v1TrialExitedEarlyExitedReason.UNSPECIFIED, - ) - self.events_count += 1 - event = bindings.v1SearcherEvent( - id=self.events_count, trialExitedEarly=trial_exited_early - ) - self.events_queue.append(event) - - trial_closed = bindings.v1TrialClosed(requestId=str(op.request_id)) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, trialClosed=trial_closed) - self.events_queue.append(event) - else: - validation_completed = bindings.v1ValidationCompleted( - requestId=str(op.request_id), - metric=metric, - validateAfterLength=str(op.length), - ) - - self.events_count += 1 - event = bindings.v1SearcherEvent( - id=self.events_count, validationCompleted=validation_completed - ) - self.events_queue.append(event) - - # Send 1.0 to signal it was completed - trial_progress = bindings.v1TrialProgress( - requestId=str(op.request_id), partialUnits=1.0 - ) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, trialProgress=trial_progress) - self.events_queue.append(event) - - elif isinstance(op, searcher.Create): - trial_created = bindings.v1TrialCreated(requestId=str(op.request_id)) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, trialCreated=trial_created) - self.events_queue.append(event) - - elif isinstance(op, searcher.Progress): # no events - pass - - elif isinstance(op, searcher.Close): - trial_closed = bindings.v1TrialClosed(requestId=str(op.request_id)) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, trialClosed=trial_closed) - self.events_queue.append(event) - - elif isinstance(op, searcher.Shutdown): - exp_state = ( - bindings.experimentv1State.ERROR - if op.failure - else bindings.experimentv1State.COMPLETED - ) - exp_inactive = bindings.v1ExperimentInactive(experimentState=exp_state) - self.events_count += 1 - event = bindings.v1SearcherEvent(id=self.events_count, experimentInactive=exp_inactive) - self.events_queue.append(event) diff --git a/harness/tests/experiment/pytorch/test_pytorch_trial.py b/harness/tests/experiment/pytorch/test_pytorch_trial.py index daa7f0463f8..5c591a9f9e9 100644 --- a/harness/tests/experiment/pytorch/test_pytorch_trial.py +++ b/harness/tests/experiment/pytorch/test_pytorch_trial.py @@ -1228,14 +1228,14 @@ def test_trial_validation_checkpointing(self, tmp_path: pathlib.Path): return_value=checkpoint_condition["best_validation"] ) controller._checkpoint = mock.MagicMock() - controller._validate(det.core.DummySearcherOperation(length=100, is_chief=True)) + controller._validate() controller.core_context.train.get_experiment_best_validation.assert_called_once() if checkpoint_condition["checkpoint"]: controller._checkpoint.assert_called_once() controller.core_context.train.get_experiment_best_validation.reset_mock() controller._checkpoint.reset_mock() - @mock.patch.object(det.core.DummySearcherOperation, "report_progress") + @mock.patch.object(det.core.DummyTrainContext, "report_progress") def test_searcher_progress_reporting(self, mock_report_progress: mock.MagicMock): trial, controller = pytorch_utils.create_trial_and_trial_controller( trial_class=pytorch_onevar_model.OneVarTrial, @@ -1246,8 +1246,8 @@ def test_searcher_progress_reporting(self, mock_report_progress: mock.MagicMock) ) controller.run() - exp_prog = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] - got_prog = [x.args[0] for x in mock_report_progress.call_args_list] + exp_prog = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + got_prog = [x.kwargs["progress"] for x in mock_report_progress.call_args_list] assert exp_prog == got_prog def test_test_mode(self): @@ -1514,12 +1514,14 @@ def run_amp(tmp_path: pathlib.Path, api_style: str, batches_trained: typing.Opti "manual": "MNistManualAMPTrial", } - config = utils.load_config(utils.fixtures_path(f"pytorch_amp/{api_style}_amp_distributed.yaml")) - config = config.copy() - config.setdefault("profiling", {}) - config["profiling"]["enabled"] = True - - hparams = config["hyperparameters"] + hparams = { + "learning_rate": 1.0, + "global_batch_size": 64, + "n_filters1": 32, + "n_filters2": 64, + "dropout1": 0.25, + "dropout2": 0.5, + } exp_config = utils.make_default_exp_config( hparams, @@ -1527,8 +1529,8 @@ def run_amp(tmp_path: pathlib.Path, api_style: str, batches_trained: typing.Opti searcher_metric="validation_loss", checkpoint_dir=checkpoint_dir, ) - exp_config.update(config) exp_config["searcher"]["smaller_is_better"] = True + exp_config.setdefault("profiling", {})["enabled"] = True example_path = utils.fixtures_path(f"pytorch_amp/{api_style}_amp_model_def.py") trial_class = utils.import_class_from_module(class_selector[api_style], example_path) diff --git a/harness/tests/experiment/test_utils.py b/harness/tests/experiment/test_utils.py new file mode 100644 index 00000000000..7040f4988de --- /dev/null +++ b/harness/tests/experiment/test_utils.py @@ -0,0 +1,47 @@ +from typing import Any, Dict, Tuple, Union + +from tests.experiment import utils + + +def test_assert_events_match() -> None: + """ + Make sure our test utility actually works, since it is the basis for + ensuring that our callback-based 3rd-party integrations works. + """ + + def expect_success( + events: utils.Events, *patterns: Union[str, Tuple[str, str]] + ) -> Dict[str, Any]: + try: + return utils.assert_events_match(events, *patterns) + except AssertionError: + raise AssertionError(f"expected success: {patterns}") + + def expect_failure(events: utils.Events, *patterns: str) -> None: + try: + utils.assert_events_match(events, *patterns) + except AssertionError: + pass + else: + raise AssertionError(f"expected failure: {patterns}") + + events = utils.Events([("1", None), ("2", 2), ("3", None)]) + + expect_success(events, "1") + expect_success(events, "2") + expect_success(events, "3") + expect_success(events, "1", "2", "3") + expect_success(events, "1", "!4") + expect_success(events, "!0", "2") + expect_success(events, "!2", "1", "2") + expect_success(events, "[0-3]", "[0-3]", "[0-3]") + # Make sure a positive match takes precedence over a negative match. + expect_success(events, "![3-9]", "3") + + expect_failure(events, "1", "3", "4") + expect_failure(events, "1", "!2") + expect_failure(events, "1", "!2", "3") + expect_failure(events, "!1", "2") + + # Make sure we capture the data for events like we expect. + assert expect_success(events, "1", ("2", "two"), "3") == {"two": 2} diff --git a/e2e_tests/tests/fixtures/custom_searcher_exp/__init__.py b/harness/tests/experiment/transformers/__init__.py similarity index 100% rename from e2e_tests/tests/fixtures/custom_searcher_exp/__init__.py rename to harness/tests/experiment/transformers/__init__.py diff --git a/harness/tests/experiment/transformers/test_callback.py b/harness/tests/experiment/transformers/test_callback.py new file mode 100644 index 00000000000..e1df54a1b65 --- /dev/null +++ b/harness/tests/experiment/transformers/test_callback.py @@ -0,0 +1,495 @@ +import pathlib +import re +from typing import Any, Callable, Dict, Optional, Tuple +from unittest import mock + +import numpy as np +import torch +import transformers + +import determined as det +import determined.transformers +from determined import core +from determined.common import storage +from tests.experiment import utils +from tests.launch import test_util + + +def mock_core_context( + path: str, events: utils.Events, distributed: Optional[core.DistributedContext] = None +) -> Tuple[core.Context, Callable[[], None]]: + """ + Returns a core_context and a set_preempt() callable. + + The core_context is partially mocked to support triggering preemption from test code and to log + all reports to the provided Events object. + """ + # Set up a functional DistributedContext. + distributed = distributed or core.DummyDistributedContext() + # Set up a functional CheckpointContext. + storage_manager = storage.SharedFSStorageManager(path) + + class DummyCheckpointContext(core.DummyCheckpointContext): + def _report_checkpoint( + self, + storage_id: str, + resources: Optional[Dict[str, int]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + events.append(("report_checkpoint", storage_id)) + super()._report_checkpoint(storage_id, resources, metadata) + + checkpoint = DummyCheckpointContext(distributed, storage_manager) + + # Mock everything else, logging report-like calls to events. + + def report_metrics(group: str, steps_completed: int, metrics: Any) -> None: + events.append((f"report_metrics:{group}:{steps_completed}", metrics)) + + def report_progress(progress: float) -> None: + fourdigits = "%.4f" % progress + events.append((f"report_progress:{fourdigits}", progress)) + + def set_status(status: str) -> None: + events.append((f"set_status:{status}", None)) + + preempted = False + + def should_preempt() -> bool: + nonlocal preempted + return preempted + + core_context = mock.Mock() + core_context.distributed = distributed + core_context.preempt.should_preempt.side_effect = should_preempt + core_context.checkpoint = checkpoint + core_context.train.report_metrics.side_effect = report_metrics + core_context.train.report_progress.side_effect = report_progress + core_context.train.set_status.side_effect = set_status + + def set_preempt() -> None: + nonlocal preempted + preempted = True + + return core_context, set_preempt + + +class MyOneVarModel(torch.nn.Linear): # type: ignore + """ + Subclass torch.nn.Linear with custom behaviors to be Transformers.Trainer-friendly. + """ + + def __init__(self) -> None: + super().__init__(1, 1, False) + self.weight.data.fill_(0) + self._loss_fn = torch.nn.MSELoss() + + # Signature must match key in dataset's output. + def forward(self, x: torch.Tensor, label_y: torch.Tensor) -> Dict[str, torch.Tensor]: + y = super().forward(x) + loss = self._loss_fn(y, label_y) + # We must return a dict with "loss" as a key. + # (technically a tuple with loss as the first element is also ok) + return {"loss": loss, "pred_y": y} + + +class OnesDataset(torch.utils.data.Dataset): + def __init__(self, dataset_len: int) -> None: + self.dataset_len = dataset_len + + def __len__(self) -> int: + return self.dataset_len + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + # Key name must match model's .forward() signature. + return {"x": torch.Tensor([float(1)]), "label_y": torch.Tensor([float(1)])} + + +def compute_metrics(pred: transformers.EvalPrediction) -> Dict[str, float]: + # Return a mean absolute error as a metric. + return {"mae": np.abs(pred.predictions - pred.label_ids).mean()} + + +class DetCallbackForTesting(det.transformers.DetCallback): + def __init__(self, events: utils.Events, *args: Any, **kwargs: Any) -> None: + self.events = events + super().__init__(*args, **kwargs) + + def on_train_begin( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"on_train_begin:{state.global_step}:{epoch}", None)) + + def on_epoch_begin( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"on_epoch_begin:{state.global_step}:{epoch}", None)) + + def on_epoch_end( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + weight = kwargs["model"].weight.data.item() + self.events.append((f"before_epoch_end:{state.global_step}:{epoch}", weight)) + super().on_epoch_end(args, state, control) + self.events.append((f"after_epoch_end:{state.global_step}:{epoch}", weight)) + + def on_save( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"before_save:{state.global_step}:{epoch}", None)) + super().on_save(args, state, control) + self.events.append((f"after_save:{state.global_step}:{epoch}", None)) + + def on_evaluate( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"on_evaluate:{state.global_step}:{epoch}", None)) + + def on_train_end(self, *args: Any, **kwargs: Any) -> None: + self.events.append(("on_train_end", None)) + + +def do_train( + tmp_path: pathlib.Path, + force_final_save: Optional[bool] = None, + force_final_evaluate: Optional[bool] = None, + set_preempt_on_event: Optional[str] = None, + latest_checkpoint: Optional[str] = None, + **kwargs: Any, +) -> utils.Events: + args = transformers.TrainingArguments( + output_dir=str(tmp_path / "trainer"), disable_tqdm=True, **kwargs + ) + + with test_util.set_mock_cluster_info(["0.0.0.0"], 0, 1) as info: + info.trial._config = {"searcher": {"name": "single", "metric": "eval_mae"}} + info._latest_checkpoint = latest_checkpoint + + model = MyOneVarModel() + train_dataset = OnesDataset(64) + eval_dataset = OnesDataset(64) + + events = utils.Events() + core_context, set_preempt = mock_core_context(str(tmp_path / "ckpt"), events) + + if set_preempt_on_event: + # Configure a hook for Events that calls set_preempt() when a matching event arrives. + p = re.compile(set_preempt_on_event) + + def hook(summary: str, data: Any) -> None: + if p.search(summary): + set_preempt() + + events.hook = hook + + det_cb = DetCallbackForTesting(events, core_context, args) + if force_final_save is not None: + det_cb._force_final_save = force_final_save + if force_final_evaluate is not None: + det_cb._force_final_evaluate = force_final_evaluate + + t = transformers.Trainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + callbacks=[det_cb], + ) + # The call to train must specify the checkpoint. We do set args.resume_from_checkpoint in + # our DetCallback but it isn't automatically respected. + t.train(resume_from_checkpoint=args.resume_from_checkpoint) + + return events + + +def check_hf_metrics(metrics: Dict[str, Any]) -> None: + # We remove the default rounded 'epoch' metric, and the + assert "epoch" not in metrics, metrics + # We remove the speed metrics. + speed_suffixes = ["_runtime", "_per_second", "_compilation_time"] + assert not any(any(m.endswith(s) for s in speed_suffixes) for m in metrics), metrics + # We inject "epochs" and "batches" + assert "epochs" in metrics, metrics + assert "batches" in metrics, metrics + + +def test_train_metrics(tmp_path: pathlib.Path) -> None: + # Make sure that training metrics happen every 5 steps, as specified. + events = do_train( + tmp_path, + num_train_epochs=2, + evaluation_strategy="epoch", + logging_steps=5, + ) + data = utils.assert_events_match( + events, + "!report_metrics:training", + ("report_metrics:training:5", "metrics"), + "!report_metrics:training", + "report_metrics:training:10", + "!report_metrics:training", + "report_metrics:training:15", + # Trainer always logs training metrics before exiting. + "report_metrics:training:16", + "!report_metrics:training", + ) + # Check non-epoch metrics. + check_hf_metrics(data["metrics"]) + + # If logging_steps aligns with our exit batch (logging_steps == len(data)), we only log once. + events = do_train( + tmp_path, + num_train_epochs=1, + evaluation_strategy="epoch", + logging_steps=8, + ) + data = utils.assert_events_match( + events, + "!report_metrics:training", + ("report_metrics:training:8", "metrics"), + "!report_metrics:training", + ) + # Check epoch metrics. + check_hf_metrics(data["metrics"]) + + +def test_save_at_end(tmp_path: pathlib.Path) -> None: + # We force a save even if Transformers wouldn't. + events = do_train( + tmp_path, + num_train_epochs=1, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:8", + "report_checkpoint", + "after_save:8", + "!report_checkpoint", + ) + + # We can override it. Also, this tests that the previous case was valid, because it proves that + # the save that occured was the one we forced. + events = do_train( + tmp_path, + force_final_save=False, + num_train_epochs=1, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + ) + + # Also, if the trainer naturally saves at that time, we don't duplicate the save. + events = do_train( + tmp_path, + # force_final_save=False, + num_train_epochs=1, + save_steps=8, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:8", + "report_checkpoint", + "after_save:8", + "!report_checkpoint", + ) + + # Same thing, but force_final_save=False to guarantee that the above test is valid (i.e. the + # save originated with Transformers). + events = do_train( + tmp_path, + force_final_save=False, + num_train_epochs=1, + save_steps=8, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:8", + "report_checkpoint", + "after_save:8", + "!report_checkpoint", + ) + + # Save a final checkpoint if we are preempted. + events = do_train( + tmp_path, + set_preempt_on_event="report_metrics:training:3", + logging_steps=1, + num_train_epochs=1, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:3", + "report_checkpoint", + "after_save:3", + "!report_checkpoint", + ) + + +def test_eval(tmp_path: pathlib.Path) -> None: + # Eval on epoch boundaries. + # (This test also ensures we don't double-evaluate with our evaluate-at-end logic). + events = do_train( + tmp_path, + num_train_epochs=2, + evaluation_strategy="epoch", + logging_steps=5, + ) + data = utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + ("report_metrics:validation:8", "metrics"), + "on_evaluate:8", + "!report_metrics:validation", + "!on_evaluate", + "report_metrics:validation:16", + "on_evaluate:16", + "!report_metrics:validation", + "!on_evaluate", + ) + # Check epoch metrics. + check_hf_metrics(data["metrics"]) + + # Eval off epoch boundaries, and once at the end. + events = do_train( + tmp_path, + num_train_epochs=1, + evaluation_strategy="steps", + eval_steps=5, + ) + data = utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + ("report_metrics:validation:5", "off-epoch-metrics"), + "on_evaluate:5", + "!report_metrics:validation", + "!on_evaluate", + ("report_metrics:validation:8", "final-metrics"), + "on_evaluate:8", + "!report_metrics:validation", + "!on_evaluate", + ) + # Check non-epoch metrics, and the at-end metrics. + check_hf_metrics(data["off-epoch-metrics"]) + check_hf_metrics(data["final-metrics"]) + + # Same thing, but we can disable the evaluate-at-end. Also this proves that our evaluate-at-end + # was working in the previous case. + events = do_train( + tmp_path, + force_final_evaluate=False, + num_train_epochs=1, + evaluation_strategy="steps", + eval_steps=5, + ) + utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + "report_metrics:validation:5", + "on_evaluate:5", + "!report_metrics:validation", + "!on_evaluate", + ) + + # Same thing, but we can disable the evaluate-at-end. Also this proves that our evaluate-at-end + # was working in the previous case. + events = do_train( + tmp_path, + force_final_evaluate=False, + num_train_epochs=1, + evaluation_strategy="steps", + eval_steps=5, + ) + utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + "report_metrics:validation:5", + "on_evaluate:5", + "!report_metrics:validation", + "!on_evaluate", + ) + + # Never evaluate-at-end if we got preempted. + events = do_train( + tmp_path, + set_preempt_on_event="report_metrics:training:3", + num_train_epochs=1, + logging_steps=1, + evaluation_strategy="steps", + eval_steps=5, + ) + utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + ) + + +def test_save_and_restore(tmp_path: pathlib.Path) -> None: + events = do_train( + tmp_path, + set_preempt_on_event="report_metrics:training:3", + max_steps=5, + logging_steps=1, + ) + data = utils.assert_events_match( + events, + ("after_epoch_end", "weight"), + ("report_checkpoint", "ckpt"), + ) + + # Make sure our next training continues from here. + ckpt = data["ckpt"] + ckpt_weight = data["weight"] + + # Note that model is loaded _after_ on_epoch_begin, so to know that we loaded a model we'll + # compare weight after training one batch to the checkpoint weight (which had more than one + # batch of training behind it). + events = do_train( + tmp_path, + latest_checkpoint=ckpt, + max_steps=1, + ) + data = utils.assert_events_match( + events, + # training should continue from global_step=3 + "on_train_begin:3", + ("after_epoch_end", "weight"), + ) + + # Model weight will be slowly moving from 0 to 1 throughout training. + assert data["weight"] > ckpt_weight diff --git a/harness/tests/experiment/utils.py b/harness/tests/experiment/utils.py index dd1311ad394..ccee5f51e90 100644 --- a/harness/tests/experiment/utils.py +++ b/harness/tests/experiment/utils.py @@ -2,7 +2,7 @@ import os import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union from unittest import mock import mypy_extensions @@ -480,3 +480,122 @@ def get_mock_distributed_context( mock_distributed_context.allgather.return_value = all_gather_return_value mock_distributed_context.gather.return_value = gather_return_value return mock_distributed_context + + +class Events: + """ + Events is basically just a list of (event_string, data) pairs, but where you can add a hook to + the .append() method. + + See assert_events_match for motivation. + """ + + def __init__(self, items: Optional[List[Tuple[str, Any]]] = None) -> None: + self.items = items if items is not None else [] + self.hook: Optional[Callable[[str, Any], None]] = None + + def append(self, item: Tuple[str, Any]) -> None: + self.items.append(item) + self.hook and self.hook(*item) + + def __str__(self) -> str: + return " - " + "\n - ".join(summary for summary, data in self.items) + + +def assert_events_match(events: Events, *patterns: Union[str, Tuple[str, str]]) -> Dict[str, Any]: + """ + Make sure the events from one a run of training loop match a set of patterns. + + This is a tool for testing the overall behavior of callback-based integrations with third party + training loops. The idea is that we have a bunch of individual hooks, probably with each one + being pretty trivial. The features we promise depend on a series of two or more hooks working + together to deliver, say, preemption with a save-at-the-end feature. Testing the individual + hooks does not guarantee that our hooks work together with the 3rd-party training loop to + accomplish the feature. Also, since we don't own the 3rd-paty training loop, testing individual + hooks does not detect if the training loop changes in a way that breaks us. + + So instead, we log what happens in a full run of the training loop by mocks and subclassing to + log events to an Events() object. See keras/test_callback.py or transformers/test_callback.py + for an example. + + Events are just pairs of (str, Any) where the string should uniquely identify the event and the + extra data can be anything you might need to extract from a matched event. + + Then we test the overall log of events against a set of positive and negative regex patterns + that specify an order of required events (positive patterns) and any unallowed events between + them (negative patterns). + + For example, to test that training metrics are reported at steps_completed=5 and 10, but no + validation metrics, training metrics, or checkpoint occurs between them, you could use this + sequence of patterns: + + - report_metrics:training:5 + - !report_metrics + - !report_checkpoint + - report_metrics:training:10 + + Note that the ! at the beginning of a pattern marks it as a negative pattern. Also note that + the negative patterns apply to the events after the previous positive match and before the next + postive match. + + A positive pattern may be specified as a Tuple of (pattern, name), in which case the data + associated with the matched event will be returned under the `name` key. + + Returns a dict containing the requested matched data. + """ + matched_data = {} + + def iter_patterns() -> Iterator[Tuple[List[re.Pattern], Optional[re.Pattern], Optional[str]]]: + negatives = [] + for p in patterns: + if isinstance(p, str): + # Plain string pattern. + pattern = p + name = None + else: + # (pattern, name) case. + pattern, name = p + # Gather up chunks of negative patterns until the next positive pattern. + if pattern.startswith("!"): + negatives.append(re.compile(pattern[1:])) + else: + positive = re.compile(pattern) + yield negatives, positive, name + negatives = [] + # Possibly yield a final chunk of just negative patterns. + yield negatives, None, None + + event_iter = iter(events.items) + for negatives, positive, name in iter_patterns(): + for event, data in event_iter: + # If this is the positive match, don't bother checking negatives. That way, you can + # have a broad negative check like "!report_metrics:training" that applies until some + # specific check like "report_metrics:training:10", meaning "no training metrics until + # steps_completed equals 10". + if positive and positive.search(event): + if name: + matched_data[name] = data + break + # Negatives must not match. + matches = [n.pattern for n in negatives if n.search(event)] + if matches: + if positive: + raise AssertionError( + f"negative pattern (!{matches[0]}) matched to event ({event}) before " + f"{positive.pattern} was found\nevents were:\n{events}" + ) + else: + raise AssertionError( + f"negative pattern (!{matches[0]}) matched to event ({event}) " + f"after final postive pattern\nevents were:\n{events}" + ) + else: + # End of events... did we match all of our postives? + if positive is None: + return matched_data + raise AssertionError( + f"did not match positive expression ({positive.pattern})\n" + f"events were:\n{events}" + ) + # Out of patterns. + return matched_data diff --git a/harness/tests/launch/test_tensorflow.py b/harness/tests/launch/test_tensorflow.py new file mode 100644 index 00000000000..ba52efb7fe9 --- /dev/null +++ b/harness/tests/launch/test_tensorflow.py @@ -0,0 +1,83 @@ +import json +import os +from unittest import mock + +import determined.launch.tensorflow # noqa: F401 +from determined import launch +from tests.launch import test_util + + +def test_parse_args() -> None: + positive_test_cases = { + "script arg": (29400, ["script", "arg"]), + "-- script -- arg": (29400, ["script", "--", "arg"]), + "-- script --port 1": (29400, ["script", "--port", "1"]), + "--port 1 -- script arg": (1, ["script", "arg"]), + "script --port 1": (29400, ["script", "--port", "1"]), + } + + negative_test_cases = { + "": "empty script", + "--port 1": "empty script", + "--port 1 --": "empty script", + "--asdf 1 script ": "unrecognized arguments", + } + test_util.parse_args_check( + positive_test_cases, negative_test_cases, launch.tensorflow.parse_args + ) + + +@mock.patch("subprocess.Popen") +@mock.patch("determined.get_cluster_info") +def test_single_node( + mock_cluster_info: mock.MagicMock, + mock_subprocess: mock.MagicMock, +) -> None: + cluster_info = test_util.make_mock_cluster_info(["0.0.0.0"], 0, 1) + mock_cluster_info.return_value = cluster_info + script = ["python3", "train.py"] + + mock_exit_code = 99 + mock_proc = mock.MagicMock() + mock_proc.wait.return_value = mock_exit_code + + mock_subprocess.return_value = mock_proc + + assert launch.tensorflow.main(88, script) == mock_exit_code + + launch_cmd = script + + # No TF_CONFIG or log wrapper for single node trainings. + env = {**os.environ, "DET_CHIEF_IP": "0.0.0.0"} + mock_subprocess.assert_called_once_with(launch_cmd, env=env) + + +@mock.patch("subprocess.Popen") +@mock.patch("determined.get_cluster_info") +def test_multi_node( + mock_cluster_info: mock.MagicMock, + mock_subprocess: mock.MagicMock, +) -> None: + cluster_info = test_util.make_mock_cluster_info(["0.0.0.0", "0.0.0.1"], 1, 2) + mock_cluster_info.return_value = cluster_info + port = 88 + script = ["python3", "train.py"] + + mock_exit_code = 99 + mock_proc = mock.MagicMock() + mock_proc.wait.return_value = mock_exit_code + + mock_subprocess.return_value = mock_proc + + assert launch.tensorflow.main(port, script) == mock_exit_code + + launch_cmd = launch.tensorflow.create_log_wrapper(1) + script + + env = {**os.environ, "DET_CHIEF_IP": "0.0.0.0"} + env["TF_CONFIG"] = json.dumps( + { + "cluster": {"worker": [f"0.0.0.0:{port}", f"0.0.0.1:{port}"]}, + "task": {"type": "worker", "index": 1}, + } + ) + mock_subprocess.assert_called_once_with(launch_cmd, env=env) diff --git a/harness/tests/search_methods.py b/harness/tests/search_methods.py deleted file mode 100644 index f482e8dbfc8..00000000000 --- a/harness/tests/search_methods.py +++ /dev/null @@ -1,555 +0,0 @@ -import dataclasses -import json -import logging -import pathlib -import pickle -import random -import sys -import uuid -from typing import Any, Dict, List, Optional, Set - -from urllib3 import connectionpool - -from determined import searcher - - -class RandomSearchMethod(searcher.SearchMethod): - def __init__( - self, - max_trials: int, - max_concurrent_trials: int, - max_length: int, - test_type: str = "core_api", - exception_points: Optional[List[str]] = None, - ) -> None: - self.max_trials = max_trials - self.max_concurrent_trials = max_concurrent_trials - self.max_length = max_length - - self.test_type = test_type - self.exception_points = exception_points - - self.created_trials = 0 - self.pending_trials = 0 - self.closed_trials = 0 - - def on_trial_created( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.raise_exception("on_trial_created") - if self.created_trials == 5: - self.raise_exception("on_trial_created_5") - self._log_stats() - return [] - - def on_validation_completed( - self, _: searcher.SearcherState, request_id: uuid.UUID, metric: Any, train_length: int - ) -> List[searcher.Operation]: - self.raise_exception("on_validation_completed") - return [] - - def on_trial_closed( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.pending_trials -= 1 - self.closed_trials += 1 - ops: List[searcher.Operation] = [] - if self.created_trials < self.max_trials: - request_id = uuid.uuid4() - ops.append( - searcher.Create( - request_id=request_id, hparams=self.sample_params(), checkpoint=None - ) - ) - ops.append(searcher.ValidateAfter(request_id=request_id, length=self.max_length)) - ops.append(searcher.Close(request_id=request_id)) - self.created_trials += 1 - self.pending_trials += 1 - elif self.pending_trials == 0: - self.raise_exception("on_trial_closed_shutdown") - ops.append(searcher.Shutdown()) - - self._log_stats() - self.raise_exception("on_trial_closed_end") - return ops - - def progress(self, searcher_state: searcher.SearcherState) -> float: - if 0 < self.max_concurrent_trials < self.pending_trials: - logging.error("pending trials is greater than max_concurrent_trial") - units_completed = sum( - ( - ( - self.max_length - if r in searcher_state.trials_closed - else searcher_state.trial_progress[r] - ) - for r in searcher_state.trial_progress - ) - ) - units_expected = self.max_length * self.max_trials - progress = units_completed / units_expected - logging.debug( - f"progress = {progress} = {units_completed} / {units_expected}," - f" {searcher_state.trial_progress}" - ) - - if progress >= 0.5: - self.raise_exception("progress_middle") - - return progress - - def on_trial_exited_early( - self, _: searcher.SearcherState, request_id: uuid.UUID, exited_reason: searcher.ExitedReason - ) -> List[searcher.Operation]: - self.pending_trials -= 1 - - ops: List[searcher.Operation] = [] - if exited_reason == searcher.ExitedReason.INVALID_HP: - request_id = uuid.uuid4() - ops.append( - searcher.Create( - request_id=request_id, hparams=self.sample_params(), checkpoint=None - ) - ) - ops.append(searcher.ValidateAfter(request_id=request_id, length=self.max_length)) - ops.append(searcher.Close(request_id=request_id)) - self.pending_trials += 1 - return ops - - self.closed_trials += 1 - self._log_stats() - return ops - - def initial_operations(self, _: searcher.SearcherState) -> List[searcher.Operation]: - self.raise_exception("initial_operations_start") - initial_trials = self.max_trials - max_concurrent_trials = self.max_concurrent_trials - if max_concurrent_trials > 0: - initial_trials = min(initial_trials, max_concurrent_trials) - - ops: List[searcher.Operation] = [] - - for __ in range(initial_trials): - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append(searcher.ValidateAfter(request_id=create.request_id, length=self.max_length)) - ops.append(searcher.Close(request_id=create.request_id)) - - self.created_trials += 1 - self.pending_trials += 1 - - self._log_stats() - return ops - - def _log_stats(self) -> None: - logging.info(f"created trials={self.created_trials}") - logging.info(f"pending trials={self.pending_trials}") - logging.info(f"closed trials={self.closed_trials}") - - def sample_params(self) -> Dict[str, int]: - hparams = {"global_batch_size": random.randint(10, 100)} - logging.info(f"hparams={hparams}") - return hparams - - def save_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("save_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("w") as f: - state = { - "max_trials": self.max_trials, - "max_concurrent_trials": self.max_concurrent_trials, - "max_length": self.max_length, - "created_trials": self.created_trials, - "pending_trials": self.pending_trials, - "closed_trials": self.closed_trials, - "exception_points": self.exception_points, - } - json.dump(state, f) - - def load_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("load_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("r") as f: - state = json.load(f) - self.max_trials = state["max_trials"] - self.max_concurrent_trials = state["max_concurrent_trials"] - self.max_length = state["max_length"] - self.created_trials = state["created_trials"] - self.pending_trials = state["pending_trials"] - self.closed_trials = state["closed_trials"] - - if self.test_type == "core_api": - # ony restore exception points for core_api searcher tests; - # local searcher is providing new exception point on resumption, - # and it shouldn't be overridden - self.exception_points = state["exception_points"] - - def raise_exception(self, exception_id: str) -> None: - if ( - self.exception_points is not None - and len(self.exception_points) > 0 - and exception_id == self.exception_points[0] - ): - logging.info(f"Raising exception in {exception_id}") - ex = connectionpool.MaxRetryError( - connectionpool.HTTPConnectionPool(host="dummyhost", port=8080), - "http://dummyurl", - ) - raise ex - - -@dataclasses.dataclass -class TrialMetric: - request_id: uuid.UUID - metric: float - promoted: bool = False - - -@dataclasses.dataclass -class Rung: - units_needed: int - idx: int - metrics: List[TrialMetric] = dataclasses.field(default_factory=list) - outstanding_trials: int = 0 - - def promotions_async( - self, request_id: uuid.UUID, metric: float, divisor: int - ) -> List[uuid.UUID]: - logging.info(f"Rung {self.idx}") - logging.info(f"outstanding_trials {self.outstanding_trials}") - - old_num_promote = len(self.metrics) // divisor - num_promote = (len(self.metrics) + 1) // divisor - - index = self._search_metric_index(metric) - promote_now = index < num_promote - trial_metric = TrialMetric(request_id=request_id, metric=metric, promoted=promote_now) - self.metrics.insert(index, trial_metric) - - if promote_now: - return [request_id] - if num_promote != old_num_promote and not self.metrics[old_num_promote].promoted: - self.metrics[old_num_promote].promoted = True - return [self.metrics[old_num_promote].request_id] - - logging.info("No promotion") - return [] - - def _search_metric_index(self, metric: float) -> int: - i: int = 0 - j: int = len(self.metrics) - while i < j: - mid = (i + j) >> 1 - if self.metrics[mid].metric <= metric: - i = mid + 1 - else: - j = mid - return i - - -class ASHASearchMethodState: - def __init__( - self, - max_length: int, - max_trials: int, - num_rungs: int, - divisor: int, - max_concurrent_trials: int = 16, - ) -> None: - # Asha params - self.max_length = max_length - self.max_trials = max_trials - self.num_rungs = num_rungs - self.divisor = divisor - self.max_concurrent_trials = max_concurrent_trials - self.is_smaller_better = True - - # structs - self.rungs: List[Rung] = [] - self.trial_rungs: Dict[uuid.UUID, int] = {} - - # accounting - self.pending_trials: int = 0 - self.completed_trials: int = 0 - self.invalid_trials: int = 0 - self.early_exit_trials: Set[uuid.UUID] = set() - self.closed_trials: Set[uuid.UUID] = set() - - self._init_rungs() - - def _init_rungs(self) -> None: - units_needed = 0 - for idx in range(self.num_rungs): - downsampling_rate = pow(self.divisor, float(self.num_rungs - idx - 1)) - units_needed += max(int(self.max_length / downsampling_rate), 1) - self.rungs.append(Rung(units_needed, idx)) - - -class ASHASearchMethod(searcher.SearchMethod): - def __init__( - self, - max_length: int, - max_trials: int, - num_rungs: int, - divisor: int, - test_type: str = "core_api", - max_concurrent_trials: int = 16, - exception_points: Optional[List[str]] = None, - ) -> None: - self.asha_search_state = ASHASearchMethodState( - max_length, max_trials, num_rungs, divisor, max_concurrent_trials - ) - self.test_type = test_type - self.exception_points = exception_points - - def on_trial_closed( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.asha_search_state.completed_trials += 1 - self.asha_search_state.closed_trials.add(request_id) - - if ( - self.asha_search_state.pending_trials == 0 - and self.asha_search_state.completed_trials == self.asha_search_state.max_trials - ): - self.raise_exception("shutdown") - return [searcher.Shutdown()] - - return [] - - def on_trial_created( - self, _: searcher.SearcherState, request_id: uuid.UUID - ) -> List[searcher.Operation]: - self.asha_search_state.rungs[0].outstanding_trials += 1 - self.asha_search_state.trial_rungs[request_id] = 0 - self.raise_exception("on_trial_created") - return [] - - def on_validation_completed( - self, _: searcher.SearcherState, request_id: uuid.UUID, metric: Any, train_length: int - ) -> List[searcher.Operation]: - assert isinstance(metric, float) - self.asha_search_state.pending_trials -= 1 - if self.asha_search_state.is_smaller_better is False: - metric *= -1 - ops = self.promote_async(request_id, metric) - self.raise_exception("on_validation_completed") - return ops - - def on_trial_exited_early( - self, _: searcher.SearcherState, request_id: uuid.UUID, exited_reason: searcher.ExitedReason - ) -> List[searcher.Operation]: - self.asha_search_state.pending_trials -= 1 - if exited_reason == searcher.ExitedReason.INVALID_HP: - ops: List[searcher.Operation] = [] - - self.asha_search_state.early_exit_trials.add(request_id) - ops.append(searcher.Close(request_id)) - self.asha_search_state.closed_trials.add(request_id) - self.asha_search_state.invalid_trials += 1 - - highest_rung_index = self.asha_search_state.trial_rungs[request_id] - rung = self.asha_search_state.rungs[highest_rung_index] - rung.outstanding_trials -= 1 - - for rung_idx in range(0, highest_rung_index + 1): - rung = self.asha_search_state.rungs[rung_idx] - rung.metrics = list(filter(lambda x: x.request_id != request_id, rung.metrics)) - - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append( - searcher.ValidateAfter( - request_id=create.request_id, - length=self.asha_search_state.rungs[0].units_needed, - ) - ) - - self.asha_search_state.trial_rungs[create.request_id] = 0 - self.asha_search_state.pending_trials += 1 - - return ops - - self.asha_search_state.early_exit_trials.add(request_id) - self.asha_search_state.closed_trials.add(request_id) - return self.promote_async(request_id, sys.float_info.max) - - def initial_operations(self, _: searcher.SearcherState) -> List[searcher.Operation]: - self.raise_exception("initial_operations_start") - ops: List[searcher.Operation] = [] - - if self.asha_search_state.max_concurrent_trials > 0: - max_concurrent_trials = min( - self.asha_search_state.max_concurrent_trials, self.asha_search_state.max_trials - ) - else: - max_concurrent_trials = max( - 1, - min( - int(pow(self.asha_search_state.divisor, self.asha_search_state.num_rungs - 1)), - self.asha_search_state.max_trials, - ), - ) - - for __ in range(0, max_concurrent_trials): - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append( - searcher.ValidateAfter( - request_id=create.request_id, - length=self.asha_search_state.rungs[0].units_needed, - ) - ) - - self.asha_search_state.trial_rungs[create.request_id] = 0 - self.asha_search_state.pending_trials += 1 - - return ops - - def promote_async(self, request_id: uuid.UUID, metric: float) -> List[searcher.Operation]: - rung_idx = self.asha_search_state.trial_rungs[request_id] - rung = self.asha_search_state.rungs[rung_idx] - rung.outstanding_trials -= 1 - added_train_workload = False - - ops: List[searcher.Operation] = [] - - if rung_idx == self.asha_search_state.num_rungs - 1: - rung.metrics.append(TrialMetric(request_id=request_id, metric=metric)) - - if request_id not in self.asha_search_state.early_exit_trials: - self.raise_exception("promote_async_close_trials") - ops.append(searcher.Close(request_id=request_id)) - logging.info(f"Closing trial {request_id}") - self.asha_search_state.closed_trials.add(request_id) - else: - next_rung = self.asha_search_state.rungs[rung_idx + 1] - self.raise_exception("promote_async") - logging.info(f"Promoting in rung {rung_idx}") - for promoted_request_id in rung.promotions_async( - request_id, metric, self.asha_search_state.divisor - ): - self.asha_search_state.trial_rungs[promoted_request_id] = rung_idx + 1 - next_rung.outstanding_trials += 1 - if promoted_request_id not in self.asha_search_state.early_exit_trials: - logging.info(f"Promoted {promoted_request_id}") - units_needed = max(next_rung.units_needed - rung.units_needed, 1) - ops.append(searcher.ValidateAfter(promoted_request_id, units_needed)) - added_train_workload = True - self.asha_search_state.pending_trials += 1 - else: - return self.promote_async(promoted_request_id, sys.float_info.max) - - all_trials = len(self.asha_search_state.trial_rungs) - self.asha_search_state.invalid_trials - if not added_train_workload and all_trials < self.asha_search_state.max_trials: - logging.info("Creating new trial instead of promoting") - self.asha_search_state.pending_trials += 1 - - create = searcher.Create( - request_id=uuid.uuid4(), - hparams=self.sample_params(), - checkpoint=None, - ) - ops.append(create) - ops.append( - searcher.ValidateAfter( - request_id=create.request_id, - length=self.asha_search_state.rungs[0].units_needed, - ) - ) - self.asha_search_state.trial_rungs[create.request_id] = 0 - - if len(self.asha_search_state.rungs[0].metrics) == self.asha_search_state.max_trials: - ops.extend(self._get_close_rungs_ops()) - - return ops - - def _get_close_rungs_ops(self) -> List[searcher.Operation]: - self.raise_exception("_get_close_rungs_ops") - ops: List[searcher.Operation] = [] - - for rung in self.asha_search_state.rungs: - if rung.outstanding_trials > 0: - break - for trial_metric in rung.metrics: - if ( - not trial_metric.promoted - and trial_metric.request_id not in self.asha_search_state.closed_trials - ): - if trial_metric.request_id not in self.asha_search_state.early_exit_trials: - logging.info(f"Closing trial {trial_metric.request_id}") - ops.append(searcher.Close(trial_metric.request_id)) - self.asha_search_state.closed_trials.add(trial_metric.request_id) - return ops - - def sample_params(self) -> Dict[str, object]: - hparams = { - "global_batch_size": 10, - "metrics_base": 0.05 * (len(self.asha_search_state.trial_rungs) + 1), - "metrics_progression": "constant", - } - logging.info(f"hparams={hparams}") - return hparams - - def progress(self, _: searcher.SearcherState) -> float: - if 0 < self.asha_search_state.max_concurrent_trials < self.asha_search_state.pending_trials: - raise RuntimeError("Pending trial is greater than max concurrent trials") - all_trials = len(self.asha_search_state.rungs[0].metrics) - - progress = all_trials / (1.2 * self.asha_search_state.max_trials) - if all_trials == self.asha_search_state.max_trials: - num_valid_trials = ( - self.asha_search_state.completed_trials - self.asha_search_state.invalid_trials - ) - progress_no_overhead = num_valid_trials / self.asha_search_state.max_trials - progress = max(progress_no_overhead, progress) - - return progress - - def save_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("save_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("wb") as f: - pickle.dump(self.asha_search_state, f) - - exception_path = path.joinpath("exceptions") - with exception_path.open("wb") as f: - pickle.dump(self.exception_points, f) - - def load_method_state(self, path: pathlib.Path) -> None: - self.raise_exception("load_method_state") - checkpoint_path = path.joinpath("method_state") - with checkpoint_path.open("rb") as f: - self.asha_search_state = pickle.load(f) - - if self.test_type == "core_api": - # ony restore exception points for core_api searcher tests; - # local searcher is providing new exception point on resumption, - # and it shouldn't be overridden - exception_path = path.joinpath("exceptions") - with exception_path.open("rb") as f: - self.exception_points = pickle.load(f) - - def raise_exception(self, exception_id: str) -> None: - if ( - self.exception_points is not None - and len(self.exception_points) > 0 - and exception_id == self.exception_points[0] - ): - logging.info(f"Raising exception in {exception_id}") - ex = connectionpool.MaxRetryError( - connectionpool.HTTPConnectionPool(host="dummyhost", port=8080), "http://dummyurl" - ) - raise ex diff --git a/harness/tests/test_custom_searcher.py b/harness/tests/test_custom_searcher.py deleted file mode 100644 index 3d8e0089e08..00000000000 --- a/harness/tests/test_custom_searcher.py +++ /dev/null @@ -1,46 +0,0 @@ -import pathlib -import tempfile - -from tests import custom_search_mocks, search_methods - - -def test_run_random_searcher_exp_mock_master() -> None: - max_trials = 5 - max_concurrent_trials = 2 - max_length = 500 - - with tempfile.TemporaryDirectory() as searcher_dir: - search_method = search_methods.RandomSearchMethod( - max_trials, max_concurrent_trials, max_length - ) - mock_master_obj = custom_search_mocks.SimulateMaster(metric=1.0) - search_runner = custom_search_mocks.MockMasterSearchRunner( - search_method, mock_master_obj, pathlib.Path(searcher_dir) - ) - search_runner.run(exp_config={}, context_dir="", includes=None) - - assert search_method.created_trials == 5 - assert search_method.pending_trials == 0 - assert search_method.closed_trials == 5 - assert len(search_runner.state.trials_created) == search_method.created_trials - assert len(search_runner.state.trials_closed) == search_method.closed_trials - - -def test_run_asha_batches_exp_mock_master(tmp_path: pathlib.Path) -> None: - max_length = 3000 - max_trials = 16 - num_rungs = 3 - divisor = 4 - - search_method = search_methods.ASHASearchMethod(max_length, max_trials, num_rungs, divisor) - mock_master_obj = custom_search_mocks.SimulateMaster(metric=1.0) - search_runner = custom_search_mocks.MockMasterSearchRunner( - search_method, mock_master_obj, tmp_path - ) - search_runner.run(exp_config={}, context_dir="", includes=None) - - assert search_method.asha_search_state.pending_trials == 0 - assert search_method.asha_search_state.completed_trials == 16 - assert len(search_runner.state.trials_closed) == len( - search_method.asha_search_state.closed_trials - ) diff --git a/master/internal/api_config_policies_intg_test.go b/master/internal/api_config_policies_intg_test.go index d2503bac37a..3c3c15a8f59 100644 --- a/master/internal/api_config_policies_intg_test.go +++ b/master/internal/api_config_policies_intg_test.go @@ -1,3 +1,6 @@ +//go:build integration +// +build integration + package internal import ( diff --git a/master/internal/api_experiment.go b/master/internal/api_experiment.go index 54d43b98e9e..5fa1914f172 100644 --- a/master/internal/api_experiment.go +++ b/master/internal/api_experiment.go @@ -251,85 +251,6 @@ func (a *apiServer) getExperimentAndCheckCanDoActions( return experiment.GetExperimentAndCheckCanDoActions(ctx, expID, actions...) } -func (a *apiServer) GetSearcherEvents( - ctx context.Context, req *apiv1.GetSearcherEventsRequest, -) (*apiv1.GetSearcherEventsResponse, error) { - curUser, _, err := grpcutil.GetUser(ctx) - if err != nil { - return nil, err - } - exp, err := a.getExperiment(ctx, *curUser, int(req.ExperimentId)) - if err != nil { - return nil, err - } - if !isActiveExperimentState(exp.State) { - return &apiv1.GetSearcherEventsResponse{ - SearcherEvents: []*experimentv1.SearcherEvent{{ - Id: -1, - Event: &experimentv1.SearcherEvent_ExperimentInactive{ - ExperimentInactive: &experimentv1.ExperimentInactive{ - ExperimentState: exp.State, - }, - }, - }}, - }, nil - } - - e, ok := experiment.ExperimentRegistry.Load(int(req.ExperimentId)) - if !ok { - return nil, api.NotFoundErrs("experiment", strconv.Itoa(int(req.ExperimentId)), true) - } - w, err := e.GetSearcherEventsWatcher() - if err != nil { - return nil, status.Errorf(codes.Internal, - "failed to get searcher events: long polling %v", err) - } - defer func() { - if err := e.UnwatchEvents(w.ID); err != nil { - log.WithError(err).Errorf("error unwatching searcher events") - } - }() - - ctx, cancel := context.WithTimeout(ctx, time.Duration(60)*time.Second) - defer cancel() - - select { - case events := <-w.C: - return &apiv1.GetSearcherEventsResponse{ - SearcherEvents: events, - }, nil - case <-ctx.Done(): - return &apiv1.GetSearcherEventsResponse{ - SearcherEvents: nil, - }, nil - } -} - -func (a *apiServer) PostSearcherOperations( - ctx context.Context, - req *apiv1.PostSearcherOperationsRequest, -) ( - resp *apiv1.PostSearcherOperationsResponse, err error, -) { - _, _, err = a.getExperimentAndCheckCanDoActions( - ctx, int(req.ExperimentId), experiment.AuthZProvider.Get().CanRunCustomSearch, - ) - if err != nil { - return nil, errors.Wrap(err, "fetching experiment from database") - } - - e, ok := experiment.ExperimentRegistry.Load(int(req.ExperimentId)) - if !ok { - return nil, api.NotFoundErrs("experiment", strconv.Itoa(int(req.ExperimentId)), true) - } - if err := e.PerformSearcherOperations(req); err != nil { - return nil, status.Errorf(codes.Internal, "failed to post operations: %v", err) - } - - log.Infof("posted operations %v", req.SearcherOperations) - return &apiv1.PostSearcherOperationsResponse{}, nil -} - func (a *apiServer) GetExperiment( ctx context.Context, req *apiv1.GetExperimentRequest, ) (*apiv1.GetExperimentResponse, error) { @@ -913,44 +834,14 @@ func (a *apiServer) PreviewHPSearch( return nil, errors.Wrap(err, "invalid experiment configuration") } - sm := searcher.NewSearchMethod(sc) - s := searcher.NewSearcher(req.Seed, sm, hc) - sim, err := searcher.Simulate(s, nil, searcher.RandomValidation, true, sc.Metric()) + sim, err := searcher.Simulate(sc, hc) if err != nil { return nil, err } - protoSim := &experimentv1.ExperimentSimulation{Seed: req.Seed} - indexes := make(map[string]int, len(sim.Results)) - toProto := func(op searcher.ValidateAfter) ([]*experimentv1.RunnableOperation, error) { - return []*experimentv1.RunnableOperation{ - { - Type: experimentv1.RunnableType_RUNNABLE_TYPE_TRAIN, - Length: op.Length, - }, - { - Type: experimentv1.RunnableType_RUNNABLE_TYPE_VALIDATE, - }, - }, nil - } - for _, result := range sim.Results { - var operations []*experimentv1.RunnableOperation - for _, msg := range result { - ops, err := toProto(msg) - if err != nil { - return nil, errors.Wrapf(err, "error converting msg in simultion result %s", msg) - } - operations = append(operations, ops...) - } - hash := fmt.Sprint(operations) - if i, ok := indexes[hash]; ok { - protoSim.Trials[i].Occurrences++ - } else { - protoSim.Trials = append(protoSim.Trials, - &experimentv1.TrialSimulation{Operations: operations, Occurrences: 1}) - indexes[hash] = len(protoSim.Trials) - 1 - } - } - return &apiv1.PreviewHPSearchResponse{Simulation: protoSim}, nil + + return &apiv1.PreviewHPSearchResponse{ + Summary: sim.Proto(), + }, nil } func (a *apiServer) ActivateExperiment( @@ -1462,14 +1353,6 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string providedConfig, err := expconf.ParseAnyExperimentConfigYAML([]byte(overrideConfig)) if err != nil { - // Add a helpful error message if a user just submits - // searcher.max_length.batches = 2. They would also need - // searcher.name = "single", which all experiments will always be here. - if strings.Contains(err.Error(), `unknown field "max_length"`) { - return nil, false, status.Errorf(codes.InvalidArgument, - `unknown field "max_length", you might also need to specify searcher.name=single`) - } - return nil, false, status.Errorf(codes.InvalidArgument, fmt.Errorf("parsing override config: %w", err).Error()) } @@ -2193,7 +2076,6 @@ func (a *apiServer) fetchTrialSample(trialID int32, metricName string, metricGro var err error var trial apiv1.TrialsSampleResponse_Trial var metricMeasurements []db.MetricMeasurements - xAxisLabelMetrics := []string{"epoch"} trial.TrialId = trialID @@ -2212,7 +2094,7 @@ func (a *apiServer) fetchTrialSample(trialID int32, metricName string, metricGro } metricMeasurements, err = trials.MetricsTimeSeries(trialID, startTime, []string{metricName}, - startBatches, endBatches, xAxisLabelMetrics, maxDatapoints, + startBatches, endBatches, maxDatapoints, "batches", nil, metricGroup) if err != nil { return nil, errors.Wrapf(err, "error fetching time series of metrics") diff --git a/master/internal/api_experiment_intg_test.go b/master/internal/api_experiment_intg_test.go index 8b07e8c3e05..38c79146854 100644 --- a/master/internal/api_experiment_intg_test.go +++ b/master/internal/api_experiment_intg_test.go @@ -478,7 +478,7 @@ func TestHPSearchContinueProvideConfigError(t *testing.T) { _, err := db.Bun().NewUpdate().Table("experiments"). Set("state = ?", model.CompletedState). Set("config = jsonb_set(config, '{searcher}', "+ - `'{"name": "random", "metric": "loss", "max_trials": 5, "max_length": 5}', false)`). + `'{"name": "random", "metric": "loss", "max_trials": 5}', false)`). Where("id = ?", trial.ExperimentID). Exec(ctx) require.NoError(t, err) @@ -497,7 +497,7 @@ func TestHPSearchContinueCompletedError(t *testing.T) { _, err := db.Bun().NewUpdate().Table("experiments"). Set("state = ?", model.CompletedState). Set("config = jsonb_set(config, '{searcher}', "+ - `'{"name": "random", "metric": "loss", "max_trials": 5, "max_length": 5}', false)`). + `'{"name": "random", "metric": "loss", "max_trials": 5}', false)`). Where("id = ?", trial.ExperimentID). Exec(ctx) require.NoError(t, err) @@ -605,26 +605,9 @@ workspace: test _, _, err = api.parseAndMergeContinueConfig(exp.ID, ` searcher: - max_length: - batches: 10 -`) - require.ErrorContains(t, err, "you might also need to specify searcher.name=single") - - _, _, err = api.parseAndMergeContinueConfig(exp.ID, ` -searcher: - max_length: - batches: 10 - name: single -`) - require.NoError(t, err) - - _, _, err = api.parseAndMergeContinueConfig(exp.ID, ` -searcher: name: random metric: accuracy max_trials: 5 - max_length: - batches: 1000 `) require.ErrorContains(t, err, "override config must have single searcher type got 'random' instead") @@ -640,8 +623,6 @@ searcher: smaller_is_better: true name: random max_trials: 3 - max_length: - batches: 10 resources: resource_pool: kubernetes` createReq := &apiv1.CreateExperimentRequest{ @@ -664,8 +645,6 @@ searcher: name: random metric: accuracy max_trials: 5 - max_length: - batches: 1000 `) require.ErrorContains(t, err, "override config is provided and experiment is not single searcher, got 'random' instead") @@ -685,7 +664,6 @@ entrypoint: test searcher: metric: loss name: single - max_length: 10 resources: resource_pool: kubernetes` createReq := &apiv1.CreateExperimentRequest{ @@ -1443,7 +1421,6 @@ func createTestExpWithActiveConfig( Config: activeConfig.AsLegacy(), } require.NoError(t, api.m.db.AddExperiment(exp, []byte{10, 11, 12}, activeConfig)) - // Get experiment as our API mostly will to make it easier to mock. exp, err := db.ExperimentByID(context.TODO(), exp.ID) require.NoError(t, err) diff --git a/master/internal/api_logretention_intg_test.go b/master/internal/api_logretention_intg_test.go index 51f413bcd86..ba16197b0c6 100644 --- a/master/internal/api_logretention_intg_test.go +++ b/master/internal/api_logretention_intg_test.go @@ -43,10 +43,10 @@ retention_policy: func setRetentionTime(timestamp string) error { _, err := db.Bun().NewRaw(fmt.Sprintf(` - CREATE or REPLACE FUNCTION retention_timestamp() RETURNS TIMESTAMPTZ AS $$ + CREATE or REPLACE FUNCTION retention_timestamp() RETURNS TIMESTAMPTZ AS $$ BEGIN RETURN %s; - END + END $$ LANGUAGE PLPGSQL; `, timestamp)).Exec(context.Background()) return err @@ -96,10 +96,9 @@ hyperparameters: searcher: name: grid metric: none - max_length: %d max_concurrent_trials: %d %s -`, numTrials, numTrials, config) +`, numTrials, config) createReq := &apiv1.CreateExperimentRequest{ ModelDefinition: []*utilv1.File{{Content: []byte{1}}}, Config: conf, diff --git a/master/internal/api_runs.go b/master/internal/api_runs.go index 807540947fc..69059b51a29 100644 --- a/master/internal/api_runs.go +++ b/master/internal/api_runs.go @@ -548,7 +548,6 @@ func (a *apiServer) KillRuns(ctx context.Context, req *apiv1.KillRunsRequest, type killRunOKResult struct { ID int32 - RequestID *string IsTerminal bool } @@ -557,7 +556,6 @@ func (a *apiServer) KillRuns(ctx context.Context, req *apiv1.KillRunsRequest, Model(&killCandidatees). Join("LEFT JOIN trials_v2 t ON r.id=t.run_id"). Column("r.id"). - ColumnExpr("t.request_id"). ColumnExpr("r.state IN (?) AS is_terminal", bun.In(model.StatesToStrings(model.TerminalStates))). Where("r.project_id = ?", req.ProjectId) @@ -592,13 +590,6 @@ func (a *apiServer) KillRuns(ctx context.Context, req *apiv1.KillRunsRequest, Error: "", Id: cand.ID, }) - // This should be impossible in the current system but we will leave this check here - // to cover a possible error in integration tests - case cand.RequestID == nil: - results = append(results, &apiv1.RunActionResult{ - Error: "Run has no associated request id.", - Id: cand.ID, - }) default: validIDs = append(validIDs, cand.ID) } diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index 3be8b83b295..ad1ca9ca10d 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -36,7 +36,6 @@ import ( "github.com/determined-ai/determined/proto/pkg/apiv1" "github.com/determined-ai/determined/proto/pkg/checkpointv1" "github.com/determined-ai/determined/proto/pkg/commonv1" - "github.com/determined-ai/determined/proto/pkg/experimentv1" "github.com/determined-ai/determined/proto/pkg/trialv1" ) @@ -847,10 +846,6 @@ func (a *apiServer) multiTrialSample(trialID int32, metricNames []string, ) ([]*apiv1.DownsampledMetrics, error) { var startTime time.Time var metrics []*apiv1.DownsampledMetrics - // For now "epoch" is the only custom xAxis metric label supported so we - // build the `MetricSeriesEpoch` array. In the future this logic should - // be updated to support any number of xAxis metric options - xAxisLabelMetrics := []string{"epoch"} if err := db.ValidatePolymorphicFilter(timeSeriesFilter); err != nil { return nil, err @@ -902,7 +897,6 @@ func (a *apiServer) multiTrialSample(trialID int32, metricNames []string, var metric apiv1.DownsampledMetrics metricMeasurements, err := trials.MetricsTimeSeries( trialID, startTime, aMetricNames, startBatches, endBatches, - xAxisLabelMetrics, maxDatapoints, *timeSeriesColumn, timeSeriesFilter, aMetricGroup) if err != nil { return nil, errors.Wrapf(err, fmt.Sprintf("error fetching time series of %s metrics", @@ -1378,73 +1372,6 @@ func (a *apiServer) MarkAllocationResourcesDaemon( return &apiv1.MarkAllocationResourcesDaemonResponse{}, nil } -func (a *apiServer) GetCurrentTrialSearcherOperation( - ctx context.Context, req *apiv1.GetCurrentTrialSearcherOperationRequest, -) (*apiv1.GetCurrentTrialSearcherOperationResponse, error) { - curUser, _, err := grpcutil.GetUser(ctx) - if err != nil { - return nil, err - } - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, - experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { - return nil, err - } - eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.TrialId)) - if err != nil { - return nil, err - } - - e, ok := experiment.ExperimentRegistry.Load(eID) - if !ok { - return nil, api.NotFoundErrs("experiment", strconv.Itoa(eID), true) - } - resp, err := e.TrialGetSearcherState(rID) - if err != nil { - return nil, err - } - - return &apiv1.GetCurrentTrialSearcherOperationResponse{ - Op: &experimentv1.TrialOperation{ - Union: &experimentv1.TrialOperation_ValidateAfter{ - ValidateAfter: resp.Op.ToProto(), - }, - }, - Completed: resp.Complete, - }, nil -} - -func (a *apiServer) CompleteTrialSearcherValidation( - ctx context.Context, req *apiv1.CompleteTrialSearcherValidationRequest, -) (*apiv1.CompleteTrialSearcherValidationResponse, error) { - curUser, _, err := grpcutil.GetUser(ctx) - if err != nil { - return nil, err - } - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, - experiment.AuthZProvider.Get().CanEditExperiment); err != nil { - return nil, err - } - eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.TrialId)) - if err != nil { - return nil, err - } - - e, ok := experiment.ExperimentRegistry.Load(eID) - if !ok { - return nil, api.NotFoundErrs("experiment", strconv.Itoa(eID), true) - } - - msg := experiment.TrialCompleteOperation{ - RequestID: rID, - Metric: req.CompletedOperation.SearcherMetric.AsInterface(), - Op: searcher.NewValidateAfter(rID, req.CompletedOperation.Op.Length), - } - if err := e.TrialCompleteOperation(msg); err != nil { - return nil, err - } - return &apiv1.CompleteTrialSearcherValidationResponse{}, nil -} - func (a *apiServer) ReportTrialSearcherEarlyExit( ctx context.Context, req *apiv1.ReportTrialSearcherEarlyExitRequest, ) (*apiv1.ReportTrialSearcherEarlyExitResponse, error) { @@ -1503,11 +1430,10 @@ func (a *apiServer) ReportTrialProgress( } msg := experiment.TrialReportProgress{ - RequestID: rID, - Progress: searcher.PartialUnits(req.Progress), - IsRaw: req.IsRaw, + Progress: searcher.PartialUnits(req.Progress), + IsRaw: req.IsRaw, } - if err := e.TrialReportProgress(msg); err != nil { + if err := e.TrialReportProgress(rID, msg); err != nil { return nil, err } return &apiv1.ReportTrialProgressResponse{}, nil @@ -1528,6 +1454,22 @@ func (a *apiServer) ReportTrialMetrics( experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } + if metricGroup == model.ValidationMetricGroup { + // Notify searcher of validation metrics. + eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.Metrics.TrialId)) + if err != nil { + return nil, errors.Errorf("Failed to get experiment ID from trial ID (%d)", req.Metrics.TrialId) + } + e, ok := experiment.ExperimentRegistry.Load(eID) + if ok { + // Report validation metrics to the searcher. Skip for experiments (such as detached mode) + // that are not already loaded in the master ExperimentRegistry. + err = e.TrialReportValidation(rID, req.Metrics.Metrics.AvgMetrics.AsMap()) + if err != nil { + return nil, err + } + } + } if err := a.m.db.AddTrialMetrics(ctx, req.Metrics, metricGroup); err != nil { return nil, err } diff --git a/master/internal/api_trials_intg_test.go b/master/internal/api_trials_intg_test.go index 66883903279..3ec3d731414 100644 --- a/master/internal/api_trials_intg_test.go +++ b/master/internal/api_trials_intg_test.go @@ -45,16 +45,18 @@ func createTestTrial( ) (*model.Trial, *model.Task) { exp := createTestExpWithProjectID(t, api, curUser, 1) + requestID := model.NewRequestID(rand.Reader) task := &model.Task{ TaskType: model.TaskTypeTrial, LogVersion: model.TaskLogVersion1, StartTime: time.Now(), - TaskID: trialTaskID(exp.ID, model.NewRequestID(rand.Reader)), + TaskID: trialTaskID(exp.ID, requestID), } require.NoError(t, db.AddTask(context.TODO(), task)) trial := &model.Trial{ StartTime: time.Now(), + RequestID: &requestID, State: model.PausedState, ExperimentID: exp.ID, } @@ -746,20 +748,6 @@ func TestTrialAuthZ(t *testing.T) { }) return err }, false}, - {"CanGetExperimentArtifacts", func(id int) error { - _, err := api.GetCurrentTrialSearcherOperation(ctx, - &apiv1.GetCurrentTrialSearcherOperationRequest{ - TrialId: int32(id), - }) - return err - }, false}, - {"CanEditExperiment", func(id int) error { - _, err := api.CompleteTrialSearcherValidation(ctx, - &apiv1.CompleteTrialSearcherValidationRequest{ - TrialId: int32(id), - }) - return err - }, false}, {"CanEditExperiment", func(id int) error { _, err := api.ReportTrialSearcherEarlyExit(ctx, &apiv1.ReportTrialSearcherEarlyExitRequest{ diff --git a/master/internal/core.go b/master/internal/core.go index 12df26b972c..a42dca3534c 100644 --- a/master/internal/core.go +++ b/master/internal/core.go @@ -1512,9 +1512,6 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error { checkpointsGroup := m.echo.Group("/checkpoints") checkpointsGroup.GET("/:checkpoint_uuid", m.getCheckpoint) - searcherGroup := m.echo.Group("/searcher") - searcherGroup.POST("/preview", api.Route(m.getSearcherPreview)) - resourcesGroup := m.echo.Group("/resources", cluster.CanGetUsageDetails()) resourcesGroup.GET("/allocation/raw", m.getRawResourceAllocation) resourcesGroup.GET("/allocation/allocations-csv", m.getResourceAllocations) diff --git a/master/internal/core_searcher.go b/master/internal/core_searcher.go index f543b1ccc72..f80c1a13b94 100644 --- a/master/internal/core_searcher.go +++ b/master/internal/core_searcher.go @@ -1,58 +1,9 @@ package internal import ( - "io" - - "github.com/labstack/echo/v4" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" - - "github.com/determined-ai/determined/master/pkg/schemas" - "github.com/determined-ai/determined/master/pkg/schemas/expconf" - "github.com/determined-ai/determined/master/pkg/searcher" ) -func (m *Master) getSearcherPreview(c echo.Context) (interface{}, error) { - bytes, err := io.ReadAll(c.Request().Body) - if err != nil { - return nil, err - } - - // Parse the provided experiment config. - config, err := expconf.ParseAnyExperimentConfigYAML(bytes) - if err != nil { - return nil, errors.Wrapf(err, "invalid experiment configuration") - } - - // Get the useful subconfigs for preview search. - if config.RawSearcher == nil { - return nil, errors.New("invalid experiment configuration; missing searcher") - } - sc := *config.RawSearcher - hc := config.RawHyperparameters - - // Apply any json-schema-defined defaults. - sc = schemas.WithDefaults(sc) - hc = schemas.WithDefaults(hc) - - // Make sure the searcher config has all eventuallyRequired fields. - if err = schemas.IsComplete(sc); err != nil { - return nil, errors.Wrapf(err, "invalid searcher configuration") - } - if err = schemas.IsComplete(hc); err != nil { - return nil, errors.Wrapf(err, "invalid hyperparameters configuration") - } - - // Disallow EOL searchers. - if err = sc.AssertCurrent(); err != nil { - return nil, errors.Wrap(err, "invalid experiment configuration") - } - - sm := searcher.NewSearchMethod(sc) - s := searcher.NewSearcher(0, sm, hc) - return searcher.Simulate(s, nil, searcher.RandomValidation, true, config.Searcher().Metric()) -} - // cleanUpExperimentSnapshots deletes all snapshots for terminal state experiments from // the database. func (m *Master) cleanUpExperimentSnapshots() { diff --git a/master/internal/db/postgres_experiments.go b/master/internal/db/postgres_experiments.go index 6346aa60117..1bd04090ad5 100644 --- a/master/internal/db/postgres_experiments.go +++ b/master/internal/db/postgres_experiments.go @@ -387,25 +387,6 @@ LIMIT 1`, metricOrdering), exp.Config.Searcher.Metric, id).Scan(ctx, &metric); e return metric, nil } -// TrialExperimentAndRequestID returns the trial's experiment and request ID. -func (db *PgDB) TrialExperimentAndRequestID(id int) (int, model.RequestID, error) { - var eID int - var rID model.RequestID - err := db.sql.QueryRow(` -SELECT e.id, t.request_id -FROM trials t, experiments e -WHERE t.experiment_id = e.id - AND t.id = $1`, id).Scan(&eID, &rID) - switch { - case err == sql.ErrNoRows: - return eID, rID, errors.WithStack(ErrNotFound) - case err != nil: - return eID, rID, errors.Wrap(err, "failed to get trial exp and req id") - default: - return eID, rID, nil - } -} - // ExperimentConfigRaw returns the full config object for an experiment as a JSON string. func (db *PgDB) ExperimentConfigRaw(id int) ([]byte, error) { return db.rawQuery(` @@ -610,6 +591,25 @@ SELECT experiment_id FROM trials where id = $1 return experimentID, nil } +// TrialExperimentAndRequestID returns the trial's experiment and request ID. +func (db *PgDB) TrialExperimentAndRequestID(id int) (int, model.RequestID, error) { + var eID int + var rID model.RequestID + err := db.sql.QueryRow(` +SELECT e.id, t.request_id +FROM trials t, experiments e +WHERE t.experiment_id = e.id + AND t.id = $1`, id).Scan(&eID, &rID) + switch { + case err == sql.ErrNoRows: + return eID, rID, errors.WithStack(ErrNotFound) + case err != nil: + return eID, rID, errors.Wrap(err, "failed to get trial exp and req id") + default: + return eID, rID, nil + } +} + // NonTerminalExperiments finds all experiments in the database whose states are not terminal. func (db *PgDB) NonTerminalExperiments() ([]*model.Experiment, error) { rows, err := db.sql.Queryx(` diff --git a/master/internal/db/postgres_experiments_intg_test.go b/master/internal/db/postgres_experiments_intg_test.go index 68cab82424c..df20f439f64 100644 --- a/master/internal/db/postgres_experiments_intg_test.go +++ b/master/internal/db/postgres_experiments_intg_test.go @@ -771,12 +771,12 @@ func TestDeleteExperiments(t *testing.T) { // Create experiment snapshot //nolint:exhaustruct config := expconf.SearcherConfig{ - RawCustomConfig: &expconf.CustomConfig{}, + RawSingleConfig: &expconf.SingleConfigV0{}, } searcher1 := searcher.NewSearcher(3, searcher.NewSearchMethod(config), nil) - _, err := searcher1.InitialOperations() + _, err := searcher1.InitialTrials() require.NoError(t, err) - _, err = searcher1.TrialExitedEarly(model.RequestID(uuid.New()), model.Errored) + _, err = searcher1.TrialExitedEarly(model.RequestID{}, model.Errored) require.NoError(t, err) snapshot, err := searcher1.Snapshot() diff --git a/master/internal/db/postgres_snapshots_test.go b/master/internal/db/postgres_snapshots_test.go deleted file mode 100644 index ebcbb8e514b..00000000000 --- a/master/internal/db/postgres_snapshots_test.go +++ /dev/null @@ -1,66 +0,0 @@ -//go:build integration -// +build integration - -package db - -import ( - "context" - "testing" - - "github.com/google/uuid" - - "github.com/stretchr/testify/require" - - "github.com/determined-ai/determined/master/pkg/etc" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/schemas/expconf" - "github.com/determined-ai/determined/master/pkg/searcher" -) - -func TestCustomSearcherSnapshot(t *testing.T) { - err := etc.SetRootPath(RootFromDB) - require.NoError(t, err) - db, closeDB := MustResolveTestPostgres(t) - defer closeDB() - MustMigrateTestPostgres(t, db, MigrationsFromDB) - user := RequireMockUser(t, db) - exp := RequireMockExperiment(t, db, user) - - //nolint:exhaustruct - config := expconf.SearcherConfig{ - RawCustomConfig: &expconf.CustomConfig{}, - } - - // Create a searcher and add some operations to it. - searcher1 := searcher.NewSearcher(3, searcher.NewSearchMethod(config), nil) - _, err = searcher1.InitialOperations() - require.NoError(t, err) - _, err = searcher1.TrialExitedEarly(model.RequestID(uuid.New()), model.Errored) - require.NoError(t, err) - - // Save snapshot to database. - snapshot, err := searcher1.Snapshot() - require.NoError(t, err) - err = db.SaveSnapshot(exp.ID, 2, snapshot) - require.NoError(t, err) - - // Retrieve snapshot from database. - restoredSnapshot, _, err := db.ExperimentSnapshot(exp.ID) - require.NoError(t, err) - - // Verify that restoring the snapshot yields a searcher in the same state as before. - searcher2 := searcher.NewSearcher(4, searcher.NewSearchMethod(config), nil) - err = searcher2.Restore(restoredSnapshot) - require.NoError(t, err) - queue1, err := searcher1.GetCustomSearcherEventQueue() - require.NoError(t, err) - queue2, err := searcher2.GetCustomSearcherEventQueue() - require.NoError(t, err) - require.Equal(t, queue1.GetEvents(), queue2.GetEvents()) - - err = db.DeleteSnapshotsForExperiment(exp.ID) - require.NoError(t, err) - ctx := context.Background() - err = db.DeleteExperiments(ctx, []int{exp.ID}) - require.NoError(t, err) -} diff --git a/master/internal/db/postgres_test_utils.go b/master/internal/db/postgres_test_utils.go index 33d0f52ef1f..e086bf8933d 100644 --- a/master/internal/db/postgres_test_utils.go +++ b/master/internal/db/postgres_test_utils.go @@ -361,13 +361,8 @@ func RequireMockExperimentParams( }, }, RawSearcher: &expconf.SearcherConfigV0{ - RawSingleConfig: &expconf.SingleConfigV0{ - RawMaxLength: &expconf.LengthV0{ - Unit: expconf.Batches, - Units: 1, - }, - }, - RawMetric: ptrs.Ptr(defaultSearcherMetric), + RawSingleConfig: &expconf.SingleConfigV0{}, + RawMetric: ptrs.Ptr(defaultSearcherMetric), }, } if p.HParamNames != nil { diff --git a/master/internal/db/postgres_trial.go b/master/internal/db/postgres_trial.go index d1f03667962..4fd4fe2a8a6 100644 --- a/master/internal/db/postgres_trial.go +++ b/master/internal/db/postgres_trial.go @@ -276,19 +276,6 @@ func TrialTaskIDsByTrialID(ctx context.Context, trialID int) ([]*model.RunTaskID return ids, nil } -// TrialByExperimentAndRequestID looks up a trial, returning an error if none exists. -func TrialByExperimentAndRequestID( - ctx context.Context, experimentID int, requestID model.RequestID, -) (*model.Trial, error) { - t := &model.Trial{} - if err := Bun().NewSelect().Model(t). - Where("experiment_id = ?", experimentID). - Where("request_id = ?", requestID).Scan(ctx); err != nil { - return nil, fmt.Errorf("error querying for trial %s: %w", requestID, err) - } - return t, nil -} - // TrialByTaskID looks up a trial by taskID, returning an error if none exists. // This errors if you called it with a non trial taskID. func TrialByTaskID(ctx context.Context, taskID model.TaskID) (*model.Trial, error) { @@ -666,6 +653,11 @@ func (db *PgDB) addTrialMetrics( default: return 0, fmt.Errorf("cannot add metric with non numeric 'epoch' value got %v", v) } + switch v := m.Metrics.AvgMetrics.Fields["epochs"].AsInterface().(type) { + case float64, nil: + default: + return 0, fmt.Errorf("cannot add metric with non numeric 'epochs' value got %v", v) + } return rollbacks, db.withTransaction(fmt.Sprintf("add trial metrics %s", mGroup), func(tx *sqlx.Tx) error { switch { @@ -1078,3 +1070,16 @@ RETURNING true`, bun.In(uniqueExpIDs)).Scan(ctx, &res) return nil } + +// TrialByExperimentAndRequestID looks up a trial, returning an error if none exists. +func TrialByExperimentAndRequestID( + ctx context.Context, experimentID int, requestID model.RequestID, +) (*model.Trial, error) { + var t model.Trial + if err := Bun().NewSelect().Model(&t). + Where("experiment_id = ?", experimentID). + Where("request_id = ?", requestID).Scan(ctx); err != nil { + return nil, fmt.Errorf("error querying for trial %s: %w", requestID, err) + } + return &t, nil +} diff --git a/master/internal/experiment.go b/master/internal/experiment.go index 01f1c955176..2cee92b62fc 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -40,8 +40,6 @@ import ( "github.com/determined-ai/determined/master/pkg/searcher" "github.com/determined-ai/determined/master/pkg/ssh" "github.com/determined-ai/determined/master/pkg/tasks" - "github.com/determined-ai/determined/proto/pkg/apiv1" - "github.com/determined-ai/determined/proto/pkg/experimentv1" ) const ( @@ -303,7 +301,7 @@ func (e *internalExperiment) start() error { return nil } - ops, err := e.searcher.InitialOperations() + creates, err := e.searcher.InitialTrials() if err != nil { err = errors.Wrap(err, "failed to generate initial operations") e.updateState(model.StateWithReason{ @@ -312,72 +310,30 @@ func (e *internalExperiment) start() error { }) return err } - e.processOperations(ops, nil) + e.handleSearcherActions(creates, nil) return nil } -func (e *internalExperiment) TrialCompleteOperation(msg experiment.TrialCompleteOperation) error { - e.mu.Lock() - defer e.mu.Unlock() - - state, ok := e.TrialSearcherState[msg.Op.RequestID] - switch { - case !ok: - return api.AsValidationError("no such trial") - case msg.Op != state.Op: - return api.AsValidationError("expected op %v but received op %v", state.Op, msg.Op) - case state.Complete: - return api.AsValidationError("received op %v which was previously completed", msg.Op) - } - - defer func() { - ops, err := e.searcher.ValidationCompleted(msg.RequestID, msg.Metric, msg.Op) - e.processOperations(ops, err) - }() - - state.Complete = true - e.TrialSearcherState[msg.Op.RequestID] = state - - t, ok := e.trials[msg.Op.RequestID] - if !ok { - return api.AsErrNotFound("trial not found") - } - - err := t.PatchSearcherState(state) - if err != nil { - e.syslog.WithError(err).Error("patching trial search state") - return err - } - - return nil -} - -func (e *internalExperiment) TrialReportProgress(msg experiment.TrialReportProgress) error { +func (e *internalExperiment) TrialReportProgress(requestID model.RequestID, msg experiment.TrialReportProgress) error { e.mu.Lock() defer e.mu.Unlock() progress := float64(msg.Progress) - if !msg.IsRaw { - e.searcher.SetTrialProgress(msg.RequestID, msg.Progress) - progress = e.searcher.Progress() - } - - if err := e.db.SaveExperimentProgress(e.ID, &progress); err != nil { + e.searcher.SetTrialProgress(requestID, progress) + experimentProgress := e.searcher.Progress() + if err := e.db.SaveExperimentProgress(e.ID, &experimentProgress); err != nil { e.syslog.WithError(err).Error("failed to save experiment progress") } return nil } -func (e *internalExperiment) TrialGetSearcherState(requestID model.RequestID) (experiment.TrialSearcherState, error) { +func (e *internalExperiment) TrialReportValidation(requestID model.RequestID, metrics map[string]interface{}) error { e.mu.Lock() defer e.mu.Unlock() - - state, ok := e.TrialSearcherState[requestID] - if !ok { - return state, api.AsErrNotFound("trial has no state") - } - return state, nil + ops, err := e.searcher.ValidationCompleted(requestID, metrics) + e.handleSearcherActions(ops, err) + return nil } func (e *internalExperiment) UserInitiatedEarlyTrialExit(msg experiment.UserInitiatedEarlyTrialExit) error { @@ -527,88 +483,6 @@ func (e *internalExperiment) stop() error { return nil } -func (e *internalExperiment) PerformSearcherOperations(msg *apiv1.PostSearcherOperationsRequest) error { - e.mu.Lock() - defer e.mu.Unlock() - - queue, err := e.searcher.GetCustomSearcherEventQueue() - if err != nil { - return status.Error(codes.Internal, err.Error()) - } - var ops []searcher.Operation - for _, searcherOp := range msg.SearcherOperations { - switch concreteOperation := searcherOp.GetUnion().(type) { - case *experimentv1.SearcherOperation_CreateTrial: - op, err := searcher.CreateFromProto(concreteOperation, model.TrialWorkloadSequencerType) - if err != nil { - e.syslog.Error(err) - } else { - ops = append(ops, *op) - } - case *experimentv1.SearcherOperation_ShutDown: - op, err := searcher.ShutdownFromProto(concreteOperation) - if err != nil { - e.syslog.Error(err) - } else { - ops = append(ops, *op) - } - case *experimentv1.SearcherOperation_TrialOperation: - if sub, ok := concreteOperation.TrialOperation.GetUnion().(*experimentv1.TrialOperation_ValidateAfter); ok { - op, err := searcher.ValidateAfterFromProto(sub) - if err != nil { - e.syslog.Error(err) - } else { - ops = append(ops, *op) - } - } - case *experimentv1.SearcherOperation_CloseTrial: - op, err := searcher.CloseFromProto(concreteOperation) - if err != nil { - e.syslog.Error(err) - } else { - ops = append(ops, *op) - } - case *experimentv1.SearcherOperation_SetSearcherProgress: - ops = append(ops, searcher.SetSearcherProgressFromProto(concreteOperation)) - default: - e.syslog.Errorf("unimplemented op %+v", concreteOperation) - } - } - e.syslog.Infof("processing searcher operations %+v", ops) - - // Remove newly processed events from queue. - if err := queue.RemoveUpTo(int(msg.TriggeredByEvent.Id)); err != nil { - return status.Error(codes.Internal, "failed to remove events from queue") - } - e.searcher.Record(ops) - e.processOperations(ops, nil) - return nil -} - -func (e *internalExperiment) GetSearcherEventsWatcher() (*searcher.EventsWatcher, error) { - e.mu.Lock() - defer e.mu.Unlock() - - queue, err := e.searcher.GetCustomSearcherEventQueue() - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - watcher, err := queue.Watch() - return &watcher, err -} - -func (e *internalExperiment) UnwatchEvents(id uuid.UUID) error { - e.mu.Lock() - defer e.mu.Unlock() - - queue, err := e.searcher.GetCustomSearcherEventQueue() - if err != nil { - return status.Error(codes.Internal, err.Error()) - } - queue.Unwatch(id) - return nil -} - func (e *internalExperiment) ActivateExperiment() error { e.mu.Lock() defer e.mu.Unlock() @@ -672,21 +546,21 @@ func (e *internalExperiment) KillExperiment() error { return nil } -func (e *internalExperiment) TrialClosed(requestID model.RequestID, reason *model.ExitedReason) { +func (e *internalExperiment) TrialExited(requestID model.RequestID, reason *model.ExitedReason) { e.mu.Lock() defer e.mu.Unlock() - e.trialClosed(requestID, reason) + e.trialExited(requestID, reason) } -func (e *internalExperiment) trialClosed(requestID model.RequestID, reason *model.ExitedReason) { +func (e *internalExperiment) trialExited(requestID model.RequestID, reason *model.ExitedReason) { if reason != nil { e.trialReportEarlyExit(requestID, *reason) } delete(e.trials, requestID) - ops, err := e.searcher.TrialClosed(requestID) - e.processOperations(ops, err) + ops, err := e.searcher.TrialExited(requestID) + e.handleSearcherActions(ops, err) if e.canTerminate() { if err := e.stop(); err != nil { e.syslog.WithError(err).Error("failed to stop experiment on trial closed") @@ -695,25 +569,24 @@ func (e *internalExperiment) trialClosed(requestID model.RequestID, reason *mode } func (e *internalExperiment) trialReportEarlyExit(requestID model.RequestID, reason model.ExitedReason) { - e.syslog.WithField("requestId", requestID).Info("experiment received trial early exit") + e.syslog.WithField("request-id", requestID).Info("experiment received trial early exit") state, ok := e.TrialSearcherState[requestID] if !ok { - e.syslog.WithField("requestID", requestID).Error("trial has no searcher state on early exit") + e.syslog.WithField("request-id", requestID).Error("trial has no searcher state on early exit") return } defer func() { ops, err := e.searcher.TrialExitedEarly(requestID, reason) - e.processOperations(ops, err) + e.handleSearcherActions(ops, err) }() - state.Complete = true - state.Closed = true + state.EarlyExitedByUserCode = true e.TrialSearcherState[requestID] = state t, ok := e.trials[requestID] if !ok { - e.syslog.WithField("requestID", requestID).Warnf("missing trial to patch on early exit") + e.syslog.WithField("trial-id", requestID).Warnf("missing trial to patch on early exit") return } @@ -726,8 +599,8 @@ func (e *internalExperiment) trialReportEarlyExit(requestID model.RequestID, rea func (e *internalExperiment) trialCreated(t *trial) { requestID := t.searcher.Create.RequestID if !e.searcher.TrialIsCreated(requestID) { - ops, err := e.searcher.TrialCreated(requestID) - e.processOperations(ops, err) + actions, err := e.searcher.TrialCreated(requestID) + e.handleSearcherActions(actions, err) } e.trials[requestID] = t } @@ -736,51 +609,16 @@ func (e *internalExperiment) trialCreated(t *trial) { // last experiment checkpoint. func (e *internalExperiment) restoreTrials() { for _, state := range e.TrialSearcherState { - checkpoint, err := e.checkpointForCreate(state.Create) - if err != nil { - e.updateState(model.StateWithReason{ - State: model.StoppingErrorState, - InformationalReason: fmt.Sprintf("failed getting checkpoint to restore with error %v", err), - }) - e.syslog.Error(err) - return - } - e.restoreTrial(checkpoint, state) + e.restoreTrial(e.warmStartCheckpoint, state) } } -func (e *internalExperiment) handleContinueExperiment(reqID model.RequestID) (*int, bool) { - var continueFromTrialID *int - if e.continueTrials { - switch trial, err := internaldb.TrialByExperimentAndRequestID(context.TODO(), e.ID, reqID); { - case errors.Is(err, sql.ErrNoRows): - // Trial doesn't exist, don't do anything - case err != nil: - e.updateState(model.StateWithReason{ - State: model.StoppingErrorState, - InformationalReason: fmt.Sprintf( - "hp search unable to get trial for the Request ID %v with error %v", reqID, err), - }) - e.syslog.Error(err) - return nil, true - case err == nil: - if trial.State != model.CompletedState { - continueFromTrialID = &trial.ID - } else { - e.trialClosed(reqID, nil) - return nil, true - } - } - } - return continueFromTrialID, false -} - -func (e *internalExperiment) processOperations( - ops []searcher.Operation, err error, +func (e *internalExperiment) handleSearcherActions( + actions []searcher.Action, err error, ) { // Only continue for experiments in stopping states if the searcher operations are all // type Shutdown failures. - if _, ok := model.StoppingStates[e.State]; ok && !allSearcherShutdowns(ops) { + if _, ok := model.StoppingStates[e.State]; ok && !allSearcherShutdowns(actions) { return } @@ -796,77 +634,57 @@ func (e *internalExperiment) processOperations( defer e.snapshotAndSave() updatedTrials := make(map[model.RequestID]bool) - for _, operation := range ops { - e.syslog.Debugf("handling searcher op: %v", operation) - switch op := operation.(type) { + for _, action := range actions { + e.syslog.Debugf("handling searcher action: %v", action) + switch action := action.(type) { case searcher.Create: - _, ok := e.trials[op.RequestID] + _, ok := e.trials[action.RequestID] if ok { - e.syslog.Errorf("trial %s already exists", op.RequestID) + e.syslog.Errorf("trial %s already exists", action.RequestID) continue } - continueFromTrialID, closed := e.handleContinueExperiment(op.RequestID) + continueFromTrialID, closed := e.handleContinueExperiment(action.RequestID) if closed { continue } + state := experiment.TrialSearcherState{Create: action} + e.TrialSearcherState[action.RequestID] = state - checkpoint, err := e.checkpointForCreate(op) - if err != nil { - e.updateState(model.StateWithReason{ - State: model.StoppingErrorState, - InformationalReason: fmt.Sprintf( - "hp search unable to get checkpoint for new trial with error %v", err), - }) - e.syslog.Error(err) - continue - } config := schemas.Copy(e.activeConfig) - state := experiment.TrialSearcherState{Create: op, Complete: true} - e.TrialSearcherState[op.RequestID] = state clonedSpec, err := e.taskSpec.Clone() if err != nil { e.syslog.WithError(err).Error("failed to create trial") - e.trialClosed(op.RequestID, ptrs.Ptr(model.Errored)) + e.trialExited(action.RequestID, ptrs.Ptr(model.Errored)) continue } t, err := newTrial( - e.logCtx, trialTaskID(e.ID, op.RequestID), e.JobID, e.StartTime, e.ID, e.State, - state, e.rm, e.db, config, checkpoint, clonedSpec, e.generatedKeys, false, - nil, continueFromTrialID, e.TrialClosed, + e.logCtx, trialTaskID(e.ID, action.RequestID), e.JobID, e.StartTime, e.ID, e.State, + state, e.rm, e.db, config, e.warmStartCheckpoint, clonedSpec, e.generatedKeys, false, + nil, continueFromTrialID, e.TrialExited, ) if err != nil { e.syslog.WithError(err).Error("failed to create trial") - e.trialClosed(op.RequestID, ptrs.Ptr(model.Errored)) + e.trialExited(action.RequestID, ptrs.Ptr(model.Errored)) continue } e.trialCreated(t) - case searcher.ValidateAfter: - state := e.TrialSearcherState[op.RequestID] - state.Op = op - state.Complete = false - e.TrialSearcherState[op.RequestID] = state - updatedTrials[op.RequestID] = true - case searcher.SetSearcherProgress: - if err := e.searcher.SetCustomSearcherProgress(op.Progress); err != nil { - e.syslog.WithError(err).Error("failed to set searcher progress") - } - case searcher.Close: - state := e.TrialSearcherState[op.RequestID] - state.Closed = true - e.TrialSearcherState[op.RequestID] = state - updatedTrials[op.RequestID] = true + case searcher.Stop: + state := e.TrialSearcherState[action.RequestID] + state.EarlyStoppedBySearcher = true + e.TrialSearcherState[action.RequestID] = state + updatedTrials[action.RequestID] = true case searcher.Shutdown: - e.syslog.WithField("op", operation).Info("searcher shutdown") + e.syslog.WithField("action", action).Info("searcher shutdown") switch { - case op.Failure: + case action.Failure: e.updateState(model.StateWithReason{ State: model.StoppingErrorState, InformationalReason: "hp search failed", }) - case op.Cancel: + case action.Cancel: e.updateState(model.StateWithReason{ State: model.StoppingCanceledState, InformationalReason: "hp search canceled", @@ -878,23 +696,23 @@ func (e *internalExperiment) processOperations( }) } default: - panic(fmt.Sprintf("unexpected operation: %v", op)) + panic(fmt.Sprintf("unexpected action: %v", action)) } } var g errgroup.Group g.SetLimit(maxConcurrentTrialOps) - for requestID := range updatedTrials { - syslog := e.syslog.WithField("requestID", requestID) - t, ok := e.trials[requestID] + for rID := range updatedTrials { + syslog := e.syslog.WithField("trial-id", rID) + t, ok := e.trials[rID] if !ok { - syslog.Errorf("processOperations invalid requestID") + syslog.Errorf("handleSearcherActions invalid trialID") continue } g.Go(func() error { - err := t.PatchSearcherState(e.TrialSearcherState[requestID]) + err := t.PatchSearcherState(e.TrialSearcherState[rID]) if err != nil { - syslog.WithError(err).Error("processOperations updating trial search state") + syslog.WithError(err).Error("handleSearcherActions updating trial search state") } return nil }) @@ -902,6 +720,32 @@ func (e *internalExperiment) processOperations( _ = g.Wait() // Errors are handled in g.Go. } +func (e *internalExperiment) handleContinueExperiment(reqID model.RequestID) (*int, bool) { + var continueFromTrialID *int + if e.continueTrials { + switch trial, err := internaldb.TrialByExperimentAndRequestID(context.TODO(), e.ID, reqID); { + case errors.Is(err, sql.ErrNoRows): + // Trial doesn't exist, don't do anything + case err != nil: + e.updateState(model.StateWithReason{ + State: model.StoppingErrorState, + InformationalReason: fmt.Sprintf( + "hp search unable to get trial for the Request ID %v with error %v", reqID, err), + }) + e.syslog.Error(err) + return nil, true + case err == nil: + if trial.State != model.CompletedState { + continueFromTrialID = &trial.ID + } else { + e.trialExited(reqID, nil) + return nil, true + } + } + } + return continueFromTrialID, false +} + func trialTaskID(eID int, rID model.RequestID) model.TaskID { return model.TaskID(fmt.Sprintf("%d.%s", eID, rID)) } @@ -925,24 +769,6 @@ func experimentIDFromTrialTaskID(taskID model.TaskID) (int, error) { return experimentID, nil } -func (e *internalExperiment) checkpointForCreate(op searcher.Create) (*model.Checkpoint, error) { - checkpoint := e.warmStartCheckpoint - // If the Create specifies a checkpoint, ignore the experiment-wide one. - if op.Checkpoint != nil { - trial, err := internaldb.TrialByExperimentAndRequestID(context.TODO(), e.ID, op.Checkpoint.RequestID) - if err != nil { - return nil, errors.Wrapf(err, - "invalid request ID in Create operation: %d", op.Checkpoint.RequestID) - } - checkpointModel, err := checkpointFromTrialIDOrUUID(e.db, &trial.ID, nil) - if err != nil { - return nil, errors.Wrap(err, "checkpoint not found") - } - checkpoint = checkpointModel - } - return checkpoint, nil -} - func (e *internalExperiment) updateState(state model.StateWithReason) bool { if wasPatched, err := e.Transition(state.State); err != nil { e.syslog.Errorf("error transitioning experiment state: %s", err) @@ -1194,9 +1020,9 @@ func (e *internalExperiment) setRP(resourcePool string) error { return nil } -func allSearcherShutdowns(ops []searcher.Operation) bool { - for _, operation := range ops { - if _, ok := operation.(searcher.Shutdown); !ok { +func allSearcherShutdowns(actions []searcher.Action) bool { + for _, action := range actions { + if _, ok := action.(searcher.Shutdown); !ok { return false } } diff --git a/master/internal/experiment/authz_basic_impl.go b/master/internal/experiment/authz_basic_impl.go index ed7e02a3722..1e8488d0628 100644 --- a/master/internal/experiment/authz_basic_impl.go +++ b/master/internal/experiment/authz_basic_impl.go @@ -118,13 +118,6 @@ func (a *ExperimentAuthZBasic) CanSetExperimentsCheckpointGCPolicy( return nil } -// CanRunCustomSearch always returns a nil error. -func (a *ExperimentAuthZBasic) CanRunCustomSearch( - ctx context.Context, curUser model.User, e *model.Experiment, -) error { - return nil -} - func init() { AuthZProvider.Register("basic", &ExperimentAuthZBasic{}) } diff --git a/master/internal/experiment/authz_iface.go b/master/internal/experiment/authz_iface.go index 923d3ce13f8..51c21048bcb 100644 --- a/master/internal/experiment/authz_iface.go +++ b/master/internal/experiment/authz_iface.go @@ -111,10 +111,6 @@ type ExperimentAuthZ interface { CanSetExperimentsCheckpointGCPolicy( ctx context.Context, curUser model.User, e *model.Experiment, ) error - - // GET /api/v1/experiments/:exp_id/searcher_events - // POST /api/v1/experiments/:exp_id/searcher_operations - CanRunCustomSearch(ctx context.Context, curUser model.User, e *model.Experiment) error } // AuthZProvider is the authz registry for experiments. diff --git a/master/internal/experiment/authz_permissive.go b/master/internal/experiment/authz_permissive.go index bad32d29824..ed7faf8b136 100644 --- a/master/internal/experiment/authz_permissive.go +++ b/master/internal/experiment/authz_permissive.go @@ -127,14 +127,6 @@ func (p *ExperimentAuthZPermissive) CanSetExperimentsCheckpointGCPolicy( return (&ExperimentAuthZBasic{}).CanSetExperimentsCheckpointGCPolicy(ctx, curUser, e) } -// CanRunCustomSearch calls RBAC authz but enforces basic authz. -func (p *ExperimentAuthZPermissive) CanRunCustomSearch( - ctx context.Context, curUser model.User, e *model.Experiment, -) error { - _ = (&ExperimentAuthZRBAC{}).CanRunCustomSearch(ctx, curUser, e) - return (&ExperimentAuthZBasic{}).CanRunCustomSearch(ctx, curUser, e) -} - func init() { AuthZProvider.Register("permissive", &ExperimentAuthZPermissive{}) } diff --git a/master/internal/experiment/authz_rbac.go b/master/internal/experiment/authz_rbac.go index 25f876a0c21..6ea154cf147 100644 --- a/master/internal/experiment/authz_rbac.go +++ b/master/internal/experiment/authz_rbac.go @@ -377,13 +377,6 @@ func (a *ExperimentAuthZRBAC) CanSetExperimentsCheckpointGCPolicy( return a.CanEditExperiment(ctx, curUser, e) } -// CanRunCustomSearch checks if a user has permission to run customer search. -func (a *ExperimentAuthZRBAC) CanRunCustomSearch( - ctx context.Context, curUser model.User, e *model.Experiment, -) error { - return a.CanEditExperiment(ctx, curUser, e) // TODO verify with custom search project. -} - func init() { AuthZProvider.Register("rbac", &ExperimentAuthZRBAC{}) } diff --git a/master/internal/experiment/experiment_iface.go b/master/internal/experiment/experiment_iface.go index 854f78901bb..5320b4ab3e0 100644 --- a/master/internal/experiment/experiment_iface.go +++ b/master/internal/experiment/experiment_iface.go @@ -1,13 +1,10 @@ package experiment import ( - "github.com/google/uuid" - "github.com/determined-ai/determined/master/internal/rm/tasklist" "github.com/determined-ai/determined/master/internal/sproto" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/searcher" - "github.com/determined-ai/determined/proto/pkg/apiv1" ) // ExperimentRegistry is a registry of all experiments. @@ -17,20 +14,11 @@ var ExperimentRegistry = tasklist.NewRegistry[int, Experiment]() // Experiment-specific interface types. type ( - // TrialCompleteOperation is a message sent to an experiment to indicate that a trial has - // completed an operation. - TrialCompleteOperation struct { - RequestID model.RequestID - Op searcher.ValidateAfter - Metric interface{} - } - // TrialReportProgress is a message sent to an experiment to indicate that a trial has // reported progress. TrialReportProgress struct { - RequestID model.RequestID - Progress searcher.PartialUnits - IsRaw bool + Progress searcher.PartialUnits + IsRaw bool } // UserInitiatedEarlyTrialExit is a user-injected message, provided through the early exit API. It @@ -47,29 +35,24 @@ type ( State model.StateWithReason } - // TrialSearcherState is a message sent to an experiment to indicate that a trial has + // TrialSearcherState is a message sent to an search to indicate that a run has // changed searcher state. TrialSearcherState struct { - Create searcher.Create - Op searcher.ValidateAfter - Complete bool - Closed bool + Create searcher.Create + EarlyStoppedBySearcher bool + EarlyExitedByUserCode bool } ) // Experiment is an interface that represents an experiment. type Experiment interface { - TrialCompleteOperation(msg TrialCompleteOperation) error - TrialReportProgress(msg TrialReportProgress) error - TrialGetSearcherState(requestID model.RequestID) (TrialSearcherState, error) + TrialReportProgress(requestID model.RequestID, msg TrialReportProgress) error + TrialReportValidation(requestID model.RequestID, metrics map[string]interface{}) error UserInitiatedEarlyTrialExit(msg UserInitiatedEarlyTrialExit) error PatchTrialState(msg PatchTrialState) error SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) SetGroupWeight(weight float64) error SetGroupPriority(priority int) error - PerformSearcherOperations(msg *apiv1.PostSearcherOperationsRequest) error - GetSearcherEventsWatcher() (*searcher.EventsWatcher, error) - UnwatchEvents(id uuid.UUID) error ActivateExperiment() error PauseExperiment() error CancelExperiment() error diff --git a/master/internal/restore.go b/master/internal/restore.go index eeb2be41507..cfafe976421 100644 --- a/master/internal/restore.go +++ b/master/internal/restore.go @@ -6,25 +6,25 @@ import ( "encoding/json" "fmt" - "github.com/determined-ai/determined/master/internal/experiment" - "github.com/determined-ai/determined/master/internal/rm" - "github.com/determined-ai/determined/master/internal/sproto" - "github.com/determined-ai/determined/master/internal/workspace" + "github.com/determined-ai/determined/master/pkg/ptrs" + "github.com/determined-ai/determined/master/pkg/schemas" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/internal/experiment" + "github.com/determined-ai/determined/master/internal/rm" + "github.com/determined-ai/determined/master/internal/sproto" "github.com/determined-ai/determined/master/internal/user" + "github.com/determined-ai/determined/master/internal/workspace" "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/schemas" - "github.com/determined-ai/determined/master/pkg/searcher" + "github.com/determined-ai/determined/master/pkg/nprand" ) // The current experiment snapshot version. Once this is incremented, older versions should be // shimmed. Experiment and trial snapshots share a version currently. -const experimentSnapshotVersion = 5 +const experimentSnapshotVersion = 6 // Restore works by restoring from distributed consistent snapshots taken through the course // of an experiment. Snapshots within the system flow from the bottom up, starting with the @@ -134,14 +134,11 @@ func (m *Master) restoreExperiment(expModel *model.Experiment) error { return nil } -// restoreTrial takes the a searcher.Create and attempts to restore the trial that would be -// associated with it. On failure, the trial is just reset to the start and errors are logged. func (e *internalExperiment) restoreTrial( ckpt *model.Checkpoint, searcher experiment.TrialSearcherState, ) { l := e.syslog.WithField("request-id", searcher.Create.RequestID) l.Debug("restoring trial") - var trialID *int var terminal bool switch trial, err := db.TrialByExperimentAndRequestID(context.TODO(), @@ -162,7 +159,6 @@ func (e *internalExperiment) restoreTrial( terminal = true } } - taskID := trialTaskID(e.ID, searcher.Create.RequestID) if !terminal && trialID != nil { trialTaskIDs, err := db.TrialTaskIDsByTrialID(context.TODO(), *trialID) @@ -177,11 +173,9 @@ func (e *internalExperiment) restoreTrial( taskID = trialTaskIDs[len(trialTaskIDs)-1].TaskID } } - - // In the event a trial is terminal and is not recorded in the searcher, replay the close. if terminal { if !e.searcher.TrialIsClosed(searcher.Create.RequestID) { - e.trialClosed(searcher.Create.RequestID, nil) + e.trialExited(searcher.Create.RequestID, nil) } return } @@ -190,17 +184,16 @@ func (e *internalExperiment) restoreTrial( l.Errorf("trial %s was already restored", searcher.Create.RequestID) return } - config := schemas.Copy(e.activeConfig) t, err := newTrial( e.logCtx, taskID, e.JobID, e.StartTime, e.ID, e.State, searcher, e.rm, e.db, config, ckpt, e.taskSpec, e.generatedKeys, true, trialID, - nil, e.TrialClosed, + nil, e.TrialExited, ) if err != nil { l.WithError(err).Error("failed restoring trial, aborting restore") if !e.searcher.TrialIsClosed(searcher.Create.RequestID) { - e.trialClosed(searcher.Create.RequestID, ptrs.Ptr(model.Errored)) + e.trialExited(searcher.Create.RequestID, ptrs.Ptr(model.Errored)) } return } @@ -246,24 +239,21 @@ var experimentSnapshotShims = map[int]snapshotShimFunc{ 1: shimExperimentSnapshotV1, 2: shimExperimentSnapshotV2, 4: shimExperimentSnapshotV4, + 5: shimExperimentSnapshotV5, } // shimExperimentSnapshot shims an experiment snapshot to the version required by the master, // returning an error in the event the shim fails or the snapshot version is greater // than the current version (which could happen in a downgrade). func shimExperimentSnapshot(snapshot []byte, version int) ([]byte, error) { - return shimSnapshot(experimentSnapshotShims, snapshot, version) -} - -func shimSnapshot(shims map[int]snapshotShimFunc, snapshot []byte, version int) ([]byte, error) { if version > experimentSnapshotVersion { return nil, fmt.Errorf("cannot shim from %d to %d", version, experimentSnapshotVersion) } var err error for version < experimentSnapshotVersion { - shim, ok := shims[version] + shim, ok := experimentSnapshotShims[version] if !ok { - return nil, fmt.Errorf("missing shim from %d to %d", version, experimentSnapshotVersion) + return nil, fmt.Errorf("missing shim from %d to %d", version, version+1) } if snapshot, err = shim(snapshot); err != nil { return nil, errors.Wrapf(err, "failed to shim snapshot") @@ -274,7 +264,7 @@ func shimSnapshot(shims map[int]snapshotShimFunc, snapshot []byte, version int) } // snapshotShimFunc is a shimming function. -type snapshotShimFunc func([]byte) ([]byte, error) +type snapshotShimFunc func(snapshot []byte) ([]byte, error) // Version 0 => 1 shims @@ -344,6 +334,18 @@ func shimExperimentSnapshotV1(snapshot []byte) ([]byte, error) { // Version 2 => 3 shims +// Legacy types which no longer exist in the searcher package, but needed to serialize old snapshots. +const ( + CreateOperation OperationType = 0 + TrainOperation OperationType = 1 + ValidateOperation OperationType = 2 + CloseOperation OperationType = 4 + ValidateAfterOperation OperationType = 5 +) + +// OperationType is a legacy searcher operation type. +type OperationType int + // shimExperimentSnapshotV2 shims a v2 snapshot to a v3 snapshot. From v2 to v3, // Train and Validate operations were merged into a single ValidateAfter operation // that indicates to the trial the total units to train before reporting a validation @@ -361,15 +363,15 @@ func shimExperimentSnapshotV2(snapshot []byte) ([]byte, error) { var newOperationsList []map[string]interface{} for _, iOp := range operationsList { op := iOp.(map[string]interface{}) - switch searcher.OperationType(op["OperationType"].(float64)) { - case searcher.TrainOperation: + switch OperationType(op["OperationType"].(float64)) { + case TrainOperation: op := op["Operation"].(map[string]interface{}) requestID := op["RequestID"].(string) length := op["Length"].(map[string]interface{}) for unit, units := range length { totalUnitsForTrial[requestID] += units.(float64) newOperationsList = append(newOperationsList, map[string]interface{}{ - "OperationType": searcher.ValidateAfterOperation, + "OperationType": ValidateAfterOperation, "Operation": map[string]interface{}{ "RequestID": requestID, "Length": map[string]interface{}{ @@ -378,7 +380,7 @@ func shimExperimentSnapshotV2(snapshot []byte) ([]byte, error) { }, }) } - case searcher.ValidateOperation: + case ValidateOperation: continue default: newOperationsList = append(newOperationsList, op) @@ -444,3 +446,96 @@ func shimExperimentSnapshotV4(snapshot []byte) ([]byte, error) { return json.Marshal(experimentSnapshotV4) } + +type v4SearcherState struct { + TrialsRequested int `json:"trials_requested"` + TrialsCreated map[model.RequestID]bool `json:"trials_created"` + TrialsClosed map[model.RequestID]bool `json:"trials_closed"` + Exits map[model.RequestID]bool `json:"exits"` + Cancels map[model.RequestID]bool `json:"cancels"` + Failures map[model.RequestID]bool `json:"failures"` + TrialProgress map[model.RequestID]float64 `json:"trial_progress"` + Rand *nprand.State `json:"rand"` + SearchMethodState json.RawMessage `json:"search_method_state"` +} +type v4CreateOp struct { + HParams map[string]interface{} `json:"hparams"` + RequestID model.RequestID `json:"request_id"` + TrialSeed uint32 `json:"trial_seed"` +} + +type v4TrialSearcherState struct { + Create v4CreateOp + Stop bool + Closed bool + Complete bool +} +type experimentSnapshotV4 struct { + SearcherState v4SearcherState `json:"searcher_state"` + TrialSearcherState map[model.RequestID]v4TrialSearcherState `json:"trial_searcher_state"` +} + +// shimExperimentSnapshotV5 shims a v5 snapshot to a v6 snapshot. From v5 to v6: +// - `searcher_state.CompletedOperations` -> dropped +// - `searcher_state.Shutdown` -> dropped +// +// - `trial_searcher_state.Create (searcher.Operation)` -> `trial_searcher_state.Create (searcher.Action)` +// - `trial_searcher_state.Complete` -> dropped +// - `trial_searcher_state.Op (searcher.ValidateAfter)` -> dropped +// - `trial_searcher_state.Stop` -> dropped. +func shimExperimentSnapshotV5(snapshot []byte) ([]byte, error) { + v4ExperimentSnapshot := experimentSnapshotV4{} + + if err := json.Unmarshal(snapshot, &v4ExperimentSnapshot); err != nil { + return nil, err + } + + searchMethodStateV4 := map[string]interface{}{} + err := json.Unmarshal(v4ExperimentSnapshot.SearcherState.SearchMethodState, &searchMethodStateV4) + if err != nil { + return nil, ExperimentSnapshotShimError{Message: err.Error()} + } + searchMethodType, ok := searchMethodStateV4["search_method_type"] + if !ok { + return nil, ExperimentSnapshotShimError{Message: "unable to parse search_method_type"} + } + + switch searchMethodType { + case "single": + case "random": + case "grid": + default: + return nil, ExperimentSnapshotShimError{Message: "unsupported search_method_type"} + } + + trialSearcherState := make(map[model.RequestID]interface{}) + + for rID, searcherState := range v4ExperimentSnapshot.TrialSearcherState { + trialSearcherState[rID] = map[string]interface{}{ + "Create": map[string]interface{}{ + "hparams": searcherState.Create.HParams, + "trial_seed": searcherState.Create.TrialSeed, + "request_id": searcherState.Create.RequestID, + }, + "EarlyStoppedBySearcher": searcherState.Stop || searcherState.Complete, + "EarlyExitedByUserCode": searcherState.Closed && searcherState.Complete, + } + } + + experimentSnapshotV5 := map[string]interface{}{ + "searcher_state": map[string]interface{}{ + "trials_requested": v4ExperimentSnapshot.SearcherState.TrialsRequested, + "trials_created": v4ExperimentSnapshot.SearcherState.TrialsCreated, + "trials_closed": v4ExperimentSnapshot.SearcherState.TrialsClosed, + "exits": v4ExperimentSnapshot.SearcherState.Exits, + "cancels": v4ExperimentSnapshot.SearcherState.Cancels, + "failures": v4ExperimentSnapshot.SearcherState.Failures, + "trial_progress": v4ExperimentSnapshot.SearcherState.TrialProgress, + "rand": v4ExperimentSnapshot.SearcherState.Rand, + "search_method_state": searchMethodStateV4, + }, + "trial_searcher_state": trialSearcherState, + } + + return json.Marshal(experimentSnapshotV5) +} diff --git a/master/internal/restore_test.go b/master/internal/restore_test.go index af14e869288..4023a83536b 100644 --- a/master/internal/restore_test.go +++ b/master/internal/restore_test.go @@ -20,8 +20,56 @@ func TestShimExperimentSnapshotV4(t *testing.T) { require.JSONEq(t, string(newSnapshot), string(actual)) } +func TestShimExperimentSnapshotV5(t *testing.T) { + cases := []struct { + name searcher.SearchMethodType + v4Snapshot []byte + v5Snapshot []byte + err string + }{ + { + name: searcher.SingleSearch, + //nolint:lll + v4Snapshot: []byte(`{"searcher_state": {"rand": {"key": [3922433409, 243371046, 1078118500, 751359450, 1341787537, 4110007575, 3830714619, 215586197, 202234178, 695788744, 568122858, 801842164, 915101998, 2089585430, 560765197, 132023793, 3625985341, 1033888033, 4275540481, 4081480301, 3803198840, 2062925182, 3815214353, 2217168443, 284556261, 4216276606, 3143802969, 4077949604, 1360661190, 1579111912, 3397581176, 2605001170, 1444943786, 1103161266, 2475001171, 586405187, 440758113, 69964893, 2302367679, 2318485289, 2699422072, 949642636, 992261637, 3418960830, 2747796370, 1988242324, 1336558633, 4144877284, 1447227155, 2130144917, 249701494, 845842848, 3562805067, 461807133, 316172300, 1019162321, 2042173338, 2357208493, 984738996, 2977486791, 4204759905, 3765157938, 1416537904, 673684512, 1014036428, 3291531021, 2149014833, 211182184, 1899637415, 2967746221, 1721885906, 226438235, 2148590681, 1789583513, 1586306530, 1162436037, 2949013157, 869268556, 2754664383, 190339737, 2683355822, 1858335276, 752629074, 3123663892, 3566745314, 2416019451, 2343621935, 4064029673, 2375292477, 4237178213, 733006903, 3966381652, 1902496609, 1636160960, 1441547075, 837638138, 885960515, 1113758007, 2646826291, 1038503594, 3726531118, 1965291352, 4027030242, 3233398909, 2964755703, 966551311, 1083992523, 1124285624, 3192801484, 2731176820, 2400133886, 2523660457, 1033111323, 2914428172, 1163901245, 1320025743, 2177720309, 1677889669, 2345935772, 2794475418, 9574722, 1397363731, 3692199984, 452561608, 112544138, 2256573095, 2829353338, 1987429988, 555809442, 1666429314, 215970428, 3154343711, 2051241582, 3877008887, 1816890945, 4013344821, 2102130895, 4186655536, 2673431404, 4026559573, 2176376772, 1221630340, 2467768183, 84713789, 3270911636, 1571685966, 1581830524, 4143291311, 2923146623, 3610913076, 2622548353, 464460930, 1486913384, 3667888493, 4266300151, 3341431843, 1965213295, 2565733412, 761474095, 761709082, 3177157908, 3043376221, 396311914, 3167326505, 1734174898, 1402058836, 2225699320, 1370103810, 1206643084, 3622576190, 109652665, 1635918058, 3535739916, 2701481795, 1919438907, 1981477202, 1360788607, 777792032, 4044337196, 2681268168, 2118070045, 1276795554, 3146333618, 746239793, 2520869193, 1894154101, 1005996243, 1898277441, 3296603528, 3591769481, 3145574469, 1323130515, 1238470035, 1076131551, 2785215479, 2287060195, 1378079485, 102967153, 2942614756, 1776625394, 938257756, 2470323776, 2927673692, 129462980, 221619173, 3277410069, 581765105, 1744873192, 1725783416, 2487492363, 3776340706, 848370815, 3284445102, 1648634245, 68108052, 3595739205, 2779510023, 2872989184, 1845125748, 121458802, 786888262, 3047839649, 2036720406, 1066086686, 3620519326, 2825286213, 1043364134, 745183959, 2008388075, 3781903549, 4088204083, 2422084247, 2245918792, 163997977, 4072121451, 377438777, 3434214799, 47957125, 901584641, 1663967634, 2224354240, 2140121772, 2239607504, 1296158592, 1144490013, 363777504, 352049636, 3537828370, 2018830719, 4106625267, 791310915, 1141549081, 383680238, 2211229427, 2955313933, 1786602475, 861117031, 1657038393, 49008346, 2757187031, 2415960764, 2063869385, 4250446539, 4072556442, 3080810042, 3746048829, 1207880898, 743812379, 1015188201, 1250574590, 2697317474, 1490831142, 3975889652, 3507228768, 3917477006, 922984658, 1548413943, 3607819493, 642573978, 2802934424, 1071690339, 2487490598, 353709112, 3041703309, 3991444003, 3766125511, 4022423108, 2255595963, 4244631415, 1858811130, 1272720251, 4254073557, 3076459717, 1790136700, 3582509861, 2769798461, 3229452726, 2752397444, 701766670, 1437507306, 1004148691, 578333132, 82009872, 3366243481, 1209720848, 1659428252, 789199905, 2021015352, 3384069524, 658409068, 1580305341, 4213315272, 4208459453, 2059532450, 893773418, 653840510, 1457966480, 393232134, 3136440598, 1431618763, 1172519077, 2575368795, 939388973, 1509101907, 988758589, 565340075, 2869588947, 1972475685, 4273757903, 204012057, 1101337211, 1533180088, 2801615526, 2440934744, 338200579, 3477442758, 2757306425, 2685737460, 3224302362, 1663459583, 696652877, 4112318289, 2675890357, 745289117, 2365755424, 2933967187, 4022266172, 1505363667, 2297045170, 570110611, 3346324247, 2491642679, 3975465941, 1581427559, 3057064987, 3791566228, 1635441872, 1271004942, 1877908946, 2574434132, 1044045440, 3454314545, 1253401130, 103952148, 3575409376, 490573770, 2501813399, 1298020930, 382668967, 2514416127, 3341542436, 1092494529, 195927818, 924628104, 3732801821, 1514695853, 1522275894, 219321854, 1556133239, 504225130, 1473692730, 406339414, 424484390, 4274713028, 2913973779, 4130369073, 1360365870, 628908339, 4154982639, 757598615, 2791015278, 2080857461, 4149657548, 3529030811, 2769339058, 429034912, 996084582, 1734359068, 1375161721, 156143303, 3037969436, 1190257791, 1676065026, 73997760, 2249183962, 3601724488, 4122290591, 2965613742, 1557037497, 4185416125, 3400325053, 3974330977, 380844763, 2773129515, 152044993, 1429397077, 4224338628, 1029309310, 1646236907, 3601261456, 148996457, 3350178033, 2474815495, 1410658357, 2550027794, 1767790115, 270701719, 3018598882, 4024035625, 887012213, 2098623801, 3632846028, 2921764843, 3962072584, 641857031, 2017495901, 263120295, 2019130833, 1014261205, 4186856768, 2015365872, 279028754, 72415546, 924570728, 3902648166, 1678286441, 2292588029, 2311635131, 2102542545, 2971228400, 1591614386, 434847124, 3404271484, 3154798571, 113413735, 2491451743, 876206292, 1926984841, 3365539483, 4251125583, 793584453, 301069822, 3419075736, 724765157, 3208860939, 374696415, 1335525536, 1919740580, 3663859906, 2324775696, 3864157940, 1565672574, 3802032297, 171709183, 1234689881, 3634865316, 2015586222, 3369223628, 2978260002, 279245834, 70045384, 214564933, 2965926443, 2361247591, 4165854731, 2521208355, 2800346995, 4049169329, 23045802, 1970506914, 3641829100, 1666315902, 2241077078, 271426571, 4088283675, 2162546067, 3127287949, 2587241655, 159689334, 3031939099, 86965855, 685758755, 1240649695, 3226456551, 2957980176, 3555426742, 428862619, 839793577, 773054281, 76694871, 1000736152, 1723860376, 2645296715, 1765515124, 2397562436, 3840534762, 3419177686, 2056708761, 2320819600, 722357493, 1951132286, 2697176849, 3418517148, 603726513, 2078863096, 2316090329, 2970470049, 462414015, 791905947, 1253829863, 78371665, 4043788288, 2830761235, 2395343283, 2760690656, 1435313866, 1069903560, 1768883836, 3478580525, 904619222, 279413896, 2319870694, 4228173768, 3232061024, 1484346351, 4004145551, 3963210362, 1575762141, 1425228619, 4101712633, 483124459, 1978185260, 2962076605, 340777240, 3916617987, 2077585685, 1670667332, 2219902641, 1768998827, 47882544, 38398050, 3045070584, 950674091, 2821356397, 2104948616, 1566315737, 865585600, 1541267679, 3879180856, 2839407679, 1159648826, 1505709233, 3176134596, 3648874638, 415873698, 3782961634, 1010479344, 3886069287, 2148727539, 3631608995, 377214383, 4046490475, 1123005720, 1622745139, 3490975330, 2511527452, 1283291646, 2881450334, 4188742159, 2510804289, 879103178, 2003656783, 2968076365, 4153230545, 1779731768, 2697187974, 938150976, 3424390191, 2962715395, 101985175, 2518895690, 2455009572, 3071599415, 1061702964, 2015878382, 2790438718, 3098661352, 105466182, 2218696853, 2782132502, 1059279116, 1168368222, 377291999, 548884237, 1330528867, 1028242923, 594901465, 2151845525, 1345962394, 829096669, 3769617076, 859176729, 1416740787, 2528276857], "pos": 5}, "exits": {}, "cancels": {}, "failures": {}, "shutdown": false, "trials_closed": {}, "trial_progress": {"c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae": 0.75}, "trials_created": {"c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae": true}, "trials_requested": 1, "search_method_state": {"created_trials": 1, "pending_trials": 1, "search_method_type": "single"}, "completed_operations": {}}, "trial_searcher_state": {"c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae": {"Op": {"Length": 1000, "RequestID": "c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae"}, "Closed": true, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1}, "checkpoint": null, "request_id": "c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae", "trial_seed": 530298166, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}}}`), + //nolint:lll + v5Snapshot: []byte(`{"searcher_state": {"rand": {"key": [3922433409, 243371046, 1078118500, 751359450, 1341787537, 4110007575, 3830714619, 215586197, 202234178, 695788744, 568122858, 801842164, 915101998, 2089585430, 560765197, 132023793, 3625985341, 1033888033, 4275540481, 4081480301, 3803198840, 2062925182, 3815214353, 2217168443, 284556261, 4216276606, 3143802969, 4077949604, 1360661190, 1579111912, 3397581176, 2605001170, 1444943786, 1103161266, 2475001171, 586405187, 440758113, 69964893, 2302367679, 2318485289, 2699422072, 949642636, 992261637, 3418960830, 2747796370, 1988242324, 1336558633, 4144877284, 1447227155, 2130144917, 249701494, 845842848, 3562805067, 461807133, 316172300, 1019162321, 2042173338, 2357208493, 984738996, 2977486791, 4204759905, 3765157938, 1416537904, 673684512, 1014036428, 3291531021, 2149014833, 211182184, 1899637415, 2967746221, 1721885906, 226438235, 2148590681, 1789583513, 1586306530, 1162436037, 2949013157, 869268556, 2754664383, 190339737, 2683355822, 1858335276, 752629074, 3123663892, 3566745314, 2416019451, 2343621935, 4064029673, 2375292477, 4237178213, 733006903, 3966381652, 1902496609, 1636160960, 1441547075, 837638138, 885960515, 1113758007, 2646826291, 1038503594, 3726531118, 1965291352, 4027030242, 3233398909, 2964755703, 966551311, 1083992523, 1124285624, 3192801484, 2731176820, 2400133886, 2523660457, 1033111323, 2914428172, 1163901245, 1320025743, 2177720309, 1677889669, 2345935772, 2794475418, 9574722, 1397363731, 3692199984, 452561608, 112544138, 2256573095, 2829353338, 1987429988, 555809442, 1666429314, 215970428, 3154343711, 2051241582, 3877008887, 1816890945, 4013344821, 2102130895, 4186655536, 2673431404, 4026559573, 2176376772, 1221630340, 2467768183, 84713789, 3270911636, 1571685966, 1581830524, 4143291311, 2923146623, 3610913076, 2622548353, 464460930, 1486913384, 3667888493, 4266300151, 3341431843, 1965213295, 2565733412, 761474095, 761709082, 3177157908, 3043376221, 396311914, 3167326505, 1734174898, 1402058836, 2225699320, 1370103810, 1206643084, 3622576190, 109652665, 1635918058, 3535739916, 2701481795, 1919438907, 1981477202, 1360788607, 777792032, 4044337196, 2681268168, 2118070045, 1276795554, 3146333618, 746239793, 2520869193, 1894154101, 1005996243, 1898277441, 3296603528, 3591769481, 3145574469, 1323130515, 1238470035, 1076131551, 2785215479, 2287060195, 1378079485, 102967153, 2942614756, 1776625394, 938257756, 2470323776, 2927673692, 129462980, 221619173, 3277410069, 581765105, 1744873192, 1725783416, 2487492363, 3776340706, 848370815, 3284445102, 1648634245, 68108052, 3595739205, 2779510023, 2872989184, 1845125748, 121458802, 786888262, 3047839649, 2036720406, 1066086686, 3620519326, 2825286213, 1043364134, 745183959, 2008388075, 3781903549, 4088204083, 2422084247, 2245918792, 163997977, 4072121451, 377438777, 3434214799, 47957125, 901584641, 1663967634, 2224354240, 2140121772, 2239607504, 1296158592, 1144490013, 363777504, 352049636, 3537828370, 2018830719, 4106625267, 791310915, 1141549081, 383680238, 2211229427, 2955313933, 1786602475, 861117031, 1657038393, 49008346, 2757187031, 2415960764, 2063869385, 4250446539, 4072556442, 3080810042, 3746048829, 1207880898, 743812379, 1015188201, 1250574590, 2697317474, 1490831142, 3975889652, 3507228768, 3917477006, 922984658, 1548413943, 3607819493, 642573978, 2802934424, 1071690339, 2487490598, 353709112, 3041703309, 3991444003, 3766125511, 4022423108, 2255595963, 4244631415, 1858811130, 1272720251, 4254073557, 3076459717, 1790136700, 3582509861, 2769798461, 3229452726, 2752397444, 701766670, 1437507306, 1004148691, 578333132, 82009872, 3366243481, 1209720848, 1659428252, 789199905, 2021015352, 3384069524, 658409068, 1580305341, 4213315272, 4208459453, 2059532450, 893773418, 653840510, 1457966480, 393232134, 3136440598, 1431618763, 1172519077, 2575368795, 939388973, 1509101907, 988758589, 565340075, 2869588947, 1972475685, 4273757903, 204012057, 1101337211, 1533180088, 2801615526, 2440934744, 338200579, 3477442758, 2757306425, 2685737460, 3224302362, 1663459583, 696652877, 4112318289, 2675890357, 745289117, 2365755424, 2933967187, 4022266172, 1505363667, 2297045170, 570110611, 3346324247, 2491642679, 3975465941, 1581427559, 3057064987, 3791566228, 1635441872, 1271004942, 1877908946, 2574434132, 1044045440, 3454314545, 1253401130, 103952148, 3575409376, 490573770, 2501813399, 1298020930, 382668967, 2514416127, 3341542436, 1092494529, 195927818, 924628104, 3732801821, 1514695853, 1522275894, 219321854, 1556133239, 504225130, 1473692730, 406339414, 424484390, 4274713028, 2913973779, 4130369073, 1360365870, 628908339, 4154982639, 757598615, 2791015278, 2080857461, 4149657548, 3529030811, 2769339058, 429034912, 996084582, 1734359068, 1375161721, 156143303, 3037969436, 1190257791, 1676065026, 73997760, 2249183962, 3601724488, 4122290591, 2965613742, 1557037497, 4185416125, 3400325053, 3974330977, 380844763, 2773129515, 152044993, 1429397077, 4224338628, 1029309310, 1646236907, 3601261456, 148996457, 3350178033, 2474815495, 1410658357, 2550027794, 1767790115, 270701719, 3018598882, 4024035625, 887012213, 2098623801, 3632846028, 2921764843, 3962072584, 641857031, 2017495901, 263120295, 2019130833, 1014261205, 4186856768, 2015365872, 279028754, 72415546, 924570728, 3902648166, 1678286441, 2292588029, 2311635131, 2102542545, 2971228400, 1591614386, 434847124, 3404271484, 3154798571, 113413735, 2491451743, 876206292, 1926984841, 3365539483, 4251125583, 793584453, 301069822, 3419075736, 724765157, 3208860939, 374696415, 1335525536, 1919740580, 3663859906, 2324775696, 3864157940, 1565672574, 3802032297, 171709183, 1234689881, 3634865316, 2015586222, 3369223628, 2978260002, 279245834, 70045384, 214564933, 2965926443, 2361247591, 4165854731, 2521208355, 2800346995, 4049169329, 23045802, 1970506914, 3641829100, 1666315902, 2241077078, 271426571, 4088283675, 2162546067, 3127287949, 2587241655, 159689334, 3031939099, 86965855, 685758755, 1240649695, 3226456551, 2957980176, 3555426742, 428862619, 839793577, 773054281, 76694871, 1000736152, 1723860376, 2645296715, 1765515124, 2397562436, 3840534762, 3419177686, 2056708761, 2320819600, 722357493, 1951132286, 2697176849, 3418517148, 603726513, 2078863096, 2316090329, 2970470049, 462414015, 791905947, 1253829863, 78371665, 4043788288, 2830761235, 2395343283, 2760690656, 1435313866, 1069903560, 1768883836, 3478580525, 904619222, 279413896, 2319870694, 4228173768, 3232061024, 1484346351, 4004145551, 3963210362, 1575762141, 1425228619, 4101712633, 483124459, 1978185260, 2962076605, 340777240, 3916617987, 2077585685, 1670667332, 2219902641, 1768998827, 47882544, 38398050, 3045070584, 950674091, 2821356397, 2104948616, 1566315737, 865585600, 1541267679, 3879180856, 2839407679, 1159648826, 1505709233, 3176134596, 3648874638, 415873698, 3782961634, 1010479344, 3886069287, 2148727539, 3631608995, 377214383, 4046490475, 1123005720, 1622745139, 3490975330, 2511527452, 1283291646, 2881450334, 4188742159, 2510804289, 879103178, 2003656783, 2968076365, 4153230545, 1779731768, 2697187974, 938150976, 3424390191, 2962715395, 101985175, 2518895690, 2455009572, 3071599415, 1061702964, 2015878382, 2790438718, 3098661352, 105466182, 2218696853, 2782132502, 1059279116, 1168368222, 377291999, 548884237, 1330528867, 1028242923, 594901465, 2151845525, 1345962394, 829096669, 3769617076, 859176729, 1416740787, 2528276857], "pos": 5}, "exits": {}, "cancels": {}, "failures": {}, "trials_closed": {}, "trial_progress": {"c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae": 0.75}, "trials_created": {"c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae": true}, "trials_requested": 1, "search_method_state": {"created_trials": 1, "pending_trials": 1, "search_method_type": "single"}}, "trial_searcher_state": {"c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae": {"EarlyExitedByUserCode": false, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1}, "request_id": "c7ffd20e-ec56-4e21-b6dd-2c2e6a27a9ae", "trial_seed": 530298166}, "EarlyStoppedBySearcher": false}}}`), + }, + { + name: searcher.AdaptiveASHASearch, + //nolint:lll + v4Snapshot: []byte(`{"searcher_state": {"rand": {"key": [1452545913, 2216706256, 2931285415, 1424553963, 2771469721, 1182268625, 2241039347, 812859712, 3441585930, 1306182607, 3485054815, 608308123, 112483053, 682642478, 1312525610, 4135551174, 1567793875, 708291338, 3829781031, 1127597650, 1379589895, 4278788276, 970175326, 4153583291, 3982059578, 1493128882, 1875069103, 3704826304, 2299697046, 3676722567, 3041672913, 3928015435, 3590950234, 1340137182, 1319309903, 945669253, 1297992086, 255387058, 3738978610, 1747520993, 1267854203, 1275083903, 1988788172, 181920895, 1106027729, 1025177369, 2440854937, 188734714, 2791495762, 379644148, 1364486454, 1214940208, 3655512912, 3005587086, 3108187149, 781287699, 4159116078, 4043987339, 1632188677, 1250292127, 3753687125, 1026997361, 3415512821, 3782772083, 1358650193, 2164683139, 564316438, 941046993, 3129882612, 2182601888, 1975366329, 3311855628, 1644174752, 2793779553, 2753842799, 3571302630, 887923171, 3588840141, 193784717, 2471237536, 1421225039, 341938740, 1192996113, 2218409675, 3991457045, 3826598018, 2275285112, 4011863454, 1177877850, 2698633618, 1138152393, 2202893314, 347707955, 2436375966, 1020461438, 792229146, 113320512, 2036693887, 1215276260, 1018416563, 1195865614, 3136674021, 4150724058, 1177584340, 3870966002, 2945594975, 3859621414, 1717325326, 2051142765, 1936740951, 3062042453, 1688513379, 910042307, 1750537331, 304051019, 2636291192, 2013521520, 4227041032, 102628453, 1359940869, 4245248134, 1993619173, 3500025022, 3924853543, 3578800279, 2421667083, 4060793652, 1094276667, 928577324, 3215001437, 2737811296, 3805282341, 408114954, 2248459851, 1476902951, 73200291, 3600387404, 422649030, 2174730155, 3307173078, 1507629695, 113607599, 4116987602, 3510362588, 1500789310, 2511817938, 3576571324, 1756177812, 2413290146, 3140611884, 1850078157, 2095279096, 635232750, 3799246263, 1381854103, 57954085, 69958379, 55700730, 2727860158, 646634153, 646716365, 3060248147, 2645894961, 521556301, 2332134955, 3223404494, 1278684995, 2435087898, 401816126, 1094921611, 563333291, 2556415388, 1739710028, 133744693, 3919682174, 4087793525, 768502435, 4228008790, 210686178, 4150681864, 4091917901, 1835561309, 2232363723, 3031311488, 1776937342, 1233396034, 3476217536, 3799316084, 864878498, 377881092, 3188094097, 2750744302, 3199804590, 1231139403, 2429436965, 2130823117, 757668219, 2740840838, 1377165079, 2589078863, 2810863562, 1162879957, 831096657, 3396087259, 3098412156, 4061324184, 3203871242, 1931886212, 2937742743, 907488495, 3996978148, 2997218253, 2004185458, 2516650587, 593077235, 3489768626, 2150154954, 2049860379, 2472886567, 1430759218, 2957836638, 435817594, 1082252350, 3951520482, 2400586842, 3456205344, 3356182591, 618991329, 2992469709, 1973188425, 450379261, 2798568635, 4037568720, 2839400952, 255422271, 1854524310, 4006718197, 3420074895, 2722893251, 3468530746, 1593574294, 3521350966, 3024147312, 1169215387, 3910087826, 443646771, 2925353376, 2069261626, 2548372079, 2863860503, 4132686053, 3433804206, 950538056, 3552376483, 1870093729, 3384688973, 1158025919, 493730904, 1789038843, 1829026938, 3596529436, 3639269430, 1817201820, 848183402, 537785361, 2081164395, 2475598070, 1575953173, 3376512332, 4036578374, 581826236, 3513291559, 2113632026, 2616353057, 3729275555, 1680494578, 1363619591, 692325427, 97273826, 2579406491, 1531446639, 2390954510, 1317588761, 536858982, 988517096, 2244293128, 2804565134, 3877515395, 839842752, 3405616275, 1039673240, 3962514502, 896670502, 1093259020, 2043255868, 3991535172, 296660859, 4163623869, 4045669283, 3512517932, 1859259214, 79788873, 217768790, 2455520809, 386000238, 3566902734, 2247723896, 674246520, 226356020, 1961300106, 3848127451, 353356745, 52039995, 2770049955, 3021363660, 1624062882, 2986876091, 448633518, 1629505280, 21883362, 1848415473, 756079215, 4190124084, 4131714971, 2392212283, 1308326958, 3371194070, 1147387144, 1285992797, 3405963410, 1892827156, 1967237941, 2398927758, 238263952, 2432436094, 1498668510, 858532034, 3668090456, 1838702366, 426802744, 1194191127, 2885124677, 1279814428, 2856963439, 2166294156, 2653671113, 26555231, 4188511414, 2257856313, 3262711136, 395413330, 2789270448, 4264752234, 1117258307, 2896009146, 1226138881, 3802856348, 2523573944, 4159838139, 1445150309, 264468203, 4150566522, 2240997763, 1509631376, 435995998, 4165403839, 3610262680, 1168629505, 4035006883, 2695323969, 1292644795, 2745853320, 225794602, 1356178611, 955998327, 2238901481, 1106469405, 826123461, 4155393474, 2635289592, 1224474211, 1824797825, 1233049504, 760946898, 1148631409, 3096376358, 4205086568, 2291321688, 537823608, 490724183, 966548030, 213314081, 3997861198, 3555542635, 3937802140, 3037829047, 3057237581, 149746191, 4198823951, 3419012586, 3058993263, 3099857571, 738118719, 2559925817, 429746199, 999300291, 3324548747, 599849818, 51382161, 1656669044, 1614461261, 3437196859, 2914200168, 2522112723, 1406081509, 1013688116, 2269547301, 1493949008, 352929430, 3057939739, 3841398959, 2865277247, 274596832, 3889831750, 4063920528, 322224127, 369484237, 1096644289, 2969910741, 355324547, 3017842962, 652957044, 2288093025, 1319424438, 3597823923, 3187158261, 2300967604, 2444733360, 3346034650, 3052890513, 3412771067, 1644673169, 2004833514, 4176603510, 812339001, 2377375086, 3657608513, 789432047, 1232357156, 311642210, 1135746832, 4278279012, 4187675947, 1484381964, 518250323, 1804044839, 749985106, 1247034074, 1322933383, 2558753984, 2148724241, 3355638721, 1239704585, 458131162, 1469044936, 1816687844, 3310401614, 711732101, 2570107777, 2195543406, 1685156591, 806200714, 485890833, 2324324855, 1361241033, 231249103, 3925784640, 2812874552, 105551818, 3360389267, 236876949, 3837288942, 2773521044, 3397769505, 396727058, 1340986334, 842480288, 3693769194, 1304819435, 3927626290, 1233405874, 2400596270, 3722895074, 4024237244, 1308945353, 2379151461, 1653639100, 2837868071, 3143239362, 3347100096, 505670669, 1220922352, 1683276312, 1409357641, 2864400686, 2154723844, 2865848502, 3711211457, 4174404079, 3760415152, 2635418164, 3817078233, 3797497846, 1843555871, 3097746780, 1631738968, 1273189302, 3422812667, 1884737974, 3996941605, 567541988, 2835894599, 2678807189, 2862594697, 259382272, 4246982418, 2015432239, 1584697754, 950481942, 2676080319, 3060794403, 3206445087, 1623919009, 3227302040, 3207120844, 2227286302, 1395885538, 786668072, 1111364168, 831375560, 2351912451, 1750459989, 1249540546, 1954378595, 377821224, 1223652079, 660872685, 3148523876, 636798872, 3792321202, 1157296821, 2728165650, 2125606106, 286143865, 1929354736, 3282533922, 2910169564, 3416956740, 3467743336, 3115919084, 1499088458, 2537111141, 2756648794, 1989501022, 1972571169, 1594350983, 1687019006, 2282666905, 2791303809, 816687306, 3206038097, 3432004981, 1128341472, 2489750111, 4212962683, 4150103461, 3277624383, 1740516965, 3298397586, 1066119190, 973732348, 1142153902, 2593584651, 1347071644, 476190584, 255904713, 2487444411, 4277502188, 503790438, 3701506354, 4080177843, 1973766567, 2765496076, 2445282555, 3253894374, 69025061, 1973540476, 3707131235, 2463125033, 3650253067, 1261214630, 1834547302, 1646383947, 1135075426, 4138467070, 1543536840, 2107446004, 379532649, 2859826137, 2562206910, 962843355, 1514927822, 1971937745, 693327781, 3542160189, 327935861, 853015903, 3483398805, 3644300960, 1581562223, 719777381, 2642808027, 4103170367, 1148912942], "pos": 32}, "exits": {}, "cancels": {}, "failures": {}, "shutdown": false, "trials_closed": {}, "trial_progress": {"35cf42b8-36cf-49e3-b104-f3cce289d072": 0, "7ad68246-0a8a-4251-916d-db94cf525270": 0}, "trials_created": {"35cf42b8-36cf-49e3-b104-f3cce289d072": true, "7ad68246-0a8a-4251-916d-db94cf525270": true}, "trials_requested": 2, "search_method_state": {"trial_table": {"35cf42b8-36cf-49e3-b104-f3cce289d072": 1, "7ad68246-0a8a-4251-916d-db94cf525270": 0}, "sub_search_states": [{"rungs": [{"metrics": null, "start_trials": 0, "units_needed": 58, "promote_trials": 0, "outstanding_trials": 1}, {"metrics": null, "start_trials": 0, "units_needed": 292, "promote_trials": 0, "outstanding_trials": 0}, {"metrics": null, "start_trials": 0, "units_needed": 1229, "promote_trials": 0, "outstanding_trials": 0}], "trial_rungs": {"7ad68246-0a8a-4251-916d-db94cf525270": 0}, "closed_trials": {}, "invalid_trials": 0, "pending_trials": 0, "trials_completed": 0, "early_exit_trials": {}, "search_method_type": "asha"}, {"rungs": [{"metrics": null, "start_trials": 0, "units_needed": 234, "promote_trials": 0, "outstanding_trials": 1}, {"metrics": null, "start_trials": 0, "units_needed": 1171, "promote_trials": 0, "outstanding_trials": 0}], "trial_rungs": {"35cf42b8-36cf-49e3-b104-f3cce289d072": 0}, "closed_trials": {}, "invalid_trials": 0, "pending_trials": 0, "trials_completed": 0, "early_exit_trials": {}, "search_method_type": "asha"}], "search_method_type": "adaptive_asha", "sub_search_units_completed": [0, 0]}, "completed_operations": {}}, "trial_searcher_state": {"35cf42b8-36cf-49e3-b104-f3cce289d072": {"Op": {"Length": 234, "RequestID": "35cf42b8-36cf-49e3-b104-f3cce289d072"}, "Closed": false, "Create": {"hparams": {"dropout1": 0.4402283511805293, "dropout2": 0.44011555407446223, "n_filters1": 15, "n_filters2": 47, "learning_rate": 0.020523695929929888}, "checkpoint": null, "request_id": "35cf42b8-36cf-49e3-b104-f3cce289d072", "trial_seed": 897420172, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "7ad68246-0a8a-4251-916d-db94cf525270": {"Op": {"Length": 58, "RequestID": "7ad68246-0a8a-4251-916d-db94cf525270"}, "Closed": false, "Create": {"hparams": {"dropout1": 0.24506761071384414, "dropout2": 0.6338879717356689, "n_filters1": 25, "n_filters2": 32, "learning_rate": 0.9864141286469745}, "checkpoint": null, "request_id": "7ad68246-0a8a-4251-916d-db94cf525270", "trial_seed": 1221021447, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}}}`), + err: "unsupported search_method_type", + }, + { + name: searcher.GridSearch, + //nolint:lll + v4Snapshot: []byte(`{"searcher_state": {"rand": {"key": [2076300081, 1180757701, 61390626, 2075653657, 3983276912, 2219280310, 1951783750, 3432627106, 1678661725, 3042313994, 2202725749, 2202588336, 4190101782, 1110057276, 3191526579, 4171973736, 4254763240, 2766827202, 1156928074, 1597039613, 3598380759, 1437190771, 1076126477, 2658152872, 3825090400, 3114761804, 915898666, 291106241, 3036411027, 2744580765, 820531638, 1883042039, 2922514564, 188829505, 1874867577, 640982739, 1714839269, 3901091477, 2804767841, 3254494107, 2484761704, 476134574, 990456848, 3870586239, 620062224, 1995301304, 2389729019, 2462435556, 2472017888, 2504467222, 2188233046, 2306262543, 3560836860, 1978513357, 606650471, 3854035095, 4245776466, 3871782493, 597013963, 2913535589, 1756877383, 1623386184, 496866040, 588131525, 4177340522, 2935519398, 850546169, 109126006, 858579993, 795071074, 3590755528, 1638945842, 1254896771, 2307161967, 3477967610, 692808483, 1899600971, 2769494850, 2779999007, 510215608, 4111810801, 4280921106, 1518864863, 834898192, 1306840028, 835688979, 459513373, 609148069, 3545906245, 1194162898, 258441420, 4179429231, 4022175812, 2648608732, 717531457, 306247172, 2976743141, 237988140, 592196390, 704362957, 7450542, 311689767, 3067438641, 808588439, 1235198799, 742758417, 2780949935, 3446973894, 3873685895, 2153250996, 4065466909, 1497342942, 3724631752, 4275880421, 272199235, 2252962393, 2056155089, 1877112438, 2110296751, 2590941302, 2013827341, 842823052, 1894032322, 1393960121, 3588305616, 2529793378, 3205436904, 596814183, 1254786875, 2966225931, 2121907724, 475901771, 555053869, 1400569825, 3688890119, 1353122628, 1459205986, 2300737010, 1481729865, 756275962, 599878589, 2046378230, 2769239562, 3613422680, 4252206344, 947287240, 2417250410, 1305902854, 2671572052, 1899272897, 3230713963, 2224033552, 1732656815, 1936213189, 443599218, 2770284889, 3459882565, 1332607404, 3278697474, 2224591365, 2257839008, 317213768, 2348128734, 1874716743, 456696413, 1116761881, 3309591553, 2862216319, 102982924, 2150630897, 3549437867, 2842597334, 4045646707, 302008588, 1654281218, 1727189467, 738749640, 3764450573, 3959398424, 3317500773, 3535585186, 3819002453, 3662808846, 1770928073, 3787632432, 2477652163, 1746899438, 3750614957, 761234677, 3689896539, 3520772878, 275147232, 337553768, 2986250282, 3506693941, 591880001, 1455806111, 3804664561, 2479592303, 4013399092, 726617980, 255395900, 178605610, 1293625725, 2646174944, 2376136796, 1142332741, 266983497, 999135041, 1646761432, 3287859002, 2993926759, 3462763364, 4227465400, 3355948815, 1238930500, 2174357849, 379232244, 2722994601, 2303499529, 974370129, 2577648277, 808892366, 3061424190, 4078456653, 1309753062, 2244949359, 371260896, 1136719967, 1708548536, 1658845169, 2768668945, 1531754301, 3191375913, 3511054906, 1062351502, 1744842885, 581610003, 1556303968, 867409273, 1667798951, 209299262, 437489551, 3229818234, 4190510574, 714265718, 2019898663, 564912797, 3376897769, 3559969466, 3938497845, 471145731, 349982265, 1552795612, 1447370340, 3534819495, 3702903169, 2270765654, 3192681951, 204048124, 3420549534, 3712568979, 3780914699, 2771260334, 2314887317, 745613193, 3577922416, 3094639701, 89439972, 1944758030, 1119237373, 586389801, 2520447175, 3625732341, 1427555151, 2087065690, 409476010, 187430597, 3753592208, 2951247134, 885626627, 844754868, 2092667268, 748208300, 3708044571, 3080007883, 3448755645, 3861156834, 1842920493, 818793362, 595126995, 2776784591, 62641362, 840212211, 3985028931, 387936511, 473253991, 1857032791, 4000006968, 4243356941, 1229644438, 153835145, 1283983784, 3918669237, 2721262354, 4294888081, 3113635576, 3580583332, 3521345732, 3528129211, 2061930144, 1634595536, 650979608, 2017813394, 1823871219, 3733362889, 3318357463, 2742200384, 763055433, 358005169, 768692017, 1050455834, 3453938424, 49062937, 3352742911, 3630652047, 2437806883, 1597018682, 2518758128, 1213648650, 4073021622, 4259149854, 3212854626, 79448901, 771981874, 3297404440, 3186097826, 319093164, 3890862606, 2245955576, 1497647520, 828363054, 3988483235, 3157718635, 562359205, 1350548803, 3372491415, 86446595, 1114830016, 2762338015, 1180274773, 2871068129, 2507166170, 2627076257, 1096162219, 2200646305, 3664591154, 3892273969, 476888795, 832867753, 4151853558, 2982123525, 182781907, 241410694, 1341125666, 4028887234, 3884607589, 2732864456, 1605421707, 2038450818, 3362242279, 4122980381, 2985487124, 110300201, 509696857, 4017443718, 1838466952, 2327354958, 2137521982, 138621377, 2133785874, 1413747039, 1739282333, 1675927427, 3185180235, 2373730108, 495353069, 1293977021, 1368037164, 2798684905, 81516419, 3857598893, 3495427721, 1760877692, 1152788660, 970145190, 3017785210, 1990290980, 67842654, 3925233768, 4002022873, 256245794, 1696909255, 738711838, 1446248938, 3868148475, 1939621088, 3658634988, 1784094744, 296610735, 1744996095, 2451305905, 2987883429, 2443943189, 1185159281, 4111976294, 2182983103, 3071467556, 1314957723, 152464689, 655243290, 1120785722, 1363764666, 3687005133, 4166824409, 4000596589, 2610383491, 3239976693, 4016033738, 1658070453, 1873771320, 1404781153, 3039196925, 3067316017, 2982654406, 373430827, 1067861532, 3675311637, 267701726, 3638688126, 2135521145, 1500865973, 1349464109, 3403519167, 2763769271, 4264496249, 1770838146, 1391852856, 2245935765, 2137032072, 1856750853, 688147071, 99608919, 3339156540, 2192512157, 2820609381, 2913853119, 2634309658, 1135799296, 1004115372, 2154400231, 248807841, 1855000157, 1129896866, 715899117, 923401956, 440983241, 3109426185, 477965580, 84854570, 1400593057, 1286402819, 3439792802, 318102008, 1067872774, 4240269308, 2549429311, 3855914333, 3795474048, 319956485, 4133374879, 1467447321, 78080933, 251366958, 3496186988, 3070794250, 2726275807, 2325947483, 53232900, 819143840, 2594863810, 55769358, 391403252, 866292794, 714727242, 2581427793, 698101228, 872117109, 997416829, 3565045060, 1611274371, 1562302439, 3417876422, 1392788396, 3206592320, 2207230324, 721665840, 2883852313, 1919434161, 395546903, 1837162976, 3429796583, 1522595013, 3348506059, 1175004850, 2074412352, 3758985171, 2415299591, 1703486181, 2304170834, 549422017, 2902550119, 2187986969, 3362538344, 768387705, 2456129962, 3429271129, 3706106775, 204072664, 806512945, 3311330988, 1474760185, 1814261184, 1577346575, 2078317992, 3122306338, 3467081158, 1225264382, 1847488518, 599526500, 2488492818, 1688305017, 438612976, 1607069832, 2927329539, 503082697, 1150118168, 1865053291, 2428734033, 4135113359, 821829536, 868562558, 413775071, 1334204977, 673408382, 1882374958, 4291776686, 4039700264, 911935763, 1482546356, 3253895898, 2457857766, 278986188, 1981961411, 1618768772, 658526763, 3374586925, 729845344, 1603013850, 1740519279, 2732939556, 821729772, 4152108407, 2159443430, 2359657320, 3741458889, 2625148830, 2740386969, 508381854, 4017968509, 3047356953, 3577840409, 4191139539, 592911320, 1632820963, 2552033020, 1212957884, 1014120341, 3310543727, 3779267868, 3583152261, 3680986715, 2995658876, 3137928161, 587029290, 2798799836, 4187466127, 594529572, 406159454, 4285083401, 2812893547, 841074247, 2953120157, 466989304, 2691320671, 3368132983, 3549982940, 3530994849, 3593311564, 2291771526, 3269554905, 4281627661, 1457258966, 3464984667, 1432592878, 1387812831, 154474021, 343300151, 1018325484, 2085104744, 4236418319, 2375494721, 1088217159, 389539389], "pos": 10}, "exits": {}, "cancels": {}, "failures": {}, "shutdown": false, "trials_closed": {}, "trial_progress": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": 0, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": 0}, "trials_created": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": true, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": true}, "trials_requested": 2, "search_method_state": {"pending_trials": 2, "remaining_trials": [{"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 32, "learning_rate": 1}, {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 64, "learning_rate": 1}], "search_method_type": "grid"}, "completed_operations": {}}, "trial_searcher_state": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": {"Op": {"Length": 1, "RequestID": "b20fd10b-c039-45fa-b450-86e9ad91ec28"}, "Closed": true, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 32, "learning_rate": 1}, "checkpoint": null, "request_id": "b20fd10b-c039-45fa-b450-86e9ad91ec28", "trial_seed": 1367408042, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": {"Op": {"Length": 1, "RequestID": "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e"}, "Closed": true, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1}, "checkpoint": null, "request_id": "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e", "trial_seed": 1545095049, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}}}`), + //nolint:lll + v5Snapshot: []byte(`{"searcher_state": {"rand": {"key": [2076300081, 1180757701, 61390626, 2075653657, 3983276912, 2219280310, 1951783750, 3432627106, 1678661725, 3042313994, 2202725749, 2202588336, 4190101782, 1110057276, 3191526579, 4171973736, 4254763240, 2766827202, 1156928074, 1597039613, 3598380759, 1437190771, 1076126477, 2658152872, 3825090400, 3114761804, 915898666, 291106241, 3036411027, 2744580765, 820531638, 1883042039, 2922514564, 188829505, 1874867577, 640982739, 1714839269, 3901091477, 2804767841, 3254494107, 2484761704, 476134574, 990456848, 3870586239, 620062224, 1995301304, 2389729019, 2462435556, 2472017888, 2504467222, 2188233046, 2306262543, 3560836860, 1978513357, 606650471, 3854035095, 4245776466, 3871782493, 597013963, 2913535589, 1756877383, 1623386184, 496866040, 588131525, 4177340522, 2935519398, 850546169, 109126006, 858579993, 795071074, 3590755528, 1638945842, 1254896771, 2307161967, 3477967610, 692808483, 1899600971, 2769494850, 2779999007, 510215608, 4111810801, 4280921106, 1518864863, 834898192, 1306840028, 835688979, 459513373, 609148069, 3545906245, 1194162898, 258441420, 4179429231, 4022175812, 2648608732, 717531457, 306247172, 2976743141, 237988140, 592196390, 704362957, 7450542, 311689767, 3067438641, 808588439, 1235198799, 742758417, 2780949935, 3446973894, 3873685895, 2153250996, 4065466909, 1497342942, 3724631752, 4275880421, 272199235, 2252962393, 2056155089, 1877112438, 2110296751, 2590941302, 2013827341, 842823052, 1894032322, 1393960121, 3588305616, 2529793378, 3205436904, 596814183, 1254786875, 2966225931, 2121907724, 475901771, 555053869, 1400569825, 3688890119, 1353122628, 1459205986, 2300737010, 1481729865, 756275962, 599878589, 2046378230, 2769239562, 3613422680, 4252206344, 947287240, 2417250410, 1305902854, 2671572052, 1899272897, 3230713963, 2224033552, 1732656815, 1936213189, 443599218, 2770284889, 3459882565, 1332607404, 3278697474, 2224591365, 2257839008, 317213768, 2348128734, 1874716743, 456696413, 1116761881, 3309591553, 2862216319, 102982924, 2150630897, 3549437867, 2842597334, 4045646707, 302008588, 1654281218, 1727189467, 738749640, 3764450573, 3959398424, 3317500773, 3535585186, 3819002453, 3662808846, 1770928073, 3787632432, 2477652163, 1746899438, 3750614957, 761234677, 3689896539, 3520772878, 275147232, 337553768, 2986250282, 3506693941, 591880001, 1455806111, 3804664561, 2479592303, 4013399092, 726617980, 255395900, 178605610, 1293625725, 2646174944, 2376136796, 1142332741, 266983497, 999135041, 1646761432, 3287859002, 2993926759, 3462763364, 4227465400, 3355948815, 1238930500, 2174357849, 379232244, 2722994601, 2303499529, 974370129, 2577648277, 808892366, 3061424190, 4078456653, 1309753062, 2244949359, 371260896, 1136719967, 1708548536, 1658845169, 2768668945, 1531754301, 3191375913, 3511054906, 1062351502, 1744842885, 581610003, 1556303968, 867409273, 1667798951, 209299262, 437489551, 3229818234, 4190510574, 714265718, 2019898663, 564912797, 3376897769, 3559969466, 3938497845, 471145731, 349982265, 1552795612, 1447370340, 3534819495, 3702903169, 2270765654, 3192681951, 204048124, 3420549534, 3712568979, 3780914699, 2771260334, 2314887317, 745613193, 3577922416, 3094639701, 89439972, 1944758030, 1119237373, 586389801, 2520447175, 3625732341, 1427555151, 2087065690, 409476010, 187430597, 3753592208, 2951247134, 885626627, 844754868, 2092667268, 748208300, 3708044571, 3080007883, 3448755645, 3861156834, 1842920493, 818793362, 595126995, 2776784591, 62641362, 840212211, 3985028931, 387936511, 473253991, 1857032791, 4000006968, 4243356941, 1229644438, 153835145, 1283983784, 3918669237, 2721262354, 4294888081, 3113635576, 3580583332, 3521345732, 3528129211, 2061930144, 1634595536, 650979608, 2017813394, 1823871219, 3733362889, 3318357463, 2742200384, 763055433, 358005169, 768692017, 1050455834, 3453938424, 49062937, 3352742911, 3630652047, 2437806883, 1597018682, 2518758128, 1213648650, 4073021622, 4259149854, 3212854626, 79448901, 771981874, 3297404440, 3186097826, 319093164, 3890862606, 2245955576, 1497647520, 828363054, 3988483235, 3157718635, 562359205, 1350548803, 3372491415, 86446595, 1114830016, 2762338015, 1180274773, 2871068129, 2507166170, 2627076257, 1096162219, 2200646305, 3664591154, 3892273969, 476888795, 832867753, 4151853558, 2982123525, 182781907, 241410694, 1341125666, 4028887234, 3884607589, 2732864456, 1605421707, 2038450818, 3362242279, 4122980381, 2985487124, 110300201, 509696857, 4017443718, 1838466952, 2327354958, 2137521982, 138621377, 2133785874, 1413747039, 1739282333, 1675927427, 3185180235, 2373730108, 495353069, 1293977021, 1368037164, 2798684905, 81516419, 3857598893, 3495427721, 1760877692, 1152788660, 970145190, 3017785210, 1990290980, 67842654, 3925233768, 4002022873, 256245794, 1696909255, 738711838, 1446248938, 3868148475, 1939621088, 3658634988, 1784094744, 296610735, 1744996095, 2451305905, 2987883429, 2443943189, 1185159281, 4111976294, 2182983103, 3071467556, 1314957723, 152464689, 655243290, 1120785722, 1363764666, 3687005133, 4166824409, 4000596589, 2610383491, 3239976693, 4016033738, 1658070453, 1873771320, 1404781153, 3039196925, 3067316017, 2982654406, 373430827, 1067861532, 3675311637, 267701726, 3638688126, 2135521145, 1500865973, 1349464109, 3403519167, 2763769271, 4264496249, 1770838146, 1391852856, 2245935765, 2137032072, 1856750853, 688147071, 99608919, 3339156540, 2192512157, 2820609381, 2913853119, 2634309658, 1135799296, 1004115372, 2154400231, 248807841, 1855000157, 1129896866, 715899117, 923401956, 440983241, 3109426185, 477965580, 84854570, 1400593057, 1286402819, 3439792802, 318102008, 1067872774, 4240269308, 2549429311, 3855914333, 3795474048, 319956485, 4133374879, 1467447321, 78080933, 251366958, 3496186988, 3070794250, 2726275807, 2325947483, 53232900, 819143840, 2594863810, 55769358, 391403252, 866292794, 714727242, 2581427793, 698101228, 872117109, 997416829, 3565045060, 1611274371, 1562302439, 3417876422, 1392788396, 3206592320, 2207230324, 721665840, 2883852313, 1919434161, 395546903, 1837162976, 3429796583, 1522595013, 3348506059, 1175004850, 2074412352, 3758985171, 2415299591, 1703486181, 2304170834, 549422017, 2902550119, 2187986969, 3362538344, 768387705, 2456129962, 3429271129, 3706106775, 204072664, 806512945, 3311330988, 1474760185, 1814261184, 1577346575, 2078317992, 3122306338, 3467081158, 1225264382, 1847488518, 599526500, 2488492818, 1688305017, 438612976, 1607069832, 2927329539, 503082697, 1150118168, 1865053291, 2428734033, 4135113359, 821829536, 868562558, 413775071, 1334204977, 673408382, 1882374958, 4291776686, 4039700264, 911935763, 1482546356, 3253895898, 2457857766, 278986188, 1981961411, 1618768772, 658526763, 3374586925, 729845344, 1603013850, 1740519279, 2732939556, 821729772, 4152108407, 2159443430, 2359657320, 3741458889, 2625148830, 2740386969, 508381854, 4017968509, 3047356953, 3577840409, 4191139539, 592911320, 1632820963, 2552033020, 1212957884, 1014120341, 3310543727, 3779267868, 3583152261, 3680986715, 2995658876, 3137928161, 587029290, 2798799836, 4187466127, 594529572, 406159454, 4285083401, 2812893547, 841074247, 2953120157, 466989304, 2691320671, 3368132983, 3549982940, 3530994849, 3593311564, 2291771526, 3269554905, 4281627661, 1457258966, 3464984667, 1432592878, 1387812831, 154474021, 343300151, 1018325484, 2085104744, 4236418319, 2375494721, 1088217159, 389539389], "pos": 10}, "exits": {}, "cancels": {}, "failures": {}, "trials_closed": {}, "trial_progress": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": 0, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": 0}, "trials_created": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": true, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": true}, "trials_requested": 2, "search_method_state": {"pending_trials": 2, "remaining_trials": [{"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 32, "learning_rate": 1}, {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 64, "learning_rate": 1}], "search_method_type": "grid"}}, "trial_searcher_state": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": {"EarlyExitedByUserCode": false, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 32, "learning_rate": 1}, "request_id": "b20fd10b-c039-45fa-b450-86e9ad91ec28", "trial_seed": 1367408042}, "EarlyStoppedBySearcher": false}, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": {"EarlyExitedByUserCode": false, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1}, "request_id": "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e", "trial_seed": 1545095049}, "EarlyStoppedBySearcher": false}}}`), + }, + { + name: searcher.RandomSearch, + //nolint:lll + v4Snapshot: []byte(`{"searcher_state": {"rand": {"key": [2076300081, 1180757701, 61390626, 2075653657, 3983276912, 2219280310, 1951783750, 3432627106, 1678661725, 3042313994, 2202725749, 2202588336, 4190101782, 1110057276, 3191526579, 4171973736, 4254763240, 2766827202, 1156928074, 1597039613, 3598380759, 1437190771, 1076126477, 2658152872, 3825090400, 3114761804, 915898666, 291106241, 3036411027, 2744580765, 820531638, 1883042039, 2922514564, 188829505, 1874867577, 640982739, 1714839269, 3901091477, 2804767841, 3254494107, 2484761704, 476134574, 990456848, 3870586239, 620062224, 1995301304, 2389729019, 2462435556, 2472017888, 2504467222, 2188233046, 2306262543, 3560836860, 1978513357, 606650471, 3854035095, 4245776466, 3871782493, 597013963, 2913535589, 1756877383, 1623386184, 496866040, 588131525, 4177340522, 2935519398, 850546169, 109126006, 858579993, 795071074, 3590755528, 1638945842, 1254896771, 2307161967, 3477967610, 692808483, 1899600971, 2769494850, 2779999007, 510215608, 4111810801, 4280921106, 1518864863, 834898192, 1306840028, 835688979, 459513373, 609148069, 3545906245, 1194162898, 258441420, 4179429231, 4022175812, 2648608732, 717531457, 306247172, 2976743141, 237988140, 592196390, 704362957, 7450542, 311689767, 3067438641, 808588439, 1235198799, 742758417, 2780949935, 3446973894, 3873685895, 2153250996, 4065466909, 1497342942, 3724631752, 4275880421, 272199235, 2252962393, 2056155089, 1877112438, 2110296751, 2590941302, 2013827341, 842823052, 1894032322, 1393960121, 3588305616, 2529793378, 3205436904, 596814183, 1254786875, 2966225931, 2121907724, 475901771, 555053869, 1400569825, 3688890119, 1353122628, 1459205986, 2300737010, 1481729865, 756275962, 599878589, 2046378230, 2769239562, 3613422680, 4252206344, 947287240, 2417250410, 1305902854, 2671572052, 1899272897, 3230713963, 2224033552, 1732656815, 1936213189, 443599218, 2770284889, 3459882565, 1332607404, 3278697474, 2224591365, 2257839008, 317213768, 2348128734, 1874716743, 456696413, 1116761881, 3309591553, 2862216319, 102982924, 2150630897, 3549437867, 2842597334, 4045646707, 302008588, 1654281218, 1727189467, 738749640, 3764450573, 3959398424, 3317500773, 3535585186, 3819002453, 3662808846, 1770928073, 3787632432, 2477652163, 1746899438, 3750614957, 761234677, 3689896539, 3520772878, 275147232, 337553768, 2986250282, 3506693941, 591880001, 1455806111, 3804664561, 2479592303, 4013399092, 726617980, 255395900, 178605610, 1293625725, 2646174944, 2376136796, 1142332741, 266983497, 999135041, 1646761432, 3287859002, 2993926759, 3462763364, 4227465400, 3355948815, 1238930500, 2174357849, 379232244, 2722994601, 2303499529, 974370129, 2577648277, 808892366, 3061424190, 4078456653, 1309753062, 2244949359, 371260896, 1136719967, 1708548536, 1658845169, 2768668945, 1531754301, 3191375913, 3511054906, 1062351502, 1744842885, 581610003, 1556303968, 867409273, 1667798951, 209299262, 437489551, 3229818234, 4190510574, 714265718, 2019898663, 564912797, 3376897769, 3559969466, 3938497845, 471145731, 349982265, 1552795612, 1447370340, 3534819495, 3702903169, 2270765654, 3192681951, 204048124, 3420549534, 3712568979, 3780914699, 2771260334, 2314887317, 745613193, 3577922416, 3094639701, 89439972, 1944758030, 1119237373, 586389801, 2520447175, 3625732341, 1427555151, 2087065690, 409476010, 187430597, 3753592208, 2951247134, 885626627, 844754868, 2092667268, 748208300, 3708044571, 3080007883, 3448755645, 3861156834, 1842920493, 818793362, 595126995, 2776784591, 62641362, 840212211, 3985028931, 387936511, 473253991, 1857032791, 4000006968, 4243356941, 1229644438, 153835145, 1283983784, 3918669237, 2721262354, 4294888081, 3113635576, 3580583332, 3521345732, 3528129211, 2061930144, 1634595536, 650979608, 2017813394, 1823871219, 3733362889, 3318357463, 2742200384, 763055433, 358005169, 768692017, 1050455834, 3453938424, 49062937, 3352742911, 3630652047, 2437806883, 1597018682, 2518758128, 1213648650, 4073021622, 4259149854, 3212854626, 79448901, 771981874, 3297404440, 3186097826, 319093164, 3890862606, 2245955576, 1497647520, 828363054, 3988483235, 3157718635, 562359205, 1350548803, 3372491415, 86446595, 1114830016, 2762338015, 1180274773, 2871068129, 2507166170, 2627076257, 1096162219, 2200646305, 3664591154, 3892273969, 476888795, 832867753, 4151853558, 2982123525, 182781907, 241410694, 1341125666, 4028887234, 3884607589, 2732864456, 1605421707, 2038450818, 3362242279, 4122980381, 2985487124, 110300201, 509696857, 4017443718, 1838466952, 2327354958, 2137521982, 138621377, 2133785874, 1413747039, 1739282333, 1675927427, 3185180235, 2373730108, 495353069, 1293977021, 1368037164, 2798684905, 81516419, 3857598893, 3495427721, 1760877692, 1152788660, 970145190, 3017785210, 1990290980, 67842654, 3925233768, 4002022873, 256245794, 1696909255, 738711838, 1446248938, 3868148475, 1939621088, 3658634988, 1784094744, 296610735, 1744996095, 2451305905, 2987883429, 2443943189, 1185159281, 4111976294, 2182983103, 3071467556, 1314957723, 152464689, 655243290, 1120785722, 1363764666, 3687005133, 4166824409, 4000596589, 2610383491, 3239976693, 4016033738, 1658070453, 1873771320, 1404781153, 3039196925, 3067316017, 2982654406, 373430827, 1067861532, 3675311637, 267701726, 3638688126, 2135521145, 1500865973, 1349464109, 3403519167, 2763769271, 4264496249, 1770838146, 1391852856, 2245935765, 2137032072, 1856750853, 688147071, 99608919, 3339156540, 2192512157, 2820609381, 2913853119, 2634309658, 1135799296, 1004115372, 2154400231, 248807841, 1855000157, 1129896866, 715899117, 923401956, 440983241, 3109426185, 477965580, 84854570, 1400593057, 1286402819, 3439792802, 318102008, 1067872774, 4240269308, 2549429311, 3855914333, 3795474048, 319956485, 4133374879, 1467447321, 78080933, 251366958, 3496186988, 3070794250, 2726275807, 2325947483, 53232900, 819143840, 2594863810, 55769358, 391403252, 866292794, 714727242, 2581427793, 698101228, 872117109, 997416829, 3565045060, 1611274371, 1562302439, 3417876422, 1392788396, 3206592320, 2207230324, 721665840, 2883852313, 1919434161, 395546903, 1837162976, 3429796583, 1522595013, 3348506059, 1175004850, 2074412352, 3758985171, 2415299591, 1703486181, 2304170834, 549422017, 2902550119, 2187986969, 3362538344, 768387705, 2456129962, 3429271129, 3706106775, 204072664, 806512945, 3311330988, 1474760185, 1814261184, 1577346575, 2078317992, 3122306338, 3467081158, 1225264382, 1847488518, 599526500, 2488492818, 1688305017, 438612976, 1607069832, 2927329539, 503082697, 1150118168, 1865053291, 2428734033, 4135113359, 821829536, 868562558, 413775071, 1334204977, 673408382, 1882374958, 4291776686, 4039700264, 911935763, 1482546356, 3253895898, 2457857766, 278986188, 1981961411, 1618768772, 658526763, 3374586925, 729845344, 1603013850, 1740519279, 2732939556, 821729772, 4152108407, 2159443430, 2359657320, 3741458889, 2625148830, 2740386969, 508381854, 4017968509, 3047356953, 3577840409, 4191139539, 592911320, 1632820963, 2552033020, 1212957884, 1014120341, 3310543727, 3779267868, 3583152261, 3680986715, 2995658876, 3137928161, 587029290, 2798799836, 4187466127, 594529572, 406159454, 4285083401, 2812893547, 841074247, 2953120157, 466989304, 2691320671, 3368132983, 3549982940, 3530994849, 3593311564, 2291771526, 3269554905, 4281627661, 1457258966, 3464984667, 1432592878, 1387812831, 154474021, 343300151, 1018325484, 2085104744, 4236418319, 2375494721, 1088217159, 389539389], "pos": 10}, "exits": {}, "cancels": {}, "failures": {}, "shutdown": false, "trials_closed": {}, "trial_progress": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": 0, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": 0}, "trials_created": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": true, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": true}, "trials_requested": 2, "search_method_state": {"pending_trials": 2, "remaining_trials": [{"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 32, "learning_rate": 1}, {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 64, "learning_rate": 1}], "search_method_type": "grid"}, "completed_operations": {}}, "trial_searcher_state": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": {"Op": {"Length": 1, "RequestID": "b20fd10b-c039-45fa-b450-86e9ad91ec28"}, "Closed": true, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 32, "learning_rate": 1}, "checkpoint": null, "request_id": "b20fd10b-c039-45fa-b450-86e9ad91ec28", "trial_seed": 1367408042, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": {"Op": {"Length": 1, "RequestID": "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e"}, "Closed": true, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1}, "checkpoint": null, "request_id": "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e", "trial_seed": 1545095049, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}}}`), + //nolint:lll + v5Snapshot: []byte(`{"searcher_state": {"rand": {"key": [2076300081, 1180757701, 61390626, 2075653657, 3983276912, 2219280310, 1951783750, 3432627106, 1678661725, 3042313994, 2202725749, 2202588336, 4190101782, 1110057276, 3191526579, 4171973736, 4254763240, 2766827202, 1156928074, 1597039613, 3598380759, 1437190771, 1076126477, 2658152872, 3825090400, 3114761804, 915898666, 291106241, 3036411027, 2744580765, 820531638, 1883042039, 2922514564, 188829505, 1874867577, 640982739, 1714839269, 3901091477, 2804767841, 3254494107, 2484761704, 476134574, 990456848, 3870586239, 620062224, 1995301304, 2389729019, 2462435556, 2472017888, 2504467222, 2188233046, 2306262543, 3560836860, 1978513357, 606650471, 3854035095, 4245776466, 3871782493, 597013963, 2913535589, 1756877383, 1623386184, 496866040, 588131525, 4177340522, 2935519398, 850546169, 109126006, 858579993, 795071074, 3590755528, 1638945842, 1254896771, 2307161967, 3477967610, 692808483, 1899600971, 2769494850, 2779999007, 510215608, 4111810801, 4280921106, 1518864863, 834898192, 1306840028, 835688979, 459513373, 609148069, 3545906245, 1194162898, 258441420, 4179429231, 4022175812, 2648608732, 717531457, 306247172, 2976743141, 237988140, 592196390, 704362957, 7450542, 311689767, 3067438641, 808588439, 1235198799, 742758417, 2780949935, 3446973894, 3873685895, 2153250996, 4065466909, 1497342942, 3724631752, 4275880421, 272199235, 2252962393, 2056155089, 1877112438, 2110296751, 2590941302, 2013827341, 842823052, 1894032322, 1393960121, 3588305616, 2529793378, 3205436904, 596814183, 1254786875, 2966225931, 2121907724, 475901771, 555053869, 1400569825, 3688890119, 1353122628, 1459205986, 2300737010, 1481729865, 756275962, 599878589, 2046378230, 2769239562, 3613422680, 4252206344, 947287240, 2417250410, 1305902854, 2671572052, 1899272897, 3230713963, 2224033552, 1732656815, 1936213189, 443599218, 2770284889, 3459882565, 1332607404, 3278697474, 2224591365, 2257839008, 317213768, 2348128734, 1874716743, 456696413, 1116761881, 3309591553, 2862216319, 102982924, 2150630897, 3549437867, 2842597334, 4045646707, 302008588, 1654281218, 1727189467, 738749640, 3764450573, 3959398424, 3317500773, 3535585186, 3819002453, 3662808846, 1770928073, 3787632432, 2477652163, 1746899438, 3750614957, 761234677, 3689896539, 3520772878, 275147232, 337553768, 2986250282, 3506693941, 591880001, 1455806111, 3804664561, 2479592303, 4013399092, 726617980, 255395900, 178605610, 1293625725, 2646174944, 2376136796, 1142332741, 266983497, 999135041, 1646761432, 3287859002, 2993926759, 3462763364, 4227465400, 3355948815, 1238930500, 2174357849, 379232244, 2722994601, 2303499529, 974370129, 2577648277, 808892366, 3061424190, 4078456653, 1309753062, 2244949359, 371260896, 1136719967, 1708548536, 1658845169, 2768668945, 1531754301, 3191375913, 3511054906, 1062351502, 1744842885, 581610003, 1556303968, 867409273, 1667798951, 209299262, 437489551, 3229818234, 4190510574, 714265718, 2019898663, 564912797, 3376897769, 3559969466, 3938497845, 471145731, 349982265, 1552795612, 1447370340, 3534819495, 3702903169, 2270765654, 3192681951, 204048124, 3420549534, 3712568979, 3780914699, 2771260334, 2314887317, 745613193, 3577922416, 3094639701, 89439972, 1944758030, 1119237373, 586389801, 2520447175, 3625732341, 1427555151, 2087065690, 409476010, 187430597, 3753592208, 2951247134, 885626627, 844754868, 2092667268, 748208300, 3708044571, 3080007883, 3448755645, 3861156834, 1842920493, 818793362, 595126995, 2776784591, 62641362, 840212211, 3985028931, 387936511, 473253991, 1857032791, 4000006968, 4243356941, 1229644438, 153835145, 1283983784, 3918669237, 2721262354, 4294888081, 3113635576, 3580583332, 3521345732, 3528129211, 2061930144, 1634595536, 650979608, 2017813394, 1823871219, 3733362889, 3318357463, 2742200384, 763055433, 358005169, 768692017, 1050455834, 3453938424, 49062937, 3352742911, 3630652047, 2437806883, 1597018682, 2518758128, 1213648650, 4073021622, 4259149854, 3212854626, 79448901, 771981874, 3297404440, 3186097826, 319093164, 3890862606, 2245955576, 1497647520, 828363054, 3988483235, 3157718635, 562359205, 1350548803, 3372491415, 86446595, 1114830016, 2762338015, 1180274773, 2871068129, 2507166170, 2627076257, 1096162219, 2200646305, 3664591154, 3892273969, 476888795, 832867753, 4151853558, 2982123525, 182781907, 241410694, 1341125666, 4028887234, 3884607589, 2732864456, 1605421707, 2038450818, 3362242279, 4122980381, 2985487124, 110300201, 509696857, 4017443718, 1838466952, 2327354958, 2137521982, 138621377, 2133785874, 1413747039, 1739282333, 1675927427, 3185180235, 2373730108, 495353069, 1293977021, 1368037164, 2798684905, 81516419, 3857598893, 3495427721, 1760877692, 1152788660, 970145190, 3017785210, 1990290980, 67842654, 3925233768, 4002022873, 256245794, 1696909255, 738711838, 1446248938, 3868148475, 1939621088, 3658634988, 1784094744, 296610735, 1744996095, 2451305905, 2987883429, 2443943189, 1185159281, 4111976294, 2182983103, 3071467556, 1314957723, 152464689, 655243290, 1120785722, 1363764666, 3687005133, 4166824409, 4000596589, 2610383491, 3239976693, 4016033738, 1658070453, 1873771320, 1404781153, 3039196925, 3067316017, 2982654406, 373430827, 1067861532, 3675311637, 267701726, 3638688126, 2135521145, 1500865973, 1349464109, 3403519167, 2763769271, 4264496249, 1770838146, 1391852856, 2245935765, 2137032072, 1856750853, 688147071, 99608919, 3339156540, 2192512157, 2820609381, 2913853119, 2634309658, 1135799296, 1004115372, 2154400231, 248807841, 1855000157, 1129896866, 715899117, 923401956, 440983241, 3109426185, 477965580, 84854570, 1400593057, 1286402819, 3439792802, 318102008, 1067872774, 4240269308, 2549429311, 3855914333, 3795474048, 319956485, 4133374879, 1467447321, 78080933, 251366958, 3496186988, 3070794250, 2726275807, 2325947483, 53232900, 819143840, 2594863810, 55769358, 391403252, 866292794, 714727242, 2581427793, 698101228, 872117109, 997416829, 3565045060, 1611274371, 1562302439, 3417876422, 1392788396, 3206592320, 2207230324, 721665840, 2883852313, 1919434161, 395546903, 1837162976, 3429796583, 1522595013, 3348506059, 1175004850, 2074412352, 3758985171, 2415299591, 1703486181, 2304170834, 549422017, 2902550119, 2187986969, 3362538344, 768387705, 2456129962, 3429271129, 3706106775, 204072664, 806512945, 3311330988, 1474760185, 1814261184, 1577346575, 2078317992, 3122306338, 3467081158, 1225264382, 1847488518, 599526500, 2488492818, 1688305017, 438612976, 1607069832, 2927329539, 503082697, 1150118168, 1865053291, 2428734033, 4135113359, 821829536, 868562558, 413775071, 1334204977, 673408382, 1882374958, 4291776686, 4039700264, 911935763, 1482546356, 3253895898, 2457857766, 278986188, 1981961411, 1618768772, 658526763, 3374586925, 729845344, 1603013850, 1740519279, 2732939556, 821729772, 4152108407, 2159443430, 2359657320, 3741458889, 2625148830, 2740386969, 508381854, 4017968509, 3047356953, 3577840409, 4191139539, 592911320, 1632820963, 2552033020, 1212957884, 1014120341, 3310543727, 3779267868, 3583152261, 3680986715, 2995658876, 3137928161, 587029290, 2798799836, 4187466127, 594529572, 406159454, 4285083401, 2812893547, 841074247, 2953120157, 466989304, 2691320671, 3368132983, 3549982940, 3530994849, 3593311564, 2291771526, 3269554905, 4281627661, 1457258966, 3464984667, 1432592878, 1387812831, 154474021, 343300151, 1018325484, 2085104744, 4236418319, 2375494721, 1088217159, 389539389], "pos": 10}, "exits": {}, "cancels": {}, "failures": {}, "trials_closed": {}, "trial_progress": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": 0, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": 0}, "trials_created": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": true, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": true}, "trials_requested": 2, "search_method_state": {"pending_trials": 2, "remaining_trials": [{"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 32, "learning_rate": 1}, {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 16, "n_filters2": 64, "learning_rate": 1}], "search_method_type": "grid"}}, "trial_searcher_state": {"b20fd10b-c039-45fa-b450-86e9ad91ec28": {"EarlyExitedByUserCode": false, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 32, "learning_rate": 1}, "request_id": "b20fd10b-c039-45fa-b450-86e9ad91ec28", "trial_seed": 1367408042}, "EarlyStoppedBySearcher": false}, "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e": {"EarlyExitedByUserCode": false, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1}, "request_id": "c8810eb9-937a-4c71-86e6-cefbc9f1ba8e", "trial_seed": 1545095049}, "EarlyStoppedBySearcher": false}}}`), + }, + } + for _, c := range cases { + t.Run(string(c.name), func(t *testing.T) { + snapshot, err := shimExperimentSnapshotV5(c.v4Snapshot) + if c.err != "" { + require.ErrorContains(t, err, c.err) + } else { + require.NoError(t, err) + require.JSONEq(t, string(c.v5Snapshot), string(snapshot)) + } + }) + } +} + func TestDeserExperimentSnapshotIntoCurrent(t *testing.T) { - // This test tries to deserialize a copy of a the current experiment snapshot. + // This test tries to deserialize a copy of the current experiment snapshot. // If this test fails, it means there was a breaking change to snapshots which may've not // received a shim. Please ensure there is a shim and fix this test for the next time. tests := []struct { @@ -36,19 +84,14 @@ func TestDeserExperimentSnapshotIntoCurrent(t *testing.T) { //nolint:exhaustruct single := expconf.SearcherConfig{ //nolint:exhaustruct - RawSingleConfig: &expconf.SingleConfig{ - RawMaxLength: &expconf.Length{ - Unit: expconf.Batches, - Units: 937, - }, - }, + RawSingleConfig: &expconf.SingleConfig{}, } sm := searcher.NewSearchMethod(single) e.searcher = searcher.NewSearcher(0, sm, expconf.Hyperparameters{}) return e }, //nolint:lll - snapshot: []byte(`{"searcher_state": {"rand": {"key": [2970094109, 656686882, 618108684, 3428065983, 3347811667, 2888225350, 3059306387, 3429410465, 120474970, 2272301777, 985108865, 991558874, 3272543769, 1573748485, 942809215, 3888215743, 2210951765, 1718115507, 3963921664, 3557444060, 1499923783, 1829703377, 750493200, 3411092685, 589595500, 2596144409, 1879328096, 1280550458, 2715466210, 2544141428, 1543312021, 2997818084, 765128503, 536629897, 1001278031, 956821445, 2736363088, 947342293, 3898485884, 3425929255, 356849665, 3772908806, 3623557158, 2612581302, 3580922597, 217692149, 2059015628, 2096894728, 1031554394, 2656443257, 856730792, 2596152358, 3710976744, 3544672276, 1734608805, 481137749, 969621572, 754645566, 4116998116, 79516852, 3779707014, 84449565, 2299776967, 3049981717, 209980308, 730595889, 475796649, 176283770, 945344397, 3533051341, 1815446302, 3773966474, 2608085762, 3988702430, 1364580973, 954047394, 3612132653, 772541827, 644272441, 189803165, 931375382, 345347493, 1070927534, 1233280688, 2977342360, 964860546, 311939975, 4227569059, 1111154497, 2670796852, 2936307495, 1340846895, 2187526044, 4049656992, 2523649560, 3366534695, 4129435744, 2210183795, 1260247727, 3172172609, 845427714, 2412303263, 4029524955, 496281340, 1521525374, 1068147028, 715985502, 4093783277, 2159512480, 133552717, 436985375, 3034348399, 1875037974, 4060219881, 2429216519, 1706838315, 3704003030, 1782549491, 3768045061, 3989593374, 93865362, 1767865857, 1137597591, 252302268, 3084248212, 619916972, 4054361685, 2046158286, 1812194877, 3286982519, 3624326839, 2208625006, 3161233673, 3006516503, 2024883981, 1656495788, 587227161, 3021529118, 3172394998, 3398906615, 615744671, 3472224112, 212954520, 1118562041, 1921307781, 1197366600, 1484195533, 1200554730, 729716373, 1652264122, 1007315030, 2434524607, 585522965, 1141433277, 278327678, 3497727325, 356102370, 1394639384, 1857470125, 735035053, 2157227949, 2119739569, 171894210, 839804861, 1047866560, 420010537, 1667233883, 352520083, 1757099499, 2384088069, 251603325, 4218299275, 3730818711, 1003235663, 299940149, 2645812204, 2480727500, 1622974487, 1629312940, 3982198086, 2956489407, 2625953102, 1319536795, 42594099, 1983096735, 973960638, 3621842015, 3277003271, 2353057652, 181048897, 971058611, 596055263, 916085615, 181371615, 1523808158, 1627674241, 1944135612, 1287953114, 427398813, 3657156653, 2038254583, 1728892774, 2979810975, 1203810426, 870115982, 2342890482, 239526972, 161374077, 1298609642, 3361627384, 1120507760, 2857852335, 1817591908, 2278524994, 2381282828, 390752450, 2583666345, 2894166573, 1090621184, 1216468230, 2471775878, 4007164892, 4120390730, 2245347208, 1747432849, 2659194233, 3123421709, 1394627981, 3642599365, 447037831, 2948358439, 2071590290, 620060883, 1189294322, 865221550, 1072574172, 4239383229, 3932654708, 977824870, 4090083517, 996564545, 2215368799, 1639088756, 1063294874, 2589869379, 2931882145, 1119478531, 2913543435, 2863960600, 2220278034, 1514566588, 2171894117, 1040890154, 613591648, 4213726618, 3733950164, 3557521554, 451572785, 1963683489, 2910238750, 1724590625, 2276900333, 2786477774, 194531157, 2427295372, 2668452276, 3024832162, 4118825874, 2048453723, 602839057, 1287626269, 391207769, 1490752784, 631791025, 3012581026, 2165558496, 711793093, 394391602, 3030359078, 753489044, 1912791312, 805796394, 3405161241, 2387156388, 3012369288, 785079410, 1579291678, 2186594177, 1466653116, 2597054161, 2032480878, 2608593331, 245623596, 4030340391, 4102274438, 1820069747, 861878589, 1987396926, 1234162528, 4203068184, 4247911765, 675763300, 776602741, 2725860139, 1101969549, 2417348663, 3640520171, 4117232477, 1498968755, 1465742592, 978609897, 207663660, 3872867123, 2345883885, 3859129626, 2085742042, 2849699660, 766128209, 1220321488, 4036832842, 2625520262, 1672441713, 68307500, 1664835655, 4277760737, 2834495524, 1659112114, 4058377674, 3503936154, 3460939589, 1534417391, 546891454, 2119405354, 1342659996, 4063240713, 2470735676, 4263599371, 1004146457, 4110714736, 161171862, 3755105332, 2045654769, 3373580076, 3804959282, 172165195, 1513507373, 756508282, 2395802730, 1042880862, 424954450, 3528754016, 705074650, 1657675167, 3512260741, 1502017722, 3914739402, 2193244645, 2422500613, 99094108, 2136256596, 2765564582, 1469408647, 422762387, 4256532504, 2587902689, 322034733, 869598360, 1591826706, 3537194108, 432319981, 3238387140, 3998122956, 1638675667, 2820632381, 2704736212, 1001829652, 3842914244, 2163956861, 1247332102, 4161816793, 1047179830, 3332934466, 3373357404, 1755466786, 2066857771, 3041960543, 1984594045, 2873866111, 625922821, 2084260334, 647387503, 1795267011, 2734649093, 1429118207, 1763491842, 2499417212, 1322753865, 2546278618, 3069808828, 2758040961, 2095345089, 1117531393, 1079149953, 199028389, 5727143, 2626507809, 1394644924, 2371859674, 3378009021, 663183580, 1964126253, 4253076758, 1951862122, 3318205484, 246506769, 2065283265, 4062893977, 1357669276, 1910189941, 1768248614, 2642642591, 3082965658, 2821926795, 3387815148, 3811042618, 47443245, 1219528119, 2554982004, 3265451419, 2888244446, 4219157977, 970880074, 1607338684, 3191412771, 2649941819, 3075671217, 1402753481, 4075173340, 1295557797, 911791588, 847594689, 2359505701, 3661496760, 733404895, 249867897, 1640333750, 582965187, 1381547711, 2657453162, 1011026771, 2180307770, 52701968, 2169192963, 3746362727, 509530176, 4097693696, 406641367, 2522117714, 2016171603, 2069773173, 3084060900, 1332397153, 1412022588, 1875059352, 3676296777, 3375212542, 1306176142, 133024327, 2608708004, 139194617, 519502757, 1353538952, 3608390217, 3046145872, 1399364698, 4125351745, 572217286, 2142672378, 2889903141, 3434412210, 2393320012, 3956475466, 562179437, 3542566562, 3807446953, 1580557399, 3614175239, 3960865405, 3703717660, 2310530500, 1372501162, 3987983520, 2273295537, 3041676882, 1184238826, 1347025024, 1978741180, 1085872423, 3721856454, 4017856803, 3452381717, 2476724512, 3844971453, 3783643616, 1806640004, 945069700, 146817685, 3591154906, 3316795159, 3055211927, 3538095801, 2794589177, 2749831736, 3448467389, 2625023521, 183144251, 1181880302, 2668701387, 1211953638, 2740205019, 1424192308, 3005227360, 3518579311, 2030094765, 2163439322, 528876425, 713989960, 49680970, 3843683734, 2284197334, 754118913, 3660862775, 2811591585, 1900142815, 909665247, 3587945731, 1738555018, 4017103008, 363980482, 3596395463, 4036688775, 562009149, 652191781, 3255386890, 742878000, 9948449, 1293599923, 3198724758, 1824688996, 971993200, 545911088, 3111974979, 548678661, 3540201872, 347949716, 3899497825, 3919059349, 4023520489, 3906867502, 2926287461, 1975024135, 1961216548, 4241655703, 882861834, 447828602, 1073917811, 3204969134, 3637120523, 2561353660, 3087651888, 2625085299, 3045978593, 3296945077, 3353240266, 264317791, 3618427144, 3041663648, 3129441565, 1614617882, 742204376, 107307162, 4015043056, 164246162, 2971078786, 3420066832, 1704293461, 459244372, 744914593, 710366705, 2599447851, 3652089701, 712803605, 2980013293, 476563708, 1665840194, 1037190002, 3225477896, 1430942116, 2632665368, 3730881422, 1625879865, 363179597, 3829523707, 14110766, 1150504976, 2891936857, 256611346, 774302498, 184441060, 803067320, 3354307461, 2067530913, 466995748, 1177075098, 4068208859, 1923629384, 3949860772, 3830641659, 1017128665, 2972991357, 4238922926, 575232160, 101928834, 2405116568], "pos": 5}, "exits": {}, "failures": {}, "shutdown": false, "trials_closed": {}, "trial_progress": {"bb239c3a-c6b5-4ef3-ba7b-118038bd0d06": 0}, "trials_created": {"bb239c3a-c6b5-4ef3-ba7b-118038bd0d06": true}, "trials_requested": 1, "search_method_state": {"created_trials": 1, "pending_trials": 1, "search_method_type": "single"}, "completed_operations": {}}, "trial_searcher_state": {"bb239c3a-c6b5-4ef3-ba7b-118038bd0d06": {"Op": {"Length": 937, "RequestID": "bb239c3a-c6b5-4ef3-ba7b-118038bd0d06"}, "Closed": true, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1, "global_batch_size": 64}, "checkpoint": null, "request_id": "bb239c3a-c6b5-4ef3-ba7b-118038bd0d06", "trial_seed": 1757369869, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}}}`), + snapshot: []byte(`{"searcher_state": {"rand": {"key": [2458791418, 3116467974, 2738037530, 1223607127, 3286445278, 625897174, 1172106481, 2572993093, 2198533759, 3837319279, 89569234, 2840266316, 3281728377, 2462842803, 223147636, 1426784029, 117006397, 2636154877, 1449436706, 2927825553, 669826405, 4158003124, 3051998482, 3863884161, 4115545788, 3969262691, 1993859905, 526793478, 2862121258, 2585948009, 363338570, 2434066506, 2697236630, 3741159605, 4276159879, 2540043166, 1254547227, 281417924, 3524428006, 1717738180, 2788608558, 3067505898, 1428442263, 2529107117, 2141185930, 2917242155, 2718794995, 1103873205, 2335184097, 1489619463, 1859666767, 643801413, 3529330080, 1352766355, 2547570226, 2073329677, 2610277209, 435342246, 2505497031, 798656570, 3382074658, 4189489635, 28016400, 3296078941, 1945638788, 2826985424, 4265599268, 304154357, 932685664, 2758841683, 2455326991, 3980298044, 2314318323, 595095528, 4177520052, 3112160276, 1257280438, 1207886472, 2677553884, 3265570879, 2457548809, 2111634072, 3972389638, 665045651, 3994591990, 2042170295, 3594420958, 1611936557, 819891391, 3473558321, 426840388, 1314012383, 2561237332, 2272262342, 3134027931, 4294175632, 3732369287, 4128031240, 4147187676, 3149809049, 4144063892, 2345560497, 3738737588, 3840299650, 280313741, 101675464, 1393579054, 1505545589, 3277145020, 2188655472, 1421937881, 854289081, 2431554199, 1539816964, 83937993, 3224697403, 1610045703, 2816848158, 792619755, 934271931, 2710423316, 3868916984, 3853694682, 580166127, 3034428351, 1242892688, 859119565, 3802529696, 1000657147, 1905689709, 2988202998, 4019138435, 3698964560, 781687417, 2634708534, 1790387080, 3083588514, 2067517328, 3813894530, 1697102895, 1736661565, 1635070572, 2451931176, 2607824278, 2710884525, 3333385014, 2122182769, 2392051351, 2906588340, 4103192869, 1168352864, 1983143924, 710890467, 2699295676, 3119053614, 1518994382, 2466039509, 572476338, 4181949791, 3735635392, 742157348, 1157573401, 1972168146, 272316361, 1510876922, 1556419901, 941326744, 2682202561, 3173709824, 1829453176, 2251191850, 3456122487, 3173338821, 330462762, 1754886765, 250991513, 1585255942, 253192177, 3012909793, 3530990056, 2004128557, 514254595, 2445947804, 2613994303, 1602501425, 2012372011, 3584430451, 2144992891, 453151729, 418148689, 3200328899, 1320919208, 3350564767, 1251183192, 3290155726, 3287744578, 1064686988, 3961170913, 1980153684, 1339419052, 3783807526, 2183864424, 2594654516, 1832017181, 3373700328, 248263337, 4269409531, 3608097081, 2124157490, 3498209804, 3263200819, 3127724750, 1312779256, 1734613689, 3719011732, 3243076830, 335603414, 1448836682, 2863567879, 3171222479, 953744612, 1854108551, 396912697, 1471435093, 2536166045, 2432755996, 50855133, 1127984088, 2445499632, 682533110, 692746935, 1651664519, 2705150811, 4151653551, 1934542959, 3007387232, 2198726288, 3698559570, 3010880745, 3239971816, 168338881, 2794239687, 2901311416, 1552192572, 1131246424, 2788592201, 805537073, 2327010547, 1261208610, 3515399927, 3639687688, 151099323, 1666670090, 3955849057, 484700107, 3470434774, 2223492086, 90444150, 1572468487, 697895815, 2621785173, 4028029284, 1631608044, 1794520030, 1369705039, 4286067898, 1548006208, 1859332416, 3062951908, 343535845, 1027995635, 847936898, 1002776644, 1787482428, 1761964089, 1010518086, 2242475935, 3301831376, 3685964202, 2354379073, 3765363997, 3338328111, 2354888011, 3883114207, 936858297, 1841452303, 2180987975, 3833453400, 4253783174, 187816314, 1065449290, 828211291, 1227118119, 3312227455, 857392088, 3989703709, 2722876258, 2826056677, 2487224771, 1620253778, 4235150726, 1357697909, 2030894361, 3540876365, 2001105143, 203686501, 3114867920, 2041712100, 2614268670, 3094356533, 3358484761, 4080319699, 3980200285, 1516323641, 1169578821, 2585305656, 1638623928, 2071984054, 1179146259, 3142029384, 2910525849, 430984842, 1809790632, 1593515823, 2878348267, 4018175470, 1966255236, 1939592619, 2197525510, 4123890540, 4027942930, 2056710998, 1057194760, 3276845032, 618524839, 912386042, 1369305871, 4213981309, 4157357229, 1998049094, 4068346859, 283997778, 433297801, 2185564962, 2668307743, 3547263851, 2971812172, 145864931, 38002940, 2929551979, 2152716980, 2278120571, 1172785191, 1888308402, 930353176, 3711016582, 2308015795, 2305204206, 4068791645, 2473454079, 2501692185, 3674392733, 2375917712, 2557427355, 1990988406, 2226078296, 600963582, 347505141, 705969876, 3222111902, 237273550, 472225245, 1376144940, 3812100488, 4130033222, 1834636053, 2200817839, 3623132254, 2048566977, 872938564, 157673410, 2667298410, 1593933448, 2550146443, 3424352086, 3520633431, 3318249052, 1535053463, 1320135181, 1072090846, 3748951794, 3105511402, 2155614440, 3343066249, 1664809838, 2354003702, 577588884, 2764537545, 4260035710, 160815645, 3551374291, 2754442860, 1118743534, 1286742203, 3375424472, 1193358286, 1523416396, 2286745682, 3818424205, 3229692865, 1250603769, 3376938675, 1934656537, 3002173085, 1421535763, 2329166244, 1539829771, 3071809424, 3722361831, 1807071764, 2928747228, 3466233204, 1681608616, 1707233748, 532399055, 1655975870, 2192615935, 1181907517, 502315677, 195435504, 1068351185, 4253522108, 401515526, 2322634116, 3817476766, 2545684199, 367003362, 939004561, 1720493014, 3282529078, 916836710, 3271777100, 1051347323, 903520982, 1843390427, 903860254, 3163311743, 1498324049, 1602796155, 3993830274, 2222947632, 1594870809, 2262149984, 706505490, 1878448793, 297783572, 1922520975, 1809817994, 3378836142, 4283471362, 724529390, 3871255903, 2675377541, 625442236, 1396594762, 584050984, 3160502847, 3713812500, 3070779976, 2415860718, 2087253424, 2150659256, 2846376499, 642246206, 954003700, 790693801, 2852862948, 2069560311, 2716657131, 1777996881, 23021516, 394280992, 151827274, 554425914, 3093163480, 500362788, 4063471109, 1518232184, 3051194338, 1564237119, 1784240953, 1728907606, 2423906971, 3245861425, 1403272208, 705601793, 809561192, 2967681376, 1107800719, 3654973939, 3529963233, 1141430012, 171161121, 1222606508, 3325794587, 2933848069, 3319288855, 4076047601, 1422688234, 750895463, 723415161, 4110917996, 172378373, 614310717, 464190258, 2719742080, 2882190533, 260241577, 872187166, 51392255, 690678529, 1032247122, 177848862, 2631860912, 3876348701, 3971628423, 1411241180, 1525225113, 3429234786, 2902950094, 3480849833, 1067421076, 4184135694, 2905699039, 3922251469, 556665325, 1008854626, 1749376673, 1706526652, 1670818738, 1459937672, 1236710049, 2044381071, 28380869, 1088464802, 2500621248, 3319166637, 303350363, 4233871691, 462899686, 1864082083, 1633855571, 3234480516, 2411053733, 3986263103, 1992510755, 1986246923, 1418621100, 3772785621, 3118288321, 3069082678, 3575509950, 2661593874, 189268839, 1690530432, 1411957573, 4198485847, 1214217760, 4032838120, 647907711, 4245255521, 1940716058, 242619953, 2097141622, 1218025765, 1007790786, 3748834757, 8175890, 2456318442, 219038421, 4142868721, 2303296422, 465336410, 237135465, 2525403966, 2972194247, 672480967, 2623034704, 1120730289, 1018451191, 1318162784, 2171567834, 1931677349, 1300941830, 1091032044, 344455379, 3648398142, 1660289918, 1476172654, 1382274791, 376429449, 3753352837, 884844415, 3994051274, 3011131147, 3092988955, 1358346546, 3519843272, 3978487789, 3830946161, 1625675712, 3511881408, 2031293928, 2203163940, 3362955715, 4279271535, 4220311393, 2051555050, 996884877, 1530801618, 3353569995, 1741998127, 728272117, 724195433], "pos": 5}, "exits": {}, "cancels": {}, "failures": {}, "trials_closed": {}, "trial_progress": {"4244f6bb-a0b0-4876-9e75-819b7cbdcd95": 0}, "trials_created": {"4244f6bb-a0b0-4876-9e75-819b7cbdcd95": true}, "trials_requested": 1, "search_method_state": {"created_trials": 1, "pending_trials": 1, "search_method_type": "single"}}, "trial_searcher_state": {"4244f6bb-a0b0-4876-9e75-819b7cbdcd95": {"Closed": false, "Create": {"hparams": {"dropout1": 0.25, "dropout2": 0.5, "n_filters1": 32, "n_filters2": 64, "learning_rate": 1}, "request_id": "4244f6bb-a0b0-4876-9e75-819b7cbdcd95", "trial_seed": 1557607182}, "EarlyStoppedBySearcher": false}}}`), }, { name: "asha", @@ -58,22 +101,20 @@ func TestDeserExperimentSnapshotIntoCurrent(t *testing.T) { asha := expconf.SearcherConfig{ //nolint:exhaustruct RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(4), - RawStopOnce: ptrs.Ptr(false), - RawMaxLength: &expconf.Length{ - Unit: expconf.Batches, - Units: 937, - }, - RawDivisor: ptrs.Ptr[float64](4), + RawNumRungs: ptrs.Ptr(4), + RawMaxTime: ptrs.Ptr(937), + RawTimeMetric: ptrs.Ptr("batches"), + RawDivisor: ptrs.Ptr[float64](4), }, RawSmallerIsBetter: ptrs.Ptr(true), + RawMetric: ptrs.Ptr("loss"), } sm := searcher.NewSearchMethod(asha) e.searcher = searcher.NewSearcher(0, sm, expconf.Hyperparameters{}) return e }, //nolint:lll - snapshot: []byte(`{"searcher_state": {"rand": {"key": [1858898874, 3136713915, 1245759550, 1290970761, 517441830, 3965724506, 3061357240, 170072766, 1693351111, 4075410292, 78151388, 3493175442, 451004311, 1812236155, 355118118, 2341554586, 3606578038, 1939667144, 2365456278, 4240753950, 3356447601, 2335017104, 2040721097, 3800998272, 2728642800, 2131115629, 492734985, 2819783747, 3003215375, 4204037864, 4126695697, 549891058, 734125689, 31409235, 985387360, 3038862999, 954240667, 2621873052, 1848091170, 133174459, 3856584872, 124845556, 460434664, 46791082, 4152381650, 4173967752, 1247237436, 174835189, 1742324874, 2007241764, 1640712598, 2310183041, 3511885426, 578000042, 706952488, 3805905197, 1542555822, 1232482704, 113479587, 2075606872, 1810250676, 4154209381, 4104047634, 2798916177, 1973933529, 2923443846, 448405122, 3166808065, 2849599398, 468126651, 2661086816, 1991933754, 682920740, 2267700733, 4109317397, 1130169565, 3047640000, 3252889878, 3054398763, 2806904880, 2855243624, 3038966390, 1241065902, 2729169001, 815645434, 1162355224, 704271692, 536316416, 2670268274, 4163301069, 4085384831, 291065618, 1546534018, 1328639211, 2136729346, 1649093914, 1950870460, 3051014222, 4026077910, 2922618107, 2381221227, 1814305839, 2087985637, 4161848789, 128854472, 2943029177, 728532032, 3490598714, 669817953, 2440038690, 3555401139, 2712280152, 384719815, 2442045992, 613477876, 1377552755, 2711705298, 3724688321, 4067433634, 1545601253, 646274014, 1377898108, 240985165, 371546231, 3046184253, 2117921649, 2247989848, 2061520991, 956268219, 97033364, 175144327, 1353593441, 1476979375, 454979253, 762796608, 2575191972, 107577675, 1028450605, 1873422748, 2542366410, 1866320548, 3589705747, 345592247, 2892036283, 3170978743, 3572903695, 3028956090, 2207819198, 1950852740, 1313260327, 3971819579, 1504837033, 4070108073, 359514223, 1902603031, 3271335576, 3157069217, 2191931698, 1256929557, 421027589, 4212835066, 4062519531, 296513181, 2106428114, 549774238, 393306693, 1174412267, 3955316363, 2441986612, 1523312170, 982121345, 1785954095, 4080295649, 299131628, 393188011, 3374589678, 4024606822, 1686014117, 1197296932, 2797619434, 3649470375, 2716003058, 86300908, 3697562572, 3237922847, 1160822791, 3374522836, 3979525117, 2212718065, 1148293750, 1919573177, 3469185829, 1712413795, 380808588, 471126594, 948789273, 3532936578, 296952917, 2753628341, 2057043750, 920211988, 600572981, 2774709890, 1819324959, 2259319642, 574572799, 3073898101, 3230354882, 2678617999, 3641549837, 2132486047, 2765479196, 1727422873, 3997112925, 471821958, 92070175, 619429912, 1550462124, 2347682773, 1567881080, 1842852436, 3821419603, 3591165196, 2535124603, 3618924873, 2628843537, 1170033179, 348211080, 3822572799, 749527186, 1398111115, 727305496, 3543922887, 1648448468, 917342585, 3626786227, 795087139, 3407120437, 3355836138, 1410515106, 2720308312, 3240564741, 3959711847, 1124409684, 552560209, 3461944805, 678032564, 1103992569, 2470254354, 1641770510, 1556963167, 2379222670, 442642977, 2742252027, 1950349383, 2063177700, 2732950260, 4185258652, 2168049423, 1082372904, 1362620822, 1741310903, 901367838, 3508554445, 1348790639, 2317903621, 3373951142, 1337343925, 3840044638, 1485094307, 4248249804, 4093032023, 2136011, 1334492457, 2503063800, 2538115118, 2466171479, 1639514097, 464469616, 1764593097, 4213907472, 3313535300, 1076926430, 1725535044, 1011006056, 3988586498, 2745469741, 2906100946, 3576331813, 82093349, 429491095, 2114758070, 933184772, 3917920972, 1890858429, 531920684, 2456575097, 420691463, 748296319, 1079814478, 4024156541, 813238500, 1077678765, 2465759496, 3962103589, 2065972465, 3456128976, 2094509478, 2645396852, 2925520141, 1161978458, 2035040231, 4099469514, 1972456981, 3786725514, 3058419285, 3818057337, 2430215031, 1937931725, 2515478709, 1847103536, 4155632999, 1328666540, 4158737985, 2754971121, 2211399759, 2853435464, 3261493685, 4150017738, 3342340355, 775224247, 2961220606, 614027509, 3160064455, 3414222376, 2091154502, 280950518, 925950753, 566408845, 2596638294, 3135099147, 2706786311, 1388405449, 2816679388, 1623101484, 3178371694, 3291179795, 4116991785, 67859698, 845643000, 760347068, 1455100406, 2912358999, 2105127, 702301878, 1202736257, 1342813496, 575012872, 1865672624, 2657430022, 644622144, 2333997711, 1488359185, 4066857739, 3973656318, 2828413897, 3450795915, 1884785401, 1930824357, 1051770001, 3572092020, 1862280154, 3461196456, 3800146563, 1633638066, 1452346239, 534900790, 3121061430, 2916315663, 4255721426, 2887119093, 1871569903, 1332276432, 654698363, 3249099829, 863175434, 136189354, 3475430455, 2548971290, 2855268040, 3442424632, 804140083, 3514822970, 3941128355, 1313049612, 2892025674, 704113086, 2375636355, 3010471523, 1804040519, 682699732, 3588164492, 2899686231, 1835583680, 1773430431, 1098795657, 593519929, 125730791, 3274894119, 3824589328, 378493220, 1626526963, 3831253813, 3392887684, 1584863613, 3559778327, 2663823706, 3923032836, 3600949771, 1780734363, 3234463836, 3308378262, 1986979169, 2249473811, 2903716227, 3677580360, 1437072767, 3381143691, 3635683729, 3252548987, 1423467992, 1159796525, 126415660, 1073122400, 4213681404, 1155308855, 2805178822, 2953363327, 2150137973, 3240821181, 1698262113, 3278041312, 1513835831, 594252990, 2719832498, 2171241681, 3185876351, 1036634921, 1136285885, 4293927994, 2470722457, 4240063557, 2637317952, 161087848, 4147589514, 3912003562, 3965637509, 2793578416, 2291662622, 3749525828, 4282881950, 3280195163, 484269631, 2746175909, 3446976909, 2867607820, 395936285, 3321864120, 4273485048, 1086468822, 2838328381, 1644558158, 1834777629, 483686628, 1256155081, 2330382284, 2755082552, 4194834958, 391405988, 3209673909, 1870065414, 3052141802, 775165415, 4149272392, 194931648, 2367265709, 3891336490, 1229306547, 3012235584, 3643629589, 832189420, 469287166, 1525326108, 267712853, 410832606, 3637702001, 1293947797, 2676100704, 3127780282, 1522722013, 2987243003, 4223119695, 2471565775, 3741486091, 2517288835, 1421919913, 2170944291, 3677925623, 2433847083, 2055300628, 1274730272, 4194354236, 3445728648, 1457079160, 530533758, 2650792665, 3386458153, 2521874364, 1903579212, 761566114, 1632275601, 3215030020, 14846682, 3063448941, 726016154, 2228712970, 4195077512, 4196716184, 2434034522, 1909349859, 3301955567, 2777200877, 4292042460, 1199269432, 685902762, 2499240636, 1122335115, 439154079, 4168729086, 496834902, 2182051026, 4131884045, 1839561063, 39520120, 866177038, 827878713, 1764316880, 206451689, 2373512346, 2172390722, 3933252186, 3127637834, 2071541478, 3960855872, 1444619351, 1938511308, 3721638501, 1259236365, 4226881135, 610757816, 838627681, 1588715647, 1972104498, 2054289333, 419545885, 2474876715, 1254095866, 634279673, 3761069612, 1633198557, 2273078632, 3566317607, 1787529366, 1415431672, 3275737983, 2504574445, 430721643, 2335224241, 4010159314, 3152956030, 3456999186, 3382584293, 2802928282, 1378465088, 4251009334, 2588165208, 243375617, 3821815786, 4098223378, 4161610756, 1935733261, 2023780182, 3371959456, 2139155484, 2612371408, 338927672, 147433055, 1797581704, 3653488639, 2425792923, 888477626, 3057409903, 1694686933, 4004894142, 1992613109, 370443325, 227467078, 3544764567, 491152413, 2983604173, 2939234591, 157854636, 464880009, 2362768353, 2772766616, 932280116, 1106584051, 2213427382, 2869551831, 1522462666, 2481374380, 2796644796, 753754775, 1586605155], "pos": 112}, "exits": {}, "failures": {}, "shutdown": false, "trials_closed": {}, "trial_progress": {"3a77d780-1fb2-4af0-ae41-180c346cfcf8": 0}, "trials_created": {"3a77d780-1fb2-4af0-ae41-180c346cfcf8": true}, "trials_requested": 16, "search_method_state": {"trial_table": {"1648b3b2-74e9-439f-8e94-bfae278ea93d": 1, "19e57eb7-eb24-46fa-b740-21caab57a9d1": 1, "2bd1e7ff-5103-41df-9fa3-1ff4fc7acc14": 0, "306cd910-0605-4c85-84f3-5e6757ea1d01": 1, "3a77d780-1fb2-4af0-ae41-180c346cfcf8": 1, "45e6cca0-42a6-44be-8b58-88b20b00357e": 0, "65b30b87-e141-41fe-8816-7e1ee1a0ff36": 0, "69129b1e-843a-4879-8d8a-d0d02b696649": 0, "b383f52f-10be-4c55-adb7-2d0f75d6686d": 0, "beda8c08-e5b9-4861-b39c-c4e866297709": 0, "bf2b30d7-301f-48bb-8b35-a20b341291cb": 1, "cea43969-c629-4d82-830f-dcb47319e025": 1, "d8342b8c-3925-4ae8-b142-0b552b98223d": 1, "eee2da18-1bd9-449e-b247-906917d06878": 0, "fad2b882-e6cd-4f57-b285-613e3e699933": 0, "fafa3fd7-6bf5-447f-8ef8-f5001060b767": 1}, "sub_search_states": [{"rungs": [{"metrics": null, "start_trials": 0, "units_needed": 40, "promote_trials": 0, "outstanding_trials": 0}, {"metrics": null, "start_trials": 0, "units_needed": 200, "promote_trials": 0, "outstanding_trials": 0}, {"metrics": null, "start_trials": 0, "units_needed": 840, "promote_trials": 0, "outstanding_trials": 0}], "trial_rungs": {"2bd1e7ff-5103-41df-9fa3-1ff4fc7acc14": 0, "45e6cca0-42a6-44be-8b58-88b20b00357e": 0, "65b30b87-e141-41fe-8816-7e1ee1a0ff36": 0, "69129b1e-843a-4879-8d8a-d0d02b696649": 0, "b383f52f-10be-4c55-adb7-2d0f75d6686d": 0, "beda8c08-e5b9-4861-b39c-c4e866297709": 0, "eee2da18-1bd9-449e-b247-906917d06878": 0, "fad2b882-e6cd-4f57-b285-613e3e699933": 0}, "closed_trials": {}, "invalid_trials": 0, "pending_trials": 8, "trials_completed": 0, "early_exit_trials": {}, "search_method_type": "asha"}, {"rungs": [{"metrics": null, "start_trials": 0, "units_needed": 160, "promote_trials": 0, "outstanding_trials": 1}, {"metrics": null, "start_trials": 0, "units_needed": 800, "promote_trials": 0, "outstanding_trials": 0}], "trial_rungs": {"1648b3b2-74e9-439f-8e94-bfae278ea93d": 0, "19e57eb7-eb24-46fa-b740-21caab57a9d1": 0, "306cd910-0605-4c85-84f3-5e6757ea1d01": 0, "3a77d780-1fb2-4af0-ae41-180c346cfcf8": 0, "bf2b30d7-301f-48bb-8b35-a20b341291cb": 0, "cea43969-c629-4d82-830f-dcb47319e025": 0, "d8342b8c-3925-4ae8-b142-0b552b98223d": 0, "fafa3fd7-6bf5-447f-8ef8-f5001060b767": 0}, "closed_trials": {}, "invalid_trials": 0, "pending_trials": 8, "trials_completed": 0, "early_exit_trials": {}, "search_method_type": "asha"}], "search_method_type": "adaptive_asha", "sub_search_units_completed": [0, 0]}, "completed_operations": {}}, "trial_searcher_state": {"1648b3b2-74e9-439f-8e94-bfae278ea93d": {"Op": {"Length": 160, "RequestID": "1648b3b2-74e9-439f-8e94-bfae278ea93d"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.5750164976379073, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "1648b3b2-74e9-439f-8e94-bfae278ea93d", "trial_seed": 2143610236, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "19e57eb7-eb24-46fa-b740-21caab57a9d1": {"Op": {"Length": 160, "RequestID": "19e57eb7-eb24-46fa-b740-21caab57a9d1"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.5816375181302468, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "19e57eb7-eb24-46fa-b740-21caab57a9d1", "trial_seed": 421951454, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "2bd1e7ff-5103-41df-9fa3-1ff4fc7acc14": {"Op": {"Length": 40, "RequestID": "2bd1e7ff-5103-41df-9fa3-1ff4fc7acc14"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.7684118342109852, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "2bd1e7ff-5103-41df-9fa3-1ff4fc7acc14", "trial_seed": 212852662, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "306cd910-0605-4c85-84f3-5e6757ea1d01": {"Op": {"Length": 160, "RequestID": "306cd910-0605-4c85-84f3-5e6757ea1d01"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.8331292167743722, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "306cd910-0605-4c85-84f3-5e6757ea1d01", "trial_seed": 1609725763, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "3a77d780-1fb2-4af0-ae41-180c346cfcf8": {"Op": {"Length": 160, "RequestID": "3a77d780-1fb2-4af0-ae41-180c346cfcf8"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.8471874828144244, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "3a77d780-1fb2-4af0-ae41-180c346cfcf8", "trial_seed": 2105514700, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "45e6cca0-42a6-44be-8b58-88b20b00357e": {"Op": {"Length": 40, "RequestID": "45e6cca0-42a6-44be-8b58-88b20b00357e"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.7217543249722611, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "45e6cca0-42a6-44be-8b58-88b20b00357e", "trial_seed": 1245324437, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "65b30b87-e141-41fe-8816-7e1ee1a0ff36": {"Op": {"Length": 40, "RequestID": "65b30b87-e141-41fe-8816-7e1ee1a0ff36"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.8213712160318751, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "65b30b87-e141-41fe-8816-7e1ee1a0ff36", "trial_seed": 599927259, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "69129b1e-843a-4879-8d8a-d0d02b696649": {"Op": {"Length": 40, "RequestID": "69129b1e-843a-4879-8d8a-d0d02b696649"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.8661962084783181, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "69129b1e-843a-4879-8d8a-d0d02b696649", "trial_seed": 1023174740, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "b383f52f-10be-4c55-adb7-2d0f75d6686d": {"Op": {"Length": 40, "RequestID": "b383f52f-10be-4c55-adb7-2d0f75d6686d"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.5359870317307868, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "b383f52f-10be-4c55-adb7-2d0f75d6686d", "trial_seed": 1099784968, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "beda8c08-e5b9-4861-b39c-c4e866297709": {"Op": {"Length": 40, "RequestID": "beda8c08-e5b9-4861-b39c-c4e866297709"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.7829296920408424, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "beda8c08-e5b9-4861-b39c-c4e866297709", "trial_seed": 1740151796, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "bf2b30d7-301f-48bb-8b35-a20b341291cb": {"Op": {"Length": 160, "RequestID": "bf2b30d7-301f-48bb-8b35-a20b341291cb"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.6767256393511032, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "bf2b30d7-301f-48bb-8b35-a20b341291cb", "trial_seed": 981345028, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "cea43969-c629-4d82-830f-dcb47319e025": {"Op": {"Length": 160, "RequestID": "cea43969-c629-4d82-830f-dcb47319e025"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.854520819444287, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "cea43969-c629-4d82-830f-dcb47319e025", "trial_seed": 575167432, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "d8342b8c-3925-4ae8-b142-0b552b98223d": {"Op": {"Length": 160, "RequestID": "d8342b8c-3925-4ae8-b142-0b552b98223d"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.7231724660254734, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "d8342b8c-3925-4ae8-b142-0b552b98223d", "trial_seed": 2032674049, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "eee2da18-1bd9-449e-b247-906917d06878": {"Op": {"Length": 40, "RequestID": "eee2da18-1bd9-449e-b247-906917d06878"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.8005169955058526, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "eee2da18-1bd9-449e-b247-906917d06878", "trial_seed": 1209630132, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "fad2b882-e6cd-4f57-b285-613e3e699933": {"Op": {"Length": 40, "RequestID": "fad2b882-e6cd-4f57-b285-613e3e699933"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.6903862897512029, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "fad2b882-e6cd-4f57-b285-613e3e699933", "trial_seed": 86866247, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}, "fafa3fd7-6bf5-447f-8ef8-f5001060b767": {"Op": {"Length": 160, "RequestID": "fafa3fd7-6bf5-447f-8ef8-f5001060b767"}, "Closed": false, "Create": {"hparams": {"metrics_base": 0.5785735939157665, "metrics_sigma": 0, "global_batch_size": 32, "metrics_progression": "decreasing"}, "checkpoint": null, "request_id": "fafa3fd7-6bf5-447f-8ef8-f5001060b767", "trial_seed": 1213847999, "workload_sequencer_type": "TRIAL_WORKLOAD_SEQUENCER"}, "Complete": false}}}`), + snapshot: []byte(`{"searcher_state": {"rand": {"key": [3912824393, 130359580, 1945038949, 683827302, 3610818264, 2838364317, 4181320722, 221936505, 3819130084, 2131487376, 3365651765, 920499941, 1270817540, 2246276083, 1885603276, 573420793, 583261746, 1955039589, 1811914457, 2837955835, 2508148103, 3642464710, 3458020788, 2195814166, 1386387121, 1852638509, 3909111177, 1735517492, 3892314499, 3707856689, 879481158, 1802646466, 116991403, 298788384, 1221574842, 3804792269, 2328567293, 3014632505, 2320044096, 1413171799, 1368251971, 4083772998, 2986640645, 1030832729, 2241679027, 2849293624, 1296881709, 2785559672, 1135156955, 2796434558, 1925391449, 1829990763, 3217631629, 3067008047, 3612907438, 1326801960, 2558425453, 4068998963, 3070803605, 2403866132, 3753353239, 3296906854, 464488764, 785445029, 141537657, 2465881066, 3135233249, 4043279941, 1278274609, 2486847651, 3954374223, 129103232, 1720962446, 3816310415, 1806398817, 1250007987, 969936705, 2336561042, 2031384508, 1285947737, 39769194, 1138144666, 993968588, 101542063, 3128650165, 2003415882, 140441857, 3432886469, 3489699300, 575264191, 3842163036, 2887600812, 2487983973, 1955108072, 1173170507, 4219701607, 1171912093, 2695607515, 1710258012, 2435659755, 2071371892, 3067907555, 544251540, 2010934296, 3117136901, 4036472942, 3059696740, 3788372660, 566955155, 358758290, 2882442037, 4169179923, 1160213813, 387790744, 594608752, 3632049166, 3841361487, 1667303281, 1460402258, 2438875135, 4278318681, 2281109706, 221829418, 3933060752, 4221766281, 3185559353, 70745549, 224504629, 2249432201, 156645497, 1758678503, 3375329816, 4241049874, 3873245768, 19688418, 1201216527, 3912743911, 216273652, 1882335310, 2852029249, 3415491706, 913228854, 2437734945, 3245489646, 1027437138, 1302726420, 1927993221, 118233091, 2085121918, 3536455915, 1939892139, 4010137323, 2303981927, 3716018075, 2265312706, 2949632039, 2447406304, 624960780, 4071396684, 3488406087, 3938075266, 757291956, 582212968, 3048403818, 2464054892, 4190155310, 292733183, 3094168255, 3626519501, 1066765321, 2996573857, 1225379640, 3774880867, 758941699, 116163606, 3942208179, 1042100701, 461769728, 3164917034, 1857854599, 2167218317, 4253533681, 3818830317, 3399752715, 2419846770, 4069369592, 2150336629, 1810772325, 3026418855, 4177626623, 4092551316, 3523400346, 4234812158, 713640060, 2681330789, 1467890933, 2993052514, 3162989269, 113784382, 3852222819, 4292344413, 1633710857, 3912716682, 6098202, 2978849649, 4160324727, 2926842096, 1514591911, 1081000496, 372322291, 3857309830, 4178509955, 669675724, 1711300651, 3716229730, 143240009, 4059287817, 3731622530, 496463408, 4088097703, 1794430305, 3973039871, 2939992759, 2712680772, 3534376933, 2786410924, 278436436, 3437770869, 2024532266, 2579067234, 4007664809, 664954588, 1789677620, 1276521848, 3683816415, 1548063198, 4095783552, 1459425342, 1799248342, 1609271699, 1512067210, 3465440954, 4269839847, 4233980673, 375224779, 3138146110, 3370149200, 579857917, 760314921, 3775887833, 4176789722, 2301024383, 1686052811, 2967171371, 4180069250, 1516946627, 705831480, 1611848722, 2806858519, 3504667135, 764796204, 4065414292, 2092577162, 1913641905, 2109472504, 1540233965, 1722162277, 3692095086, 2929243742, 3283756641, 1377766218, 699611861, 1212571847, 78849955, 698623272, 2796609094, 241440821, 426090509, 1958846217, 2444318395, 3236802781, 133530479, 1092024126, 4072615628, 2198341714, 282741196, 1398198553, 242935369, 2651474853, 677947580, 1708783244, 1906140963, 3673969094, 2279101198, 3916495828, 1917458424, 2415715383, 2959221265, 2938458868, 1115625347, 2902621834, 1038437194, 3503923682, 2953499342, 3582273746, 325962114, 707955326, 3027365402, 4165229536, 1899485827, 3813891089, 2681883220, 1184425310, 2323719917, 2457851599, 2520494097, 3105387047, 1535476439, 1564705883, 3867748520, 1195731676, 2921591976, 333612810, 1635420993, 4259291997, 1241352669, 3675691581, 352453891, 2248815017, 3137177282, 729048856, 3439156494, 4008821214, 4014395416, 440152211, 768552979, 3657002944, 2120438653, 4230634748, 1758565824, 782079134, 3695862927, 394672225, 1006358179, 183399528, 2883538266, 1469129531, 320543524, 892791693, 870595213, 3437448149, 808212320, 3050276488, 492820811, 4000264990, 3105300607, 133021706, 1603329269, 4075016600, 1820213035, 2998065090, 1632804974, 675976454, 1681042893, 1653548377, 3253586235, 1201835952, 2918720057, 1931586371, 3328503644, 403841940, 2742827784, 586210709, 3159629495, 246459628, 1721247097, 1312377931, 3239372286, 1679600923, 794123331, 1817944193, 3905102033, 3378552747, 2071157036, 2394010092, 2765024891, 966846871, 2805363146, 2459053889, 618201661, 1749046098, 568312106, 484146509, 541311246, 3718232983, 326564023, 284954267, 538365293, 1950605761, 943434222, 1820312121, 121100221, 2064278572, 2697100734, 102308777, 3114797650, 2729269456, 537270132, 1206836708, 2583134729, 4017452671, 1912421380, 1901938780, 2087021833, 3135764995, 1186775394, 26323543, 2703782091, 30908834, 901812067, 2118851464, 3029690482, 375825601, 2213919290, 2396825578, 3057698947, 2058071747, 1632052001, 544954234, 2757964379, 4281338914, 2452188298, 2530548058, 2140337096, 8135661, 1073358913, 1887116366, 1039964896, 3596599560, 3114419318, 1914007107, 3160591144, 554429067, 857043881, 3862692398, 2322654082, 4293541172, 703438619, 1995234082, 1658759338, 3008231258, 2598610424, 1965795485, 3566144875, 2921648686, 3345004424, 2078323309, 3324845404, 1698658304, 455090979, 361487935, 3873466677, 1296591922, 3439118724, 2030854673, 4221505583, 529693036, 1194187791, 2108178438, 3512528157, 3889045825, 1912433754, 3786020945, 3365110760, 4106573256, 3223175585, 1666849351, 1260427075, 2012436050, 1441286049, 3811351537, 3360856433, 260097501, 632249840, 3297189399, 4276078413, 200516346, 3427560446, 1180083550, 570419283, 3289677820, 557135629, 3537527213, 3760168670, 3350252912, 2643371920, 3193156919, 4182398732, 3395828475, 495148909, 3660199181, 3822738986, 2696324761, 1126852263, 2413756996, 2224145812, 763779427, 4236225820, 1512096030, 936890385, 3871755505, 3056164931, 1971858895, 2437295547, 1083615465, 2711940487, 1939973176, 611444976, 2577021608, 2690806091, 4001225725, 1295618973, 2739332287, 1543112445, 3309369740, 1403192430, 2634062471, 472690292, 2308009112, 2156438823, 568667603, 1171167051, 551717849, 2679282693, 1893880801, 1266209741, 2206656951, 3095140942, 3825923, 3559702505, 3008679460, 554453022, 395553371, 1655473374, 828559601, 4044140613, 1728590401, 3047261812, 1003908978, 2447662187, 1442086653, 1489454960, 2524559089, 2050535748, 3390209316, 1931858035, 720671506, 3793532296, 2486505886, 3776873526, 1029777346, 3472059923, 2093494001, 4001189252, 1295318492, 256747214, 2561538344, 820561233, 1492703895, 4281534291, 2679759074, 1296226481, 2068854739, 4204858459, 3826654958, 814564580, 2703682783, 511237125, 1886845754, 4054856321, 2444144699, 95504058, 685255946, 4196035992, 4009208239, 2844476321, 3882335574, 1524669498, 3071049109, 3939141547, 53659124, 2986167924, 4062365550, 3944273163, 858877051, 3811779921, 2800310021, 3471182814, 3978499914, 184095907, 3143372954, 978190594, 428108114, 3979094339, 504541629, 3012747542, 1929522274, 641465681, 389762475, 3863437281, 131374935, 695264856, 934006733, 4014903973, 2845408941, 2293374102, 3231336512, 516967894, 51360470, 3991740459, 2740043234, 2010491456, 3737653932, 4251501920], "pos": 142}, "exits": {}, "cancels": {}, "failures": {}, "trials_closed": {"1241288f-ef58-4d42-9287-31cfa20244c5": true, "5d32d462-b831-4437-84d5-9bf8b820e09e": true, "5d804239-7e15-49e3-b710-fd7c24509d44": true, "763e2853-d3ec-4ea1-a325-050f5aef5b1a": true, "872fb3e1-39b9-40f8-b5ce-12ca53c225ed": true, "98a8c8cd-2aaf-4fbe-a358-7735337e11cd": true, "a3e23956-04ee-489d-9135-7186a7941d28": true, "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8": true, "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec": true, "d882fdea-d2be-453d-bb93-831af2dc317b": true}, "trial_progress": {"1241288f-ef58-4d42-9287-31cfa20244c5": 1, "5d32d462-b831-4437-84d5-9bf8b820e09e": 0.3333333333333333, "5d804239-7e15-49e3-b710-fd7c24509d44": 1, "763e2853-d3ec-4ea1-a325-050f5aef5b1a": 1, "872fb3e1-39b9-40f8-b5ce-12ca53c225ed": 0.3333333333333333, "98a8c8cd-2aaf-4fbe-a358-7735337e11cd": 1, "a3e23956-04ee-489d-9135-7186a7941d28": 0.3333333333333333, "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8": 1, "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec": 0.3333333333333333, "d882fdea-d2be-453d-bb93-831af2dc317b": 1}, "trials_created": {"1241288f-ef58-4d42-9287-31cfa20244c5": true, "5d32d462-b831-4437-84d5-9bf8b820e09e": true, "5d804239-7e15-49e3-b710-fd7c24509d44": true, "763e2853-d3ec-4ea1-a325-050f5aef5b1a": true, "872fb3e1-39b9-40f8-b5ce-12ca53c225ed": true, "98a8c8cd-2aaf-4fbe-a358-7735337e11cd": true, "a3e23956-04ee-489d-9135-7186a7941d28": true, "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8": true, "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec": true, "d882fdea-d2be-453d-bb93-831af2dc317b": true}, "trials_requested": 10, "search_method_state": {"trial_table": {"1241288f-ef58-4d42-9287-31cfa20244c5": 1, "5d32d462-b831-4437-84d5-9bf8b820e09e": 0, "5d804239-7e15-49e3-b710-fd7c24509d44": 1, "763e2853-d3ec-4ea1-a325-050f5aef5b1a": 1, "872fb3e1-39b9-40f8-b5ce-12ca53c225ed": 0, "98a8c8cd-2aaf-4fbe-a358-7735337e11cd": 0, "a3e23956-04ee-489d-9135-7186a7941d28": 0, "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8": 0, "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec": 0, "d882fdea-d2be-453d-bb93-831af2dc317b": 0}, "sub_search_states": [{"rungs": [{"metrics": [{"metric": 0.09292137346198563, "request_id": "d882fdea-d2be-453d-bb93-831af2dc317b"}, {"metric": 0.14394278251002454, "request_id": "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec"}, {"metric": 0.169617546004522, "request_id": "98a8c8cd-2aaf-4fbe-a358-7735337e11cd"}, {"metric": 0.17236693170526698, "request_id": "a3e23956-04ee-489d-9135-7186a7941d28"}, {"metric": 0.1814434532918463, "request_id": "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8"}, {"metric": 0.19748935120026018, "request_id": "5d32d462-b831-4437-84d5-9bf8b820e09e"}, {"metric": 0.21989905646153887, "request_id": "872fb3e1-39b9-40f8-b5ce-12ca53c225ed"}], "units_needed": 225}, {"metrics": [{"metric": 0.05633601378523418, "request_id": "d882fdea-d2be-453d-bb93-831af2dc317b"}, {"metric": 0.09834996045635314, "request_id": "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8"}, {"metric": 0.10896268058612729, "request_id": "98a8c8cd-2aaf-4fbe-a358-7735337e11cd"}], "units_needed": 900}], "trial_rungs": {"5d32d462-b831-4437-84d5-9bf8b820e09e": 0, "872fb3e1-39b9-40f8-b5ce-12ca53c225ed": 0, "98a8c8cd-2aaf-4fbe-a358-7735337e11cd": 1, "a3e23956-04ee-489d-9135-7186a7941d28": 0, "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8": 1, "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec": 0, "d882fdea-d2be-453d-bb93-831af2dc317b": 1}, "invalid_trials": 0, "trials_completed": 7, "early_exit_trials": {}, "search_method_type": "asha"}, {"rungs": [{"metrics": [{"metric": 0.05686319968148741, "request_id": "763e2853-d3ec-4ea1-a325-050f5aef5b1a"}, {"metric": 0.09348985412422045, "request_id": "1241288f-ef58-4d42-9287-31cfa20244c5"}, {"metric": 0.21061281618442695, "request_id": "5d804239-7e15-49e3-b710-fd7c24509d44"}], "units_needed": 900}], "trial_rungs": {"1241288f-ef58-4d42-9287-31cfa20244c5": 0, "5d804239-7e15-49e3-b710-fd7c24509d44": 0, "763e2853-d3ec-4ea1-a325-050f5aef5b1a": 0}, "invalid_trials": 0, "trials_completed": 3, "early_exit_trials": {}, "search_method_type": "asha"}], "search_method_type": "adaptive_asha"}}, "trial_searcher_state": {"1241288f-ef58-4d42-9287-31cfa20244c5": {"Closed": false, "Create": {"hparams": {"dropout1": 0.7856465654579918, "dropout2": 0.33680938844866426, "n_filters1": 28, "n_filters2": 40, "learning_rate": 0.33873975920890426}, "request_id": "1241288f-ef58-4d42-9287-31cfa20244c5", "trial_seed": 1053878338}, "EarlyStoppedBySearcher": true}, "5d32d462-b831-4437-84d5-9bf8b820e09e": {"Closed": false, "Create": {"hparams": {"dropout1": 0.3005242490944792, "dropout2": 0.39344632915675276, "n_filters1": 44, "n_filters2": 9, "learning_rate": 0.22728733457505207}, "request_id": "5d32d462-b831-4437-84d5-9bf8b820e09e", "trial_seed": 40569440}, "EarlyStoppedBySearcher": true}, "5d804239-7e15-49e3-b710-fd7c24509d44": {"Closed": false, "Create": {"hparams": {"dropout1": 0.7197975421763473, "dropout2": 0.718622666277326, "n_filters1": 62, "n_filters2": 9, "learning_rate": 0.3138344243067023}, "request_id": "5d804239-7e15-49e3-b710-fd7c24509d44", "trial_seed": 821874082}, "EarlyStoppedBySearcher": true}, "763e2853-d3ec-4ea1-a325-050f5aef5b1a": {"Closed": false, "Create": {"hparams": {"dropout1": 0.3797289112689676, "dropout2": 0.4162215587922132, "n_filters1": 12, "n_filters2": 42, "learning_rate": 0.8267899122382265}, "request_id": "763e2853-d3ec-4ea1-a325-050f5aef5b1a", "trial_seed": 905057022}, "EarlyStoppedBySearcher": true}, "872fb3e1-39b9-40f8-b5ce-12ca53c225ed": {"Closed": false, "Create": {"hparams": {"dropout1": 0.5123726165661315, "dropout2": 0.6988116676979639, "n_filters1": 57, "n_filters2": 41, "learning_rate": 0.14624646818760453}, "request_id": "872fb3e1-39b9-40f8-b5ce-12ca53c225ed", "trial_seed": 1848480811}, "EarlyStoppedBySearcher": true}, "98a8c8cd-2aaf-4fbe-a358-7735337e11cd": {"Closed": false, "Create": {"hparams": {"dropout1": 0.7577371982482018, "dropout2": 0.757670158564443, "n_filters1": 35, "n_filters2": 50, "learning_rate": 0.5384580326825059}, "request_id": "98a8c8cd-2aaf-4fbe-a358-7735337e11cd", "trial_seed": 1457684159}, "EarlyStoppedBySearcher": true}, "a3e23956-04ee-489d-9135-7186a7941d28": {"Closed": false, "Create": {"hparams": {"dropout1": 0.6717156035663113, "dropout2": 0.6826566058616935, "n_filters1": 34, "n_filters2": 38, "learning_rate": 0.36435611193497525}, "request_id": "a3e23956-04ee-489d-9135-7186a7941d28", "trial_seed": 283271754}, "EarlyStoppedBySearcher": true}, "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8": {"Closed": false, "Create": {"hparams": {"dropout1": 0.5394040074405178, "dropout2": 0.5360441307525875, "n_filters1": 22, "n_filters2": 20, "learning_rate": 0.30396393450440096}, "request_id": "b6f8eca1-9c4a-4f9d-9730-ded009bf51d8", "trial_seed": 1659377654}, "EarlyStoppedBySearcher": true}, "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec": {"Closed": false, "Create": {"hparams": {"dropout1": 0.2799191061847552, "dropout2": 0.394202279713878, "n_filters1": 40, "n_filters2": 65, "learning_rate": 0.14238235074491137}, "request_id": "cea8e3d2-d430-4c69-a8c8-e7d2c35dc5ec", "trial_seed": 229268612}, "EarlyStoppedBySearcher": true}, "d882fdea-d2be-453d-bb93-831af2dc317b": {"Closed": false, "Create": {"hparams": {"dropout1": 0.39100440701610784, "dropout2": 0.49850349318825615, "n_filters1": 59, "n_filters2": 71, "learning_rate": 0.33813778901495695}, "request_id": "d882fdea-d2be-453d-bb93-831af2dc317b", "trial_seed": 491875076}, "EarlyStoppedBySearcher": true}}}`), }, } for _, tt := range tests { diff --git a/master/internal/telemetry/telemetry_test.go b/master/internal/telemetry/telemetry_test.go index 3351b178717..552ed2a2f1b 100644 --- a/master/internal/telemetry/telemetry_test.go +++ b/master/internal/telemetry/telemetry_test.go @@ -138,14 +138,11 @@ func initMockedTelemetry(t *testing.T) (*mockClient, *mocks.ResourceManager) { // Helper function for ReportExperimentCreated. func createExpConfig() expconf.ExperimentConfig { - maxLength := expconf.NewLengthInBatches(100) //nolint:exhaustruct activeConfig := expconf.ExperimentConfig{ RawSearcher: &expconf.SearcherConfig{ - RawMetric: ptrs.Ptr("loss"), - RawSingleConfig: &expconf.SingleConfig{ - RawMaxLength: &maxLength, - }, + RawMetric: ptrs.Ptr("loss"), + RawSingleConfig: &expconf.SingleConfig{}, }, RawEntrypoint: &expconf.Entrypoint{RawEntrypoint: "model_def:SomeTrialClass"}, RawHyperparameters: expconf.Hyperparameters{}, diff --git a/master/internal/templates/service_intg_test.go b/master/internal/templates/service_intg_test.go index 137bd03736f..499f2e8728c 100644 --- a/master/internal/templates/service_intg_test.go +++ b/master/internal/templates/service_intg_test.go @@ -70,13 +70,8 @@ func TestUnmarshalTemplateConfig(t *testing.T) { }, }, RawSearcher: &expconf.SearcherConfigV0{ - RawSingleConfig: &expconf.SingleConfigV0{ - RawMaxLength: &expconf.LengthV0{ - Unit: expconf.Batches, - Units: 1, - }, - }, - RawMetric: ptrs.Ptr("loss_of_something"), + RawSingleConfig: &expconf.SingleConfigV0{}, + RawMetric: ptrs.Ptr("loss_of_something"), }, }) err = UnmarshalTemplateConfig(ctx, input.Name, u, &fakeConfig, false) diff --git a/master/internal/trial.go b/master/internal/trial.go index 90711bd73d5..ec29205bef8 100644 --- a/master/internal/trial.go +++ b/master/internal/trial.go @@ -88,7 +88,7 @@ type trial struct { // restarts is a failure count, it increments when the trial fails and we retry it. restarts int // runID is a count of how many times the task container(s) have stopped and restarted, which - // could be due to a failure or due to normal pausing and continuing. When RunID increments, + // could be due to a failure or due to normal pausing and continuing. When TrialID increments, // it effectively invalidates many outstanding messages associated with the previous run. runID int @@ -230,15 +230,21 @@ func (t *trial) PatchSearcherState(req experiment.TrialSearcherState) error { t.searcher = req switch { - case !t.searcher.Complete: - return t.maybeAllocateTask() - case t.searcher.Complete && t.searcher.Closed: + case t.searcher.EarlyStoppedBySearcher: + return t.patchState( + model.StateWithReason{ + State: model.StoppingCompletedState, + InformationalReason: "searcher decided to early stop trial", + }, + ) + case t.searcher.EarlyExitedByUserCode: return t.patchState(model.StateWithReason{ - State: model.StoppingCompletedState, - InformationalReason: "hp search is finished", + State: model.StoppingCanceledState, + InformationalReason: "trial received early exit signal", }) + default: + return t.maybeAllocateTask() } - return nil } func (t *trial) PatchRP(rp string) { @@ -366,19 +372,20 @@ func (t *trial) maybeAllocateTask() error { // Only allocate for active trials, or trials that have been restored and are stopping. // We need to allocate for stopping because we need to reattach the allocation. shouldAllocateState := t.state == model.ActiveState || (t.restored && model.StoppingStates[t.state]) - if t.allocationID != nil || t.searcher.Complete || !shouldAllocateState { + searcherStop := t.searcher.EarlyExitedByUserCode || t.searcher.EarlyStoppedBySearcher + if t.allocationID != nil || searcherStop || !shouldAllocateState { t.syslog.WithFields(logrus.Fields{ - "allocation-id": t.allocationID, - "sercher-complete": t.searcher.Complete, - "trial-state": t.state, - "restored": t.restored, + "allocation-id": t.allocationID, + "trial-early-exited": t.searcher.EarlyExitedByUserCode, + "searcher-early-stopped": t.searcher.EarlyStoppedBySearcher, + "trial-state": t.state, + "restored": t.restored, }).Trace("decided not to allocate trial") return nil } name := fmt.Sprintf("Trial %d (Experiment %d)", t.id, t.experimentID) t.syslog.Info("decided to allocate trial") - blockedNodes, err := logpattern.GetBlockedNodes(context.TODO(), t.taskID) if err != nil { return err @@ -581,7 +588,7 @@ func (t *trial) handleAllocationExit(exit *task.AllocationExited) error { State: model.StoppingToTerminalStates[t.state], InformationalReason: "trial stopped", }) - case t.searcher.Complete && t.searcher.Closed: + case t.searcher.EarlyStoppedBySearcher: if exit.Err != nil { return t.transition(model.StateWithReason{ State: model.ErrorState, diff --git a/master/internal/trial_intg_test.go b/master/internal/trial_intg_test.go index 5681a1065ab..5a7fa0792d5 100644 --- a/master/internal/trial_intg_test.go +++ b/master/internal/trial_intg_test.go @@ -6,7 +6,6 @@ package internal import ( "context" - "crypto/rand" "fmt" "testing" "time" @@ -32,30 +31,22 @@ import ( ) func TestTrial(t *testing.T) { - _, rID, tr, alloc, done := setup(t) - + _, tr, alloc, done := setup(t) + // xxx: fix this test // Pre-scheduled stage. require.NoError(t, tr.PatchState( model.StateWithReason{State: model.ActiveState})) require.NoError(t, tr.PatchSearcherState(experiment.TrialSearcherState{ - Create: searcher.Create{RequestID: rID}, - Op: searcher.ValidateAfter{ - RequestID: rID, - Length: 10, - }, - Complete: false, - Closed: true, + Create: searcher.Create{}, + EarlyStoppedBySearcher: false, + EarlyExitedByUserCode: false, })) // Running stage. require.NoError(t, tr.PatchSearcherState(experiment.TrialSearcherState{ - Create: searcher.Create{RequestID: rID}, - Op: searcher.ValidateAfter{ - RequestID: rID, - Length: 10, - }, - Complete: true, - Closed: true, + Create: searcher.Create{}, + EarlyStoppedBySearcher: true, + EarlyExitedByUserCode: false, })) require.True(t, alloc.AssertExpectations(t)) require.NotNil(t, tr.allocationID) @@ -79,18 +70,14 @@ func TestTrial(t *testing.T) { } func TestTrialRestarts(t *testing.T) { - pgDB, rID, tr, _, done := setup(t) + pgDB, tr, _, done := setup(t) // Pre-scheduled stage. require.NoError(t, tr.PatchState( model.StateWithReason{State: model.ActiveState})) require.NoError(t, tr.PatchSearcherState(experiment.TrialSearcherState{ - Create: searcher.Create{RequestID: rID}, - Op: searcher.ValidateAfter{ - RequestID: rID, - Length: 10, - }, - Complete: false, - Closed: true, + Create: searcher.Create{}, + EarlyStoppedBySearcher: false, + EarlyExitedByUserCode: false, })) for i := 0; i <= tr.config.MaxRestarts(); i++ { @@ -120,7 +107,6 @@ func TestTrialRestarts(t *testing.T) { func setup(t *testing.T) ( *internaldb.PgDB, - model.RequestID, *trial, *allocationmocks.AllocationService, chan bool, @@ -143,9 +129,9 @@ func setup(t *testing.T) ( j := &model.Job{JobID: model.NewJobID(), JobType: model.JobTypeExperiment} require.NoError(t, internaldb.AddJob(j)) + eID := 1 // instantiate the trial - rID := model.NewRequestID(rand.Reader) - taskID := model.TaskID(fmt.Sprintf("%s-%s", model.TaskTypeTrial, rID)) + taskID := model.TaskID(fmt.Sprintf("%d.%s", eID, model.NewTaskID())) done := make(chan bool) // create expconf merged with task container defaults @@ -163,9 +149,9 @@ func setup(t *testing.T) ( taskID, j.JobID, time.Now(), - 1, + eID, model.PausedState, - experiment.TrialSearcherState{Create: searcher.Create{RequestID: rID}, Complete: true}, + experiment.TrialSearcherState{Create: searcher.Create{}, EarlyExitedByUserCode: true}, rmImpl, a.m.db, expConf, @@ -176,12 +162,11 @@ func setup(t *testing.T) ( }, ssh.PrivateAndPublicKeys{}, false, - nil, nil, func(ri model.RequestID, reason *model.ExitedReason) { - require.Equal(t, rID, ri) + nil, nil, func(rID model.RequestID, reason *model.ExitedReason) { done <- true close(done) }, ) require.NoError(t, err) - return a.m.db, rID, tr, &as, done + return a.m.db, tr, &as, done } diff --git a/master/internal/trials/postgres_trials.go b/master/internal/trials/postgres_trials.go index 985879f5080..147bdee57d5 100644 --- a/master/internal/trials/postgres_trials.go +++ b/master/internal/trials/postgres_trials.go @@ -73,7 +73,7 @@ func generateMetricToColumn(metric string) string { // trial. func MetricsTimeSeries(trialID int32, startTime time.Time, metricNames []string, - startBatches int, endBatches int, xAxisMetricLabels []string, + startBatches int, endBatches int, maxDatapoints int, timeSeriesColumn string, timeSeriesFilter *commonv1.PolymorphicFilter, metricGroup model.MetricGroup) ( metricMeasurements []db.MetricMeasurements, err error, @@ -106,7 +106,7 @@ func MetricsTimeSeries(trialID int32, startTime time.Time, return nil, fmt.Errorf("getting summary metrics for trial %d: %w", trialID, err) } - for _, metricName := range append(metricNames, "epoch") { + for _, metricName := range append(metricNames, "epoch", "epochs") { metricType := db.MetricTypeString if curSummary, ok := summaryMetrics.Metrics[metricName].(map[string]any); ok { if m, ok := curSummary["type"].(string); ok { @@ -163,14 +163,23 @@ func MetricsTimeSeries(trialID int32, startTime time.Time, valuesMap[selectMetrics[mName]] = mVal } } - var epoch *float64 - if results[i]["epoch"] != nil { + var epochs *float64 + // "epoch" is the legacy metric name for epoch x-axis metric, it was renamed to "epochs" + // but we fallback to "epoch" for backwards compatibility. + if results[i]["epochs"] != nil { + e, ok := results[i]["epochs"].(float64) + if !ok { + return nil, fmt.Errorf( + "metric 'epochs' has nonnumeric value reported value='%v'", results[i]["epochs"]) + } + epochs = &e + } else if results[i]["epoch"] != nil { e, ok := results[i]["epoch"].(float64) if !ok { return nil, fmt.Errorf( "metric 'epoch' has nonnumeric value reported value='%v'", results[i]["epoch"]) } - epoch = &e + epochs = &e } var endTime time.Time if results[i]["time"] == nil { @@ -181,7 +190,7 @@ func MetricsTimeSeries(trialID int32, startTime time.Time, metricM := db.MetricMeasurements{ Batches: uint(results[i]["batches"].(int64)), Time: endTime, - Epoch: epoch, + Epoch: epochs, TrialID: int32(results[i]["trial_id"].(int64)), Values: valuesMap, } diff --git a/master/pkg/model/searcher.go b/master/pkg/model/searcher.go index 6538942118c..106cb4bcf25 100644 --- a/master/pkg/model/searcher.go +++ b/master/pkg/model/searcher.go @@ -62,24 +62,6 @@ func (r RequestID) String() string { return uuid.UUID(r).String() } -// ParseRequestID decodes s into a request id or returns an error. -func ParseRequestID(s string) (RequestID, error) { - parsed, err := uuid.Parse(s) - if err != nil { - return RequestID{}, err - } - return RequestID(parsed), nil -} - -// MustParseRequestID decodes s into a request id or panics. -func MustParseRequestID(s string) RequestID { - parsed, err := ParseRequestID(s) - if err != nil { - panic(err) - } - return parsed -} - // Value implements the sql.Driver interface. func (r RequestID) Value() (driver.Value, error) { return r.String(), nil diff --git a/master/pkg/model/test_utils.go b/master/pkg/model/test_utils.go index d73d05cde18..7ab17de3c2a 100644 --- a/master/pkg/model/test_utils.go +++ b/master/pkg/model/test_utils.go @@ -28,13 +28,10 @@ func (f ExperimentModelOptionFunc) apply(experiment *Experiment) { // ExperimentModel returns a new experiment with the specified options. // nolint: exhaustruct func ExperimentModel(opts ...ExperimentModelOption) (*Experiment, expconf.ExperimentConfig) { - maxLength := expconf.NewLengthInBatches(100) activeConfig := expconf.ExperimentConfig{ RawSearcher: &expconf.SearcherConfig{ - RawMetric: ptrs.Ptr("loss"), - RawSingleConfig: &expconf.SingleConfig{ - RawMaxLength: &maxLength, - }, + RawMetric: ptrs.Ptr("loss"), + RawSingleConfig: &expconf.SingleConfig{}, }, RawEntrypoint: &expconf.Entrypoint{RawEntrypoint: "model_def:SomeTrialClass"}, RawHyperparameters: expconf.Hyperparameters{}, diff --git a/master/pkg/schemas/expconf/experiment_config.go b/master/pkg/schemas/expconf/experiment_config.go index b2701d846e4..52133a0c9d6 100644 --- a/master/pkg/schemas/expconf/experiment_config.go +++ b/master/pkg/schemas/expconf/experiment_config.go @@ -52,11 +52,6 @@ type ExperimentConfigV0 struct { RawPreemptionTimeout *int `json:"preemption_timeout"` } -// Unit implements the model.InUnits interface. -func (e *ExperimentConfigV0) Unit() Unit { - return e.RawSearcher.Unit() -} - // Value implements the driver.Valuer interface. func (e ExperimentConfigV0) Value() (driver.Value, error) { // Validate the object before passing it to the database. @@ -146,11 +141,6 @@ func (d *Name) UnmarshalJSON(bytes []byte) error { return json.Unmarshal(bytes, &d.RawString) } -// InUnits is describes a type that is in terms of a specific unit. -type InUnits interface { - Unit() Unit -} - // LabelsV0 holds the set of labels on the experiment. type LabelsV0 map[string]bool diff --git a/master/pkg/schemas/expconf/latest.go b/master/pkg/schemas/expconf/latest.go index 9751934700f..1352815c0a1 100644 --- a/master/pkg/schemas/expconf/latest.go +++ b/master/pkg/schemas/expconf/latest.go @@ -11,7 +11,6 @@ type ( CategoricalHyperparameter = CategoricalHyperparameterV0 CheckpointStorageConfig = CheckpointStorageConfigV0 ConstHyperparameter = ConstHyperparameterV0 - CustomConfig = CustomConfigV0 Device = DeviceV0 DevicesConfig = DevicesConfigV0 DirectoryConfig = DirectoryConfigV0 diff --git a/master/pkg/schemas/expconf/searcher_config.go b/master/pkg/schemas/expconf/searcher_config.go index d799fb56970..bdf7fefba64 100644 --- a/master/pkg/schemas/expconf/searcher_config.go +++ b/master/pkg/schemas/expconf/searcher_config.go @@ -3,6 +3,8 @@ package expconf import ( "encoding/json" + log "github.com/sirupsen/logrus" + "github.com/pkg/errors" "github.com/determined-ai/determined/master/pkg/schemas" @@ -18,7 +20,6 @@ type SearcherConfigV0 struct { RawGridConfig *GridConfigV0 `union:"name,grid" json:"-"` RawAsyncHalvingConfig *AsyncHalvingConfigV0 `union:"name,async_halving" json:"-"` RawAdaptiveASHAConfig *AdaptiveASHAConfigV0 `union:"name,adaptive_asha" json:"-"` - RawCustomConfig *CustomConfigV0 `union:"name,custom" json:"-"` // TODO(DET-8577): There should not be a need to parse EOL searchers if we get rid of parsing // active experiment configs unnecessarily. @@ -26,6 +27,7 @@ type SearcherConfigV0 struct { RawSyncHalvingConfig *SyncHalvingConfigV0 `union:"name,sync_halving" json:"-"` RawAdaptiveConfig *AdaptiveConfigV0 `union:"name,adaptive" json:"-"` RawAdaptiveSimpleConfig *AdaptiveSimpleConfigV0 `union:"name,adaptive_simple" json:"-"` + RawCustomConfig *CustomConfigV0 `union:"name,custom" json:"-"` RawMetric *string `json:"metric"` RawSmallerIsBetter *bool `json:"smaller_is_better"` @@ -52,32 +54,6 @@ func (s *SearcherConfigV0) UnmarshalJSON(data []byte) error { return errors.Wrap(json.Unmarshal(data, DefaultParser(s)), "failed to parse searcher config") } -// Unit implements the model.InUnits interface. -func (s SearcherConfigV0) Unit() Unit { - switch { - case s.RawSingleConfig != nil: - return s.RawSingleConfig.Unit() - case s.RawRandomConfig != nil: - return s.RawRandomConfig.Unit() - case s.RawGridConfig != nil: - return s.RawGridConfig.Unit() - case s.RawAsyncHalvingConfig != nil: - return s.RawAsyncHalvingConfig.Unit() - case s.RawAdaptiveASHAConfig != nil: - return s.RawAdaptiveASHAConfig.Unit() - case s.RawCustomConfig != nil: - panic("custom searcher config does not provide Unit()") - case s.RawSyncHalvingConfig != nil: - panic("cannot get unit of EOL searcher class") - case s.RawAdaptiveConfig != nil: - panic("cannot get unit of EOL searcher class") - case s.RawAdaptiveSimpleConfig != nil: - panic("cannot get unit of EOL searcher class") - default: - panic("no searcher type specified") - } -} - // AsLegacy converts a current ExperimentConfig to a (limited capacity) LegacySearcher. func (s SearcherConfigV0) AsLegacy() LegacySearcher { var name string @@ -110,67 +86,53 @@ func (s SearcherConfigV0) AsLegacy() LegacySearcher { } } -// CustomConfigV0 configures a custom search. -// -//go:generate ../gen.sh -type CustomConfigV0 struct { - RawUnit *Unit `json:"unit"` -} - // SingleConfigV0 configures a single trial. // //go:generate ../gen.sh type SingleConfigV0 struct { - RawMaxLength *LengthV0 `json:"max_length"` -} - -// Unit implements the model.InUnits interface. -func (s SingleConfigV0) Unit() Unit { - return s.RawMaxLength.Unit + RawMaxLength *LengthV0 `json:"max_length,omitempty"` } // RandomConfigV0 configures a random search. // //go:generate ../gen.sh type RandomConfigV0 struct { - RawMaxLength *LengthV0 `json:"max_length"` + RawMaxLength *LengthV0 `json:"max_length,omitempty"` RawMaxTrials *int `json:"max_trials"` RawMaxConcurrentTrials *int `json:"max_concurrent_trials"` } -// Unit implements the model.InUnits interface. -func (r RandomConfigV0) Unit() Unit { - return r.RawMaxLength.Unit -} - // GridConfigV0 configures a grid search. // //go:generate ../gen.sh type GridConfigV0 struct { - RawMaxLength *LengthV0 `json:"max_length"` + RawMaxLength *LengthV0 `json:"max_length,omitempty"` RawMaxConcurrentTrials *int `json:"max_concurrent_trials"` } -// Unit implements the model.InUnits interface. -func (g GridConfigV0) Unit() Unit { - return g.RawMaxLength.Unit -} - // AsyncHalvingConfigV0 configures asynchronous successive halving. // //go:generate ../gen.sh type AsyncHalvingConfigV0 struct { - RawNumRungs *int `json:"num_rungs"` - RawMaxLength *LengthV0 `json:"max_length"` - RawMaxTrials *int `json:"max_trials"` - RawDivisor *float64 `json:"divisor"` - RawMaxConcurrentTrials *int `json:"max_concurrent_trials"` - RawStopOnce *bool `json:"stop_once"` + RawNumRungs *int `json:"num_rungs"` + RawMaxTrials *int `json:"max_trials"` + RawDivisor *float64 `json:"divisor"` + RawMaxConcurrentTrials *int `json:"max_concurrent_trials"` + RawMaxTime *int `json:"max_time"` + RawTimeMetric *string `json:"time_metric"` + // These config options are deprecated and should not be used. + // They exist to help parse legacy exp configs. + RawMaxLength *LengthV0 `json:"max_length,omitempty"` + RawStopOnce *bool `json:"stop_once,omitempty"` } -// Unit implements the model.InUnits interface. -func (a AsyncHalvingConfigV0) Unit() Unit { - return a.RawMaxLength.Unit +// Length returns the maximum training length. +func (a AsyncHalvingConfigV0) Length() Length { + if a.RawMaxTime != nil && a.RawTimeMetric != nil { + return Length{Unit: Unit(*a.RawTimeMetric), Units: uint64(*a.RawMaxTime)} + } + // Parse legacy expconfs for backwards compat. + return *a.RawMaxLength } // AdaptiveMode specifies how aggressively to perform early stopping. @@ -179,37 +141,39 @@ type AdaptiveMode string const ( // AggressiveMode quickly stops underperforming trials, which enables the searcher to explore // more hyperparameter configurations. - AggressiveMode = "aggressive" + AggressiveMode AdaptiveMode = "aggressive" // StandardMode provides a balance between downsampling and hyperparameter exploration. - StandardMode = "standard" + StandardMode AdaptiveMode = "standard" // ConservativeMode performs minimal downsampling at the cost of not exploring as many // configurations. - ConservativeMode = "conservative" + ConservativeMode AdaptiveMode = "conservative" ) -// AdaptiveModePtr is like &AdaptiveMode("standard"), except it works. -func AdaptiveModePtr(mode string) *AdaptiveMode { - tmp := AdaptiveMode(mode) - return &tmp -} - // AdaptiveASHAConfigV0 configures an adaptive searcher for use with ASHA. // //go:generate ../gen.sh type AdaptiveASHAConfigV0 struct { - RawMaxLength *LengthV0 `json:"max_length"` RawMaxTrials *int `json:"max_trials"` RawBracketRungs []int `json:"bracket_rungs"` RawDivisor *float64 `json:"divisor"` RawMode *AdaptiveMode `json:"mode"` RawMaxRungs *int `json:"max_rungs"` RawMaxConcurrentTrials *int `json:"max_concurrent_trials"` - RawStopOnce *bool `json:"stop_once"` + RawMaxTime *int `json:"max_time"` + RawTimeMetric *string `json:"time_metric"` + // These config options are deprecated and should not be used. + // They exist to help parse legacy exp configs. + RawMaxLength *LengthV0 `json:"max_length,omitempty"` + RawStopOnce *bool `json:"stop_once,omitempty"` } -// Unit implements the model.InUnits interface. -func (a AdaptiveASHAConfigV0) Unit() Unit { - return a.RawMaxLength.Unit +// Length returns the maximum training length. +func (a AdaptiveASHAConfigV0) Length() Length { + if a.RawMaxTime != nil && a.RawTimeMetric != nil { + return Length{Unit: Unit(*a.RawTimeMetric), Units: uint64(*a.RawMaxTime)} + } + // Parse legacy expconfs for backwards compat. + return *a.RawMaxLength } // SyncHalvingConfigV0 is a legacy config. @@ -217,7 +181,7 @@ func (a AdaptiveASHAConfigV0) Unit() Unit { //go:generate ../gen.sh type SyncHalvingConfigV0 struct { RawNumRungs *int `json:"num_rungs"` - RawMaxLength *LengthV0 `json:"max_length"` + RawMaxLength *LengthV0 `json:"max_length,omitempty"` RawBudget *LengthV0 `json:"budget"` RawDivisor *float64 `json:"divisor"` RawTrainStragglers *bool `json:"train_stragglers"` @@ -227,7 +191,7 @@ type SyncHalvingConfigV0 struct { // //go:generate ../gen.sh type AdaptiveConfigV0 struct { - RawMaxLength *LengthV0 `json:"max_length"` + RawMaxLength *LengthV0 `json:"max_length,omitempty"` RawBudget *LengthV0 `json:"budget"` RawBracketRungs []int `json:"bracket_rungs"` RawDivisor *float64 `json:"divisor"` @@ -240,16 +204,64 @@ type AdaptiveConfigV0 struct { // //go:generate ../gen.sh type AdaptiveSimpleConfigV0 struct { - RawMaxLength *LengthV0 `json:"max_length"` + RawMaxLength *LengthV0 `json:"max_length,omitempty"` RawMaxTrials *int `json:"max_trials"` RawDivisor *float64 `json:"divisor"` RawMode *AdaptiveMode `json:"mode"` RawMaxRungs *int `json:"max_rungs"` } -// AssertCurrent distinguishes configs which are only parsable from those that are runnable. +// CustomConfigV0 configures a custom search. +// +//go:generate ../gen.sh +type CustomConfigV0 struct { + RawUnit *Unit `json:"unit"` +} + +// AssertCurrent distinguishes configs which are only parsable from those that are runnable and logs deprecation +// warnings for legacy fields. func (s SearcherConfig) AssertCurrent() error { switch { + case s.RawAdaptiveASHAConfig != nil: + if s.RawAdaptiveASHAConfig.RawMaxLength != nil { + log.Warn( + "the `max_length` field of the searcher config has been deprecated and will be removed in a " + + "future release.") + } + if s.RawAdaptiveASHAConfig.RawStopOnce != nil { + log.Warn("the `stop_once` field of the searcher config has been deprecated and will be removed in " + + "a future release.") + } + if s.RawAdaptiveASHAConfig.RawMaxTime == nil || s.RawAdaptiveASHAConfig.RawTimeMetric == nil { + return errors.New("the `adaptive_asha` searcher requires `max_time` and `time_metric` to be set") + } + case s.RawAsyncHalvingConfig != nil: + if s.RawAsyncHalvingConfig.RawMaxLength != nil { + log.Warn("the `max_length` field of the searcher config has been deprecated and will be removed in " + + "a future release.") + } + if s.RawAsyncHalvingConfig.RawStopOnce != nil { + log.Warn("the `stop_once` field of the searcher config has been deprecated and will be removed in " + + "a future release.") + } + if s.RawAsyncHalvingConfig.RawMaxTime == nil || s.RawAsyncHalvingConfig.RawTimeMetric == nil { + return errors.New("the `async_halving` searcher requires `max_time` and `time_metric` to be set") + } + case s.RawGridConfig != nil: + if s.RawGridConfig.RawMaxLength != nil { + log.Warn("the `max_length` field of the searcher config has been deprecated and will be removed in " + + "a future release.") + } + case s.RawSingleConfig != nil: + if s.RawSingleConfig.RawMaxLength != nil { + log.Warn("the `max_length` field of the searcher config has been deprecated and will be removed in " + + "a future release.") + } + case s.RawRandomConfig != nil: + if s.RawRandomConfig.RawMaxLength != nil { + log.Warn("the `max_length` field of the searcher config has been deprecated and will be removed in " + + "a future release.") + } case s.RawSyncHalvingConfig != nil: return errors.New( "the 'sync_halving' searcher has been removed and is not valid for new experiments", @@ -262,6 +274,10 @@ func (s SearcherConfig) AssertCurrent() error { return errors.New( "the 'adaptive_simple' searcher has been removed and is not valid for new experiments", ) + case s.RawCustomConfig != nil: + return errors.New( + "the 'custom' searcher has been removed and is not valid for new experiments", + ) } return nil } diff --git a/master/pkg/schemas/expconf/zgen_adaptive_asha_config_v0.go b/master/pkg/schemas/expconf/zgen_adaptive_asha_config_v0.go index 779f7b8b9f0..f6fc57fd7e5 100644 --- a/master/pkg/schemas/expconf/zgen_adaptive_asha_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_adaptive_asha_config_v0.go @@ -8,17 +8,6 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas" ) -func (a AdaptiveASHAConfigV0) MaxLength() LengthV0 { - if a.RawMaxLength == nil { - panic("You must call WithDefaults on AdaptiveASHAConfigV0 before .MaxLength") - } - return *a.RawMaxLength -} - -func (a *AdaptiveASHAConfigV0) SetMaxLength(val LengthV0) { - a.RawMaxLength = &val -} - func (a AdaptiveASHAConfigV0) MaxTrials() int { if a.RawMaxTrials == nil { panic("You must call WithDefaults on AdaptiveASHAConfigV0 before .MaxTrials") @@ -82,15 +71,36 @@ func (a *AdaptiveASHAConfigV0) SetMaxConcurrentTrials(val int) { a.RawMaxConcurrentTrials = &val } -func (a AdaptiveASHAConfigV0) StopOnce() bool { - if a.RawStopOnce == nil { - panic("You must call WithDefaults on AdaptiveASHAConfigV0 before .StopOnce") - } - return *a.RawStopOnce +func (a AdaptiveASHAConfigV0) MaxTime() *int { + return a.RawMaxTime +} + +func (a *AdaptiveASHAConfigV0) SetMaxTime(val *int) { + a.RawMaxTime = val +} + +func (a AdaptiveASHAConfigV0) TimeMetric() *string { + return a.RawTimeMetric +} + +func (a *AdaptiveASHAConfigV0) SetTimeMetric(val *string) { + a.RawTimeMetric = val +} + +func (a AdaptiveASHAConfigV0) MaxLength() *LengthV0 { + return a.RawMaxLength +} + +func (a *AdaptiveASHAConfigV0) SetMaxLength(val *LengthV0) { + a.RawMaxLength = val +} + +func (a AdaptiveASHAConfigV0) StopOnce() *bool { + return a.RawStopOnce } -func (a *AdaptiveASHAConfigV0) SetStopOnce(val bool) { - a.RawStopOnce = &val +func (a *AdaptiveASHAConfigV0) SetStopOnce(val *bool) { + a.RawStopOnce = val } func (a AdaptiveASHAConfigV0) ParsedSchema() interface{} { diff --git a/master/pkg/schemas/expconf/zgen_adaptive_config_v0.go b/master/pkg/schemas/expconf/zgen_adaptive_config_v0.go index cf3c484b815..6816c834aed 100644 --- a/master/pkg/schemas/expconf/zgen_adaptive_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_adaptive_config_v0.go @@ -8,15 +8,12 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas" ) -func (a AdaptiveConfigV0) MaxLength() LengthV0 { - if a.RawMaxLength == nil { - panic("You must call WithDefaults on AdaptiveConfigV0 before .MaxLength") - } - return *a.RawMaxLength +func (a AdaptiveConfigV0) MaxLength() *LengthV0 { + return a.RawMaxLength } -func (a *AdaptiveConfigV0) SetMaxLength(val LengthV0) { - a.RawMaxLength = &val +func (a *AdaptiveConfigV0) SetMaxLength(val *LengthV0) { + a.RawMaxLength = val } func (a AdaptiveConfigV0) Budget() LengthV0 { diff --git a/master/pkg/schemas/expconf/zgen_adaptive_simple_config_v0.go b/master/pkg/schemas/expconf/zgen_adaptive_simple_config_v0.go index ad4613973e6..63fff4c4780 100644 --- a/master/pkg/schemas/expconf/zgen_adaptive_simple_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_adaptive_simple_config_v0.go @@ -8,15 +8,12 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas" ) -func (a AdaptiveSimpleConfigV0) MaxLength() LengthV0 { - if a.RawMaxLength == nil { - panic("You must call WithDefaults on AdaptiveSimpleConfigV0 before .MaxLength") - } - return *a.RawMaxLength +func (a AdaptiveSimpleConfigV0) MaxLength() *LengthV0 { + return a.RawMaxLength } -func (a *AdaptiveSimpleConfigV0) SetMaxLength(val LengthV0) { - a.RawMaxLength = &val +func (a *AdaptiveSimpleConfigV0) SetMaxLength(val *LengthV0) { + a.RawMaxLength = val } func (a AdaptiveSimpleConfigV0) MaxTrials() int { diff --git a/master/pkg/schemas/expconf/zgen_async_halving_config_v0.go b/master/pkg/schemas/expconf/zgen_async_halving_config_v0.go index 9bf52ee3afc..03ad968feb9 100644 --- a/master/pkg/schemas/expconf/zgen_async_halving_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_async_halving_config_v0.go @@ -19,17 +19,6 @@ func (a *AsyncHalvingConfigV0) SetNumRungs(val int) { a.RawNumRungs = &val } -func (a AsyncHalvingConfigV0) MaxLength() LengthV0 { - if a.RawMaxLength == nil { - panic("You must call WithDefaults on AsyncHalvingConfigV0 before .MaxLength") - } - return *a.RawMaxLength -} - -func (a *AsyncHalvingConfigV0) SetMaxLength(val LengthV0) { - a.RawMaxLength = &val -} - func (a AsyncHalvingConfigV0) MaxTrials() int { if a.RawMaxTrials == nil { panic("You must call WithDefaults on AsyncHalvingConfigV0 before .MaxTrials") @@ -63,15 +52,36 @@ func (a *AsyncHalvingConfigV0) SetMaxConcurrentTrials(val int) { a.RawMaxConcurrentTrials = &val } -func (a AsyncHalvingConfigV0) StopOnce() bool { - if a.RawStopOnce == nil { - panic("You must call WithDefaults on AsyncHalvingConfigV0 before .StopOnce") - } - return *a.RawStopOnce +func (a AsyncHalvingConfigV0) MaxTime() *int { + return a.RawMaxTime +} + +func (a *AsyncHalvingConfigV0) SetMaxTime(val *int) { + a.RawMaxTime = val +} + +func (a AsyncHalvingConfigV0) TimeMetric() *string { + return a.RawTimeMetric +} + +func (a *AsyncHalvingConfigV0) SetTimeMetric(val *string) { + a.RawTimeMetric = val +} + +func (a AsyncHalvingConfigV0) MaxLength() *LengthV0 { + return a.RawMaxLength +} + +func (a *AsyncHalvingConfigV0) SetMaxLength(val *LengthV0) { + a.RawMaxLength = val +} + +func (a AsyncHalvingConfigV0) StopOnce() *bool { + return a.RawStopOnce } -func (a *AsyncHalvingConfigV0) SetStopOnce(val bool) { - a.RawStopOnce = &val +func (a *AsyncHalvingConfigV0) SetStopOnce(val *bool) { + a.RawStopOnce = val } func (a AsyncHalvingConfigV0) ParsedSchema() interface{} { diff --git a/master/pkg/schemas/expconf/zgen_grid_config_v0.go b/master/pkg/schemas/expconf/zgen_grid_config_v0.go index 45a561ebdb5..a372484fdb8 100644 --- a/master/pkg/schemas/expconf/zgen_grid_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_grid_config_v0.go @@ -8,15 +8,12 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas" ) -func (g GridConfigV0) MaxLength() LengthV0 { - if g.RawMaxLength == nil { - panic("You must call WithDefaults on GridConfigV0 before .MaxLength") - } - return *g.RawMaxLength +func (g GridConfigV0) MaxLength() *LengthV0 { + return g.RawMaxLength } -func (g *GridConfigV0) SetMaxLength(val LengthV0) { - g.RawMaxLength = &val +func (g *GridConfigV0) SetMaxLength(val *LengthV0) { + g.RawMaxLength = val } func (g GridConfigV0) MaxConcurrentTrials() int { diff --git a/master/pkg/schemas/expconf/zgen_random_config_v0.go b/master/pkg/schemas/expconf/zgen_random_config_v0.go index 3cbc34a706b..4d27e3df352 100644 --- a/master/pkg/schemas/expconf/zgen_random_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_random_config_v0.go @@ -8,15 +8,12 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas" ) -func (r RandomConfigV0) MaxLength() LengthV0 { - if r.RawMaxLength == nil { - panic("You must call WithDefaults on RandomConfigV0 before .MaxLength") - } - return *r.RawMaxLength +func (r RandomConfigV0) MaxLength() *LengthV0 { + return r.RawMaxLength } -func (r *RandomConfigV0) SetMaxLength(val LengthV0) { - r.RawMaxLength = &val +func (r *RandomConfigV0) SetMaxLength(val *LengthV0) { + r.RawMaxLength = val } func (r RandomConfigV0) MaxTrials() int { diff --git a/master/pkg/schemas/expconf/zgen_searcher_config_v0.go b/master/pkg/schemas/expconf/zgen_searcher_config_v0.go index d64660335ed..fae3468c56f 100644 --- a/master/pkg/schemas/expconf/zgen_searcher_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_searcher_config_v0.go @@ -62,9 +62,6 @@ func (s SearcherConfigV0) GetUnionMember() interface{} { if s.RawAdaptiveASHAConfig != nil { return *s.RawAdaptiveASHAConfig } - if s.RawCustomConfig != nil { - return *s.RawCustomConfig - } if s.RawSyncHalvingConfig != nil { return *s.RawSyncHalvingConfig } @@ -74,6 +71,9 @@ func (s SearcherConfigV0) GetUnionMember() interface{} { if s.RawAdaptiveSimpleConfig != nil { return *s.RawAdaptiveSimpleConfig } + if s.RawCustomConfig != nil { + return *s.RawCustomConfig + } panic("no union member defined") } diff --git a/master/pkg/schemas/expconf/zgen_single_config_v0.go b/master/pkg/schemas/expconf/zgen_single_config_v0.go index 09ae94649d2..0b8c01d5312 100644 --- a/master/pkg/schemas/expconf/zgen_single_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_single_config_v0.go @@ -8,15 +8,12 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas" ) -func (s SingleConfigV0) MaxLength() LengthV0 { - if s.RawMaxLength == nil { - panic("You must call WithDefaults on SingleConfigV0 before .MaxLength") - } - return *s.RawMaxLength +func (s SingleConfigV0) MaxLength() *LengthV0 { + return s.RawMaxLength } -func (s *SingleConfigV0) SetMaxLength(val LengthV0) { - s.RawMaxLength = &val +func (s *SingleConfigV0) SetMaxLength(val *LengthV0) { + s.RawMaxLength = val } func (s SingleConfigV0) ParsedSchema() interface{} { diff --git a/master/pkg/schemas/expconf/zgen_sync_halving_config_v0.go b/master/pkg/schemas/expconf/zgen_sync_halving_config_v0.go index 6526bc4abf5..a145c24a5e5 100644 --- a/master/pkg/schemas/expconf/zgen_sync_halving_config_v0.go +++ b/master/pkg/schemas/expconf/zgen_sync_halving_config_v0.go @@ -19,15 +19,12 @@ func (s *SyncHalvingConfigV0) SetNumRungs(val int) { s.RawNumRungs = &val } -func (s SyncHalvingConfigV0) MaxLength() LengthV0 { - if s.RawMaxLength == nil { - panic("You must call WithDefaults on SyncHalvingConfigV0 before .MaxLength") - } - return *s.RawMaxLength +func (s SyncHalvingConfigV0) MaxLength() *LengthV0 { + return s.RawMaxLength } -func (s *SyncHalvingConfigV0) SetMaxLength(val LengthV0) { - s.RawMaxLength = &val +func (s *SyncHalvingConfigV0) SetMaxLength(val *LengthV0) { + s.RawMaxLength = val } func (s SyncHalvingConfigV0) Budget() LengthV0 { diff --git a/master/pkg/schemas/zgen_schemas.go b/master/pkg/schemas/zgen_schemas.go index b887031a82a..075f645f6fa 100644 --- a/master/pkg/schemas/zgen_schemas.go +++ b/master/pkg/schemas/zgen_schemas.go @@ -2242,7 +2242,6 @@ var ( "name" ], "eventuallyRequired": [ - "max_length", "max_trials", "metric" ], @@ -2268,6 +2267,20 @@ var ( "default": null, "minimum": 1 }, + "time_metric": { + "type": [ + "string", + "null" + ], + "default": null + }, + "max_time": { + "type": [ + "integer", + "null" + ], + "default": null + }, "mode": { "enum": [ null, @@ -2315,7 +2328,7 @@ var ( "boolean", "null" ], - "default": false + "default": null }, "metric": { "type": [ @@ -2360,7 +2373,6 @@ var ( ], "eventuallyRequired": [ "max_trials", - "max_length", "metric" ], "properties": { @@ -2452,7 +2464,6 @@ var ( ], "eventuallyRequired": [ "budget", - "max_length", "metric" ], "properties": { @@ -2559,7 +2570,6 @@ var ( ], "eventuallyRequired": [ "num_rungs", - "max_length", "max_trials", "metric" ], @@ -2612,7 +2622,7 @@ var ( "boolean", "null" ], - "default": false + "default": null }, "metric": { "type": [ @@ -2621,6 +2631,20 @@ var ( ], "default": null }, + "time_metric": { + "type": [ + "string", + "null" + ], + "default": null + }, + "max_time": { + "type": [ + "integer", + "null" + ], + "default": null + }, "smaller_is_better": { "type": [ "boolean", @@ -2647,6 +2671,7 @@ var ( `) textCustomConfigV0 = []byte(`{ "$schema": "http://json-schema.org/draft-07/schema#", + "$comment": "this is an EOL searcher, not to be used in new experiments", "$id": "http://determined.ai/schemas/expconf/v0/searcher-custom.json", "title": "CustomConfig", "type": "object", @@ -2697,7 +2722,6 @@ var ( "name" ], "eventuallyRequired": [ - "max_length", "metric" ], "properties": { @@ -2783,7 +2807,6 @@ var ( ], "eventuallyRequired": [ "max_trials", - "max_length", "metric" ], "properties": { @@ -2856,7 +2879,6 @@ var ( "name" ], "eventuallyRequired": [ - "max_length", "metric" ], "properties": { @@ -2915,7 +2937,6 @@ var ( ], "eventuallyRequired": [ "num_rungs", - "max_length", "budget", "metric" ], @@ -3019,10 +3040,6 @@ var ( "unionKey": "const:name=grid", "$ref": "http://determined.ai/schemas/expconf/v0/searcher-grid.json" }, - { - "unionKey": "const:name=custom", - "$ref": "http://determined.ai/schemas/expconf/v0/searcher-custom.json" - }, { "unionKey": "const:name=adaptive_asha", "$ref": "http://determined.ai/schemas/expconf/v0/searcher-adaptive-asha.json" @@ -3031,6 +3048,11 @@ var ( "unionKey": "const:name=async_halving", "$ref": "http://determined.ai/schemas/expconf/v0/searcher-async-halving.json" }, + { + "$comment": "this is an EOL searcher, not to be used in new experiments", + "unionKey": "const:name=custom", + "$ref": "http://determined.ai/schemas/expconf/v0/searcher-custom.json" + }, { "$comment": "this is an EOL searcher, not to be used in new experiments", "unionKey": "const:name=adaptive", @@ -3060,6 +3082,8 @@ var ( "max_concurrent_trials": true, "max_length": true, "max_rungs": true, + "max_time": true, + "time_metric": true, "max_trials": true, "mode": true, "name": true, diff --git a/master/pkg/searcher/actions.go b/master/pkg/searcher/actions.go new file mode 100644 index 00000000000..d2aa7c9f857 --- /dev/null +++ b/master/pkg/searcher/actions.go @@ -0,0 +1,73 @@ +package searcher + +import ( + "fmt" + + "github.com/determined-ai/determined/master/pkg/model" + + "github.com/determined-ai/determined/master/pkg/nprand" +) + +// Action is an action that a searcher would like to perform. +type Action interface { + searcherAction() +} + +// Create is a directive from the searcher to create a new run. +type Create struct { + RequestID model.RequestID `json:"request_id"` + // TrialSeed must be a value between 0 and 2**31 - 1. + TrialSeed uint32 `json:"trial_seed"` + Hparams HParamSample `json:"hparams"` +} + +// searcherAction (Create) implements SearcherAction. +func (Create) searcherAction() {} + +func (action Create) String() string { + return fmt.Sprintf( + "Create{TrialSeed: %d, Hparams: %v, RequestID: %d}", + action.TrialSeed, action.Hparams, action.RequestID, + ) +} + +// NewCreate initializes a new Create operation with a new request ID and the given hyperparameters. +func NewCreate( + rand *nprand.State, s HParamSample, +) Create { + return Create{ + RequestID: model.NewRequestID(rand), + TrialSeed: uint32(rand.Int64n(1 << 31)), + Hparams: s, + } +} + +// Stop is a directive from the searcher to stop a run. +type Stop struct { + RequestID model.RequestID `json:"request_id"` +} + +// SearcherAction (Stop) implements SearcherAction. +func (Stop) searcherAction() {} + +// NewStop initializes a new Stop action with the given Run ID. +func NewStop(requestID model.RequestID) Stop { + return Stop{RequestID: requestID} +} + +func (action Stop) String() string { + return fmt.Sprintf("Stop{RequestID: %d}", action.RequestID) +} + +// Shutdown marks the searcher as completed. +type Shutdown struct { + Cancel bool + Failure bool +} + +// SearcherAction (Shutdown) implements SearcherAction. +func (Shutdown) searcherAction() {} + +func (shutdown Shutdown) String() string { + return fmt.Sprintf("{Shutdown Cancel: %v Failure: %v}", shutdown.Cancel, shutdown.Failure) +} diff --git a/master/pkg/searcher/adaptive_asha.go b/master/pkg/searcher/adaptive_asha.go index 5c7e288a7c5..8df3e4873f4 100644 --- a/master/pkg/searcher/adaptive_asha.go +++ b/master/pkg/searcher/adaptive_asha.go @@ -10,6 +10,19 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) +type bracket struct { + numRungs int + maxTrials int + maxConcurrentTrials int +} + +func (b *bracket) String() string { + return fmt.Sprintf( + "Bracket{numRungs: %d, maxTrials: %d, maxConcurrentTrials: %d}", + b.numRungs, b.maxTrials, b.maxConcurrentTrials, + ) +} + func getBracketMaxTrials( maxTrials int, divisor float64, brackets []int, ) []int { @@ -52,7 +65,7 @@ func getBracketMaxConcurrentTrials( if maxConcurrentTrials == 0 { minTrials = mathx.Max(maxTrials[numBrackets-1], int(divisor)) } else { - // Without this, the remainder will be less than numBrackets and later brackets willgit pu + // Without this, the remainder will be less than numBrackets and later brackets will // not receive a constraint on bracketMaxConcurrentTrials. maxConcurrentTrials = mathx.Max(maxConcurrentTrials, numBrackets) minTrials = maxConcurrentTrials / numBrackets @@ -68,40 +81,51 @@ func getBracketMaxConcurrentTrials( return bracketMaxConcurrentTrials } -func newAdaptiveASHASearch(config expconf.AdaptiveASHAConfig, smallerIsBetter bool) SearchMethod { +func makeBrackets(config expconf.AdaptiveASHAConfig) []bracket { modeFunc := parseAdaptiveMode(config.Mode()) - brackets := config.BracketRungs() - if len(brackets) == 0 { + bracketRungs := config.BracketRungs() + if len(bracketRungs) == 0 { maxRungs := config.MaxRungs() + // Ensure that the top rung will contain at least one run. maxRungs = mathx.Min( maxRungs, - int(math.Log(float64(config.MaxLength().Units))/math.Log(config.Divisor()))+1, + int(math.Log(float64(config.Length().Units))/math.Log(config.Divisor()))+1, int(math.Log(float64(config.MaxTrials()))/math.Log(config.Divisor()))+1) - brackets = modeFunc(maxRungs) + bracketRungs = modeFunc(maxRungs) } // We prioritize brackets that perform more early stopping to try to max speedups early on. - sort.Sort(sort.Reverse(sort.IntSlice(brackets))) + sort.Sort(sort.Reverse(sort.IntSlice(bracketRungs))) bracketMaxTrials := getBracketMaxTrials( - config.MaxTrials(), config.Divisor(), brackets) + config.MaxTrials(), config.Divisor(), bracketRungs) bracketMaxConcurrentTrials := getBracketMaxConcurrentTrials( config.MaxConcurrentTrials(), config.Divisor(), bracketMaxTrials) + brackets := make([]bracket, len(bracketRungs)) + for i, bracketRung := range bracketRungs { + brackets[i] = bracket{ + numRungs: bracketRung, + maxTrials: bracketMaxTrials[i], + maxConcurrentTrials: bracketMaxConcurrentTrials[i], + } + } + return brackets +} + +func newAdaptiveASHASearch(config expconf.AdaptiveASHAConfig, smallerIsBetter bool, metric string) SearchMethod { + brackets := makeBrackets(config) methods := make([]SearchMethod, 0, len(brackets)) - for i, numRungs := range brackets { + for _, bracket := range brackets { c := expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(numRungs), - RawMaxLength: ptrs.Ptr(config.MaxLength()), - RawMaxTrials: &bracketMaxTrials[i], + RawNumRungs: ptrs.Ptr(bracket.numRungs), + RawMaxLength: config.RawMaxLength, + RawMaxTrials: &bracket.maxTrials, RawDivisor: ptrs.Ptr(config.Divisor()), - RawMaxConcurrentTrials: ptrs.Ptr(bracketMaxConcurrentTrials[i]), - RawStopOnce: ptrs.Ptr(config.StopOnce()), - } - if config.StopOnce() { - methods = append(methods, newAsyncHalvingStoppingSearch(c, smallerIsBetter)) - } else { - methods = append(methods, newAsyncHalvingSearch(c, smallerIsBetter)) + RawMaxConcurrentTrials: ptrs.Ptr(bracket.maxConcurrentTrials), + RawTimeMetric: config.RawTimeMetric, + RawMaxTime: config.RawMaxTime, } + methods = append(methods, newAsyncHalvingStoppingSearch(c, smallerIsBetter, metric)) } return newTournamentSearch(AdaptiveASHASearch, methods...) diff --git a/master/pkg/searcher/adaptive_asha_test.go b/master/pkg/searcher/adaptive_asha_test.go index 901668bd083..e91e3966284 100644 --- a/master/pkg/searcher/adaptive_asha_test.go +++ b/master/pkg/searcher/adaptive_asha_test.go @@ -4,6 +4,8 @@ package searcher import ( "testing" + "github.com/stretchr/testify/require" + "gotest.tools/assert" "github.com/determined-ai/determined/master/pkg/ptrs" @@ -25,194 +27,63 @@ func TestBracketMaxConcurrentTrials(t *testing.T) { assert.DeepEqual(t, getBracketMaxConcurrentTrials(0, 4., []int{40, 10}), []int{10, 10}) } -func modePtr(x expconf.AdaptiveMode) *expconf.AdaptiveMode { - return &x -} - -func TestAdaptiveASHASearcherReproducibility(t *testing.T) { - conf := expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(6400)), - RawMaxTrials: ptrs.Ptr(128), - } - conf = schemas.WithDefaults(conf) - gen := func() SearchMethod { return newAdaptiveASHASearch(conf, true) } - checkReproducibility(t, gen, nil, defaultMetric) -} - -func TestAdaptiveASHASearchMethod(t *testing.T) { - testCases := []valueSimulationTestCase{ +func TestMakeBrackets(t *testing.T) { + cases := []struct { + conf expconf.AdaptiveASHAConfig + expBrackets []bracket + }{ { - name: "smaller is better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.1), - newConstantPredefinedTrial(toOps("300B"), 0.2), - newConstantPredefinedTrial(toOps("300B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.4), - newConstantPredefinedTrial(toOps("900B"), 0.5), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), + conf: expconf.AdaptiveASHAConfig{ + RawMode: ptrs.Ptr(expconf.StandardMode), + RawMaxTime: ptrs.Ptr(100), + RawTimeMetric: ptrs.Ptr("batches"), + RawMaxConcurrentTrials: ptrs.Ptr(2), + RawMaxTrials: ptrs.Ptr(10), + }, + expBrackets: []bracket{ + { + numRungs: 2, + maxTrials: 7, + maxConcurrentTrials: 1, }, - }, - }, - { - name: "early exit -- smaller is better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.1), - newEarlyExitPredefinedTrial(toOps("300B"), 0.2), - newConstantPredefinedTrial(toOps("300B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.4), - newConstantPredefinedTrial(toOps("900B"), 0.5), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), + { + numRungs: 1, + maxTrials: 3, + maxConcurrentTrials: 1, }, }, }, { - name: "smaller is not better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.5), - newConstantPredefinedTrial(toOps("300B"), 0.4), - newConstantPredefinedTrial(toOps("300B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.2), - newConstantPredefinedTrial(toOps("900B"), 0.1), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), + conf: expconf.AdaptiveASHAConfig{ + RawMode: ptrs.Ptr(expconf.ConservativeMode), + RawMaxTime: ptrs.Ptr(1000), + RawTimeMetric: ptrs.Ptr("batches"), + RawDivisor: ptrs.Ptr(3.0), + RawMaxConcurrentTrials: ptrs.Ptr(5), + RawMaxTrials: ptrs.Ptr(10), + }, + expBrackets: []bracket{ + { + numRungs: 3, + maxTrials: 7, + maxConcurrentTrials: 2, }, - }, - }, - { - name: "early exit -- smaller is not better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.5), - newEarlyExitPredefinedTrial(toOps("300B"), 0.4), - newConstantPredefinedTrial(toOps("300B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.2), - newConstantPredefinedTrial(toOps("900B"), 0.1), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), + { + numRungs: 2, + maxTrials: 2, + maxConcurrentTrials: 2, }, - }, - }, - } - - runValueSimulationTestCases(t, testCases) -} - -func TestAdaptiveASHAStoppingSearchMethod(t *testing.T) { - testCases := []valueSimulationTestCase{ - { - name: "smaller is better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.1), - newConstantPredefinedTrial(toOps("300B"), 0.2), - newConstantPredefinedTrial(toOps("300B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.4), - newConstantPredefinedTrial(toOps("900B"), 0.5), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), - }, - }, - }, - { - name: "early exit -- smaller is better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.1), - newEarlyExitPredefinedTrial(toOps("300B"), 0.2), - newConstantPredefinedTrial(toOps("300B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.4), - newConstantPredefinedTrial(toOps("900B"), 0.5), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), - }, - }, - }, - { - name: "smaller is not better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.1), - newConstantPredefinedTrial(toOps("300B 900B"), 0.2), - newConstantPredefinedTrial(toOps("300B 900B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.4), - newConstantPredefinedTrial(toOps("900B"), 0.5), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), - }, - }, - }, - { - name: "early exit -- smaller is not better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B 900B"), 0.1), - newEarlyExitPredefinedTrial(toOps("300B"), 0.2), - newConstantPredefinedTrial(toOps("300B 900B"), 0.3), - newConstantPredefinedTrial(toOps("900B"), 0.4), - newConstantPredefinedTrial(toOps("900B"), 0.5), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(900)), - RawMaxTrials: ptrs.Ptr(5), - RawMode: modePtr(expconf.StandardMode), - RawMaxRungs: ptrs.Ptr(2), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), + { + numRungs: 1, + maxTrials: 1, + maxConcurrentTrials: 1, }, }, }, } - - runValueSimulationTestCases(t, testCases) + for _, c := range cases { + brackets := makeBrackets(schemas.WithDefaults(c.conf)) + require.Equal(t, len(c.expBrackets), len(brackets)) + require.Equal(t, c.expBrackets, brackets) + } } diff --git a/master/pkg/searcher/asha.go b/master/pkg/searcher/asha.go deleted file mode 100644 index 5f70f62f5c5..00000000000 --- a/master/pkg/searcher/asha.go +++ /dev/null @@ -1,327 +0,0 @@ -package searcher - -import ( - "encoding/json" - "fmt" - "math" - "sort" - - "github.com/determined-ai/determined/master/pkg/mathx" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/schemas/expconf" -) - -// AsyncHalvingSearch implements a search using the asynchronous successive halving algorithm -// (ASHA). The experiment will run until the target number of trials have been completed -// in the bottom rung and no further promotions can be made to higher rungs. -type ( - asyncHalvingSearchState struct { - Rungs []*rung `json:"rungs"` - TrialRungs map[model.RequestID]int `json:"trial_rungs"` - // EarlyExitTrials contains trials that exited early that are still considered in the search. - EarlyExitTrials map[model.RequestID]bool `json:"early_exit_trials"` - ClosedTrials map[model.RequestID]bool `json:"closed_trials"` - TrialsCompleted int `json:"trials_completed"` - InvalidTrials int `json:"invalid_trials"` - PendingTrials int `json:"pending_trials"` - SearchMethodType SearchMethodType `json:"search_method_type"` - } - - asyncHalvingSearch struct { - expconf.AsyncHalvingConfig - SmallerIsBetter bool - asyncHalvingSearchState - } - - trialMetric struct { - RequestID model.RequestID `json:"request_id"` - Metric model.ExtendedFloat64 `json:"metric"` - // fields below used by asha.go. - Promoted bool `json:"promoted"` - } - - // rung describes a set of trials that are to be trained for the same number of units. - rung struct { - UnitsNeeded uint64 `json:"units_needed"` - Metrics []trialMetric `json:"metrics"` - StartTrials int `json:"start_trials"` - PromoteTrials int `json:"promote_trials"` - // field below used by asha.go. - OutstandingTrials int `json:"outstanding_trials"` - } -) - -const ashaExitedMetricValue = math.MaxFloat64 - -func newAsyncHalvingSearch(config expconf.AsyncHalvingConfig, smallerIsBetter bool) SearchMethod { - rungs := make([]*rung, 0, config.NumRungs()) - var unitsNeeded uint64 - for id := 0; id < config.NumRungs(); id++ { - // We divide the MaxLength by downsampling rate to get the target units - // for a rung. - downsamplingRate := math.Pow(config.Divisor(), float64(config.NumRungs()-id-1)) - unitsNeeded += mathx.Max(uint64(float64(config.MaxLength().Units)/downsamplingRate), 1) - rungs = append(rungs, &rung{UnitsNeeded: unitsNeeded}) - } - - return &asyncHalvingSearch{ - AsyncHalvingConfig: config, - SmallerIsBetter: smallerIsBetter, - asyncHalvingSearchState: asyncHalvingSearchState{ - Rungs: rungs, - TrialRungs: make(map[model.RequestID]int), - EarlyExitTrials: make(map[model.RequestID]bool), - ClosedTrials: make(map[model.RequestID]bool), - SearchMethodType: ASHASearch, - }, - } -} - -func (s *asyncHalvingSearch) Snapshot() (json.RawMessage, error) { - return json.Marshal(s.asyncHalvingSearchState) -} - -func (s *asyncHalvingSearch) Restore(state json.RawMessage) error { - return json.Unmarshal(state, &s.asyncHalvingSearchState) -} - -// promotions handles bookkeeping of validation metrics and returns a RequestID to promote if -// appropriate. -func (r *rung) promotionsAsync( - requestID model.RequestID, metric float64, divisor float64, -) []model.RequestID { - // See if there is a trial to promote. We are increasing the total number of trials seen by 1; the - // number of best trials that definitely should have been promoted so far (numPromote) can only - // stay the same or increase by 1. - oldNumPromote := int(float64(len(r.Metrics)) / divisor) - numPromote := int(float64(len(r.Metrics)+1) / divisor) - - // Insert the new trial result in the appropriate place in the sorted list. - insertIndex := sort.Search( - len(r.Metrics), - func(i int) bool { return float64(r.Metrics[i].Metric) > metric }, - ) - promoteNow := insertIndex < numPromote - - r.Metrics = append(r.Metrics, trialMetric{}) - copy(r.Metrics[insertIndex+1:], r.Metrics[insertIndex:]) - r.Metrics[insertIndex] = trialMetric{ - RequestID: requestID, - Metric: model.ExtendedFloat64(metric), - Promoted: promoteNow, - } - - // If the new trial is good enough, it should be promoted immediately (whether or not numPromote - // changes). Otherwise, if numPromote changes, there is some other trial that should be promoted, - // unless it has been promoted already. - switch { - case promoteNow: - return []model.RequestID{requestID} - case numPromote != oldNumPromote && !r.Metrics[oldNumPromote].Promoted: - t := &r.Metrics[oldNumPromote] - t.Promoted = true - return []model.RequestID{t.RequestID} - default: - return nil - } -} - -func (s *asyncHalvingSearch) initialOperations(ctx context) ([]Operation, error) { - // The number of initialOperations will control the degree of parallelism - // of the search experiment since we guarantee that each validationComplete - // call will return a new train workload until we reach MaxTrials. - - // We will use searcher config field if available. - // Otherwise we will default to a number of trials that will - // guarantee at least one trial at the top rung. - var ops []Operation - var maxConcurrentTrials int - - if s.MaxConcurrentTrials() > 0 { - maxConcurrentTrials = mathx.Min(s.MaxConcurrentTrials(), s.MaxTrials()) - } else { - maxConcurrentTrials = mathx.Clamp( - 1, - int(math.Pow(s.Divisor(), float64(s.NumRungs()-1))), - s.MaxTrials(), - ) - } - - for trial := 0; trial < maxConcurrentTrials; trial++ { - create := NewCreate( - ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - s.TrialRungs[create.RequestID] = 0 - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.Rungs[0].UnitsNeeded)) - s.PendingTrials++ - } - return ops, nil -} - -func (s *asyncHalvingSearch) trialCreated( - ctx context, requestID model.RequestID, -) ([]Operation, error) { - s.Rungs[0].OutstandingTrials++ - s.TrialRungs[requestID] = 0 - return nil, nil -} - -func (s *asyncHalvingSearch) trialClosed( - ctx context, requestID model.RequestID, -) ([]Operation, error) { - s.TrialsCompleted++ - s.ClosedTrials[requestID] = true - return nil, nil -} - -func (s *asyncHalvingSearch) validationCompleted( - ctx context, requestID model.RequestID, metric interface{}, op ValidateAfter, -) ([]Operation, error) { - s.PendingTrials-- - value, ok := metric.(float64) - if !ok { - return nil, fmt.Errorf("unexpected metric type for ASHA built-in search method %v", value) - } - if !s.SmallerIsBetter { - value *= -1 - } - return s.promoteAsync(ctx, requestID, value), nil -} - -func (s *asyncHalvingSearch) promoteAsync( - ctx context, requestID model.RequestID, metric float64, -) []Operation { - // Upon a validation complete, we should return at least one more train&val workload - // unless the bracket of successive halving is finished. - rungIndex := s.TrialRungs[requestID] - rung := s.Rungs[rungIndex] - rung.OutstandingTrials-- - addedTrainWorkload := false - - var ops []Operation - // If the trial has completed the top rung's validation, close the trial. - if rungIndex == s.NumRungs()-1 { - rung.Metrics = append(rung.Metrics, - trialMetric{ - RequestID: requestID, - Metric: model.ExtendedFloat64(metric), - }, - ) - - if !s.EarlyExitTrials[requestID] { - ops = append(ops, NewClose(requestID)) - s.ClosedTrials[requestID] = true - } - } else { - // This is not the top rung, so do promotions to the next rung. - nextRung := s.Rungs[rungIndex+1] - for _, promotionID := range rung.promotionsAsync( - requestID, - metric, - s.Divisor(), - ) { - s.TrialRungs[promotionID] = rungIndex + 1 - nextRung.OutstandingTrials++ - if s.EarlyExitTrials[promotionID] { - // We make a recursive call that will behave the same - // as if we'd actually run the promoted job and received - // the worse possible result in return. - return s.promoteAsync(ctx, promotionID, ashaExitedMetricValue) - } - unitsNeeded := mathx.Max(nextRung.UnitsNeeded-rung.UnitsNeeded, 1) - ops = append(ops, NewValidateAfter(promotionID, unitsNeeded)) - addedTrainWorkload = true - s.PendingTrials++ - } - } - - allTrials := len(s.TrialRungs) - s.InvalidTrials - if !addedTrainWorkload && allTrials < s.MaxTrials() { - s.PendingTrials++ - create := NewCreate( - ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - s.TrialRungs[create.RequestID] = 0 - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.Rungs[0].UnitsNeeded)) - } - - // Only close out trials once we have reached the MaxTrials for the searcher. - if len(s.Rungs[0].Metrics) == s.MaxTrials() { - ops = append(ops, s.closeOutRungs()...) - } - return ops -} - -// closeOutRungs closes all remaining unpromoted trials in any rungs that have no more outstanding -// trials. -func (s *asyncHalvingSearch) closeOutRungs() []Operation { - var ops []Operation - for _, rung := range s.Rungs { - if rung.OutstandingTrials > 0 { - break - } - for _, trialMetric := range rung.Metrics { - if !trialMetric.Promoted && !s.ClosedTrials[trialMetric.RequestID] { - if !s.EarlyExitTrials[trialMetric.RequestID] { - ops = append(ops, NewClose(trialMetric.RequestID)) - s.ClosedTrials[trialMetric.RequestID] = true - } - } - } - } - return ops -} - -func (s *asyncHalvingSearch) progress( - map[model.RequestID]PartialUnits, map[model.RequestID]bool, -) float64 { - if s.MaxConcurrentTrials() > 0 && s.PendingTrials > s.MaxConcurrentTrials() { - panic("pending trials is greater than max_concurrent_trials") - } - allTrials := len(s.Rungs[0].Metrics) - // Give ourselves an overhead of 20% of MaxTrials when calculating progress. - progress := float64(allTrials) / (1.2 * float64(s.MaxTrials())) - if allTrials == s.MaxTrials() { - numValidTrials := float64(s.TrialsCompleted) - float64(s.InvalidTrials) - progressNoOverhead := numValidTrials / float64(s.MaxTrials()) - progress = math.Max(progressNoOverhead, progress) - } - return progress -} - -func (s *asyncHalvingSearch) trialExitedEarly( - ctx context, requestID model.RequestID, exitedReason model.ExitedReason, -) ([]Operation, error) { - s.PendingTrials-- - if exitedReason == model.InvalidHP || exitedReason == model.InitInvalidHP { - var ops []Operation - s.EarlyExitTrials[requestID] = true - ops = append(ops, NewClose(requestID)) - s.ClosedTrials[requestID] = true - s.InvalidTrials++ - // Remove metrics associated with InvalidHP trial across all rungs - highestRungIndex := s.TrialRungs[requestID] - rung := s.Rungs[highestRungIndex] - rung.OutstandingTrials-- - for rungIndex := 0; rungIndex <= highestRungIndex; rungIndex++ { - rung := s.Rungs[rungIndex] - for i, trialMetric := range rung.Metrics { - if trialMetric.RequestID == requestID { - rung.Metrics = append(rung.Metrics[:i], rung.Metrics[i+1:]...) - break - } - } - } - // Add new trial to searcher queue - create := NewCreate( - ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - s.TrialRungs[create.RequestID] = 0 - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.Rungs[0].UnitsNeeded)) - s.PendingTrials++ - return ops, nil - } - s.EarlyExitTrials[requestID] = true - s.ClosedTrials[requestID] = true - return s.promoteAsync(ctx, requestID, ashaExitedMetricValue), nil -} diff --git a/master/pkg/searcher/asha_stopping.go b/master/pkg/searcher/asha_stopping.go index 87f9f645f2b..6c3e103182e 100644 --- a/master/pkg/searcher/asha_stopping.go +++ b/master/pkg/searcher/asha_stopping.go @@ -6,50 +6,79 @@ import ( "math" "sort" + "github.com/determined-ai/determined/master/pkg/ptrs" + "github.com/determined-ai/determined/master/pkg/mathx" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) -// AsyncHalvingStoppingSearch implements a modified version of the asynchronous successive -// halving algorithm (ASHA) that does not require fault tolerance to perform early-stopping. -// For each trial, after a train and validation workload, the algorithm will decide whether -// to stop or continue training the trial based on the ranking of the validation metric -// compared to other trials in a particular rung. Once a trial has been stopped, it will not -// be resumed later; this is why the algorithm does not require fault tolerance. -// The searcher state and config match that of AsyncHalvingSearch but we will only run -// the stopping based version if StopOnce is true. +// AsyncHalvingStoppingSearch implements a version of the asynchronous successive halving +// algorithm (ASHA) that early-stops worse performing trials rather than actively promoting better +// performing trials. When a new validation metric is reported, the searcher decides if the run +// should be stopped based on the ranking of the metric compared to other trials' metrics in the +// same rung. type asyncHalvingStoppingSearch struct { expconf.AsyncHalvingConfig SmallerIsBetter bool + Metric string asyncHalvingSearchState } +type ( + asyncHalvingSearchState struct { + Rungs []*rung `json:"rungs"` + TrialRungs map[model.RequestID]int `json:"trial_rungs"` + // EarlyExitTrials contains trials that exited early that are still considered in the search. + EarlyExitTrials map[model.RequestID]bool `json:"early_exit_trials"` + TrialsCompleted int `json:"trials_completed"` + InvalidTrials int `json:"invalid_trials"` + SearchMethodType SearchMethodType `json:"search_method_type"` + } -func newAsyncHalvingStoppingSearch( - config expconf.AsyncHalvingConfig, smallerIsBetter bool, -) SearchMethod { - rungs := make([]*rung, 0, config.NumRungs()) - var unitsNeeded uint64 - for id := 0; id < config.NumRungs(); id++ { + runMetric struct { + RequestID model.RequestID `json:"request_id"` + Metric model.ExtendedFloat64 `json:"metric"` + } + rung struct { + UnitsNeeded uint64 `json:"units_needed"` + Metrics []runMetric `json:"metrics"` + } +) + +func (r *rung) String() string { + return fmt.Sprintf("Rung{UnitsNeeded: %d, Metrics: %v}", r.UnitsNeeded, r.Metrics) +} + +const ashaExitedMetricValue = math.MaxFloat64 + +func makeRungs(numRungs int, divisor float64, maxLength uint64) []*rung { + rungs := make([]*rung, 0, numRungs) + for i := 0; i < numRungs; i++ { // We divide the MaxLength by downsampling rate to get the target units - // for a rung. - downsamplingRate := math.Pow(config.Divisor(), float64(config.NumRungs()-id-1)) - unitsNeeded += mathx.Max(uint64(float64(config.MaxLength().Units)/downsamplingRate), 1) + // for a bracketRung. + downsamplingRate := math.Pow(divisor, float64(numRungs-i-1)) + unitsNeeded := mathx.Max(uint64(float64(maxLength)/downsamplingRate), 1) rungs = append(rungs, &rung{ - UnitsNeeded: unitsNeeded, - OutstandingTrials: 0, + UnitsNeeded: unitsNeeded, }) } + return rungs +} + +func newAsyncHalvingStoppingSearch( + config expconf.AsyncHalvingConfig, smallerIsBetter bool, metric string, +) SearchMethod { + rungs := makeRungs(config.NumRungs(), config.Divisor(), config.Length().Units) return &asyncHalvingStoppingSearch{ AsyncHalvingConfig: config, SmallerIsBetter: smallerIsBetter, + Metric: metric, asyncHalvingSearchState: asyncHalvingSearchState{ Rungs: rungs, TrialRungs: make(map[model.RequestID]int), EarlyExitTrials: make(map[model.RequestID]bool), - ClosedTrials: make(map[model.RequestID]bool), SearchMethodType: ASHASearch, }, } @@ -63,43 +92,34 @@ func (s *asyncHalvingStoppingSearch) Restore(state json.RawMessage) error { return json.Unmarshal(state, &s.asyncHalvingSearchState) } -// promotions handles bookkeeping of validation metrics and decides whether to continue -// training the current trial. -func (r *rung) continueTraining(requestID model.RequestID, metric float64, divisor float64) bool { - // Compute cutoff for promotion to next rung to continue training. - numPromote := mathx.Max(int(float64(len(r.Metrics)+1)/divisor), 1) - - // Insert the new trial result in the appropriate place in the sorted list. +// insertMetric adds a completed validation metric to the rung in the appropriate order of all +// the metrics in the rung thus far and returns the insert index. +func (r *rung) insertMetric(requestID model.RequestID, metric float64) int { insertIndex := sort.Search( len(r.Metrics), func(i int) bool { return float64(r.Metrics[i].Metric) >= metric }, ) - // We will continue training if trial ranked in top 1/divisor for the rung or - // if there are fewere than divisor trials in the rung. - promoteNow := insertIndex < numPromote - r.Metrics = append(r.Metrics, trialMetric{}) + // Add metrics to state. + r.Metrics = append(r.Metrics, runMetric{}) copy(r.Metrics[insertIndex+1:], r.Metrics[insertIndex:]) - r.Metrics[insertIndex] = trialMetric{ + r.Metrics[insertIndex] = runMetric{ RequestID: requestID, Metric: model.ExtendedFloat64(metric), - Promoted: promoteNow, } - - return promoteNow + return insertIndex } -func (s *asyncHalvingStoppingSearch) initialOperations(ctx context) ([]Operation, error) { - // The number of initialOperations will control the degree of parallelism - // of the search experiment since we guarantee that each validationComplete - // call will return a new train workload until we reach MaxTrials. - - // We will use searcher config field if available. - // Otherwise we will default to a number of trials that will - // guarantee at least one trial at the top rung. - var ops []Operation +// initialTrials specifies the initial trials that the search will create. +// Since each run can only stop and create a new run, this effectively controls the degree of +// parallelism of the search. +func (s *asyncHalvingStoppingSearch) initialTrials(ctx context) ([]Action, error) { + var actions []Action var maxConcurrentTrials int + // Use searcher config fields to determine number of trials if set. + // Otherwise, default to a number of trials that guarantees at least one run will continue + // to the top rung. if s.MaxConcurrentTrials() > 0 { maxConcurrentTrials = mathx.Min(s.MaxConcurrentTrials(), s.MaxTrials()) } else { @@ -112,111 +132,113 @@ func (s *asyncHalvingStoppingSearch) initialOperations(ctx context) ([]Operation for trial := 0; trial < maxConcurrentTrials; trial++ { create := NewCreate( - ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - s.TrialRungs[create.RequestID] = 0 - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.Rungs[0].UnitsNeeded)) + ctx.rand, sampleAll(ctx.hparams, ctx.rand)) + actions = append(actions, create) } - return ops, nil + return actions, nil } func (s *asyncHalvingStoppingSearch) trialCreated( ctx context, requestID model.RequestID, -) ([]Operation, error) { - s.Rungs[0].OutstandingTrials++ +) ([]Action, error) { s.TrialRungs[requestID] = 0 return nil, nil } -func (s *asyncHalvingStoppingSearch) trialClosed( +func (s *asyncHalvingStoppingSearch) trialExited( ctx context, requestID model.RequestID, -) ([]Operation, error) { +) ([]Action, error) { s.TrialsCompleted++ - s.ClosedTrials[requestID] = true return nil, nil } +// validationCompleted handles every validation metric reported by a run and returns any resulting +// actions the searcher would like to take. func (s *asyncHalvingStoppingSearch) validationCompleted( - ctx context, requestID model.RequestID, metric interface{}, op ValidateAfter, -) ([]Operation, error) { - value, ok := metric.(float64) + ctx context, requestID model.RequestID, metrics map[string]interface{}, +) ([]Action, error) { + timeStep, value, err := s.getMetric(metrics) + if err != nil { + return nil, err + } + + ops := s.doEarlyStopping(requestID, *timeStep, *value) + allTrials := len(s.TrialRungs) - s.InvalidTrials + if len(ops) > 0 && allTrials < s.MaxTrials() { + create := NewCreate(ctx.rand, sampleAll(ctx.hparams, ctx.rand)) + ops = append(ops, create) + } + return ops, nil +} + +// getMetric reads the searcher metric and time step value from the reported validation metrics. +func (s *asyncHalvingStoppingSearch) getMetric(metrics map[string]interface{}) (*uint64, *float64, error) { + searcherMetric, ok := metrics[s.Metric].(float64) + if !ok { - return nil, fmt.Errorf("unexpected metric type for ASHA built-in search method %v", value) + return nil, nil, fmt.Errorf("error parsing searcher metric (%s) from validation metrics: %v", s.Metric, metrics) } if !s.SmallerIsBetter { - value *= -1 + searcherMetric *= -1 + } + + unit := string(s.Length().Unit) + stepNum, ok := metrics[unit].(float64) + if !ok { + return nil, nil, fmt.Errorf("error parsing searcher time metric (%s) in validation metrics: %v", unit, metrics) } - return s.promoteAsync(ctx, requestID, value), nil + + return ptrs.Ptr(uint64(stepNum)), &searcherMetric, nil } -func (s *asyncHalvingStoppingSearch) promoteAsync( - ctx context, requestID model.RequestID, metric float64, -) []Operation { - // Upon a validation complete, we should return at least one more train&val workload - // unless the bracket of successive halving is finished. +// doEarlyStopping handles early-stopping and record-keeping logic for a validation metric reported to the +// searcher. +// If the metric qualifies the run for a rung but is not in the top 1/divisor trials for that rung, +// doEarlyStopping will return a single `searcher.Stop` action. Otherwise, no actions will be returned. +func (s *asyncHalvingStoppingSearch) doEarlyStopping( + requestID model.RequestID, timeStep uint64, metric float64, +) []Action { rungIndex := s.TrialRungs[requestID] - rung := s.Rungs[rungIndex] - rung.OutstandingTrials-- - addedTrainWorkload := false - - var ops []Operation - // If the trial has completed the top rung's validation, close the trial. - if rungIndex == s.NumRungs()-1 { - rung.Metrics = append(rung.Metrics, - trialMetric{ - RequestID: requestID, - Metric: model.ExtendedFloat64(metric), - }, - ) + var actions []Action + + // Starting at current rung, check if run should continue to next rung or early-stop. + // Since validations aren't controlled by searcher, they could complete > 1 rungs at a time. + for r := rungIndex; r < s.NumRungs(); r++ { + rung := s.Rungs[r] + s.TrialRungs[requestID] = r - if !s.EarlyExitTrials[requestID] { - ops = append(ops, NewClose(requestID)) - s.ClosedTrials[requestID] = true + // If run has not completed enough steps to qualify for this rung, exit. + if timeStep < rung.UnitsNeeded { + return actions } - } else { - // This is not the top rung, so do promotions to the next rung. - nextRung := s.Rungs[rungIndex+1] - // We need to run continueTraining even if the trial was terminated early so that we - // can add the metric to the rung. - promoteTrial := rung.continueTraining( - requestID, - metric, - s.Divisor(), - ) - // In contrast to promotion-based ASHA, we will not let early-exited trials add - // -/+inf metrics to higher rungs even if portion of terminated trials in bottom rung - // is greater than 1 - 1 / divisor. - if !s.EarlyExitTrials[requestID] { - if promoteTrial { - s.TrialRungs[requestID] = rungIndex + 1 - nextRung.OutstandingTrials++ - unitsNeeded := mathx.Max(nextRung.UnitsNeeded-rung.UnitsNeeded, 1) - ops = append(ops, NewValidateAfter(requestID, unitsNeeded)) - addedTrainWorkload = true - } else { - ops = append(ops, NewClose(requestID)) - s.ClosedTrials[requestID] = true - } + + insertIndex := rung.insertMetric(requestID, metric) + + // If this is the top rung, close the run and exit. + if r == s.NumRungs()-1 { + actions = append(actions, NewStop(requestID)) + return actions } - } - allTrials := len(s.TrialRungs) - s.InvalidTrials - if !addedTrainWorkload && allTrials < s.MaxTrials() { - create := NewCreate( - ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - s.TrialRungs[create.RequestID] = 0 - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.Rungs[0].UnitsNeeded)) - } + // Top 1/divisor trials should continue, 1 - 1/divisor trials should be stopped. + // If trials < divisor, continue only if this is the best performing run so far. + numContinue := mathx.Max(int(float64(len(rung.Metrics))/s.Divisor()), 1) + + if insertIndex >= numContinue { + actions = append(actions, NewStop(requestID)) + return actions + } - return ops + // Continue to next rung. + } + return actions } func (s *asyncHalvingStoppingSearch) progress( - map[model.RequestID]PartialUnits, map[model.RequestID]bool, + map[model.RequestID]float64, map[model.RequestID]bool, ) float64 { allTrials := len(s.Rungs[0].Metrics) - // Give ourselves an overhead of 20% of maxTrials when calculating progress. + // Give ourselves an overhead of 20% of max trials when calculating progress. progress := float64(allTrials) / (1.2 * float64(s.MaxTrials())) if allTrials == s.MaxTrials() { numValidTrials := float64(s.TrialsCompleted) - float64(s.InvalidTrials) @@ -228,12 +250,11 @@ func (s *asyncHalvingStoppingSearch) progress( func (s *asyncHalvingStoppingSearch) trialExitedEarly( ctx context, requestID model.RequestID, exitedReason model.ExitedReason, -) ([]Operation, error) { +) ([]Action, error) { if exitedReason == model.InvalidHP || exitedReason == model.InitInvalidHP { - var ops []Operation + var actions []Action s.EarlyExitTrials[requestID] = true - ops = append(ops, NewClose(requestID)) - s.ClosedTrials[requestID] = true + actions = append(actions, NewStop(requestID)) s.InvalidTrials++ // Remove metrics associated with InvalidHP trial across all rungs highestRungIndex := s.TrialRungs[requestID] @@ -247,14 +268,26 @@ func (s *asyncHalvingStoppingSearch) trialExitedEarly( } } // Add new trial to searcher queue - create := NewCreate( - ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - s.TrialRungs[create.RequestID] = 0 - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.Rungs[0].UnitsNeeded)) - return ops, nil + create := NewCreate(ctx.rand, sampleAll(ctx.hparams, ctx.rand)) + actions = append(actions, create) + return actions, nil } s.EarlyExitTrials[requestID] = true - s.ClosedTrials[requestID] = true - return s.promoteAsync(ctx, requestID, ashaExitedMetricValue), nil + + var actions []Action + rungIndex := s.TrialRungs[requestID] + rung := s.Rungs[rungIndex] + + rung.insertMetric(requestID, ashaExitedMetricValue) + + allTrials := len(s.TrialRungs) - s.InvalidTrials + if allTrials < s.MaxTrials() { + create := NewCreate(ctx.rand, sampleAll(ctx.hparams, ctx.rand)) + actions = append(actions, create) + } + return actions, nil +} + +func (s *asyncHalvingStoppingSearch) Type() SearchMethodType { + return s.SearchMethodType } diff --git a/master/pkg/searcher/asha_stopping_test.go b/master/pkg/searcher/asha_stopping_test.go index 0228351d9bb..87cd2a1ecd7 100644 --- a/master/pkg/searcher/asha_stopping_test.go +++ b/master/pkg/searcher/asha_stopping_test.go @@ -4,262 +4,448 @@ package searcher import ( "testing" + "github.com/stretchr/testify/require" + + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) -func TestASHAStoppingSearcherRecords(t *testing.T) { - actual := expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInRecords(576000)), - RawDivisor: ptrs.Ptr[float64](3), - RawMaxTrials: ptrs.Ptr(12), - RawStopOnce: ptrs.Ptr(true), - RawMaxConcurrentTrials: ptrs.Ptr(2), +func TestMakeRungs(t *testing.T) { + cases := []struct { + numRungs int + maxTime uint64 + divisor float64 + expectedRungs []*rung + }{ + { + numRungs: 3, + maxTime: 9, + divisor: float64(3), + expectedRungs: []*rung{ + { + UnitsNeeded: 1, + }, + { + UnitsNeeded: 3, + }, + { + UnitsNeeded: 9, + }, + }, + }, + { + numRungs: 4, + maxTime: 10, + divisor: float64(2), + expectedRungs: []*rung{ + { + UnitsNeeded: 1, + }, + { + UnitsNeeded: 2, + }, + { + UnitsNeeded: 5, + }, + { + UnitsNeeded: 10, + }, + }, + }, + { + numRungs: 1, + maxTime: 9, + divisor: float64(3), + expectedRungs: []*rung{ + { + UnitsNeeded: 9, + }, + }, + }, + { + numRungs: 3, + maxTime: 900, + divisor: float64(3), + expectedRungs: []*rung{ + { + UnitsNeeded: 100, + }, + { + UnitsNeeded: 300, + }, + { + UnitsNeeded: 900, + }, + }, + }, } - actual = schemas.WithDefaults(actual) - // Stopping-based ASHA will only promote if a trial is in top 1/3 of trials in the rung or if - // there have been no promotions so far. Since trials cannot be restarted and metrics increase - // for later trials, only the first trial will be promoted and all others will be stopped on - // the first rung. See continueTraining method in asha_stopping.go for the logic. - expected := [][]ValidateAfter{ - toOps("64000R 192000R 576000R"), - toOps("64000R"), toOps("64000R"), toOps("64000R"), - toOps("64000R"), toOps("64000R"), toOps("64000R"), - toOps("64000R"), toOps("64000R"), toOps("64000R"), - toOps("64000R"), toOps("64000R"), + for _, c := range cases { + rungs := makeRungs(c.numRungs, c.divisor, c.maxTime) + require.Equal(t, c.expectedRungs, rungs) } - checkSimulation(t, newAsyncHalvingStoppingSearch(actual, true), nil, TrialIDMetric, expected) } -func TestASHAStoppingSearcherBatches(t *testing.T) { - actual := expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawDivisor: ptrs.Ptr[float64](3), - RawMaxTrials: ptrs.Ptr(12), - RawStopOnce: ptrs.Ptr(true), - RawMaxConcurrentTrials: ptrs.Ptr(2), +func TestInsertCompletedMetric(t *testing.T) { + cases := []struct { + newMetric float64 + existingMetrics []float64 + expectedInsertIndex int + expectedMetrics []float64 + }{ + { + newMetric: 1.2, + existingMetrics: []float64{0.0, 1.5, 2.1}, + expectedInsertIndex: 1, + expectedMetrics: []float64{0.0, 1.2, 1.5, 2.1}, + }, + { + newMetric: 3.0, + existingMetrics: []float64{0.0, 1.5, 2.0}, + expectedInsertIndex: 3, + expectedMetrics: []float64{0.0, 1.5, 2.0, 3.0}, + }, + { + newMetric: -3.4, + existingMetrics: []float64{-3.0, -2.0, -1.0}, + expectedInsertIndex: 0, + expectedMetrics: []float64{-3.4, -3.0, -2.0, -1.0}, + }, + { + newMetric: 1.2, + existingMetrics: []float64{}, + expectedInsertIndex: 0, + expectedMetrics: []float64{1.2}, + }, + } + rung := rung{ + UnitsNeeded: 0, + Metrics: []runMetric{}, } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("1000B 3000B 9000B"), - toOps("1000B"), toOps("1000B"), toOps("1000B"), - toOps("1000B"), toOps("1000B"), toOps("1000B"), - toOps("1000B"), toOps("1000B"), toOps("1000B"), - toOps("1000B"), toOps("1000B"), + for _, c := range cases { + var currentMetrics []runMetric + for _, m := range c.existingMetrics { + currentMetrics = append(currentMetrics, runMetric{ + Metric: model.ExtendedFloat64(m), + }) + } + rung.Metrics = currentMetrics + insertIndex := rung.insertMetric(model.RequestID{}, c.newMetric) + var addedMetrics []float64 + for _, m := range rung.Metrics { + addedMetrics = append(addedMetrics, float64(m.Metric)) + } + require.Equal(t, c.expectedInsertIndex, insertIndex) + require.Equal(t, c.expectedMetrics, addedMetrics) } - checkSimulation(t, newAsyncHalvingStoppingSearch(actual, true), nil, TrialIDMetric, expected) } -func TestASHAStoppingSearcherEpochs(t *testing.T) { - actual := expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInEpochs(12)), - RawDivisor: ptrs.Ptr[float64](3), - RawMaxTrials: ptrs.Ptr(12), - RawStopOnce: ptrs.Ptr(true), - RawMaxConcurrentTrials: ptrs.Ptr(2), +func TestGetMetric(t *testing.T) { + cases := []struct { + metrics map[string]interface{} + metricName string + timeMetricName string + timeMetric int + smallerIsBetter bool + expectedTimeStep int + expectedMetric float64 + expectedError string + }{ + { + metrics: map[string]interface{}{"loss": 0.25, "batches": 2.0}, + metricName: "loss", + timeMetricName: "batches", + smallerIsBetter: true, + expectedTimeStep: 2, + expectedMetric: 0.25, + }, + { + metrics: map[string]interface{}{"loss": 0.2, "batches": 3.0}, + metricName: "loss", + timeMetricName: "batches", + smallerIsBetter: false, + expectedTimeStep: 3, + expectedMetric: -0.2, + }, + { + metrics: map[string]interface{}{"loss": 1.2, "custom_time_step": 5.0}, + metricName: "loss", + timeMetricName: "custom_time_step", + smallerIsBetter: true, + expectedTimeStep: 5, + expectedMetric: 1.2, + }, + { + metrics: model.JSONObj{"batches": 2.0}, + metricName: "loss", + timeMetricName: "batches", + smallerIsBetter: true, + expectedError: "error parsing searcher metric", + }, } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("1E 4E 12E"), - toOps("1E"), toOps("1E"), toOps("1E"), - toOps("1E"), toOps("1E"), toOps("1E"), - toOps("1E"), toOps("1E"), toOps("1E"), - toOps("1E"), toOps("1E"), + + searcher := &asyncHalvingStoppingSearch{} + for _, c := range cases { + searcher.Metric = c.metricName + searcher.RawTimeMetric = &c.timeMetricName + searcher.SmallerIsBetter = c.smallerIsBetter + searcher.RawMaxTime = ptrs.Ptr(10) + stepNum, searcherMetric, err := searcher.getMetric(c.metrics) + if c.expectedError != "" { + require.ErrorContains(t, err, c.expectedError) + } else { + require.NoError(t, err, "got unexpected error %v: %v", err, c) + require.Equal(t, uint64(c.expectedTimeStep), *stepNum, "time step does not match") + require.InEpsilon(t, c.expectedMetric, *searcherMetric, 0.001, "searcher metric value doesn't match") + } } - checkSimulation(t, newAsyncHalvingStoppingSearch(actual, true), nil, TrialIDMetric, expected) } -func TestASHAStoppingSearchMethod(t *testing.T) { - testCases := []valueSimulationTestCase{ +func TestStopTrials(t *testing.T) { + type testMetric struct { + rID model.RequestID + timeStep uint64 + metric float64 + } + + cases := []struct { + name string + rungs []*rung + runRungs map[model.RequestID]int + divisor float64 + metric testMetric + expectedOps []Action + expectedRunRungs map[model.RequestID]int + expectedRungs []*rung + }{ { - name: "smaller is better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.05), - newConstantPredefinedTrial(toOps("1000B"), 0.06), - newConstantPredefinedTrial(toOps("1000B"), 0.07), - newConstantPredefinedTrial(toOps("1000B"), 0.08), - newConstantPredefinedTrial(toOps("1000B"), 0.09), - newConstantPredefinedTrial(toOps("1000B"), 0.10), - newConstantPredefinedTrial(toOps("1000B"), 0.11), - newConstantPredefinedTrial(toOps("1000B"), 0.12), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), + name: "first validation", + rungs: []*rung{ + { + UnitsNeeded: 1, + }, + { + UnitsNeeded: 3, + }, + { + UnitsNeeded: 9, }, }, - }, - { - name: "smaller is better (round robin)", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.04), + runRungs: map[model.RequestID]int{ + mockRequestID(1): 0, }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), - }, + divisor: 3.0, + metric: testMetric{ + rID: mockRequestID(1), + timeStep: 1, + metric: 0.5, }, - }, - { - name: "smaller is not better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.05), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.06), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.07), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.08), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.09), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.10), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.11), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.12), + expectedRunRungs: map[model.RequestID]int{ + mockRequestID(1): 1, }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), + expectedRungs: []*rung{ + { + UnitsNeeded: 1, + Metrics: []runMetric{ + { + RequestID: mockRequestID(1), + Metric: model.ExtendedFloat64(0.5), + }, + }, + }, + { + UnitsNeeded: 3, + }, + { + UnitsNeeded: 9, }, }, + expectedOps: []Action(nil), }, { - name: "smaller is not better (round robin)", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.04), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), + name: "second validation better than first", + rungs: []*rung{ + { + UnitsNeeded: 1, + Metrics: []runMetric{ + { + RequestID: mockRequestID(1), + Metric: model.ExtendedFloat64(0.5), + }, + }, + }, + { + UnitsNeeded: 3, + }, + { + UnitsNeeded: 9, }, }, - }, - { - name: "early exit -- smaller is better (round robin)", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.04), - newEarlyExitPredefinedTrial(toOps("1000B 3000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.04), + runRungs: map[model.RequestID]int{ + mockRequestID(1): 1, + mockRequestID(2): 0, }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), - }, + divisor: 3.0, + metric: testMetric{ + rID: mockRequestID(2), + timeStep: 1, + metric: 0.4, }, - }, - { - name: "early exit -- smaller is not better (round robin)", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.03), - newEarlyExitPredefinedTrial(toOps("1000B 3000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.04), + expectedRunRungs: map[model.RequestID]int{ + mockRequestID(1): 1, + mockRequestID(2): 1, }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), + expectedRungs: []*rung{ + { + UnitsNeeded: 1, + Metrics: []runMetric{ + { + RequestID: mockRequestID(2), + Metric: model.ExtendedFloat64(0.4), + }, + { + RequestID: mockRequestID(1), + Metric: model.ExtendedFloat64(0.5), + }, + }, + }, + { + UnitsNeeded: 3, + }, + { + UnitsNeeded: 9, }, }, + expectedOps: []Action(nil), }, { - name: "single rung bracket", - expectedTrials: []predefinedTrial{ - // The first trial is promoted due to asynchronous - // promotions despite being below top 1/3 of trials in - // base rung. - newConstantPredefinedTrial(toOps("9000B"), 0.05), - newConstantPredefinedTrial(toOps("9000B"), 0.06), - newConstantPredefinedTrial(toOps("9000B"), 0.07), - newConstantPredefinedTrial(toOps("9000B"), 0.08), + name: "second validation worse than first", + rungs: []*rung{ + { + UnitsNeeded: 1, + Metrics: []runMetric{ + { + RequestID: mockRequestID(1), + Metric: model.ExtendedFloat64(0.5), + }, + }, + }, + { + UnitsNeeded: 3, + }, + { + UnitsNeeded: 9, + }, + }, + runRungs: map[model.RequestID]int{ + mockRequestID(1): 1, + mockRequestID(2): 0, }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(1), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(4), - RawDivisor: ptrs.Ptr[float64](3), - RawStopOnce: ptrs.Ptr(true), + divisor: 3.0, + metric: testMetric{ + rID: mockRequestID(2), + timeStep: 1, + metric: 0.6, + }, + expectedRunRungs: map[model.RequestID]int{ + mockRequestID(1): 1, + mockRequestID(2): 0, + }, + expectedRungs: []*rung{ + { + UnitsNeeded: 1, + Metrics: []runMetric{ + { + RequestID: mockRequestID(1), + Metric: model.ExtendedFloat64(0.5), + }, + { + RequestID: mockRequestID(2), + Metric: model.ExtendedFloat64(0.6), + }, + }, + }, + { + UnitsNeeded: 3, + }, + { + UnitsNeeded: 9, }, }, + expectedOps: []Action{Stop{RequestID: mockRequestID(2)}}, }, } - runValueSimulationTestCases(t, testCases) + searcher := &asyncHalvingStoppingSearch{} + for _, c := range cases { + searcher.TrialRungs = c.runRungs + searcher.Rungs = c.rungs + searcher.AsyncHalvingConfig.RawDivisor = &c.divisor + numRungs := len(c.rungs) + searcher.AsyncHalvingConfig.RawNumRungs = &numRungs + ops := searcher.doEarlyStopping(c.metric.rID, c.metric.timeStep, c.metric.metric) + require.Equal(t, c.expectedOps, ops) + require.Equal(t, c.expectedRungs, searcher.Rungs) + require.Equal(t, c.expectedRunRungs, searcher.TrialRungs) + } +} + +func TestASHAStoppingSearchMethod(t *testing.T) { + maxConcurrentTrials := 3 + maxTrials := 10 + divisor := 3.0 + maxTime := 900 + metric := "val_loss" + config := expconf.AsyncHalvingConfig{ + RawMaxTime: &maxTime, + RawDivisor: &divisor, + RawNumRungs: ptrs.Ptr(3), + RawMaxConcurrentTrials: &maxConcurrentTrials, + RawMaxTrials: &maxTrials, + RawTimeMetric: ptrs.Ptr("batches"), + } + searcherConfig := expconf.SearcherConfig{ + RawAsyncHalvingConfig: &config, + RawSmallerIsBetter: ptrs.Ptr(true), + RawMetric: ptrs.Ptr(metric), + } + config = schemas.WithDefaults(config) + searcherConfig = schemas.WithDefaults(searcherConfig) + + intHparam := &expconf.IntHyperparameter{RawMaxval: 10, RawCount: ptrs.Ptr(3)} + hparams := expconf.Hyperparameters{ + "x": expconf.Hyperparameter{RawIntHyperparameter: intHparam}, + } + + // Create a new test searcher and verify brackets/rungs. + testSearchRunner := NewTestSearchRunner(t, searcherConfig, hparams) + search := testSearchRunner.method.(*asyncHalvingStoppingSearch) + + expectedRungs := []*rung{ + {UnitsNeeded: uint64(100)}, + {UnitsNeeded: uint64(300)}, + {UnitsNeeded: uint64(900)}, + } + + require.Equal(t, expectedRungs, search.Rungs) + + // Simulate the search. + testSearchRunner.run(900, 100, true) + + // Expect 10 total trials. + // Since we reported progressively worse metrics, only one trial should continue. + require.Len(t, testSearchRunner.trials, maxTrials) + stoppedAt900 := 0 + stoppedAt100 := 0 + for _, tr := range testSearchRunner.trials { + if tr.stoppedAt == 900 { + stoppedAt900++ + } + if tr.stoppedAt == 100 { + stoppedAt100++ + } + } + require.Equal(t, 1, stoppedAt900) + require.Equal(t, 9, stoppedAt100) } diff --git a/master/pkg/searcher/asha_test.go b/master/pkg/searcher/asha_test.go deleted file mode 100644 index 393dc214a2a..00000000000 --- a/master/pkg/searcher/asha_test.go +++ /dev/null @@ -1,231 +0,0 @@ -//nolint:exhaustruct -package searcher - -import ( - "testing" - - "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/schemas" - "github.com/determined-ai/determined/master/pkg/schemas/expconf" -) - -func TestASHASearcherRecords(t *testing.T) { - actual := expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInRecords(576000)), - RawDivisor: ptrs.Ptr[float64](3), - RawMaxTrials: ptrs.Ptr(12), - } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("64000R"), toOps("64000R"), toOps("64000R"), - toOps("64000R"), toOps("64000R"), toOps("64000R"), - toOps("64000R"), toOps("64000R"), - toOps("64000R 192000R"), - toOps("64000R 192000R"), - toOps("64000R 192000R"), - toOps("64000R 192000R 576000R"), - } - checkSimulation(t, newAsyncHalvingSearch(actual, true), nil, ConstantValidation, expected) -} - -func TestASHASearcherBatches(t *testing.T) { - actual := expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawDivisor: ptrs.Ptr[float64](3), - RawMaxTrials: ptrs.Ptr(12), - } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("1000B"), toOps("1000B"), toOps("1000B"), - toOps("1000B"), toOps("1000B"), toOps("1000B"), - toOps("1000B"), toOps("1000B"), - toOps("1000B 3000B"), - toOps("1000B 3000B"), - toOps("1000B 3000B"), - toOps("1000B 3000B 9000B"), - } - checkSimulation(t, newAsyncHalvingSearch(actual, true), nil, ConstantValidation, expected) -} - -func TestASHASearcherEpochs(t *testing.T) { - actual := expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInEpochs(12)), - RawDivisor: ptrs.Ptr[float64](3), - RawMaxTrials: ptrs.Ptr(12), - } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("1E"), toOps("1E"), toOps("1E"), - toOps("1E"), toOps("1E"), toOps("1E"), - toOps("1E"), toOps("1E"), - toOps("1E 4E"), - toOps("1E 4E"), - toOps("1E 4E"), - toOps("1E 4E 12E"), - } - checkSimulation(t, newAsyncHalvingSearch(actual, true), nil, ConstantValidation, expected) -} - -func TestASHASearchMethod(t *testing.T) { - testCases := []valueSimulationTestCase{ - { - name: "smaller is better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.05), - newConstantPredefinedTrial(toOps("1000B"), 0.06), - newConstantPredefinedTrial(toOps("1000B"), 0.07), - newConstantPredefinedTrial(toOps("1000B"), 0.08), - newConstantPredefinedTrial(toOps("1000B"), 0.09), - newConstantPredefinedTrial(toOps("1000B"), 0.10), - newConstantPredefinedTrial(toOps("1000B"), 0.11), - newConstantPredefinedTrial(toOps("1000B"), 0.12), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - }, - }, - }, - { - name: "early exit -- smaller is better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.02), - newEarlyExitPredefinedTrial(toOps("1000B 3000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.05), - newConstantPredefinedTrial(toOps("1000B"), 0.06), - newConstantPredefinedTrial(toOps("1000B"), 0.07), - newConstantPredefinedTrial(toOps("1000B"), 0.08), - newConstantPredefinedTrial(toOps("1000B"), 0.09), - newConstantPredefinedTrial(toOps("1000B"), 0.10), - newConstantPredefinedTrial(toOps("1000B"), 0.11), - newConstantPredefinedTrial(toOps("1000B"), 0.12), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - }, - }, - }, - { - name: "smaller is not better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.12), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.11), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.10), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.09), - newConstantPredefinedTrial(toOps("1000B"), 0.08), - newConstantPredefinedTrial(toOps("1000B"), 0.07), - newConstantPredefinedTrial(toOps("1000B"), 0.06), - newConstantPredefinedTrial(toOps("1000B"), 0.05), - newConstantPredefinedTrial(toOps("1000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.01), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - }, - }, - }, - { - name: "early exit -- smaller is not better", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.12), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.11), - newEarlyExitPredefinedTrial(toOps("1000B 3000B"), 0.10), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.09), - newConstantPredefinedTrial(toOps("1000B"), 0.08), - newConstantPredefinedTrial(toOps("1000B"), 0.07), - newConstantPredefinedTrial(toOps("1000B"), 0.06), - newConstantPredefinedTrial(toOps("1000B"), 0.05), - newConstantPredefinedTrial(toOps("1000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B"), 0.01), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(false), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - }, - }, - }, - { - name: "async promotions", - expectedTrials: []predefinedTrial{ - // The first trial is promoted due to asynchronous - // promotions despite being below top 1/3 of trials in - // base rung. - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.10), - newConstantPredefinedTrial(toOps("1000B"), 0.11), - newEarlyExitPredefinedTrial(toOps("1000B"), 0.12), - newConstantPredefinedTrial(toOps("1000B 3000B 9000B"), 0.01), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.02), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.03), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.04), - newConstantPredefinedTrial(toOps("1000B"), 0.05), - newConstantPredefinedTrial(toOps("1000B"), 0.06), - newConstantPredefinedTrial(toOps("1000B"), 0.07), - newConstantPredefinedTrial(toOps("1000B"), 0.08), - newConstantPredefinedTrial(toOps("1000B"), 0.09), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(12), - RawDivisor: ptrs.Ptr[float64](3), - }, - }, - }, - { - name: "single rung bracket", - expectedTrials: []predefinedTrial{ - // The first trial is promoted due to asynchronous - // promotions despite being below top 1/3 of trials in - // base rung. - newConstantPredefinedTrial(toOps("9000B"), 0.05), - newConstantPredefinedTrial(toOps("9000B"), 0.06), - newConstantPredefinedTrial(toOps("9000B"), 0.07), - newConstantPredefinedTrial(toOps("9000B"), 0.08), - }, - config: expconf.SearcherConfig{ - RawSmallerIsBetter: ptrs.Ptr(true), - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(1), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(4), - RawDivisor: ptrs.Ptr[float64](3), - }, - }, - }, - } - - runValueSimulationTestCases(t, testCases) -} diff --git a/master/pkg/searcher/custom_search.go b/master/pkg/searcher/custom_search.go deleted file mode 100644 index aaf37062f1f..00000000000 --- a/master/pkg/searcher/custom_search.go +++ /dev/null @@ -1,147 +0,0 @@ -package searcher - -import ( - "encoding/json" - - "github.com/pkg/errors" - "google.golang.org/protobuf/types/known/structpb" - - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/schemas/expconf" - "github.com/determined-ai/determined/proto/pkg/experimentv1" -) - -type ( - customSearchState struct { - SearchMethodType SearchMethodType `json:"search_method_type"` - SearcherEventQueue *SearcherEventQueue - CustomSearchProgress float64 - } - - customSearch struct { - expconf.CustomConfig - customSearchState - } -) - -func newCustomSearch(config expconf.CustomConfig) SearchMethod { - return &customSearch{ - CustomConfig: config, - customSearchState: customSearchState{ - SearchMethodType: CustomSearch, - SearcherEventQueue: newSearcherEventQueue(), - }, - } -} - -func (s *customSearch) initialOperations(ctx context) ([]Operation, error) { - // For this method and all the other methods in customSearch, the ID will be set in Enqueue. - s.SearcherEventQueue.Enqueue(&experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_InitialOperations{ - InitialOperations: &experimentv1.InitialOperations{}, - }, - }) - - return nil, nil -} - -func (s *customSearch) getSearcherEventQueue() *SearcherEventQueue { - return s.SearcherEventQueue -} - -func (s *customSearch) setCustomSearcherProgress(progress float64) { - s.customSearchState.CustomSearchProgress = progress -} - -func (s *customSearch) trialProgress( - ctx context, - requestID model.RequestID, - progress PartialUnits, -) { - s.SearcherEventQueue.Enqueue(&experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_TrialProgress{ - TrialProgress: &experimentv1.TrialProgress{ - RequestId: requestID.String(), - PartialUnits: float64(progress), - }, - }, - }) -} - -func (s *customSearch) trialCreated(ctx context, requestID model.RequestID) ([]Operation, error) { - s.SearcherEventQueue.Enqueue(&experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_TrialCreated{ - TrialCreated: &experimentv1.TrialCreated{ - RequestId: requestID.String(), - }, - }, - }) - return nil, nil -} - -func (s *customSearch) progress( - trialProgress map[model.RequestID]PartialUnits, - trialsClosed map[model.RequestID]bool, -) float64 { - return s.customSearchState.CustomSearchProgress -} - -func (s *customSearch) validationCompleted( - ctx context, requestID model.RequestID, metric interface{}, op ValidateAfter, -) ([]Operation, error) { - protoMetric, err := structpb.NewValue(metric) - if err != nil { - return nil, errors.Wrapf(err, "illegal type for metric=%v", metric) - } - s.SearcherEventQueue.Enqueue(&experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_ValidationCompleted{ - ValidationCompleted: &experimentv1.ValidationCompleted{ - RequestId: requestID.String(), - ValidateAfterLength: op.ToProto().Length, - Metric: protoMetric, - }, - }, - }) - return nil, nil -} - -func (s *customSearch) trialExitedEarly( - ctx context, requestID model.RequestID, exitedReason model.ExitedReason, -) ([]Operation, error) { - s.SearcherEventQueue.Enqueue(&experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_TrialExitedEarly{ - TrialExitedEarly: &experimentv1.TrialExitedEarly{ - RequestId: requestID.String(), - ExitedReason: exitedReason.ToSearcherProto(), - }, - }, - }) - return nil, nil -} - -func (s *customSearch) trialClosed(ctx context, requestID model.RequestID) ([]Operation, error) { - s.SearcherEventQueue.Enqueue(&experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_TrialClosed{ - TrialClosed: &experimentv1.TrialClosed{ - RequestId: requestID.String(), - }, - }, - }) - return nil, nil -} - -func (s *customSearch) Snapshot() (json.RawMessage, error) { - return json.Marshal(s.customSearchState) -} - -func (s *customSearch) Restore(state json.RawMessage) error { - if state == nil { - return nil - } - return json.Unmarshal(state, &s.customSearchState) -} - -func (s *customSearch) Unit() expconf.Unit { - // TODO: Does unit make sense for custom search? - return expconf.Batches -} diff --git a/master/pkg/searcher/custom_search_test.go b/master/pkg/searcher/custom_search_test.go deleted file mode 100644 index 93a91779fb0..00000000000 --- a/master/pkg/searcher/custom_search_test.go +++ /dev/null @@ -1,198 +0,0 @@ -//nolint:exhaustruct -package searcher - -import ( - "testing" - - "github.com/stretchr/testify/require" - "google.golang.org/protobuf/types/known/structpb" - - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/nprand" - "github.com/determined-ai/determined/master/pkg/schemas/expconf" - "github.com/determined-ai/determined/proto/pkg/experimentv1" -) - -type idMaker struct { - num int32 -} - -func (m *idMaker) next() int32 { - m.num++ - return m.num -} - -// Test a few methods (not all because they are similar) from CustomSearchMethod and the queue. -func TestCustomSearchMethod(t *testing.T) { - config := expconf.SearcherConfig{ - RawCustomConfig: &expconf.CustomConfig{}, - } - - customSearchMethod := NewSearchMethod(config) - rand := nprand.New(0) - ctx := context{rand: rand} - - queue := customSearchMethod.(CustomSearchMethod).getSearcherEventQueue() - require.Zero(t, len(queue.events)) - - var expEvents []*experimentv1.SearcherEvent - var ids idMaker - - // Add initialOperations. - _, err := customSearchMethod.initialOperations(ctx) - require.NoError(t, err) - - expEvents = append(expEvents, &experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_InitialOperations{ - InitialOperations: &experimentv1.InitialOperations{}, - }, - Id: ids.next(), - }) - require.Equal(t, expEvents, queue.GetEvents()) - - // Add trialExitedEarly. - requestID := model.NewRequestID(rand) - exitedReason := model.Errored - _, err = customSearchMethod.trialExitedEarly(ctx, requestID, exitedReason) - require.NoError(t, err) - trialExitedEarlyEvent := experimentv1.SearcherEvent_TrialExitedEarly{ - TrialExitedEarly: &experimentv1.TrialExitedEarly{ - RequestId: requestID.String(), - ExitedReason: experimentv1.TrialExitedEarly_EXITED_REASON_UNSPECIFIED, - }, - } - expEvents = append(expEvents, &experimentv1.SearcherEvent{ - Event: &trialExitedEarlyEvent, - Id: ids.next(), - }) - require.Equal(t, expEvents, queue.GetEvents()) - - // Add validationAfter. - validateAfterOp := ValidateAfter{requestID, uint64(200)} - metric := float64(10.3) - _, err = customSearchMethod.validationCompleted(ctx, requestID, metric, validateAfterOp) - require.NoError(t, err) - protoMetric, err := structpb.NewValue(metric) - require.NoError(t, err) - validationCompletedEvent := experimentv1.SearcherEvent_ValidationCompleted{ - ValidationCompleted: &experimentv1.ValidationCompleted{ - RequestId: requestID.String(), - Metric: protoMetric, - ValidateAfterLength: validateAfterOp.ToProto().Length, - }, - } - expEvents = append(expEvents, &experimentv1.SearcherEvent{ - Event: &validationCompletedEvent, - Id: ids.next(), - }) - require.Equal(t, expEvents, queue.GetEvents()) - - // Add ValidationCompleted with a dictionary of all metrics. - validateAfterOp2 := ValidateAfter{requestID, uint64(300)} - allMetrics := map[string]interface{}{ - "themetric": float64(10.3), - } - _, err = customSearchMethod.validationCompleted(ctx, requestID, allMetrics, validateAfterOp2) - require.NoError(t, err) - - protoAllMetrics, err := structpb.NewValue(allMetrics) - require.NoError(t, err) - validationCompletedEvent2 := experimentv1.SearcherEvent_ValidationCompleted{ - ValidationCompleted: &experimentv1.ValidationCompleted{ - RequestId: requestID.String(), - Metric: protoAllMetrics, - ValidateAfterLength: validateAfterOp2.ToProto().Length, - }, - } - expEvents = append(expEvents, &experimentv1.SearcherEvent{ - Event: &validationCompletedEvent2, - Id: ids.next(), - }) - require.Equal(t, expEvents, queue.GetEvents()) - - // Add trialProgress. - trialProgress := 0.02 - customSearchMethod.(CustomSearchMethod).trialProgress(ctx, requestID, PartialUnits(trialProgress)) - require.NoError(t, err) - - expEvents = append(expEvents, &experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_TrialProgress{ - TrialProgress: &experimentv1.TrialProgress{ - RequestId: requestID.String(), - PartialUnits: trialProgress, - }, - }, - Id: ids.next(), - }) - require.Equal(t, expEvents, queue.GetEvents()) - - // Set customSearcherProgress. - searcherProgress := 0.4 - customSearchMethod.(CustomSearchMethod).setCustomSearcherProgress(searcherProgress) - require.InEpsilon(t, searcherProgress, customSearchMethod.progress(nil, nil), 0.01) - - // Check removeUpto. - err = queue.RemoveUpTo(2) - require.NoError(t, err) - require.Equal(t, expEvents[2:], queue.events) -} - -func TestCustomSearchWatcher(t *testing.T) { - config := expconf.SearcherConfig{ - RawCustomConfig: &expconf.CustomConfig{}, - } - - customSearchMethod := NewSearchMethod(config) - rand := nprand.New(0) - ctx := context{rand: rand} - - queue := customSearchMethod.(CustomSearchMethod).getSearcherEventQueue() - w, err := queue.Watch() - require.NoError(t, err) - - // Should immediately receive initial status. - select { - case <-w.C: - t.Fatal("received a non-empty channel but should not have") - default: - } - - var expEvents []*experimentv1.SearcherEvent - var ids idMaker - - // Add initialOperations. - _, err = customSearchMethod.initialOperations(ctx) - require.NoError(t, err) - - expEvents = append(expEvents, &experimentv1.SearcherEvent{ - Event: &experimentv1.SearcherEvent_InitialOperations{ - InitialOperations: &experimentv1.InitialOperations{}, - }, - Id: ids.next(), - }) - require.Equal(t, expEvents, queue.GetEvents()) - - // Receive events in the watcher channel after it's added. - select { - case eventsInWatcher := <-w.C: - require.Equal(t, queue.GetEvents(), eventsInWatcher) - default: - t.Fatal("did not receive events") - } - - // Unwatching should work. - queue.Unwatch(w.ID) - - // Receive events when you create a new watcher after events exist. - w2, err := queue.Watch() - require.NoError(t, err) - select { - case eventsInWatcher2 := <-w2.C: - require.Equal(t, queue.GetEvents(), eventsInWatcher2) - default: - t.Fatal("did not receive events") - } - - // Unwatching should work. - queue.Unwatch(w2.ID) -} diff --git a/master/pkg/searcher/custom_searcher_events_queue.go b/master/pkg/searcher/custom_searcher_events_queue.go deleted file mode 100644 index 43de67fd5b9..00000000000 --- a/master/pkg/searcher/custom_searcher_events_queue.go +++ /dev/null @@ -1,158 +0,0 @@ -package searcher - -import ( - "encoding/json" - "fmt" - - "github.com/google/uuid" - "google.golang.org/protobuf/encoding/protojson" - - "github.com/pkg/errors" - - "github.com/determined-ai/determined/proto/pkg/experimentv1" -) - -type ( - // SearcherEventQueue stores the list of custom searcher events and the event that was event that - // was processed last by client and acknowledged by master. - SearcherEventQueue struct { - events []*experimentv1.SearcherEvent - eventCount int32 - watchers map[uuid.UUID]chan<- []*experimentv1.SearcherEvent - } - - // searcherEventQueueJSON is used internally for JSON marshaling purposes. - searcherEventQueueJSON struct { - Events []json.RawMessage `json:"custom_searcher_events"` - EventCount int32 `json:"custom_searcher_event_count"` - } - - // EventsWatcher has a channel which allows communication to the GET searcher events API. - EventsWatcher struct { - ID uuid.UUID - C <-chan []*experimentv1.SearcherEvent - } -) - -func newSearcherEventQueue() *SearcherEventQueue { - return &SearcherEventQueue{ - events: nil, - eventCount: 0, - watchers: map[uuid.UUID]chan<- []*experimentv1.SearcherEvent{}, - } -} - -func (q *SearcherEventQueue) sendEventsToWatcher( - id uuid.UUID, - w chan<- []*experimentv1.SearcherEvent, -) { - events := make([]*experimentv1.SearcherEvent, len(q.events)) - copy(events, q.events) - w <- events - close(w) - delete(q.watchers, id) -} - -// Watch creates an eventsWatcher. If any events are currently in the queue, they are immediately -// sent; otherwise, the channel in the result will block until an event comes in. -func (q *SearcherEventQueue) Watch() (EventsWatcher, error) { - // Buffer size is 1 because we don't want to block until another goroutine receives from this - // channel and only one event list can be sent to a channel. - w := make(chan []*experimentv1.SearcherEvent, 1) - id := uuid.New() - q.watchers[id] = w - - if len(q.events) > 0 { - q.sendEventsToWatcher(id, w) - } - return EventsWatcher{ID: id, C: w}, nil -} - -// Unwatch unregisters an eventsWatcher. -func (q *SearcherEventQueue) Unwatch(id uuid.UUID) { - if q == nil { - return - } - delete(q.watchers, id) -} - -// Enqueue adds an event to the queue, setting its ID automatically. -func (q *SearcherEventQueue) Enqueue(event *experimentv1.SearcherEvent) { - q.eventCount++ - event.Id = q.eventCount - q.events = append(q.events, event) - - // Add events to all watcher channels. - for id, w := range q.watchers { - q.sendEventsToWatcher(id, w) - } -} - -// GetEvents returns all the events. -func (q *SearcherEventQueue) GetEvents() []*experimentv1.SearcherEvent { - return q.events -} - -// RemoveUpTo removes all events up to and including the one with the given event ID. -func (q *SearcherEventQueue) RemoveUpTo(eventID int) error { - maxID := int(q.eventCount) - minID := maxID - (len(q.events) - 1) - if !(minID <= eventID && eventID <= maxID) { - return fmt.Errorf("event %d not found", eventID) - } - q.events = q.events[eventID-minID+1:] - return nil -} - -// MarshalJSON implements the json.Marshaler interface. -func (q *SearcherEventQueue) MarshalJSON() ([]byte, error) { - events, err := marshalEvents(q.events) - if err != nil { - return nil, err - } - - return json.Marshal(searcherEventQueueJSON{ - Events: events, - EventCount: q.eventCount, - }) -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (q *SearcherEventQueue) UnmarshalJSON(data []byte) error { - var js searcherEventQueueJSON - if err := json.Unmarshal(data, &js); err != nil { - return err - } - events, err := unmarshalEvents(js.Events) - if err != nil { - return err - } - q.events = events - q.eventCount = js.EventCount - q.watchers = map[uuid.UUID]chan<- []*experimentv1.SearcherEvent{} - return nil -} - -func marshalEvents(pbEvents []*experimentv1.SearcherEvent) ([]json.RawMessage, error) { - var events []json.RawMessage - for _, pbEvent := range pbEvents { - event, err := protojson.Marshal(pbEvent) - if err != nil { - return nil, errors.Wrap(err, "failed to marshal searcher event") - } - events = append(events, event) - } - return events, nil -} - -func unmarshalEvents(events []json.RawMessage) ([]*experimentv1.SearcherEvent, error) { - var pbEvents []*experimentv1.SearcherEvent - for _, event := range events { - var pbEvent experimentv1.SearcherEvent - if err := protojson.Unmarshal(event, &pbEvent); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal searcher event") - } - pbEvents = append(pbEvents, &pbEvent) - } - return pbEvents, nil -} diff --git a/master/pkg/searcher/grid.go b/master/pkg/searcher/grid.go index a83c9cc6e03..e6f9c2a469a 100644 --- a/master/pkg/searcher/grid.go +++ b/master/pkg/searcher/grid.go @@ -12,8 +12,9 @@ import ( type ( // gridSearchState stores the state for grid. The state will track the remaining hp settings - // that have yet to be created for evaluation. PendingTrials tracks how many trials have - // active workloads and is used to check max_concurrent_trials for the searcher is respected. + // that have yet to be created for evaluation. RemainingTrials tracks how many trials are + // currently in progress and is used to check max_concurrent_trials for the searcher is + // respected. // Tracking searcher type on restart gives us the ability to differentiate grid searches // in a shim if needed. gridSearchState struct { @@ -27,7 +28,6 @@ type ( defaultSearchMethod expconf.GridConfig gridSearchState - trials int } ) @@ -41,29 +41,26 @@ func newGridSearch(config expconf.GridConfig) SearchMethod { } } -func (s *gridSearch) initialOperations(ctx context) ([]Operation, error) { +func (s *gridSearch) initialTrials(ctx context) ([]Action, error) { grid := newHyperparameterGrid(ctx.hparams) - s.trials = len(grid) s.RemainingTrials = append(s.RemainingTrials, grid...) - initialTrials := s.trials + initialTrials := len(grid) if s.MaxConcurrentTrials() > 0 { - initialTrials = mathx.Min(s.trials, s.MaxConcurrentTrials()) + initialTrials = mathx.Min(initialTrials, s.MaxConcurrentTrials()) } - var ops []Operation + var actions []Action for trial := 0; trial < initialTrials; trial++ { params := s.RemainingTrials[len(s.RemainingTrials)-1] s.RemainingTrials = s.RemainingTrials[:len(s.RemainingTrials)-1] - create := NewCreate(ctx.rand, params, model.TrialWorkloadSequencerType) - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.MaxLength().Units)) - ops = append(ops, NewClose(create.RequestID)) + create := NewCreate(ctx.rand, params) + actions = append(actions, create) s.PendingTrials++ } - return ops, nil + return actions, nil } func (s *gridSearch) progress( - trialProgress map[model.RequestID]PartialUnits, + trialProgress map[model.RequestID]float64, trialsClosed map[model.RequestID]bool, ) float64 { if s.MaxConcurrentTrials() > 0 && s.PendingTrials > s.MaxConcurrentTrials() { @@ -74,43 +71,38 @@ func (s *gridSearch) progress( // and are not replaced with a new config as with random search // - Other early-exit trials contribute max_length units // - In progress trials contribute units trained - unitsCompleted := 0. - // trialsClosed includes InvalidHP trials and other exited trials - for range trialsClosed { - unitsCompleted += float64(s.MaxLength().Units) - } - // trialProgress records units trained for all trials except for InvalidHP trials. - // This can overlap with trialsClosed so we need to be sure to not double count. + trialProgresses := 0. + for k, v := range trialProgress { - if !trialsClosed[k] { - unitsCompleted += float64(v) + if trialsClosed[k] { + trialProgresses += 1.0 + } else { + trialProgresses += v } } - unitsExpected := s.MaxLength().Units * uint64(s.trials) - return unitsCompleted / float64(unitsExpected) + + return trialProgresses / float64(len(trialProgress)) } // trialExitedEarly does nothing since grid does not take actions based on // search status or progress. func (s *gridSearch) trialExitedEarly( ctx context, requestID model.RequestID, exitedReason model.ExitedReason, -) ([]Operation, error) { +) ([]Action, error) { return nil, nil } -func (s *gridSearch) trialClosed(ctx context, _ model.RequestID) ([]Operation, error) { +func (s *gridSearch) trialExited(ctx context, _ model.RequestID) ([]Action, error) { s.PendingTrials-- - var ops []Operation + var actions []Action if len(s.RemainingTrials) > 0 { params := s.RemainingTrials[len(s.RemainingTrials)-1] s.RemainingTrials = s.RemainingTrials[:len(s.RemainingTrials)-1] - create := NewCreate(ctx.rand, params, model.TrialWorkloadSequencerType) - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.MaxLength().Units)) - ops = append(ops, NewClose(create.RequestID)) + create := NewCreate(ctx.rand, params) + actions = append(actions, create) s.PendingTrials++ } - return ops, nil + return actions, nil } func newHyperparameterGrid(params expconf.Hyperparameters) []HParamSample { @@ -291,3 +283,7 @@ func (s *gridSearch) Restore(state json.RawMessage) error { } return json.Unmarshal(state, &s.gridSearchState) } + +func (s *gridSearch) Type() SearchMethodType { + return s.SearchMethodType +} diff --git a/master/pkg/searcher/grid_test.go b/master/pkg/searcher/grid_test.go index cffb52f314d..705124660fc 100644 --- a/master/pkg/searcher/grid_test.go +++ b/master/pkg/searcher/grid_test.go @@ -3,14 +3,16 @@ package searcher import ( "encoding/json" + "slices" "strconv" "testing" + "github.com/stretchr/testify/require" + "github.com/pkg/errors" "gotest.tools/assert" "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) @@ -256,63 +258,36 @@ func TestGridIntCountNegative(t *testing.T) { assert.DeepEqual(t, actual, expected) } -func TestGridSearcherRecords(t *testing.T) { - actual := expconf.GridConfig{RawMaxLength: ptrs.Ptr(expconf.NewLengthInRecords(19200))} - actual = schemas.WithDefaults(actual) - params := generateHyperparameters([]int{2, 1, 3}) - expected := [][]ValidateAfter{ - toOps("19200R"), toOps("19200R"), toOps("19200R"), - toOps("19200R"), toOps("19200R"), toOps("19200R"), +func TestGridSearchMethod(t *testing.T) { + maxConcurrentTrials := 2 + gridConfig := expconf.GridConfig{ + RawMaxConcurrentTrials: ptrs.Ptr(maxConcurrentTrials), } - searchMethod := newGridSearch(actual) - checkSimulation(t, searchMethod, params, ConstantValidation, expected) -} - -func TestGridSearcherBatches(t *testing.T) { - actual := expconf.GridConfig{RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(300))} - actual = schemas.WithDefaults(actual) - params := generateHyperparameters([]int{2, 1, 3}) - expected := [][]ValidateAfter{ - toOps("300B"), toOps("300B"), toOps("300B"), - toOps("300B"), toOps("300B"), toOps("300B"), + searcherConfig := expconf.SearcherConfig{ + RawGridConfig: &gridConfig, + RawMetric: ptrs.Ptr("loss"), } - searchMethod := newGridSearch(actual) - checkSimulation(t, searchMethod, params, ConstantValidation, expected) -} - -func TestGridSearcherEpochs(t *testing.T) { - actual := expconf.GridConfig{RawMaxLength: ptrs.Ptr(expconf.NewLengthInEpochs(3))} - actual = schemas.WithDefaults(actual) - params := generateHyperparameters([]int{2, 1, 3}) - expected := [][]ValidateAfter{ - toOps("3E"), toOps("3E"), toOps("3E"), - toOps("3E"), toOps("3E"), toOps("3E"), - } - searchMethod := newGridSearch(actual) - checkSimulation(t, searchMethod, params, ConstantValidation, expected) -} - -func TestGridSearchMethod(t *testing.T) { - testCases := []valueSimulationTestCase{ - { - name: "test grid search method", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("300B"), 0.1), - newConstantPredefinedTrial(toOps("300B"), 0.1), - newConstantPredefinedTrial(toOps("300B"), 0.1), - newConstantPredefinedTrial(toOps("300B"), 0.1), - newConstantPredefinedTrial(toOps("300B"), 0.1), - newEarlyExitPredefinedTrial(toOps("300B"), .1), - }, - config: expconf.SearcherConfig{ - RawGridConfig: &expconf.GridConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(300)), - RawMaxConcurrentTrials: ptrs.Ptr(2), - }, + hparams := expconf.Hyperparameters{ + "a": expconf.Hyperparameter{ + RawIntHyperparameter: &expconf.IntHyperparameter{ + RawMinval: 0, RawMaxval: 3, RawCount: ptrs.Ptr(4), }, - hparams: generateHyperparameters([]int{2, 1, 3}), }, } + allHparams := []int{0, 1, 2, 3} + + testSearchRunner := NewTestSearchRunner(t, searcherConfig, hparams) - runValueSimulationTestCases(t, testCases) + // Simulate the search and check resulting trials. + testSearchRunner.run(100, 10, false) + + // 4 total trials for each hparam in space, all should run to completion. + var runHparams []int + require.Len(t, testSearchRunner.trials, len(allHparams)) + for _, tr := range testSearchRunner.trials { + require.False(t, tr.stopped) + runHparams = append(runHparams, tr.hparams["a"].(int)) + } + slices.Sort(runHparams) + require.Equal(t, allHparams, runHparams) } diff --git a/master/pkg/searcher/operations.go b/master/pkg/searcher/operations.go deleted file mode 100644 index 7fb91d8f960..00000000000 --- a/master/pkg/searcher/operations.go +++ /dev/null @@ -1,295 +0,0 @@ -package searcher - -import ( - "encoding/json" - "fmt" - - "github.com/google/uuid" - - "github.com/determined-ai/determined/proto/pkg/experimentv1" - - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/nprand" -) - -// Operation represents the base interface for possible operations that a search method can return. -type Operation interface{} - -type ( - // OperationType encodes the underlying type of an Operation for serialization. - OperationType int - - // OperationWithType is an operation with a serializable repr of its underlying type. - OperationWithType struct { - OperationType - Operation - } - - // OperationList is []Operation that handles marshaling and unmarshaling heterogeneous - // operations to and from their correct underlying types. - OperationList []Operation -) - -// All the operation types that support serialization. -const ( - CreateOperation OperationType = 0 - TrainOperation OperationType = 1 - ValidateOperation OperationType = 2 - CloseOperation OperationType = 4 - ValidateAfterOperation OperationType = 5 - SetSearcherProgressOperation OperationType = 6 -) - -// MarshalJSON implements json.Marshaler. -func (l OperationList) MarshalJSON() ([]byte, error) { - var typedOps []OperationWithType - for _, op := range l { - typedOp := OperationWithType{Operation: op} - switch op.(type) { - case Create: - typedOp.OperationType = CreateOperation - case ValidateAfter: - typedOp.OperationType = ValidateAfterOperation - case Close: - typedOp.OperationType = CloseOperation - case SetSearcherProgress: - typedOp.OperationType = SetSearcherProgressOperation - default: - return nil, fmt.Errorf("unable to serialize %T as operation", op) - } - typedOps = append(typedOps, typedOp) - } - return json.Marshal(typedOps) -} - -// UnmarshalJSON implements json.Unmarshaler. -func (l *OperationList) UnmarshalJSON(b []byte) error { - var typedOps []OperationWithType - if err := json.Unmarshal(b, &typedOps); err != nil { - return err - } - var ops OperationList - for _, typedOp := range typedOps { - b, err := json.Marshal(typedOp.Operation) - if err != nil { - return err - } - switch typedOp.OperationType { - case CreateOperation: - var op Create - if err := json.Unmarshal(b, &op); err != nil { - return err - } - ops = append(ops, op) - case ValidateAfterOperation: - var op ValidateAfter - if err := json.Unmarshal(b, &op); err != nil { - return err - } - ops = append(ops, op) - case CloseOperation: - var op Close - if err := json.Unmarshal(b, &op); err != nil { - return err - } - ops = append(ops, op) - default: - return fmt.Errorf("unable to deserialize %d as operation", typedOp.OperationType) - } - } - *l = ops - return nil -} - -// Requested is a convenience interface for operations that were requested by a searcher method -// for a specific trial. -type Requested interface { - GetRequestID() model.RequestID -} - -// Create a new trial for the search method. -type Create struct { - RequestID model.RequestID `json:"request_id"` - // TrialSeed must be a value between 0 and 2**31 - 1. - TrialSeed uint32 `json:"trial_seed"` - Hparams HParamSample `json:"hparams"` - Checkpoint *Checkpoint `json:"checkpoint"` - WorkloadSequencerType model.WorkloadSequencerType `json:"workload_sequencer_type"` -} - -// NewCreate initializes a new Create operation with a new request ID and the given hyperparameters. -func NewCreate( - rand *nprand.State, s HParamSample, sequencerType model.WorkloadSequencerType, -) Create { - return Create{ - RequestID: model.NewRequestID(rand), - TrialSeed: uint32(rand.Int64n(1 << 31)), - Hparams: s, - WorkloadSequencerType: sequencerType, - } -} - -// NewCreateFromCheckpoint initializes a new Create operation with a new request ID and the given -// hyperparameters and checkpoint to initially load from. -func NewCreateFromCheckpoint( - rand *nprand.State, s HParamSample, parentID model.RequestID, - sequencerType model.WorkloadSequencerType, -) Create { - create := NewCreate(rand, s, sequencerType) - create.Checkpoint = &Checkpoint{parentID} - return create -} - -// CreateFromProto initializes a new Create operation from an -// experimentv1.SearcherOperation_CreateTrial. -func CreateFromProto( - protoSearcherOp *experimentv1.SearcherOperation_CreateTrial, - sequencerType model.WorkloadSequencerType, -) (*Create, error) { - requestID, err := uuid.Parse(protoSearcherOp.CreateTrial.RequestId) - if err != nil { - return nil, fmt.Errorf("unparseable trial ID %s", protoSearcherOp.CreateTrial.RequestId) - } - // TODO: Determine whether trial seed is set on client or on master. - trialSeed := uint32(42) - var hparams HParamSample - if err = json.Unmarshal([]byte(protoSearcherOp.CreateTrial.Hyperparams), &hparams); err != nil { - // TODO: Should we return this err instead? - return nil, fmt.Errorf("unparseable hyperparams %s", protoSearcherOp.CreateTrial.Hyperparams) - } - return &Create{ - RequestID: model.RequestID(requestID), - TrialSeed: trialSeed, - Hparams: hparams, - WorkloadSequencerType: sequencerType, - }, nil -} - -func (create Create) String() string { - if create.Checkpoint == nil { - return fmt.Sprintf("{Create %s, seed %d}", create.RequestID, create.TrialSeed) - } - return fmt.Sprintf( - "{Create %s, seed %d, parent %v}", create.RequestID, create.TrialSeed, - create.Checkpoint.RequestID, - ) -} - -// GetRequestID implemented Requested. -func (create Create) GetRequestID() model.RequestID { return create.RequestID } - -// Checkpoint indicates which trial the trial created by a Create should inherit from. -type Checkpoint struct { - RequestID model.RequestID -} - -func (c Checkpoint) String() string { - return fmt.Sprintf("{Checkpoint %s}", c.RequestID) -} - -// ValidateAfter is an operation emitted by search methods to signal the trial train until -// its total batches trained equals the specified length. -type ValidateAfter struct { - RequestID model.RequestID - Length uint64 -} - -// NewValidateAfter returns a new train operation. -func NewValidateAfter(requestID model.RequestID, length uint64) ValidateAfter { - return ValidateAfter{requestID, length} -} - -// ValidateAfterFromProto creates a ValidateAfter operation from its protobuf representation. -func ValidateAfterFromProto( - op *experimentv1.TrialOperation_ValidateAfter, -) (*ValidateAfter, error) { - requestID, err := uuid.Parse(op.ValidateAfter.RequestId) - if err != nil { - return nil, fmt.Errorf("unparseable trial ID %s", op.ValidateAfter.RequestId) - } - return &ValidateAfter{ - RequestID: model.RequestID(requestID), - Length: op.ValidateAfter.Length, - }, nil -} - -func (t ValidateAfter) String() string { - return fmt.Sprintf("{ValidateAfter %s, %v}", t.RequestID, t.Length) -} - -// GetRequestID implemented Requested. -func (t ValidateAfter) GetRequestID() model.RequestID { return t.RequestID } - -// ToProto converts a searcher.ValidateAfter to its protobuf representation. -func (t ValidateAfter) ToProto() *experimentv1.ValidateAfterOperation { - return &experimentv1.ValidateAfterOperation{Length: t.Length} -} - -// SetSearcherProgress sets the progress of the custom searcher. -type SetSearcherProgress struct { - Progress float64 -} - -// SetSearcherProgressFromProto creates a SetSearcherProgress from its protobuf representation. -func SetSearcherProgressFromProto( - op *experimentv1.SearcherOperation_SetSearcherProgress, -) SetSearcherProgress { - return SetSearcherProgress{Progress: op.SetSearcherProgress.Progress} -} - -// Close the trial with the given trial ID. -type Close struct { - RequestID model.RequestID `json:"request_id"` -} - -// NewClose initializes a new Close operation for the request ID. -func NewClose(requestID model.RequestID) Close { - return Close{ - RequestID: requestID, - } -} - -// CloseFromProto returns a Close operation from its protobuf representation. -func CloseFromProto( - op *experimentv1.SearcherOperation_CloseTrial, -) (*Close, error) { - requestID, err := uuid.Parse(op.CloseTrial.RequestId) - if err != nil { - return nil, fmt.Errorf("unparseable trial ID %s", op.CloseTrial.RequestId) - } - return &Close{ - RequestID: model.RequestID(requestID), - }, nil -} - -func (close Close) String() string { - return fmt.Sprintf("{Close %s}", close.RequestID) -} - -// GetRequestID implemented Requested. -func (close Close) GetRequestID() model.RequestID { return close.RequestID } - -// Shutdown marks the searcher as completed. -type Shutdown struct { - Cancel bool - Failure bool -} - -// NewShutdown initializes a Shutdown operation for the searcher. -func NewShutdown() Shutdown { - return Shutdown{} -} - -// ShutdownFromProto creates a Shutdown from its protobuf representation. -func ShutdownFromProto( - op *experimentv1.SearcherOperation_ShutDown, -) (*Shutdown, error) { - return &Shutdown{ - Cancel: op.ShutDown.Cancel, - Failure: op.ShutDown.Failure, - }, nil -} - -func (shutdown Shutdown) String() string { - return fmt.Sprintf("{Shutdown Cancel: %v Failure: %v}", shutdown.Cancel, shutdown.Failure) -} diff --git a/master/pkg/searcher/random.go b/master/pkg/searcher/random.go index 3d6a065ea3a..f626c98e0bb 100644 --- a/master/pkg/searcher/random.go +++ b/master/pkg/searcher/random.go @@ -14,7 +14,7 @@ type ( // randomSearchState stores the state for random. Since not all trials are always created at // initialization, we need to track CreatedTrials so we know whether we need to create more // trials when workloads complete so that we reach MaxTrials. PendingTrials tracks active - // workloads and is used to check max_concurrent_trials for the searcher is respected. + // trials and is used to check max_concurrent_trials for the searcher is respected. // Tracking searcher type on restart gives us the ability to differentiate random searches // in a shim if needed. randomSearchState struct { @@ -44,7 +44,6 @@ func newSingleSearch(config expconf.SingleConfig) SearchMethod { return &randomSearch{ RandomConfig: schemas.WithDefaults(expconf.RandomConfig{ RawMaxTrials: ptrs.Ptr(1), - RawMaxLength: ptrs.Ptr(config.MaxLength()), RawMaxConcurrentTrials: ptrs.Ptr(1), }), randomSearchState: randomSearchState{ @@ -53,25 +52,23 @@ func newSingleSearch(config expconf.SingleConfig) SearchMethod { } } -func (s *randomSearch) initialOperations(ctx context) ([]Operation, error) { - var ops []Operation +func (s *randomSearch) initialTrials(ctx context) ([]Action, error) { + var actions []Action initialTrials := s.MaxTrials() if s.MaxConcurrentTrials() > 0 { initialTrials = mathx.Min(s.MaxTrials(), s.MaxConcurrentTrials()) } for trial := 0; trial < initialTrials; trial++ { - create := NewCreate(ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.MaxLength().Units)) - ops = append(ops, NewClose(create.RequestID)) + create := NewCreate(ctx.rand, sampleAll(ctx.hparams, ctx.rand)) + actions = append(actions, create) s.CreatedTrials++ s.PendingTrials++ } - return ops, nil + return actions, nil } func (s *randomSearch) progress( - trialProgress map[model.RequestID]PartialUnits, + trialProgress map[model.RequestID]float64, trialsClosed map[model.RequestID]bool, ) float64 { if s.MaxConcurrentTrials() > 0 && s.PendingTrials > s.MaxConcurrentTrials() { @@ -82,24 +79,25 @@ func (s *randomSearch) progress( // replaced with another randomly sampled config // - Other early-exit trials contribute max_length units // - In progress trials contribute units trained - unitsCompleted := 0. - // trialProgress records units trained for all trials except for InvalidHP trials. + // trialsProgress records units trained for all trials except for InvalidHP trials. + trialProgresses := 0. + for k, v := range trialProgress { if trialsClosed[k] { - unitsCompleted += float64(s.MaxLength().Units) + trialProgresses += 1.0 } else { - unitsCompleted += float64(v) + trialProgresses += v } } - unitsExpected := s.MaxLength().Units * uint64(s.MaxTrials()) - return unitsCompleted / float64(unitsExpected) + + return trialProgresses / float64(len(trialProgress)) } // trialExitedEarly creates a new trial upon receiving an InvalidHP workload. // Otherwise, it does nothing since actions are not taken based on search status. func (s *randomSearch) trialExitedEarly( ctx context, requestID model.RequestID, exitedReason model.ExitedReason, -) ([]Operation, error) { +) ([]Action, error) { s.PendingTrials-- if s.SearchMethodType == RandomSearch { if exitedReason == model.InvalidHP || exitedReason == model.InitInvalidHP { @@ -112,18 +110,16 @@ func (s *randomSearch) trialExitedEarly( return nil, nil } -func (s *randomSearch) trialClosed(ctx context, requestID model.RequestID) ([]Operation, error) { +func (s *randomSearch) trialExited(ctx context, requestID model.RequestID) ([]Action, error) { s.PendingTrials-- - var ops []Operation + var actions []Action if s.CreatedTrials < s.MaxTrials() { - create := NewCreate(ctx.rand, sampleAll(ctx.hparams, ctx.rand), model.TrialWorkloadSequencerType) - ops = append(ops, create) - ops = append(ops, NewValidateAfter(create.RequestID, s.MaxLength().Units)) - ops = append(ops, NewClose(create.RequestID)) + create := NewCreate(ctx.rand, sampleAll(ctx.hparams, ctx.rand)) + actions = append(actions, create) s.CreatedTrials++ s.PendingTrials++ } - return ops, nil + return actions, nil } func (s *randomSearch) Snapshot() (json.RawMessage, error) { @@ -136,3 +132,7 @@ func (s *randomSearch) Restore(state json.RawMessage) error { } return json.Unmarshal(state, &s.randomSearchState) } + +func (s *randomSearch) Type() SearchMethodType { + return s.SearchMethodType +} diff --git a/master/pkg/searcher/random_test.go b/master/pkg/searcher/random_test.go index 2fdf6fb199e..caac50c158a 100644 --- a/master/pkg/searcher/random_test.go +++ b/master/pkg/searcher/random_test.go @@ -4,117 +4,49 @@ package searcher import ( "testing" + "github.com/stretchr/testify/require" + "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) -func TestRandomSearcherRecords(t *testing.T) { - actual := expconf.RandomConfig{ - RawMaxTrials: ptrs.Ptr(4), RawMaxLength: ptrs.Ptr(expconf.NewLengthInRecords(19200)), - } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("19200R"), - toOps("19200R"), - toOps("19200R"), - toOps("19200R"), - } - search := newRandomSearch(actual) - checkSimulation(t, search, nil, ConstantValidation, expected) -} - -func TestRandomSearcherBatches(t *testing.T) { - actual := expconf.RandomConfig{ - RawMaxTrials: ptrs.Ptr(4), RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(300)), - } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("300B"), - toOps("300B"), - toOps("300B"), - toOps("300B"), - } - search := newRandomSearch(actual) - checkSimulation(t, search, nil, ConstantValidation, expected) -} - -func TestRandomSearcherReproducibility(t *testing.T) { - conf := expconf.RandomConfig{ - RawMaxTrials: ptrs.Ptr(4), RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(300)), - } - conf = schemas.WithDefaults(conf) - gen := func() SearchMethod { return newRandomSearch(conf) } - checkReproducibility(t, gen, nil, defaultMetric) -} - func TestRandomSearchMethod(t *testing.T) { - testCases := []valueSimulationTestCase{ - { - name: "test random search method", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("500B"), .1), - newConstantPredefinedTrial(toOps("500B"), .1), - newConstantPredefinedTrial(toOps("500B"), .1), - newEarlyExitPredefinedTrial(toOps("500B"), .1), - }, - config: expconf.SearcherConfig{ - RawRandomConfig: &expconf.RandomConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(500)), - RawMaxTrials: ptrs.Ptr(4), - RawMaxConcurrentTrials: ptrs.Ptr(2), - }, - }, - }, - { - name: "test random search method with records", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("32017R"), .1), - newConstantPredefinedTrial(toOps("32017R"), .1), - newConstantPredefinedTrial(toOps("32017R"), .1), - newConstantPredefinedTrial(toOps("32017R"), .1), - }, - config: expconf.SearcherConfig{ - RawRandomConfig: &expconf.RandomConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInRecords(32017)), - RawMaxTrials: ptrs.Ptr(4), - }, - }, + conf := expconf.SearcherConfig{ + RawMetric: ptrs.Ptr("loss"), + RawRandomConfig: &expconf.RandomConfig{ + RawMaxTrials: ptrs.Ptr(4), + RawMaxConcurrentTrials: ptrs.Ptr(2), }, } + intHparam := &expconf.IntHyperparameter{RawMaxval: 10, RawCount: ptrs.Ptr(4)} + hparams := expconf.Hyperparameters{ + "x": expconf.Hyperparameter{RawIntHyperparameter: intHparam}, + } + testSearchRunner := NewTestSearchRunner(t, conf, hparams) - runValueSimulationTestCases(t, testCases) + // Simulate a search and verify expected run states. + testSearchRunner.run(100, 10, false) + // 4 total trials created, each with hparam in space and run to completion. + require.Len(t, testSearchRunner.trials, 4) + for _, tr := range testSearchRunner.trials { + hparam := tr.hparams["x"].(int) + require.True(t, hparam <= 10 && hparam >= 0) + require.False(t, tr.stopped) + } } func TestSingleSearchMethod(t *testing.T) { - testCases := []valueSimulationTestCase{ - { - name: "test single search method", - expectedTrials: []predefinedTrial{ - newConstantPredefinedTrial(toOps("500B"), .1), - }, - config: expconf.SearcherConfig{ - RawSingleConfig: &expconf.SingleConfig{ - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(500)), - }, - }, - }, + conf := expconf.SearcherConfig{ + RawMetric: ptrs.Ptr("loss"), + RawSingleConfig: &expconf.SingleConfig{}, } - runValueSimulationTestCases(t, testCases) -} + testSearchRunner := NewTestSearchRunner(t, conf, expconf.Hyperparameters{}) -func TestRandomSearcherSingleConcurrent(t *testing.T) { - actual := expconf.RandomConfig{ - RawMaxTrials: ptrs.Ptr(2), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInRecords(100)), - RawMaxConcurrentTrials: ptrs.Ptr(1), - } - actual = schemas.WithDefaults(actual) - expected := [][]ValidateAfter{ - toOps("100R"), - toOps("100R"), - } - search := newRandomSearch(actual) - checkSimulation(t, search, nil, ConstantValidation, expected) + // Simulate a search and verify expected run states. + testSearchRunner.run(100, 10, false) + + // Single search should create exactly one run. + require.Len(t, testSearchRunner.trials, 1) + require.False(t, testSearchRunner.trials[0].stopped) } diff --git a/master/pkg/searcher/search_method.go b/master/pkg/searcher/search_method.go index 92ba1575691..fd03096e22c 100644 --- a/master/pkg/searcher/search_method.go +++ b/master/pkg/searcher/search_method.go @@ -11,41 +11,32 @@ type context struct { hparams expconf.Hyperparameters } -// SearchMethod is the interface for hyper-parameter tuning methods. Implementations of this +// SearchMethod is the interface for hyperparameter tuning methods. Implementations of this // interface should use pointer receivers to ensure interface equality is calculated through pointer // equality. type SearchMethod interface { - // initialOperations returns a set of initial operations that the searcher would like to take. + // initialTrials returns a set of initial trials the searcher would like to create. // This should be called only once after the searcher has been created. - initialOperations(ctx context) ([]Operation, error) + initialTrials(ctx context) ([]Action, error) // trialCreated informs the searcher that a trial has been created as a result of a Create - // operation. - trialCreated(ctx context, requestID model.RequestID) ([]Operation, error) - // validationCompleted informs the searcher that the validation workload initiated by the same - // searcher has completed. It returns any new operations as a result of this workload - // completing. + // action and returns any additional Actions to perform. + trialCreated(ctx context, requestID model.RequestID) ([]Action, error) + // validationCompleted informs the searcher that a validation metric has been reported + // and returns any resulting actions. validationCompleted(ctx context, requestID model.RequestID, - metric interface{}, op ValidateAfter) ([]Operation, error) - // trialClosed informs the searcher that the trial has been closed as a result of a Close - // operation. - trialClosed(ctx context, requestID model.RequestID) ([]Operation, error) - // progress returns experiment progress as a float between 0.0 and 1.0. - progress(map[model.RequestID]PartialUnits, map[model.RequestID]bool) float64 + metrics map[string]interface{}) ([]Action, error) + // trialExited informs the searcher that the trial has exited. + trialExited(ctx context, requestID model.RequestID) ([]Action, error) + // progress returns search progress as a float between 0.0 and 1.0. + progress(map[model.RequestID]float64, map[model.RequestID]bool) float64 // trialExitedEarly informs the searcher that the trial has exited earlier than expected. trialExitedEarly( ctx context, requestID model.RequestID, exitedReason model.ExitedReason, - ) ([]Operation, error) + ) ([]Action, error) // TODO: refactor as model.Snapshotter interface or something model.Snapshotter - expconf.InUnits -} - -// CustomSearchMethod is the interface for the custom search method. -type CustomSearchMethod interface { - getSearcherEventQueue() *SearcherEventQueue - setCustomSearcherProgress(progress float64) - trialProgress(ctx context, requestID model.RequestID, progress PartialUnits) + Type() SearchMethodType } // SearchMethodType is the type of a SearchMethod. It is saved in snapshots to be used @@ -59,14 +50,10 @@ const ( RandomSearch SearchMethodType = "random" // GridSearch is the SearchMethodType for a grid searcher. GridSearch SearchMethodType = "grid" - // AdaptiveSearch is the SearchMethodType for an adaptive searcher. - AdaptiveSearch SearchMethodType = "adaptive" // ASHASearch is the SearchMethodType for an ASHA searcher. ASHASearch SearchMethodType = "asha" // AdaptiveASHASearch is the SearchMethodType for an adaptive ASHA searcher. AdaptiveASHASearch SearchMethodType = "adaptive_asha" - // CustomSearch is the SearchMethodType for a custom searcher. - CustomSearch SearchMethodType = "custom_search" ) // NewSearchMethod returns a new search method for the provided searcher configuration. @@ -79,14 +66,9 @@ func NewSearchMethod(c expconf.SearcherConfig) SearchMethod { case c.RawGridConfig != nil: return newGridSearch(*c.RawGridConfig) case c.RawAsyncHalvingConfig != nil: - if c.RawAsyncHalvingConfig.StopOnce() { - return newAsyncHalvingStoppingSearch(*c.RawAsyncHalvingConfig, c.SmallerIsBetter()) - } - return newAsyncHalvingSearch(*c.RawAsyncHalvingConfig, c.SmallerIsBetter()) + return newAsyncHalvingStoppingSearch(*c.RawAsyncHalvingConfig, c.SmallerIsBetter(), c.Metric()) case c.RawAdaptiveASHAConfig != nil: - return newAdaptiveASHASearch(*c.RawAdaptiveASHAConfig, c.SmallerIsBetter()) - case c.RawCustomConfig != nil: - return newCustomSearch(*c.RawCustomConfig) + return newAdaptiveASHASearch(*c.RawAdaptiveASHAConfig, c.SmallerIsBetter(), c.Metric()) default: panic("no searcher type specified") } @@ -94,24 +76,22 @@ func NewSearchMethod(c expconf.SearcherConfig) SearchMethod { type defaultSearchMethod struct{} -func (defaultSearchMethod) trialCreated(context, model.RequestID) ([]Operation, error) { +func (defaultSearchMethod) trialCreated(context, model.RequestID) ([]Action, error) { return nil, nil } -func (defaultSearchMethod) validationCompleted( - context, model.RequestID, interface{}, ValidateAfter, -) ([]Operation, error) { +func (defaultSearchMethod) validationCompleted(context, model.RequestID, map[string]interface{}) ([]Action, error) { return nil, nil } // nolint:unused -func (defaultSearchMethod) trialClosed(context, model.RequestID) ([]Operation, error) { +func (defaultSearchMethod) trialExited(context, model.RequestID) ([]Action, error) { return nil, nil } // nolint:unused func (defaultSearchMethod) trialExitedEarly( context, model.RequestID, model.ExitedReason, -) ([]Operation, error) { - return []Operation{Shutdown{Failure: true}}, nil +) ([]Action, error) { + return []Action{Shutdown{Failure: true}}, nil } diff --git a/master/pkg/searcher/searcher.go b/master/pkg/searcher/searcher.go index 698e6c3a3f9..8b282b369f7 100644 --- a/master/pkg/searcher/searcher.go +++ b/master/pkg/searcher/searcher.go @@ -2,7 +2,6 @@ package searcher import ( "encoding/json" - "fmt" "math" "sync" @@ -19,15 +18,13 @@ type PartialUnits float64 type ( // SearcherState encapsulates all persisted searcher state. SearcherState struct { - TrialsRequested int `json:"trials_requested"` - TrialsCreated map[model.RequestID]bool `json:"trials_created"` - TrialsClosed map[model.RequestID]bool `json:"trials_closed"` - Exits map[model.RequestID]bool `json:"exits"` - Cancels map[model.RequestID]bool `json:"cancels"` - Failures map[model.RequestID]bool `json:"failures"` - TrialProgress map[model.RequestID]PartialUnits `json:"trial_progress"` - Shutdown bool `json:"shutdown"` - CompletedOperations map[string]ValidateAfter `json:"completed_operations"` + TrialsRequested int `json:"trials_requested"` + TrialsCreated map[model.RequestID]bool `json:"trials_created"` + TrialsClosed map[model.RequestID]bool `json:"trials_closed"` + Exits map[model.RequestID]bool `json:"exits"` + Cancels map[model.RequestID]bool `json:"cancels"` + Failures map[model.RequestID]bool `json:"failures"` + TrialProgress map[model.RequestID]float64 `json:"trial_progress"` Rand *nprand.State `json:"rand"` @@ -50,43 +47,37 @@ func NewSearcher(seed uint32, method SearchMethod, hparams expconf.Hyperparamete hparams: hparams, method: method, state: SearcherState{ - Rand: nprand.New(seed), - TrialsCreated: map[model.RequestID]bool{}, - TrialsClosed: map[model.RequestID]bool{}, - Exits: map[model.RequestID]bool{}, - Cancels: map[model.RequestID]bool{}, - Failures: map[model.RequestID]bool{}, - TrialProgress: map[model.RequestID]PartialUnits{}, - CompletedOperations: map[string]ValidateAfter{}, + Rand: nprand.New(seed), + TrialsCreated: map[model.RequestID]bool{}, + TrialsClosed: map[model.RequestID]bool{}, + Exits: map[model.RequestID]bool{}, + Cancels: map[model.RequestID]bool{}, + Failures: map[model.RequestID]bool{}, + TrialProgress: map[model.RequestID]float64{}, }, } } -func unsupportedMethodError(method SearchMethod, unsupportedOp string) error { - return fmt.Errorf("%T search method does not support %s", method, unsupportedOp) -} - func (s *Searcher) context() context { return context{rand: s.state.Rand, hparams: s.hparams} } -// InitialOperations return a set of initial operations that the searcher would like to take. +// InitialTrials returns the initial trials the searcher intends to create at the start of a search. // This should be called only once after the searcher has been created. -func (s *Searcher) InitialOperations() ([]Operation, error) { +func (s *Searcher) InitialTrials() ([]Action, error) { s.mu.Lock() defer s.mu.Unlock() - operations, err := s.method.initialOperations(s.context()) + creates, err := s.method.initialTrials(s.context()) if err != nil { return nil, errors.Wrap(err, "error while fetching initial operations of search method") } - s.record(operations) - return operations, nil + s.record(creates) + return creates, nil } -// TrialCreated informs the searcher that a trial has been created as a result of a Create -// operation. -func (s *Searcher) TrialCreated(requestID model.RequestID) ([]Operation, error) { +// TrialCreated informs the searcher that a new trial has been created. +func (s *Searcher) TrialCreated(requestID model.RequestID) ([]Action, error) { s.mu.Lock() defer s.mu.Unlock() @@ -95,7 +86,7 @@ func (s *Searcher) TrialCreated(requestID model.RequestID) ([]Operation, error) operations, err := s.method.trialCreated(s.context(), requestID) if err != nil { return nil, errors.Wrapf(err, - "error while handling a trial created event: %s", requestID) + "error while handling a trial created event: %d", requestID) } s.record(operations) return operations, nil @@ -109,10 +100,10 @@ func (s *Searcher) TrialIsCreated(requestID model.RequestID) bool { return s.state.TrialsCreated[requestID] } -// TrialExitedEarly indicates to the searcher that the trial with the given trialID exited early. +// TrialExitedEarly informs the searcher that a trial has exited early. func (s *Searcher) TrialExitedEarly( requestID model.RequestID, exitedReason model.ExitedReason, -) ([]Operation, error) { +) ([]Action, error) { s.mu.Lock() defer s.mu.Unlock() @@ -139,11 +130,9 @@ func (s *Searcher) TrialExitedEarly( s.state.Exits[requestID] = true s.record(operations) - _, isCustom := s.method.(*customSearch) - // For non-custom-search methods, you can assume that trials will be created immediately. - if s.state.TrialsRequested == len(s.state.TrialsClosed) && !isCustom { + if s.state.TrialsRequested == len(s.state.TrialsClosed) { shutdown := Shutdown{Failure: len(s.state.Failures) >= s.state.TrialsRequested} - s.record([]Operation{shutdown}) + s.record([]Action{shutdown}) operations = append(operations, shutdown) } @@ -151,63 +140,53 @@ func (s *Searcher) TrialExitedEarly( } // SetTrialProgress informs the searcher of the progress of a given trial. -func (s *Searcher) SetTrialProgress(requestID model.RequestID, progress PartialUnits) { +func (s *Searcher) SetTrialProgress(requestID model.RequestID, progress float64) { s.mu.Lock() defer s.mu.Unlock() - if sMethod, ok := s.method.(*customSearch); ok { - sMethod.trialProgress(s.context(), requestID, progress) - } s.state.TrialProgress[requestID] = progress } // ValidationCompleted informs the searcher that a validation for the trial was completed. func (s *Searcher) ValidationCompleted( - requestID model.RequestID, metric interface{}, op ValidateAfter, -) ([]Operation, error) { + requestID model.RequestID, metrics map[string]interface{}, +) ([]Action, error) { s.mu.Lock() defer s.mu.Unlock() - if _, ok := s.state.CompletedOperations[op.String()]; ok { - return nil, fmt.Errorf("operation %v was already completed", op) - } - - operations, err := s.method.validationCompleted(s.context(), requestID, metric, op) + operations, err := s.method.validationCompleted(s.context(), requestID, metrics) if err != nil { - return nil, errors.Wrapf(err, "error while handling a workload completed event: %s", requestID) + return nil, errors.Wrapf(err, "error while handling a validation completed event: %d", requestID) } - s.state.CompletedOperations[op.String()] = op s.record(operations) return operations, nil } -// TrialClosed informs the searcher that the trial has been closed as a result of a Close operation. -func (s *Searcher) TrialClosed(requestID model.RequestID) ([]Operation, error) { +// TrialExited informs the searcher that a trial has exited. +func (s *Searcher) TrialExited(requestID model.RequestID) ([]Action, error) { s.mu.Lock() defer s.mu.Unlock() s.state.TrialsClosed[requestID] = true - operations, err := s.method.trialClosed(s.context(), requestID) + actions, err := s.method.trialExited(s.context(), requestID) if err != nil { - return nil, errors.Wrapf(err, "error while handling a trial closed event: %s", requestID) + return nil, errors.Wrapf(err, "error while handling a trial closed event: %d", requestID) } - s.record(operations) + s.record(actions) - _, isCustom := s.method.(*customSearch) - // For non-custom-search methods, you can assume that trials will be created immediately. - if s.state.TrialsRequested == len(s.state.TrialsClosed) && !isCustom { + if s.state.TrialsRequested == len(s.state.TrialsClosed) { shutdown := Shutdown{ Cancel: len(s.state.Cancels) >= s.state.TrialsRequested, Failure: len(s.state.Failures) >= s.state.TrialsRequested, } - s.record([]Operation{shutdown}) - operations = append(operations, shutdown) + s.record([]Action{shutdown}) + actions = append(actions, shutdown) } - return operations, nil + return actions, nil } -// TrialIsClosed returns true if the close has been recorded with a TrialClosed call. +// TrialIsClosed returns true if the close has been recorded with a TrialExited call. func (s *Searcher) TrialIsClosed(requestID model.RequestID) bool { s.mu.Lock() defer s.mu.Unlock() @@ -227,45 +206,18 @@ func (s *Searcher) Progress() float64 { return progress } -// GetCustomSearcherEventQueue returns the searcher's custom searcher event queue. It returns an -// error if the search method is not a custom searcher. -func (s *Searcher) GetCustomSearcherEventQueue() (*SearcherEventQueue, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if sMethod, ok := s.method.(*customSearch); ok { - return sMethod.getSearcherEventQueue(), nil - } - return nil, unsupportedMethodError(s.method, "GetCustomSearcherEventQueue") -} - -// SetCustomSearcherProgress sets the custom searcher progress. -func (s *Searcher) SetCustomSearcherProgress(progress float64) error { - s.mu.Lock() - defer s.mu.Unlock() - - if sMethod, ok := s.method.(*customSearch); ok { - sMethod.setCustomSearcherProgress(progress) - return nil - } - return unsupportedMethodError(s.method, "SetCustomSearcherProgress") -} - -// Record records operations that were requested by the searcher for a specific trial. -func (s *Searcher) Record(ops []Operation) { +// Record records actions that were requested by the searcher for a specific trial. +func (s *Searcher) Record(ops []Action) { s.mu.Lock() defer s.mu.Unlock() s.record(ops) } -func (s *Searcher) record(ops []Operation) { +func (s *Searcher) record(ops []Action) { for _, op := range ops { - switch op.(type) { - case Create: + if _, ok := op.(Create); ok { s.state.TrialsRequested++ - case Shutdown: - s.state.Shutdown = true } } } diff --git a/master/pkg/searcher/simulate.go b/master/pkg/searcher/simulate.go index 0d89dc110c4..f64d3ba4cb4 100644 --- a/master/pkg/searcher/simulate.go +++ b/master/pkg/searcher/simulate.go @@ -1,222 +1,121 @@ package searcher import ( - "encoding/json" - "math/rand" "sort" - "strconv" - "strings" - "time" + + "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/pkg/errors" - "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/master/pkg/mathx" + "github.com/determined-ai/determined/master/pkg/protoutils" + "github.com/determined-ai/determined/master/pkg/schemas/expconf" + "github.com/determined-ai/determined/proto/pkg/experimentv1" ) -// ValidationFunction calculates the validation metric for the validation step. -type ValidationFunction func(random *rand.Rand, trialID, idx int) float64 - -// ConstantValidation returns the same validation metric for all validation steps. -func ConstantValidation(_ *rand.Rand, _, _ int) float64 { return 1 } - -// RandomValidation returns a random validation metric for each validation step. -func RandomValidation(rand *rand.Rand, _, _ int) float64 { return rand.Float64() } - -// TrialIDMetric returns the trialID as the metric for all validation steps. -func TrialIDMetric(_ *rand.Rand, trialID, _ int) float64 { - return float64(trialID) +// SearchSummary describes a summary of planned trials and the associated expconf.SearcherConfig. +type SearchSummary struct { + Trials []TrialSummary + Config expconf.SearcherConfig } -// SimulationResults holds all created trials and all executed workloads for each trial. -type SimulationResults map[model.RequestID][]ValidateAfter - -// MarshalJSON implements the json.Marshaler interface. -func (s SimulationResults) MarshalJSON() ([]byte, error) { - summary := make(map[string]int) +// SearchUnit is a length unit. If MaxLength is true, Name and Value will be ignored. +type SearchUnit struct { + Name *string + Value *int32 + MaxLength bool +} - for _, ops := range s { - var keyParts []string - for _, op := range ops { - keyParts = append(keyParts, strconv.FormatUint(op.Length, 10)) - } - summary[strings.Join(keyParts, " ")]++ +// Proto converts the SearchUnit to its protobuf representation. +func (su SearchUnit) Proto() *experimentv1.SearchUnit { + return &experimentv1.SearchUnit{ + Name: su.Name, + Value: su.Value, + MaxLength: su.MaxLength, } - - return json.Marshal(summary) } -// Simulation holds the configuration and results of simulated run of a searcher. -type Simulation struct { - Results SimulationResults `json:"results"` - Seed int64 `json:"seed"` +// TrialSummary is a summary of the number of trials that will train for Unit length. +type TrialSummary struct { + Count int + Unit SearchUnit } -// Simulate simulates the searcher. -func Simulate( - s *Searcher, seed *int64, valFunc ValidationFunction, randomOrder bool, metricName string, -) (Simulation, error) { - simulation := Simulation{ - Results: make(SimulationResults), - Seed: time.Now().Unix(), - } - //nolint:gosec // Weak RNG doesn't matter here. - random := rand.New(rand.NewSource(simulation.Seed)) - if seed != nil { - simulation.Seed = *seed - //nolint:gosec // Weak RNG doesn't matter here. - random = rand.New(rand.NewSource(*seed)) - } - - lengthCompleted := make(map[model.RequestID]PartialUnits) - pending := make(map[model.RequestID][]Operation) - trialIDs := make(map[model.RequestID]int) - var requestIDs []model.RequestID - ops, err := s.InitialOperations() - if err != nil { - return simulation, err - } - - lastProgress := s.Progress() - if lastProgress != 0.0 { - return simulation, errors.Errorf("Initial searcher progress started at %f", lastProgress) - } - - shutdown, err := handleOperations(pending, &requestIDs, ops) - if err != nil { - return simulation, err - } - - nextTrialID := 1 - trialOpIdxs := map[model.RequestID]int{} - for !shutdown { - requestID, err := pickTrial(random, pending, requestIDs, randomOrder) - if err != nil { - return simulation, err - } - operation := pending[requestID][0] - pending[requestID] = pending[requestID][1:] - - switch operation := operation.(type) { - case Create: - simulation.Results[requestID] = []ValidateAfter{} - trialIDs[requestID] = nextTrialID - ops, err := s.TrialCreated(operation.RequestID) - if err != nil { - return simulation, err - } - trialOpIdxs[requestID] = 0 - lengthCompleted[requestID] = 0 - shutdown, err = handleOperations(pending, &requestIDs, ops) - if err != nil { - return simulation, err - } - nextTrialID++ - case ValidateAfter: - simulation.Results[requestID] = append(simulation.Results[requestID], operation) - s.SetTrialProgress(requestID, PartialUnits(operation.Length)) - - metric := valFunc(random, trialIDs[requestID], trialOpIdxs[requestID]) - ops, err := s.ValidationCompleted(requestID, metric, operation) - if err != nil { - return simulation, err - } - trialOpIdxs[requestID]++ - - shutdown, err = handleOperations(pending, &requestIDs, ops) - if err != nil { - return simulation, err - } - case Close: - delete(pending, requestID) - ops, err := s.TrialClosed(requestID) - if err != nil { - return simulation, err - } - shutdown, err = handleOperations(pending, &requestIDs, ops) - if err != nil { - return simulation, err - } - default: - return simulation, errors.Errorf("unexpected searcher operation: %T", operation) - } - if shutdown { - if len(pending) != 0 { - return simulation, errors.New("searcher shutdown prematurely") - } - break - } - - progress := s.Progress() - if progress < lastProgress { - return simulation, errors.Errorf( - "searcher progress dropped from %f%% to %f%%", lastProgress*100, progress*100) - } - lastProgress = progress +// Proto converts the TrialSummary to its protobuf representation. +func (rs TrialSummary) Proto() *experimentv1.TrialSummary { + return &experimentv1.TrialSummary{ + Count: int32(rs.Count), + Unit: rs.Unit.Proto(), } +} - lastProgress = s.Progress() - if lastProgress != 1.0 { - return simulation, errors.Errorf( - "searcher progress was not equal to 100%%: %f%%", lastProgress*100) +// Proto converts the SearchSummary to its protobuf representation. +func (s SearchSummary) Proto() *experimentv1.SearchSummary { + var trialSummaries []*experimentv1.TrialSummary + for _, v := range s.Trials { + trialSummaries = append(trialSummaries, v.Proto()) } - if len(simulation.Results) != len(requestIDs) { - return simulation, errors.New("more trials created than completed") + return &experimentv1.SearchSummary{ + Config: protoutils.ToStruct(s.Config), + Trials: trialSummaries, } - return simulation, nil } -func handleOperations( - pending map[model.RequestID][]Operation, requestIDs *[]model.RequestID, operations []Operation, -) (bool, error) { - for _, operation := range operations { - switch op := operation.(type) { - case Create: - *requestIDs = append(*requestIDs, op.RequestID) - pending[op.RequestID] = []Operation{op} - case Requested: - pending[op.GetRequestID()] = append(pending[op.GetRequestID()], op) - case Shutdown: - return true, nil - default: - return false, errors.Errorf("unexpected operation: %T", operation) - } +// Simulate generates the intended training plan for the searcher. +func Simulate(conf expconf.SearcherConfig, hparams expconf.Hyperparameters) (SearchSummary, error) { + searchSummary := SearchSummary{ + Trials: []TrialSummary{}, + Config: conf, } - return false, nil -} - -func pickTrial( - random *rand.Rand, pending map[model.RequestID][]Operation, requestIDs []model.RequestID, - randomOrder bool, -) (model.RequestID, error) { - // If randomOrder is false, then return the first id from requestIDs that has any operations - // pending. - if !randomOrder { - for _, requestID := range requestIDs { - operations := pending[requestID] - if len(operations) > 0 { - return requestID, nil + switch { + case conf.RawSingleConfig != nil: + searchSummary.Trials = append(searchSummary.Trials, TrialSummary{Count: 1, Unit: SearchUnit{MaxLength: true}}) + return searchSummary, nil + case conf.RawRandomConfig != nil: + maxTrials := conf.RawRandomConfig.MaxTrials() + searchSummary.Trials = append(searchSummary.Trials, TrialSummary{Count: maxTrials, Unit: SearchUnit{MaxLength: true}}) + return searchSummary, nil + case conf.RawGridConfig != nil: + hparamGrid := newHyperparameterGrid(hparams) + searchSummary.Trials = append( + searchSummary.Trials, TrialSummary{Count: len(hparamGrid), Unit: SearchUnit{MaxLength: true}}, + ) + return searchSummary, nil + case conf.RawAdaptiveASHAConfig != nil: + ashaConfig := conf.RawAdaptiveASHAConfig + brackets := makeBrackets(*ashaConfig) + unitsPerTrial := make(map[int32]int) + for _, bracket := range brackets { + rungs := makeRungs(bracket.numRungs, ashaConfig.Divisor(), ashaConfig.Length().Units) + rungTrials := bracket.maxTrials + // For each rung, calculate number of trials that will be stopped before next rung + // to determine the number of trials that will only train to the current rung. + for i, rung := range rungs { + rungUnits := int(rung.UnitsNeeded) + trialsContinued := mathx.Max(int(float64(rungTrials)/ashaConfig.Divisor()), 1) + trialsStopped := rungTrials - trialsContinued + if i == len(rungs)-1 { + trialsStopped = rungTrials + } + unitsPerTrial[int32(rungUnits)] += trialsStopped + rungTrials = trialsContinued } } - return model.RequestID{}, errors.New("tried to pick a trial when no trial had pending operations") - } - - // If randomOrder is true, pseudo-randomly select a trial that has pending operations. - var candidates []model.RequestID - for requestID, operations := range pending { - if len(operations) > 0 { - candidates = append(candidates, requestID) + for units, numTrials := range unitsPerTrial { + searchSummary.Trials = append(searchSummary.Trials, TrialSummary{ + Count: numTrials, + Unit: SearchUnit{ + Name: ptrs.Ptr(string(ashaConfig.Length().Unit)), + Value: &units, + }, + }) } + // Sort by target units for consistency in output. + sort.Slice(searchSummary.Trials, func(i, j int) bool { + return *searchSummary.Trials[i].Unit.Value < *searchSummary.Trials[j].Unit.Value + }) + return searchSummary, nil + default: + return SearchSummary{}, errors.New("invalid searcher configuration") } - if len(candidates) == 0 { - return model.RequestID{}, errors.New("tried to pick a trial when no trial had pending operations") - } - - // Map iteration order is nondeterministic, even for identical maps in the same run, so sort the - // candidates before selecting one. - sort.Slice(candidates, func(i, j int) bool { - return candidates[i].Before(candidates[j]) - }) - - choice := random.Intn(len(candidates)) - return candidates[choice], nil } diff --git a/master/pkg/searcher/simulate_test.go b/master/pkg/searcher/simulate_test.go new file mode 100644 index 00000000000..dacba6b3a1f --- /dev/null +++ b/master/pkg/searcher/simulate_test.go @@ -0,0 +1,83 @@ +package searcher + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/determined-ai/determined/master/pkg/ptrs" + "github.com/determined-ai/determined/master/pkg/schemas/expconf" +) + +func TestSimulateASHA(t *testing.T) { + maxConcurrentTrials := 5 + maxTrials := 10 + divisor := 3.0 + maxTime := 900 + timeMetric := ptrs.Ptr("batches") + config := expconf.SearcherConfig{ + RawAdaptiveASHAConfig: &expconf.AdaptiveASHAConfig{ + RawMaxRungs: ptrs.Ptr(10), + RawMaxTime: &maxTime, + RawDivisor: &divisor, + RawMaxConcurrentTrials: &maxConcurrentTrials, + RawMaxTrials: &maxTrials, + RawTimeMetric: timeMetric, + RawMode: ptrs.Ptr(expconf.StandardMode), + }, + RawMetric: ptrs.Ptr("loss"), + RawSmallerIsBetter: ptrs.Ptr(true), + } + intHparam := &expconf.IntHyperparameter{RawMaxval: 10, RawCount: ptrs.Ptr(3)} + hparams := expconf.Hyperparameters{ + "x": expconf.Hyperparameter{RawIntHyperparameter: intHparam}, + } + + res, err := Simulate(config, hparams) + require.NoError(t, err) + // Bracket #1: 7 total trials + // Rungs: [100, 300, 900] + // - 7 at 100 -> 2 at 300 -> 1 at 900 + // => 5 for 100, 1 for 300, 1 for 900 + // + // Bracket #2: 3 total trials + // Rungs: [300, 900] + // - 3 at 300 -> 1 at 900 + // => 2 for 300, 1 for 900 + require.Equal(t, config, res.Config) + expectedRunSummary := []TrialSummary{ + {Count: 5, Unit: SearchUnit{Name: timeMetric, Value: ptrs.Ptr(int32(100))}}, + {Count: 3, Unit: SearchUnit{Name: timeMetric, Value: ptrs.Ptr(int32(300))}}, + {Count: 2, Unit: SearchUnit{Name: timeMetric, Value: ptrs.Ptr(int32(900))}}, + } + require.Equal(t, expectedRunSummary, res.Trials) +} + +func TestSimulateGrid(t *testing.T) { + maxConcurrentTrials := 2 + numHparams := 4 + gridConfig := expconf.GridConfig{ + RawMaxConcurrentTrials: ptrs.Ptr(maxConcurrentTrials), + } + searcherConfig := expconf.SearcherConfig{ + RawGridConfig: &gridConfig, + RawMetric: ptrs.Ptr("loss"), + } + hparams := expconf.Hyperparameters{ + "a": expconf.Hyperparameter{ + RawIntHyperparameter: &expconf.IntHyperparameter{ + RawMinval: 0, RawMaxval: 10, RawCount: ptrs.Ptr(numHparams), + }, + }, + } + + res, err := Simulate(searcherConfig, hparams) + require.NoError(t, err) + + // Expect all configured hparams in space = 4 trials at max length. + require.Equal(t, searcherConfig, res.Config) + expectedRunSummary := []TrialSummary{ + {Count: numHparams, Unit: SearchUnit{MaxLength: true}}, + } + require.Equal(t, expectedRunSummary, res.Trials) +} diff --git a/master/pkg/searcher/tournament.go b/master/pkg/searcher/tournament.go index 9feee17d60c..e99752207c6 100644 --- a/master/pkg/searcher/tournament.go +++ b/master/pkg/searcher/tournament.go @@ -6,17 +6,15 @@ import ( "github.com/pkg/errors" "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) -// tournamentSearch runs multiple search methods in tandem. Callbacks for completed operations -// are sent to the originating search method that created the corresponding operation. +// tournamentSearch trial multiple search methods in tandem. Callbacks for completed actions +// are sent to the originating search method that initiated the corresponding action. type ( tournamentSearchState struct { - SubSearchUnitsCompleted []float64 `json:"sub_search_units_completed"` - TrialTable map[model.RequestID]int `json:"trial_table"` - SubSearchStates []json.RawMessage `json:"sub_search_states"` - SearchMethodType SearchMethodType `json:"search_method_type"` + TrialTable map[model.RequestID]int `json:"trial_table"` + SubSearchStates []json.RawMessage `json:"sub_search_states"` + SearchMethodType SearchMethodType `json:"search_method_type"` } tournamentSearch struct { subSearches []SearchMethod @@ -28,10 +26,9 @@ func newTournamentSearch(mt SearchMethodType, subSearches ...SearchMethod) *tour return &tournamentSearch{ subSearches: subSearches, tournamentSearchState: tournamentSearchState{ - SubSearchUnitsCompleted: make([]float64, len(subSearches)), - TrialTable: make(map[model.RequestID]int), - SubSearchStates: make([]json.RawMessage, len(subSearches)), - SearchMethodType: mt, + TrialTable: make(map[model.RequestID]int), + SubSearchStates: make([]json.RawMessage, len(subSearches)), + SearchMethodType: mt, }, } } @@ -60,22 +57,22 @@ func (s *tournamentSearch) Restore(state json.RawMessage) error { return nil } -func (s *tournamentSearch) initialOperations(ctx context) ([]Operation, error) { - var operations []Operation +func (s *tournamentSearch) initialTrials(ctx context) ([]Action, error) { + var actions []Action for i, subSearch := range s.subSearches { - ops, err := subSearch.initialOperations(ctx) + creates, err := subSearch.initialTrials(ctx) if err != nil { return nil, err } - s.markCreates(i, ops) - operations = append(operations, ops...) + s.markCreates(i, creates) + actions = append(actions, creates...) } - return operations, nil + return actions, nil } func (s *tournamentSearch) trialCreated( ctx context, requestID model.RequestID, -) ([]Operation, error) { +) ([]Action, error) { subSearchID := s.TrialTable[requestID] subSearch := s.subSearches[subSearchID] ops, err := subSearch.trialCreated(ctx, requestID) @@ -83,27 +80,27 @@ func (s *tournamentSearch) trialCreated( } func (s *tournamentSearch) validationCompleted( - ctx context, requestID model.RequestID, metric interface{}, op ValidateAfter, -) ([]Operation, error) { + ctx context, requestID model.RequestID, metrics map[string]interface{}, +) ([]Action, error) { subSearchID := s.TrialTable[requestID] subSearch := s.subSearches[subSearchID] - ops, err := subSearch.validationCompleted(ctx, requestID, metric, op) + ops, err := subSearch.validationCompleted(ctx, requestID, metrics) return s.markCreates(subSearchID, ops), err } -// trialClosed informs the searcher that the trial has been closed as a result of a Close operation. -func (s *tournamentSearch) trialClosed( +// runExited informs the searcher that the run has exited. +func (s *tournamentSearch) trialExited( ctx context, requestID model.RequestID, -) ([]Operation, error) { +) ([]Action, error) { subSearchID := s.TrialTable[requestID] subSearch := s.subSearches[subSearchID] - ops, err := subSearch.trialClosed(ctx, requestID) + ops, err := subSearch.trialExited(ctx, requestID) return s.markCreates(subSearchID, ops), err } func (s *tournamentSearch) trialExitedEarly( ctx context, requestID model.RequestID, exitedReason model.ExitedReason, -) ([]Operation, error) { +) ([]Action, error) { subSearchID := s.TrialTable[requestID] subSearch := s.subSearches[subSearchID] ops, err := subSearch.trialExitedEarly(ctx, requestID, exitedReason) @@ -112,12 +109,12 @@ func (s *tournamentSearch) trialExitedEarly( // progress returns experiment progress as a float between 0.0 and 1.0. func (s *tournamentSearch) progress( - trialProgress map[model.RequestID]PartialUnits, + trialProgress map[model.RequestID]float64, trialsClosed map[model.RequestID]bool, ) float64 { sum := 0.0 for subSearchID, subSearch := range s.subSearches { - subSearchTrialProgress := map[model.RequestID]PartialUnits{} + subSearchTrialProgress := map[model.RequestID]float64{} for rID, p := range trialProgress { if subSearchID == s.TrialTable[rID] { subSearchTrialProgress[rID] = p @@ -134,15 +131,16 @@ func (s *tournamentSearch) progress( return sum / float64(len(s.subSearches)) } -func (s *tournamentSearch) Unit() expconf.Unit { - return s.subSearches[0].Unit() -} - -func (s *tournamentSearch) markCreates(subSearchID int, operations []Operation) []Operation { - for _, operation := range operations { - if operation, ok := operation.(Create); ok { - s.TrialTable[operation.RequestID] = subSearchID +func (s *tournamentSearch) markCreates(subSearchID int, actions []Action) []Action { + for _, action := range actions { + if _, ok := action.(Create); ok { + create := action.(Create) + s.TrialTable[create.RequestID] = subSearchID } } - return operations + return actions +} + +func (s *tournamentSearch) Type() SearchMethodType { + return s.SearchMethodType } diff --git a/master/pkg/searcher/tournament_test.go b/master/pkg/searcher/tournament_test.go index ae51010aca1..5a4c80c0c6c 100644 --- a/master/pkg/searcher/tournament_test.go +++ b/master/pkg/searcher/tournament_test.go @@ -4,89 +4,103 @@ package searcher import ( "testing" - "gotest.tools/assert" + "github.com/determined-ai/determined/master/pkg/model" + + "github.com/stretchr/testify/require" "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) -const RandomTournamentSearch SearchMethodType = "random_tournament" - -func TestRandomTournamentSearcher(t *testing.T) { - actual := newTournamentSearch( - RandomTournamentSearch, - newRandomSearch(schemas.WithDefaults(expconf.RandomConfig{ - RawMaxTrials: ptrs.Ptr(2), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(300)), - })), - newRandomSearch(schemas.WithDefaults(expconf.RandomConfig{ - RawMaxTrials: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(200)), - })), - ) - expected := [][]ValidateAfter{ - toOps("300B"), - toOps("300B"), - toOps("200B"), - toOps("200B"), - toOps("200B"), +func TestAdaptiveASHASearchMethod(t *testing.T) { + maxConcurrentTrials := 3 + maxTrials := 9 + maxRungs := 5 + divisor := 3.0 + maxTime := 90 + metric := "loss" + config := expconf.AdaptiveASHAConfig{ + RawMaxTime: &maxTime, + RawDivisor: &divisor, + RawMaxRungs: &maxRungs, + RawMaxConcurrentTrials: &maxConcurrentTrials, + RawMaxTrials: &maxTrials, + RawTimeMetric: ptrs.Ptr("batches"), + RawMode: ptrs.Ptr(expconf.StandardMode), + } + searcherConfig := expconf.SearcherConfig{ + RawAdaptiveASHAConfig: &config, + RawSmallerIsBetter: ptrs.Ptr(true), + RawMetric: ptrs.Ptr(metric), + } + intHparam := &expconf.IntHyperparameter{RawMaxval: 10, RawCount: ptrs.Ptr(3)} + hparams := expconf.Hyperparameters{ + "x": expconf.Hyperparameter{RawIntHyperparameter: intHparam}, } - checkSimulation(t, actual, nil, ConstantValidation, expected) -} -func TestRandomTournamentSearcherReproducibility(t *testing.T) { - conf := expconf.RandomConfig{ - RawMaxTrials: ptrs.Ptr(5), RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(800)), + // Create a new test searcher and verify correct brackets/rungs initialized. + testSearchRunner := NewTestSearchRunner(t, searcherConfig, hparams) + search := testSearchRunner.method.(*tournamentSearch) + expectedRungs := []*rung{ + {UnitsNeeded: uint64(10)}, + {UnitsNeeded: uint64(30)}, + {UnitsNeeded: uint64(90)}, } - conf = schemas.WithDefaults(conf) - gen := func() SearchMethod { - return newTournamentSearch( - RandomTournamentSearch, - newRandomSearch(conf), - newRandomSearch(conf), - ) + for i, s := range search.subSearches { + ashaSearch := s.(*asyncHalvingStoppingSearch) + require.Equal(t, expectedRungs[i:], ashaSearch.Rungs) } - checkReproducibility(t, gen, nil, defaultMetric) -} -func TestTournamentSearchMethod(t *testing.T) { - expectedTrials := []predefinedTrial{ - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.1), - newConstantPredefinedTrial(toOps("1000B"), 0.2), - newConstantPredefinedTrial(toOps("1000B"), 0.3), + // Simulate running the search. + testSearchRunner.run(90, 10, true) - newConstantPredefinedTrial(toOps("1000B"), 0.3), - newConstantPredefinedTrial(toOps("1000B"), 0.2), - newConstantPredefinedTrial(toOps("1000B 3000B"), 0.1), - } + // Expect 2 brackets and 9 total trials. + require.Len(t, search.subSearches, 2) + require.Len(t, search.TrialTable, maxTrials) - adaptiveConfig1 := expconf.SearcherConfig{ - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(3), - RawDivisor: ptrs.Ptr[float64](3), - }, - } - adaptiveConfig1 = schemas.WithDefaults(adaptiveConfig1) - adaptiveMethod1 := NewSearchMethod(adaptiveConfig1) + bracket1 := make(map[model.RequestID]*testTrial) + bracket2 := make(map[model.RequestID]*testTrial) - adaptiveConfig2 := expconf.SearcherConfig{ - RawAsyncHalvingConfig: &expconf.AsyncHalvingConfig{ - RawNumRungs: ptrs.Ptr(3), - RawMaxLength: ptrs.Ptr(expconf.NewLengthInBatches(9000)), - RawMaxTrials: ptrs.Ptr(3), - RawDivisor: ptrs.Ptr[float64](3), - }, + for _, tr := range testSearchRunner.trials { + if search.TrialTable[tr.requestID] == 0 { + bracket1[tr.requestID] = tr + } else { + bracket2[tr.requestID] = tr + } } - adaptiveConfig2 = schemas.WithDefaults(adaptiveConfig2) - adaptiveMethod2 := NewSearchMethod(adaptiveConfig2) - - params := expconf.Hyperparameters{} - method := newTournamentSearch(AdaptiveSearch, adaptiveMethod1, adaptiveMethod2) + // Bracket #1: 6 total trials + // Rungs: [10, 30, 90] + // Since we reported progressively worse metrics, only one run continues to top rung. + // All others are stopped at first rung. + require.Len(t, bracket1, 6) + stoppedAt90 := 0 + stoppedAt10 := 0 + for _, tr := range bracket1 { + if tr.stoppedAt == 90 { + stoppedAt90++ + } + if tr.stoppedAt == 10 { + stoppedAt10++ + } + } + require.Equal(t, 5, stoppedAt10) + require.Equal(t, 1, stoppedAt90) - err := checkValueSimulation(t, method, params, expectedTrials) - assert.NilError(t, err) + // Bracket #2: 3 total trials + // Rungs: [30, 90] + // First run (run #3 from initialTrials) continues to top rung, two will stop at first rung. + require.Len(t, bracket2, 3) + stoppedAt90 = 0 + stoppedAt30 := 0 + for _, tr := range bracket2 { + if tr.stoppedAt == 90 { + stoppedAt90++ + } + if tr.stoppedAt == 30 { + stoppedAt30++ + } + } + require.Equal(t, 1, stoppedAt90) + require.Equal(t, 2, stoppedAt30) } diff --git a/master/pkg/searcher/util_test.go b/master/pkg/searcher/util_test.go index 0a6c1fd23b5..637532d841a 100644 --- a/master/pkg/searcher/util_test.go +++ b/master/pkg/searcher/util_test.go @@ -1,317 +1,143 @@ package searcher import ( - "bytes" "fmt" - "strconv" - "strings" "testing" - "github.com/pkg/errors" + "github.com/determined-ai/determined/master/pkg/model" + "gotest.tools/assert" - "github.com/determined-ai/determined/master/pkg/check" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/nprand" - "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) -const defaultMetric = "metric" - -func isExpected(actual, expected []ValidateAfter) bool { - if len(actual) != len(expected) { - return false - } - for i, act := range actual { - if expected[i].Length != act.Length { - return false - } - } - return true +type TestSearchRunner struct { + config expconf.SearcherConfig + searcher *Searcher + method SearchMethod + trials []*testTrial + t *testing.T } -func checkSimulation( - t *testing.T, - method SearchMethod, - params expconf.Hyperparameters, - validation ValidationFunction, - expected [][]ValidateAfter, -) { - search := NewSearcher(0, method, params) - actual, err := Simulate(search, new(int64), validation, true, defaultMetric) - assert.NilError(t, err) - - assert.Equal(t, len(actual.Results), len(expected)) - for _, actualTrial := range actual.Results { - found := false - for i, expectedTrial := range expected { - if isExpected(actualTrial, expectedTrial) { - expected = append(expected[:i], expected[i+1:]...) - found = true - break - } - } - if !found { - t.Errorf("unexpected trial %+v not in %+v", actualTrial, expected) - } - } +type testTrial struct { + requestID model.RequestID + hparams HParamSample + stopped bool + stoppedAt int + completed bool } -// checkReproducibility creates two searchers with the same seed and the given config, simulates -// them, and checks that they produce the same trials and the same sequence of workloads for each -// trial. -func checkReproducibility( - t assert.TestingT, methodGen func() SearchMethod, hparams expconf.Hyperparameters, metric string, -) { - hparams = schemas.WithDefaults(hparams) - seed := int64(17) - searcher1 := NewSearcher(uint32(seed), methodGen(), hparams) - searcher2 := NewSearcher(uint32(seed), methodGen(), hparams) - - results1, err1 := Simulate(searcher1, &seed, ConstantValidation, true, metric) - assert.NilError(t, err1) - results2, err2 := Simulate(searcher2, &seed, ConstantValidation, true, metric) - assert.NilError(t, err2) - - assert.Equal(t, len(results1.Results), len(results2.Results), - "searchers had different number of trials") - for requestID := range results1.Results { - w1 := results1.Results[requestID] - w2 := results2.Results[requestID] - - assert.Equal(t, len(w1), len(w2), "trial had different numbers of workloads between searchers") - for i := range w1 { - // We want to ignore the start and end time fields, so check the rest individually. - assert.Equal(t, w1[i], w2[i], "workload differed between searchers") - } - } -} - -func toOps(types string) (ops []ValidateAfter) { - for _, unparsed := range strings.Fields(types) { - count, err := strconv.ParseUint(unparsed[:len(unparsed)-1], 10, 64) - if err != nil { - panic(err) - } - switch unit := string(unparsed[len(unparsed)-1]); unit { - case "R": - ops = append(ops, ValidateAfter{Length: count}) - case "B": - ops = append(ops, ValidateAfter{Length: count}) - case "E": - ops = append(ops, ValidateAfter{Length: count}) - } - } - return ops +func (tr testTrial) String() string { + return fmt.Sprintf( + "testTrial{requestID: %v, hparams: %v, stopped: %v, stoppedAt: %v, completed: %v}", + tr.requestID, tr.hparams, tr.stopped, tr.stoppedAt, tr.completed, + ) } -type predefinedTrial struct { - Ops []ValidateAfter - ValMetrics []float64 - EarlyExit *int +func mockRequestID(id int) model.RequestID { + return model.RequestID{byte(id)} } -func newPredefinedTrial(ops []ValidateAfter, earlyExit *int, valMetrics []float64) predefinedTrial { - return predefinedTrial{ - Ops: ops, - EarlyExit: earlyExit, - ValMetrics: valMetrics, +func (sr *TestSearchRunner) run(maxUnits int, valPeriod int, increasing bool) { + metric := 0.0 + sr.initialRuns() + for i := 0; i < len(sr.trials); i++ { + trial := sr.trials[i] + for j := 0; j <= maxUnits; j += valPeriod { + if increasing { + metric++ + } else { + metric-- + } + sr.reportValidationMetric(trial.requestID, j, metric) + if trial.stopped { + trial.stoppedAt = j + break + } + } + sr.closeRun(trial.requestID) } } -func newEarlyExitPredefinedTrial(ops []ValidateAfter, valMetric float64) predefinedTrial { - var valMetrics []float64 - for range ops { - valMetrics = append(valMetrics, valMetric) +func NewTestSearchRunner( + t *testing.T, config expconf.SearcherConfig, hparams expconf.Hyperparameters, +) *TestSearchRunner { + expSeed := uint32(102932948) + method := NewSearchMethod(config) + searcher := NewSearcher(expSeed, method, hparams) + return &TestSearchRunner{ + t: t, + config: config, + searcher: searcher, + method: method, + trials: []*testTrial{}, } - exitEarly := len(ops) - 1 - return newPredefinedTrial(ops, &exitEarly, valMetrics) } -func newConstantPredefinedTrial(ops []ValidateAfter, valMetric float64) predefinedTrial { - var valMetrics []float64 - for range ops { - valMetrics = append(valMetrics, valMetric) - } - return newPredefinedTrial(ops, nil, valMetrics) +func (sr *TestSearchRunner) initialRuns() ([]testTrial, []testTrial) { + creates, err := sr.searcher.InitialTrials() + assert.NilError(sr.t, err, "error getting initial trials") + created, stopped := sr.handleActions(creates) + return created, stopped } -func (t *predefinedTrial) Train(length uint64, opIndex int) error { - if opIndex >= len(t.Ops) { - return errors.Errorf("ran out of expected ops trying to train") - } - op := t.Ops[opIndex] - if op.Length != length { - return errors.Errorf("wanted %v got %v", op.Length, length) +func (sr *TestSearchRunner) reportValidationMetric( + requestID model.RequestID, stepNum int, metricVal float64, +) ([]testTrial, []testTrial) { + metrics := map[string]interface{}{ + sr.config.Metric(): metricVal, } - return nil -} - -func (t *predefinedTrial) CheckComplete(opIndex int) error { - return check.Equal(len(t.Ops), opIndex, "had ops %s left", t.Ops[opIndex:]) -} - -// checkValueSimulation will run a SearchMethod until completion, using predefinedTrials. -func checkValueSimulation( - t *testing.T, - method SearchMethod, - params expconf.Hyperparameters, - expectedTrials []predefinedTrial, -) error { - // Create requests are assigned a predefinedTrial in order. - var nextTrialID int - var pending []Operation - - trialIDs := map[model.RequestID]int{} - trialOpIdx := map[model.RequestID]int{} - trialEarlyExits := map[model.RequestID]bool{} - - ctx := context{ - rand: nprand.New(0), - hparams: params, + if sr.config.RawAdaptiveASHAConfig != nil { + timeMetric := string(sr.config.RawAdaptiveASHAConfig.Length().Unit) + metrics[timeMetric] = float64(stepNum) } - - ops, err := method.initialOperations(ctx) - if err != nil { - return errors.Wrap(err, "initialOperations") + if sr.config.RawAsyncHalvingConfig != nil { + timeMetric := string(sr.config.RawAsyncHalvingConfig.Length().Unit) + metrics[timeMetric] = float64(stepNum) } + actions, err := sr.searcher.ValidationCompleted(requestID, metrics) + assert.NilError(sr.t, err, "error completing validation") - pending = append(pending, ops...) - - for len(pending) > 0 { - var requestID model.RequestID - operation := pending[0] - pending = pending[1:] - - switch operation := operation.(type) { - case Create: - requestID = operation.RequestID - if nextTrialID >= len(expectedTrials) { - return errors.Errorf("search method created too many trials") - } - trialIDs[requestID] = nextTrialID - trialOpIdx[requestID] = 0 - - ops, err = method.trialCreated(ctx, requestID) - if err != nil { - return errors.Wrap(err, "trialCreated") - } - nextTrialID++ + created, stopped := sr.handleActions(actions) - case ValidateAfter: - requestID = operation.GetRequestID() - if trialEarlyExits[requestID] { - continue - } - - trialID := trialIDs[requestID] - trial := expectedTrials[trialID] - if trial.EarlyExit != nil && trialOpIdx[requestID] == *trial.EarlyExit { - trialEarlyExits[requestID] = true - } - ops, err = simulateOperationComplete(ctx, method, trial, operation, trialOpIdx[requestID]) - if err != nil { - return errors.Wrapf(err, "simulateOperationComplete for trial %v", trialID+1) - } - trialOpIdx[requestID]++ - if err = saveAndReload(method); err != nil { - return errors.Wrap(err, "snapshot failed") - } - - case Close: - requestID = operation.RequestID - trialID := trialIDs[requestID] - trial := expectedTrials[trialID] - err = trial.CheckComplete(trialOpIdx[requestID]) - if err != nil { - return errors.Wrapf(err, "trial %v closed before completion", trialID+1) - } - - ops, err = method.trialClosed(ctx, requestID) - if err != nil { - return errors.Wrap(err, "trialClosed") - } - - default: - return errors.Errorf("unexpected searcher operation: %T", operation) - } + return created, stopped +} - pending = append(pending, ops...) - } +// closeRun simulates a run completing its train loop and exiting. +func (sr *TestSearchRunner) closeRun(requestID model.RequestID) ([]testTrial, []testTrial) { + run := sr.getTrialByRequestID(requestID) + run.completed = true + actions, err := sr.searcher.TrialExited(requestID) + assert.NilError(sr.t, err, "error closing run") + return sr.handleActions(actions) +} - for requestID, trialID := range trialIDs { - if err = expectedTrials[trialID].CheckComplete(trialOpIdx[requestID]); err != nil { - return errors.Wrapf(err, "incomplete trial %v", trialID+1) +func (sr *TestSearchRunner) getTrialByRequestID(requestID model.RequestID) *testTrial { + for i, t := range sr.trials { + if t.requestID == requestID { + return sr.trials[i] } } - return nil } -func runValueSimulationTestCases(t *testing.T, testCases []valueSimulationTestCase) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Apply WithDefaults in one place to make tests easyto write. - config := schemas.WithDefaults(tc.config) - hparams := schemas.WithDefaults(tc.hparams) - method := NewSearchMethod(config) - err := checkValueSimulation(t, method, hparams, tc.expectedTrials) - assert.NilError(t, err) - }) - } -} - -type valueSimulationTestCase struct { - name string - expectedTrials []predefinedTrial - hparams expconf.Hyperparameters - config expconf.SearcherConfig -} - -func simulateOperationComplete( - ctx context, - method SearchMethod, - trial predefinedTrial, - operation ValidateAfter, - opIndex int, -) ([]Operation, error) { - if err := trial.Train(operation.Length, opIndex); err != nil { - return nil, errors.Wrap(err, "error checking ValidateAfter with predefinedTrial") - } +func (sr *TestSearchRunner) handleActions(actions []Action) ([]testTrial, []testTrial) { + var trialsCreated []testTrial + var trialsStopped []testTrial - if trial.EarlyExit != nil && opIndex == *trial.EarlyExit { - ops, err := method.trialExitedEarly(ctx, operation.RequestID, model.UserRequestedStop) - if err != nil { - return nil, errors.Wrap(err, "trainCompleted") + for _, action := range actions { + switch action := action.(type) { + case Create: + run := testTrial{requestID: action.RequestID, hparams: action.Hparams} + _, err := sr.searcher.TrialCreated(action.RequestID) + assert.NilError(sr.t, err, "error creating run") + sr.trials = append(sr.trials, &run) + trialsCreated = append(trialsCreated, run) + case Stop: + trial := sr.getTrialByRequestID(action.RequestID) + trial.stopped = true + trialsStopped = append(trialsStopped, *trial) } - return ops, nil } - - ops, err := method.validationCompleted( - ctx, operation.RequestID, trial.ValMetrics[opIndex], operation, - ) - if err != nil { - return nil, errors.Wrap(err, "validationCompleted") - } - - return ops, nil -} - -func saveAndReload(method SearchMethod) error { - // take the state back and forth through a round of serialization to test. - if state, err := method.Snapshot(); err != nil { - return err - } else if err := method.Restore(state); err != nil { - return err - } else if state2, err := method.Snapshot(); err != nil { // Test restore is correct. - return err - } else if !bytes.Equal(state, state2) { - unmarshaledState := method.Restore(state) - unmarshaledState2 := method.Restore(state2) - fmt.Printf("%+v\n", unmarshaledState) //nolint: forbidigo - fmt.Printf("%+v\n", unmarshaledState2) //nolint: forbidigo - return errors.New("successive snapshots were not identical") - } - return nil + return trialsCreated, trialsStopped } diff --git a/performance/daist/daist/metrics/base.py b/performance/daist/daist/metrics/base.py index b11fcdc214b..a60b9421813 100644 --- a/performance/daist/daist/metrics/base.py +++ b/performance/daist/daist/metrics/base.py @@ -32,6 +32,7 @@ class BaseMetricsTest(TestCase): hyperparameters: checkpoint_size: 4096 + num_batches: {num_batches} metric_count: {metric_count} {param1}: type: int @@ -44,14 +45,13 @@ class BaseMetricsTest(TestCase): resources: slots_per_trial: 1 - + # This may or may not be needed - # resource_pool: compute_pool + # resource_pool: compute_pool searcher: name: grid metric: "{param1}" - max_length: {searcher_max_length} max_concurrent_trials: {searcher_max_concurrent_trials} max_restarts: 0""" diff --git a/performance/daist/daist/metrics/config.yaml b/performance/daist/daist/metrics/config.yaml index a0c1ffda5ac..bf850d5dbc6 100644 --- a/performance/daist/daist/metrics/config.yaml +++ b/performance/daist/daist/metrics/config.yaml @@ -15,6 +15,10 @@ hyperparameters: delay: 0 # Seconds to sleep during each batch checkpoint_size: 4096 # Size of checkpoint data + # 2 batches are required because we are measuring the thing that is + # storing the metrics in the first pass + num_batches: 2 + batch_metric_count: type: int minval: 1 @@ -26,9 +30,6 @@ hyperparameters: searcher: name: grid metric: "metrics_count" - # this is the number of batches. 2 are required because we are measuring the thing that is - # storing the metrics in the first pass - max_length: 2 max_concurrent_trials: 1 -max_restarts: 0 \ No newline at end of file +max_restarts: 0 diff --git a/performance/daist/daist/metrics/model_def.py b/performance/daist/daist/metrics/model_def.py index 3217026e50d..6743d3e984a 100644 --- a/performance/daist/daist/metrics/model_def.py +++ b/performance/daist/daist/metrics/model_def.py @@ -26,10 +26,7 @@ class Param: METRIC_COUNT: type_ = 'metric_count' -def worker_main_in_context(core_context: Context, metric_count: int) -> int: - op = next(core_context.searcher.operations()) - num_batches = op.length - +def worker_main_in_context(core_context: Context, num_batches: int, metric_count: int) -> int: write_latencies = list() for batch in range(num_batches): @@ -48,9 +45,6 @@ def worker_main_in_context(core_context: Context, metric_count: int) -> int: core_context.train.report_validation_metrics(steps_completed=num_batches, metrics={MetricKey.WRITE: write_latencies}) - if core_context.distributed.rank == 0: - op.report_completed(None) - return 0 @@ -133,7 +127,9 @@ def worker_main(info: det.ClusterInfo): with det.core.init(distributed=distributed) as core_context: exit_code = worker_main_in_context( core_context=core_context, - metric_count=info.trial.hparams['metric_count']) + num_batches=info.trial.hparams['num_batches'], + metric_count=info.trial.hparams['metric_count'], + ) return exit_code diff --git a/performance/daist/daist/metrics/test_latency.py b/performance/daist/daist/metrics/test_latency.py index 9868dc657e6..cbf5f26398f 100644 --- a/performance/daist/daist/metrics/test_latency.py +++ b/performance/daist/daist/metrics/test_latency.py @@ -25,12 +25,12 @@ def _run_latency_experiment(self, samples: int) -> Experiment: cfg = self._cfg.format(name=f'{self.id()}', + num_batches=samples, metric_count=metric_count, param1=model_def.Param.CONCURRENCY, param1_minval=1, param1_maxval=concurrency+1, param1_count=concurrency, - searcher_max_length=samples, searcher_max_concurrent_trials=concurrency) return self._run_experiment(cfg) diff --git a/proto/buf.image.bin b/proto/buf.image.bin index 390d6dbde20a5f548d20bdb96192b2ab21235720..49c522ee0ec0209f8a4586f716f0267ebb04ac23 100644 GIT binary patch delta 29693 zcmbWgcbrs3_V+)1yZZL@?HqxT(12u6aR_Dv%pw?ezq`+~`*oMqaYbQ37|~VNCTBqq zrciQ_5d=X5Mg$o^Km{ZxK|w&01OY*k!1Fm()u-M4{rCHW*K5w_o?G>TR(Ydd?!UO?x!&$x zp#d((d8(*m_h+8}V{xYv*{Y#ckg4K$dUovma#3X~?W*s2$n~T)2f19aYL3&TCqBi-C{Wq0vtr92jhwhy^ zc7E#RE?qiy|5q&Gc%JF{a&bvz7iAcFQd_ax9Px5-C%ONSNXp=A?vR|f&28UGoKMn) zk~2iX)NG5&xYG3!Q`S&<>31uo@`o*Kti`O;RpVSw6+>spB zI_5;Sxa?=T7C+bNpGdI#%f&t9TTR`qoY6Nq-k2vZ?{k-Yd$=nhe`@M(B#U43`O44t zbBCN@H^&m*n0Pj?ldIrad3vwBvRa^(2 zujN;9Zb$MCCI5zOJ=Wb+r3xFVzG2x46Bc~m@CLIMQhg&5b#7I6;?kRP)>wDL=2TUV zMa6xS@1)|s=?$4Xskm=OW6tB?@yF~wuDszg@UHvOl;0H-eO#b<1Zn#7^I~>Cm-UwH z@~*pW5+<(pbL9t}r|2qJ#Q?ebU3Z5<{B6TiapMqI-FnWG>PA?@WSw#D_Raj38#fMfiKOv#Tz!XP^B1~Gu_-N`;QmtvexKqd6Zw_88aHwN7x@Ka_BdDG`z}{OhjebtOFvuG6d#HJ)8c=ztTOrb}35tP-A3K_a5nzuIVlvi#z_aC|%M$)3H-=r{aI!NgH(vgF6)MZ!gPBKkLj+@rAwYK`VWHo zot{Hp+2JraxkGr1mz|70m*mW55JJ2L*ukHr-{V3R}PpDlW>Q!#3zQ*rv4uFn>6`7LCJBy5u% zwbJP>Z?^arvYYNoB%Obt@@vtWoVH-BI1?=el-BLwTmCdkHMZ)R$=iApp#D=kAsVP*B4RHgV&Ry*(1rQVRq#wY^B)r` zDxsK4Pm$6=GBHx@>5jz+*MH(DpGFTD z2;d@@^(g|Fp*|E=5j4s}D1VqD_?gSPvEYvMA3&1rq{N5$JN;buu3gZUIz98e0_~xf z8`O{fSyWuqo$A#D8Yv+#NJ@{CFah$JtBM}C!suv-%x5%kJ*M)-)uk@0c0sFj>vZQ% zJxZwNs7TaEi~p-GRnvLCs1xerzk@Q`X8@t!#7Gjbv9kS-pJVC+-O~ z%Fk|dzv&KKJn7PFn-$&$a{m;ouzc<=_cyiwP!~l_sbR{pnio`7A?{w(NP_QOx#%NjozU0fr)%dz13?t?Ev3_286L* zdHFHB?QJw-LcrZ)y_VOwQ~Mw5&ByvpU&Y;Hy_mY|tGh8kz2}uLwYO`u)Yoaec`tW2 zCI6lmvy!=ryWjJoZ$6>!hBaO`Y%d-|8Qdtt_*@y>NN2p)jMKP^bjEwlI1Q8T^kSH# zQKEx|TPJ$unSORvBqWr@!4ODjVlE+pgeH1px)H+}2~G6gP}R8~;eI@3Pxa=_l2tm0 zdkY_j<5X`xtsvauarCokUh93Yo~j;K-OLCLf$%t0N0}VfLDVh8->By@ucg-zFp*N3 zx2~>dn8>osTiDE@{`oisXr?^WLEP62f8)lPUQ5>$OuFW+T2}=RCT^VRZO{_zBUX*%OL>f!7S2FzK2%sJkdk z+_=D7aD&qV^iRg@PrZ2+vh)ejwD3tde(KGq4Bg^M&@c8{OAP&!irxscCVZ0Um&wad zhzFYDZ&dEGY-@vwl$Lqb*ytfoWVy^+xr$SWET4*r74oGgMN`jH@LDmj>`8%tp7OFM z?cRm%(JL&ud6Fwp-fDo|^TzT~v+Z4ef7uVF*Z*eA$RnP+<-WTtfBEO1+IRff|1gql z9MIC!|EjB{} zJP>TS+^C`pQ{7SL=B$YVjPcxB=-iFmTA<^}-bwGB$rkV$>&b8IjBr73iA9qq?*q9x z=^awo5|_q%Osj2)SWfU*ldu}6$)Tl+r4eofU`vYS6uCd)P0QhNuia#d$8>vKK|)hJ z`FfcG6C_hS@r2V!-G__Q<^GhnsIa9PFQ$92@t`galdgGudcK5-8>f3PUuM*e7%$4? zzbbiOZ`#tLp`y%VmL@Q9W0?n=3!M%vP0BnNDw?5;x5dOPS*No1*2Hb_n&ru_t#)p= zqZ4@030|!1EzG|iSLS%~Ysy6lZ%wBjvb$Zibh){*xAtxLo7%O<(pv(UaG&p~-B4u* z6WPu8G`R^|61ou=7kI3AK^xh$P{d`cD&A&Ip@JB)E-Fmr<Oct zVFCuWjx0exv&+kZ1UvO*yk3lK(yIoZ8b!#6p<0AA-5u;?W!Rc;&0T__H0XriIlcyTQW@G zY^Rn?%fnVw(L3d~s@_IT@i)q`(_`tK5lp(~X?UwD2Ta_!)6=|#s@{8JVz<1dnzw24 zUU=>HBHPgL*fV1n*|r&@K@(_vyObfx!$by0J*IWTLXSXWHU+p3rwflhbY62n!GjHbQy+rj|D)hq){#N%o<;4qQ$)oj3Fg5fF! zoaVFpVsX$2)%8qudT=6`xJr^C=B-zLVx)~+FP|P{suvR z+f+GBM4`WJq5u;K_qRZneFl5DK6t z+GeW+Q+ZTLtL6(6w@6pL4_cTLC(E+h-iMGc&54t3Gbh5tjgxJ!o`qrJ#>sX( zsm}2pirLd`tX6B^=xx~SAvjLAaY#t^bN#yU<_MEGG$xOUzrh;S(V_{%r}_>b^Dj63Jp`IVeT z$mC(lWUjn+qqklm{ze{iZHu=39=Zh2xpq>Iv@ns!TpMRlwbYHcIA8v`j<fBpVQzNNPzoATW{V5(-Fl}8r!cIn=sWab@U1dCemAD zV+BNmce|Lq*3R36kkrJZfamLZ|k5T@+R9xHscVKZ1fyG4~SeQt8S2nOPLA8qlOFO&AD6qR_X?<_q=J*>o?#^~k zm~_oHOIDb;akq^nD@Cw9rh7Z@s4dUb_ttOL9-e!%!D|oreYUmVz_(ZMMyQ57+Y|gj zTRzsnTe%qwWO^`53lm8lwABfkhg^~AK^u<_sPD8V+C#Q{w}H1>TNtSSL)rR=iG&W> zW={=M`Ns5#HB2OQ$i|5^4ax0kR61hIOAWj?G>3sZkK_<2mm}2CsW;;)?mS{+Ye!2U zoLL;V^Um1v;fCH?&FIkLcy=D>fO4O(ty8Al9aOoEKtl%3CSu~WEhjZZuwkINp0>?4 zS|4Maw(GW3nd)PV)Ak*As-+`#?&s`0Y|}0`^xo8rw(jR_v!SG|`+3{CU^3NPcOy{G zQCs&*w)|@&;KD#jFJ()r_wJXnO+)Y9FJ=3j+Phz|<@K99wd%saoma9WmEOExv3>fW zL2ur#*aI&&@|It2l>f-v?(5e*HSJ-7h3w9SCn^q&DYH3dcXYr*M0xnAvD3|2s_W}6tkE4tS{yG z^}ey0P6)y>pRE?Z%xCz`)_Nw)))eg1^a$8+?`JhHrfBGcSNVFy;3GXx91) z)OwCg(5&-?8ycDyK(pQ_HZsi%Lem&g(~#FDpV-1QF9@PB(w`K~i!pn<&)Vyg<2Lx7 z$h-)M?LNDjfp`%R+kMsV%9?=K;S1knAYMfBJAGm|L%e7pjHn^VYmZNS&k!#v2qWQ% zCQ8>iX8+)`j``$^8+|v|?F@(?e70&AbP*6g_^PN)2AP1^?+Z7OZ8rEG4dE#Wv@YmT zKI^t)f6N3i@8=)jKW{?iQhOx#U?~X z)FC3Tt3Gj^LsYCnWF(CUAJNMZSV?NeVqzqLR#;%{8ngR4)-btnv#(~RD;WDawt8ej z*TC4wVko?H%kHWejdYh{=!W&J zV~usBv(;Bq_T1u2X1al7q+_ci3y^dJ$w((u)sS=p$w;SO3q#Tkd5m(zXeQ}qNQ|gS z$ZL!vB$ISgBu2VPk-UVbbdEJmk<`z;1c-5tjkgL^*Ln#MlGVn(h` zx&xvNbs|6B;(IjMK;7RXW`FEhi{&%hd^gtZf!jZJY_&W@+C6al$4*kMOvuFTA3Mz- zm4Duf+c7IcD~Z`l9qS9}+3u^8DZ%Ya9a}y8rE9o-sS{V{rexywrB2iPO$bU*xX&F? z!67IyAuys20eLNV#7Yi9i3)*{ux~=bFURcFjwx2yXL@?8$hF&jPTSqNntKJ9yK^=73NUx)YVH*P>~X|i z#(c$K?va&t_!@@rNCE+VHD>?lSSKC%^-f={%&UmQj}A{fuOfpV9iDhzMFxi)o_JnG z28SJSj5B!EWMD)c2lD#K5hpkfuc|m0sjhn4Im)GXvR*e`}M%j=58NV;>q z5%bx@{ni-2>0EL7x81&I<_*vc_jBiZ12n_^+_~NW&2T?=t~Zd+2)`J~G;bIhBWfD* z8s!(GndS{eW28b=AODWo(r-=j%b$1qnq>YB2$>7XzX2h0A^A5TWG*ECM)G6*VjM&K z+dvpmLy*^aznI7n|5gx2(&PS{F?)*N`oQmU%g6uVsVQIB?Q=3yavX1hV@i(WO>j)f zalDBfruxNn=6KU^7*TVO*ZY1^#vE@d4kO*If|0>fd%yJwYNknAe!S0DOIJ+>QJLlE zu9gAKEI)U(3}`;|b63kCpV@vfmuWJF#)z7Ryyp4Ee5T1L8Y9(F?XwroF&Fx+&u#D| z((UA#|MA2#dKuA+wm1v@m{+nkm=^jmv1DyPE%c-BMAXu#7dA}xBEMLS{1T3X#t%g= zJ&pQ1zt8+)Dg0tizV@TVr_t)`p3@s-%l%l8&-orBQE!4-?#D!}Gw+S}8en6l&f0)k z?#EQEGwDq*EBxXMhJm7CjQTsjRerICVc@6M9NC zjTKd2y!&Tw@mt@P?(@~n^d*)pxd`3hHUftjFaCZqn&Z@*t0WG46#lhJA^rvB(t_F=zu3bhtZx9j+FF<#W@nC|g> zQRmJ$)6@w43GpzEiCG(nhiOpE+CV%^qoU5bKOr9Ri=P+~ibgbQ&aaVu+%HZtBOHm* zXblzP03e?6TNnM(+3%~J89+2=ax?=dU}thP1E>r(VMWY#2qba{@e$gwC9S)!nn$d1jXoIM&_YGJ>SyR(VbM3h zwKj-g`Ubex22q^)2DsMnj$+L27Z3xGi5{DvXc(hv809xGAO++;q4U}PXspKZX%K;c6M1}|Y^h7E{u3miL7?W*C;cLU}gt+J{v|bvjaSz z4Wl+QJHYc9-ZI63eL&1*7$_RXs2WE3%?pV63T$Sj8|lO-jPhF@5MMG36oN5Y zeFZauVAck#t+MqIU#%7+2xe^n>wVpOMi9)}02cgN8!+nv;ra$<1i`GAw;c8*UF+oo zhkcDhe^Keajbyh6tiAHs5#JvhyiLh&57;$KGH+9|+XLx)*+#NE0^wWakB5C7LVAWB zi6njqSV!fmqnKewQW8Jp!ZebS_#qdjkyK~IyeP0_qVG`|L{s{xt*$=5XV4$=G?z-f-o9WPgtjpLlG z4VqsAIM30!yhAh>0^$-2+ay|#Br=3#}L(!9Mu@28j_MLDo&kbe^P)c;~1bNk6N_6i8 zdDUG?bngUt)r}YB@zP~bNT!3L>5QuBl;7B(c#r9z5S`H)>Mei--Grd^zTAGocVk8p z%!D8>KqSFT2=W3%63m1kFF+*0Obm)C3rBQH%m+a+i(#N>7^D8q@57*&!!Yn8 z7^CSq|6PKa7qrS{%~P23-zAuNL7v&(C75|Zp4r}|$jl4!%=Rw9d>j;?Fbot8V^j^J z{1yboB8GuNFh;vYb-woqW^vG3Ay=RBHO#z6FpG0-<~@R0oNF`h5zOLTn|Y65J`0M^ z83u}mF{*}9e#?SlIm1997^CTtU|bY1Uj(hSY9zR?sB=+?yy=uLmH8sbi^Xw7^hJ;t zi{pssiy$u+#}U!0pjg91P&AQIHIefBGAO=gA}B;;G(EkI2hjSUu(k%>r4dho=~?3` zO6zk`!lxW$bI%%2OzU$|8c$5$1jR;Xf})v>s+pAErl8ouOi+l)Xzg^*ngFQnLF+rY z?6mKJ%mjkjo@+l72xfb({Y)U3?YZ_dfnat7#V&?{qG61xVU*vuL9vHnpb(7F3RU}= zNV|@`LF+K8td`z&{H?ynBWIlPMW2{RXnTXaf}BWbdxN}!oJeT?A zLD6Vhl^sv};vT}wa~4aZNtFN5An#x&QT|7Ryn~%Y z`5z1N4t5gd|5H$$;QXQJ{EezZqx?<=#m^iXCB(XW_48|(WaZ3Xu!5aJ zFxP^-PMbn7*MhuGn?f+xgS<|gLNL8TqHicWpF+_vM%6INuU|+EfS;}=CF93nH7V1|Tv3x-dI$mT8BRDu~A;w{)zf*BSP!x;vOhB2y!QGO#rVkE;r zAsD0mQNc_jm@y%1f)wX{9okJJm@y%{wwVj25zLs-jSsO6n9@+VnVB1=5sVCpcNqqX zhLN)7IbY+L9%!c{m&qZkEab)qjOq~T^m$)AGo7-V9O4}xK3F1~cYM>S?2|*haG7Ja)lo8Cz5ckS5g83rELqr+DtO|)Q83u}mF{*}9eqV*eT84o_Fh;AZU_Kz24Iyi% zoO98a&U`>H8$#SGKOmS5A?}qQ5X{C9_sS0lW>ZLPWf&+L#;6)b`E3h{9Sj47V2t*p zf|*G}`RX>>i)CgK)$Wj8)AWFuM729q>pr$YwI>w5)%1v&MD<-r z>}4t_n#!n}O8M;ziTz9kg{X|CsbLe`~_9R8c{xy*-@ z?U`IZ{gARflk2A+QnqJv*?vgbo(-iMnQT9#Y<~%f^PDXdovl$dp7Q%OBrbALpp43| zg#DTp|E}m}gYIg`>Mw^}_5I(>Y=XO*i^*()yPAv1Y=XO%gPToo*K#qLO>oykqIWnu zxb{z1Xb0o1h~2yRfA7u$0P zZcw;sOVga@5Zs{fZS73t<`CTAuo%j4P&C|N`O*~(GP+}bgj`33t@q^9*L|%Td_=j9 z3Uj}vPbpz@zy64Fc_)kq^SAtl6N7sc_*^70A#A-r2H#j@=28k1!aNbor4%NFc_Nrg zDNGFWL@<{)Cxyin=7gd-jjDsF{HBJ*bPggEYW7Cc=O6QEWcVO#eLQBr^T*6QV)-C! z-(X^dk2J}ydK=qdnHdge8~Hq9nH3hZnFWewF{)-!esjWNF0()(7NhCy(7gIcPRtAA zQFd6p`~;_(MtE8U?qk3%2wO|#_b%tp4e%W;J_iZo1uT94$3(jzj5n~dHfR@y!-djy z-Pa+aPe|q?iHfkbChV5cX84q4ujSa8`IJ&cuIU0{mw3Bi07zM%!% zfcZKc&UX1v2xe_qtY;V~8pfy^M)`dc78@A`3c(mnpJpwfXEa;F*0%^y!9Dra%~oUq zL2U{1`e6YqMB8cO-?2#Z||1%;rDRyV0av=C5x z!q&mCyC6wN5VDuYshn9zPko+wo*hhj~v_PBf>(yeBFrnls^W_Ef!` zXwHVkIi`W4X^g6Al;8QVxWF_}h{kC3RDFF4n9Jdi)hFUENTpljmt8vbERrkoon+=y z3e)A>i1{gjUCxb|pAy&=jF@Iz#+TQ4-6AZmF&GpLW>gKP{H}*Z??`sifI?tK(_4bY z)Ytk&@S@v`UZ-|uF~RhU@RndP!SsvpmS8c#^pEhCU@^fAh=@VRB%vN(L(wos)iBC$ za6}Ac7$^i|wAw04pW(+E?BNk>45AcEKQ1fsol2R{2x)kP=P`T_PBzbDpApjV2#-ad z5z>f=7|BRbG?Gy@lJXlB5u+Ii3LzOyKk8pX@sSa0qAbrxn_faNGQw?o3Bkw+x9KGW zBO~0Vmk`X@h1Pfn%?v;C73A@JSh6B&uNlbN-$F*JX7Lh zb+UP;TuLxgBRo?sC75XuF@s^CXc(hv80A+M5i=PE3c(mnuQom>m^l$^kqkRd`wV`P z0H0z+cn$qI!OV&98v1jB`6$9`=+6meZbZyy7$_RXs2WE3eG(B183qc$7)_t7FN*+X zal~4ImQyd?4j=!^ZcpXcmYeYVMww-Vwm4#EpH43$w8fD+kC+K$8KHd^2{$)AY#E^~ ziHOe`4T?rHszy_O%OYYqqd_4wqxDdYx`N`gI$~{*9B8EfviQ zqS+g<4oBo)g3c|O6-2W)VrS2TR}jtKNWJ!kW(Cpgi-cR6*sLI$A0pxa(?HQQM%6UR z?_fk6Vj3t!W3>NKG%JbbSj0LLk+G1|A+wTbjzxI?wUTI#MR@9LD8Ew?ahhqM5RK9FYXx7>uJByMx`GyMUMsj3av}}BAgXf_KG6SysLnrJ(Iu~iuN;VBR9YXpPY8BGD8o{<@bl7Q}Sw%@)jqp^sijue*;i+&HC2=jnQ{gHq z*!76$9nB8%P;|kJstcz4`b0%P_~~s56e^g}+NvP@xUrgG21R+?SWPg4qhct-K+!Np)iBC$SX2yW7$^i|G(DoOA()X->s^V@$D3r< z5X{IZkEm-1W@MB{)HMV%D#|138iIKzD#kDj6b)lk4Ws-@qhc(>Kp_~T=?4^F63qCh zHC--^I@L2@63qB0pTK`fFyo_q0{M0w)-ieP3$dE)zuU}i*l;`@qV%A#Tx!$8q6M%6IN z@588=!!S??#%OvL{+eLsMe**oKaN@WYl4{<L_nl))D0DC~sHR5#;J9Z&%h4=X(7HqB#&1 zhnNP6rZK9fQGSP`;waNVAsVAKSA+9LqB$P5eu>J9Dd(olMxr^MtEr7db39j58;Rz4 zuBJ8;&55Y^nQ5SC8l!3&<##$N&N2-YqA{922-rk4=cCrOsGL*DX`b0cH0N{ueiPB0 z&-MFFMDuH|-)|zC3sG^IX`pBtqiP!EcO@#WG7S`>F`9m^w3$uqA7I%3LiJPrq$1q!hkO^=w{XvCZvvu3J@@re1~5!*(PQ)4`0 zZX?L4F&;6u5#-bukC@vCa#~Eh&yY|wq)|2GHdkXf&cHDuQ^Jubji%q6*bd0qF+3n@ zaD!7dvz=gO=OVV9U}oncww+*R=OVV9VCLWpcZPwYVT`I_l;6CVn9nRwh{b3vRPF7c z32#vh&ym*MfK~bqf>{*f^SB)Zvna;raXSd6JjUm7I|$}e%z+F8MZ*|X!zjO{_-2%0 zpb(7F^i;BwU{=PguVwG5SPJYUn3XYpL1`z!tc>vsN;?T=WsF}?+DR~9V6I~rC>qA7 z8b>`@=F@8Z|7tySb@vA?( zh~^vo>B7bm<9^b80}d_ z^F4K`zHyuZE>1g7xBH%8`o`_-CoA6*Oux9@)a=5)Cz$^6@SSEC{yo7Ah>Jn-?1TnI z!wirwR(BfXqhk81YcFyc5x2(52{oNN8|*p{( z5KLJdpMhs>z%Of4-ia6oX5)p1hYIYzF-(A8pfy^M)|FZ zi!}@bgmECn3)OdWBFJk=P#M#1T! zJ8d#22xdVJ<^;hk$ibW-n1wl*69ls;AwFdoC>qA78bX**Vt zu)a*t_lEQp0)2sCek?ls8Nx|ItVr;d<0K(gBzVhlk`R|Cc*}8;5LYC`7mNr+BN|mB zl3ycxRYI&`Mo7eHG`(d!1;n)pYimNj+r+7sIYl&U6MX!5ifGm*`1tV@(X31G@#86? zS)ULanFflcF{-9fewz|v3&TJm7^CTR_0KdI?nqdB6Y}nxoazmJCYl`y9-4n9njHxq zntvvmoe8{gv$>Jep+KL3okkk_6V@>VJ(d1rakq~6U41#FiIZ$_nv&a};3vSRDY^X# zegb@&k~@&V@5Q{^*y(^W($ek>k~@*G&L!mFo1Gf%&QKaB68wzf45e`*!OtkpP#Px^ z{L#S~3ed@fIL!fqq62hNKHbD=9ImFaIg4yAB&_QRdGjq!Dsz^yxsc$!-&xA$LW1{x zXDOQt3EumirED%H#1+m4iq6KUx*n9@ZwYaY>j4VYgV8k1FVt}RB&{Jy{J3}e(N6ew zhJL4>k6eBsx;{y!`-SNGB$@6PqU)1nx?hN{Z&D0MW`|-Zn$D=2PWcT?ioqE;>h&-r z!Xpt{4Hc_%pnEH6jR9RKdoKJReCLSott3~{IpTXO$(3}D_})r#C7mO_;Yl$v%a^er zX+R@uK;<|pDMm9N6yh^lTJfC+-`J!z8GPZic_HFI19+YQ$0m6_cAfyoCV4%UIZuRR zle{21PlWF##rQ0t1w{iIRRb!&2}v=D0ih6}(e%LdD*&e@tq%bhMHiv}jR)^+{0=dL zhW{%uPD}DB(XYffEy<@uzY^lKB%c!fN{G{wqKpxtXhfrGMCJEEQp{pRD1>MI8Z0s$+Or&~EqQi>c9Ez) zON!5#3W}yOs-{wY%aUR_Q$Zmrqv@gh63un1l6dWSV>28=Tq2lNNq&rViC|VG`7zcd zf?1X1$5@vLW_410$uLkfj8Qd=^7|?&)-ntff-#yNbuZJuU;jq=-;6g!v(3egy?NcF%gM6)Mp9ZX70 zq))fILNt4lymq`oG<%Y~cDzC~-z9nNc!g-bPl|m^14Yw(FEh#Zfe%(VR{4LhLHhoK5mV>?)P{mn1L5t`g0;q`1H| zP&AEEHI4GSm=u?p1`5#_O&_dYqsDnHY2jCVZ^OarHG;X83(Ylxxt0sfHG;XG3(Ylx z>6H?FQrQ^;iiR<&hEaZfQ=&io^n?M0V2q|8L|qR9W^l?HnZg%Ev=Eb$dol6 zQ1#U`Ch_*i%}qzA_!)gK3qYe&{EWUA{Wj#76hEWyh5xI;E|pzdI!RZl?B5dSQ9Am) zk@TdLRVF9i;oNe2Z~BGFNhxlZy(!sADQ=g&)i2?uxLx+9WT&LWG*FlBw7a?8mjcG5iU z>Pt`yQ#^q7C8&id9zgpN)WQ@GpnZvHQA#XkCMcT8sG3RneU=hSnF$Ip8Lg&%+SLzG z%Tv}@Dr{=a_5TT5KO$V73tK-TT%HSCKO$VN!$wi-M}#XQ6N5a`Ea6BTXm+HmeJT0<-A*;x6^|`$+>zo5X8_UcNb!U-fM|B6!gb}? zJ9%+G5NRApS;ted_dR&xGLTAiAjQ+fKuY64il>Qzl*YjnPZI;FL_em)5iSuFT_U6E z5-Gow8$5;^M!Afu#9QKFl*_nEyd@q+xs0#Gr+!7sKf} zWRQ#?MN)zU1$T6PpRRh>Q@OgQnYgXm zp8nOl_Vn~LKKK6oI{W%hc2!=xg12-@ufgU{r(?cDB_yw1&D=g~X7@v{HV zcJd4Ii}Txdc%iUUUgx&O`Go~VrMyW`wZJ`haL?bpxs963>#g+V$1~bQ)lF&9dCCqJ zccygObZ^|l@I+lx$JXjj;=Q-?gd(I;}R zC35Gj-}PFJ0M*pJjAUZz+qCRFVawECXxE{z2&tGNmCBUL?RA<84Yvj)fiRu_u^ow= z_hvpaegU!R@Kz~H+*pHdtWgcOM1)T>;nIveoYr%x5|jPg7bnM%X<@H%%JTkiiPx^vTwOvEPWld4 zZD_yW#Ik+cJw|@}hw(?=>qw5j^R~Cm;|=XH&*T-mmF#A$YstyAGc~gZ*?QV&yTZpF zyPO+N-t}#7j_)DcDD1@VeAjzlvLteql#jlhy;K|x3t3p?g9TxTY zOn%2M9You7p0w>k-TaZVSND!Tx+!abHn+GsNjp&3u3eYTo%0IXz_QJDmnlx`2@t{#;|Ng=snzU&CNRwxqMGQW$w)a8* z;0J8W@nvNERMNJq_jx-t%3~YMa#;I}w>`N7JW029no&bc<`*T~COe{6w0$`*`BLY? zj>(Q~i;D9)CyNVlp=oxJ8e6Q?lm%e z*SxTTAv;x((#_0p*sBJH%S7^7B-5^tQZN3e^fEJ$!mF7_%X3hzZW7@Gj+=NNR|u^lh#Qq#>Si zx7(vVB57P5<}qsK_V?vq|g#P9X| z_9$D?D}`M;vm>h*l?Pi##R8`mX#>hseqi6;k zlY278dU$An_wF)OzX7;&Pi7hWP82LaF&pJEO64Z`gau}TcO2l&5{(;IP-(a%D-9;@ z8|5}`mkehtC+`t=E9U8!)JGk4fK6kYhalxPZh}+&#vVQ_8Lh#w^Di z=gFAp;W-1n6-p6f7QDxKGH&FfM22%bm8;lI#ErN(&%gv)U~~F~&yz=b_#Qv%^SWnI|*behk=Qs?rr6<10RUsP~}+Ok7<-Rr+VT zN>#d=``-6HpTge&S?yuDB4?QJT{9S*LE#$(;!fI8hGSQ4S(n#A( zB#4WfJw|SB{bcHy{Fe*bb^tjocIC25vPf#IY85zC;n6m~K<7eD7YT`mgj~?WVU{Xw zw-UL)+UhaN<~B+G`%I$g3}`?{Somw>{iRb#joLHr7u|$Rk=$)%?7yN#L>zOcCu1)v zqK&SIbff5;*9p5^ER|{EEW*r$=un|Vn9|jf*D=3XRYTiNkhavXfYA0 zgWWWQSHcuSYoOoPHPyRLSj{202Q3RX$>9hlQrztc%asdEw5;7Uf+AYT02lXoj5wmz z{3SFCw;?rdoP54jnSD3X&eQpDhh zB<0)!Q{7AY2AQg4WYnT60eml&QuXUl_ zA!YYQVdqz|E@@A}Kofr>ud@g=rdsTZ#B^(`HuVM;F?z<9t(-*x^qnU&$9@(8h*E?P z(d_>yBAm)gcF4z;;?{2T?=$%~dC=&FWHEptT^SUZsFOpUKpD}!VS?e1r>vaUVInAp zFs-NTmjn~89;I6UqgYs^<6Mj*5yc}aLUtVNqF%~>xr;oN659f`nQotK-=#CHkWj5? zX?ioPT1r$@6*96El_w=K&Xi3Bmw|PZno2(E%5lsSo{XPRSAV``xp+%`y;9gg)QH&D z-2y}lM)Zp=9ndcjScQRch`M&+OEz^DbzD+iRgnXZ6SR)G8{u%MihlGMm2xYHaEJu> zTM1USE-oRixF#nxn7HvrPiZ-)z(fIlq&bDw8V^=YrOGdwn9_<1>)00#7&iB0Td|Ts+}k`{Bj3(x-8t{|{Jb|FZ!HfK z+T;~=!bxPF)6%xn`IqwW7Tr%`Q4v-qa`m+^%mY9 zf6y-VSb75@jwycJ*eJT=d8j zpS@Ofet3dWr)1E2@1}~+S4W%4rdB)NFdoDqz_Vx_vIkISOxKg@hj2y1fbU(RA;3ht z7;oI&&^6I6#v6?vD4BKE`$3C8H9}`><(R{*3~qU_A8JD*BL*!@r7x=`k=sDsh{*FA zHO;eDNPf7LnY4+s6>*NRFe~52&p&LG=F3}ntrS#r56E(8fQfrXSU5VO(G4b$Mp&3~ z?*&Y2#~f{Cdkqt<3VGleZ>}v6k_u=;Yz_I4o+6-`&NV zNWlmF2~M&``T7lXE3MLX9Yu(etHNiO6dZd~CBp9D|%<)#nR4_Ky0mEd0k*A_DL{>hdC15jt_~4Cn?K7u=4x zOtJ#^ia@{=$<=G_n#gOC)#53!CViULH&ZO*q1>k=yf~W_)2i+jG=1PeGDRnnB6Gxy zw4z%UADER2g`Yf_&}1~^A^B=5Zz|aot9m2X1j-cafd}mdsG>Fwn`UKH7cDy7h|>M= zu3I{Lk=xExTWW)Jr&%~Gpl#tZ*m7AJ3wf(XtY+*PT$w?Qy(S|23{}-EKBy6^))apu zYO~y^!9)$svH}so1`{{VvZC_P047jpSy;-p5b5CJ9E*>>-)t(jgwKfSWR6?EFoi>= zT!X>H-E%B#3F*l38H)Kli$@!qx7Vls1Oa)^OW#dlnrC6#pswQXc@`$4r^Vf{=JWOU zv!_r7FUl}KT?Q}GnQsN;;R~)Jo%t5d@MSt)y3;D1YnK=4;MNk04=pmwAR)|;uv8O* ziG)hh3DGvG#B$^`i>pYe#KNMu8N&UnV=l5XR`D+yvxn+D3&%xPChbhU>{-h(7hA?s zi`iim^I1A%c7+;1c$TVTIj{Kut5OGlqp2*n47nkNiIkRGRqht-Fp=eQt4>|}ad9Iq zuH?NQV2{*&Ry6mOmfUn8O$z-=t4tMjBZYpYRlNol)CPObG1piboB3}Kuv+z=gX0<+ zfQa=uuzqbB>lEvAf>jCF4zcHmb%SNp<6RGUYdj1S0fR0398Az}umX}ECeq$uRgpsz zOh9a~>fS5PC7z?|-NZLO$m-R_-?(v;Wn>9lm~_nw%H1_g+_=fgy$z?72K$3!ZnHAJ z<@Fw7wd?!=j@xJqBKkjoe!FFSqv-!2=#@Zc4Syi|-F);z?9tlz8IEm9hPLl`oziWwuER6}EfqDoxdM(ce+&!}$V%Rocn7V>+Gm!lRVI zgEhU$Og9a2_!w{FXji@9;M*pdi3i+3=5#(U^HJ4pe;^!+oCIJ20&wn3yhUgMjf` z)`DJIr)M}dWZsMV9OTQqi$`ym#i^!9te#_z@MVlaV)y>@JpLc+R4qW{5pz8xJHnTF zw|y6yT|H`cBYj5g-11^}`L|FhwYjG|&>wJ%JP3h_^hWw>)D+zVCK~5RA0CO&y1X8B zq|rVjq1NT}u&ayMWKaa$YWQ8rX&}UH=m5A+p8mn~R3_V!LqBX}GTIl5+l7cqeVkzW zGCo06+K!mxZ}KOsbqdrAxl8<|4o41L_2M_=_$){EeFp=COU#8@P3BF0hM}ug61i_r* zGp717b8}^s-FE=~Ps!nn=)$>y2q#P=JjK@#FC`Ed3f>k6pFY#F@l2cDhdI@kaZhgD z-?fJS+LT3fWMjAw$xlNYvzsHi2AGz88T0v9znPC@Ho%n`zD&G3VOZq@6om#9g}E{c zVpfEJD9m-E02BAjb)x_iQJ70nNQ%30@pC@-hFPX|15wJ)eMTiwESPl7S6MEvVdBQm zefQL{8-uDLW-`A2hS@055MCv|%$oLKdT!eYvzjmCYhLrF*{^ORTw3H)vsxpX)qKWM zRh^CGtR@92NFyr9@>>c50R>s^76c}eTkaMFCU}-pLGs1jxVY+r`>nh>jYK9cwX09+=z=icz45U zQM<8-!45YDFzK4lFBeuYapMjj7FIM_J>ali{H9^mF7*Jsc2R{jf&8Fj?(tWIbV`%D_Klw5)^S}G7yXrQ@m7jcAQ&WJNg6W*k zxS*Jt3MM5`qi;$~7jJ3w5YXr^x{V$tlDp_OdYAyaNR6K6iKf&-FY&uf>yEnk8zs8r z&Hym!ny{9I3ZVf$W^ISM1&s20>aD~Zt_28TGi@2iqr;~nN$WZdlV5GY#KvNIc+^E zBBTTwav!BCAI2NmRvG?Xz%m~g=EeqBP=&+HfSk2qBHdwTt-7KLVS;j)S^qv!M~^us zH#0tzbyW8;cyd#XFpq&?q-l&+1dj;sWiV1i*R)ekMKAIBVTW=!V)8?b8Ec^r=8Of}d)4utWhF;O8rE)bMJL3o@HCW#<4 z#or*9qua2f?SN zFFfQcZ?aDxC7DbH|&022r^+#o#Rn6u4{1u_Uvz;U*T zHw6gc2_Vcdjd=>;34x#l3c?eFFkc1%e}iDYO8^r=nC}JwCK8_S1_34r=2H-EM-ZNL z%rDH0<-B~Bl^c5!j$cq)q#^G~%3%?2nPt`FXYXTXtwp9<9N`Ka7n$m43MNuqWMcb5 zFOxh;1IrRVBg?8#2Y&-)iKz~YVbV3Tw0sl@6E`j~@fHg0@SmhT`Z8_=t;g%)Z$xpK zsn&}yapN*mEf-+JefsmcJ0Tl6CPnSl60rQwtL*tu@QYqhgrIa;;gZvRwvQK83m3%-GKF3RyX^ zr{K8J#M2FWkLxK)ax?!^$f}yc-$-(^X~^gFFyXn`Ovr^cOeDG4#1lt)+2|=M;x-cx z|Kd38$6IN5Sf?I$VUs+sv}Ji}wIpns{_(X6)m>uvN8A zOL*=yGwC?N%UU8XyG-L-6_=JGE=r)dw4_4p6>&-7Zxmv$8yA>JX|Ed>n8|Y>`U-f5i3@Qj!;j!<8^D4;78LqtxC{Zlt2j- zM{&e@R<&6i#GCu3Z5of0t%#2_8U!qp&I&}1T!42Be$ns~?_(f&; zw8&Bk6f|)(dRfLy9)VtVVD;h-iH#mm>E@2hz`%8RgbNdicegTScY+Cw?iOBqq5ji`>bz$P2cPIbE%2T<&gbzy zdAnCT73LSX=eBS{W9aFgA;CljJ>5e$n3RU)$TursB7>e5-mIV@w+)R?y)C0mu9JMG zIIq+1A9GS;g2z2b4?Zq1G4-}EG^oo+uQwen(Mkztrv0ppAr_w$vntf3L(_hi8vUO^ z-1=L_Koz%VMBJ1>gAC3_9X8nFzs68mFwmk0yT^F)7wAJUigjMofBJ+tM-k}4{ zl8KWuXL*fHr%RssG;{u=Rlf7s^r+)M9+5Z=`OwVF?8^VqNu8tQa?j4poU;GXm2bR| zS=MlB^5~w-+v=-FWq0X^7|yRA`}Y5`QFj{fZfi5+<^IdTSmPs^Wghr1N$KOOzrNqZ zwlkX-$O)tSj_-#VPgSTVAn$b{+qtI`=>xWg|LWAu+2ygz&{d`r8!JL z*1E2W8hUD!oG9(D`e@+OpvDOve;j>CG|}*(b$*x!}V$NKqhm{rGB>D zZ`v{2FNE3aA6i^Nt7*3tezpp3*x3j-+Sw>6PvHIyXX}1rD{!B~8Ct{C%i^teHx~9k zpHm@`zk&2?zbTH>$pq=w{wggL>EA%Q)*r5?NdE@Xb$+%{lm3m(UQeJ(l%V9b$IKZ#X_K!@x3W{&YUk2z^P<-RB_^4}wVy8b`M^U_r z^mqB$9!>G8qEMowAg{fCwqJ+qRRN)-Lc!3%F%S8TlYZ{Z_2*aV0Ek0=leVo?Egb-H z$e%0rQe*<+us>Xb-~PG($>{Uq{*I1$%5PjmxF6%SZnF zPWj8-@0v*Gv_E_&--C35vXvD$<|V&zldqfSzdKcc+b{V|F)||G0^EMdUs~+S$;9oK z{o(2=76mBf6+gSCV^N@Dp+p%A^1AM4H*_otL@bnqml=_8p<{Nljs7-&^K*ZrR3RX` z*``>s(?vjZv&%oM%1{W1?smAgf+$4tJ#5z7b{lY^f>5G_Ag?|)dk zhxM}=c8ej~qkiH=0gh9AxW8(+Rl` zw%Pldv6Es{qGTkmVK)0fGj_m(t&_)Q1uQ}!@+nC7X3;g9%uYqxt zZHfnjbPbH7>>N6-rY80p7)RUTsw$4JA@?yh`$#jsrWlnd8Odv$%_eBZ*94=I?iUQ5 z9doK}%(QuaiN8GGxxk;O(itREZBsnzMAn@_GSv==_oT=K$yB>)J^r_aXq9-Y8CnrG zqqgw{f2G7opJSWi#d^Ai+vnI3w|x}h_Br+)_w%eTa68@!hD$MqPTN?`!wdcQ zriyX<;&dE}ar@$Q9Ex%Kl5`x3QNpD*Tdw0!tm2?V83*!OVY5{_4#gr4O2R7_VCaJJ z&^ET%eAPmK#Z(tSd~KUzAxRej@wJ^0djK*4@wI*Lld3YhAo;a6+n^!3C5o44 z259!#rg)xBUZB}$J7Tv)CTRBAcRj3V-atP4ZT6j}c|+0c=jRsrYXtG52i)|gV}5TN zXZVaS{STzxMEc*S(|;4`f1ghOO{D*QI{i11{xO@K(CNRa(pRD^EqR@^*(qJxH$`cc zB!dheb{btXsm7Bxsjrft`s!OaY_*LmwwJf~t5v$Dyy-2lo=exvTVOqxu9>&MdJZ+i z!%O^6hVYsJiu!lQ{M9yk1bDT@{tBtTl=6xzsy=7#h$W?j1l34d@Q{4j_gGbccHf$-f_9!!^V^3POnzg1kNmFs>oq5fDm} z9qu0(Mgzvg06JW?20pjL{R1##0=mQf12AI(y2JegFk=F`!~FxyrSRrfJ?4G)k%?>L-O~g8^fHfVW)cua%)A!;N|BY_LWb`$-XP? z!%GOqwZT(DJgyC(5(0{d#X(o>>dl1#_9gO)+csJs6uBN%>MeeY0&EHV96M9`(SlTI z56Euc4OGhm#yUQJx&I!8)Qymq2hf)#0)D4PHafFw18GG7y;)|}jgVFb*jE|}ibPWC zEq-ePY^_FuA0a8Na!drOJCHU6j2&n@INjhwoB3kA(+aP}o6sAoxz?!>i$x|VB$ zYXkKynO}F}+8AJ4G#3=frPN#ewg%XC%>_T=Qd%{E)dN2PG&u>W?r<Jl&0qT~avi5sZ$djCz&ce{ zS#Lr+AHYJ@wSjgafR(DuvNxgq9AK9;8Wf49)LZ;62iR4O20ubmS_OgD2jh@=Jz(_A z^5S=`E!uV|5J$-Py={jMI6{^PeF*V7jhLwZ>m9)sNycDjgsEaFuKwpMDgdQR(>fBbd=yY^;WXB4Ly&VT9j@S!|q! zfkH4!Yb;>;(*QLw%b3pp^0mKy1N`I{Z~bOrOD&s9e}b8qg*~-v1Lor_Y^r4@{Rw7L z7Msd%U*}KUHx+r^a^aIKHcjUaSIS*!l|=3Xkok-(V?NJai#5;y%6vwa8B>`Jpv-4v z#j9x>na{|o^SH`m0A)Tii_OuIfFdKIRLLUz=4P?aH47ACQJSolffR{_S;lg{bS3=ZTHe?w)c-i&-nk@zs%!YKJ1{2JNbf5+k%*J$}1{2Jt zEVf0%K#?$;_>B$znxQ5l1VfO^&Mf0QzGc0?QS~8|%g!vbj0!(~cujWMhT29hyRyQy z_=1i8R@iLN#&IZ;IFx0a`WrQPpRi74 znYSy9_X+D%R`MZj1M75F_-=*uK4G27V&{1I&FFIHkgMbn7tUv~pEU;*;!v8LX@>#h zQkHRp$2MW29Y&d7N{1D{^Cnw|br@CNWrS7bF^n?5lEr@2F@YjuqEyKu{H|rO-!uyp zVo{o0-43V0xm(b9FUXf~LZ8F0@Nsw+)N9D$MAI#(*O0@Brh8DYA%_!9k09$6bSEw- zl18bLM)>s(vcB+>^A;4MQQCc?hCU#g0YPI}kQZ+D*G_#vGy{S-2a@gM1ELua#6gg2 zgJwXmYCRR34~S-9kPX%}P$Z2~C5`YK5@hdd8Yo1gwDN*x1kLm#g2snIFP;_CfqL5x zc;4N%=vDg7j3gLAgd>7_%{GDvM+EhnZ3M+@L{P8UMi3znvQe54iX>F3BouzNMhDqg zjR=Vll_uAd9E9V8##9ih<>a}Yeji`9%kNKdLK~k(!%LN9Ycx)1Fm>LS#z2M+9%A1GG+q38_i0eC!$%U`v>=7YnDrAj*CH#5j)YdR=Ir?eV!038Loc|rWjzS1^-C47?y zF!O?XHG+>IkgZoEqX=eRP_ITt5zOa7R-$2`NEoF`7~!`t$iCDtPzXk8az#0sroAOW zoIY;e<}Z^PO)yJ>dS^MBV3q{+&T=%tED7qJ zlqS~$V+dwV(AdN~ZTFYs2X|mkvL>kKIeg#&uZafrJU515)&%uDH-=!o4zhI`28x7H zs)P}K>w|2ghJivbO1oP$qOoAv5;Tn6JifzUBQ=&HvnAbT#uChybekDVFk8}XW-P&M z4YC~?28x7Hs)P}K-vrq%4FiQ>lqQFR4{12q6EqHq;ouQ`TM3^i!pE8l+vjy~hl3A^ za8FRLv_B-mJwd&a{E!Iu1ocY$Ln7Q8WcxKC6iKL5Nhtj8I1pq9H6kQJRGOS^KLX*A zpm8$j<>PO#T-DV+BB&$juzf^ON77;Yh@g(7!}bwD9SyQ$8VZVpQmTX!e#e9Cgl2+5 zOiF7ZyV^KFoeCNkcyOox(bPDCIhAfu;|S(dxZLA}Anhfm1X z8{7#*b2F$nxD$w`Ylw9ZxsxyyNuyLrBm8=VSTFd=Nf-*zC{1o2ClXEHkTEpGTkXc~ zaU#+54QZN*MAJ8pN%nGn)@ut}85#1NaLbAcjrQK~Gm@S7ZB zQ+1J{P?44PoM4$uMV=lq=J99t_*=D@Ofb_!W<|B+m`pI!LzN!aHefyrh3l$C&}4#{ z5n{763=|17gSXh{uj$CabqaD>5Hgm8yt%hiSBUe6_WGUF6v}QvNN)kBP<9JKdJ8y( z%Dy0^w}4Y9yOI$5LT3j>W~WqHcH#GBh%MG-heBmnn%n~7x41ai4B=q&+r3zhP9>P- zA-(6I@Akmfi{q&Tvm&Gy$5RPrWr%&HVW3DDrAip#wtTmFndGnfQEr0VU#Ligx_}|c1XiO zAsD4q5irvT=KByHm|frRPqvswFyE(p+i#j_fGCQTpvJ1aoLhQ0GI}|Fr(pm_p&uATREoAfv$E_YUjsIfK}HhjsUyLF|3Px_izbq`qO+A4u{c2ozZnefgQg{+eM~?=u1Q ze%Kft=3Tz?KiOa=W%GVmFG6QhHt&b^B6KFbFFPz8b`OeXQntgw*){p#BmP$5j-sHm zkmg5WV@jAmf6)I@Y8IvWQCJUi_~aJZdYGFbn9H9z?K)nhmplB3yNB)KeXe9fyK!C*F*l3W(n zON_acBFt9l@rYV@LL^bYjpXbP!~|z1EP5IXfRtJ zHn#KANBsAt<`K;Lu%2({5zP9qo^R(7%!aU@Z|4!r#xUEgVW3DDrAip#w=sm}>!XV|<=1@m))*%>Z#pSA(BD;#!f{&Rxa9cFtp3=|2Y zR0$*e_J-Mh4FiQ>lqScf&#NIhb|8#b5W?auJ~)+E!vBic%?It_FwXa;eDD8bYCh2( z4&yB+dH;N(JsiesPOc5wBVoMfBq`<-?a?qhrfH!_TBY9NcRb8aXj=FYtlzITp((9qNU$vg+Rd=hJHpQ$_g79WB$}J)F>xW$ z+)R&&3n@@H!+PnrkZ8I_SoerKoI{Z`N|iLiuSbOSf}h-XKp`5X$u-3nV&4%lhD7+- zAN;qcz95>u5j{eGK{S0MdW8OhX!=Gp%@;)TUW5(MG*BdsQYDS>8yI1OH4PM^QJSRr zGJ@C)ix^`f^jYHPI<;@h&zkP>4opRYmhz44A1A!*H8U~7lQL2OyexF6yObr8tV3a0za7*H#nG-R-K!ue` zK370*VCTQi&!52dehHDyiRk^@5+a)u(fhe2L^d~~_j5~#Y+i)T*JMy6nNlU0@LLdJ z3pE)OB2!vL5vHXyq%VpXEBS#R{bf^231(45&!0;PW>G}XpGyg5QA7_%O9^Iige}!D zP$Y~}C5-S}7GWzi3>1PU48nMlh?>&3PHY zd=+6|YZxdJMyV1;_^pkw^%@2W!6;20$1Eq9%@LgMRX&Y3e3lc;=7^qomlMq9h@N?u z6U>%~o_UuO%+?6op<$p%7^O-W;rC61?b0w%2u5jgm9v6i_C}1O{M*z1mZ=p4vp1qo z8t?&6vh_*B3WC`e(I*Wn2xfnT9n>&TB#crejPN@YVMjC!6oOHjd=j#fV2($Ov;4C& zc)@2S!5oj6?)w8P3FdgD@)K&FSV=HHM8fx~iFGBxoQSX=H4GF9qf`ka{7yyK84Ux4 zV3bxQ+RQ3~xezg~@%Cr^9jmV*mGYg zh!}mN{Et8Ro26D$5;r2I`^v^@O5$e3yjLZ;nv&=m4ZDy3S5tYqMOn|NJJUds>fhnLvDBAWhD`~pHQQ@$da{!#n{!nHv& zAd25W$Xvc6nt@R^Skpj}G)le2Z%CBAuW8^%G)nt}pjksSBcjGfQNHh-zf}sKmc=8Z zsNUVKA(|0Uy}MmQvEfm@yIn&xBcp7zrhy`9lqzY2-9J(nCVeH-K`~<&!TLmhJhkslqz9_ z->fK`qhX*BjMAEl(5xev&!gA}g@5)pPOT%D&!c*PSVu6QNA&=)j$r0T^#HMsU=~E# zLJb2&!YEb32){3)Y>|e6LNH2`L+5&eSsFFg@a;ci=v+@QOQU+|Tu(4dqk8CEPcX}( zdgxqFFw3KCrG|kbVU#Ligx{(t`%1$=AsD5}hejJ{jk_*tY>RqxWAw0(ce{jXcLPDK zi|VH#8whG$R6h;bKv3(V`f11pf?6MCn=}*@38hpCCHyu=*;Wk&g`kuspN(#$21P}MZzdm!U(@JQTCH&fkG@wlPp`P{w_p~tGw)GJc<12 zD(de-RBs2i5X^(zSGx+Db^*qwI!8f+CTWDv^ZW%_!^U zxU(1(LQtj;;R47)DEH<>FCbC zgJ?!Nx~6sz%}7Ug{vAX!3f}_JG*BdsQYDS>`w$0%ng$BdC`}##d=moAM8}wp#&bvV zsY1Mo!E*SpU;U+0-w@bDN6%>A5ZFXV&uHHe*vF2Z(Y_(DN!V6tFenmCsS-^1eS*nR zgFzuMrO8KqJE@D!bny0H?Q3`qVkg1OOxM{?f|;4Fvz-JpGhJsp31${{QW^$|gi)%5 z5q@(qIcgXv1fw+B-**wrLI-dE?YoA-XBWXNboAh}i(nQyF?W%+i(nQydhpprFkj%8 zS{eq5gi)%55q?WB8EP0P1f#T8q5?$MzWh$M?)`fRaZfs4dkAr_!}e=LC=yYr5|R9Bm+% zj7qB~YHu&%b<{CV^WitJyvLXM@x^jSAF}Nwn4^w9WZO$H-#hw{Z7;zb!_1~(phy^{ zN*LjH5^pPL7AVA`G&ytZqssfqF)r~ZZeqi|k6?as^b2GA2<9h8zm>I*V19D+3uF5T z<{V}=4Fg5OC{@A;zn^h6kpsY66HDW)GT;D06{TR&VlL{yVw`r*PM zqM8z8pXh)?kx)vNwJ-cWjj`z(3JO6ft-XLcOdVr(48Pg!?zR6nb(mmg$IKdP(|wp= z=ETg}+6K(rSh$hegdZlDc`-I$!$6TRN|i9eZ$XSL)G$y8Mrra@=OYBOIEMeD>DiiAT?55zL{O ze(-dRU=GFfgQsH@qQf!$;OQ8_9Eq`G8U~7lQL2Oye#c|%goc4aFiLAFV2%^a>6mem zx3}!aQ^yJBbW9Jb#|h?iOb@EZ3FdT6530ur=1h#8(=bpZj8Y|x@H-!4KWi8$1fw*0 z()I(vT#n%b4fkZ)kEVVgn9J#r=LdqhoE~|8AehVPk>>}3xe{a7Gz=68qf`ka{I18? z4GjZ@V3hW}h|CEZ;k(C;{&9XY({7zQK{Vat`po16(R7dNGm{fU(<83WOimC@&p7K7 zcZWeJl18bLM)>uOvwrZC=YCL#Mrm@ndy;4d#SI?kyZm;=)JdWl6xTbZlSDHpu6Il) ziDqzI@0d;!&5$@7rfHx^8l_4a;Ws?aMrax+M58qM{?(5}GdgZejPtIxT|M<9(TtAk zUm@T>FdUx4ng)ubQL3a7e&gb7f~J8&G)j~2AfBRTGdXU| z?%!{)H8U~7lQL2OyekF1Cg@%Db zFiNW{+VB~ISrW%d+s9erNE^G06|FAmyeQfG;4U0e@=_#bA-)&f9Kg4rF%hh?6M+O0y5i!{z5iNkRW>6b(H!`07G5{J`CoTDTTr;|8G zNgSahE>aS5Ep{GBoQ`84UKX~?v^Y;moQ~^-|9MK{bX+g|&r=d-;(FnKoyhOGYuy_5=Q?7#|pY!>SvPxD3V60l1BLTO|X9OlS3^OqEVVWoVtkrRMZ@l zFh(T2d~*w{G%pvm&D2F=8kEq_VJ{NXpoD%7dy$w1C-igJi^Mb}!QR(QP$ZL5C6n+Q zmS7)fCMd+DH2LuS7uvawN*EJRZwdK(9{nGC;iG=(ANBk~n4=QA2Fm{269QYB_eIDV91<25G~;#68$(GM;Gb5g?i447FsEy8x6xZo0X=!Omwp(-QVadNtX$4RvPaz;muC4c{+GdB%M-a@Pyyz3ARAf zK_NP&-7aEx1$18~@Tk3Z+%BEELNH&ZEAk4#e3`DuE5!0;x+1R-%c2BZs#%~&7NtrS z;kPWoR%jL|#G1P zfc!uO+yAHEHTG-7xi6vD($|P{UqY{?uMy|IgkDQuBhLK^c2ILdk(~Sapwf2D((>rx zIslF*jI#i!gO^1*{SG1ZJA~_ma6H|Jt`oxXbR)V>2*=Zn=sF?%kYGP*1Sk?gsWMc; z?^J@F(V>Dup*j_pqMU4g1L^sMafOe}!KnTl!JJR%gY4f3=6pgQWdBAm=M(xM`!|BQ zkYE=z3=|2YR0$*eeo3&)8U_l%C{2#)H)xKzmN0r|^Tws@WaCI5>4N1y-d7GH1B2W&BaZk>6gt0Xc{PzMyZlU_zlcvgEb8lqEVXMk#@~A zFyCYwW3&0X(sqLuUFnOW!?N*JI=LV2O5bK3mW|KSxi)BqXXCqcGMBCfXgGJkPlayFZ)O9_QasWf@A+8w!1&&E$vn%!pKo$5}RPfwSzJ7qpS zUB>Q|`Df`ecBjl|WV2bij8J45l`2_;-|TEQSF=DN7NyBYPCYEZEXX#N@@x2iqfFhVU#Ligx|Jo zwnM`}AsD5}-%a9M!+0e*+c?ParNrB+_a>O#+4^X$H^J=A)<n6 z`2CR0PU`YNq4FqAmZvY3=S;To3vYY7eH*_E^K0r%x;%Xe=1jUgeF^3a%A@*OUxGQC z&CY8WC=y1g5=Qu4$YvKc3>1Pq?@y%!Tqg3??)%1e)Jqp;Bbim#t z#w+Q7y+@2!(gAyq7_Vlt>zWaYWK^nT6n?*Dvzj+GBqV}_NSd6&`T?>>jxhj?F+4ke zy1ZQ~)sK*RIm&5wQPfi_B zh)8L2E6|??>%lq3$Q=GedHePj{fTC9j-ERD6V2cpJ$3XanjtxQ>gZ22Lvz@0O#?;J z4CS9!!eITV7!?K}n=v`Y$2t65dHccC0Lo@ej$U;QplrtE=vCJM%4STCUUdziY{urW zaXK3)G8?7JLJPm~Ic%aXG!!bd(&YYnpaYsIImS%XL;ZhV^yH=n65Ny=-Kz!?+>{*M zs|FI>lpNiw1`^!V95zkEL6LAum2kpudJdbR;h+$l(&RgNgQ!={$-$A&KPuX_Q-cU* zPP%yvBA7Yp<}rw1=BAs+AcC2f!xm^5C=y1g5=Quy1P determined.api.v1.LoginRequest @@ -3721,449 +3658,441 @@ var file_determined_api_v1_api_proto_depIdxs = []int32{ 102, // 102: determined.api.v1.Determined.PostAllocationAcceleratorData:input_type -> determined.api.v1.PostAllocationAcceleratorDataRequest 103, // 103: determined.api.v1.Determined.AllocationAllGather:input_type -> determined.api.v1.AllocationAllGatherRequest 104, // 104: determined.api.v1.Determined.NotifyContainerRunning:input_type -> determined.api.v1.NotifyContainerRunningRequest - 105, // 105: determined.api.v1.Determined.GetCurrentTrialSearcherOperation:input_type -> determined.api.v1.GetCurrentTrialSearcherOperationRequest - 106, // 106: determined.api.v1.Determined.CompleteTrialSearcherValidation:input_type -> determined.api.v1.CompleteTrialSearcherValidationRequest - 107, // 107: determined.api.v1.Determined.ReportTrialSearcherEarlyExit:input_type -> determined.api.v1.ReportTrialSearcherEarlyExitRequest - 108, // 108: determined.api.v1.Determined.ReportTrialProgress:input_type -> determined.api.v1.ReportTrialProgressRequest - 109, // 109: determined.api.v1.Determined.PostTrialRunnerMetadata:input_type -> determined.api.v1.PostTrialRunnerMetadataRequest - 110, // 110: determined.api.v1.Determined.ReportTrialMetrics:input_type -> determined.api.v1.ReportTrialMetricsRequest - 111, // 111: determined.api.v1.Determined.ReportTrialTrainingMetrics:input_type -> determined.api.v1.ReportTrialTrainingMetricsRequest - 112, // 112: determined.api.v1.Determined.ReportTrialValidationMetrics:input_type -> determined.api.v1.ReportTrialValidationMetricsRequest - 113, // 113: determined.api.v1.Determined.ReportCheckpoint:input_type -> determined.api.v1.ReportCheckpointRequest - 114, // 114: determined.api.v1.Determined.GetJobs:input_type -> determined.api.v1.GetJobsRequest - 115, // 115: determined.api.v1.Determined.GetJobsV2:input_type -> determined.api.v1.GetJobsV2Request - 116, // 116: determined.api.v1.Determined.GetJobQueueStats:input_type -> determined.api.v1.GetJobQueueStatsRequest - 117, // 117: determined.api.v1.Determined.UpdateJobQueue:input_type -> determined.api.v1.UpdateJobQueueRequest - 118, // 118: determined.api.v1.Determined.GetTemplates:input_type -> determined.api.v1.GetTemplatesRequest - 119, // 119: determined.api.v1.Determined.GetTemplate:input_type -> determined.api.v1.GetTemplateRequest - 120, // 120: determined.api.v1.Determined.PutTemplate:input_type -> determined.api.v1.PutTemplateRequest - 121, // 121: determined.api.v1.Determined.PostTemplate:input_type -> determined.api.v1.PostTemplateRequest - 122, // 122: determined.api.v1.Determined.PatchTemplateConfig:input_type -> determined.api.v1.PatchTemplateConfigRequest - 123, // 123: determined.api.v1.Determined.PatchTemplateName:input_type -> determined.api.v1.PatchTemplateNameRequest - 124, // 124: determined.api.v1.Determined.DeleteTemplate:input_type -> determined.api.v1.DeleteTemplateRequest - 125, // 125: determined.api.v1.Determined.GetNotebooks:input_type -> determined.api.v1.GetNotebooksRequest - 126, // 126: determined.api.v1.Determined.GetNotebook:input_type -> determined.api.v1.GetNotebookRequest - 127, // 127: determined.api.v1.Determined.IdleNotebook:input_type -> determined.api.v1.IdleNotebookRequest - 128, // 128: determined.api.v1.Determined.KillNotebook:input_type -> determined.api.v1.KillNotebookRequest - 129, // 129: determined.api.v1.Determined.SetNotebookPriority:input_type -> determined.api.v1.SetNotebookPriorityRequest - 130, // 130: determined.api.v1.Determined.LaunchNotebook:input_type -> determined.api.v1.LaunchNotebookRequest - 131, // 131: determined.api.v1.Determined.GetShells:input_type -> determined.api.v1.GetShellsRequest - 132, // 132: determined.api.v1.Determined.GetShell:input_type -> determined.api.v1.GetShellRequest - 133, // 133: determined.api.v1.Determined.KillShell:input_type -> determined.api.v1.KillShellRequest - 134, // 134: determined.api.v1.Determined.SetShellPriority:input_type -> determined.api.v1.SetShellPriorityRequest - 135, // 135: determined.api.v1.Determined.LaunchShell:input_type -> determined.api.v1.LaunchShellRequest - 136, // 136: determined.api.v1.Determined.GetCommands:input_type -> determined.api.v1.GetCommandsRequest - 137, // 137: determined.api.v1.Determined.GetCommand:input_type -> determined.api.v1.GetCommandRequest - 138, // 138: determined.api.v1.Determined.KillCommand:input_type -> determined.api.v1.KillCommandRequest - 139, // 139: determined.api.v1.Determined.SetCommandPriority:input_type -> determined.api.v1.SetCommandPriorityRequest - 140, // 140: determined.api.v1.Determined.LaunchCommand:input_type -> determined.api.v1.LaunchCommandRequest - 141, // 141: determined.api.v1.Determined.GetTensorboards:input_type -> determined.api.v1.GetTensorboardsRequest - 142, // 142: determined.api.v1.Determined.GetTensorboard:input_type -> determined.api.v1.GetTensorboardRequest - 143, // 143: determined.api.v1.Determined.KillTensorboard:input_type -> determined.api.v1.KillTensorboardRequest - 144, // 144: determined.api.v1.Determined.SetTensorboardPriority:input_type -> determined.api.v1.SetTensorboardPriorityRequest - 145, // 145: determined.api.v1.Determined.LaunchTensorboard:input_type -> determined.api.v1.LaunchTensorboardRequest - 146, // 146: determined.api.v1.Determined.LaunchTensorboardSearches:input_type -> determined.api.v1.LaunchTensorboardSearchesRequest - 147, // 147: determined.api.v1.Determined.DeleteTensorboardFiles:input_type -> determined.api.v1.DeleteTensorboardFilesRequest - 148, // 148: determined.api.v1.Determined.GetActiveTasksCount:input_type -> determined.api.v1.GetActiveTasksCountRequest - 149, // 149: determined.api.v1.Determined.GetTask:input_type -> determined.api.v1.GetTaskRequest - 150, // 150: determined.api.v1.Determined.GetTasks:input_type -> determined.api.v1.GetTasksRequest - 151, // 151: determined.api.v1.Determined.GetModel:input_type -> determined.api.v1.GetModelRequest - 152, // 152: determined.api.v1.Determined.PostModel:input_type -> determined.api.v1.PostModelRequest - 153, // 153: determined.api.v1.Determined.PatchModel:input_type -> determined.api.v1.PatchModelRequest - 154, // 154: determined.api.v1.Determined.ArchiveModel:input_type -> determined.api.v1.ArchiveModelRequest - 155, // 155: determined.api.v1.Determined.UnarchiveModel:input_type -> determined.api.v1.UnarchiveModelRequest - 156, // 156: determined.api.v1.Determined.MoveModel:input_type -> determined.api.v1.MoveModelRequest - 157, // 157: determined.api.v1.Determined.DeleteModel:input_type -> determined.api.v1.DeleteModelRequest - 158, // 158: determined.api.v1.Determined.GetModels:input_type -> determined.api.v1.GetModelsRequest - 159, // 159: determined.api.v1.Determined.GetModelLabels:input_type -> determined.api.v1.GetModelLabelsRequest - 160, // 160: determined.api.v1.Determined.GetModelVersion:input_type -> determined.api.v1.GetModelVersionRequest - 161, // 161: determined.api.v1.Determined.GetModelVersions:input_type -> determined.api.v1.GetModelVersionsRequest - 162, // 162: determined.api.v1.Determined.PostModelVersion:input_type -> determined.api.v1.PostModelVersionRequest - 163, // 163: determined.api.v1.Determined.PatchModelVersion:input_type -> determined.api.v1.PatchModelVersionRequest - 164, // 164: determined.api.v1.Determined.DeleteModelVersion:input_type -> determined.api.v1.DeleteModelVersionRequest - 165, // 165: determined.api.v1.Determined.GetTrialMetricsByModelVersion:input_type -> determined.api.v1.GetTrialMetricsByModelVersionRequest - 166, // 166: determined.api.v1.Determined.GetCheckpoint:input_type -> determined.api.v1.GetCheckpointRequest - 167, // 167: determined.api.v1.Determined.PostCheckpointMetadata:input_type -> determined.api.v1.PostCheckpointMetadataRequest - 168, // 168: determined.api.v1.Determined.CheckpointsRemoveFiles:input_type -> determined.api.v1.CheckpointsRemoveFilesRequest - 169, // 169: determined.api.v1.Determined.PatchCheckpoints:input_type -> determined.api.v1.PatchCheckpointsRequest - 170, // 170: determined.api.v1.Determined.DeleteCheckpoints:input_type -> determined.api.v1.DeleteCheckpointsRequest - 171, // 171: determined.api.v1.Determined.GetTrialMetricsByCheckpoint:input_type -> determined.api.v1.GetTrialMetricsByCheckpointRequest - 172, // 172: determined.api.v1.Determined.GetSearcherEvents:input_type -> determined.api.v1.GetSearcherEventsRequest - 173, // 173: determined.api.v1.Determined.PostSearcherOperations:input_type -> determined.api.v1.PostSearcherOperationsRequest - 174, // 174: determined.api.v1.Determined.ExpMetricNames:input_type -> determined.api.v1.ExpMetricNamesRequest - 175, // 175: determined.api.v1.Determined.MetricBatches:input_type -> determined.api.v1.MetricBatchesRequest - 176, // 176: determined.api.v1.Determined.TrialsSnapshot:input_type -> determined.api.v1.TrialsSnapshotRequest - 177, // 177: determined.api.v1.Determined.TrialsSample:input_type -> determined.api.v1.TrialsSampleRequest - 178, // 178: determined.api.v1.Determined.GetResourcePools:input_type -> determined.api.v1.GetResourcePoolsRequest - 179, // 179: determined.api.v1.Determined.GetKubernetesResourceManagers:input_type -> determined.api.v1.GetKubernetesResourceManagersRequest - 180, // 180: determined.api.v1.Determined.ResourceAllocationRaw:input_type -> determined.api.v1.ResourceAllocationRawRequest - 181, // 181: determined.api.v1.Determined.ResourceAllocationAggregated:input_type -> determined.api.v1.ResourceAllocationAggregatedRequest - 182, // 182: determined.api.v1.Determined.GetWorkspace:input_type -> determined.api.v1.GetWorkspaceRequest - 183, // 183: determined.api.v1.Determined.GetWorkspaceProjects:input_type -> determined.api.v1.GetWorkspaceProjectsRequest - 184, // 184: determined.api.v1.Determined.GetWorkspaces:input_type -> determined.api.v1.GetWorkspacesRequest - 185, // 185: determined.api.v1.Determined.PostWorkspace:input_type -> determined.api.v1.PostWorkspaceRequest - 186, // 186: determined.api.v1.Determined.PatchWorkspace:input_type -> determined.api.v1.PatchWorkspaceRequest - 187, // 187: determined.api.v1.Determined.DeleteWorkspace:input_type -> determined.api.v1.DeleteWorkspaceRequest - 188, // 188: determined.api.v1.Determined.ArchiveWorkspace:input_type -> determined.api.v1.ArchiveWorkspaceRequest - 189, // 189: determined.api.v1.Determined.UnarchiveWorkspace:input_type -> determined.api.v1.UnarchiveWorkspaceRequest - 190, // 190: determined.api.v1.Determined.PinWorkspace:input_type -> determined.api.v1.PinWorkspaceRequest - 191, // 191: determined.api.v1.Determined.UnpinWorkspace:input_type -> determined.api.v1.UnpinWorkspaceRequest - 192, // 192: determined.api.v1.Determined.SetWorkspaceNamespaceBindings:input_type -> determined.api.v1.SetWorkspaceNamespaceBindingsRequest - 193, // 193: determined.api.v1.Determined.SetResourceQuotas:input_type -> determined.api.v1.SetResourceQuotasRequest - 194, // 194: determined.api.v1.Determined.ListWorkspaceNamespaceBindings:input_type -> determined.api.v1.ListWorkspaceNamespaceBindingsRequest - 195, // 195: determined.api.v1.Determined.GetWorkspacesWithDefaultNamespaceBindings:input_type -> determined.api.v1.GetWorkspacesWithDefaultNamespaceBindingsRequest - 196, // 196: determined.api.v1.Determined.BulkAutoCreateWorkspaceNamespaceBindings:input_type -> determined.api.v1.BulkAutoCreateWorkspaceNamespaceBindingsRequest - 197, // 197: determined.api.v1.Determined.DeleteWorkspaceNamespaceBindings:input_type -> determined.api.v1.DeleteWorkspaceNamespaceBindingsRequest - 198, // 198: determined.api.v1.Determined.GetKubernetesResourceQuotas:input_type -> determined.api.v1.GetKubernetesResourceQuotasRequest - 199, // 199: determined.api.v1.Determined.GetProject:input_type -> determined.api.v1.GetProjectRequest - 200, // 200: determined.api.v1.Determined.GetProjectByKey:input_type -> determined.api.v1.GetProjectByKeyRequest - 201, // 201: determined.api.v1.Determined.GetProjectColumns:input_type -> determined.api.v1.GetProjectColumnsRequest - 202, // 202: determined.api.v1.Determined.GetProjectNumericMetricsRange:input_type -> determined.api.v1.GetProjectNumericMetricsRangeRequest - 203, // 203: determined.api.v1.Determined.PostProject:input_type -> determined.api.v1.PostProjectRequest - 204, // 204: determined.api.v1.Determined.AddProjectNote:input_type -> determined.api.v1.AddProjectNoteRequest - 205, // 205: determined.api.v1.Determined.PutProjectNotes:input_type -> determined.api.v1.PutProjectNotesRequest - 206, // 206: determined.api.v1.Determined.PatchProject:input_type -> determined.api.v1.PatchProjectRequest - 207, // 207: determined.api.v1.Determined.DeleteProject:input_type -> determined.api.v1.DeleteProjectRequest - 208, // 208: determined.api.v1.Determined.ArchiveProject:input_type -> determined.api.v1.ArchiveProjectRequest - 209, // 209: determined.api.v1.Determined.UnarchiveProject:input_type -> determined.api.v1.UnarchiveProjectRequest - 210, // 210: determined.api.v1.Determined.MoveProject:input_type -> determined.api.v1.MoveProjectRequest - 211, // 211: determined.api.v1.Determined.MoveExperiment:input_type -> determined.api.v1.MoveExperimentRequest - 212, // 212: determined.api.v1.Determined.MoveExperiments:input_type -> determined.api.v1.MoveExperimentsRequest - 213, // 213: determined.api.v1.Determined.GetWebhooks:input_type -> determined.api.v1.GetWebhooksRequest - 214, // 214: determined.api.v1.Determined.PatchWebhook:input_type -> determined.api.v1.PatchWebhookRequest - 215, // 215: determined.api.v1.Determined.PostWebhook:input_type -> determined.api.v1.PostWebhookRequest - 216, // 216: determined.api.v1.Determined.DeleteWebhook:input_type -> determined.api.v1.DeleteWebhookRequest - 217, // 217: determined.api.v1.Determined.TestWebhook:input_type -> determined.api.v1.TestWebhookRequest - 218, // 218: determined.api.v1.Determined.PostWebhookEventData:input_type -> determined.api.v1.PostWebhookEventDataRequest - 219, // 219: determined.api.v1.Determined.GetGroup:input_type -> determined.api.v1.GetGroupRequest - 220, // 220: determined.api.v1.Determined.GetGroups:input_type -> determined.api.v1.GetGroupsRequest - 221, // 221: determined.api.v1.Determined.CreateGroup:input_type -> determined.api.v1.CreateGroupRequest - 222, // 222: determined.api.v1.Determined.UpdateGroup:input_type -> determined.api.v1.UpdateGroupRequest - 223, // 223: determined.api.v1.Determined.DeleteGroup:input_type -> determined.api.v1.DeleteGroupRequest - 224, // 224: determined.api.v1.Determined.GetPermissionsSummary:input_type -> determined.api.v1.GetPermissionsSummaryRequest - 225, // 225: determined.api.v1.Determined.GetGroupsAndUsersAssignedToWorkspace:input_type -> determined.api.v1.GetGroupsAndUsersAssignedToWorkspaceRequest - 226, // 226: determined.api.v1.Determined.GetRolesByID:input_type -> determined.api.v1.GetRolesByIDRequest - 227, // 227: determined.api.v1.Determined.GetRolesAssignedToUser:input_type -> determined.api.v1.GetRolesAssignedToUserRequest - 228, // 228: determined.api.v1.Determined.GetRolesAssignedToGroup:input_type -> determined.api.v1.GetRolesAssignedToGroupRequest - 229, // 229: determined.api.v1.Determined.SearchRolesAssignableToScope:input_type -> determined.api.v1.SearchRolesAssignableToScopeRequest - 230, // 230: determined.api.v1.Determined.ListRoles:input_type -> determined.api.v1.ListRolesRequest - 231, // 231: determined.api.v1.Determined.AssignRoles:input_type -> determined.api.v1.AssignRolesRequest - 232, // 232: determined.api.v1.Determined.RemoveAssignments:input_type -> determined.api.v1.RemoveAssignmentsRequest - 233, // 233: determined.api.v1.Determined.PostUserActivity:input_type -> determined.api.v1.PostUserActivityRequest - 234, // 234: determined.api.v1.Determined.GetProjectsByUserActivity:input_type -> determined.api.v1.GetProjectsByUserActivityRequest - 235, // 235: determined.api.v1.Determined.SearchExperiments:input_type -> determined.api.v1.SearchExperimentsRequest - 236, // 236: determined.api.v1.Determined.BindRPToWorkspace:input_type -> determined.api.v1.BindRPToWorkspaceRequest - 237, // 237: determined.api.v1.Determined.UnbindRPFromWorkspace:input_type -> determined.api.v1.UnbindRPFromWorkspaceRequest - 238, // 238: determined.api.v1.Determined.OverwriteRPWorkspaceBindings:input_type -> determined.api.v1.OverwriteRPWorkspaceBindingsRequest - 239, // 239: determined.api.v1.Determined.ListRPsBoundToWorkspace:input_type -> determined.api.v1.ListRPsBoundToWorkspaceRequest - 240, // 240: determined.api.v1.Determined.ListWorkspacesBoundToRP:input_type -> determined.api.v1.ListWorkspacesBoundToRPRequest - 241, // 241: determined.api.v1.Determined.GetGenericTaskConfig:input_type -> determined.api.v1.GetGenericTaskConfigRequest - 242, // 242: determined.api.v1.Determined.KillGenericTask:input_type -> determined.api.v1.KillGenericTaskRequest - 243, // 243: determined.api.v1.Determined.PauseGenericTask:input_type -> determined.api.v1.PauseGenericTaskRequest - 244, // 244: determined.api.v1.Determined.UnpauseGenericTask:input_type -> determined.api.v1.UnpauseGenericTaskRequest - 245, // 245: determined.api.v1.Determined.SearchRuns:input_type -> determined.api.v1.SearchRunsRequest - 246, // 246: determined.api.v1.Determined.MoveRuns:input_type -> determined.api.v1.MoveRunsRequest - 247, // 247: determined.api.v1.Determined.KillRuns:input_type -> determined.api.v1.KillRunsRequest - 248, // 248: determined.api.v1.Determined.DeleteRuns:input_type -> determined.api.v1.DeleteRunsRequest - 249, // 249: determined.api.v1.Determined.ArchiveRuns:input_type -> determined.api.v1.ArchiveRunsRequest - 250, // 250: determined.api.v1.Determined.UnarchiveRuns:input_type -> determined.api.v1.UnarchiveRunsRequest - 251, // 251: determined.api.v1.Determined.PauseRuns:input_type -> determined.api.v1.PauseRunsRequest - 252, // 252: determined.api.v1.Determined.ResumeRuns:input_type -> determined.api.v1.ResumeRunsRequest - 253, // 253: determined.api.v1.Determined.GetRunMetadata:input_type -> determined.api.v1.GetRunMetadataRequest - 254, // 254: determined.api.v1.Determined.PostRunMetadata:input_type -> determined.api.v1.PostRunMetadataRequest - 255, // 255: determined.api.v1.Determined.GetMetadataValues:input_type -> determined.api.v1.GetMetadataValuesRequest - 256, // 256: determined.api.v1.Determined.PutWorkspaceConfigPolicies:input_type -> determined.api.v1.PutWorkspaceConfigPoliciesRequest - 257, // 257: determined.api.v1.Determined.PutGlobalConfigPolicies:input_type -> determined.api.v1.PutGlobalConfigPoliciesRequest - 258, // 258: determined.api.v1.Determined.GetWorkspaceConfigPolicies:input_type -> determined.api.v1.GetWorkspaceConfigPoliciesRequest - 259, // 259: determined.api.v1.Determined.GetGlobalConfigPolicies:input_type -> determined.api.v1.GetGlobalConfigPoliciesRequest - 260, // 260: determined.api.v1.Determined.DeleteWorkspaceConfigPolicies:input_type -> determined.api.v1.DeleteWorkspaceConfigPoliciesRequest - 261, // 261: determined.api.v1.Determined.DeleteGlobalConfigPolicies:input_type -> determined.api.v1.DeleteGlobalConfigPoliciesRequest - 262, // 262: determined.api.v1.Determined.MoveSearches:input_type -> determined.api.v1.MoveSearchesRequest - 263, // 263: determined.api.v1.Determined.CancelSearches:input_type -> determined.api.v1.CancelSearchesRequest - 264, // 264: determined.api.v1.Determined.KillSearches:input_type -> determined.api.v1.KillSearchesRequest - 265, // 265: determined.api.v1.Determined.DeleteSearches:input_type -> determined.api.v1.DeleteSearchesRequest - 266, // 266: determined.api.v1.Determined.ArchiveSearches:input_type -> determined.api.v1.ArchiveSearchesRequest - 267, // 267: determined.api.v1.Determined.UnarchiveSearches:input_type -> determined.api.v1.UnarchiveSearchesRequest - 268, // 268: determined.api.v1.Determined.PauseSearches:input_type -> determined.api.v1.PauseSearchesRequest - 269, // 269: determined.api.v1.Determined.ResumeSearches:input_type -> determined.api.v1.ResumeSearchesRequest - 270, // 270: determined.api.v1.Determined.PostAccessToken:input_type -> determined.api.v1.PostAccessTokenRequest - 271, // 271: determined.api.v1.Determined.GetAccessTokens:input_type -> determined.api.v1.GetAccessTokensRequest - 272, // 272: determined.api.v1.Determined.PatchAccessToken:input_type -> determined.api.v1.PatchAccessTokenRequest - 273, // 273: determined.api.v1.Determined.Login:output_type -> determined.api.v1.LoginResponse - 274, // 274: determined.api.v1.Determined.CurrentUser:output_type -> determined.api.v1.CurrentUserResponse - 275, // 275: determined.api.v1.Determined.Logout:output_type -> determined.api.v1.LogoutResponse - 276, // 276: determined.api.v1.Determined.GetUsers:output_type -> determined.api.v1.GetUsersResponse - 277, // 277: determined.api.v1.Determined.GetUserSetting:output_type -> determined.api.v1.GetUserSettingResponse - 278, // 278: determined.api.v1.Determined.ResetUserSetting:output_type -> determined.api.v1.ResetUserSettingResponse - 279, // 279: determined.api.v1.Determined.PostUserSetting:output_type -> determined.api.v1.PostUserSettingResponse - 280, // 280: determined.api.v1.Determined.GetUser:output_type -> determined.api.v1.GetUserResponse - 281, // 281: determined.api.v1.Determined.GetUserByUsername:output_type -> determined.api.v1.GetUserByUsernameResponse - 282, // 282: determined.api.v1.Determined.GetMe:output_type -> determined.api.v1.GetMeResponse - 283, // 283: determined.api.v1.Determined.PostUser:output_type -> determined.api.v1.PostUserResponse - 284, // 284: determined.api.v1.Determined.SetUserPassword:output_type -> determined.api.v1.SetUserPasswordResponse - 285, // 285: determined.api.v1.Determined.AssignMultipleGroups:output_type -> determined.api.v1.AssignMultipleGroupsResponse - 286, // 286: determined.api.v1.Determined.PatchUser:output_type -> determined.api.v1.PatchUserResponse - 287, // 287: determined.api.v1.Determined.PatchUsers:output_type -> determined.api.v1.PatchUsersResponse - 288, // 288: determined.api.v1.Determined.GetTelemetry:output_type -> determined.api.v1.GetTelemetryResponse - 289, // 289: determined.api.v1.Determined.GetMaster:output_type -> determined.api.v1.GetMasterResponse - 290, // 290: determined.api.v1.Determined.GetMasterConfig:output_type -> determined.api.v1.GetMasterConfigResponse - 291, // 291: determined.api.v1.Determined.PatchMasterConfig:output_type -> determined.api.v1.PatchMasterConfigResponse - 292, // 292: determined.api.v1.Determined.MasterLogs:output_type -> determined.api.v1.MasterLogsResponse - 293, // 293: determined.api.v1.Determined.GetClusterMessage:output_type -> determined.api.v1.GetClusterMessageResponse - 294, // 294: determined.api.v1.Determined.SetClusterMessage:output_type -> determined.api.v1.SetClusterMessageResponse - 295, // 295: determined.api.v1.Determined.DeleteClusterMessage:output_type -> determined.api.v1.DeleteClusterMessageResponse - 296, // 296: determined.api.v1.Determined.GetAgents:output_type -> determined.api.v1.GetAgentsResponse - 297, // 297: determined.api.v1.Determined.GetAgent:output_type -> determined.api.v1.GetAgentResponse - 298, // 298: determined.api.v1.Determined.GetSlots:output_type -> determined.api.v1.GetSlotsResponse - 299, // 299: determined.api.v1.Determined.GetSlot:output_type -> determined.api.v1.GetSlotResponse - 300, // 300: determined.api.v1.Determined.EnableAgent:output_type -> determined.api.v1.EnableAgentResponse - 301, // 301: determined.api.v1.Determined.DisableAgent:output_type -> determined.api.v1.DisableAgentResponse - 302, // 302: determined.api.v1.Determined.EnableSlot:output_type -> determined.api.v1.EnableSlotResponse - 303, // 303: determined.api.v1.Determined.DisableSlot:output_type -> determined.api.v1.DisableSlotResponse - 304, // 304: determined.api.v1.Determined.CreateGenericTask:output_type -> determined.api.v1.CreateGenericTaskResponse - 305, // 305: determined.api.v1.Determined.CreateExperiment:output_type -> determined.api.v1.CreateExperimentResponse - 306, // 306: determined.api.v1.Determined.PutExperiment:output_type -> determined.api.v1.PutExperimentResponse - 307, // 307: determined.api.v1.Determined.ContinueExperiment:output_type -> determined.api.v1.ContinueExperimentResponse - 308, // 308: determined.api.v1.Determined.GetExperiment:output_type -> determined.api.v1.GetExperimentResponse - 309, // 309: determined.api.v1.Determined.GetExperiments:output_type -> determined.api.v1.GetExperimentsResponse - 310, // 310: determined.api.v1.Determined.PutExperimentRetainLogs:output_type -> determined.api.v1.PutExperimentRetainLogsResponse - 311, // 311: determined.api.v1.Determined.PutExperimentsRetainLogs:output_type -> determined.api.v1.PutExperimentsRetainLogsResponse - 312, // 312: determined.api.v1.Determined.PutTrialRetainLogs:output_type -> determined.api.v1.PutTrialRetainLogsResponse - 313, // 313: determined.api.v1.Determined.GetModelDef:output_type -> determined.api.v1.GetModelDefResponse - 314, // 314: determined.api.v1.Determined.GetTaskContextDirectory:output_type -> determined.api.v1.GetTaskContextDirectoryResponse - 315, // 315: determined.api.v1.Determined.GetModelDefTree:output_type -> determined.api.v1.GetModelDefTreeResponse - 316, // 316: determined.api.v1.Determined.GetModelDefFile:output_type -> determined.api.v1.GetModelDefFileResponse - 317, // 317: determined.api.v1.Determined.GetExperimentLabels:output_type -> determined.api.v1.GetExperimentLabelsResponse - 318, // 318: determined.api.v1.Determined.GetExperimentValidationHistory:output_type -> determined.api.v1.GetExperimentValidationHistoryResponse - 319, // 319: determined.api.v1.Determined.ActivateExperiment:output_type -> determined.api.v1.ActivateExperimentResponse - 320, // 320: determined.api.v1.Determined.ActivateExperiments:output_type -> determined.api.v1.ActivateExperimentsResponse - 321, // 321: determined.api.v1.Determined.PauseExperiment:output_type -> determined.api.v1.PauseExperimentResponse - 322, // 322: determined.api.v1.Determined.PauseExperiments:output_type -> determined.api.v1.PauseExperimentsResponse - 323, // 323: determined.api.v1.Determined.CancelExperiment:output_type -> determined.api.v1.CancelExperimentResponse - 324, // 324: determined.api.v1.Determined.CancelExperiments:output_type -> determined.api.v1.CancelExperimentsResponse - 325, // 325: determined.api.v1.Determined.KillExperiment:output_type -> determined.api.v1.KillExperimentResponse - 326, // 326: determined.api.v1.Determined.KillExperiments:output_type -> determined.api.v1.KillExperimentsResponse - 327, // 327: determined.api.v1.Determined.ArchiveExperiment:output_type -> determined.api.v1.ArchiveExperimentResponse - 328, // 328: determined.api.v1.Determined.ArchiveExperiments:output_type -> determined.api.v1.ArchiveExperimentsResponse - 329, // 329: determined.api.v1.Determined.UnarchiveExperiment:output_type -> determined.api.v1.UnarchiveExperimentResponse - 330, // 330: determined.api.v1.Determined.UnarchiveExperiments:output_type -> determined.api.v1.UnarchiveExperimentsResponse - 331, // 331: determined.api.v1.Determined.PatchExperiment:output_type -> determined.api.v1.PatchExperimentResponse - 332, // 332: determined.api.v1.Determined.DeleteExperiments:output_type -> determined.api.v1.DeleteExperimentsResponse - 333, // 333: determined.api.v1.Determined.DeleteExperiment:output_type -> determined.api.v1.DeleteExperimentResponse - 334, // 334: determined.api.v1.Determined.GetBestSearcherValidationMetric:output_type -> determined.api.v1.GetBestSearcherValidationMetricResponse - 335, // 335: determined.api.v1.Determined.GetExperimentCheckpoints:output_type -> determined.api.v1.GetExperimentCheckpointsResponse - 336, // 336: determined.api.v1.Determined.PutExperimentLabel:output_type -> determined.api.v1.PutExperimentLabelResponse - 337, // 337: determined.api.v1.Determined.DeleteExperimentLabel:output_type -> determined.api.v1.DeleteExperimentLabelResponse - 338, // 338: determined.api.v1.Determined.PreviewHPSearch:output_type -> determined.api.v1.PreviewHPSearchResponse - 339, // 339: determined.api.v1.Determined.GetExperimentTrials:output_type -> determined.api.v1.GetExperimentTrialsResponse - 340, // 340: determined.api.v1.Determined.GetTrialRemainingLogRetentionDays:output_type -> determined.api.v1.GetTrialRemainingLogRetentionDaysResponse - 341, // 341: determined.api.v1.Determined.CompareTrials:output_type -> determined.api.v1.CompareTrialsResponse - 342, // 342: determined.api.v1.Determined.ReportTrialSourceInfo:output_type -> determined.api.v1.ReportTrialSourceInfoResponse - 343, // 343: determined.api.v1.Determined.CreateTrial:output_type -> determined.api.v1.CreateTrialResponse - 344, // 344: determined.api.v1.Determined.PutTrial:output_type -> determined.api.v1.PutTrialResponse - 345, // 345: determined.api.v1.Determined.PatchTrial:output_type -> determined.api.v1.PatchTrialResponse - 346, // 346: determined.api.v1.Determined.StartTrial:output_type -> determined.api.v1.StartTrialResponse - 347, // 347: determined.api.v1.Determined.RunPrepareForReporting:output_type -> determined.api.v1.RunPrepareForReportingResponse - 348, // 348: determined.api.v1.Determined.GetTrial:output_type -> determined.api.v1.GetTrialResponse - 349, // 349: determined.api.v1.Determined.GetTrialByExternalID:output_type -> determined.api.v1.GetTrialByExternalIDResponse - 350, // 350: determined.api.v1.Determined.GetTrialWorkloads:output_type -> determined.api.v1.GetTrialWorkloadsResponse - 351, // 351: determined.api.v1.Determined.TrialLogs:output_type -> determined.api.v1.TrialLogsResponse - 352, // 352: determined.api.v1.Determined.TrialLogsFields:output_type -> determined.api.v1.TrialLogsFieldsResponse - 353, // 353: determined.api.v1.Determined.AllocationReady:output_type -> determined.api.v1.AllocationReadyResponse - 354, // 354: determined.api.v1.Determined.GetAllocation:output_type -> determined.api.v1.GetAllocationResponse - 355, // 355: determined.api.v1.Determined.AllocationWaiting:output_type -> determined.api.v1.AllocationWaitingResponse - 356, // 356: determined.api.v1.Determined.PostTaskLogs:output_type -> determined.api.v1.PostTaskLogsResponse - 357, // 357: determined.api.v1.Determined.TaskLogs:output_type -> determined.api.v1.TaskLogsResponse - 358, // 358: determined.api.v1.Determined.TaskLogsFields:output_type -> determined.api.v1.TaskLogsFieldsResponse - 359, // 359: determined.api.v1.Determined.GetTrialProfilerMetrics:output_type -> determined.api.v1.GetTrialProfilerMetricsResponse - 360, // 360: determined.api.v1.Determined.GetTrialProfilerAvailableSeries:output_type -> determined.api.v1.GetTrialProfilerAvailableSeriesResponse - 361, // 361: determined.api.v1.Determined.PostTrialProfilerMetricsBatch:output_type -> determined.api.v1.PostTrialProfilerMetricsBatchResponse - 362, // 362: determined.api.v1.Determined.GetMetrics:output_type -> determined.api.v1.GetMetricsResponse - 363, // 363: determined.api.v1.Determined.GetTrainingMetrics:output_type -> determined.api.v1.GetTrainingMetricsResponse - 364, // 364: determined.api.v1.Determined.GetValidationMetrics:output_type -> determined.api.v1.GetValidationMetricsResponse - 365, // 365: determined.api.v1.Determined.KillTrial:output_type -> determined.api.v1.KillTrialResponse - 366, // 366: determined.api.v1.Determined.GetTrialCheckpoints:output_type -> determined.api.v1.GetTrialCheckpointsResponse - 367, // 367: determined.api.v1.Determined.CleanupLogs:output_type -> determined.api.v1.CleanupLogsResponse - 368, // 368: determined.api.v1.Determined.AllocationPreemptionSignal:output_type -> determined.api.v1.AllocationPreemptionSignalResponse - 369, // 369: determined.api.v1.Determined.AllocationPendingPreemptionSignal:output_type -> determined.api.v1.AllocationPendingPreemptionSignalResponse - 370, // 370: determined.api.v1.Determined.AckAllocationPreemptionSignal:output_type -> determined.api.v1.AckAllocationPreemptionSignalResponse - 371, // 371: determined.api.v1.Determined.MarkAllocationResourcesDaemon:output_type -> determined.api.v1.MarkAllocationResourcesDaemonResponse - 372, // 372: determined.api.v1.Determined.AllocationRendezvousInfo:output_type -> determined.api.v1.AllocationRendezvousInfoResponse - 373, // 373: determined.api.v1.Determined.PostAllocationProxyAddress:output_type -> determined.api.v1.PostAllocationProxyAddressResponse - 374, // 374: determined.api.v1.Determined.GetTaskAcceleratorData:output_type -> determined.api.v1.GetTaskAcceleratorDataResponse - 375, // 375: determined.api.v1.Determined.PostAllocationAcceleratorData:output_type -> determined.api.v1.PostAllocationAcceleratorDataResponse - 376, // 376: determined.api.v1.Determined.AllocationAllGather:output_type -> determined.api.v1.AllocationAllGatherResponse - 377, // 377: determined.api.v1.Determined.NotifyContainerRunning:output_type -> determined.api.v1.NotifyContainerRunningResponse - 378, // 378: determined.api.v1.Determined.GetCurrentTrialSearcherOperation:output_type -> determined.api.v1.GetCurrentTrialSearcherOperationResponse - 379, // 379: determined.api.v1.Determined.CompleteTrialSearcherValidation:output_type -> determined.api.v1.CompleteTrialSearcherValidationResponse - 380, // 380: determined.api.v1.Determined.ReportTrialSearcherEarlyExit:output_type -> determined.api.v1.ReportTrialSearcherEarlyExitResponse - 381, // 381: determined.api.v1.Determined.ReportTrialProgress:output_type -> determined.api.v1.ReportTrialProgressResponse - 382, // 382: determined.api.v1.Determined.PostTrialRunnerMetadata:output_type -> determined.api.v1.PostTrialRunnerMetadataResponse - 383, // 383: determined.api.v1.Determined.ReportTrialMetrics:output_type -> determined.api.v1.ReportTrialMetricsResponse - 384, // 384: determined.api.v1.Determined.ReportTrialTrainingMetrics:output_type -> determined.api.v1.ReportTrialTrainingMetricsResponse - 385, // 385: determined.api.v1.Determined.ReportTrialValidationMetrics:output_type -> determined.api.v1.ReportTrialValidationMetricsResponse - 386, // 386: determined.api.v1.Determined.ReportCheckpoint:output_type -> determined.api.v1.ReportCheckpointResponse - 387, // 387: determined.api.v1.Determined.GetJobs:output_type -> determined.api.v1.GetJobsResponse - 388, // 388: determined.api.v1.Determined.GetJobsV2:output_type -> determined.api.v1.GetJobsV2Response - 389, // 389: determined.api.v1.Determined.GetJobQueueStats:output_type -> determined.api.v1.GetJobQueueStatsResponse - 390, // 390: determined.api.v1.Determined.UpdateJobQueue:output_type -> determined.api.v1.UpdateJobQueueResponse - 391, // 391: determined.api.v1.Determined.GetTemplates:output_type -> determined.api.v1.GetTemplatesResponse - 392, // 392: determined.api.v1.Determined.GetTemplate:output_type -> determined.api.v1.GetTemplateResponse - 393, // 393: determined.api.v1.Determined.PutTemplate:output_type -> determined.api.v1.PutTemplateResponse - 394, // 394: determined.api.v1.Determined.PostTemplate:output_type -> determined.api.v1.PostTemplateResponse - 395, // 395: determined.api.v1.Determined.PatchTemplateConfig:output_type -> determined.api.v1.PatchTemplateConfigResponse - 396, // 396: determined.api.v1.Determined.PatchTemplateName:output_type -> determined.api.v1.PatchTemplateNameResponse - 397, // 397: determined.api.v1.Determined.DeleteTemplate:output_type -> determined.api.v1.DeleteTemplateResponse - 398, // 398: determined.api.v1.Determined.GetNotebooks:output_type -> determined.api.v1.GetNotebooksResponse - 399, // 399: determined.api.v1.Determined.GetNotebook:output_type -> determined.api.v1.GetNotebookResponse - 400, // 400: determined.api.v1.Determined.IdleNotebook:output_type -> determined.api.v1.IdleNotebookResponse - 401, // 401: determined.api.v1.Determined.KillNotebook:output_type -> determined.api.v1.KillNotebookResponse - 402, // 402: determined.api.v1.Determined.SetNotebookPriority:output_type -> determined.api.v1.SetNotebookPriorityResponse - 403, // 403: determined.api.v1.Determined.LaunchNotebook:output_type -> determined.api.v1.LaunchNotebookResponse - 404, // 404: determined.api.v1.Determined.GetShells:output_type -> determined.api.v1.GetShellsResponse - 405, // 405: determined.api.v1.Determined.GetShell:output_type -> determined.api.v1.GetShellResponse - 406, // 406: determined.api.v1.Determined.KillShell:output_type -> determined.api.v1.KillShellResponse - 407, // 407: determined.api.v1.Determined.SetShellPriority:output_type -> determined.api.v1.SetShellPriorityResponse - 408, // 408: determined.api.v1.Determined.LaunchShell:output_type -> determined.api.v1.LaunchShellResponse - 409, // 409: determined.api.v1.Determined.GetCommands:output_type -> determined.api.v1.GetCommandsResponse - 410, // 410: determined.api.v1.Determined.GetCommand:output_type -> determined.api.v1.GetCommandResponse - 411, // 411: determined.api.v1.Determined.KillCommand:output_type -> determined.api.v1.KillCommandResponse - 412, // 412: determined.api.v1.Determined.SetCommandPriority:output_type -> determined.api.v1.SetCommandPriorityResponse - 413, // 413: determined.api.v1.Determined.LaunchCommand:output_type -> determined.api.v1.LaunchCommandResponse - 414, // 414: determined.api.v1.Determined.GetTensorboards:output_type -> determined.api.v1.GetTensorboardsResponse - 415, // 415: determined.api.v1.Determined.GetTensorboard:output_type -> determined.api.v1.GetTensorboardResponse - 416, // 416: determined.api.v1.Determined.KillTensorboard:output_type -> determined.api.v1.KillTensorboardResponse - 417, // 417: determined.api.v1.Determined.SetTensorboardPriority:output_type -> determined.api.v1.SetTensorboardPriorityResponse - 418, // 418: determined.api.v1.Determined.LaunchTensorboard:output_type -> determined.api.v1.LaunchTensorboardResponse - 419, // 419: determined.api.v1.Determined.LaunchTensorboardSearches:output_type -> determined.api.v1.LaunchTensorboardSearchesResponse - 420, // 420: determined.api.v1.Determined.DeleteTensorboardFiles:output_type -> determined.api.v1.DeleteTensorboardFilesResponse - 421, // 421: determined.api.v1.Determined.GetActiveTasksCount:output_type -> determined.api.v1.GetActiveTasksCountResponse - 422, // 422: determined.api.v1.Determined.GetTask:output_type -> determined.api.v1.GetTaskResponse - 423, // 423: determined.api.v1.Determined.GetTasks:output_type -> determined.api.v1.GetTasksResponse - 424, // 424: determined.api.v1.Determined.GetModel:output_type -> determined.api.v1.GetModelResponse - 425, // 425: determined.api.v1.Determined.PostModel:output_type -> determined.api.v1.PostModelResponse - 426, // 426: determined.api.v1.Determined.PatchModel:output_type -> determined.api.v1.PatchModelResponse - 427, // 427: determined.api.v1.Determined.ArchiveModel:output_type -> determined.api.v1.ArchiveModelResponse - 428, // 428: determined.api.v1.Determined.UnarchiveModel:output_type -> determined.api.v1.UnarchiveModelResponse - 429, // 429: determined.api.v1.Determined.MoveModel:output_type -> determined.api.v1.MoveModelResponse - 430, // 430: determined.api.v1.Determined.DeleteModel:output_type -> determined.api.v1.DeleteModelResponse - 431, // 431: determined.api.v1.Determined.GetModels:output_type -> determined.api.v1.GetModelsResponse - 432, // 432: determined.api.v1.Determined.GetModelLabels:output_type -> determined.api.v1.GetModelLabelsResponse - 433, // 433: determined.api.v1.Determined.GetModelVersion:output_type -> determined.api.v1.GetModelVersionResponse - 434, // 434: determined.api.v1.Determined.GetModelVersions:output_type -> determined.api.v1.GetModelVersionsResponse - 435, // 435: determined.api.v1.Determined.PostModelVersion:output_type -> determined.api.v1.PostModelVersionResponse - 436, // 436: determined.api.v1.Determined.PatchModelVersion:output_type -> determined.api.v1.PatchModelVersionResponse - 437, // 437: determined.api.v1.Determined.DeleteModelVersion:output_type -> determined.api.v1.DeleteModelVersionResponse - 438, // 438: determined.api.v1.Determined.GetTrialMetricsByModelVersion:output_type -> determined.api.v1.GetTrialMetricsByModelVersionResponse - 439, // 439: determined.api.v1.Determined.GetCheckpoint:output_type -> determined.api.v1.GetCheckpointResponse - 440, // 440: determined.api.v1.Determined.PostCheckpointMetadata:output_type -> determined.api.v1.PostCheckpointMetadataResponse - 441, // 441: determined.api.v1.Determined.CheckpointsRemoveFiles:output_type -> determined.api.v1.CheckpointsRemoveFilesResponse - 442, // 442: determined.api.v1.Determined.PatchCheckpoints:output_type -> determined.api.v1.PatchCheckpointsResponse - 443, // 443: determined.api.v1.Determined.DeleteCheckpoints:output_type -> determined.api.v1.DeleteCheckpointsResponse - 444, // 444: determined.api.v1.Determined.GetTrialMetricsByCheckpoint:output_type -> determined.api.v1.GetTrialMetricsByCheckpointResponse - 445, // 445: determined.api.v1.Determined.GetSearcherEvents:output_type -> determined.api.v1.GetSearcherEventsResponse - 446, // 446: determined.api.v1.Determined.PostSearcherOperations:output_type -> determined.api.v1.PostSearcherOperationsResponse - 447, // 447: determined.api.v1.Determined.ExpMetricNames:output_type -> determined.api.v1.ExpMetricNamesResponse - 448, // 448: determined.api.v1.Determined.MetricBatches:output_type -> determined.api.v1.MetricBatchesResponse - 449, // 449: determined.api.v1.Determined.TrialsSnapshot:output_type -> determined.api.v1.TrialsSnapshotResponse - 450, // 450: determined.api.v1.Determined.TrialsSample:output_type -> determined.api.v1.TrialsSampleResponse - 451, // 451: determined.api.v1.Determined.GetResourcePools:output_type -> determined.api.v1.GetResourcePoolsResponse - 452, // 452: determined.api.v1.Determined.GetKubernetesResourceManagers:output_type -> determined.api.v1.GetKubernetesResourceManagersResponse - 453, // 453: determined.api.v1.Determined.ResourceAllocationRaw:output_type -> determined.api.v1.ResourceAllocationRawResponse - 454, // 454: determined.api.v1.Determined.ResourceAllocationAggregated:output_type -> determined.api.v1.ResourceAllocationAggregatedResponse - 455, // 455: determined.api.v1.Determined.GetWorkspace:output_type -> determined.api.v1.GetWorkspaceResponse - 456, // 456: determined.api.v1.Determined.GetWorkspaceProjects:output_type -> determined.api.v1.GetWorkspaceProjectsResponse - 457, // 457: determined.api.v1.Determined.GetWorkspaces:output_type -> determined.api.v1.GetWorkspacesResponse - 458, // 458: determined.api.v1.Determined.PostWorkspace:output_type -> determined.api.v1.PostWorkspaceResponse - 459, // 459: determined.api.v1.Determined.PatchWorkspace:output_type -> determined.api.v1.PatchWorkspaceResponse - 460, // 460: determined.api.v1.Determined.DeleteWorkspace:output_type -> determined.api.v1.DeleteWorkspaceResponse - 461, // 461: determined.api.v1.Determined.ArchiveWorkspace:output_type -> determined.api.v1.ArchiveWorkspaceResponse - 462, // 462: determined.api.v1.Determined.UnarchiveWorkspace:output_type -> determined.api.v1.UnarchiveWorkspaceResponse - 463, // 463: determined.api.v1.Determined.PinWorkspace:output_type -> determined.api.v1.PinWorkspaceResponse - 464, // 464: determined.api.v1.Determined.UnpinWorkspace:output_type -> determined.api.v1.UnpinWorkspaceResponse - 465, // 465: determined.api.v1.Determined.SetWorkspaceNamespaceBindings:output_type -> determined.api.v1.SetWorkspaceNamespaceBindingsResponse - 466, // 466: determined.api.v1.Determined.SetResourceQuotas:output_type -> determined.api.v1.SetResourceQuotasResponse - 467, // 467: determined.api.v1.Determined.ListWorkspaceNamespaceBindings:output_type -> determined.api.v1.ListWorkspaceNamespaceBindingsResponse - 468, // 468: determined.api.v1.Determined.GetWorkspacesWithDefaultNamespaceBindings:output_type -> determined.api.v1.GetWorkspacesWithDefaultNamespaceBindingsResponse - 469, // 469: determined.api.v1.Determined.BulkAutoCreateWorkspaceNamespaceBindings:output_type -> determined.api.v1.BulkAutoCreateWorkspaceNamespaceBindingsResponse - 470, // 470: determined.api.v1.Determined.DeleteWorkspaceNamespaceBindings:output_type -> determined.api.v1.DeleteWorkspaceNamespaceBindingsResponse - 471, // 471: determined.api.v1.Determined.GetKubernetesResourceQuotas:output_type -> determined.api.v1.GetKubernetesResourceQuotasResponse - 472, // 472: determined.api.v1.Determined.GetProject:output_type -> determined.api.v1.GetProjectResponse - 473, // 473: determined.api.v1.Determined.GetProjectByKey:output_type -> determined.api.v1.GetProjectByKeyResponse - 474, // 474: determined.api.v1.Determined.GetProjectColumns:output_type -> determined.api.v1.GetProjectColumnsResponse - 475, // 475: determined.api.v1.Determined.GetProjectNumericMetricsRange:output_type -> determined.api.v1.GetProjectNumericMetricsRangeResponse - 476, // 476: determined.api.v1.Determined.PostProject:output_type -> determined.api.v1.PostProjectResponse - 477, // 477: determined.api.v1.Determined.AddProjectNote:output_type -> determined.api.v1.AddProjectNoteResponse - 478, // 478: determined.api.v1.Determined.PutProjectNotes:output_type -> determined.api.v1.PutProjectNotesResponse - 479, // 479: determined.api.v1.Determined.PatchProject:output_type -> determined.api.v1.PatchProjectResponse - 480, // 480: determined.api.v1.Determined.DeleteProject:output_type -> determined.api.v1.DeleteProjectResponse - 481, // 481: determined.api.v1.Determined.ArchiveProject:output_type -> determined.api.v1.ArchiveProjectResponse - 482, // 482: determined.api.v1.Determined.UnarchiveProject:output_type -> determined.api.v1.UnarchiveProjectResponse - 483, // 483: determined.api.v1.Determined.MoveProject:output_type -> determined.api.v1.MoveProjectResponse - 484, // 484: determined.api.v1.Determined.MoveExperiment:output_type -> determined.api.v1.MoveExperimentResponse - 485, // 485: determined.api.v1.Determined.MoveExperiments:output_type -> determined.api.v1.MoveExperimentsResponse - 486, // 486: determined.api.v1.Determined.GetWebhooks:output_type -> determined.api.v1.GetWebhooksResponse - 487, // 487: determined.api.v1.Determined.PatchWebhook:output_type -> determined.api.v1.PatchWebhookResponse - 488, // 488: determined.api.v1.Determined.PostWebhook:output_type -> determined.api.v1.PostWebhookResponse - 489, // 489: determined.api.v1.Determined.DeleteWebhook:output_type -> determined.api.v1.DeleteWebhookResponse - 490, // 490: determined.api.v1.Determined.TestWebhook:output_type -> determined.api.v1.TestWebhookResponse - 491, // 491: determined.api.v1.Determined.PostWebhookEventData:output_type -> determined.api.v1.PostWebhookEventDataResponse - 492, // 492: determined.api.v1.Determined.GetGroup:output_type -> determined.api.v1.GetGroupResponse - 493, // 493: determined.api.v1.Determined.GetGroups:output_type -> determined.api.v1.GetGroupsResponse - 494, // 494: determined.api.v1.Determined.CreateGroup:output_type -> determined.api.v1.CreateGroupResponse - 495, // 495: determined.api.v1.Determined.UpdateGroup:output_type -> determined.api.v1.UpdateGroupResponse - 496, // 496: determined.api.v1.Determined.DeleteGroup:output_type -> determined.api.v1.DeleteGroupResponse - 497, // 497: determined.api.v1.Determined.GetPermissionsSummary:output_type -> determined.api.v1.GetPermissionsSummaryResponse - 498, // 498: determined.api.v1.Determined.GetGroupsAndUsersAssignedToWorkspace:output_type -> determined.api.v1.GetGroupsAndUsersAssignedToWorkspaceResponse - 499, // 499: determined.api.v1.Determined.GetRolesByID:output_type -> determined.api.v1.GetRolesByIDResponse - 500, // 500: determined.api.v1.Determined.GetRolesAssignedToUser:output_type -> determined.api.v1.GetRolesAssignedToUserResponse - 501, // 501: determined.api.v1.Determined.GetRolesAssignedToGroup:output_type -> determined.api.v1.GetRolesAssignedToGroupResponse - 502, // 502: determined.api.v1.Determined.SearchRolesAssignableToScope:output_type -> determined.api.v1.SearchRolesAssignableToScopeResponse - 503, // 503: determined.api.v1.Determined.ListRoles:output_type -> determined.api.v1.ListRolesResponse - 504, // 504: determined.api.v1.Determined.AssignRoles:output_type -> determined.api.v1.AssignRolesResponse - 505, // 505: determined.api.v1.Determined.RemoveAssignments:output_type -> determined.api.v1.RemoveAssignmentsResponse - 506, // 506: determined.api.v1.Determined.PostUserActivity:output_type -> determined.api.v1.PostUserActivityResponse - 507, // 507: determined.api.v1.Determined.GetProjectsByUserActivity:output_type -> determined.api.v1.GetProjectsByUserActivityResponse - 508, // 508: determined.api.v1.Determined.SearchExperiments:output_type -> determined.api.v1.SearchExperimentsResponse - 509, // 509: determined.api.v1.Determined.BindRPToWorkspace:output_type -> determined.api.v1.BindRPToWorkspaceResponse - 510, // 510: determined.api.v1.Determined.UnbindRPFromWorkspace:output_type -> determined.api.v1.UnbindRPFromWorkspaceResponse - 511, // 511: determined.api.v1.Determined.OverwriteRPWorkspaceBindings:output_type -> determined.api.v1.OverwriteRPWorkspaceBindingsResponse - 512, // 512: determined.api.v1.Determined.ListRPsBoundToWorkspace:output_type -> determined.api.v1.ListRPsBoundToWorkspaceResponse - 513, // 513: determined.api.v1.Determined.ListWorkspacesBoundToRP:output_type -> determined.api.v1.ListWorkspacesBoundToRPResponse - 514, // 514: determined.api.v1.Determined.GetGenericTaskConfig:output_type -> determined.api.v1.GetGenericTaskConfigResponse - 515, // 515: determined.api.v1.Determined.KillGenericTask:output_type -> determined.api.v1.KillGenericTaskResponse - 516, // 516: determined.api.v1.Determined.PauseGenericTask:output_type -> determined.api.v1.PauseGenericTaskResponse - 517, // 517: determined.api.v1.Determined.UnpauseGenericTask:output_type -> determined.api.v1.UnpauseGenericTaskResponse - 518, // 518: determined.api.v1.Determined.SearchRuns:output_type -> determined.api.v1.SearchRunsResponse - 519, // 519: determined.api.v1.Determined.MoveRuns:output_type -> determined.api.v1.MoveRunsResponse - 520, // 520: determined.api.v1.Determined.KillRuns:output_type -> determined.api.v1.KillRunsResponse - 521, // 521: determined.api.v1.Determined.DeleteRuns:output_type -> determined.api.v1.DeleteRunsResponse - 522, // 522: determined.api.v1.Determined.ArchiveRuns:output_type -> determined.api.v1.ArchiveRunsResponse - 523, // 523: determined.api.v1.Determined.UnarchiveRuns:output_type -> determined.api.v1.UnarchiveRunsResponse - 524, // 524: determined.api.v1.Determined.PauseRuns:output_type -> determined.api.v1.PauseRunsResponse - 525, // 525: determined.api.v1.Determined.ResumeRuns:output_type -> determined.api.v1.ResumeRunsResponse - 526, // 526: determined.api.v1.Determined.GetRunMetadata:output_type -> determined.api.v1.GetRunMetadataResponse - 527, // 527: determined.api.v1.Determined.PostRunMetadata:output_type -> determined.api.v1.PostRunMetadataResponse - 528, // 528: determined.api.v1.Determined.GetMetadataValues:output_type -> determined.api.v1.GetMetadataValuesResponse - 529, // 529: determined.api.v1.Determined.PutWorkspaceConfigPolicies:output_type -> determined.api.v1.PutWorkspaceConfigPoliciesResponse - 530, // 530: determined.api.v1.Determined.PutGlobalConfigPolicies:output_type -> determined.api.v1.PutGlobalConfigPoliciesResponse - 531, // 531: determined.api.v1.Determined.GetWorkspaceConfigPolicies:output_type -> determined.api.v1.GetWorkspaceConfigPoliciesResponse - 532, // 532: determined.api.v1.Determined.GetGlobalConfigPolicies:output_type -> determined.api.v1.GetGlobalConfigPoliciesResponse - 533, // 533: determined.api.v1.Determined.DeleteWorkspaceConfigPolicies:output_type -> determined.api.v1.DeleteWorkspaceConfigPoliciesResponse - 534, // 534: determined.api.v1.Determined.DeleteGlobalConfigPolicies:output_type -> determined.api.v1.DeleteGlobalConfigPoliciesResponse - 535, // 535: determined.api.v1.Determined.MoveSearches:output_type -> determined.api.v1.MoveSearchesResponse - 536, // 536: determined.api.v1.Determined.CancelSearches:output_type -> determined.api.v1.CancelSearchesResponse - 537, // 537: determined.api.v1.Determined.KillSearches:output_type -> determined.api.v1.KillSearchesResponse - 538, // 538: determined.api.v1.Determined.DeleteSearches:output_type -> determined.api.v1.DeleteSearchesResponse - 539, // 539: determined.api.v1.Determined.ArchiveSearches:output_type -> determined.api.v1.ArchiveSearchesResponse - 540, // 540: determined.api.v1.Determined.UnarchiveSearches:output_type -> determined.api.v1.UnarchiveSearchesResponse - 541, // 541: determined.api.v1.Determined.PauseSearches:output_type -> determined.api.v1.PauseSearchesResponse - 542, // 542: determined.api.v1.Determined.ResumeSearches:output_type -> determined.api.v1.ResumeSearchesResponse - 543, // 543: determined.api.v1.Determined.PostAccessToken:output_type -> determined.api.v1.PostAccessTokenResponse - 544, // 544: determined.api.v1.Determined.GetAccessTokens:output_type -> determined.api.v1.GetAccessTokensResponse - 545, // 545: determined.api.v1.Determined.PatchAccessToken:output_type -> determined.api.v1.PatchAccessTokenResponse - 273, // [273:546] is the sub-list for method output_type - 0, // [0:273] is the sub-list for method input_type + 105, // 105: determined.api.v1.Determined.ReportTrialSearcherEarlyExit:input_type -> determined.api.v1.ReportTrialSearcherEarlyExitRequest + 106, // 106: determined.api.v1.Determined.ReportTrialProgress:input_type -> determined.api.v1.ReportTrialProgressRequest + 107, // 107: determined.api.v1.Determined.PostTrialRunnerMetadata:input_type -> determined.api.v1.PostTrialRunnerMetadataRequest + 108, // 108: determined.api.v1.Determined.ReportTrialMetrics:input_type -> determined.api.v1.ReportTrialMetricsRequest + 109, // 109: determined.api.v1.Determined.ReportTrialTrainingMetrics:input_type -> determined.api.v1.ReportTrialTrainingMetricsRequest + 110, // 110: determined.api.v1.Determined.ReportTrialValidationMetrics:input_type -> determined.api.v1.ReportTrialValidationMetricsRequest + 111, // 111: determined.api.v1.Determined.ReportCheckpoint:input_type -> determined.api.v1.ReportCheckpointRequest + 112, // 112: determined.api.v1.Determined.GetJobs:input_type -> determined.api.v1.GetJobsRequest + 113, // 113: determined.api.v1.Determined.GetJobsV2:input_type -> determined.api.v1.GetJobsV2Request + 114, // 114: determined.api.v1.Determined.GetJobQueueStats:input_type -> determined.api.v1.GetJobQueueStatsRequest + 115, // 115: determined.api.v1.Determined.UpdateJobQueue:input_type -> determined.api.v1.UpdateJobQueueRequest + 116, // 116: determined.api.v1.Determined.GetTemplates:input_type -> determined.api.v1.GetTemplatesRequest + 117, // 117: determined.api.v1.Determined.GetTemplate:input_type -> determined.api.v1.GetTemplateRequest + 118, // 118: determined.api.v1.Determined.PutTemplate:input_type -> determined.api.v1.PutTemplateRequest + 119, // 119: determined.api.v1.Determined.PostTemplate:input_type -> determined.api.v1.PostTemplateRequest + 120, // 120: determined.api.v1.Determined.PatchTemplateConfig:input_type -> determined.api.v1.PatchTemplateConfigRequest + 121, // 121: determined.api.v1.Determined.PatchTemplateName:input_type -> determined.api.v1.PatchTemplateNameRequest + 122, // 122: determined.api.v1.Determined.DeleteTemplate:input_type -> determined.api.v1.DeleteTemplateRequest + 123, // 123: determined.api.v1.Determined.GetNotebooks:input_type -> determined.api.v1.GetNotebooksRequest + 124, // 124: determined.api.v1.Determined.GetNotebook:input_type -> determined.api.v1.GetNotebookRequest + 125, // 125: determined.api.v1.Determined.IdleNotebook:input_type -> determined.api.v1.IdleNotebookRequest + 126, // 126: determined.api.v1.Determined.KillNotebook:input_type -> determined.api.v1.KillNotebookRequest + 127, // 127: determined.api.v1.Determined.SetNotebookPriority:input_type -> determined.api.v1.SetNotebookPriorityRequest + 128, // 128: determined.api.v1.Determined.LaunchNotebook:input_type -> determined.api.v1.LaunchNotebookRequest + 129, // 129: determined.api.v1.Determined.GetShells:input_type -> determined.api.v1.GetShellsRequest + 130, // 130: determined.api.v1.Determined.GetShell:input_type -> determined.api.v1.GetShellRequest + 131, // 131: determined.api.v1.Determined.KillShell:input_type -> determined.api.v1.KillShellRequest + 132, // 132: determined.api.v1.Determined.SetShellPriority:input_type -> determined.api.v1.SetShellPriorityRequest + 133, // 133: determined.api.v1.Determined.LaunchShell:input_type -> determined.api.v1.LaunchShellRequest + 134, // 134: determined.api.v1.Determined.GetCommands:input_type -> determined.api.v1.GetCommandsRequest + 135, // 135: determined.api.v1.Determined.GetCommand:input_type -> determined.api.v1.GetCommandRequest + 136, // 136: determined.api.v1.Determined.KillCommand:input_type -> determined.api.v1.KillCommandRequest + 137, // 137: determined.api.v1.Determined.SetCommandPriority:input_type -> determined.api.v1.SetCommandPriorityRequest + 138, // 138: determined.api.v1.Determined.LaunchCommand:input_type -> determined.api.v1.LaunchCommandRequest + 139, // 139: determined.api.v1.Determined.GetTensorboards:input_type -> determined.api.v1.GetTensorboardsRequest + 140, // 140: determined.api.v1.Determined.GetTensorboard:input_type -> determined.api.v1.GetTensorboardRequest + 141, // 141: determined.api.v1.Determined.KillTensorboard:input_type -> determined.api.v1.KillTensorboardRequest + 142, // 142: determined.api.v1.Determined.SetTensorboardPriority:input_type -> determined.api.v1.SetTensorboardPriorityRequest + 143, // 143: determined.api.v1.Determined.LaunchTensorboard:input_type -> determined.api.v1.LaunchTensorboardRequest + 144, // 144: determined.api.v1.Determined.LaunchTensorboardSearches:input_type -> determined.api.v1.LaunchTensorboardSearchesRequest + 145, // 145: determined.api.v1.Determined.DeleteTensorboardFiles:input_type -> determined.api.v1.DeleteTensorboardFilesRequest + 146, // 146: determined.api.v1.Determined.GetActiveTasksCount:input_type -> determined.api.v1.GetActiveTasksCountRequest + 147, // 147: determined.api.v1.Determined.GetTask:input_type -> determined.api.v1.GetTaskRequest + 148, // 148: determined.api.v1.Determined.GetTasks:input_type -> determined.api.v1.GetTasksRequest + 149, // 149: determined.api.v1.Determined.GetModel:input_type -> determined.api.v1.GetModelRequest + 150, // 150: determined.api.v1.Determined.PostModel:input_type -> determined.api.v1.PostModelRequest + 151, // 151: determined.api.v1.Determined.PatchModel:input_type -> determined.api.v1.PatchModelRequest + 152, // 152: determined.api.v1.Determined.ArchiveModel:input_type -> determined.api.v1.ArchiveModelRequest + 153, // 153: determined.api.v1.Determined.UnarchiveModel:input_type -> determined.api.v1.UnarchiveModelRequest + 154, // 154: determined.api.v1.Determined.MoveModel:input_type -> determined.api.v1.MoveModelRequest + 155, // 155: determined.api.v1.Determined.DeleteModel:input_type -> determined.api.v1.DeleteModelRequest + 156, // 156: determined.api.v1.Determined.GetModels:input_type -> determined.api.v1.GetModelsRequest + 157, // 157: determined.api.v1.Determined.GetModelLabels:input_type -> determined.api.v1.GetModelLabelsRequest + 158, // 158: determined.api.v1.Determined.GetModelVersion:input_type -> determined.api.v1.GetModelVersionRequest + 159, // 159: determined.api.v1.Determined.GetModelVersions:input_type -> determined.api.v1.GetModelVersionsRequest + 160, // 160: determined.api.v1.Determined.PostModelVersion:input_type -> determined.api.v1.PostModelVersionRequest + 161, // 161: determined.api.v1.Determined.PatchModelVersion:input_type -> determined.api.v1.PatchModelVersionRequest + 162, // 162: determined.api.v1.Determined.DeleteModelVersion:input_type -> determined.api.v1.DeleteModelVersionRequest + 163, // 163: determined.api.v1.Determined.GetTrialMetricsByModelVersion:input_type -> determined.api.v1.GetTrialMetricsByModelVersionRequest + 164, // 164: determined.api.v1.Determined.GetCheckpoint:input_type -> determined.api.v1.GetCheckpointRequest + 165, // 165: determined.api.v1.Determined.PostCheckpointMetadata:input_type -> determined.api.v1.PostCheckpointMetadataRequest + 166, // 166: determined.api.v1.Determined.CheckpointsRemoveFiles:input_type -> determined.api.v1.CheckpointsRemoveFilesRequest + 167, // 167: determined.api.v1.Determined.PatchCheckpoints:input_type -> determined.api.v1.PatchCheckpointsRequest + 168, // 168: determined.api.v1.Determined.DeleteCheckpoints:input_type -> determined.api.v1.DeleteCheckpointsRequest + 169, // 169: determined.api.v1.Determined.GetTrialMetricsByCheckpoint:input_type -> determined.api.v1.GetTrialMetricsByCheckpointRequest + 170, // 170: determined.api.v1.Determined.ExpMetricNames:input_type -> determined.api.v1.ExpMetricNamesRequest + 171, // 171: determined.api.v1.Determined.MetricBatches:input_type -> determined.api.v1.MetricBatchesRequest + 172, // 172: determined.api.v1.Determined.TrialsSnapshot:input_type -> determined.api.v1.TrialsSnapshotRequest + 173, // 173: determined.api.v1.Determined.TrialsSample:input_type -> determined.api.v1.TrialsSampleRequest + 174, // 174: determined.api.v1.Determined.GetResourcePools:input_type -> determined.api.v1.GetResourcePoolsRequest + 175, // 175: determined.api.v1.Determined.GetKubernetesResourceManagers:input_type -> determined.api.v1.GetKubernetesResourceManagersRequest + 176, // 176: determined.api.v1.Determined.ResourceAllocationRaw:input_type -> determined.api.v1.ResourceAllocationRawRequest + 177, // 177: determined.api.v1.Determined.ResourceAllocationAggregated:input_type -> determined.api.v1.ResourceAllocationAggregatedRequest + 178, // 178: determined.api.v1.Determined.GetWorkspace:input_type -> determined.api.v1.GetWorkspaceRequest + 179, // 179: determined.api.v1.Determined.GetWorkspaceProjects:input_type -> determined.api.v1.GetWorkspaceProjectsRequest + 180, // 180: determined.api.v1.Determined.GetWorkspaces:input_type -> determined.api.v1.GetWorkspacesRequest + 181, // 181: determined.api.v1.Determined.PostWorkspace:input_type -> determined.api.v1.PostWorkspaceRequest + 182, // 182: determined.api.v1.Determined.PatchWorkspace:input_type -> determined.api.v1.PatchWorkspaceRequest + 183, // 183: determined.api.v1.Determined.DeleteWorkspace:input_type -> determined.api.v1.DeleteWorkspaceRequest + 184, // 184: determined.api.v1.Determined.ArchiveWorkspace:input_type -> determined.api.v1.ArchiveWorkspaceRequest + 185, // 185: determined.api.v1.Determined.UnarchiveWorkspace:input_type -> determined.api.v1.UnarchiveWorkspaceRequest + 186, // 186: determined.api.v1.Determined.PinWorkspace:input_type -> determined.api.v1.PinWorkspaceRequest + 187, // 187: determined.api.v1.Determined.UnpinWorkspace:input_type -> determined.api.v1.UnpinWorkspaceRequest + 188, // 188: determined.api.v1.Determined.SetWorkspaceNamespaceBindings:input_type -> determined.api.v1.SetWorkspaceNamespaceBindingsRequest + 189, // 189: determined.api.v1.Determined.SetResourceQuotas:input_type -> determined.api.v1.SetResourceQuotasRequest + 190, // 190: determined.api.v1.Determined.ListWorkspaceNamespaceBindings:input_type -> determined.api.v1.ListWorkspaceNamespaceBindingsRequest + 191, // 191: determined.api.v1.Determined.GetWorkspacesWithDefaultNamespaceBindings:input_type -> determined.api.v1.GetWorkspacesWithDefaultNamespaceBindingsRequest + 192, // 192: determined.api.v1.Determined.BulkAutoCreateWorkspaceNamespaceBindings:input_type -> determined.api.v1.BulkAutoCreateWorkspaceNamespaceBindingsRequest + 193, // 193: determined.api.v1.Determined.DeleteWorkspaceNamespaceBindings:input_type -> determined.api.v1.DeleteWorkspaceNamespaceBindingsRequest + 194, // 194: determined.api.v1.Determined.GetKubernetesResourceQuotas:input_type -> determined.api.v1.GetKubernetesResourceQuotasRequest + 195, // 195: determined.api.v1.Determined.GetProject:input_type -> determined.api.v1.GetProjectRequest + 196, // 196: determined.api.v1.Determined.GetProjectByKey:input_type -> determined.api.v1.GetProjectByKeyRequest + 197, // 197: determined.api.v1.Determined.GetProjectColumns:input_type -> determined.api.v1.GetProjectColumnsRequest + 198, // 198: determined.api.v1.Determined.GetProjectNumericMetricsRange:input_type -> determined.api.v1.GetProjectNumericMetricsRangeRequest + 199, // 199: determined.api.v1.Determined.PostProject:input_type -> determined.api.v1.PostProjectRequest + 200, // 200: determined.api.v1.Determined.AddProjectNote:input_type -> determined.api.v1.AddProjectNoteRequest + 201, // 201: determined.api.v1.Determined.PutProjectNotes:input_type -> determined.api.v1.PutProjectNotesRequest + 202, // 202: determined.api.v1.Determined.PatchProject:input_type -> determined.api.v1.PatchProjectRequest + 203, // 203: determined.api.v1.Determined.DeleteProject:input_type -> determined.api.v1.DeleteProjectRequest + 204, // 204: determined.api.v1.Determined.ArchiveProject:input_type -> determined.api.v1.ArchiveProjectRequest + 205, // 205: determined.api.v1.Determined.UnarchiveProject:input_type -> determined.api.v1.UnarchiveProjectRequest + 206, // 206: determined.api.v1.Determined.MoveProject:input_type -> determined.api.v1.MoveProjectRequest + 207, // 207: determined.api.v1.Determined.MoveExperiment:input_type -> determined.api.v1.MoveExperimentRequest + 208, // 208: determined.api.v1.Determined.MoveExperiments:input_type -> determined.api.v1.MoveExperimentsRequest + 209, // 209: determined.api.v1.Determined.GetWebhooks:input_type -> determined.api.v1.GetWebhooksRequest + 210, // 210: determined.api.v1.Determined.PatchWebhook:input_type -> determined.api.v1.PatchWebhookRequest + 211, // 211: determined.api.v1.Determined.PostWebhook:input_type -> determined.api.v1.PostWebhookRequest + 212, // 212: determined.api.v1.Determined.DeleteWebhook:input_type -> determined.api.v1.DeleteWebhookRequest + 213, // 213: determined.api.v1.Determined.TestWebhook:input_type -> determined.api.v1.TestWebhookRequest + 214, // 214: determined.api.v1.Determined.PostWebhookEventData:input_type -> determined.api.v1.PostWebhookEventDataRequest + 215, // 215: determined.api.v1.Determined.GetGroup:input_type -> determined.api.v1.GetGroupRequest + 216, // 216: determined.api.v1.Determined.GetGroups:input_type -> determined.api.v1.GetGroupsRequest + 217, // 217: determined.api.v1.Determined.CreateGroup:input_type -> determined.api.v1.CreateGroupRequest + 218, // 218: determined.api.v1.Determined.UpdateGroup:input_type -> determined.api.v1.UpdateGroupRequest + 219, // 219: determined.api.v1.Determined.DeleteGroup:input_type -> determined.api.v1.DeleteGroupRequest + 220, // 220: determined.api.v1.Determined.GetPermissionsSummary:input_type -> determined.api.v1.GetPermissionsSummaryRequest + 221, // 221: determined.api.v1.Determined.GetGroupsAndUsersAssignedToWorkspace:input_type -> determined.api.v1.GetGroupsAndUsersAssignedToWorkspaceRequest + 222, // 222: determined.api.v1.Determined.GetRolesByID:input_type -> determined.api.v1.GetRolesByIDRequest + 223, // 223: determined.api.v1.Determined.GetRolesAssignedToUser:input_type -> determined.api.v1.GetRolesAssignedToUserRequest + 224, // 224: determined.api.v1.Determined.GetRolesAssignedToGroup:input_type -> determined.api.v1.GetRolesAssignedToGroupRequest + 225, // 225: determined.api.v1.Determined.SearchRolesAssignableToScope:input_type -> determined.api.v1.SearchRolesAssignableToScopeRequest + 226, // 226: determined.api.v1.Determined.ListRoles:input_type -> determined.api.v1.ListRolesRequest + 227, // 227: determined.api.v1.Determined.AssignRoles:input_type -> determined.api.v1.AssignRolesRequest + 228, // 228: determined.api.v1.Determined.RemoveAssignments:input_type -> determined.api.v1.RemoveAssignmentsRequest + 229, // 229: determined.api.v1.Determined.PostUserActivity:input_type -> determined.api.v1.PostUserActivityRequest + 230, // 230: determined.api.v1.Determined.GetProjectsByUserActivity:input_type -> determined.api.v1.GetProjectsByUserActivityRequest + 231, // 231: determined.api.v1.Determined.SearchExperiments:input_type -> determined.api.v1.SearchExperimentsRequest + 232, // 232: determined.api.v1.Determined.BindRPToWorkspace:input_type -> determined.api.v1.BindRPToWorkspaceRequest + 233, // 233: determined.api.v1.Determined.UnbindRPFromWorkspace:input_type -> determined.api.v1.UnbindRPFromWorkspaceRequest + 234, // 234: determined.api.v1.Determined.OverwriteRPWorkspaceBindings:input_type -> determined.api.v1.OverwriteRPWorkspaceBindingsRequest + 235, // 235: determined.api.v1.Determined.ListRPsBoundToWorkspace:input_type -> determined.api.v1.ListRPsBoundToWorkspaceRequest + 236, // 236: determined.api.v1.Determined.ListWorkspacesBoundToRP:input_type -> determined.api.v1.ListWorkspacesBoundToRPRequest + 237, // 237: determined.api.v1.Determined.GetGenericTaskConfig:input_type -> determined.api.v1.GetGenericTaskConfigRequest + 238, // 238: determined.api.v1.Determined.KillGenericTask:input_type -> determined.api.v1.KillGenericTaskRequest + 239, // 239: determined.api.v1.Determined.PauseGenericTask:input_type -> determined.api.v1.PauseGenericTaskRequest + 240, // 240: determined.api.v1.Determined.UnpauseGenericTask:input_type -> determined.api.v1.UnpauseGenericTaskRequest + 241, // 241: determined.api.v1.Determined.SearchRuns:input_type -> determined.api.v1.SearchRunsRequest + 242, // 242: determined.api.v1.Determined.MoveRuns:input_type -> determined.api.v1.MoveRunsRequest + 243, // 243: determined.api.v1.Determined.KillRuns:input_type -> determined.api.v1.KillRunsRequest + 244, // 244: determined.api.v1.Determined.DeleteRuns:input_type -> determined.api.v1.DeleteRunsRequest + 245, // 245: determined.api.v1.Determined.ArchiveRuns:input_type -> determined.api.v1.ArchiveRunsRequest + 246, // 246: determined.api.v1.Determined.UnarchiveRuns:input_type -> determined.api.v1.UnarchiveRunsRequest + 247, // 247: determined.api.v1.Determined.PauseRuns:input_type -> determined.api.v1.PauseRunsRequest + 248, // 248: determined.api.v1.Determined.ResumeRuns:input_type -> determined.api.v1.ResumeRunsRequest + 249, // 249: determined.api.v1.Determined.GetRunMetadata:input_type -> determined.api.v1.GetRunMetadataRequest + 250, // 250: determined.api.v1.Determined.PostRunMetadata:input_type -> determined.api.v1.PostRunMetadataRequest + 251, // 251: determined.api.v1.Determined.GetMetadataValues:input_type -> determined.api.v1.GetMetadataValuesRequest + 252, // 252: determined.api.v1.Determined.PutWorkspaceConfigPolicies:input_type -> determined.api.v1.PutWorkspaceConfigPoliciesRequest + 253, // 253: determined.api.v1.Determined.PutGlobalConfigPolicies:input_type -> determined.api.v1.PutGlobalConfigPoliciesRequest + 254, // 254: determined.api.v1.Determined.GetWorkspaceConfigPolicies:input_type -> determined.api.v1.GetWorkspaceConfigPoliciesRequest + 255, // 255: determined.api.v1.Determined.GetGlobalConfigPolicies:input_type -> determined.api.v1.GetGlobalConfigPoliciesRequest + 256, // 256: determined.api.v1.Determined.DeleteWorkspaceConfigPolicies:input_type -> determined.api.v1.DeleteWorkspaceConfigPoliciesRequest + 257, // 257: determined.api.v1.Determined.DeleteGlobalConfigPolicies:input_type -> determined.api.v1.DeleteGlobalConfigPoliciesRequest + 258, // 258: determined.api.v1.Determined.MoveSearches:input_type -> determined.api.v1.MoveSearchesRequest + 259, // 259: determined.api.v1.Determined.CancelSearches:input_type -> determined.api.v1.CancelSearchesRequest + 260, // 260: determined.api.v1.Determined.KillSearches:input_type -> determined.api.v1.KillSearchesRequest + 261, // 261: determined.api.v1.Determined.DeleteSearches:input_type -> determined.api.v1.DeleteSearchesRequest + 262, // 262: determined.api.v1.Determined.ArchiveSearches:input_type -> determined.api.v1.ArchiveSearchesRequest + 263, // 263: determined.api.v1.Determined.UnarchiveSearches:input_type -> determined.api.v1.UnarchiveSearchesRequest + 264, // 264: determined.api.v1.Determined.PauseSearches:input_type -> determined.api.v1.PauseSearchesRequest + 265, // 265: determined.api.v1.Determined.ResumeSearches:input_type -> determined.api.v1.ResumeSearchesRequest + 266, // 266: determined.api.v1.Determined.PostAccessToken:input_type -> determined.api.v1.PostAccessTokenRequest + 267, // 267: determined.api.v1.Determined.GetAccessTokens:input_type -> determined.api.v1.GetAccessTokensRequest + 268, // 268: determined.api.v1.Determined.PatchAccessToken:input_type -> determined.api.v1.PatchAccessTokenRequest + 269, // 269: determined.api.v1.Determined.Login:output_type -> determined.api.v1.LoginResponse + 270, // 270: determined.api.v1.Determined.CurrentUser:output_type -> determined.api.v1.CurrentUserResponse + 271, // 271: determined.api.v1.Determined.Logout:output_type -> determined.api.v1.LogoutResponse + 272, // 272: determined.api.v1.Determined.GetUsers:output_type -> determined.api.v1.GetUsersResponse + 273, // 273: determined.api.v1.Determined.GetUserSetting:output_type -> determined.api.v1.GetUserSettingResponse + 274, // 274: determined.api.v1.Determined.ResetUserSetting:output_type -> determined.api.v1.ResetUserSettingResponse + 275, // 275: determined.api.v1.Determined.PostUserSetting:output_type -> determined.api.v1.PostUserSettingResponse + 276, // 276: determined.api.v1.Determined.GetUser:output_type -> determined.api.v1.GetUserResponse + 277, // 277: determined.api.v1.Determined.GetUserByUsername:output_type -> determined.api.v1.GetUserByUsernameResponse + 278, // 278: determined.api.v1.Determined.GetMe:output_type -> determined.api.v1.GetMeResponse + 279, // 279: determined.api.v1.Determined.PostUser:output_type -> determined.api.v1.PostUserResponse + 280, // 280: determined.api.v1.Determined.SetUserPassword:output_type -> determined.api.v1.SetUserPasswordResponse + 281, // 281: determined.api.v1.Determined.AssignMultipleGroups:output_type -> determined.api.v1.AssignMultipleGroupsResponse + 282, // 282: determined.api.v1.Determined.PatchUser:output_type -> determined.api.v1.PatchUserResponse + 283, // 283: determined.api.v1.Determined.PatchUsers:output_type -> determined.api.v1.PatchUsersResponse + 284, // 284: determined.api.v1.Determined.GetTelemetry:output_type -> determined.api.v1.GetTelemetryResponse + 285, // 285: determined.api.v1.Determined.GetMaster:output_type -> determined.api.v1.GetMasterResponse + 286, // 286: determined.api.v1.Determined.GetMasterConfig:output_type -> determined.api.v1.GetMasterConfigResponse + 287, // 287: determined.api.v1.Determined.PatchMasterConfig:output_type -> determined.api.v1.PatchMasterConfigResponse + 288, // 288: determined.api.v1.Determined.MasterLogs:output_type -> determined.api.v1.MasterLogsResponse + 289, // 289: determined.api.v1.Determined.GetClusterMessage:output_type -> determined.api.v1.GetClusterMessageResponse + 290, // 290: determined.api.v1.Determined.SetClusterMessage:output_type -> determined.api.v1.SetClusterMessageResponse + 291, // 291: determined.api.v1.Determined.DeleteClusterMessage:output_type -> determined.api.v1.DeleteClusterMessageResponse + 292, // 292: determined.api.v1.Determined.GetAgents:output_type -> determined.api.v1.GetAgentsResponse + 293, // 293: determined.api.v1.Determined.GetAgent:output_type -> determined.api.v1.GetAgentResponse + 294, // 294: determined.api.v1.Determined.GetSlots:output_type -> determined.api.v1.GetSlotsResponse + 295, // 295: determined.api.v1.Determined.GetSlot:output_type -> determined.api.v1.GetSlotResponse + 296, // 296: determined.api.v1.Determined.EnableAgent:output_type -> determined.api.v1.EnableAgentResponse + 297, // 297: determined.api.v1.Determined.DisableAgent:output_type -> determined.api.v1.DisableAgentResponse + 298, // 298: determined.api.v1.Determined.EnableSlot:output_type -> determined.api.v1.EnableSlotResponse + 299, // 299: determined.api.v1.Determined.DisableSlot:output_type -> determined.api.v1.DisableSlotResponse + 300, // 300: determined.api.v1.Determined.CreateGenericTask:output_type -> determined.api.v1.CreateGenericTaskResponse + 301, // 301: determined.api.v1.Determined.CreateExperiment:output_type -> determined.api.v1.CreateExperimentResponse + 302, // 302: determined.api.v1.Determined.PutExperiment:output_type -> determined.api.v1.PutExperimentResponse + 303, // 303: determined.api.v1.Determined.ContinueExperiment:output_type -> determined.api.v1.ContinueExperimentResponse + 304, // 304: determined.api.v1.Determined.GetExperiment:output_type -> determined.api.v1.GetExperimentResponse + 305, // 305: determined.api.v1.Determined.GetExperiments:output_type -> determined.api.v1.GetExperimentsResponse + 306, // 306: determined.api.v1.Determined.PutExperimentRetainLogs:output_type -> determined.api.v1.PutExperimentRetainLogsResponse + 307, // 307: determined.api.v1.Determined.PutExperimentsRetainLogs:output_type -> determined.api.v1.PutExperimentsRetainLogsResponse + 308, // 308: determined.api.v1.Determined.PutTrialRetainLogs:output_type -> determined.api.v1.PutTrialRetainLogsResponse + 309, // 309: determined.api.v1.Determined.GetModelDef:output_type -> determined.api.v1.GetModelDefResponse + 310, // 310: determined.api.v1.Determined.GetTaskContextDirectory:output_type -> determined.api.v1.GetTaskContextDirectoryResponse + 311, // 311: determined.api.v1.Determined.GetModelDefTree:output_type -> determined.api.v1.GetModelDefTreeResponse + 312, // 312: determined.api.v1.Determined.GetModelDefFile:output_type -> determined.api.v1.GetModelDefFileResponse + 313, // 313: determined.api.v1.Determined.GetExperimentLabels:output_type -> determined.api.v1.GetExperimentLabelsResponse + 314, // 314: determined.api.v1.Determined.GetExperimentValidationHistory:output_type -> determined.api.v1.GetExperimentValidationHistoryResponse + 315, // 315: determined.api.v1.Determined.ActivateExperiment:output_type -> determined.api.v1.ActivateExperimentResponse + 316, // 316: determined.api.v1.Determined.ActivateExperiments:output_type -> determined.api.v1.ActivateExperimentsResponse + 317, // 317: determined.api.v1.Determined.PauseExperiment:output_type -> determined.api.v1.PauseExperimentResponse + 318, // 318: determined.api.v1.Determined.PauseExperiments:output_type -> determined.api.v1.PauseExperimentsResponse + 319, // 319: determined.api.v1.Determined.CancelExperiment:output_type -> determined.api.v1.CancelExperimentResponse + 320, // 320: determined.api.v1.Determined.CancelExperiments:output_type -> determined.api.v1.CancelExperimentsResponse + 321, // 321: determined.api.v1.Determined.KillExperiment:output_type -> determined.api.v1.KillExperimentResponse + 322, // 322: determined.api.v1.Determined.KillExperiments:output_type -> determined.api.v1.KillExperimentsResponse + 323, // 323: determined.api.v1.Determined.ArchiveExperiment:output_type -> determined.api.v1.ArchiveExperimentResponse + 324, // 324: determined.api.v1.Determined.ArchiveExperiments:output_type -> determined.api.v1.ArchiveExperimentsResponse + 325, // 325: determined.api.v1.Determined.UnarchiveExperiment:output_type -> determined.api.v1.UnarchiveExperimentResponse + 326, // 326: determined.api.v1.Determined.UnarchiveExperiments:output_type -> determined.api.v1.UnarchiveExperimentsResponse + 327, // 327: determined.api.v1.Determined.PatchExperiment:output_type -> determined.api.v1.PatchExperimentResponse + 328, // 328: determined.api.v1.Determined.DeleteExperiments:output_type -> determined.api.v1.DeleteExperimentsResponse + 329, // 329: determined.api.v1.Determined.DeleteExperiment:output_type -> determined.api.v1.DeleteExperimentResponse + 330, // 330: determined.api.v1.Determined.GetBestSearcherValidationMetric:output_type -> determined.api.v1.GetBestSearcherValidationMetricResponse + 331, // 331: determined.api.v1.Determined.GetExperimentCheckpoints:output_type -> determined.api.v1.GetExperimentCheckpointsResponse + 332, // 332: determined.api.v1.Determined.PutExperimentLabel:output_type -> determined.api.v1.PutExperimentLabelResponse + 333, // 333: determined.api.v1.Determined.DeleteExperimentLabel:output_type -> determined.api.v1.DeleteExperimentLabelResponse + 334, // 334: determined.api.v1.Determined.PreviewHPSearch:output_type -> determined.api.v1.PreviewHPSearchResponse + 335, // 335: determined.api.v1.Determined.GetExperimentTrials:output_type -> determined.api.v1.GetExperimentTrialsResponse + 336, // 336: determined.api.v1.Determined.GetTrialRemainingLogRetentionDays:output_type -> determined.api.v1.GetTrialRemainingLogRetentionDaysResponse + 337, // 337: determined.api.v1.Determined.CompareTrials:output_type -> determined.api.v1.CompareTrialsResponse + 338, // 338: determined.api.v1.Determined.ReportTrialSourceInfo:output_type -> determined.api.v1.ReportTrialSourceInfoResponse + 339, // 339: determined.api.v1.Determined.CreateTrial:output_type -> determined.api.v1.CreateTrialResponse + 340, // 340: determined.api.v1.Determined.PutTrial:output_type -> determined.api.v1.PutTrialResponse + 341, // 341: determined.api.v1.Determined.PatchTrial:output_type -> determined.api.v1.PatchTrialResponse + 342, // 342: determined.api.v1.Determined.StartTrial:output_type -> determined.api.v1.StartTrialResponse + 343, // 343: determined.api.v1.Determined.RunPrepareForReporting:output_type -> determined.api.v1.RunPrepareForReportingResponse + 344, // 344: determined.api.v1.Determined.GetTrial:output_type -> determined.api.v1.GetTrialResponse + 345, // 345: determined.api.v1.Determined.GetTrialByExternalID:output_type -> determined.api.v1.GetTrialByExternalIDResponse + 346, // 346: determined.api.v1.Determined.GetTrialWorkloads:output_type -> determined.api.v1.GetTrialWorkloadsResponse + 347, // 347: determined.api.v1.Determined.TrialLogs:output_type -> determined.api.v1.TrialLogsResponse + 348, // 348: determined.api.v1.Determined.TrialLogsFields:output_type -> determined.api.v1.TrialLogsFieldsResponse + 349, // 349: determined.api.v1.Determined.AllocationReady:output_type -> determined.api.v1.AllocationReadyResponse + 350, // 350: determined.api.v1.Determined.GetAllocation:output_type -> determined.api.v1.GetAllocationResponse + 351, // 351: determined.api.v1.Determined.AllocationWaiting:output_type -> determined.api.v1.AllocationWaitingResponse + 352, // 352: determined.api.v1.Determined.PostTaskLogs:output_type -> determined.api.v1.PostTaskLogsResponse + 353, // 353: determined.api.v1.Determined.TaskLogs:output_type -> determined.api.v1.TaskLogsResponse + 354, // 354: determined.api.v1.Determined.TaskLogsFields:output_type -> determined.api.v1.TaskLogsFieldsResponse + 355, // 355: determined.api.v1.Determined.GetTrialProfilerMetrics:output_type -> determined.api.v1.GetTrialProfilerMetricsResponse + 356, // 356: determined.api.v1.Determined.GetTrialProfilerAvailableSeries:output_type -> determined.api.v1.GetTrialProfilerAvailableSeriesResponse + 357, // 357: determined.api.v1.Determined.PostTrialProfilerMetricsBatch:output_type -> determined.api.v1.PostTrialProfilerMetricsBatchResponse + 358, // 358: determined.api.v1.Determined.GetMetrics:output_type -> determined.api.v1.GetMetricsResponse + 359, // 359: determined.api.v1.Determined.GetTrainingMetrics:output_type -> determined.api.v1.GetTrainingMetricsResponse + 360, // 360: determined.api.v1.Determined.GetValidationMetrics:output_type -> determined.api.v1.GetValidationMetricsResponse + 361, // 361: determined.api.v1.Determined.KillTrial:output_type -> determined.api.v1.KillTrialResponse + 362, // 362: determined.api.v1.Determined.GetTrialCheckpoints:output_type -> determined.api.v1.GetTrialCheckpointsResponse + 363, // 363: determined.api.v1.Determined.CleanupLogs:output_type -> determined.api.v1.CleanupLogsResponse + 364, // 364: determined.api.v1.Determined.AllocationPreemptionSignal:output_type -> determined.api.v1.AllocationPreemptionSignalResponse + 365, // 365: determined.api.v1.Determined.AllocationPendingPreemptionSignal:output_type -> determined.api.v1.AllocationPendingPreemptionSignalResponse + 366, // 366: determined.api.v1.Determined.AckAllocationPreemptionSignal:output_type -> determined.api.v1.AckAllocationPreemptionSignalResponse + 367, // 367: determined.api.v1.Determined.MarkAllocationResourcesDaemon:output_type -> determined.api.v1.MarkAllocationResourcesDaemonResponse + 368, // 368: determined.api.v1.Determined.AllocationRendezvousInfo:output_type -> determined.api.v1.AllocationRendezvousInfoResponse + 369, // 369: determined.api.v1.Determined.PostAllocationProxyAddress:output_type -> determined.api.v1.PostAllocationProxyAddressResponse + 370, // 370: determined.api.v1.Determined.GetTaskAcceleratorData:output_type -> determined.api.v1.GetTaskAcceleratorDataResponse + 371, // 371: determined.api.v1.Determined.PostAllocationAcceleratorData:output_type -> determined.api.v1.PostAllocationAcceleratorDataResponse + 372, // 372: determined.api.v1.Determined.AllocationAllGather:output_type -> determined.api.v1.AllocationAllGatherResponse + 373, // 373: determined.api.v1.Determined.NotifyContainerRunning:output_type -> determined.api.v1.NotifyContainerRunningResponse + 374, // 374: determined.api.v1.Determined.ReportTrialSearcherEarlyExit:output_type -> determined.api.v1.ReportTrialSearcherEarlyExitResponse + 375, // 375: determined.api.v1.Determined.ReportTrialProgress:output_type -> determined.api.v1.ReportTrialProgressResponse + 376, // 376: determined.api.v1.Determined.PostTrialRunnerMetadata:output_type -> determined.api.v1.PostTrialRunnerMetadataResponse + 377, // 377: determined.api.v1.Determined.ReportTrialMetrics:output_type -> determined.api.v1.ReportTrialMetricsResponse + 378, // 378: determined.api.v1.Determined.ReportTrialTrainingMetrics:output_type -> determined.api.v1.ReportTrialTrainingMetricsResponse + 379, // 379: determined.api.v1.Determined.ReportTrialValidationMetrics:output_type -> determined.api.v1.ReportTrialValidationMetricsResponse + 380, // 380: determined.api.v1.Determined.ReportCheckpoint:output_type -> determined.api.v1.ReportCheckpointResponse + 381, // 381: determined.api.v1.Determined.GetJobs:output_type -> determined.api.v1.GetJobsResponse + 382, // 382: determined.api.v1.Determined.GetJobsV2:output_type -> determined.api.v1.GetJobsV2Response + 383, // 383: determined.api.v1.Determined.GetJobQueueStats:output_type -> determined.api.v1.GetJobQueueStatsResponse + 384, // 384: determined.api.v1.Determined.UpdateJobQueue:output_type -> determined.api.v1.UpdateJobQueueResponse + 385, // 385: determined.api.v1.Determined.GetTemplates:output_type -> determined.api.v1.GetTemplatesResponse + 386, // 386: determined.api.v1.Determined.GetTemplate:output_type -> determined.api.v1.GetTemplateResponse + 387, // 387: determined.api.v1.Determined.PutTemplate:output_type -> determined.api.v1.PutTemplateResponse + 388, // 388: determined.api.v1.Determined.PostTemplate:output_type -> determined.api.v1.PostTemplateResponse + 389, // 389: determined.api.v1.Determined.PatchTemplateConfig:output_type -> determined.api.v1.PatchTemplateConfigResponse + 390, // 390: determined.api.v1.Determined.PatchTemplateName:output_type -> determined.api.v1.PatchTemplateNameResponse + 391, // 391: determined.api.v1.Determined.DeleteTemplate:output_type -> determined.api.v1.DeleteTemplateResponse + 392, // 392: determined.api.v1.Determined.GetNotebooks:output_type -> determined.api.v1.GetNotebooksResponse + 393, // 393: determined.api.v1.Determined.GetNotebook:output_type -> determined.api.v1.GetNotebookResponse + 394, // 394: determined.api.v1.Determined.IdleNotebook:output_type -> determined.api.v1.IdleNotebookResponse + 395, // 395: determined.api.v1.Determined.KillNotebook:output_type -> determined.api.v1.KillNotebookResponse + 396, // 396: determined.api.v1.Determined.SetNotebookPriority:output_type -> determined.api.v1.SetNotebookPriorityResponse + 397, // 397: determined.api.v1.Determined.LaunchNotebook:output_type -> determined.api.v1.LaunchNotebookResponse + 398, // 398: determined.api.v1.Determined.GetShells:output_type -> determined.api.v1.GetShellsResponse + 399, // 399: determined.api.v1.Determined.GetShell:output_type -> determined.api.v1.GetShellResponse + 400, // 400: determined.api.v1.Determined.KillShell:output_type -> determined.api.v1.KillShellResponse + 401, // 401: determined.api.v1.Determined.SetShellPriority:output_type -> determined.api.v1.SetShellPriorityResponse + 402, // 402: determined.api.v1.Determined.LaunchShell:output_type -> determined.api.v1.LaunchShellResponse + 403, // 403: determined.api.v1.Determined.GetCommands:output_type -> determined.api.v1.GetCommandsResponse + 404, // 404: determined.api.v1.Determined.GetCommand:output_type -> determined.api.v1.GetCommandResponse + 405, // 405: determined.api.v1.Determined.KillCommand:output_type -> determined.api.v1.KillCommandResponse + 406, // 406: determined.api.v1.Determined.SetCommandPriority:output_type -> determined.api.v1.SetCommandPriorityResponse + 407, // 407: determined.api.v1.Determined.LaunchCommand:output_type -> determined.api.v1.LaunchCommandResponse + 408, // 408: determined.api.v1.Determined.GetTensorboards:output_type -> determined.api.v1.GetTensorboardsResponse + 409, // 409: determined.api.v1.Determined.GetTensorboard:output_type -> determined.api.v1.GetTensorboardResponse + 410, // 410: determined.api.v1.Determined.KillTensorboard:output_type -> determined.api.v1.KillTensorboardResponse + 411, // 411: determined.api.v1.Determined.SetTensorboardPriority:output_type -> determined.api.v1.SetTensorboardPriorityResponse + 412, // 412: determined.api.v1.Determined.LaunchTensorboard:output_type -> determined.api.v1.LaunchTensorboardResponse + 413, // 413: determined.api.v1.Determined.LaunchTensorboardSearches:output_type -> determined.api.v1.LaunchTensorboardSearchesResponse + 414, // 414: determined.api.v1.Determined.DeleteTensorboardFiles:output_type -> determined.api.v1.DeleteTensorboardFilesResponse + 415, // 415: determined.api.v1.Determined.GetActiveTasksCount:output_type -> determined.api.v1.GetActiveTasksCountResponse + 416, // 416: determined.api.v1.Determined.GetTask:output_type -> determined.api.v1.GetTaskResponse + 417, // 417: determined.api.v1.Determined.GetTasks:output_type -> determined.api.v1.GetTasksResponse + 418, // 418: determined.api.v1.Determined.GetModel:output_type -> determined.api.v1.GetModelResponse + 419, // 419: determined.api.v1.Determined.PostModel:output_type -> determined.api.v1.PostModelResponse + 420, // 420: determined.api.v1.Determined.PatchModel:output_type -> determined.api.v1.PatchModelResponse + 421, // 421: determined.api.v1.Determined.ArchiveModel:output_type -> determined.api.v1.ArchiveModelResponse + 422, // 422: determined.api.v1.Determined.UnarchiveModel:output_type -> determined.api.v1.UnarchiveModelResponse + 423, // 423: determined.api.v1.Determined.MoveModel:output_type -> determined.api.v1.MoveModelResponse + 424, // 424: determined.api.v1.Determined.DeleteModel:output_type -> determined.api.v1.DeleteModelResponse + 425, // 425: determined.api.v1.Determined.GetModels:output_type -> determined.api.v1.GetModelsResponse + 426, // 426: determined.api.v1.Determined.GetModelLabels:output_type -> determined.api.v1.GetModelLabelsResponse + 427, // 427: determined.api.v1.Determined.GetModelVersion:output_type -> determined.api.v1.GetModelVersionResponse + 428, // 428: determined.api.v1.Determined.GetModelVersions:output_type -> determined.api.v1.GetModelVersionsResponse + 429, // 429: determined.api.v1.Determined.PostModelVersion:output_type -> determined.api.v1.PostModelVersionResponse + 430, // 430: determined.api.v1.Determined.PatchModelVersion:output_type -> determined.api.v1.PatchModelVersionResponse + 431, // 431: determined.api.v1.Determined.DeleteModelVersion:output_type -> determined.api.v1.DeleteModelVersionResponse + 432, // 432: determined.api.v1.Determined.GetTrialMetricsByModelVersion:output_type -> determined.api.v1.GetTrialMetricsByModelVersionResponse + 433, // 433: determined.api.v1.Determined.GetCheckpoint:output_type -> determined.api.v1.GetCheckpointResponse + 434, // 434: determined.api.v1.Determined.PostCheckpointMetadata:output_type -> determined.api.v1.PostCheckpointMetadataResponse + 435, // 435: determined.api.v1.Determined.CheckpointsRemoveFiles:output_type -> determined.api.v1.CheckpointsRemoveFilesResponse + 436, // 436: determined.api.v1.Determined.PatchCheckpoints:output_type -> determined.api.v1.PatchCheckpointsResponse + 437, // 437: determined.api.v1.Determined.DeleteCheckpoints:output_type -> determined.api.v1.DeleteCheckpointsResponse + 438, // 438: determined.api.v1.Determined.GetTrialMetricsByCheckpoint:output_type -> determined.api.v1.GetTrialMetricsByCheckpointResponse + 439, // 439: determined.api.v1.Determined.ExpMetricNames:output_type -> determined.api.v1.ExpMetricNamesResponse + 440, // 440: determined.api.v1.Determined.MetricBatches:output_type -> determined.api.v1.MetricBatchesResponse + 441, // 441: determined.api.v1.Determined.TrialsSnapshot:output_type -> determined.api.v1.TrialsSnapshotResponse + 442, // 442: determined.api.v1.Determined.TrialsSample:output_type -> determined.api.v1.TrialsSampleResponse + 443, // 443: determined.api.v1.Determined.GetResourcePools:output_type -> determined.api.v1.GetResourcePoolsResponse + 444, // 444: determined.api.v1.Determined.GetKubernetesResourceManagers:output_type -> determined.api.v1.GetKubernetesResourceManagersResponse + 445, // 445: determined.api.v1.Determined.ResourceAllocationRaw:output_type -> determined.api.v1.ResourceAllocationRawResponse + 446, // 446: determined.api.v1.Determined.ResourceAllocationAggregated:output_type -> determined.api.v1.ResourceAllocationAggregatedResponse + 447, // 447: determined.api.v1.Determined.GetWorkspace:output_type -> determined.api.v1.GetWorkspaceResponse + 448, // 448: determined.api.v1.Determined.GetWorkspaceProjects:output_type -> determined.api.v1.GetWorkspaceProjectsResponse + 449, // 449: determined.api.v1.Determined.GetWorkspaces:output_type -> determined.api.v1.GetWorkspacesResponse + 450, // 450: determined.api.v1.Determined.PostWorkspace:output_type -> determined.api.v1.PostWorkspaceResponse + 451, // 451: determined.api.v1.Determined.PatchWorkspace:output_type -> determined.api.v1.PatchWorkspaceResponse + 452, // 452: determined.api.v1.Determined.DeleteWorkspace:output_type -> determined.api.v1.DeleteWorkspaceResponse + 453, // 453: determined.api.v1.Determined.ArchiveWorkspace:output_type -> determined.api.v1.ArchiveWorkspaceResponse + 454, // 454: determined.api.v1.Determined.UnarchiveWorkspace:output_type -> determined.api.v1.UnarchiveWorkspaceResponse + 455, // 455: determined.api.v1.Determined.PinWorkspace:output_type -> determined.api.v1.PinWorkspaceResponse + 456, // 456: determined.api.v1.Determined.UnpinWorkspace:output_type -> determined.api.v1.UnpinWorkspaceResponse + 457, // 457: determined.api.v1.Determined.SetWorkspaceNamespaceBindings:output_type -> determined.api.v1.SetWorkspaceNamespaceBindingsResponse + 458, // 458: determined.api.v1.Determined.SetResourceQuotas:output_type -> determined.api.v1.SetResourceQuotasResponse + 459, // 459: determined.api.v1.Determined.ListWorkspaceNamespaceBindings:output_type -> determined.api.v1.ListWorkspaceNamespaceBindingsResponse + 460, // 460: determined.api.v1.Determined.GetWorkspacesWithDefaultNamespaceBindings:output_type -> determined.api.v1.GetWorkspacesWithDefaultNamespaceBindingsResponse + 461, // 461: determined.api.v1.Determined.BulkAutoCreateWorkspaceNamespaceBindings:output_type -> determined.api.v1.BulkAutoCreateWorkspaceNamespaceBindingsResponse + 462, // 462: determined.api.v1.Determined.DeleteWorkspaceNamespaceBindings:output_type -> determined.api.v1.DeleteWorkspaceNamespaceBindingsResponse + 463, // 463: determined.api.v1.Determined.GetKubernetesResourceQuotas:output_type -> determined.api.v1.GetKubernetesResourceQuotasResponse + 464, // 464: determined.api.v1.Determined.GetProject:output_type -> determined.api.v1.GetProjectResponse + 465, // 465: determined.api.v1.Determined.GetProjectByKey:output_type -> determined.api.v1.GetProjectByKeyResponse + 466, // 466: determined.api.v1.Determined.GetProjectColumns:output_type -> determined.api.v1.GetProjectColumnsResponse + 467, // 467: determined.api.v1.Determined.GetProjectNumericMetricsRange:output_type -> determined.api.v1.GetProjectNumericMetricsRangeResponse + 468, // 468: determined.api.v1.Determined.PostProject:output_type -> determined.api.v1.PostProjectResponse + 469, // 469: determined.api.v1.Determined.AddProjectNote:output_type -> determined.api.v1.AddProjectNoteResponse + 470, // 470: determined.api.v1.Determined.PutProjectNotes:output_type -> determined.api.v1.PutProjectNotesResponse + 471, // 471: determined.api.v1.Determined.PatchProject:output_type -> determined.api.v1.PatchProjectResponse + 472, // 472: determined.api.v1.Determined.DeleteProject:output_type -> determined.api.v1.DeleteProjectResponse + 473, // 473: determined.api.v1.Determined.ArchiveProject:output_type -> determined.api.v1.ArchiveProjectResponse + 474, // 474: determined.api.v1.Determined.UnarchiveProject:output_type -> determined.api.v1.UnarchiveProjectResponse + 475, // 475: determined.api.v1.Determined.MoveProject:output_type -> determined.api.v1.MoveProjectResponse + 476, // 476: determined.api.v1.Determined.MoveExperiment:output_type -> determined.api.v1.MoveExperimentResponse + 477, // 477: determined.api.v1.Determined.MoveExperiments:output_type -> determined.api.v1.MoveExperimentsResponse + 478, // 478: determined.api.v1.Determined.GetWebhooks:output_type -> determined.api.v1.GetWebhooksResponse + 479, // 479: determined.api.v1.Determined.PatchWebhook:output_type -> determined.api.v1.PatchWebhookResponse + 480, // 480: determined.api.v1.Determined.PostWebhook:output_type -> determined.api.v1.PostWebhookResponse + 481, // 481: determined.api.v1.Determined.DeleteWebhook:output_type -> determined.api.v1.DeleteWebhookResponse + 482, // 482: determined.api.v1.Determined.TestWebhook:output_type -> determined.api.v1.TestWebhookResponse + 483, // 483: determined.api.v1.Determined.PostWebhookEventData:output_type -> determined.api.v1.PostWebhookEventDataResponse + 484, // 484: determined.api.v1.Determined.GetGroup:output_type -> determined.api.v1.GetGroupResponse + 485, // 485: determined.api.v1.Determined.GetGroups:output_type -> determined.api.v1.GetGroupsResponse + 486, // 486: determined.api.v1.Determined.CreateGroup:output_type -> determined.api.v1.CreateGroupResponse + 487, // 487: determined.api.v1.Determined.UpdateGroup:output_type -> determined.api.v1.UpdateGroupResponse + 488, // 488: determined.api.v1.Determined.DeleteGroup:output_type -> determined.api.v1.DeleteGroupResponse + 489, // 489: determined.api.v1.Determined.GetPermissionsSummary:output_type -> determined.api.v1.GetPermissionsSummaryResponse + 490, // 490: determined.api.v1.Determined.GetGroupsAndUsersAssignedToWorkspace:output_type -> determined.api.v1.GetGroupsAndUsersAssignedToWorkspaceResponse + 491, // 491: determined.api.v1.Determined.GetRolesByID:output_type -> determined.api.v1.GetRolesByIDResponse + 492, // 492: determined.api.v1.Determined.GetRolesAssignedToUser:output_type -> determined.api.v1.GetRolesAssignedToUserResponse + 493, // 493: determined.api.v1.Determined.GetRolesAssignedToGroup:output_type -> determined.api.v1.GetRolesAssignedToGroupResponse + 494, // 494: determined.api.v1.Determined.SearchRolesAssignableToScope:output_type -> determined.api.v1.SearchRolesAssignableToScopeResponse + 495, // 495: determined.api.v1.Determined.ListRoles:output_type -> determined.api.v1.ListRolesResponse + 496, // 496: determined.api.v1.Determined.AssignRoles:output_type -> determined.api.v1.AssignRolesResponse + 497, // 497: determined.api.v1.Determined.RemoveAssignments:output_type -> determined.api.v1.RemoveAssignmentsResponse + 498, // 498: determined.api.v1.Determined.PostUserActivity:output_type -> determined.api.v1.PostUserActivityResponse + 499, // 499: determined.api.v1.Determined.GetProjectsByUserActivity:output_type -> determined.api.v1.GetProjectsByUserActivityResponse + 500, // 500: determined.api.v1.Determined.SearchExperiments:output_type -> determined.api.v1.SearchExperimentsResponse + 501, // 501: determined.api.v1.Determined.BindRPToWorkspace:output_type -> determined.api.v1.BindRPToWorkspaceResponse + 502, // 502: determined.api.v1.Determined.UnbindRPFromWorkspace:output_type -> determined.api.v1.UnbindRPFromWorkspaceResponse + 503, // 503: determined.api.v1.Determined.OverwriteRPWorkspaceBindings:output_type -> determined.api.v1.OverwriteRPWorkspaceBindingsResponse + 504, // 504: determined.api.v1.Determined.ListRPsBoundToWorkspace:output_type -> determined.api.v1.ListRPsBoundToWorkspaceResponse + 505, // 505: determined.api.v1.Determined.ListWorkspacesBoundToRP:output_type -> determined.api.v1.ListWorkspacesBoundToRPResponse + 506, // 506: determined.api.v1.Determined.GetGenericTaskConfig:output_type -> determined.api.v1.GetGenericTaskConfigResponse + 507, // 507: determined.api.v1.Determined.KillGenericTask:output_type -> determined.api.v1.KillGenericTaskResponse + 508, // 508: determined.api.v1.Determined.PauseGenericTask:output_type -> determined.api.v1.PauseGenericTaskResponse + 509, // 509: determined.api.v1.Determined.UnpauseGenericTask:output_type -> determined.api.v1.UnpauseGenericTaskResponse + 510, // 510: determined.api.v1.Determined.SearchRuns:output_type -> determined.api.v1.SearchRunsResponse + 511, // 511: determined.api.v1.Determined.MoveRuns:output_type -> determined.api.v1.MoveRunsResponse + 512, // 512: determined.api.v1.Determined.KillRuns:output_type -> determined.api.v1.KillRunsResponse + 513, // 513: determined.api.v1.Determined.DeleteRuns:output_type -> determined.api.v1.DeleteRunsResponse + 514, // 514: determined.api.v1.Determined.ArchiveRuns:output_type -> determined.api.v1.ArchiveRunsResponse + 515, // 515: determined.api.v1.Determined.UnarchiveRuns:output_type -> determined.api.v1.UnarchiveRunsResponse + 516, // 516: determined.api.v1.Determined.PauseRuns:output_type -> determined.api.v1.PauseRunsResponse + 517, // 517: determined.api.v1.Determined.ResumeRuns:output_type -> determined.api.v1.ResumeRunsResponse + 518, // 518: determined.api.v1.Determined.GetRunMetadata:output_type -> determined.api.v1.GetRunMetadataResponse + 519, // 519: determined.api.v1.Determined.PostRunMetadata:output_type -> determined.api.v1.PostRunMetadataResponse + 520, // 520: determined.api.v1.Determined.GetMetadataValues:output_type -> determined.api.v1.GetMetadataValuesResponse + 521, // 521: determined.api.v1.Determined.PutWorkspaceConfigPolicies:output_type -> determined.api.v1.PutWorkspaceConfigPoliciesResponse + 522, // 522: determined.api.v1.Determined.PutGlobalConfigPolicies:output_type -> determined.api.v1.PutGlobalConfigPoliciesResponse + 523, // 523: determined.api.v1.Determined.GetWorkspaceConfigPolicies:output_type -> determined.api.v1.GetWorkspaceConfigPoliciesResponse + 524, // 524: determined.api.v1.Determined.GetGlobalConfigPolicies:output_type -> determined.api.v1.GetGlobalConfigPoliciesResponse + 525, // 525: determined.api.v1.Determined.DeleteWorkspaceConfigPolicies:output_type -> determined.api.v1.DeleteWorkspaceConfigPoliciesResponse + 526, // 526: determined.api.v1.Determined.DeleteGlobalConfigPolicies:output_type -> determined.api.v1.DeleteGlobalConfigPoliciesResponse + 527, // 527: determined.api.v1.Determined.MoveSearches:output_type -> determined.api.v1.MoveSearchesResponse + 528, // 528: determined.api.v1.Determined.CancelSearches:output_type -> determined.api.v1.CancelSearchesResponse + 529, // 529: determined.api.v1.Determined.KillSearches:output_type -> determined.api.v1.KillSearchesResponse + 530, // 530: determined.api.v1.Determined.DeleteSearches:output_type -> determined.api.v1.DeleteSearchesResponse + 531, // 531: determined.api.v1.Determined.ArchiveSearches:output_type -> determined.api.v1.ArchiveSearchesResponse + 532, // 532: determined.api.v1.Determined.UnarchiveSearches:output_type -> determined.api.v1.UnarchiveSearchesResponse + 533, // 533: determined.api.v1.Determined.PauseSearches:output_type -> determined.api.v1.PauseSearchesResponse + 534, // 534: determined.api.v1.Determined.ResumeSearches:output_type -> determined.api.v1.ResumeSearchesResponse + 535, // 535: determined.api.v1.Determined.PostAccessToken:output_type -> determined.api.v1.PostAccessTokenResponse + 536, // 536: determined.api.v1.Determined.GetAccessTokens:output_type -> determined.api.v1.GetAccessTokensResponse + 537, // 537: determined.api.v1.Determined.PatchAccessToken:output_type -> determined.api.v1.PatchAccessTokenResponse + 269, // [269:538] is the sub-list for method output_type + 0, // [0:269] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -4471,11 +4400,6 @@ type DeterminedClient interface { // really considered to be in a "Running" state until all the containers // that are part of the experiment are running and not being pulled. NotifyContainerRunning(ctx context.Context, in *NotifyContainerRunningRequest, opts ...grpc.CallOption) (*NotifyContainerRunningResponse, error) - // Get the current searcher operation. - GetCurrentTrialSearcherOperation(ctx context.Context, in *GetCurrentTrialSearcherOperationRequest, opts ...grpc.CallOption) (*GetCurrentTrialSearcherOperationResponse, error) - // Reports to the searcher that the trial has completed the given searcher - // operation. - CompleteTrialSearcherValidation(ctx context.Context, in *CompleteTrialSearcherValidationRequest, opts ...grpc.CallOption) (*CompleteTrialSearcherValidationResponse, error) // Reports to the searcher that the trial has completed the current // requested amount of training with the given searcher validation // metric. @@ -4612,10 +4536,6 @@ type DeterminedClient interface { DeleteCheckpoints(ctx context.Context, in *DeleteCheckpointsRequest, opts ...grpc.CallOption) (*DeleteCheckpointsResponse, error) // Gets the metrics for all trials associated with this checkpoint GetTrialMetricsByCheckpoint(ctx context.Context, in *GetTrialMetricsByCheckpointRequest, opts ...grpc.CallOption) (*GetTrialMetricsByCheckpointResponse, error) - // Get the list of custom searcher events with long polling. - GetSearcherEvents(ctx context.Context, in *GetSearcherEventsRequest, opts ...grpc.CallOption) (*GetSearcherEventsResponse, error) - // Submit operations to a custom searcher. - PostSearcherOperations(ctx context.Context, in *PostSearcherOperationsRequest, opts ...grpc.CallOption) (*PostSearcherOperationsResponse, error) // Get the set of metric names recorded for a list of experiments. ExpMetricNames(ctx context.Context, in *ExpMetricNamesRequest, opts ...grpc.CallOption) (Determined_ExpMetricNamesClient, error) // Get the milestones (in batches processed) at which a metric is recorded by @@ -6010,24 +5930,6 @@ func (c *determinedClient) NotifyContainerRunning(ctx context.Context, in *Notif return out, nil } -func (c *determinedClient) GetCurrentTrialSearcherOperation(ctx context.Context, in *GetCurrentTrialSearcherOperationRequest, opts ...grpc.CallOption) (*GetCurrentTrialSearcherOperationResponse, error) { - out := new(GetCurrentTrialSearcherOperationResponse) - err := c.cc.Invoke(ctx, "/determined.api.v1.Determined/GetCurrentTrialSearcherOperation", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *determinedClient) CompleteTrialSearcherValidation(ctx context.Context, in *CompleteTrialSearcherValidationRequest, opts ...grpc.CallOption) (*CompleteTrialSearcherValidationResponse, error) { - out := new(CompleteTrialSearcherValidationResponse) - err := c.cc.Invoke(ctx, "/determined.api.v1.Determined/CompleteTrialSearcherValidation", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - func (c *determinedClient) ReportTrialSearcherEarlyExit(ctx context.Context, in *ReportTrialSearcherEarlyExitRequest, opts ...grpc.CallOption) (*ReportTrialSearcherEarlyExitResponse, error) { out := new(ReportTrialSearcherEarlyExitResponse) err := c.cc.Invoke(ctx, "/determined.api.v1.Determined/ReportTrialSearcherEarlyExit", in, out, opts...) @@ -6616,24 +6518,6 @@ func (c *determinedClient) GetTrialMetricsByCheckpoint(ctx context.Context, in * return out, nil } -func (c *determinedClient) GetSearcherEvents(ctx context.Context, in *GetSearcherEventsRequest, opts ...grpc.CallOption) (*GetSearcherEventsResponse, error) { - out := new(GetSearcherEventsResponse) - err := c.cc.Invoke(ctx, "/determined.api.v1.Determined/GetSearcherEvents", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *determinedClient) PostSearcherOperations(ctx context.Context, in *PostSearcherOperationsRequest, opts ...grpc.CallOption) (*PostSearcherOperationsResponse, error) { - out := new(PostSearcherOperationsResponse) - err := c.cc.Invoke(ctx, "/determined.api.v1.Determined/PostSearcherOperations", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - func (c *determinedClient) ExpMetricNames(ctx context.Context, in *ExpMetricNamesRequest, opts ...grpc.CallOption) (Determined_ExpMetricNamesClient, error) { stream, err := c.cc.NewStream(ctx, &_Determined_serviceDesc.Streams[10], "/determined.api.v1.Determined/ExpMetricNames", opts...) if err != nil { @@ -7861,11 +7745,6 @@ type DeterminedServer interface { // really considered to be in a "Running" state until all the containers // that are part of the experiment are running and not being pulled. NotifyContainerRunning(context.Context, *NotifyContainerRunningRequest) (*NotifyContainerRunningResponse, error) - // Get the current searcher operation. - GetCurrentTrialSearcherOperation(context.Context, *GetCurrentTrialSearcherOperationRequest) (*GetCurrentTrialSearcherOperationResponse, error) - // Reports to the searcher that the trial has completed the given searcher - // operation. - CompleteTrialSearcherValidation(context.Context, *CompleteTrialSearcherValidationRequest) (*CompleteTrialSearcherValidationResponse, error) // Reports to the searcher that the trial has completed the current // requested amount of training with the given searcher validation // metric. @@ -8002,10 +7881,6 @@ type DeterminedServer interface { DeleteCheckpoints(context.Context, *DeleteCheckpointsRequest) (*DeleteCheckpointsResponse, error) // Gets the metrics for all trials associated with this checkpoint GetTrialMetricsByCheckpoint(context.Context, *GetTrialMetricsByCheckpointRequest) (*GetTrialMetricsByCheckpointResponse, error) - // Get the list of custom searcher events with long polling. - GetSearcherEvents(context.Context, *GetSearcherEventsRequest) (*GetSearcherEventsResponse, error) - // Submit operations to a custom searcher. - PostSearcherOperations(context.Context, *PostSearcherOperationsRequest) (*PostSearcherOperationsResponse, error) // Get the set of metric names recorded for a list of experiments. ExpMetricNames(*ExpMetricNamesRequest, Determined_ExpMetricNamesServer) error // Get the milestones (in batches processed) at which a metric is recorded by @@ -8533,12 +8408,6 @@ func (*UnimplementedDeterminedServer) AllocationAllGather(context.Context, *Allo func (*UnimplementedDeterminedServer) NotifyContainerRunning(context.Context, *NotifyContainerRunningRequest) (*NotifyContainerRunningResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NotifyContainerRunning not implemented") } -func (*UnimplementedDeterminedServer) GetCurrentTrialSearcherOperation(context.Context, *GetCurrentTrialSearcherOperationRequest) (*GetCurrentTrialSearcherOperationResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetCurrentTrialSearcherOperation not implemented") -} -func (*UnimplementedDeterminedServer) CompleteTrialSearcherValidation(context.Context, *CompleteTrialSearcherValidationRequest) (*CompleteTrialSearcherValidationResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CompleteTrialSearcherValidation not implemented") -} func (*UnimplementedDeterminedServer) ReportTrialSearcherEarlyExit(context.Context, *ReportTrialSearcherEarlyExitRequest) (*ReportTrialSearcherEarlyExitResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ReportTrialSearcherEarlyExit not implemented") } @@ -8734,12 +8603,6 @@ func (*UnimplementedDeterminedServer) DeleteCheckpoints(context.Context, *Delete func (*UnimplementedDeterminedServer) GetTrialMetricsByCheckpoint(context.Context, *GetTrialMetricsByCheckpointRequest) (*GetTrialMetricsByCheckpointResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetTrialMetricsByCheckpoint not implemented") } -func (*UnimplementedDeterminedServer) GetSearcherEvents(context.Context, *GetSearcherEventsRequest) (*GetSearcherEventsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetSearcherEvents not implemented") -} -func (*UnimplementedDeterminedServer) PostSearcherOperations(context.Context, *PostSearcherOperationsRequest) (*PostSearcherOperationsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method PostSearcherOperations not implemented") -} func (*UnimplementedDeterminedServer) ExpMetricNames(*ExpMetricNamesRequest, Determined_ExpMetricNamesServer) error { return status.Errorf(codes.Unimplemented, "method ExpMetricNames not implemented") } @@ -10962,42 +10825,6 @@ func _Determined_NotifyContainerRunning_Handler(srv interface{}, ctx context.Con return interceptor(ctx, in, info, handler) } -func _Determined_GetCurrentTrialSearcherOperation_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetCurrentTrialSearcherOperationRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(DeterminedServer).GetCurrentTrialSearcherOperation(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/determined.api.v1.Determined/GetCurrentTrialSearcherOperation", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DeterminedServer).GetCurrentTrialSearcherOperation(ctx, req.(*GetCurrentTrialSearcherOperationRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _Determined_CompleteTrialSearcherValidation_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(CompleteTrialSearcherValidationRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(DeterminedServer).CompleteTrialSearcherValidation(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/determined.api.v1.Determined/CompleteTrialSearcherValidation", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DeterminedServer).CompleteTrialSearcherValidation(ctx, req.(*CompleteTrialSearcherValidationRequest)) - } - return interceptor(ctx, in, info, handler) -} - func _Determined_ReportTrialSearcherEarlyExit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ReportTrialSearcherEarlyExitRequest) if err := dec(in); err != nil { @@ -12168,42 +11995,6 @@ func _Determined_GetTrialMetricsByCheckpoint_Handler(srv interface{}, ctx contex return interceptor(ctx, in, info, handler) } -func _Determined_GetSearcherEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetSearcherEventsRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(DeterminedServer).GetSearcherEvents(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/determined.api.v1.Determined/GetSearcherEvents", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DeterminedServer).GetSearcherEvents(ctx, req.(*GetSearcherEventsRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _Determined_PostSearcherOperations_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(PostSearcherOperationsRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(DeterminedServer).PostSearcherOperations(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/determined.api.v1.Determined/PostSearcherOperations", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DeterminedServer).PostSearcherOperations(ctx, req.(*PostSearcherOperationsRequest)) - } - return interceptor(ctx, in, info, handler) -} - func _Determined_ExpMetricNames_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(ExpMetricNamesRequest) if err := stream.RecvMsg(m); err != nil { @@ -14382,14 +14173,6 @@ var _Determined_serviceDesc = grpc.ServiceDesc{ MethodName: "NotifyContainerRunning", Handler: _Determined_NotifyContainerRunning_Handler, }, - { - MethodName: "GetCurrentTrialSearcherOperation", - Handler: _Determined_GetCurrentTrialSearcherOperation_Handler, - }, - { - MethodName: "CompleteTrialSearcherValidation", - Handler: _Determined_CompleteTrialSearcherValidation_Handler, - }, { MethodName: "ReportTrialSearcherEarlyExit", Handler: _Determined_ReportTrialSearcherEarlyExit_Handler, @@ -14650,14 +14433,6 @@ var _Determined_serviceDesc = grpc.ServiceDesc{ MethodName: "GetTrialMetricsByCheckpoint", Handler: _Determined_GetTrialMetricsByCheckpoint_Handler, }, - { - MethodName: "GetSearcherEvents", - Handler: _Determined_GetSearcherEvents_Handler, - }, - { - MethodName: "PostSearcherOperations", - Handler: _Determined_PostSearcherOperations_Handler, - }, { MethodName: "GetResourcePools", Handler: _Determined_GetResourcePools_Handler, diff --git a/proto/pkg/apiv1/api.pb.gw.go b/proto/pkg/apiv1/api.pb.gw.go index 0053f245473..16d93eff5f8 100644 --- a/proto/pkg/apiv1/api.pb.gw.go +++ b/proto/pkg/apiv1/api.pb.gw.go @@ -5437,130 +5437,6 @@ func local_request_Determined_NotifyContainerRunning_0(ctx context.Context, mars } -func request_Determined_GetCurrentTrialSearcherOperation_0(ctx context.Context, marshaler runtime.Marshaler, client DeterminedClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetCurrentTrialSearcherOperationRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["trial_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "trial_id") - } - - protoReq.TrialId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "trial_id", err) - } - - msg, err := client.GetCurrentTrialSearcherOperation(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_Determined_GetCurrentTrialSearcherOperation_0(ctx context.Context, marshaler runtime.Marshaler, server DeterminedServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetCurrentTrialSearcherOperationRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["trial_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "trial_id") - } - - protoReq.TrialId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "trial_id", err) - } - - msg, err := server.GetCurrentTrialSearcherOperation(ctx, &protoReq) - return msg, metadata, err - -} - -func request_Determined_CompleteTrialSearcherValidation_0(ctx context.Context, marshaler runtime.Marshaler, client DeterminedClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CompleteTrialSearcherValidationRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq.CompletedOperation); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["trial_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "trial_id") - } - - protoReq.TrialId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "trial_id", err) - } - - msg, err := client.CompleteTrialSearcherValidation(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_Determined_CompleteTrialSearcherValidation_0(ctx context.Context, marshaler runtime.Marshaler, server DeterminedServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CompleteTrialSearcherValidationRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq.CompletedOperation); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["trial_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "trial_id") - } - - protoReq.TrialId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "trial_id", err) - } - - msg, err := server.CompleteTrialSearcherValidation(ctx, &protoReq) - return msg, metadata, err - -} - func request_Determined_ReportTrialSearcherEarlyExit_0(ctx context.Context, marshaler runtime.Marshaler, client DeterminedClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var protoReq ReportTrialSearcherEarlyExitRequest var metadata runtime.ServerMetadata @@ -9025,130 +8901,6 @@ func local_request_Determined_GetTrialMetricsByCheckpoint_0(ctx context.Context, } -func request_Determined_GetSearcherEvents_0(ctx context.Context, marshaler runtime.Marshaler, client DeterminedClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetSearcherEventsRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["experiment_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "experiment_id") - } - - protoReq.ExperimentId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "experiment_id", err) - } - - msg, err := client.GetSearcherEvents(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_Determined_GetSearcherEvents_0(ctx context.Context, marshaler runtime.Marshaler, server DeterminedServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetSearcherEventsRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["experiment_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "experiment_id") - } - - protoReq.ExperimentId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "experiment_id", err) - } - - msg, err := server.GetSearcherEvents(ctx, &protoReq) - return msg, metadata, err - -} - -func request_Determined_PostSearcherOperations_0(ctx context.Context, marshaler runtime.Marshaler, client DeterminedClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq PostSearcherOperationsRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["experiment_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "experiment_id") - } - - protoReq.ExperimentId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "experiment_id", err) - } - - msg, err := client.PostSearcherOperations(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_Determined_PostSearcherOperations_0(ctx context.Context, marshaler runtime.Marshaler, server DeterminedServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq PostSearcherOperationsRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["experiment_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "experiment_id") - } - - protoReq.ExperimentId, err = runtime.Int32(val) - - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "experiment_id", err) - } - - msg, err := server.PostSearcherOperations(ctx, &protoReq) - return msg, metadata, err - -} - var ( filter_Determined_ExpMetricNames_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} ) @@ -16084,46 +15836,6 @@ func RegisterDeterminedHandlerServer(ctx context.Context, mux *runtime.ServeMux, }) - mux.Handle("GET", pattern_Determined_GetCurrentTrialSearcherOperation_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_Determined_GetCurrentTrialSearcherOperation_0(rctx, inboundMarshaler, server, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_GetCurrentTrialSearcherOperation_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_Determined_CompleteTrialSearcherValidation_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_Determined_CompleteTrialSearcherValidation_0(rctx, inboundMarshaler, server, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_CompleteTrialSearcherValidation_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - mux.Handle("POST", pattern_Determined_ReportTrialSearcherEarlyExit_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -17424,46 +17136,6 @@ func RegisterDeterminedHandlerServer(ctx context.Context, mux *runtime.ServeMux, }) - mux.Handle("GET", pattern_Determined_GetSearcherEvents_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_Determined_GetSearcherEvents_0(rctx, inboundMarshaler, server, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_GetSearcherEvents_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_Determined_PostSearcherOperations_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_Determined_PostSearcherOperations_0(rctx, inboundMarshaler, server, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_PostSearcherOperations_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - mux.Handle("GET", pattern_Determined_ExpMetricNames_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport") _, outboundMarshaler := runtime.MarshalerForRequest(mux, req) @@ -21533,46 +21205,6 @@ func RegisterDeterminedHandlerClient(ctx context.Context, mux *runtime.ServeMux, }) - mux.Handle("GET", pattern_Determined_GetCurrentTrialSearcherOperation_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_Determined_GetCurrentTrialSearcherOperation_0(rctx, inboundMarshaler, client, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_GetCurrentTrialSearcherOperation_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_Determined_CompleteTrialSearcherValidation_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_Determined_CompleteTrialSearcherValidation_0(rctx, inboundMarshaler, client, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_CompleteTrialSearcherValidation_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - mux.Handle("POST", pattern_Determined_ReportTrialSearcherEarlyExit_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -22873,46 +22505,6 @@ func RegisterDeterminedHandlerClient(ctx context.Context, mux *runtime.ServeMux, }) - mux.Handle("GET", pattern_Determined_GetSearcherEvents_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_Determined_GetSearcherEvents_0(rctx, inboundMarshaler, client, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_GetSearcherEvents_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_Determined_PostSearcherOperations_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateContext(ctx, mux, req) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_Determined_PostSearcherOperations_0(rctx, inboundMarshaler, client, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - - forward_Determined_PostSearcherOperations_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - mux.Handle("GET", pattern_Determined_ExpMetricNames_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -25107,10 +24699,6 @@ var ( pattern_Determined_NotifyContainerRunning_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "allocations", "allocation_id", "notify_container_running"}, "", runtime.AssumeColonVerbOpt(true))) - pattern_Determined_GetCurrentTrialSearcherOperation_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 2, 5}, []string{"api", "v1", "trials", "trial_id", "searcher", "operation"}, "", runtime.AssumeColonVerbOpt(true))) - - pattern_Determined_CompleteTrialSearcherValidation_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 2, 5}, []string{"api", "v1", "trials", "trial_id", "searcher", "completed_operation"}, "", runtime.AssumeColonVerbOpt(true))) - pattern_Determined_ReportTrialSearcherEarlyExit_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "trials", "trial_id", "early_exit"}, "", runtime.AssumeColonVerbOpt(true))) pattern_Determined_ReportTrialProgress_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "trials", "trial_id", "progress"}, "", runtime.AssumeColonVerbOpt(true))) @@ -25241,10 +24829,6 @@ var ( pattern_Determined_GetTrialMetricsByCheckpoint_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "checkpoints", "checkpoint_uuid", "metrics"}, "", runtime.AssumeColonVerbOpt(true))) - pattern_Determined_GetSearcherEvents_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "experiments", "experiment_id", "searcher_events"}, "", runtime.AssumeColonVerbOpt(true))) - - pattern_Determined_PostSearcherOperations_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "experiments", "experiment_id", "searcher_operations"}, "", runtime.AssumeColonVerbOpt(true))) - pattern_Determined_ExpMetricNames_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4}, []string{"api", "v1", "experiments", "metrics-stream", "metric-names"}, "", runtime.AssumeColonVerbOpt(true))) pattern_Determined_MetricBatches_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 2, 5}, []string{"api", "v1", "experiments", "experiment_id", "metrics-stream", "batches"}, "", runtime.AssumeColonVerbOpt(true))) @@ -25655,10 +25239,6 @@ var ( forward_Determined_NotifyContainerRunning_0 = runtime.ForwardResponseMessage - forward_Determined_GetCurrentTrialSearcherOperation_0 = runtime.ForwardResponseMessage - - forward_Determined_CompleteTrialSearcherValidation_0 = runtime.ForwardResponseMessage - forward_Determined_ReportTrialSearcherEarlyExit_0 = runtime.ForwardResponseMessage forward_Determined_ReportTrialProgress_0 = runtime.ForwardResponseMessage @@ -25789,10 +25369,6 @@ var ( forward_Determined_GetTrialMetricsByCheckpoint_0 = runtime.ForwardResponseMessage - forward_Determined_GetSearcherEvents_0 = runtime.ForwardResponseMessage - - forward_Determined_PostSearcherOperations_0 = runtime.ForwardResponseMessage - forward_Determined_ExpMetricNames_0 = runtime.ForwardResponseStream forward_Determined_MetricBatches_0 = runtime.ForwardResponseStream diff --git a/proto/pkg/apiv1/experiment.pb.go b/proto/pkg/apiv1/experiment.pb.go index ef0a1286f87..e83c48213da 100644 --- a/proto/pkg/apiv1/experiment.pb.go +++ b/proto/pkg/apiv1/experiment.pb.go @@ -1515,8 +1515,8 @@ type PreviewHPSearchResponse struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // The resulting simulation. - Simulation *experimentv1.ExperimentSimulation `protobuf:"bytes,1,opt,name=simulation,proto3" json:"simulation,omitempty"` + // The resulting summary. + Summary *experimentv1.SearchSummary `protobuf:"bytes,1,opt,name=summary,proto3" json:"summary,omitempty"` } func (x *PreviewHPSearchResponse) Reset() { @@ -1551,9 +1551,9 @@ func (*PreviewHPSearchResponse) Descriptor() ([]byte, []int) { return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{22} } -func (x *PreviewHPSearchResponse) GetSimulation() *experimentv1.ExperimentSimulation { +func (x *PreviewHPSearchResponse) GetSummary() *experimentv1.SearchSummary { if x != nil { - return x.Simulation + return x.Summary } return nil } @@ -4908,210 +4908,6 @@ func (x *GetModelDefFileResponse) GetFile() []byte { return nil } -// Request to get the list of searcher events. -type GetSearcherEventsRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The ID of the experiment. - ExperimentId int32 `protobuf:"varint,1,opt,name=experiment_id,json=experimentId,proto3" json:"experiment_id,omitempty"` -} - -func (x *GetSearcherEventsRequest) Reset() { - *x = GetSearcherEventsRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[79] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetSearcherEventsRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetSearcherEventsRequest) ProtoMessage() {} - -func (x *GetSearcherEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[79] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetSearcherEventsRequest.ProtoReflect.Descriptor instead. -func (*GetSearcherEventsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{79} -} - -func (x *GetSearcherEventsRequest) GetExperimentId() int32 { - if x != nil { - return x.ExperimentId - } - return 0 -} - -// Response to GetSearcherEventsRequest. -type GetSearcherEventsResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The list of events in the queue. - SearcherEvents []*experimentv1.SearcherEvent `protobuf:"bytes,1,rep,name=searcher_events,json=searcherEvents,proto3" json:"searcher_events,omitempty"` -} - -func (x *GetSearcherEventsResponse) Reset() { - *x = GetSearcherEventsResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[80] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetSearcherEventsResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetSearcherEventsResponse) ProtoMessage() {} - -func (x *GetSearcherEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[80] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetSearcherEventsResponse.ProtoReflect.Descriptor instead. -func (*GetSearcherEventsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{80} -} - -func (x *GetSearcherEventsResponse) GetSearcherEvents() []*experimentv1.SearcherEvent { - if x != nil { - return x.SearcherEvents - } - return nil -} - -// Request for sending operations from a custom search method. -type PostSearcherOperationsRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The experiment ID. - ExperimentId int32 `protobuf:"varint,1,opt,name=experiment_id,json=experimentId,proto3" json:"experiment_id,omitempty"` - // List of operations to submit. - SearcherOperations []*experimentv1.SearcherOperation `protobuf:"bytes,2,rep,name=searcher_operations,json=searcherOperations,proto3" json:"searcher_operations,omitempty"` - // The event that triggered the client to send these operations to the master. - TriggeredByEvent *experimentv1.SearcherEvent `protobuf:"bytes,3,opt,name=triggered_by_event,json=triggeredByEvent,proto3" json:"triggered_by_event,omitempty"` -} - -func (x *PostSearcherOperationsRequest) Reset() { - *x = PostSearcherOperationsRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[81] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *PostSearcherOperationsRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*PostSearcherOperationsRequest) ProtoMessage() {} - -func (x *PostSearcherOperationsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[81] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use PostSearcherOperationsRequest.ProtoReflect.Descriptor instead. -func (*PostSearcherOperationsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{81} -} - -func (x *PostSearcherOperationsRequest) GetExperimentId() int32 { - if x != nil { - return x.ExperimentId - } - return 0 -} - -func (x *PostSearcherOperationsRequest) GetSearcherOperations() []*experimentv1.SearcherOperation { - if x != nil { - return x.SearcherOperations - } - return nil -} - -func (x *PostSearcherOperationsRequest) GetTriggeredByEvent() *experimentv1.SearcherEvent { - if x != nil { - return x.TriggeredByEvent - } - return nil -} - -// Response to PostSearcherOperationsResponse. -type PostSearcherOperationsResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *PostSearcherOperationsResponse) Reset() { - *x = PostSearcherOperationsResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[82] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *PostSearcherOperationsResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*PostSearcherOperationsResponse) ProtoMessage() {} - -func (x *PostSearcherOperationsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[82] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use PostSearcherOperationsResponse.ProtoReflect.Descriptor instead. -func (*PostSearcherOperationsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{82} -} - // Request for searching experiments type SearchExperimentsRequest struct { state protoimpl.MessageState @@ -5133,7 +4929,7 @@ type SearchExperimentsRequest struct { func (x *SearchExperimentsRequest) Reset() { *x = SearchExperimentsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[83] + mi := &file_determined_api_v1_experiment_proto_msgTypes[79] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5146,7 +4942,7 @@ func (x *SearchExperimentsRequest) String() string { func (*SearchExperimentsRequest) ProtoMessage() {} func (x *SearchExperimentsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[83] + mi := &file_determined_api_v1_experiment_proto_msgTypes[79] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5159,7 +4955,7 @@ func (x *SearchExperimentsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SearchExperimentsRequest.ProtoReflect.Descriptor instead. func (*SearchExperimentsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{83} + return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{79} } func (x *SearchExperimentsRequest) GetProjectId() int32 { @@ -5212,7 +5008,7 @@ type SearchExperimentExperiment struct { func (x *SearchExperimentExperiment) Reset() { *x = SearchExperimentExperiment{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[84] + mi := &file_determined_api_v1_experiment_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5225,7 +5021,7 @@ func (x *SearchExperimentExperiment) String() string { func (*SearchExperimentExperiment) ProtoMessage() {} func (x *SearchExperimentExperiment) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[84] + mi := &file_determined_api_v1_experiment_proto_msgTypes[80] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5238,7 +5034,7 @@ func (x *SearchExperimentExperiment) ProtoReflect() protoreflect.Message { // Deprecated: Use SearchExperimentExperiment.ProtoReflect.Descriptor instead. func (*SearchExperimentExperiment) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{84} + return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{80} } func (x *SearchExperimentExperiment) GetExperiment() *experimentv1.Experiment { @@ -5270,7 +5066,7 @@ type SearchExperimentsResponse struct { func (x *SearchExperimentsResponse) Reset() { *x = SearchExperimentsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[85] + mi := &file_determined_api_v1_experiment_proto_msgTypes[81] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5283,7 +5079,7 @@ func (x *SearchExperimentsResponse) String() string { func (*SearchExperimentsResponse) ProtoMessage() {} func (x *SearchExperimentsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[85] + mi := &file_determined_api_v1_experiment_proto_msgTypes[81] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5296,7 +5092,7 @@ func (x *SearchExperimentsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SearchExperimentsResponse.ProtoReflect.Descriptor instead. func (*SearchExperimentsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{85} + return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{81} } func (x *SearchExperimentsResponse) GetExperiments() []*SearchExperimentExperiment { @@ -5326,7 +5122,7 @@ type DeleteTensorboardFilesRequest struct { func (x *DeleteTensorboardFilesRequest) Reset() { *x = DeleteTensorboardFilesRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[86] + mi := &file_determined_api_v1_experiment_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5339,7 +5135,7 @@ func (x *DeleteTensorboardFilesRequest) String() string { func (*DeleteTensorboardFilesRequest) ProtoMessage() {} func (x *DeleteTensorboardFilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[86] + mi := &file_determined_api_v1_experiment_proto_msgTypes[82] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5352,7 +5148,7 @@ func (x *DeleteTensorboardFilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteTensorboardFilesRequest.ProtoReflect.Descriptor instead. func (*DeleteTensorboardFilesRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{86} + return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{82} } func (x *DeleteTensorboardFilesRequest) GetExperimentId() int32 { @@ -5372,7 +5168,7 @@ type DeleteTensorboardFilesResponse struct { func (x *DeleteTensorboardFilesResponse) Reset() { *x = DeleteTensorboardFilesResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[87] + mi := &file_determined_api_v1_experiment_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5385,7 +5181,7 @@ func (x *DeleteTensorboardFilesResponse) String() string { func (*DeleteTensorboardFilesResponse) ProtoMessage() {} func (x *DeleteTensorboardFilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[87] + mi := &file_determined_api_v1_experiment_proto_msgTypes[83] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5398,7 +5194,7 @@ func (x *DeleteTensorboardFilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteTensorboardFilesResponse.ProtoReflect.Descriptor instead. func (*DeleteTensorboardFilesResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{87} + return file_determined_api_v1_experiment_proto_rawDescGZIP(), []int{83} } // Metric value and metadata for a trial that has progress this far. @@ -5420,7 +5216,7 @@ type TrialsSnapshotResponse_Trial struct { func (x *TrialsSnapshotResponse_Trial) Reset() { *x = TrialsSnapshotResponse_Trial{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[88] + mi := &file_determined_api_v1_experiment_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5433,7 +5229,7 @@ func (x *TrialsSnapshotResponse_Trial) String() string { func (*TrialsSnapshotResponse_Trial) ProtoMessage() {} func (x *TrialsSnapshotResponse_Trial) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[88] + mi := &file_determined_api_v1_experiment_proto_msgTypes[84] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5495,7 +5291,7 @@ type TrialsSampleResponse_Trial struct { func (x *TrialsSampleResponse_Trial) Reset() { *x = TrialsSampleResponse_Trial{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_experiment_proto_msgTypes[89] + mi := &file_determined_api_v1_experiment_proto_msgTypes[85] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5508,7 +5304,7 @@ func (x *TrialsSampleResponse_Trial) String() string { func (*TrialsSampleResponse_Trial) ProtoMessage() {} func (x *TrialsSampleResponse_Trial) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_experiment_proto_msgTypes[89] + mi := &file_determined_api_v1_experiment_proto_msgTypes[85] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5793,295 +5589,268 @@ var file_determined_api_v1_experiment_proto_rawDesc = []byte{ 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x65, 0x65, 0x64, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x73, 0x65, 0x65, 0x64, 0x22, 0x69, 0x0a, 0x17, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x73, 0x65, 0x65, 0x64, 0x22, 0x5c, 0x0a, 0x17, 0x50, 0x72, 0x65, 0x76, 0x69, 0x65, 0x77, 0x48, 0x50, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4e, 0x0a, 0x0a, 0x73, 0x69, 0x6d, 0x75, 0x6c, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2e, 0x2e, 0x64, 0x65, - 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x53, 0x69, 0x6d, 0x75, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0a, 0x73, 0x69, 0x6d, - 0x75, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x2b, 0x0a, 0x19, 0x41, 0x63, 0x74, 0x69, 0x76, - 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x07, 0x73, 0x75, 0x6d, 0x6d, 0x61, + 0x72, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x27, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, + 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x53, 0x75, 0x6d, 0x6d, 0x61, 0x72, + 0x79, 0x52, 0x07, 0x73, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x22, 0x2b, 0x0a, 0x19, 0x41, 0x63, + 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x22, 0x1c, 0x0a, 0x1a, 0x41, 0x63, 0x74, 0x69, 0x76, + 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x52, 0x0a, 0x16, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, + 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x02, 0x69, 0x64, 0x3a, 0x12, 0x92, 0x41, 0x0f, 0x0a, 0x0d, 0xd2, 0x01, 0x05, 0x65, + 0x72, 0x72, 0x6f, 0x72, 0xd2, 0x01, 0x02, 0x69, 0x64, 0x22, 0xcb, 0x01, 0x0a, 0x1a, 0x41, 0x63, + 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, + 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, + 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, + 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, + 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, + 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, + 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x73, 0x0a, 0x1b, 0x41, 0x63, 0x74, 0x69, 0x76, + 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, + 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, + 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, + 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x28, 0x0a, 0x16, + 0x50, 0x61, 0x75, 0x73, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x22, 0x19, 0x0a, 0x17, 0x50, 0x61, 0x75, 0x73, 0x65, 0x45, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0xc8, 0x02, 0x0a, 0x15, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x64, + 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x06, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x12, 0x36, 0x0a, 0x08, 0x61, 0x72, 0x63, + 0x68, 0x69, 0x76, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, + 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x08, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, + 0x64, 0x12, 0x37, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, + 0x0e, 0x32, 0x1f, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x05, 0x52, 0x07, 0x75, 0x73, + 0x65, 0x72, 0x49, 0x64, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, + 0x5f, 0x69, 0x64, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, + 0x63, 0x74, 0x49, 0x64, 0x12, 0x36, 0x0a, 0x17, 0x65, 0x78, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x64, + 0x5f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, + 0x08, 0x20, 0x03, 0x28, 0x05, 0x52, 0x15, 0x65, 0x78, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x64, 0x45, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x22, 0xc8, 0x01, 0x0a, + 0x17, 0x50, 0x61, 0x75, 0x73, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, + 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, + 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, + 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, + 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, + 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, + 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x70, 0x0a, 0x18, 0x50, 0x61, 0x75, 0x73, 0x65, + 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, + 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, + 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, + 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x29, 0x0a, 0x17, 0x43, 0x61, 0x6e, + 0x63, 0x65, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x02, 0x69, 0x64, 0x22, 0x1c, 0x0a, 0x1a, 0x41, 0x63, 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, + 0x52, 0x02, 0x69, 0x64, 0x22, 0x1a, 0x0a, 0x18, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0xc9, 0x01, 0x0a, 0x18, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, + 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, + 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, + 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, + 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, + 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, + 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x71, 0x0a, 0x19, + 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, + 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, + 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, + 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, + 0x27, 0x0a, 0x15, 0x4b, 0x69, 0x6c, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x22, 0x18, 0x0a, 0x16, 0x4b, 0x69, 0x6c, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x52, 0x0a, 0x16, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, 0x0a, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, - 0x69, 0x64, 0x3a, 0x12, 0x92, 0x41, 0x0f, 0x0a, 0x0d, 0xd2, 0x01, 0x05, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0xd2, 0x01, 0x02, 0x69, 0x64, 0x22, 0xcb, 0x01, 0x0a, 0x1a, 0x41, 0x63, 0x74, 0x69, 0x76, - 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, - 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, - 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, - 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, - 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, - 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, - 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x5f, 0x69, 0x64, 0x73, 0x22, 0x73, 0x0a, 0x1b, 0x41, 0x63, 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, + 0x73, 0x65, 0x22, 0xc7, 0x01, 0x0a, 0x16, 0x4b, 0x69, 0x6c, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, + 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, + 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, + 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, + 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, + 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, + 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x6f, 0x0a, 0x17, + 0x4b, 0x69, 0x6c, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, + 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, + 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, + 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x2a, 0x0a, + 0x18, 0x41, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x22, 0x1b, 0x0a, 0x19, 0x41, 0x72, 0x63, + 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xca, 0x01, 0x0a, 0x19, 0x41, 0x72, 0x63, 0x68, 0x69, + 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, + 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, + 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, + 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, + 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, + 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, + 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, + 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x73, 0x22, 0x72, 0x0a, 0x1a, 0x41, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, + 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, + 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x2c, 0x0a, 0x1a, 0x55, 0x6e, 0x61, 0x72, 0x63, + 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x02, 0x69, 0x64, 0x22, 0x1d, 0x0a, 0x1b, 0x55, 0x6e, 0x61, 0x72, 0x63, 0x68, 0x69, + 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xcc, 0x01, 0x0a, 0x1b, 0x55, 0x6e, 0x61, 0x72, 0x63, 0x68, 0x69, + 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, + 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, + 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, + 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, + 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, + 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, + 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, + 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x1c, 0x55, 0x6e, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, - 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x28, 0x0a, 0x16, 0x50, 0x61, 0x75, - 0x73, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, + 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x63, 0x0a, 0x16, 0x50, 0x61, 0x74, + 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x49, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, + 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x76, 0x31, 0x2e, 0x50, 0x61, 0x74, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x5f, + 0x0a, 0x17, 0x50, 0x61, 0x74, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0a, 0x65, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, + 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x22, + 0xc8, 0x02, 0x0a, 0x1f, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, - 0x02, 0x69, 0x64, 0x22, 0x19, 0x0a, 0x17, 0x50, 0x61, 0x75, 0x73, 0x65, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xc8, - 0x02, 0x0a, 0x15, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, - 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, - 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x16, - 0x0a, 0x06, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, - 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x12, 0x36, 0x0a, 0x08, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, - 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, 0x6f, 0x6c, 0x56, - 0x61, 0x6c, 0x75, 0x65, 0x52, 0x08, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x64, 0x12, 0x37, - 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0e, 0x32, 0x1f, - 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x5f, - 0x69, 0x64, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x05, 0x52, 0x07, 0x75, 0x73, 0x65, 0x72, 0x49, - 0x64, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, - 0x64, 0x12, 0x36, 0x0a, 0x17, 0x65, 0x78, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x64, 0x5f, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x08, 0x20, 0x03, - 0x28, 0x05, 0x52, 0x15, 0x65, 0x78, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x64, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x22, 0xc8, 0x01, 0x0a, 0x17, 0x50, 0x61, - 0x75, 0x73, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, - 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, - 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, - 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, - 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, - 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, - 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x5f, 0x69, 0x64, 0x73, 0x22, 0x70, 0x0a, 0x18, 0x50, 0x61, 0x75, 0x73, 0x65, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, - 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, - 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, - 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x29, 0x0a, 0x17, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, - 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, - 0x64, 0x22, 0x1a, 0x0a, 0x18, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xc9, 0x01, - 0x0a, 0x18, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, - 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, - 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, - 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, - 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, - 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, - 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x71, 0x0a, 0x19, 0x43, 0x61, 0x6e, - 0x63, 0x65, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, - 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, - 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, - 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x27, 0x0a, 0x15, - 0x4b, 0x69, 0x6c, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x02, 0x69, 0x64, 0x22, 0x18, 0x0a, 0x16, 0x4b, 0x69, 0x6c, 0x6c, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0xc7, 0x01, 0x0a, 0x16, 0x4b, 0x69, 0x6c, 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, - 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, - 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, - 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, - 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, - 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, - 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x6f, 0x0a, 0x17, 0x4b, 0x69, 0x6c, - 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, - 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, - 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, - 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x2a, 0x0a, 0x18, 0x41, 0x72, - 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x22, 0x1b, 0x0a, 0x19, 0x41, 0x72, 0x63, 0x68, 0x69, 0x76, - 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0xca, 0x01, 0x0a, 0x19, 0x41, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, - 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, - 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, - 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, - 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, - 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, - 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, - 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, - 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, - 0x22, 0x72, 0x0a, 0x1a, 0x41, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, - 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, - 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, - 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x73, 0x22, 0x2c, 0x0a, 0x1a, 0x55, 0x6e, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, - 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, - 0x69, 0x64, 0x22, 0x1d, 0x0a, 0x1b, 0x55, 0x6e, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0xcc, 0x01, 0x0a, 0x1b, 0x55, 0x6e, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, - 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x42, 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, - 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, - 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, - 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, - 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, - 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x23, 0x92, 0x41, 0x20, - 0x0a, 0x1e, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, - 0x01, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, - 0x22, 0x74, 0x0a, 0x1c, 0x55, 0x6e, 0x61, 0x72, 0x63, 0x68, 0x69, 0x76, 0x65, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, - 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, - 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, - 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x63, 0x0a, 0x16, 0x50, 0x61, 0x74, 0x63, 0x68, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x49, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, - 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, - 0x50, 0x61, 0x74, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, - 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x5f, 0x0a, 0x17, 0x50, - 0x61, 0x74, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, - 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0xc8, 0x02, 0x0a, - 0x1f, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x43, 0x68, - 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, - 0x12, 0x44, 0x0a, 0x0c, 0x73, 0x6f, 0x72, 0x74, 0x5f, 0x62, 0x79, 0x5f, 0x61, 0x74, 0x74, 0x72, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x20, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, + 0x02, 0x69, 0x64, 0x12, 0x44, 0x0a, 0x0c, 0x73, 0x6f, 0x72, 0x74, 0x5f, 0x62, 0x79, 0x5f, 0x61, + 0x74, 0x74, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x20, 0x2e, 0x64, 0x65, 0x74, 0x65, + 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x6f, 0x72, 0x74, 0x42, 0x79, 0x48, 0x00, 0x52, 0x0a, 0x73, + 0x6f, 0x72, 0x74, 0x42, 0x79, 0x41, 0x74, 0x74, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x6f, 0x72, + 0x74, 0x5f, 0x62, 0x79, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x6f, 0x72, 0x74, 0x42, 0x79, 0x4d, 0x65, 0x74, 0x72, 0x69, + 0x63, 0x12, 0x35, 0x0a, 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x5f, 0x62, 0x79, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, + 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x72, 0x64, 0x65, 0x72, 0x42, 0x79, 0x52, + 0x07, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x42, 0x79, 0x12, 0x16, 0x0a, 0x06, 0x6f, 0x66, 0x66, 0x73, + 0x65, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, + 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x37, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, + 0x18, 0x07, 0x20, 0x03, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x2e, 0x76, - 0x31, 0x2e, 0x53, 0x6f, 0x72, 0x74, 0x42, 0x79, 0x48, 0x00, 0x52, 0x0a, 0x73, 0x6f, 0x72, 0x74, - 0x42, 0x79, 0x41, 0x74, 0x74, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x6f, 0x72, 0x74, 0x5f, 0x62, - 0x79, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, - 0x52, 0x0c, 0x73, 0x6f, 0x72, 0x74, 0x42, 0x79, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x35, - 0x0a, 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x5f, 0x62, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, - 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x72, 0x64, 0x65, 0x72, 0x42, 0x79, 0x52, 0x07, 0x6f, 0x72, - 0x64, 0x65, 0x72, 0x42, 0x79, 0x12, 0x16, 0x0a, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x12, 0x14, 0x0a, - 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x6c, 0x69, - 0x6d, 0x69, 0x74, 0x12, 0x37, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, - 0x03, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, - 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x42, 0x09, 0x0a, 0x07, - 0x73, 0x6f, 0x72, 0x74, 0x5f, 0x62, 0x79, 0x22, 0xcb, 0x01, 0x0a, 0x20, 0x47, 0x65, 0x74, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x46, 0x0a, 0x0b, - 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x63, - 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x0b, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x73, 0x12, 0x3d, 0x0a, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, - 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x61, 0x67, - 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x92, 0x41, 0x1d, 0x0a, 0x1b, 0xd2, 0x01, 0x0b, 0x63, 0x68, 0x65, - 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0xd2, 0x01, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x4c, 0x0a, 0x25, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, - 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x49, 0x64, 0x22, 0x89, 0x01, 0x0a, 0x26, 0x47, 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, - 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5f, - 0x0a, 0x12, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x68, 0x69, 0x73, - 0x74, 0x6f, 0x72, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x64, 0x65, 0x74, - 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x11, 0x76, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x22, - 0xd2, 0x02, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x43, 0x0a, 0x10, 0x6d, - 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, - 0x65, 0x64, 0x2e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, - 0x0f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x16, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x23, 0x0a, 0x0d, 0x76, 0x61, 0x6c, 0x69, - 0x64, 0x61, 0x74, 0x65, 0x5f, 0x6f, 0x6e, 0x6c, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0c, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x4f, 0x6e, 0x6c, 0x79, 0x12, 0x1b, 0x0a, - 0x09, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x63, - 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x61, 0x63, - 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, - 0x74, 0x5f, 0x69, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, - 0x65, 0x63, 0x74, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x08, 0x74, 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x74, - 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x08, 0x74, 0x65, 0x6d, 0x70, 0x6c, - 0x61, 0x74, 0x65, 0x88, 0x01, 0x01, 0x12, 0x21, 0x0a, 0x09, 0x75, 0x6e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x64, 0x18, 0x28, 0x20, 0x01, 0x28, 0x08, 0x48, 0x01, 0x52, 0x09, 0x75, 0x6e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x64, 0x88, 0x01, 0x01, 0x42, 0x0b, 0x0a, 0x09, 0x5f, 0x74, 0x65, - 0x6d, 0x70, 0x6c, 0x61, 0x74, 0x65, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x75, 0x6e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x64, 0x22, 0xec, 0x01, 0x0a, 0x18, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x44, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, - 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, - 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, - 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3c, 0x0a, 0x08, 0x77, 0x61, 0x72, 0x6e, - 0x69, 0x6e, 0x67, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0e, 0x32, 0x20, 0x2e, 0x64, 0x65, 0x74, - 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4c, - 0x61, 0x75, 0x6e, 0x63, 0x68, 0x57, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x08, 0x77, 0x61, - 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x3a, 0x1b, 0x92, 0x41, 0x18, 0x0a, 0x16, 0xd2, 0x01, 0x0a, - 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0xd2, 0x01, 0x06, 0x63, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x22, 0xb4, 0x01, 0x0a, 0x14, 0x50, 0x75, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x66, 0x0a, 0x19, - 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x2a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, - 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x52, 0x17, 0x63, 0x72, 0x65, - 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x34, 0x0a, 0x16, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, - 0x5f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x29, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x14, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x45, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0xab, 0x01, 0x0a, 0x15, 0x50, - 0x75, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, + 0x31, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x42, + 0x09, 0x0a, 0x07, 0x73, 0x6f, 0x72, 0x74, 0x5f, 0x62, 0x79, 0x22, 0xcb, 0x01, 0x0a, 0x20, 0x47, + 0x65, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x43, 0x68, 0x65, 0x63, + 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x46, 0x0a, 0x0b, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, + 0x64, 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x0b, 0x63, 0x68, 0x65, 0x63, + 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x12, 0x3d, 0x0a, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x64, 0x65, + 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, + 0x50, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0a, 0x70, 0x61, 0x67, 0x69, + 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x92, 0x41, 0x1d, 0x0a, 0x1b, 0xd2, 0x01, 0x0b, + 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0xd2, 0x01, 0x0a, 0x70, 0x61, + 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x4c, 0x0a, 0x25, 0x47, 0x65, 0x74, 0x45, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0x89, 0x01, 0x0a, 0x26, 0x47, 0x65, 0x74, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x5f, 0x0a, 0x12, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, + 0x68, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, + 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, + 0x11, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x69, 0x73, 0x74, 0x6f, + 0x72, 0x79, 0x22, 0xd2, 0x02, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x43, + 0x0a, 0x10, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x69, + 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, + 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x69, + 0x6c, 0x65, 0x52, 0x0f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x23, 0x0a, 0x0d, 0x76, + 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x6f, 0x6e, 0x6c, 0x79, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0c, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x4f, 0x6e, 0x6c, 0x79, + 0x12, 0x1b, 0x0a, 0x09, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1a, 0x0a, + 0x08, 0x61, 0x63, 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x08, 0x61, 0x63, 0x74, 0x69, 0x76, 0x61, 0x74, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, + 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, + 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x08, 0x74, 0x65, 0x6d, 0x70, + 0x6c, 0x61, 0x74, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x08, 0x74, 0x65, + 0x6d, 0x70, 0x6c, 0x61, 0x74, 0x65, 0x88, 0x01, 0x01, 0x12, 0x21, 0x0a, 0x09, 0x75, 0x6e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x64, 0x18, 0x28, 0x20, 0x01, 0x28, 0x08, 0x48, 0x01, 0x52, 0x09, + 0x75, 0x6e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x64, 0x88, 0x01, 0x01, 0x42, 0x0b, 0x0a, 0x09, + 0x5f, 0x74, 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x74, 0x65, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x75, 0x6e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x64, 0x22, 0xec, 0x01, 0x0a, 0x18, 0x43, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, @@ -6089,309 +5858,306 @@ var file_determined_api_v1_experiment_proto_rawDesc = []byte{ 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, - 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x3a, 0x1b, 0x92, 0x41, 0x18, - 0x0a, 0x16, 0xd2, 0x01, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0xd2, - 0x01, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x60, 0x0a, 0x19, 0x43, 0x6f, 0x6e, 0x74, - 0x69, 0x6e, 0x75, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x02, 0x69, 0x64, 0x12, 0x27, 0x0a, 0x0f, 0x6f, 0x76, 0x65, 0x72, 0x72, 0x69, 0x64, - 0x65, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, - 0x6f, 0x76, 0x65, 0x72, 0x72, 0x69, 0x64, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x3a, 0x0a, - 0x92, 0x41, 0x07, 0x0a, 0x05, 0xd2, 0x01, 0x02, 0x69, 0x64, 0x22, 0xbd, 0x01, 0x0a, 0x1a, 0x43, + 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3c, 0x0a, 0x08, 0x77, + 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0e, 0x32, 0x20, 0x2e, + 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, + 0x31, 0x2e, 0x4c, 0x61, 0x75, 0x6e, 0x63, 0x68, 0x57, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x52, + 0x08, 0x77, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x3a, 0x1b, 0x92, 0x41, 0x18, 0x0a, 0x16, + 0xd2, 0x01, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0xd2, 0x01, 0x06, + 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb4, 0x01, 0x0a, 0x14, 0x50, 0x75, 0x74, 0x45, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x66, 0x0a, 0x19, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x52, 0x17, + 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x34, 0x0a, 0x16, 0x65, 0x78, 0x74, 0x65, 0x72, + 0x6e, 0x61, 0x6c, 0x5f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x29, 0x20, 0x01, 0x28, 0x09, 0x52, 0x14, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, + 0x6c, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0xab, 0x01, + 0x0a, 0x15, 0x50, 0x75, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, + 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x2f, 0x0a, + 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x3a, 0x1b, + 0x92, 0x41, 0x18, 0x0a, 0x16, 0xd2, 0x01, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0xd2, 0x01, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x60, 0x0a, 0x19, 0x43, 0x6f, 0x6e, 0x74, 0x69, 0x6e, 0x75, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0a, 0x65, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, - 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, - 0x3c, 0x0a, 0x08, 0x77, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0e, 0x32, 0x20, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, - 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x61, 0x75, 0x6e, 0x63, 0x68, 0x57, 0x61, 0x72, 0x6e, - 0x69, 0x6e, 0x67, 0x52, 0x08, 0x77, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x3a, 0x1b, 0x92, - 0x41, 0x18, 0x0a, 0x16, 0xd2, 0x01, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0xd2, 0x01, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x5b, 0x0a, 0x15, 0x45, 0x78, - 0x70, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x03, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, - 0x42, 0x09, 0x92, 0x41, 0x06, 0xd2, 0x01, 0x03, 0x69, 0x64, 0x73, 0x52, 0x03, 0x69, 0x64, 0x73, - 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, - 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, - 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x16, 0x45, 0x78, 0x70, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x29, 0x0a, 0x10, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x5f, 0x6d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, 0x2d, 0x0a, - 0x10, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, - 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0f, 0x74, 0x72, 0x61, - 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, 0x31, 0x0a, 0x12, - 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x11, 0x76, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, - 0x49, 0x0a, 0x0c, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, - 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, - 0x65, 0x64, 0x2e, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x52, 0x0b, 0x6d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x22, 0x85, 0x02, 0x0a, 0x14, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x42, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x38, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x13, 0x92, 0x41, 0x10, 0xd2, - 0x01, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x52, - 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x32, 0x0a, - 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x42, 0x11, 0x92, 0x41, 0x0e, 0xd2, 0x01, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, - 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, - 0x65, 0x12, 0x42, 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x74, 0x79, 0x70, 0x65, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, - 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x54, 0x79, 0x70, 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x70, - 0x65, 0x72, 0x69, 0x6f, 0x64, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x0d, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x53, 0x65, 0x63, 0x6f, 0x6e, - 0x64, 0x73, 0x22, 0x31, 0x0a, 0x15, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x42, 0x61, 0x74, 0x63, - 0x68, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, - 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x07, 0x62, 0x61, - 0x74, 0x63, 0x68, 0x65, 0x73, 0x22, 0xf3, 0x02, 0x0a, 0x15, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, - 0x53, 0x6e, 0x61, 0x70, 0x73, 0x68, 0x6f, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x38, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x13, 0x92, 0x41, 0x10, 0xd2, 0x01, 0x0d, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x52, 0x0c, 0x65, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x32, 0x0a, 0x0b, 0x6d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x11, - 0x92, 0x41, 0x0e, 0xd2, 0x01, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, - 0x65, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x42, 0x0a, - 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, - 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, - 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x44, 0x0a, 0x11, 0x62, 0x61, 0x74, 0x63, 0x68, - 0x65, 0x73, 0x5f, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x05, 0x42, 0x17, 0x92, 0x41, 0x14, 0xd2, 0x01, 0x11, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, - 0x73, 0x5f, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x52, 0x10, 0x62, 0x61, 0x74, - 0x63, 0x68, 0x65, 0x73, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x12, 0x25, 0x0a, - 0x0e, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x5f, 0x6d, 0x61, 0x72, 0x67, 0x69, 0x6e, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x4d, 0x61, - 0x72, 0x67, 0x69, 0x6e, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x5f, 0x73, - 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x70, 0x65, - 0x72, 0x69, 0x6f, 0x64, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x22, 0xc7, 0x02, 0x0a, 0x16, - 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x53, 0x6e, 0x61, 0x70, 0x73, 0x68, 0x6f, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x47, 0x0a, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2f, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, - 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, - 0x73, 0x53, 0x6e, 0x61, 0x70, 0x73, 0x68, 0x6f, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x1a, - 0xd3, 0x01, 0x0a, 0x05, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x49, 0x64, 0x12, 0x31, 0x0a, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x07, - 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, 0x01, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, - 0x2b, 0x0a, 0x11, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x5f, 0x70, 0x72, 0x6f, 0x63, 0x65, - 0x73, 0x73, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x10, 0x62, 0x61, 0x74, 0x63, - 0x68, 0x65, 0x73, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x3a, 0x37, 0x92, 0x41, - 0x34, 0x0a, 0x32, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, - 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0xd2, 0x01, 0x06, 0x6d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0xd2, 0x01, 0x11, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x5f, 0x70, 0x72, 0x6f, 0x63, - 0x65, 0x73, 0x73, 0x65, 0x64, 0x3a, 0x0e, 0x92, 0x41, 0x0b, 0x0a, 0x09, 0xd2, 0x01, 0x06, 0x74, - 0x72, 0x69, 0x61, 0x6c, 0x73, 0x22, 0x90, 0x03, 0x0a, 0x13, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, - 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x38, 0x0a, - 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x05, 0x42, 0x13, 0x92, 0x41, 0x10, 0xd2, 0x01, 0x0d, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x32, 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x11, 0x92, 0x41, - 0x0e, 0xd2, 0x01, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x52, - 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x42, 0x0a, 0x0b, 0x6d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, - 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, 0x65, 0x42, - 0x02, 0x18, 0x01, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, 0x65, 0x12, - 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x1d, 0x0a, 0x0a, 0x6d, 0x61, 0x78, 0x5f, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6d, 0x61, 0x78, 0x54, 0x72, - 0x69, 0x61, 0x6c, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x6d, 0x61, 0x78, 0x5f, 0x64, 0x61, 0x74, 0x61, - 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x6d, 0x61, - 0x78, 0x44, 0x61, 0x74, 0x61, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x73, - 0x74, 0x61, 0x72, 0x74, 0x5f, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x72, 0x74, 0x42, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, - 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x6e, 0x64, 0x5f, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x65, 0x6e, 0x64, 0x42, 0x61, 0x74, 0x63, 0x68, 0x65, - 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x5f, 0x73, 0x65, 0x63, 0x6f, - 0x6e, 0x64, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x70, 0x65, 0x72, 0x69, 0x6f, - 0x64, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x22, 0x8d, 0x03, 0x0a, 0x14, 0x54, 0x72, 0x69, - 0x61, 0x6c, 0x73, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x45, 0x0a, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x2d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, - 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x53, 0x61, 0x6d, 0x70, - 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, - 0x52, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x12, 0x27, 0x0a, 0x0f, 0x70, 0x72, 0x6f, 0x6d, - 0x6f, 0x74, 0x65, 0x64, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x05, 0x52, 0x0e, 0x70, 0x72, 0x6f, 0x6d, 0x6f, 0x74, 0x65, 0x64, 0x54, 0x72, 0x69, 0x61, 0x6c, - 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x64, 0x5f, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6d, 0x6f, 0x74, - 0x65, 0x64, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x1a, 0xaa, 0x01, 0x0a, 0x05, 0x54, 0x72, 0x69, - 0x61, 0x6c, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x31, 0x0a, - 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, - 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, - 0x12, 0x30, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, - 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, - 0x76, 0x31, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x50, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x04, 0x64, 0x61, - 0x74, 0x61, 0x3a, 0x21, 0x92, 0x41, 0x1e, 0x0a, 0x1c, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, - 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0xd2, 0x01, - 0x04, 0x64, 0x61, 0x74, 0x61, 0x3a, 0x31, 0x92, 0x41, 0x2e, 0x0a, 0x2c, 0xd2, 0x01, 0x06, 0x74, - 0x72, 0x69, 0x61, 0x6c, 0x73, 0xd2, 0x01, 0x0f, 0x70, 0x72, 0x6f, 0x6d, 0x6f, 0x74, 0x65, 0x64, - 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0xd2, 0x01, 0x0e, 0x64, 0x65, 0x6d, 0x6f, 0x74, 0x65, - 0x64, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x22, 0x39, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4d, - 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, - 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x49, 0x64, 0x22, 0x3f, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x44, - 0x65, 0x66, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x62, 0x36, - 0x34, 0x5f, 0x74, 0x67, 0x7a, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x62, 0x36, 0x34, - 0x54, 0x67, 0x7a, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x62, 0x36, 0x34, - 0x5f, 0x74, 0x67, 0x7a, 0x22, 0xa2, 0x01, 0x0a, 0x15, 0x4d, 0x6f, 0x76, 0x65, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, - 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x49, 0x64, 0x12, 0x34, 0x0a, 0x16, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x14, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x2e, 0x92, 0x41, 0x2b, 0x0a, 0x29, - 0xd2, 0x01, 0x16, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, - 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0d, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x22, 0x18, 0x0a, 0x16, 0x4d, 0x6f, 0x76, - 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x96, 0x02, 0x0a, 0x16, 0x4d, 0x6f, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, - 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x34, 0x0a, 0x16, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x14, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x12, 0x42, 0x0a, 0x07, 0x66, - 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, - 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, - 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x46, - 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x12, - 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x3c, - 0x92, 0x41, 0x39, 0x0a, 0x37, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, - 0x69, 0x64, 0xd2, 0x01, 0x16, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0e, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x6f, 0x0a, 0x17, - 0x4d, 0x6f, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, - 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, - 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, 0x0f, 0x92, 0x41, - 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x3d, 0x0a, - 0x16, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x54, 0x72, 0x65, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, - 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0x53, 0x0a, 0x17, - 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x54, 0x72, 0x65, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x38, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, - 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, - 0x31, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, - 0x73, 0x22, 0x51, 0x0a, 0x16, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, - 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, - 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, - 0x70, 0x61, 0x74, 0x68, 0x22, 0x2d, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, - 0x44, 0x65, 0x66, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, - 0x69, 0x6c, 0x65, 0x22, 0x3f, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x65, 0x72, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x49, 0x64, 0x22, 0x6d, 0x0a, 0x19, 0x47, 0x65, 0x74, 0x53, 0x65, 0x61, 0x72, 0x63, - 0x68, 0x65, 0x72, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x50, 0x0a, 0x0f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x5f, 0x65, 0x76, - 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x27, 0x2e, 0x64, 0x65, 0x74, - 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x45, 0x76, - 0x65, 0x6e, 0x74, 0x52, 0x0e, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x45, 0x76, 0x65, - 0x6e, 0x74, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x1d, 0x50, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, - 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x5c, 0x0a, 0x13, 0x73, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x5f, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, - 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x76, 0x31, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4f, 0x70, - 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x55, 0x0a, 0x12, 0x74, 0x72, 0x69, 0x67, - 0x67, 0x65, 0x72, 0x65, 0x64, 0x5f, 0x62, 0x79, 0x5f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x27, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, - 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, - 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x52, 0x10, 0x74, - 0x72, 0x69, 0x67, 0x67, 0x65, 0x72, 0x65, 0x64, 0x42, 0x79, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, - 0x20, 0x0a, 0x1e, 0x50, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4f, - 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0xc5, 0x01, 0x0a, 0x18, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x22, - 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x05, 0x48, 0x00, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x88, - 0x01, 0x01, 0x12, 0x16, 0x0a, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, - 0x6d, 0x69, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, - 0x12, 0x17, 0x0a, 0x04, 0x73, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x48, 0x01, - 0x52, 0x04, 0x73, 0x6f, 0x72, 0x74, 0x88, 0x01, 0x01, 0x12, 0x1b, 0x0a, 0x06, 0x66, 0x69, 0x6c, - 0x74, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x02, 0x52, 0x06, 0x66, 0x69, 0x6c, - 0x74, 0x65, 0x72, 0x88, 0x01, 0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, - 0x63, 0x74, 0x5f, 0x69, 0x64, 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x73, 0x6f, 0x72, 0x74, 0x42, 0x09, - 0x0a, 0x07, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x22, 0xb1, 0x01, 0x0a, 0x1a, 0x53, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x45, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x44, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, - 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x39, - 0x0a, 0x0a, 0x62, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, - 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x09, - 0x62, 0x65, 0x73, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x12, 0x92, 0x41, 0x0f, 0x0a, 0x0d, - 0xd2, 0x01, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0xcd, 0x01, - 0x0a, 0x19, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4f, 0x0a, 0x0b, 0x65, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x2d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, - 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, - 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, - 0x0b, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x3d, 0x0a, 0x0a, - 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, - 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x92, 0x41, 0x1d, - 0x0a, 0x1b, 0xd2, 0x01, 0x0b, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, - 0xd2, 0x01, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x5b, 0x0a, - 0x1d, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x62, 0x6f, 0x61, - 0x72, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x12, 0x27, 0x0a, 0x0f, 0x6f, 0x76, 0x65, 0x72, + 0x72, 0x69, 0x64, 0x65, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0e, 0x6f, 0x76, 0x65, 0x72, 0x72, 0x69, 0x64, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x3a, 0x0a, 0x92, 0x41, 0x07, 0x0a, 0x05, 0xd2, 0x01, 0x02, 0x69, 0x64, 0x22, 0xbd, 0x01, + 0x0a, 0x1a, 0x43, 0x6f, 0x6e, 0x74, 0x69, 0x6e, 0x75, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0a, + 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x12, 0x3c, 0x0a, 0x08, 0x77, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x02, + 0x20, 0x03, 0x28, 0x0e, 0x32, 0x20, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, + 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x61, 0x75, 0x6e, 0x63, 0x68, 0x57, + 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x08, 0x77, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, + 0x3a, 0x1b, 0x92, 0x41, 0x18, 0x0a, 0x16, 0xd2, 0x01, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0xd2, 0x01, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x5b, 0x0a, + 0x15, 0x45, 0x78, 0x70, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x03, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x05, 0x42, 0x09, 0x92, 0x41, 0x06, 0xd2, 0x01, 0x03, 0x69, 0x64, 0x73, 0x52, 0x03, + 0x69, 0x64, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x5f, 0x73, 0x65, + 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x70, 0x65, 0x72, + 0x69, 0x6f, 0x64, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x16, 0x45, + 0x78, 0x70, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x29, 0x0a, 0x10, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, + 0x72, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x0f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, + 0x12, 0x2d, 0x0a, 0x10, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0f, + 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, + 0x31, 0x0a, 0x12, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, + 0x11, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, + 0x63, 0x73, 0x12, 0x49, 0x0a, 0x0c, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, + 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, + 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x2e, 0x76, 0x31, 0x2e, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, + 0x52, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x22, 0x85, 0x02, + 0x0a, 0x14, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x42, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x38, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x13, 0x92, + 0x41, 0x10, 0xd2, 0x01, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, + 0x12, 0x32, 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x11, 0x92, 0x41, 0x0e, 0xd2, 0x01, 0x0b, 0x6d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x42, 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x74, + 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, + 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0a, 0x6d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, + 0x70, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x25, + 0x0a, 0x0e, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x53, 0x65, + 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x22, 0x31, 0x0a, 0x15, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x42, + 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, + 0x0a, 0x07, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, + 0x07, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x22, 0xf3, 0x02, 0x0a, 0x15, 0x54, 0x72, 0x69, + 0x61, 0x6c, 0x73, 0x53, 0x6e, 0x61, 0x70, 0x73, 0x68, 0x6f, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x38, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x13, 0x92, 0x41, 0x10, 0xd2, 0x01, + 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x52, 0x0c, + 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x32, 0x0a, 0x0b, + 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x42, 0x11, 0x92, 0x41, 0x0e, 0xd2, 0x01, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, + 0x6e, 0x61, 0x6d, 0x65, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, + 0x12, 0x42, 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, + 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x54, 0x79, 0x70, 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x44, 0x0a, 0x11, 0x62, 0x61, + 0x74, 0x63, 0x68, 0x65, 0x73, 0x5f, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x05, 0x42, 0x17, 0x92, 0x41, 0x14, 0xd2, 0x01, 0x11, 0x62, 0x61, 0x74, + 0x63, 0x68, 0x65, 0x73, 0x5f, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x52, 0x10, + 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, + 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x5f, 0x6d, 0x61, 0x72, 0x67, + 0x69, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, + 0x73, 0x4d, 0x61, 0x72, 0x67, 0x69, 0x6e, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x65, 0x72, 0x69, 0x6f, + 0x64, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x0d, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x22, 0xc7, + 0x02, 0x0a, 0x16, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x53, 0x6e, 0x61, 0x70, 0x73, 0x68, 0x6f, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x47, 0x0a, 0x06, 0x74, 0x72, 0x69, + 0x61, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2f, 0x2e, 0x64, 0x65, 0x74, 0x65, + 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, + 0x69, 0x61, 0x6c, 0x73, 0x53, 0x6e, 0x61, 0x70, 0x73, 0x68, 0x6f, 0x74, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x06, 0x74, 0x72, 0x69, 0x61, + 0x6c, 0x73, 0x1a, 0xd3, 0x01, 0x0a, 0x05, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x19, 0x0a, 0x08, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x31, 0x0a, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, + 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, + 0x74, 0x52, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, 0x01, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x12, 0x2b, 0x0a, 0x11, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x5f, 0x70, 0x72, + 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x10, 0x62, + 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x3a, + 0x37, 0x92, 0x41, 0x34, 0x0a, 0x32, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, + 0x64, 0xd2, 0x01, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0xd2, 0x01, 0x06, 0x6d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0xd2, 0x01, 0x11, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x5f, 0x70, + 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x3a, 0x0e, 0x92, 0x41, 0x0b, 0x0a, 0x09, 0xd2, + 0x01, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x22, 0x90, 0x03, 0x0a, 0x13, 0x54, 0x72, 0x69, + 0x61, 0x6c, 0x73, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x38, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x13, 0x92, 0x41, 0x10, 0xd2, 0x01, 0x0d, 0x65, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x52, 0x0c, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x32, 0x0a, 0x0b, 0x6d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, + 0x11, 0x92, 0x41, 0x0e, 0xd2, 0x01, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x42, + 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, + 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, + 0x70, 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, + 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x1d, 0x0a, 0x0a, 0x6d, 0x61, 0x78, 0x5f, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6d, 0x61, + 0x78, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x6d, 0x61, 0x78, 0x5f, 0x64, + 0x61, 0x74, 0x61, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x0d, 0x6d, 0x61, 0x78, 0x44, 0x61, 0x74, 0x61, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x12, 0x23, + 0x0a, 0x0d, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x62, 0x61, 0x74, 0x63, 0x68, 0x65, 0x73, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x72, 0x74, 0x42, 0x61, 0x74, 0x63, + 0x68, 0x65, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x6e, 0x64, 0x5f, 0x62, 0x61, 0x74, 0x63, 0x68, + 0x65, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x65, 0x6e, 0x64, 0x42, 0x61, 0x74, + 0x63, 0x68, 0x65, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x5f, 0x73, + 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x70, 0x65, + 0x72, 0x69, 0x6f, 0x64, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x22, 0x8d, 0x03, 0x0a, 0x14, + 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x45, 0x0a, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, + 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x53, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x54, 0x72, + 0x69, 0x61, 0x6c, 0x52, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x12, 0x27, 0x0a, 0x0f, 0x70, + 0x72, 0x6f, 0x6d, 0x6f, 0x74, 0x65, 0x64, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x18, 0x02, + 0x20, 0x03, 0x28, 0x05, 0x52, 0x0e, 0x70, 0x72, 0x6f, 0x6d, 0x6f, 0x74, 0x65, 0x64, 0x54, 0x72, + 0x69, 0x61, 0x6c, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x64, 0x5f, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, + 0x6d, 0x6f, 0x74, 0x65, 0x64, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x1a, 0xaa, 0x01, 0x0a, 0x05, + 0x54, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, + 0x12, 0x31, 0x0a, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x07, 0x68, 0x70, 0x61, 0x72, + 0x61, 0x6d, 0x73, 0x12, 0x30, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1c, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, + 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x50, 0x6f, 0x69, 0x6e, 0x74, 0x52, + 0x04, 0x64, 0x61, 0x74, 0x61, 0x3a, 0x21, 0x92, 0x41, 0x1e, 0x0a, 0x1c, 0xd2, 0x01, 0x08, 0x74, + 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, + 0x73, 0xd2, 0x01, 0x04, 0x64, 0x61, 0x74, 0x61, 0x3a, 0x31, 0x92, 0x41, 0x2e, 0x0a, 0x2c, 0xd2, + 0x01, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0xd2, 0x01, 0x0f, 0x70, 0x72, 0x6f, 0x6d, 0x6f, + 0x74, 0x65, 0x64, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0xd2, 0x01, 0x0e, 0x64, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x64, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x22, 0x39, 0x0a, 0x12, 0x47, + 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0x3f, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, + 0x65, 0x6c, 0x44, 0x65, 0x66, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x17, 0x0a, + 0x07, 0x62, 0x36, 0x34, 0x5f, 0x74, 0x67, 0x7a, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, + 0x62, 0x36, 0x34, 0x54, 0x67, 0x7a, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, + 0x62, 0x36, 0x34, 0x5f, 0x74, 0x67, 0x7a, 0x22, 0xa2, 0x01, 0x0a, 0x15, 0x4d, 0x6f, 0x76, 0x65, + 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x34, 0x0a, 0x16, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x14, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x3a, 0x2e, 0x92, 0x41, + 0x2b, 0x0a, 0x29, 0xd2, 0x01, 0x16, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0d, 0x65, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x22, 0x18, 0x0a, 0x16, + 0x4d, 0x6f, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x96, 0x02, 0x0a, 0x16, 0x4d, 0x6f, 0x76, 0x65, 0x45, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x73, 0x12, 0x34, 0x0a, 0x16, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, + 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x14, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, 0x64, 0x12, 0x42, + 0x0a, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, + 0x2e, 0x76, 0x31, 0x2e, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x73, 0x52, 0x07, 0x66, 0x69, 0x6c, 0x74, 0x65, + 0x72, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x49, + 0x64, 0x3a, 0x3c, 0x92, 0x41, 0x39, 0x0a, 0x37, 0xd2, 0x01, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, + 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x16, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0xd2, 0x01, + 0x0e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x73, 0x22, + 0x6f, 0x0a, 0x17, 0x4d, 0x6f, 0x76, 0x65, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x07, 0x72, 0x65, + 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, + 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, + 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x3a, + 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, + 0x22, 0x3d, 0x0a, 0x16, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x54, + 0x72, 0x65, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, + 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, + 0x53, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x54, 0x72, + 0x65, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x38, 0x0a, 0x05, 0x66, 0x69, + 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x64, 0x65, 0x74, 0x65, + 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x05, 0x66, + 0x69, 0x6c, 0x65, 0x73, 0x22, 0x51, 0x0a, 0x16, 0x47, 0x65, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6c, + 0x44, 0x65, 0x66, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x49, 0x64, 0x3a, 0x15, 0x92, 0x41, 0x12, 0x0a, 0x10, 0xd2, 0x01, 0x0d, 0x65, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x62, 0x6f, 0x61, 0x72, 0x64, 0x46, - 0x69, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x7a, 0x0a, 0x0a, - 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x17, 0x4d, 0x45, - 0x54, 0x52, 0x49, 0x43, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, - 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x4d, 0x45, 0x54, 0x52, 0x49, - 0x43, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x54, 0x52, 0x41, 0x49, 0x4e, 0x49, 0x4e, 0x47, 0x10, - 0x01, 0x12, 0x1a, 0x0a, 0x16, 0x4d, 0x45, 0x54, 0x52, 0x49, 0x43, 0x5f, 0x54, 0x59, 0x50, 0x45, - 0x5f, 0x56, 0x41, 0x4c, 0x49, 0x44, 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x10, 0x02, 0x12, 0x19, 0x0a, - 0x15, 0x4d, 0x45, 0x54, 0x52, 0x49, 0x43, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x50, 0x52, 0x4f, - 0x46, 0x49, 0x4c, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, - 0x64, 0x2d, 0x61, 0x69, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, 0x70, 0x69, 0x76, 0x31, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x74, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x2d, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x6f, + 0x64, 0x65, 0x6c, 0x44, 0x65, 0x66, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0xc5, 0x01, 0x0a, 0x18, 0x53, 0x65, 0x61, 0x72, 0x63, + 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x22, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x48, 0x00, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x6a, 0x65, + 0x63, 0x74, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x16, 0x0a, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x12, + 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, + 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x17, 0x0a, 0x04, 0x73, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, 0x04, 0x73, 0x6f, 0x72, 0x74, 0x88, 0x01, 0x01, 0x12, 0x1b, + 0x0a, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x02, + 0x52, 0x06, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x88, 0x01, 0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, + 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x73, + 0x6f, 0x72, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x22, 0xb1, + 0x01, 0x0a, 0x1a, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x44, 0x0a, + 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, + 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, + 0x65, 0x6e, 0x74, 0x12, 0x39, 0x0a, 0x0a, 0x62, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x72, 0x69, 0x61, + 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, + 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, + 0x69, 0x61, 0x6c, 0x52, 0x09, 0x62, 0x65, 0x73, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x12, + 0x92, 0x41, 0x0f, 0x0a, 0x0d, 0xd2, 0x01, 0x0a, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x22, 0xcd, 0x01, 0x0a, 0x19, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x45, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x4f, 0x0a, 0x0b, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, + 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x45, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x73, 0x12, 0x3d, 0x0a, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, + 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x61, 0x67, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x3a, 0x20, 0x92, 0x41, 0x1d, 0x0a, 0x1b, 0xd2, 0x01, 0x0b, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, + 0x6d, 0x65, 0x6e, 0x74, 0x73, 0xd2, 0x01, 0x0a, 0x70, 0x61, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x22, 0x5b, 0x0a, 0x1d, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x65, 0x6e, 0x73, + 0x6f, 0x72, 0x62, 0x6f, 0x61, 0x72, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, + 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x3a, 0x15, 0x92, 0x41, 0x12, 0x0a, 0x10, 0xd2, + 0x01, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x22, + 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x62, + 0x6f, 0x61, 0x72, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x2a, 0x7a, 0x0a, 0x0a, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x1b, 0x0a, 0x17, 0x4d, 0x45, 0x54, 0x52, 0x49, 0x43, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, + 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, + 0x4d, 0x45, 0x54, 0x52, 0x49, 0x43, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x54, 0x52, 0x41, 0x49, + 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x1a, 0x0a, 0x16, 0x4d, 0x45, 0x54, 0x52, 0x49, 0x43, + 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x56, 0x41, 0x4c, 0x49, 0x44, 0x41, 0x54, 0x49, 0x4f, 0x4e, + 0x10, 0x02, 0x12, 0x19, 0x0a, 0x15, 0x4d, 0x45, 0x54, 0x52, 0x49, 0x43, 0x5f, 0x54, 0x59, 0x50, + 0x45, 0x5f, 0x50, 0x52, 0x4f, 0x46, 0x49, 0x4c, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x42, 0x35, 0x5a, + 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x74, 0x65, + 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2d, 0x61, 0x69, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, + 0x69, 0x6e, 0x65, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, + 0x70, 0x69, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -6407,7 +6173,7 @@ func file_determined_api_v1_experiment_proto_rawDescGZIP() []byte { } var file_determined_api_v1_experiment_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_determined_api_v1_experiment_proto_msgTypes = make([]protoimpl.MessageInfo, 90) +var file_determined_api_v1_experiment_proto_msgTypes = make([]protoimpl.MessageInfo, 86) var file_determined_api_v1_experiment_proto_goTypes = []interface{}{ (MetricType)(0), // 0: determined.api.v1.MetricType (GetExperimentsRequest_SortBy)(0), // 1: determined.api.v1.GetExperimentsRequest.SortBy @@ -6490,63 +6256,57 @@ var file_determined_api_v1_experiment_proto_goTypes = []interface{}{ (*GetModelDefTreeResponse)(nil), // 78: determined.api.v1.GetModelDefTreeResponse (*GetModelDefFileRequest)(nil), // 79: determined.api.v1.GetModelDefFileRequest (*GetModelDefFileResponse)(nil), // 80: determined.api.v1.GetModelDefFileResponse - (*GetSearcherEventsRequest)(nil), // 81: determined.api.v1.GetSearcherEventsRequest - (*GetSearcherEventsResponse)(nil), // 82: determined.api.v1.GetSearcherEventsResponse - (*PostSearcherOperationsRequest)(nil), // 83: determined.api.v1.PostSearcherOperationsRequest - (*PostSearcherOperationsResponse)(nil), // 84: determined.api.v1.PostSearcherOperationsResponse - (*SearchExperimentsRequest)(nil), // 85: determined.api.v1.SearchExperimentsRequest - (*SearchExperimentExperiment)(nil), // 86: determined.api.v1.SearchExperimentExperiment - (*SearchExperimentsResponse)(nil), // 87: determined.api.v1.SearchExperimentsResponse - (*DeleteTensorboardFilesRequest)(nil), // 88: determined.api.v1.DeleteTensorboardFilesRequest - (*DeleteTensorboardFilesResponse)(nil), // 89: determined.api.v1.DeleteTensorboardFilesResponse - (*TrialsSnapshotResponse_Trial)(nil), // 90: determined.api.v1.TrialsSnapshotResponse.Trial - (*TrialsSampleResponse_Trial)(nil), // 91: determined.api.v1.TrialsSampleResponse.Trial - (*_struct.Struct)(nil), // 92: google.protobuf.Struct - (*timestamp.Timestamp)(nil), // 93: google.protobuf.Timestamp - (*experimentv1.Experiment)(nil), // 94: determined.experiment.v1.Experiment - (*jobv1.JobSummary)(nil), // 95: determined.job.v1.JobSummary - (OrderBy)(0), // 96: determined.api.v1.OrderBy - (*wrappers.BoolValue)(nil), // 97: google.protobuf.BoolValue - (experimentv1.State)(0), // 98: determined.experiment.v1.State - (*commonv1.Int32FieldFilter)(nil), // 99: determined.common.v1.Int32FieldFilter - (*Pagination)(nil), // 100: determined.api.v1.Pagination - (*experimentv1.ExperimentSimulation)(nil), // 101: determined.experiment.v1.ExperimentSimulation - (*experimentv1.PatchExperiment)(nil), // 102: determined.experiment.v1.PatchExperiment - (checkpointv1.SortBy)(0), // 103: determined.checkpoint.v1.SortBy - (checkpointv1.State)(0), // 104: determined.checkpoint.v1.State - (*checkpointv1.Checkpoint)(nil), // 105: determined.checkpoint.v1.Checkpoint - (*experimentv1.ValidationHistoryEntry)(nil), // 106: determined.experiment.v1.ValidationHistoryEntry - (*utilv1.File)(nil), // 107: determined.util.v1.File - (LaunchWarning)(0), // 108: determined.api.v1.LaunchWarning - (*metricv1.MetricIdentifier)(nil), // 109: determined.metric.v1.MetricIdentifier - (*experimentv1.FileNode)(nil), // 110: determined.experiment.v1.FileNode - (*experimentv1.SearcherEvent)(nil), // 111: determined.experiment.v1.SearcherEvent - (*experimentv1.SearcherOperation)(nil), // 112: determined.experiment.v1.SearcherOperation - (*trialv1.Trial)(nil), // 113: determined.trial.v1.Trial + (*SearchExperimentsRequest)(nil), // 81: determined.api.v1.SearchExperimentsRequest + (*SearchExperimentExperiment)(nil), // 82: determined.api.v1.SearchExperimentExperiment + (*SearchExperimentsResponse)(nil), // 83: determined.api.v1.SearchExperimentsResponse + (*DeleteTensorboardFilesRequest)(nil), // 84: determined.api.v1.DeleteTensorboardFilesRequest + (*DeleteTensorboardFilesResponse)(nil), // 85: determined.api.v1.DeleteTensorboardFilesResponse + (*TrialsSnapshotResponse_Trial)(nil), // 86: determined.api.v1.TrialsSnapshotResponse.Trial + (*TrialsSampleResponse_Trial)(nil), // 87: determined.api.v1.TrialsSampleResponse.Trial + (*_struct.Struct)(nil), // 88: google.protobuf.Struct + (*timestamp.Timestamp)(nil), // 89: google.protobuf.Timestamp + (*experimentv1.Experiment)(nil), // 90: determined.experiment.v1.Experiment + (*jobv1.JobSummary)(nil), // 91: determined.job.v1.JobSummary + (OrderBy)(0), // 92: determined.api.v1.OrderBy + (*wrappers.BoolValue)(nil), // 93: google.protobuf.BoolValue + (experimentv1.State)(0), // 94: determined.experiment.v1.State + (*commonv1.Int32FieldFilter)(nil), // 95: determined.common.v1.Int32FieldFilter + (*Pagination)(nil), // 96: determined.api.v1.Pagination + (*experimentv1.SearchSummary)(nil), // 97: determined.experiment.v1.SearchSummary + (*experimentv1.PatchExperiment)(nil), // 98: determined.experiment.v1.PatchExperiment + (checkpointv1.SortBy)(0), // 99: determined.checkpoint.v1.SortBy + (checkpointv1.State)(0), // 100: determined.checkpoint.v1.State + (*checkpointv1.Checkpoint)(nil), // 101: determined.checkpoint.v1.Checkpoint + (*experimentv1.ValidationHistoryEntry)(nil), // 102: determined.experiment.v1.ValidationHistoryEntry + (*utilv1.File)(nil), // 103: determined.util.v1.File + (LaunchWarning)(0), // 104: determined.api.v1.LaunchWarning + (*metricv1.MetricIdentifier)(nil), // 105: determined.metric.v1.MetricIdentifier + (*experimentv1.FileNode)(nil), // 106: determined.experiment.v1.FileNode + (*trialv1.Trial)(nil), // 107: determined.trial.v1.Trial } var file_determined_api_v1_experiment_proto_depIdxs = []int32{ - 92, // 0: determined.api.v1.DataPoint.values:type_name -> google.protobuf.Struct - 93, // 1: determined.api.v1.DataPoint.time:type_name -> google.protobuf.Timestamp - 94, // 2: determined.api.v1.GetExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment - 95, // 3: determined.api.v1.GetExperimentResponse.job_summary:type_name -> determined.job.v1.JobSummary - 92, // 4: determined.api.v1.GetExperimentResponse.config:type_name -> google.protobuf.Struct + 88, // 0: determined.api.v1.DataPoint.values:type_name -> google.protobuf.Struct + 89, // 1: determined.api.v1.DataPoint.time:type_name -> google.protobuf.Timestamp + 90, // 2: determined.api.v1.GetExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment + 91, // 3: determined.api.v1.GetExperimentResponse.job_summary:type_name -> determined.job.v1.JobSummary + 88, // 4: determined.api.v1.GetExperimentResponse.config:type_name -> google.protobuf.Struct 1, // 5: determined.api.v1.GetExperimentsRequest.sort_by:type_name -> determined.api.v1.GetExperimentsRequest.SortBy - 96, // 6: determined.api.v1.GetExperimentsRequest.order_by:type_name -> determined.api.v1.OrderBy - 97, // 7: determined.api.v1.GetExperimentsRequest.archived:type_name -> google.protobuf.BoolValue - 98, // 8: determined.api.v1.GetExperimentsRequest.states:type_name -> determined.experiment.v1.State - 99, // 9: determined.api.v1.GetExperimentsRequest.experiment_id_filter:type_name -> determined.common.v1.Int32FieldFilter - 94, // 10: determined.api.v1.GetExperimentsResponse.experiments:type_name -> determined.experiment.v1.Experiment - 100, // 11: determined.api.v1.GetExperimentsResponse.pagination:type_name -> determined.api.v1.Pagination + 92, // 6: determined.api.v1.GetExperimentsRequest.order_by:type_name -> determined.api.v1.OrderBy + 93, // 7: determined.api.v1.GetExperimentsRequest.archived:type_name -> google.protobuf.BoolValue + 94, // 8: determined.api.v1.GetExperimentsRequest.states:type_name -> determined.experiment.v1.State + 95, // 9: determined.api.v1.GetExperimentsRequest.experiment_id_filter:type_name -> determined.common.v1.Int32FieldFilter + 90, // 10: determined.api.v1.GetExperimentsResponse.experiments:type_name -> determined.experiment.v1.Experiment + 96, // 11: determined.api.v1.GetExperimentsResponse.pagination:type_name -> determined.api.v1.Pagination 32, // 12: determined.api.v1.PutExperimentsRetainLogsRequest.filters:type_name -> determined.api.v1.BulkExperimentFilters 27, // 13: determined.api.v1.PutExperimentsRetainLogsResponse.results:type_name -> determined.api.v1.ExperimentActionResult 32, // 14: determined.api.v1.DeleteExperimentsRequest.filters:type_name -> determined.api.v1.BulkExperimentFilters 27, // 15: determined.api.v1.DeleteExperimentsResponse.results:type_name -> determined.api.v1.ExperimentActionResult - 92, // 16: determined.api.v1.PreviewHPSearchRequest.config:type_name -> google.protobuf.Struct - 101, // 17: determined.api.v1.PreviewHPSearchResponse.simulation:type_name -> determined.experiment.v1.ExperimentSimulation + 88, // 16: determined.api.v1.PreviewHPSearchRequest.config:type_name -> google.protobuf.Struct + 97, // 17: determined.api.v1.PreviewHPSearchResponse.summary:type_name -> determined.experiment.v1.SearchSummary 32, // 18: determined.api.v1.ActivateExperimentsRequest.filters:type_name -> determined.api.v1.BulkExperimentFilters 27, // 19: determined.api.v1.ActivateExperimentsResponse.results:type_name -> determined.api.v1.ExperimentActionResult - 97, // 20: determined.api.v1.BulkExperimentFilters.archived:type_name -> google.protobuf.BoolValue - 98, // 21: determined.api.v1.BulkExperimentFilters.states:type_name -> determined.experiment.v1.State + 93, // 20: determined.api.v1.BulkExperimentFilters.archived:type_name -> google.protobuf.BoolValue + 94, // 21: determined.api.v1.BulkExperimentFilters.states:type_name -> determined.experiment.v1.State 32, // 22: determined.api.v1.PauseExperimentsRequest.filters:type_name -> determined.api.v1.BulkExperimentFilters 27, // 23: determined.api.v1.PauseExperimentsResponse.results:type_name -> determined.api.v1.ExperimentActionResult 32, // 24: determined.api.v1.CancelExperimentsRequest.filters:type_name -> determined.api.v1.BulkExperimentFilters @@ -6557,47 +6317,44 @@ var file_determined_api_v1_experiment_proto_depIdxs = []int32{ 27, // 29: determined.api.v1.ArchiveExperimentsResponse.results:type_name -> determined.api.v1.ExperimentActionResult 32, // 30: determined.api.v1.UnarchiveExperimentsRequest.filters:type_name -> determined.api.v1.BulkExperimentFilters 27, // 31: determined.api.v1.UnarchiveExperimentsResponse.results:type_name -> determined.api.v1.ExperimentActionResult - 102, // 32: determined.api.v1.PatchExperimentRequest.experiment:type_name -> determined.experiment.v1.PatchExperiment - 94, // 33: determined.api.v1.PatchExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment - 103, // 34: determined.api.v1.GetExperimentCheckpointsRequest.sort_by_attr:type_name -> determined.checkpoint.v1.SortBy - 96, // 35: determined.api.v1.GetExperimentCheckpointsRequest.order_by:type_name -> determined.api.v1.OrderBy - 104, // 36: determined.api.v1.GetExperimentCheckpointsRequest.states:type_name -> determined.checkpoint.v1.State - 105, // 37: determined.api.v1.GetExperimentCheckpointsResponse.checkpoints:type_name -> determined.checkpoint.v1.Checkpoint - 100, // 38: determined.api.v1.GetExperimentCheckpointsResponse.pagination:type_name -> determined.api.v1.Pagination - 106, // 39: determined.api.v1.GetExperimentValidationHistoryResponse.validation_history:type_name -> determined.experiment.v1.ValidationHistoryEntry - 107, // 40: determined.api.v1.CreateExperimentRequest.model_definition:type_name -> determined.util.v1.File - 94, // 41: determined.api.v1.CreateExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment - 92, // 42: determined.api.v1.CreateExperimentResponse.config:type_name -> google.protobuf.Struct - 108, // 43: determined.api.v1.CreateExperimentResponse.warnings:type_name -> determined.api.v1.LaunchWarning + 98, // 32: determined.api.v1.PatchExperimentRequest.experiment:type_name -> determined.experiment.v1.PatchExperiment + 90, // 33: determined.api.v1.PatchExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment + 99, // 34: determined.api.v1.GetExperimentCheckpointsRequest.sort_by_attr:type_name -> determined.checkpoint.v1.SortBy + 92, // 35: determined.api.v1.GetExperimentCheckpointsRequest.order_by:type_name -> determined.api.v1.OrderBy + 100, // 36: determined.api.v1.GetExperimentCheckpointsRequest.states:type_name -> determined.checkpoint.v1.State + 101, // 37: determined.api.v1.GetExperimentCheckpointsResponse.checkpoints:type_name -> determined.checkpoint.v1.Checkpoint + 96, // 38: determined.api.v1.GetExperimentCheckpointsResponse.pagination:type_name -> determined.api.v1.Pagination + 102, // 39: determined.api.v1.GetExperimentValidationHistoryResponse.validation_history:type_name -> determined.experiment.v1.ValidationHistoryEntry + 103, // 40: determined.api.v1.CreateExperimentRequest.model_definition:type_name -> determined.util.v1.File + 90, // 41: determined.api.v1.CreateExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment + 88, // 42: determined.api.v1.CreateExperimentResponse.config:type_name -> google.protobuf.Struct + 104, // 43: determined.api.v1.CreateExperimentResponse.warnings:type_name -> determined.api.v1.LaunchWarning 57, // 44: determined.api.v1.PutExperimentRequest.create_experiment_request:type_name -> determined.api.v1.CreateExperimentRequest - 94, // 45: determined.api.v1.PutExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment - 92, // 46: determined.api.v1.PutExperimentResponse.config:type_name -> google.protobuf.Struct - 94, // 47: determined.api.v1.ContinueExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment - 108, // 48: determined.api.v1.ContinueExperimentResponse.warnings:type_name -> determined.api.v1.LaunchWarning - 109, // 49: determined.api.v1.ExpMetricNamesResponse.metric_names:type_name -> determined.metric.v1.MetricIdentifier + 90, // 45: determined.api.v1.PutExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment + 88, // 46: determined.api.v1.PutExperimentResponse.config:type_name -> google.protobuf.Struct + 90, // 47: determined.api.v1.ContinueExperimentResponse.experiment:type_name -> determined.experiment.v1.Experiment + 104, // 48: determined.api.v1.ContinueExperimentResponse.warnings:type_name -> determined.api.v1.LaunchWarning + 105, // 49: determined.api.v1.ExpMetricNamesResponse.metric_names:type_name -> determined.metric.v1.MetricIdentifier 0, // 50: determined.api.v1.MetricBatchesRequest.metric_type:type_name -> determined.api.v1.MetricType 0, // 51: determined.api.v1.TrialsSnapshotRequest.metric_type:type_name -> determined.api.v1.MetricType - 90, // 52: determined.api.v1.TrialsSnapshotResponse.trials:type_name -> determined.api.v1.TrialsSnapshotResponse.Trial + 86, // 52: determined.api.v1.TrialsSnapshotResponse.trials:type_name -> determined.api.v1.TrialsSnapshotResponse.Trial 0, // 53: determined.api.v1.TrialsSampleRequest.metric_type:type_name -> determined.api.v1.MetricType - 91, // 54: determined.api.v1.TrialsSampleResponse.trials:type_name -> determined.api.v1.TrialsSampleResponse.Trial + 87, // 54: determined.api.v1.TrialsSampleResponse.trials:type_name -> determined.api.v1.TrialsSampleResponse.Trial 32, // 55: determined.api.v1.MoveExperimentsRequest.filters:type_name -> determined.api.v1.BulkExperimentFilters 27, // 56: determined.api.v1.MoveExperimentsResponse.results:type_name -> determined.api.v1.ExperimentActionResult - 110, // 57: determined.api.v1.GetModelDefTreeResponse.files:type_name -> determined.experiment.v1.FileNode - 111, // 58: determined.api.v1.GetSearcherEventsResponse.searcher_events:type_name -> determined.experiment.v1.SearcherEvent - 112, // 59: determined.api.v1.PostSearcherOperationsRequest.searcher_operations:type_name -> determined.experiment.v1.SearcherOperation - 111, // 60: determined.api.v1.PostSearcherOperationsRequest.triggered_by_event:type_name -> determined.experiment.v1.SearcherEvent - 94, // 61: determined.api.v1.SearchExperimentExperiment.experiment:type_name -> determined.experiment.v1.Experiment - 113, // 62: determined.api.v1.SearchExperimentExperiment.best_trial:type_name -> determined.trial.v1.Trial - 86, // 63: determined.api.v1.SearchExperimentsResponse.experiments:type_name -> determined.api.v1.SearchExperimentExperiment - 100, // 64: determined.api.v1.SearchExperimentsResponse.pagination:type_name -> determined.api.v1.Pagination - 92, // 65: determined.api.v1.TrialsSnapshotResponse.Trial.hparams:type_name -> google.protobuf.Struct - 92, // 66: determined.api.v1.TrialsSampleResponse.Trial.hparams:type_name -> google.protobuf.Struct - 2, // 67: determined.api.v1.TrialsSampleResponse.Trial.data:type_name -> determined.api.v1.DataPoint - 68, // [68:68] is the sub-list for method output_type - 68, // [68:68] is the sub-list for method input_type - 68, // [68:68] is the sub-list for extension type_name - 68, // [68:68] is the sub-list for extension extendee - 0, // [0:68] is the sub-list for field type_name + 106, // 57: determined.api.v1.GetModelDefTreeResponse.files:type_name -> determined.experiment.v1.FileNode + 90, // 58: determined.api.v1.SearchExperimentExperiment.experiment:type_name -> determined.experiment.v1.Experiment + 107, // 59: determined.api.v1.SearchExperimentExperiment.best_trial:type_name -> determined.trial.v1.Trial + 82, // 60: determined.api.v1.SearchExperimentsResponse.experiments:type_name -> determined.api.v1.SearchExperimentExperiment + 96, // 61: determined.api.v1.SearchExperimentsResponse.pagination:type_name -> determined.api.v1.Pagination + 88, // 62: determined.api.v1.TrialsSnapshotResponse.Trial.hparams:type_name -> google.protobuf.Struct + 88, // 63: determined.api.v1.TrialsSampleResponse.Trial.hparams:type_name -> google.protobuf.Struct + 2, // 64: determined.api.v1.TrialsSampleResponse.Trial.data:type_name -> determined.api.v1.DataPoint + 65, // [65:65] is the sub-list for method output_type + 65, // [65:65] is the sub-list for method input_type + 65, // [65:65] is the sub-list for extension type_name + 65, // [65:65] is the sub-list for extension extendee + 0, // [0:65] is the sub-list for field type_name } func init() { file_determined_api_v1_experiment_proto_init() } @@ -7557,54 +7314,6 @@ func file_determined_api_v1_experiment_proto_init() { } } file_determined_api_v1_experiment_proto_msgTypes[79].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetSearcherEventsRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_experiment_proto_msgTypes[80].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetSearcherEventsResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_experiment_proto_msgTypes[81].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PostSearcherOperationsRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_experiment_proto_msgTypes[82].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PostSearcherOperationsResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_experiment_proto_msgTypes[83].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SearchExperimentsRequest); i { case 0: return &v.state @@ -7616,7 +7325,7 @@ func file_determined_api_v1_experiment_proto_init() { return nil } } - file_determined_api_v1_experiment_proto_msgTypes[84].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_experiment_proto_msgTypes[80].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SearchExperimentExperiment); i { case 0: return &v.state @@ -7628,7 +7337,7 @@ func file_determined_api_v1_experiment_proto_init() { return nil } } - file_determined_api_v1_experiment_proto_msgTypes[85].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_experiment_proto_msgTypes[81].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SearchExperimentsResponse); i { case 0: return &v.state @@ -7640,7 +7349,7 @@ func file_determined_api_v1_experiment_proto_init() { return nil } } - file_determined_api_v1_experiment_proto_msgTypes[86].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_experiment_proto_msgTypes[82].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*DeleteTensorboardFilesRequest); i { case 0: return &v.state @@ -7652,7 +7361,7 @@ func file_determined_api_v1_experiment_proto_init() { return nil } } - file_determined_api_v1_experiment_proto_msgTypes[87].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_experiment_proto_msgTypes[83].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*DeleteTensorboardFilesResponse); i { case 0: return &v.state @@ -7664,7 +7373,7 @@ func file_determined_api_v1_experiment_proto_init() { return nil } } - file_determined_api_v1_experiment_proto_msgTypes[88].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_experiment_proto_msgTypes[84].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*TrialsSnapshotResponse_Trial); i { case 0: return &v.state @@ -7676,7 +7385,7 @@ func file_determined_api_v1_experiment_proto_init() { return nil } } - file_determined_api_v1_experiment_proto_msgTypes[89].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_experiment_proto_msgTypes[85].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*TrialsSampleResponse_Trial); i { case 0: return &v.state @@ -7695,14 +7404,14 @@ func file_determined_api_v1_experiment_proto_init() { (*GetExperimentCheckpointsRequest_SortByMetric)(nil), } file_determined_api_v1_experiment_proto_msgTypes[55].OneofWrappers = []interface{}{} - file_determined_api_v1_experiment_proto_msgTypes[83].OneofWrappers = []interface{}{} + file_determined_api_v1_experiment_proto_msgTypes[79].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_determined_api_v1_experiment_proto_rawDesc, NumEnums: 2, - NumMessages: 90, + NumMessages: 86, NumExtensions: 0, NumServices: 0, }, diff --git a/proto/pkg/apiv1/trial.pb.go b/proto/pkg/apiv1/trial.pb.go index 9e3a30c408d..28b61a7e3e3 100644 --- a/proto/pkg/apiv1/trial.pb.go +++ b/proto/pkg/apiv1/trial.pb.go @@ -3644,211 +3644,6 @@ func (x *NotifyContainerRunningResponse) GetData() []*_struct.Struct { return nil } -// Retrieves the current searcher operation. -type GetCurrentTrialSearcherOperationRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The id of the trial. - TrialId int32 `protobuf:"varint,1,opt,name=trial_id,json=trialId,proto3" json:"trial_id,omitempty"` -} - -func (x *GetCurrentTrialSearcherOperationRequest) Reset() { - *x = GetCurrentTrialSearcherOperationRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[53] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetCurrentTrialSearcherOperationRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetCurrentTrialSearcherOperationRequest) ProtoMessage() {} - -func (x *GetCurrentTrialSearcherOperationRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[53] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetCurrentTrialSearcherOperationRequest.ProtoReflect.Descriptor instead. -func (*GetCurrentTrialSearcherOperationRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{53} -} - -func (x *GetCurrentTrialSearcherOperationRequest) GetTrialId() int32 { - if x != nil { - return x.TrialId - } - return 0 -} - -// Response to GetCurrentTrialSearcherOperationRequest -type GetCurrentTrialSearcherOperationResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The current searcher operation. - Op *experimentv1.TrialOperation `protobuf:"bytes,1,opt,name=op,proto3" json:"op,omitempty"` - // The status of the searcher operation. - Completed bool `protobuf:"varint,2,opt,name=completed,proto3" json:"completed,omitempty"` -} - -func (x *GetCurrentTrialSearcherOperationResponse) Reset() { - *x = GetCurrentTrialSearcherOperationResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[54] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetCurrentTrialSearcherOperationResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetCurrentTrialSearcherOperationResponse) ProtoMessage() {} - -func (x *GetCurrentTrialSearcherOperationResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[54] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetCurrentTrialSearcherOperationResponse.ProtoReflect.Descriptor instead. -func (*GetCurrentTrialSearcherOperationResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{54} -} - -func (x *GetCurrentTrialSearcherOperationResponse) GetOp() *experimentv1.TrialOperation { - if x != nil { - return x.Op - } - return nil -} - -func (x *GetCurrentTrialSearcherOperationResponse) GetCompleted() bool { - if x != nil { - return x.Completed - } - return false -} - -// Reports to the searcher that the trial has completed the current requested -// amount of training. -type CompleteTrialSearcherValidationRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The id of the trial. - TrialId int32 `protobuf:"varint,1,opt,name=trial_id,json=trialId,proto3" json:"trial_id,omitempty"` - // The completed operation. - CompletedOperation *experimentv1.CompleteValidateAfterOperation `protobuf:"bytes,2,opt,name=completed_operation,json=completedOperation,proto3" json:"completed_operation,omitempty"` -} - -func (x *CompleteTrialSearcherValidationRequest) Reset() { - *x = CompleteTrialSearcherValidationRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[55] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *CompleteTrialSearcherValidationRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*CompleteTrialSearcherValidationRequest) ProtoMessage() {} - -func (x *CompleteTrialSearcherValidationRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[55] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use CompleteTrialSearcherValidationRequest.ProtoReflect.Descriptor instead. -func (*CompleteTrialSearcherValidationRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{55} -} - -func (x *CompleteTrialSearcherValidationRequest) GetTrialId() int32 { - if x != nil { - return x.TrialId - } - return 0 -} - -func (x *CompleteTrialSearcherValidationRequest) GetCompletedOperation() *experimentv1.CompleteValidateAfterOperation { - if x != nil { - return x.CompletedOperation - } - return nil -} - -// Response to CompleteTrialSearcherValidationRequest -type CompleteTrialSearcherValidationResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *CompleteTrialSearcherValidationResponse) Reset() { - *x = CompleteTrialSearcherValidationResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[56] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *CompleteTrialSearcherValidationResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*CompleteTrialSearcherValidationResponse) ProtoMessage() {} - -func (x *CompleteTrialSearcherValidationResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[56] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use CompleteTrialSearcherValidationResponse.ProtoReflect.Descriptor instead. -func (*CompleteTrialSearcherValidationResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{56} -} - // Report a voluntary, permanent early exit to the searcher. type ReportTrialSearcherEarlyExitRequest struct { state protoimpl.MessageState @@ -3864,7 +3659,7 @@ type ReportTrialSearcherEarlyExitRequest struct { func (x *ReportTrialSearcherEarlyExitRequest) Reset() { *x = ReportTrialSearcherEarlyExitRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[57] + mi := &file_determined_api_v1_trial_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3877,7 +3672,7 @@ func (x *ReportTrialSearcherEarlyExitRequest) String() string { func (*ReportTrialSearcherEarlyExitRequest) ProtoMessage() {} func (x *ReportTrialSearcherEarlyExitRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[57] + mi := &file_determined_api_v1_trial_proto_msgTypes[53] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3890,7 +3685,7 @@ func (x *ReportTrialSearcherEarlyExitRequest) ProtoReflect() protoreflect.Messag // Deprecated: Use ReportTrialSearcherEarlyExitRequest.ProtoReflect.Descriptor instead. func (*ReportTrialSearcherEarlyExitRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{57} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{53} } func (x *ReportTrialSearcherEarlyExitRequest) GetTrialId() int32 { @@ -3917,7 +3712,7 @@ type ReportTrialSearcherEarlyExitResponse struct { func (x *ReportTrialSearcherEarlyExitResponse) Reset() { *x = ReportTrialSearcherEarlyExitResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[58] + mi := &file_determined_api_v1_trial_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3930,7 +3725,7 @@ func (x *ReportTrialSearcherEarlyExitResponse) String() string { func (*ReportTrialSearcherEarlyExitResponse) ProtoMessage() {} func (x *ReportTrialSearcherEarlyExitResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[58] + mi := &file_determined_api_v1_trial_proto_msgTypes[54] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3943,7 +3738,7 @@ func (x *ReportTrialSearcherEarlyExitResponse) ProtoReflect() protoreflect.Messa // Deprecated: Use ReportTrialSearcherEarlyExitResponse.ProtoReflect.Descriptor instead. func (*ReportTrialSearcherEarlyExitResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{58} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{54} } // For bookkeeping, updates the progress of the trial as a percent torwards @@ -3966,7 +3761,7 @@ type ReportTrialProgressRequest struct { func (x *ReportTrialProgressRequest) Reset() { *x = ReportTrialProgressRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[59] + mi := &file_determined_api_v1_trial_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3979,7 +3774,7 @@ func (x *ReportTrialProgressRequest) String() string { func (*ReportTrialProgressRequest) ProtoMessage() {} func (x *ReportTrialProgressRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[59] + mi := &file_determined_api_v1_trial_proto_msgTypes[55] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3992,7 +3787,7 @@ func (x *ReportTrialProgressRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportTrialProgressRequest.ProtoReflect.Descriptor instead. func (*ReportTrialProgressRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{59} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{55} } func (x *ReportTrialProgressRequest) GetTrialId() int32 { @@ -4026,7 +3821,7 @@ type ReportTrialProgressResponse struct { func (x *ReportTrialProgressResponse) Reset() { *x = ReportTrialProgressResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[60] + mi := &file_determined_api_v1_trial_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4039,7 +3834,7 @@ func (x *ReportTrialProgressResponse) String() string { func (*ReportTrialProgressResponse) ProtoMessage() {} func (x *ReportTrialProgressResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[60] + mi := &file_determined_api_v1_trial_proto_msgTypes[56] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4052,7 +3847,7 @@ func (x *ReportTrialProgressResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportTrialProgressResponse.ProtoReflect.Descriptor instead. func (*ReportTrialProgressResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{60} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{56} } // Persist the given metrics for the trial. @@ -4070,7 +3865,7 @@ type ReportTrialMetricsRequest struct { func (x *ReportTrialMetricsRequest) Reset() { *x = ReportTrialMetricsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[61] + mi := &file_determined_api_v1_trial_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4083,7 +3878,7 @@ func (x *ReportTrialMetricsRequest) String() string { func (*ReportTrialMetricsRequest) ProtoMessage() {} func (x *ReportTrialMetricsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[61] + mi := &file_determined_api_v1_trial_proto_msgTypes[57] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4096,7 +3891,7 @@ func (x *ReportTrialMetricsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportTrialMetricsRequest.ProtoReflect.Descriptor instead. func (*ReportTrialMetricsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{61} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{57} } func (x *ReportTrialMetricsRequest) GetMetrics() *trialv1.TrialMetrics { @@ -4123,7 +3918,7 @@ type ReportTrialMetricsResponse struct { func (x *ReportTrialMetricsResponse) Reset() { *x = ReportTrialMetricsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[62] + mi := &file_determined_api_v1_trial_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4136,7 +3931,7 @@ func (x *ReportTrialMetricsResponse) String() string { func (*ReportTrialMetricsResponse) ProtoMessage() {} func (x *ReportTrialMetricsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[62] + mi := &file_determined_api_v1_trial_proto_msgTypes[58] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4149,7 +3944,7 @@ func (x *ReportTrialMetricsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportTrialMetricsResponse.ProtoReflect.Descriptor instead. func (*ReportTrialMetricsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{62} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{58} } // Persist the given training metrics for the trial. @@ -4165,7 +3960,7 @@ type ReportTrialTrainingMetricsRequest struct { func (x *ReportTrialTrainingMetricsRequest) Reset() { *x = ReportTrialTrainingMetricsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[63] + mi := &file_determined_api_v1_trial_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4178,7 +3973,7 @@ func (x *ReportTrialTrainingMetricsRequest) String() string { func (*ReportTrialTrainingMetricsRequest) ProtoMessage() {} func (x *ReportTrialTrainingMetricsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[63] + mi := &file_determined_api_v1_trial_proto_msgTypes[59] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4191,7 +3986,7 @@ func (x *ReportTrialTrainingMetricsRequest) ProtoReflect() protoreflect.Message // Deprecated: Use ReportTrialTrainingMetricsRequest.ProtoReflect.Descriptor instead. func (*ReportTrialTrainingMetricsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{63} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{59} } func (x *ReportTrialTrainingMetricsRequest) GetTrainingMetrics() *trialv1.TrialMetrics { @@ -4211,7 +4006,7 @@ type ReportTrialTrainingMetricsResponse struct { func (x *ReportTrialTrainingMetricsResponse) Reset() { *x = ReportTrialTrainingMetricsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[64] + mi := &file_determined_api_v1_trial_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4224,7 +4019,7 @@ func (x *ReportTrialTrainingMetricsResponse) String() string { func (*ReportTrialTrainingMetricsResponse) ProtoMessage() {} func (x *ReportTrialTrainingMetricsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[64] + mi := &file_determined_api_v1_trial_proto_msgTypes[60] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4237,7 +4032,7 @@ func (x *ReportTrialTrainingMetricsResponse) ProtoReflect() protoreflect.Message // Deprecated: Use ReportTrialTrainingMetricsResponse.ProtoReflect.Descriptor instead. func (*ReportTrialTrainingMetricsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{64} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{60} } // Persist the given validation metrics for the trial. @@ -4253,7 +4048,7 @@ type ReportTrialValidationMetricsRequest struct { func (x *ReportTrialValidationMetricsRequest) Reset() { *x = ReportTrialValidationMetricsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[65] + mi := &file_determined_api_v1_trial_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4266,7 +4061,7 @@ func (x *ReportTrialValidationMetricsRequest) String() string { func (*ReportTrialValidationMetricsRequest) ProtoMessage() {} func (x *ReportTrialValidationMetricsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[65] + mi := &file_determined_api_v1_trial_proto_msgTypes[61] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4279,7 +4074,7 @@ func (x *ReportTrialValidationMetricsRequest) ProtoReflect() protoreflect.Messag // Deprecated: Use ReportTrialValidationMetricsRequest.ProtoReflect.Descriptor instead. func (*ReportTrialValidationMetricsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{65} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{61} } func (x *ReportTrialValidationMetricsRequest) GetValidationMetrics() *trialv1.TrialMetrics { @@ -4299,7 +4094,7 @@ type ReportTrialValidationMetricsResponse struct { func (x *ReportTrialValidationMetricsResponse) Reset() { *x = ReportTrialValidationMetricsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[66] + mi := &file_determined_api_v1_trial_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4312,7 +4107,7 @@ func (x *ReportTrialValidationMetricsResponse) String() string { func (*ReportTrialValidationMetricsResponse) ProtoMessage() {} func (x *ReportTrialValidationMetricsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[66] + mi := &file_determined_api_v1_trial_proto_msgTypes[62] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4325,7 +4120,7 @@ func (x *ReportTrialValidationMetricsResponse) ProtoReflect() protoreflect.Messa // Deprecated: Use ReportTrialValidationMetricsResponse.ProtoReflect.Descriptor instead. func (*ReportTrialValidationMetricsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{66} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{62} } // Partially update the trial metadata. @@ -4343,7 +4138,7 @@ type PostTrialRunnerMetadataRequest struct { func (x *PostTrialRunnerMetadataRequest) Reset() { *x = PostTrialRunnerMetadataRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[67] + mi := &file_determined_api_v1_trial_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4356,7 +4151,7 @@ func (x *PostTrialRunnerMetadataRequest) String() string { func (*PostTrialRunnerMetadataRequest) ProtoMessage() {} func (x *PostTrialRunnerMetadataRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[67] + mi := &file_determined_api_v1_trial_proto_msgTypes[63] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4369,7 +4164,7 @@ func (x *PostTrialRunnerMetadataRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PostTrialRunnerMetadataRequest.ProtoReflect.Descriptor instead. func (*PostTrialRunnerMetadataRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{67} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{63} } func (x *PostTrialRunnerMetadataRequest) GetTrialId() int32 { @@ -4396,7 +4191,7 @@ type PostTrialRunnerMetadataResponse struct { func (x *PostTrialRunnerMetadataResponse) Reset() { *x = PostTrialRunnerMetadataResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[68] + mi := &file_determined_api_v1_trial_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4409,7 +4204,7 @@ func (x *PostTrialRunnerMetadataResponse) String() string { func (*PostTrialRunnerMetadataResponse) ProtoMessage() {} func (x *PostTrialRunnerMetadataResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[68] + mi := &file_determined_api_v1_trial_proto_msgTypes[64] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4422,7 +4217,7 @@ func (x *PostTrialRunnerMetadataResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use PostTrialRunnerMetadataResponse.ProtoReflect.Descriptor instead. func (*PostTrialRunnerMetadataResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{68} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{64} } // Stream training metrics. @@ -4440,7 +4235,7 @@ type GetMetricsRequest struct { func (x *GetMetricsRequest) Reset() { *x = GetMetricsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[69] + mi := &file_determined_api_v1_trial_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4453,7 +4248,7 @@ func (x *GetMetricsRequest) String() string { func (*GetMetricsRequest) ProtoMessage() {} func (x *GetMetricsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[69] + mi := &file_determined_api_v1_trial_proto_msgTypes[65] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4466,7 +4261,7 @@ func (x *GetMetricsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMetricsRequest.ProtoReflect.Descriptor instead. func (*GetMetricsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{69} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{65} } func (x *GetMetricsRequest) GetTrialIds() []int32 { @@ -4496,7 +4291,7 @@ type GetMetricsResponse struct { func (x *GetMetricsResponse) Reset() { *x = GetMetricsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[70] + mi := &file_determined_api_v1_trial_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4509,7 +4304,7 @@ func (x *GetMetricsResponse) String() string { func (*GetMetricsResponse) ProtoMessage() {} func (x *GetMetricsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[70] + mi := &file_determined_api_v1_trial_proto_msgTypes[66] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4522,7 +4317,7 @@ func (x *GetMetricsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMetricsResponse.ProtoReflect.Descriptor instead. func (*GetMetricsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{70} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{66} } func (x *GetMetricsResponse) GetMetrics() []*trialv1.MetricsReport { @@ -4545,7 +4340,7 @@ type GetTrainingMetricsRequest struct { func (x *GetTrainingMetricsRequest) Reset() { *x = GetTrainingMetricsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[71] + mi := &file_determined_api_v1_trial_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4558,7 +4353,7 @@ func (x *GetTrainingMetricsRequest) String() string { func (*GetTrainingMetricsRequest) ProtoMessage() {} func (x *GetTrainingMetricsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[71] + mi := &file_determined_api_v1_trial_proto_msgTypes[67] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4571,7 +4366,7 @@ func (x *GetTrainingMetricsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetTrainingMetricsRequest.ProtoReflect.Descriptor instead. func (*GetTrainingMetricsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{71} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{67} } func (x *GetTrainingMetricsRequest) GetTrialIds() []int32 { @@ -4594,7 +4389,7 @@ type GetTrainingMetricsResponse struct { func (x *GetTrainingMetricsResponse) Reset() { *x = GetTrainingMetricsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[72] + mi := &file_determined_api_v1_trial_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4607,7 +4402,7 @@ func (x *GetTrainingMetricsResponse) String() string { func (*GetTrainingMetricsResponse) ProtoMessage() {} func (x *GetTrainingMetricsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[72] + mi := &file_determined_api_v1_trial_proto_msgTypes[68] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4620,7 +4415,7 @@ func (x *GetTrainingMetricsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetTrainingMetricsResponse.ProtoReflect.Descriptor instead. func (*GetTrainingMetricsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{72} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{68} } func (x *GetTrainingMetricsResponse) GetMetrics() []*trialv1.MetricsReport { @@ -4643,7 +4438,7 @@ type GetValidationMetricsRequest struct { func (x *GetValidationMetricsRequest) Reset() { *x = GetValidationMetricsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[73] + mi := &file_determined_api_v1_trial_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4656,7 +4451,7 @@ func (x *GetValidationMetricsRequest) String() string { func (*GetValidationMetricsRequest) ProtoMessage() {} func (x *GetValidationMetricsRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[73] + mi := &file_determined_api_v1_trial_proto_msgTypes[69] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4669,7 +4464,7 @@ func (x *GetValidationMetricsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetValidationMetricsRequest.ProtoReflect.Descriptor instead. func (*GetValidationMetricsRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{73} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{69} } func (x *GetValidationMetricsRequest) GetTrialIds() []int32 { @@ -4692,7 +4487,7 @@ type GetValidationMetricsResponse struct { func (x *GetValidationMetricsResponse) Reset() { *x = GetValidationMetricsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[74] + mi := &file_determined_api_v1_trial_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4705,7 +4500,7 @@ func (x *GetValidationMetricsResponse) String() string { func (*GetValidationMetricsResponse) ProtoMessage() {} func (x *GetValidationMetricsResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[74] + mi := &file_determined_api_v1_trial_proto_msgTypes[70] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4718,7 +4513,7 @@ func (x *GetValidationMetricsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetValidationMetricsResponse.ProtoReflect.Descriptor instead. func (*GetValidationMetricsResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{74} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{70} } func (x *GetValidationMetricsResponse) GetMetrics() []*trialv1.MetricsReport { @@ -4745,7 +4540,7 @@ type CreateTrialRequest struct { func (x *CreateTrialRequest) Reset() { *x = CreateTrialRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[75] + mi := &file_determined_api_v1_trial_proto_msgTypes[71] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4758,7 +4553,7 @@ func (x *CreateTrialRequest) String() string { func (*CreateTrialRequest) ProtoMessage() {} func (x *CreateTrialRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[75] + mi := &file_determined_api_v1_trial_proto_msgTypes[71] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4771,7 +4566,7 @@ func (x *CreateTrialRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateTrialRequest.ProtoReflect.Descriptor instead. func (*CreateTrialRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{75} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{71} } func (x *CreateTrialRequest) GetExperimentId() int32 { @@ -4808,7 +4603,7 @@ type CreateTrialResponse struct { func (x *CreateTrialResponse) Reset() { *x = CreateTrialResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[76] + mi := &file_determined_api_v1_trial_proto_msgTypes[72] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4821,7 +4616,7 @@ func (x *CreateTrialResponse) String() string { func (*CreateTrialResponse) ProtoMessage() {} func (x *CreateTrialResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[76] + mi := &file_determined_api_v1_trial_proto_msgTypes[72] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4834,7 +4629,7 @@ func (x *CreateTrialResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateTrialResponse.ProtoReflect.Descriptor instead. func (*CreateTrialResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{76} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{72} } func (x *CreateTrialResponse) GetTrial() *trialv1.Trial { @@ -4859,7 +4654,7 @@ type PutTrialRequest struct { func (x *PutTrialRequest) Reset() { *x = PutTrialRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[77] + mi := &file_determined_api_v1_trial_proto_msgTypes[73] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4872,7 +4667,7 @@ func (x *PutTrialRequest) String() string { func (*PutTrialRequest) ProtoMessage() {} func (x *PutTrialRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[77] + mi := &file_determined_api_v1_trial_proto_msgTypes[73] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4885,7 +4680,7 @@ func (x *PutTrialRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PutTrialRequest.ProtoReflect.Descriptor instead. func (*PutTrialRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{77} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{73} } func (x *PutTrialRequest) GetCreateTrialRequest() *CreateTrialRequest { @@ -4915,7 +4710,7 @@ type PutTrialResponse struct { func (x *PutTrialResponse) Reset() { *x = PutTrialResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[78] + mi := &file_determined_api_v1_trial_proto_msgTypes[74] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4928,7 +4723,7 @@ func (x *PutTrialResponse) String() string { func (*PutTrialResponse) ProtoMessage() {} func (x *PutTrialResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[78] + mi := &file_determined_api_v1_trial_proto_msgTypes[74] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4941,7 +4736,7 @@ func (x *PutTrialResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use PutTrialResponse.ProtoReflect.Descriptor instead. func (*PutTrialResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{78} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{74} } func (x *PutTrialResponse) GetTrial() *trialv1.Trial { @@ -4966,7 +4761,7 @@ type PatchTrialRequest struct { func (x *PatchTrialRequest) Reset() { *x = PatchTrialRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[79] + mi := &file_determined_api_v1_trial_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4979,7 +4774,7 @@ func (x *PatchTrialRequest) String() string { func (*PatchTrialRequest) ProtoMessage() {} func (x *PatchTrialRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[79] + mi := &file_determined_api_v1_trial_proto_msgTypes[75] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4992,7 +4787,7 @@ func (x *PatchTrialRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PatchTrialRequest.ProtoReflect.Descriptor instead. func (*PatchTrialRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{79} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{75} } func (x *PatchTrialRequest) GetTrialId() int32 { @@ -5022,7 +4817,7 @@ type PatchTrialResponse struct { func (x *PatchTrialResponse) Reset() { *x = PatchTrialResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[80] + mi := &file_determined_api_v1_trial_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5035,7 +4830,7 @@ func (x *PatchTrialResponse) String() string { func (*PatchTrialResponse) ProtoMessage() {} func (x *PatchTrialResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[80] + mi := &file_determined_api_v1_trial_proto_msgTypes[76] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5048,7 +4843,7 @@ func (x *PatchTrialResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use PatchTrialResponse.ProtoReflect.Descriptor instead. func (*PatchTrialResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{80} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{76} } func (x *PatchTrialResponse) GetTrial() *trialv1.Trial { @@ -5073,7 +4868,7 @@ type StartTrialRequest struct { func (x *StartTrialRequest) Reset() { *x = StartTrialRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[81] + mi := &file_determined_api_v1_trial_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5086,7 +4881,7 @@ func (x *StartTrialRequest) String() string { func (*StartTrialRequest) ProtoMessage() {} func (x *StartTrialRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[81] + mi := &file_determined_api_v1_trial_proto_msgTypes[77] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5099,7 +4894,7 @@ func (x *StartTrialRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StartTrialRequest.ProtoReflect.Descriptor instead. func (*StartTrialRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{81} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{77} } func (x *StartTrialRequest) GetTrialId() int32 { @@ -5133,7 +4928,7 @@ type StartTrialResponse struct { func (x *StartTrialResponse) Reset() { *x = StartTrialResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[82] + mi := &file_determined_api_v1_trial_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5146,7 +4941,7 @@ func (x *StartTrialResponse) String() string { func (*StartTrialResponse) ProtoMessage() {} func (x *StartTrialResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[82] + mi := &file_determined_api_v1_trial_proto_msgTypes[78] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5159,7 +4954,7 @@ func (x *StartTrialResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StartTrialResponse.ProtoReflect.Descriptor instead. func (*StartTrialResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{82} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{78} } func (x *StartTrialResponse) GetTrialRunId() int32 { @@ -5196,7 +4991,7 @@ type ReportTrialSourceInfoRequest struct { func (x *ReportTrialSourceInfoRequest) Reset() { *x = ReportTrialSourceInfoRequest{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[83] + mi := &file_determined_api_v1_trial_proto_msgTypes[79] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5209,7 +5004,7 @@ func (x *ReportTrialSourceInfoRequest) String() string { func (*ReportTrialSourceInfoRequest) ProtoMessage() {} func (x *ReportTrialSourceInfoRequest) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[83] + mi := &file_determined_api_v1_trial_proto_msgTypes[79] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5222,7 +5017,7 @@ func (x *ReportTrialSourceInfoRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportTrialSourceInfoRequest.ProtoReflect.Descriptor instead. func (*ReportTrialSourceInfoRequest) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{83} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{79} } func (x *ReportTrialSourceInfoRequest) GetTrialSourceInfo() *trialv1.TrialSourceInfo { @@ -5247,7 +5042,7 @@ type ReportTrialSourceInfoResponse struct { func (x *ReportTrialSourceInfoResponse) Reset() { *x = ReportTrialSourceInfoResponse{} if protoimpl.UnsafeEnabled { - mi := &file_determined_api_v1_trial_proto_msgTypes[84] + mi := &file_determined_api_v1_trial_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5260,7 +5055,7 @@ func (x *ReportTrialSourceInfoResponse) String() string { func (*ReportTrialSourceInfoResponse) ProtoMessage() {} func (x *ReportTrialSourceInfoResponse) ProtoReflect() protoreflect.Message { - mi := &file_determined_api_v1_trial_proto_msgTypes[84] + mi := &file_determined_api_v1_trial_proto_msgTypes[80] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5273,7 +5068,7 @@ func (x *ReportTrialSourceInfoResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportTrialSourceInfoResponse.ProtoReflect.Descriptor instead. func (*ReportTrialSourceInfoResponse) Descriptor() ([]byte, []int) { - return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{84} + return file_determined_api_v1_trial_proto_rawDescGZIP(), []int{80} } func (x *ReportTrialSourceInfoResponse) GetTrialId() int32 { @@ -5877,227 +5672,197 @@ var file_determined_api_v1_trial_proto_rawDesc = []byte{ 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x3a, 0x0c, 0x92, 0x41, 0x09, - 0x0a, 0x07, 0xd2, 0x01, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0x56, 0x0a, 0x27, 0x47, 0x65, 0x74, - 0x43, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x3a, - 0x10, 0x92, 0x41, 0x0d, 0x0a, 0x0b, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, - 0x64, 0x22, 0x82, 0x01, 0x0a, 0x28, 0x47, 0x65, 0x74, 0x43, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, - 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4f, 0x70, 0x65, - 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x38, - 0x0a, 0x02, 0x6f, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, - 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x02, 0x6f, 0x70, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6d, 0x70, - 0x6c, 0x65, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6d, - 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x22, 0xd2, 0x01, 0x0a, 0x26, 0x43, 0x6f, 0x6d, 0x70, 0x6c, - 0x65, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, - 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x0a, 0x07, 0xd2, 0x01, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0x9f, 0x01, 0x0a, 0x23, 0x52, 0x65, + 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, + 0x72, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x45, 0x78, 0x69, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x69, 0x0a, 0x13, - 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x38, 0x2e, 0x64, 0x65, 0x74, 0x65, - 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x56, 0x61, 0x6c, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, - 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x3a, 0x22, 0x92, 0x41, 0x1f, 0x0a, 0x1d, 0xd2, 0x01, - 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0f, 0x73, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x65, 0x72, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x22, 0x29, 0x0a, 0x27, 0x43, - 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x65, 0x72, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x9f, 0x01, 0x0a, 0x23, 0x52, 0x65, 0x70, 0x6f, 0x72, - 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x45, 0x61, - 0x72, 0x6c, 0x79, 0x45, 0x78, 0x69, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, - 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x42, 0x0a, 0x0a, 0x65, 0x61, 0x72, - 0x6c, 0x79, 0x5f, 0x65, 0x78, 0x69, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, - 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, - 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x45, 0x78, - 0x69, 0x74, 0x52, 0x09, 0x65, 0x61, 0x72, 0x6c, 0x79, 0x45, 0x78, 0x69, 0x74, 0x3a, 0x19, 0x92, - 0x41, 0x16, 0x0a, 0x14, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, - 0x01, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x22, 0x26, 0x0a, 0x24, 0x52, 0x65, 0x70, 0x6f, - 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x45, - 0x61, 0x72, 0x6c, 0x79, 0x45, 0x78, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x87, 0x01, 0x0a, 0x1a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, - 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, - 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x01, 0x52, 0x08, 0x70, 0x72, - 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x15, 0x0a, 0x06, 0x69, 0x73, 0x5f, 0x72, 0x61, 0x77, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x69, 0x73, 0x52, 0x61, 0x77, 0x3a, 0x1b, 0x92, - 0x41, 0x18, 0x0a, 0x16, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, - 0x01, 0x08, 0x70, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x22, 0x1d, 0x0a, 0x1b, 0x52, 0x65, - 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x87, 0x01, 0x0a, 0x19, 0x52, 0x65, - 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x3b, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, - 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, - 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x07, 0x6d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x3a, 0x17, 0x92, 0x41, 0x14, 0x0a, - 0x12, 0xd2, 0x01, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0xd2, 0x01, 0x05, 0x67, 0x72, - 0x6f, 0x75, 0x70, 0x22, 0x1c, 0x0a, 0x1a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, - 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x21, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, + 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x42, 0x0a, 0x0a, + 0x65, 0x61, 0x72, 0x6c, 0x79, 0x5f, 0x65, 0x78, 0x69, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x23, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, + 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x45, 0x61, 0x72, 0x6c, + 0x79, 0x45, 0x78, 0x69, 0x74, 0x52, 0x09, 0x65, 0x61, 0x72, 0x6c, 0x79, 0x45, 0x78, 0x69, 0x74, + 0x3a, 0x19, 0x92, 0x41, 0x16, 0x0a, 0x14, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, + 0x69, 0x64, 0xd2, 0x01, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x22, 0x26, 0x0a, 0x24, 0x52, + 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x65, 0x72, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x45, 0x78, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x87, 0x01, 0x0a, 0x1a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, + 0x69, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x1a, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x01, 0x52, + 0x08, 0x70, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x15, 0x0a, 0x06, 0x69, 0x73, 0x5f, + 0x72, 0x61, 0x77, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x69, 0x73, 0x52, 0x61, 0x77, + 0x3a, 0x1b, 0x92, 0x41, 0x18, 0x0a, 0x16, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, + 0x69, 0x64, 0xd2, 0x01, 0x08, 0x70, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x22, 0x1d, 0x0a, + 0x1b, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x67, + 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x87, 0x01, 0x0a, + 0x19, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x3b, 0x0a, 0x07, 0x6d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x65, + 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, + 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x07, + 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x3a, 0x17, 0x92, + 0x41, 0x14, 0x0a, 0x12, 0xd2, 0x01, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0xd2, 0x01, + 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x22, 0x1c, 0x0a, 0x1a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, + 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x21, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, + 0x72, 0x69, 0x61, 0x6c, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x4c, 0x0a, 0x10, 0x74, 0x72, + 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, + 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x0f, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, + 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, 0x18, 0x92, 0x41, 0x15, 0x0a, 0x13, 0xd2, + 0x01, 0x10, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, + 0x63, 0x73, 0x22, 0x24, 0x0a, 0x22, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x4c, 0x0a, 0x10, 0x74, 0x72, 0x61, 0x69, 0x6e, - 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, - 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x73, 0x52, 0x0f, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, - 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, 0x18, 0x92, 0x41, 0x15, 0x0a, 0x13, 0xd2, 0x01, 0x10, 0x74, - 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, - 0x24, 0x0a, 0x22, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x54, 0x72, - 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x93, 0x01, 0x0a, 0x23, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x54, 0x72, 0x69, 0x61, 0x6c, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x50, 0x0a, - 0x12, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x65, 0x74, 0x72, - 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x65, 0x74, 0x65, - 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, - 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x11, 0x76, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, - 0x1a, 0x92, 0x41, 0x17, 0x0a, 0x15, 0xd2, 0x01, 0x12, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x26, 0x0a, 0x24, 0x52, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x9e, 0x01, 0x0a, 0x1e, 0x50, 0x6f, 0x73, 0x74, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x52, 0x75, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, - 0x64, 0x12, 0x44, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, - 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, - 0x75, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x08, 0x6d, - 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x3a, 0x1b, 0x92, 0x41, 0x18, 0x0a, 0x16, 0xd2, 0x01, - 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x08, 0x6d, 0x65, 0x74, 0x61, - 0x64, 0x61, 0x74, 0x61, 0x22, 0x21, 0x0a, 0x1f, 0x50, 0x6f, 0x73, 0x74, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x52, 0x75, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x7f, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4d, 0x65, - 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x09, - 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x42, - 0x0f, 0x92, 0x41, 0x0c, 0xd2, 0x01, 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, - 0x52, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x05, 0x67, 0x72, - 0x6f, 0x75, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0b, 0x92, 0x41, 0x08, 0xd2, 0x01, - 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x3a, 0x19, 0x92, - 0x41, 0x16, 0x0a, 0x14, 0xd2, 0x01, 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, - 0xd2, 0x01, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x22, 0x63, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3c, - 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x22, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x70, - 0x6f, 0x72, 0x74, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, 0x0f, 0x92, 0x41, - 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x4b, 0x0a, - 0x19, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, - 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x72, - 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x08, 0x74, - 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x73, 0x3a, 0x11, 0x92, 0x41, 0x0e, 0x0a, 0x0c, 0xd2, 0x01, - 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x6b, 0x0a, 0x1a, 0x47, 0x65, - 0x74, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3c, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, - 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x64, 0x65, 0x74, 0x65, - 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, - 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x07, 0x6d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, - 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x4d, 0x0a, 0x1b, 0x47, 0x65, 0x74, 0x56, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, - 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, - 0x49, 0x64, 0x73, 0x3a, 0x11, 0x92, 0x41, 0x0e, 0x0a, 0x0c, 0xd2, 0x01, 0x09, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x6d, 0x0a, 0x1c, 0x47, 0x65, 0x74, 0x56, 0x61, 0x6c, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x93, 0x01, 0x0a, 0x23, 0x52, 0x65, 0x70, + 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x50, 0x0a, 0x12, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, + 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, + 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, + 0x11, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, + 0x63, 0x73, 0x3a, 0x1a, 0x92, 0x41, 0x17, 0x0a, 0x15, 0xd2, 0x01, 0x12, 0x76, 0x61, 0x6c, 0x69, + 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x26, + 0x0a, 0x24, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3c, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, - 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, - 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x07, 0x6d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x6d, 0x65, - 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x8a, 0x01, 0x0a, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x0d, - 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, - 0x64, 0x12, 0x31, 0x0a, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x07, 0x68, 0x70, 0x61, - 0x72, 0x61, 0x6d, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x6e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x64, 0x18, 0x28, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x75, 0x6e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x64, 0x22, 0x56, 0x0a, 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x05, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, - 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, - 0x72, 0x69, 0x61, 0x6c, 0x52, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x0d, 0x92, 0x41, 0x0a, - 0x0a, 0x08, 0xd2, 0x01, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x22, 0x96, 0x01, 0x0a, 0x0f, 0x50, - 0x75, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x57, - 0x0a, 0x14, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x72, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x64, - 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, - 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x52, 0x12, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2a, 0x0a, 0x11, 0x65, 0x78, 0x74, 0x65, 0x72, - 0x6e, 0x61, 0x6c, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x29, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x49, 0x64, 0x22, 0x53, 0x0a, 0x10, 0x50, 0x75, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x9e, 0x01, 0x0a, 0x1e, 0x50, 0x6f, 0x73, 0x74, 0x54, + 0x72, 0x69, 0x61, 0x6c, 0x52, 0x75, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, + 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, + 0x61, 0x6c, 0x49, 0x64, 0x12, 0x44, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, - 0x61, 0x6c, 0x52, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x0d, 0x92, 0x41, 0x0a, 0x0a, 0x08, - 0xd2, 0x01, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x22, 0x81, 0x01, 0x0a, 0x11, 0x50, 0x61, 0x74, - 0x63, 0x68, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, - 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x35, 0x0a, 0x05, 0x73, 0x74, 0x61, - 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, - 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x48, 0x00, 0x52, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x88, 0x01, 0x01, - 0x3a, 0x10, 0x92, 0x41, 0x0d, 0x0a, 0x0b, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, - 0x69, 0x64, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x22, 0x55, 0x0a, 0x12, - 0x50, 0x61, 0x74, 0x63, 0x68, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x30, 0x0a, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, - 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x05, 0x74, - 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x0d, 0x92, 0x41, 0x0a, 0x0a, 0x08, 0xd2, 0x01, 0x05, 0x74, 0x72, - 0x69, 0x61, 0x6c, 0x22, 0x58, 0x0a, 0x11, 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, - 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, - 0x6c, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x65, 0x73, 0x75, 0x6d, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x06, 0x72, 0x65, 0x73, 0x75, 0x6d, 0x65, 0x3a, 0x10, 0x92, 0x41, 0x0d, - 0x0a, 0x0b, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x22, 0xcf, 0x01, - 0x0a, 0x12, 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x20, 0x0a, 0x0c, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x72, 0x75, - 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x74, 0x72, 0x69, 0x61, - 0x6c, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x11, 0x6c, 0x61, 0x74, 0x65, 0x73, 0x74, - 0x5f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x48, 0x00, 0x52, 0x10, 0x6c, 0x61, 0x74, 0x65, 0x73, 0x74, 0x43, 0x68, 0x65, 0x63, 0x6b, - 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x88, 0x01, 0x01, 0x12, 0x27, 0x0a, 0x0f, 0x73, 0x74, 0x65, 0x70, - 0x73, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x0e, 0x73, 0x74, 0x65, 0x70, 0x73, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, - 0x64, 0x3a, 0x26, 0x92, 0x41, 0x23, 0x0a, 0x21, 0xd2, 0x01, 0x0c, 0x74, 0x72, 0x69, 0x61, 0x6c, - 0x5f, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0f, 0x73, 0x74, 0x65, 0x70, 0x73, 0x5f, - 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x6c, 0x61, - 0x74, 0x65, 0x73, 0x74, 0x5f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, - 0x8b, 0x01, 0x0a, 0x1c, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x50, 0x0a, 0x11, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, - 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, + 0x61, 0x6c, 0x52, 0x75, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x3a, 0x1b, 0x92, 0x41, 0x18, 0x0a, + 0x16, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x08, 0x6d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x22, 0x21, 0x0a, 0x1f, 0x50, 0x6f, 0x73, 0x74, 0x54, + 0x72, 0x69, 0x61, 0x6c, 0x52, 0x75, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x7f, 0x0a, 0x11, 0x47, 0x65, + 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x2c, 0x0a, 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x05, 0x42, 0x0f, 0x92, 0x41, 0x0c, 0xd2, 0x01, 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, + 0x69, 0x64, 0x73, 0x52, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, + 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0b, 0x92, 0x41, + 0x08, 0xd2, 0x01, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, + 0x3a, 0x19, 0x92, 0x41, 0x16, 0x0a, 0x14, 0xd2, 0x01, 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, + 0x69, 0x64, 0x73, 0xd2, 0x01, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x22, 0x63, 0x0a, 0x12, 0x47, + 0x65, 0x74, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x3c, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, + 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, + 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, + 0x22, 0x4b, 0x0a, 0x19, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, + 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, + 0x52, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x73, 0x3a, 0x11, 0x92, 0x41, 0x0e, 0x0a, + 0x0c, 0xd2, 0x01, 0x09, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x6b, 0x0a, + 0x1a, 0x47, 0x65, 0x74, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3c, 0x0a, 0x07, 0x6d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x64, + 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, + 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, + 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, + 0xd2, 0x01, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x4d, 0x0a, 0x1b, 0x47, 0x65, + 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, + 0x63, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x72, 0x69, + 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x08, 0x74, 0x72, + 0x69, 0x61, 0x6c, 0x49, 0x64, 0x73, 0x3a, 0x11, 0x92, 0x41, 0x0e, 0x0a, 0x0c, 0xd2, 0x01, 0x09, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x73, 0x22, 0x6d, 0x0a, 0x1c, 0x47, 0x65, 0x74, + 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3c, 0x0a, 0x07, 0x6d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x64, 0x65, 0x74, + 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, + 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x07, + 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x3a, 0x0f, 0x92, 0x41, 0x0c, 0x0a, 0x0a, 0xd2, 0x01, + 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x22, 0x8a, 0x01, 0x0a, 0x12, 0x43, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x23, 0x0a, 0x0d, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, + 0x6e, 0x74, 0x49, 0x64, 0x12, 0x31, 0x0a, 0x07, 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x07, + 0x68, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x6e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x64, 0x18, 0x28, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x75, 0x6e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x64, 0x22, 0x56, 0x0a, 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, + 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x05, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, - 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x6e, 0x66, - 0x6f, 0x52, 0x0f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x6e, - 0x66, 0x6f, 0x3a, 0x19, 0x92, 0x41, 0x16, 0x0a, 0x14, 0xd2, 0x01, 0x11, 0x74, 0x72, 0x69, 0x61, - 0x6c, 0x5f, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x87, 0x01, - 0x0a, 0x1d, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x6f, 0x75, - 0x72, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x27, 0x0a, 0x0f, 0x63, 0x68, - 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x55, - 0x75, 0x69, 0x64, 0x3a, 0x22, 0x92, 0x41, 0x1f, 0x0a, 0x1d, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, - 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, - 0x2d, 0x61, 0x69, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2f, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, 0x70, 0x69, 0x76, 0x31, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x0d, + 0x92, 0x41, 0x0a, 0x0a, 0x08, 0xd2, 0x01, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x22, 0x96, 0x01, + 0x0a, 0x0f, 0x50, 0x75, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x57, 0x0a, 0x14, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x74, 0x72, 0x69, 0x61, + 0x6c, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x25, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x61, 0x70, 0x69, + 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x52, 0x12, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, + 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2a, 0x0a, 0x11, 0x65, 0x78, + 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, + 0x29, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x54, + 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x22, 0x53, 0x0a, 0x10, 0x50, 0x75, 0x74, 0x54, 0x72, 0x69, + 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x05, 0x74, 0x72, + 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, + 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, + 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x0d, 0x92, 0x41, + 0x0a, 0x0a, 0x08, 0xd2, 0x01, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x22, 0x81, 0x01, 0x0a, 0x11, + 0x50, 0x61, 0x74, 0x63, 0x68, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x35, 0x0a, 0x05, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x64, 0x65, + 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, + 0x31, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x48, 0x00, 0x52, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, + 0x88, 0x01, 0x01, 0x3a, 0x10, 0x92, 0x41, 0x0d, 0x0a, 0x0b, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, + 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x22, + 0x55, 0x0a, 0x12, 0x50, 0x61, 0x74, 0x63, 0x68, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, + 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, + 0x52, 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x3a, 0x0d, 0x92, 0x41, 0x0a, 0x0a, 0x08, 0xd2, 0x01, + 0x05, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x22, 0x58, 0x0a, 0x11, 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, + 0x72, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x74, + 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, + 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x65, 0x73, 0x75, 0x6d, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x72, 0x65, 0x73, 0x75, 0x6d, 0x65, 0x3a, 0x10, + 0x92, 0x41, 0x0d, 0x0a, 0x0b, 0xd2, 0x01, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, + 0x22, 0xcf, 0x01, 0x0a, 0x12, 0x53, 0x74, 0x61, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x20, 0x0a, 0x0c, 0x74, 0x72, 0x69, 0x61, 0x6c, + 0x5f, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x74, + 0x72, 0x69, 0x61, 0x6c, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x11, 0x6c, 0x61, 0x74, + 0x65, 0x73, 0x74, 0x5f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x10, 0x6c, 0x61, 0x74, 0x65, 0x73, 0x74, 0x43, 0x68, + 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x88, 0x01, 0x01, 0x12, 0x27, 0x0a, 0x0f, 0x73, + 0x74, 0x65, 0x70, 0x73, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x0e, 0x73, 0x74, 0x65, 0x70, 0x73, 0x43, 0x6f, 0x6d, 0x70, 0x6c, + 0x65, 0x74, 0x65, 0x64, 0x3a, 0x26, 0x92, 0x41, 0x23, 0x0a, 0x21, 0xd2, 0x01, 0x0c, 0x74, 0x72, + 0x69, 0x61, 0x6c, 0x5f, 0x72, 0x75, 0x6e, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0f, 0x73, 0x74, 0x65, + 0x70, 0x73, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x42, 0x14, 0x0a, 0x12, + 0x5f, 0x6c, 0x61, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x22, 0x8b, 0x01, 0x0a, 0x1c, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, + 0x61, 0x6c, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x50, 0x0a, 0x11, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x6f, 0x75, + 0x72, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, + 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x74, 0x72, 0x69, 0x61, + 0x6c, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x3a, 0x19, 0x92, 0x41, 0x16, 0x0a, 0x14, 0xd2, 0x01, 0x11, 0x74, + 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x6e, 0x66, 0x6f, + 0x22, 0x87, 0x01, 0x0a, 0x1d, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, + 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x49, 0x64, 0x12, 0x27, 0x0a, + 0x0f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x55, 0x75, 0x69, 0x64, 0x3a, 0x22, 0x92, 0x41, 0x1f, 0x0a, 0x1d, 0xd2, 0x01, 0x08, + 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0xd2, 0x01, 0x0f, 0x63, 0x68, 0x65, 0x63, 0x6b, + 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, + 0x6e, 0x65, 0x64, 0x2d, 0x61, 0x69, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, + 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, 0x70, 0x69, 0x76, + 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -6113,193 +5878,185 @@ func file_determined_api_v1_trial_proto_rawDescGZIP() []byte { } var file_determined_api_v1_trial_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_determined_api_v1_trial_proto_msgTypes = make([]protoimpl.MessageInfo, 85) +var file_determined_api_v1_trial_proto_msgTypes = make([]protoimpl.MessageInfo, 81) var file_determined_api_v1_trial_proto_goTypes = []interface{}{ - (TrialSorter_Namespace)(0), // 0: determined.api.v1.TrialSorter.Namespace - (GetExperimentTrialsRequest_SortBy)(0), // 1: determined.api.v1.GetExperimentTrialsRequest.SortBy - (GetTrialWorkloadsRequest_FilterOption)(0), // 2: determined.api.v1.GetTrialWorkloadsRequest.FilterOption - (*DownsampledMetrics)(nil), // 3: determined.api.v1.DownsampledMetrics - (*WorkloadContainer)(nil), // 4: determined.api.v1.WorkloadContainer - (*ColumnFilter)(nil), // 5: determined.api.v1.ColumnFilter - (*TrialTag)(nil), // 6: determined.api.v1.TrialTag - (*TrialSorter)(nil), // 7: determined.api.v1.TrialSorter - (*TrialLogsRequest)(nil), // 8: determined.api.v1.TrialLogsRequest - (*TrialLogsResponse)(nil), // 9: determined.api.v1.TrialLogsResponse - (*TrialLogsFieldsRequest)(nil), // 10: determined.api.v1.TrialLogsFieldsRequest - (*TrialLogsFieldsResponse)(nil), // 11: determined.api.v1.TrialLogsFieldsResponse - (*GetTrialCheckpointsRequest)(nil), // 12: determined.api.v1.GetTrialCheckpointsRequest - (*GetTrialCheckpointsResponse)(nil), // 13: determined.api.v1.GetTrialCheckpointsResponse - (*KillTrialRequest)(nil), // 14: determined.api.v1.KillTrialRequest - (*KillTrialResponse)(nil), // 15: determined.api.v1.KillTrialResponse - (*GetExperimentTrialsRequest)(nil), // 16: determined.api.v1.GetExperimentTrialsRequest - (*GetExperimentTrialsResponse)(nil), // 17: determined.api.v1.GetExperimentTrialsResponse - (*GetTrialRemainingLogRetentionDaysRequest)(nil), // 18: determined.api.v1.GetTrialRemainingLogRetentionDaysRequest - (*GetTrialRemainingLogRetentionDaysResponse)(nil), // 19: determined.api.v1.GetTrialRemainingLogRetentionDaysResponse - (*GetTrialRequest)(nil), // 20: determined.api.v1.GetTrialRequest - (*GetTrialResponse)(nil), // 21: determined.api.v1.GetTrialResponse - (*GetTrialByExternalIDRequest)(nil), // 22: determined.api.v1.GetTrialByExternalIDRequest - (*GetTrialByExternalIDResponse)(nil), // 23: determined.api.v1.GetTrialByExternalIDResponse - (*GetTrialWorkloadsRequest)(nil), // 24: determined.api.v1.GetTrialWorkloadsRequest - (*GetTrialWorkloadsResponse)(nil), // 25: determined.api.v1.GetTrialWorkloadsResponse - (*GetTrialProfilerMetricsRequest)(nil), // 26: determined.api.v1.GetTrialProfilerMetricsRequest - (*GetTrialProfilerMetricsResponse)(nil), // 27: determined.api.v1.GetTrialProfilerMetricsResponse - (*GetTrialProfilerAvailableSeriesRequest)(nil), // 28: determined.api.v1.GetTrialProfilerAvailableSeriesRequest - (*GetTrialProfilerAvailableSeriesResponse)(nil), // 29: determined.api.v1.GetTrialProfilerAvailableSeriesResponse - (*PostTrialProfilerMetricsBatchRequest)(nil), // 30: determined.api.v1.PostTrialProfilerMetricsBatchRequest - (*PostTrialProfilerMetricsBatchResponse)(nil), // 31: determined.api.v1.PostTrialProfilerMetricsBatchResponse - (*ComparableTrial)(nil), // 32: determined.api.v1.ComparableTrial - (*CompareTrialsRequest)(nil), // 33: determined.api.v1.CompareTrialsRequest - (*PutTrialRetainLogsRequest)(nil), // 34: determined.api.v1.PutTrialRetainLogsRequest - (*PutTrialRetainLogsResponse)(nil), // 35: determined.api.v1.PutTrialRetainLogsResponse - (*CompareTrialsResponse)(nil), // 36: determined.api.v1.CompareTrialsResponse - (*AllocationPreemptionSignalRequest)(nil), // 37: determined.api.v1.AllocationPreemptionSignalRequest - (*AllocationPreemptionSignalResponse)(nil), // 38: determined.api.v1.AllocationPreemptionSignalResponse - (*AckAllocationPreemptionSignalRequest)(nil), // 39: determined.api.v1.AckAllocationPreemptionSignalRequest - (*AllocationPendingPreemptionSignalRequest)(nil), // 40: determined.api.v1.AllocationPendingPreemptionSignalRequest - (*AllocationPendingPreemptionSignalResponse)(nil), // 41: determined.api.v1.AllocationPendingPreemptionSignalResponse - (*AckAllocationPreemptionSignalResponse)(nil), // 42: determined.api.v1.AckAllocationPreemptionSignalResponse - (*MarkAllocationResourcesDaemonRequest)(nil), // 43: determined.api.v1.MarkAllocationResourcesDaemonRequest - (*MarkAllocationResourcesDaemonResponse)(nil), // 44: determined.api.v1.MarkAllocationResourcesDaemonResponse - (*AllocationRendezvousInfoRequest)(nil), // 45: determined.api.v1.AllocationRendezvousInfoRequest - (*AllocationRendezvousInfoResponse)(nil), // 46: determined.api.v1.AllocationRendezvousInfoResponse - (*PostAllocationProxyAddressRequest)(nil), // 47: determined.api.v1.PostAllocationProxyAddressRequest - (*PostAllocationProxyAddressResponse)(nil), // 48: determined.api.v1.PostAllocationProxyAddressResponse - (*PostAllocationAcceleratorDataRequest)(nil), // 49: determined.api.v1.PostAllocationAcceleratorDataRequest - (*PostAllocationAcceleratorDataResponse)(nil), // 50: determined.api.v1.PostAllocationAcceleratorDataResponse - (*AcceleratorData)(nil), // 51: determined.api.v1.AcceleratorData - (*AllocationAllGatherRequest)(nil), // 52: determined.api.v1.AllocationAllGatherRequest - (*AllocationAllGatherResponse)(nil), // 53: determined.api.v1.AllocationAllGatherResponse - (*NotifyContainerRunningRequest)(nil), // 54: determined.api.v1.NotifyContainerRunningRequest - (*NotifyContainerRunningResponse)(nil), // 55: determined.api.v1.NotifyContainerRunningResponse - (*GetCurrentTrialSearcherOperationRequest)(nil), // 56: determined.api.v1.GetCurrentTrialSearcherOperationRequest - (*GetCurrentTrialSearcherOperationResponse)(nil), // 57: determined.api.v1.GetCurrentTrialSearcherOperationResponse - (*CompleteTrialSearcherValidationRequest)(nil), // 58: determined.api.v1.CompleteTrialSearcherValidationRequest - (*CompleteTrialSearcherValidationResponse)(nil), // 59: determined.api.v1.CompleteTrialSearcherValidationResponse - (*ReportTrialSearcherEarlyExitRequest)(nil), // 60: determined.api.v1.ReportTrialSearcherEarlyExitRequest - (*ReportTrialSearcherEarlyExitResponse)(nil), // 61: determined.api.v1.ReportTrialSearcherEarlyExitResponse - (*ReportTrialProgressRequest)(nil), // 62: determined.api.v1.ReportTrialProgressRequest - (*ReportTrialProgressResponse)(nil), // 63: determined.api.v1.ReportTrialProgressResponse - (*ReportTrialMetricsRequest)(nil), // 64: determined.api.v1.ReportTrialMetricsRequest - (*ReportTrialMetricsResponse)(nil), // 65: determined.api.v1.ReportTrialMetricsResponse - (*ReportTrialTrainingMetricsRequest)(nil), // 66: determined.api.v1.ReportTrialTrainingMetricsRequest - (*ReportTrialTrainingMetricsResponse)(nil), // 67: determined.api.v1.ReportTrialTrainingMetricsResponse - (*ReportTrialValidationMetricsRequest)(nil), // 68: determined.api.v1.ReportTrialValidationMetricsRequest - (*ReportTrialValidationMetricsResponse)(nil), // 69: determined.api.v1.ReportTrialValidationMetricsResponse - (*PostTrialRunnerMetadataRequest)(nil), // 70: determined.api.v1.PostTrialRunnerMetadataRequest - (*PostTrialRunnerMetadataResponse)(nil), // 71: determined.api.v1.PostTrialRunnerMetadataResponse - (*GetMetricsRequest)(nil), // 72: determined.api.v1.GetMetricsRequest - (*GetMetricsResponse)(nil), // 73: determined.api.v1.GetMetricsResponse - (*GetTrainingMetricsRequest)(nil), // 74: determined.api.v1.GetTrainingMetricsRequest - (*GetTrainingMetricsResponse)(nil), // 75: determined.api.v1.GetTrainingMetricsResponse - (*GetValidationMetricsRequest)(nil), // 76: determined.api.v1.GetValidationMetricsRequest - (*GetValidationMetricsResponse)(nil), // 77: determined.api.v1.GetValidationMetricsResponse - (*CreateTrialRequest)(nil), // 78: determined.api.v1.CreateTrialRequest - (*CreateTrialResponse)(nil), // 79: determined.api.v1.CreateTrialResponse - (*PutTrialRequest)(nil), // 80: determined.api.v1.PutTrialRequest - (*PutTrialResponse)(nil), // 81: determined.api.v1.PutTrialResponse - (*PatchTrialRequest)(nil), // 82: determined.api.v1.PatchTrialRequest - (*PatchTrialResponse)(nil), // 83: determined.api.v1.PatchTrialResponse - (*StartTrialRequest)(nil), // 84: determined.api.v1.StartTrialRequest - (*StartTrialResponse)(nil), // 85: determined.api.v1.StartTrialResponse - (*ReportTrialSourceInfoRequest)(nil), // 86: determined.api.v1.ReportTrialSourceInfoRequest - (*ReportTrialSourceInfoResponse)(nil), // 87: determined.api.v1.ReportTrialSourceInfoResponse - (*DataPoint)(nil), // 88: determined.api.v1.DataPoint - (MetricType)(0), // 89: determined.api.v1.MetricType - (*trialv1.MetricsWorkload)(nil), // 90: determined.trial.v1.MetricsWorkload - (*trialv1.CheckpointWorkload)(nil), // 91: determined.trial.v1.CheckpointWorkload - (*commonv1.DoubleFieldFilter)(nil), // 92: determined.common.v1.DoubleFieldFilter - (OrderBy)(0), // 93: determined.api.v1.OrderBy - (logv1.LogLevel)(0), // 94: determined.log.v1.LogLevel - (*timestamp.Timestamp)(nil), // 95: google.protobuf.Timestamp - (checkpointv1.SortBy)(0), // 96: determined.checkpoint.v1.SortBy - (checkpointv1.State)(0), // 97: determined.checkpoint.v1.State - (*checkpointv1.Checkpoint)(nil), // 98: determined.checkpoint.v1.Checkpoint - (*Pagination)(nil), // 99: determined.api.v1.Pagination - (experimentv1.State)(0), // 100: determined.experiment.v1.State - (*trialv1.Trial)(nil), // 101: determined.trial.v1.Trial - (*trialv1.TrialProfilerMetricLabels)(nil), // 102: determined.trial.v1.TrialProfilerMetricLabels - (*trialv1.TrialProfilerMetricsBatch)(nil), // 103: determined.trial.v1.TrialProfilerMetricsBatch - (*commonv1.PolymorphicFilter)(nil), // 104: determined.common.v1.PolymorphicFilter - (*trialv1.RendezvousInfo)(nil), // 105: determined.trial.v1.RendezvousInfo - (*_struct.Struct)(nil), // 106: google.protobuf.Struct - (*experimentv1.TrialOperation)(nil), // 107: determined.experiment.v1.TrialOperation - (*experimentv1.CompleteValidateAfterOperation)(nil), // 108: determined.experiment.v1.CompleteValidateAfterOperation - (*trialv1.TrialEarlyExit)(nil), // 109: determined.trial.v1.TrialEarlyExit - (*trialv1.TrialMetrics)(nil), // 110: determined.trial.v1.TrialMetrics - (*trialv1.TrialRunnerMetadata)(nil), // 111: determined.trial.v1.TrialRunnerMetadata - (*trialv1.MetricsReport)(nil), // 112: determined.trial.v1.MetricsReport - (trialv1.State)(0), // 113: determined.trial.v1.State - (*trialv1.TrialSourceInfo)(nil), // 114: determined.trial.v1.TrialSourceInfo + (TrialSorter_Namespace)(0), // 0: determined.api.v1.TrialSorter.Namespace + (GetExperimentTrialsRequest_SortBy)(0), // 1: determined.api.v1.GetExperimentTrialsRequest.SortBy + (GetTrialWorkloadsRequest_FilterOption)(0), // 2: determined.api.v1.GetTrialWorkloadsRequest.FilterOption + (*DownsampledMetrics)(nil), // 3: determined.api.v1.DownsampledMetrics + (*WorkloadContainer)(nil), // 4: determined.api.v1.WorkloadContainer + (*ColumnFilter)(nil), // 5: determined.api.v1.ColumnFilter + (*TrialTag)(nil), // 6: determined.api.v1.TrialTag + (*TrialSorter)(nil), // 7: determined.api.v1.TrialSorter + (*TrialLogsRequest)(nil), // 8: determined.api.v1.TrialLogsRequest + (*TrialLogsResponse)(nil), // 9: determined.api.v1.TrialLogsResponse + (*TrialLogsFieldsRequest)(nil), // 10: determined.api.v1.TrialLogsFieldsRequest + (*TrialLogsFieldsResponse)(nil), // 11: determined.api.v1.TrialLogsFieldsResponse + (*GetTrialCheckpointsRequest)(nil), // 12: determined.api.v1.GetTrialCheckpointsRequest + (*GetTrialCheckpointsResponse)(nil), // 13: determined.api.v1.GetTrialCheckpointsResponse + (*KillTrialRequest)(nil), // 14: determined.api.v1.KillTrialRequest + (*KillTrialResponse)(nil), // 15: determined.api.v1.KillTrialResponse + (*GetExperimentTrialsRequest)(nil), // 16: determined.api.v1.GetExperimentTrialsRequest + (*GetExperimentTrialsResponse)(nil), // 17: determined.api.v1.GetExperimentTrialsResponse + (*GetTrialRemainingLogRetentionDaysRequest)(nil), // 18: determined.api.v1.GetTrialRemainingLogRetentionDaysRequest + (*GetTrialRemainingLogRetentionDaysResponse)(nil), // 19: determined.api.v1.GetTrialRemainingLogRetentionDaysResponse + (*GetTrialRequest)(nil), // 20: determined.api.v1.GetTrialRequest + (*GetTrialResponse)(nil), // 21: determined.api.v1.GetTrialResponse + (*GetTrialByExternalIDRequest)(nil), // 22: determined.api.v1.GetTrialByExternalIDRequest + (*GetTrialByExternalIDResponse)(nil), // 23: determined.api.v1.GetTrialByExternalIDResponse + (*GetTrialWorkloadsRequest)(nil), // 24: determined.api.v1.GetTrialWorkloadsRequest + (*GetTrialWorkloadsResponse)(nil), // 25: determined.api.v1.GetTrialWorkloadsResponse + (*GetTrialProfilerMetricsRequest)(nil), // 26: determined.api.v1.GetTrialProfilerMetricsRequest + (*GetTrialProfilerMetricsResponse)(nil), // 27: determined.api.v1.GetTrialProfilerMetricsResponse + (*GetTrialProfilerAvailableSeriesRequest)(nil), // 28: determined.api.v1.GetTrialProfilerAvailableSeriesRequest + (*GetTrialProfilerAvailableSeriesResponse)(nil), // 29: determined.api.v1.GetTrialProfilerAvailableSeriesResponse + (*PostTrialProfilerMetricsBatchRequest)(nil), // 30: determined.api.v1.PostTrialProfilerMetricsBatchRequest + (*PostTrialProfilerMetricsBatchResponse)(nil), // 31: determined.api.v1.PostTrialProfilerMetricsBatchResponse + (*ComparableTrial)(nil), // 32: determined.api.v1.ComparableTrial + (*CompareTrialsRequest)(nil), // 33: determined.api.v1.CompareTrialsRequest + (*PutTrialRetainLogsRequest)(nil), // 34: determined.api.v1.PutTrialRetainLogsRequest + (*PutTrialRetainLogsResponse)(nil), // 35: determined.api.v1.PutTrialRetainLogsResponse + (*CompareTrialsResponse)(nil), // 36: determined.api.v1.CompareTrialsResponse + (*AllocationPreemptionSignalRequest)(nil), // 37: determined.api.v1.AllocationPreemptionSignalRequest + (*AllocationPreemptionSignalResponse)(nil), // 38: determined.api.v1.AllocationPreemptionSignalResponse + (*AckAllocationPreemptionSignalRequest)(nil), // 39: determined.api.v1.AckAllocationPreemptionSignalRequest + (*AllocationPendingPreemptionSignalRequest)(nil), // 40: determined.api.v1.AllocationPendingPreemptionSignalRequest + (*AllocationPendingPreemptionSignalResponse)(nil), // 41: determined.api.v1.AllocationPendingPreemptionSignalResponse + (*AckAllocationPreemptionSignalResponse)(nil), // 42: determined.api.v1.AckAllocationPreemptionSignalResponse + (*MarkAllocationResourcesDaemonRequest)(nil), // 43: determined.api.v1.MarkAllocationResourcesDaemonRequest + (*MarkAllocationResourcesDaemonResponse)(nil), // 44: determined.api.v1.MarkAllocationResourcesDaemonResponse + (*AllocationRendezvousInfoRequest)(nil), // 45: determined.api.v1.AllocationRendezvousInfoRequest + (*AllocationRendezvousInfoResponse)(nil), // 46: determined.api.v1.AllocationRendezvousInfoResponse + (*PostAllocationProxyAddressRequest)(nil), // 47: determined.api.v1.PostAllocationProxyAddressRequest + (*PostAllocationProxyAddressResponse)(nil), // 48: determined.api.v1.PostAllocationProxyAddressResponse + (*PostAllocationAcceleratorDataRequest)(nil), // 49: determined.api.v1.PostAllocationAcceleratorDataRequest + (*PostAllocationAcceleratorDataResponse)(nil), // 50: determined.api.v1.PostAllocationAcceleratorDataResponse + (*AcceleratorData)(nil), // 51: determined.api.v1.AcceleratorData + (*AllocationAllGatherRequest)(nil), // 52: determined.api.v1.AllocationAllGatherRequest + (*AllocationAllGatherResponse)(nil), // 53: determined.api.v1.AllocationAllGatherResponse + (*NotifyContainerRunningRequest)(nil), // 54: determined.api.v1.NotifyContainerRunningRequest + (*NotifyContainerRunningResponse)(nil), // 55: determined.api.v1.NotifyContainerRunningResponse + (*ReportTrialSearcherEarlyExitRequest)(nil), // 56: determined.api.v1.ReportTrialSearcherEarlyExitRequest + (*ReportTrialSearcherEarlyExitResponse)(nil), // 57: determined.api.v1.ReportTrialSearcherEarlyExitResponse + (*ReportTrialProgressRequest)(nil), // 58: determined.api.v1.ReportTrialProgressRequest + (*ReportTrialProgressResponse)(nil), // 59: determined.api.v1.ReportTrialProgressResponse + (*ReportTrialMetricsRequest)(nil), // 60: determined.api.v1.ReportTrialMetricsRequest + (*ReportTrialMetricsResponse)(nil), // 61: determined.api.v1.ReportTrialMetricsResponse + (*ReportTrialTrainingMetricsRequest)(nil), // 62: determined.api.v1.ReportTrialTrainingMetricsRequest + (*ReportTrialTrainingMetricsResponse)(nil), // 63: determined.api.v1.ReportTrialTrainingMetricsResponse + (*ReportTrialValidationMetricsRequest)(nil), // 64: determined.api.v1.ReportTrialValidationMetricsRequest + (*ReportTrialValidationMetricsResponse)(nil), // 65: determined.api.v1.ReportTrialValidationMetricsResponse + (*PostTrialRunnerMetadataRequest)(nil), // 66: determined.api.v1.PostTrialRunnerMetadataRequest + (*PostTrialRunnerMetadataResponse)(nil), // 67: determined.api.v1.PostTrialRunnerMetadataResponse + (*GetMetricsRequest)(nil), // 68: determined.api.v1.GetMetricsRequest + (*GetMetricsResponse)(nil), // 69: determined.api.v1.GetMetricsResponse + (*GetTrainingMetricsRequest)(nil), // 70: determined.api.v1.GetTrainingMetricsRequest + (*GetTrainingMetricsResponse)(nil), // 71: determined.api.v1.GetTrainingMetricsResponse + (*GetValidationMetricsRequest)(nil), // 72: determined.api.v1.GetValidationMetricsRequest + (*GetValidationMetricsResponse)(nil), // 73: determined.api.v1.GetValidationMetricsResponse + (*CreateTrialRequest)(nil), // 74: determined.api.v1.CreateTrialRequest + (*CreateTrialResponse)(nil), // 75: determined.api.v1.CreateTrialResponse + (*PutTrialRequest)(nil), // 76: determined.api.v1.PutTrialRequest + (*PutTrialResponse)(nil), // 77: determined.api.v1.PutTrialResponse + (*PatchTrialRequest)(nil), // 78: determined.api.v1.PatchTrialRequest + (*PatchTrialResponse)(nil), // 79: determined.api.v1.PatchTrialResponse + (*StartTrialRequest)(nil), // 80: determined.api.v1.StartTrialRequest + (*StartTrialResponse)(nil), // 81: determined.api.v1.StartTrialResponse + (*ReportTrialSourceInfoRequest)(nil), // 82: determined.api.v1.ReportTrialSourceInfoRequest + (*ReportTrialSourceInfoResponse)(nil), // 83: determined.api.v1.ReportTrialSourceInfoResponse + (*DataPoint)(nil), // 84: determined.api.v1.DataPoint + (MetricType)(0), // 85: determined.api.v1.MetricType + (*trialv1.MetricsWorkload)(nil), // 86: determined.trial.v1.MetricsWorkload + (*trialv1.CheckpointWorkload)(nil), // 87: determined.trial.v1.CheckpointWorkload + (*commonv1.DoubleFieldFilter)(nil), // 88: determined.common.v1.DoubleFieldFilter + (OrderBy)(0), // 89: determined.api.v1.OrderBy + (logv1.LogLevel)(0), // 90: determined.log.v1.LogLevel + (*timestamp.Timestamp)(nil), // 91: google.protobuf.Timestamp + (checkpointv1.SortBy)(0), // 92: determined.checkpoint.v1.SortBy + (checkpointv1.State)(0), // 93: determined.checkpoint.v1.State + (*checkpointv1.Checkpoint)(nil), // 94: determined.checkpoint.v1.Checkpoint + (*Pagination)(nil), // 95: determined.api.v1.Pagination + (experimentv1.State)(0), // 96: determined.experiment.v1.State + (*trialv1.Trial)(nil), // 97: determined.trial.v1.Trial + (*trialv1.TrialProfilerMetricLabels)(nil), // 98: determined.trial.v1.TrialProfilerMetricLabels + (*trialv1.TrialProfilerMetricsBatch)(nil), // 99: determined.trial.v1.TrialProfilerMetricsBatch + (*commonv1.PolymorphicFilter)(nil), // 100: determined.common.v1.PolymorphicFilter + (*trialv1.RendezvousInfo)(nil), // 101: determined.trial.v1.RendezvousInfo + (*_struct.Struct)(nil), // 102: google.protobuf.Struct + (*trialv1.TrialEarlyExit)(nil), // 103: determined.trial.v1.TrialEarlyExit + (*trialv1.TrialMetrics)(nil), // 104: determined.trial.v1.TrialMetrics + (*trialv1.TrialRunnerMetadata)(nil), // 105: determined.trial.v1.TrialRunnerMetadata + (*trialv1.MetricsReport)(nil), // 106: determined.trial.v1.MetricsReport + (trialv1.State)(0), // 107: determined.trial.v1.State + (*trialv1.TrialSourceInfo)(nil), // 108: determined.trial.v1.TrialSourceInfo } var file_determined_api_v1_trial_proto_depIdxs = []int32{ - 88, // 0: determined.api.v1.DownsampledMetrics.data:type_name -> determined.api.v1.DataPoint - 89, // 1: determined.api.v1.DownsampledMetrics.type:type_name -> determined.api.v1.MetricType - 90, // 2: determined.api.v1.WorkloadContainer.training:type_name -> determined.trial.v1.MetricsWorkload - 90, // 3: determined.api.v1.WorkloadContainer.validation:type_name -> determined.trial.v1.MetricsWorkload - 91, // 4: determined.api.v1.WorkloadContainer.checkpoint:type_name -> determined.trial.v1.CheckpointWorkload - 92, // 5: determined.api.v1.ColumnFilter.filter:type_name -> determined.common.v1.DoubleFieldFilter + 84, // 0: determined.api.v1.DownsampledMetrics.data:type_name -> determined.api.v1.DataPoint + 85, // 1: determined.api.v1.DownsampledMetrics.type:type_name -> determined.api.v1.MetricType + 86, // 2: determined.api.v1.WorkloadContainer.training:type_name -> determined.trial.v1.MetricsWorkload + 86, // 3: determined.api.v1.WorkloadContainer.validation:type_name -> determined.trial.v1.MetricsWorkload + 87, // 4: determined.api.v1.WorkloadContainer.checkpoint:type_name -> determined.trial.v1.CheckpointWorkload + 88, // 5: determined.api.v1.ColumnFilter.filter:type_name -> determined.common.v1.DoubleFieldFilter 0, // 6: determined.api.v1.TrialSorter.namespace:type_name -> determined.api.v1.TrialSorter.Namespace - 93, // 7: determined.api.v1.TrialSorter.order_by:type_name -> determined.api.v1.OrderBy - 94, // 8: determined.api.v1.TrialLogsRequest.levels:type_name -> determined.log.v1.LogLevel - 95, // 9: determined.api.v1.TrialLogsRequest.timestamp_before:type_name -> google.protobuf.Timestamp - 95, // 10: determined.api.v1.TrialLogsRequest.timestamp_after:type_name -> google.protobuf.Timestamp - 93, // 11: determined.api.v1.TrialLogsRequest.order_by:type_name -> determined.api.v1.OrderBy - 95, // 12: determined.api.v1.TrialLogsResponse.timestamp:type_name -> google.protobuf.Timestamp - 94, // 13: determined.api.v1.TrialLogsResponse.level:type_name -> determined.log.v1.LogLevel - 96, // 14: determined.api.v1.GetTrialCheckpointsRequest.sort_by_attr:type_name -> determined.checkpoint.v1.SortBy - 93, // 15: determined.api.v1.GetTrialCheckpointsRequest.order_by:type_name -> determined.api.v1.OrderBy - 97, // 16: determined.api.v1.GetTrialCheckpointsRequest.states:type_name -> determined.checkpoint.v1.State - 98, // 17: determined.api.v1.GetTrialCheckpointsResponse.checkpoints:type_name -> determined.checkpoint.v1.Checkpoint - 99, // 18: determined.api.v1.GetTrialCheckpointsResponse.pagination:type_name -> determined.api.v1.Pagination + 89, // 7: determined.api.v1.TrialSorter.order_by:type_name -> determined.api.v1.OrderBy + 90, // 8: determined.api.v1.TrialLogsRequest.levels:type_name -> determined.log.v1.LogLevel + 91, // 9: determined.api.v1.TrialLogsRequest.timestamp_before:type_name -> google.protobuf.Timestamp + 91, // 10: determined.api.v1.TrialLogsRequest.timestamp_after:type_name -> google.protobuf.Timestamp + 89, // 11: determined.api.v1.TrialLogsRequest.order_by:type_name -> determined.api.v1.OrderBy + 91, // 12: determined.api.v1.TrialLogsResponse.timestamp:type_name -> google.protobuf.Timestamp + 90, // 13: determined.api.v1.TrialLogsResponse.level:type_name -> determined.log.v1.LogLevel + 92, // 14: determined.api.v1.GetTrialCheckpointsRequest.sort_by_attr:type_name -> determined.checkpoint.v1.SortBy + 89, // 15: determined.api.v1.GetTrialCheckpointsRequest.order_by:type_name -> determined.api.v1.OrderBy + 93, // 16: determined.api.v1.GetTrialCheckpointsRequest.states:type_name -> determined.checkpoint.v1.State + 94, // 17: determined.api.v1.GetTrialCheckpointsResponse.checkpoints:type_name -> determined.checkpoint.v1.Checkpoint + 95, // 18: determined.api.v1.GetTrialCheckpointsResponse.pagination:type_name -> determined.api.v1.Pagination 1, // 19: determined.api.v1.GetExperimentTrialsRequest.sort_by:type_name -> determined.api.v1.GetExperimentTrialsRequest.SortBy - 93, // 20: determined.api.v1.GetExperimentTrialsRequest.order_by:type_name -> determined.api.v1.OrderBy - 100, // 21: determined.api.v1.GetExperimentTrialsRequest.states:type_name -> determined.experiment.v1.State - 101, // 22: determined.api.v1.GetExperimentTrialsResponse.trials:type_name -> determined.trial.v1.Trial - 99, // 23: determined.api.v1.GetExperimentTrialsResponse.pagination:type_name -> determined.api.v1.Pagination - 101, // 24: determined.api.v1.GetTrialResponse.trial:type_name -> determined.trial.v1.Trial - 101, // 25: determined.api.v1.GetTrialByExternalIDResponse.trial:type_name -> determined.trial.v1.Trial - 93, // 26: determined.api.v1.GetTrialWorkloadsRequest.order_by:type_name -> determined.api.v1.OrderBy + 89, // 20: determined.api.v1.GetExperimentTrialsRequest.order_by:type_name -> determined.api.v1.OrderBy + 96, // 21: determined.api.v1.GetExperimentTrialsRequest.states:type_name -> determined.experiment.v1.State + 97, // 22: determined.api.v1.GetExperimentTrialsResponse.trials:type_name -> determined.trial.v1.Trial + 95, // 23: determined.api.v1.GetExperimentTrialsResponse.pagination:type_name -> determined.api.v1.Pagination + 97, // 24: determined.api.v1.GetTrialResponse.trial:type_name -> determined.trial.v1.Trial + 97, // 25: determined.api.v1.GetTrialByExternalIDResponse.trial:type_name -> determined.trial.v1.Trial + 89, // 26: determined.api.v1.GetTrialWorkloadsRequest.order_by:type_name -> determined.api.v1.OrderBy 2, // 27: determined.api.v1.GetTrialWorkloadsRequest.filter:type_name -> determined.api.v1.GetTrialWorkloadsRequest.FilterOption - 89, // 28: determined.api.v1.GetTrialWorkloadsRequest.metric_type:type_name -> determined.api.v1.MetricType + 85, // 28: determined.api.v1.GetTrialWorkloadsRequest.metric_type:type_name -> determined.api.v1.MetricType 4, // 29: determined.api.v1.GetTrialWorkloadsResponse.workloads:type_name -> determined.api.v1.WorkloadContainer - 99, // 30: determined.api.v1.GetTrialWorkloadsResponse.pagination:type_name -> determined.api.v1.Pagination - 102, // 31: determined.api.v1.GetTrialProfilerMetricsRequest.labels:type_name -> determined.trial.v1.TrialProfilerMetricLabels - 103, // 32: determined.api.v1.GetTrialProfilerMetricsResponse.batch:type_name -> determined.trial.v1.TrialProfilerMetricsBatch - 102, // 33: determined.api.v1.GetTrialProfilerAvailableSeriesResponse.labels:type_name -> determined.trial.v1.TrialProfilerMetricLabels - 103, // 34: determined.api.v1.PostTrialProfilerMetricsBatchRequest.batches:type_name -> determined.trial.v1.TrialProfilerMetricsBatch - 101, // 35: determined.api.v1.ComparableTrial.trial:type_name -> determined.trial.v1.Trial + 95, // 30: determined.api.v1.GetTrialWorkloadsResponse.pagination:type_name -> determined.api.v1.Pagination + 98, // 31: determined.api.v1.GetTrialProfilerMetricsRequest.labels:type_name -> determined.trial.v1.TrialProfilerMetricLabels + 99, // 32: determined.api.v1.GetTrialProfilerMetricsResponse.batch:type_name -> determined.trial.v1.TrialProfilerMetricsBatch + 98, // 33: determined.api.v1.GetTrialProfilerAvailableSeriesResponse.labels:type_name -> determined.trial.v1.TrialProfilerMetricLabels + 99, // 34: determined.api.v1.PostTrialProfilerMetricsBatchRequest.batches:type_name -> determined.trial.v1.TrialProfilerMetricsBatch + 97, // 35: determined.api.v1.ComparableTrial.trial:type_name -> determined.trial.v1.Trial 3, // 36: determined.api.v1.ComparableTrial.metrics:type_name -> determined.api.v1.DownsampledMetrics - 89, // 37: determined.api.v1.CompareTrialsRequest.metric_type:type_name -> determined.api.v1.MetricType - 104, // 38: determined.api.v1.CompareTrialsRequest.time_series_filter:type_name -> determined.common.v1.PolymorphicFilter + 85, // 37: determined.api.v1.CompareTrialsRequest.metric_type:type_name -> determined.api.v1.MetricType + 100, // 38: determined.api.v1.CompareTrialsRequest.time_series_filter:type_name -> determined.common.v1.PolymorphicFilter 32, // 39: determined.api.v1.CompareTrialsResponse.trials:type_name -> determined.api.v1.ComparableTrial - 105, // 40: determined.api.v1.AllocationRendezvousInfoResponse.rendezvous_info:type_name -> determined.trial.v1.RendezvousInfo + 101, // 40: determined.api.v1.AllocationRendezvousInfoResponse.rendezvous_info:type_name -> determined.trial.v1.RendezvousInfo 51, // 41: determined.api.v1.PostAllocationAcceleratorDataRequest.accelerator_data:type_name -> determined.api.v1.AcceleratorData - 106, // 42: determined.api.v1.AllocationAllGatherRequest.data:type_name -> google.protobuf.Struct - 106, // 43: determined.api.v1.AllocationAllGatherResponse.data:type_name -> google.protobuf.Struct - 106, // 44: determined.api.v1.NotifyContainerRunningRequest.data:type_name -> google.protobuf.Struct - 106, // 45: determined.api.v1.NotifyContainerRunningResponse.data:type_name -> google.protobuf.Struct - 107, // 46: determined.api.v1.GetCurrentTrialSearcherOperationResponse.op:type_name -> determined.experiment.v1.TrialOperation - 108, // 47: determined.api.v1.CompleteTrialSearcherValidationRequest.completed_operation:type_name -> determined.experiment.v1.CompleteValidateAfterOperation - 109, // 48: determined.api.v1.ReportTrialSearcherEarlyExitRequest.early_exit:type_name -> determined.trial.v1.TrialEarlyExit - 110, // 49: determined.api.v1.ReportTrialMetricsRequest.metrics:type_name -> determined.trial.v1.TrialMetrics - 110, // 50: determined.api.v1.ReportTrialTrainingMetricsRequest.training_metrics:type_name -> determined.trial.v1.TrialMetrics - 110, // 51: determined.api.v1.ReportTrialValidationMetricsRequest.validation_metrics:type_name -> determined.trial.v1.TrialMetrics - 111, // 52: determined.api.v1.PostTrialRunnerMetadataRequest.metadata:type_name -> determined.trial.v1.TrialRunnerMetadata - 112, // 53: determined.api.v1.GetMetricsResponse.metrics:type_name -> determined.trial.v1.MetricsReport - 112, // 54: determined.api.v1.GetTrainingMetricsResponse.metrics:type_name -> determined.trial.v1.MetricsReport - 112, // 55: determined.api.v1.GetValidationMetricsResponse.metrics:type_name -> determined.trial.v1.MetricsReport - 106, // 56: determined.api.v1.CreateTrialRequest.hparams:type_name -> google.protobuf.Struct - 101, // 57: determined.api.v1.CreateTrialResponse.trial:type_name -> determined.trial.v1.Trial - 78, // 58: determined.api.v1.PutTrialRequest.create_trial_request:type_name -> determined.api.v1.CreateTrialRequest - 101, // 59: determined.api.v1.PutTrialResponse.trial:type_name -> determined.trial.v1.Trial - 113, // 60: determined.api.v1.PatchTrialRequest.state:type_name -> determined.trial.v1.State - 101, // 61: determined.api.v1.PatchTrialResponse.trial:type_name -> determined.trial.v1.Trial - 114, // 62: determined.api.v1.ReportTrialSourceInfoRequest.trial_source_info:type_name -> determined.trial.v1.TrialSourceInfo - 63, // [63:63] is the sub-list for method output_type - 63, // [63:63] is the sub-list for method input_type - 63, // [63:63] is the sub-list for extension type_name - 63, // [63:63] is the sub-list for extension extendee - 0, // [0:63] is the sub-list for field type_name + 102, // 42: determined.api.v1.AllocationAllGatherRequest.data:type_name -> google.protobuf.Struct + 102, // 43: determined.api.v1.AllocationAllGatherResponse.data:type_name -> google.protobuf.Struct + 102, // 44: determined.api.v1.NotifyContainerRunningRequest.data:type_name -> google.protobuf.Struct + 102, // 45: determined.api.v1.NotifyContainerRunningResponse.data:type_name -> google.protobuf.Struct + 103, // 46: determined.api.v1.ReportTrialSearcherEarlyExitRequest.early_exit:type_name -> determined.trial.v1.TrialEarlyExit + 104, // 47: determined.api.v1.ReportTrialMetricsRequest.metrics:type_name -> determined.trial.v1.TrialMetrics + 104, // 48: determined.api.v1.ReportTrialTrainingMetricsRequest.training_metrics:type_name -> determined.trial.v1.TrialMetrics + 104, // 49: determined.api.v1.ReportTrialValidationMetricsRequest.validation_metrics:type_name -> determined.trial.v1.TrialMetrics + 105, // 50: determined.api.v1.PostTrialRunnerMetadataRequest.metadata:type_name -> determined.trial.v1.TrialRunnerMetadata + 106, // 51: determined.api.v1.GetMetricsResponse.metrics:type_name -> determined.trial.v1.MetricsReport + 106, // 52: determined.api.v1.GetTrainingMetricsResponse.metrics:type_name -> determined.trial.v1.MetricsReport + 106, // 53: determined.api.v1.GetValidationMetricsResponse.metrics:type_name -> determined.trial.v1.MetricsReport + 102, // 54: determined.api.v1.CreateTrialRequest.hparams:type_name -> google.protobuf.Struct + 97, // 55: determined.api.v1.CreateTrialResponse.trial:type_name -> determined.trial.v1.Trial + 74, // 56: determined.api.v1.PutTrialRequest.create_trial_request:type_name -> determined.api.v1.CreateTrialRequest + 97, // 57: determined.api.v1.PutTrialResponse.trial:type_name -> determined.trial.v1.Trial + 107, // 58: determined.api.v1.PatchTrialRequest.state:type_name -> determined.trial.v1.State + 97, // 59: determined.api.v1.PatchTrialResponse.trial:type_name -> determined.trial.v1.Trial + 108, // 60: determined.api.v1.ReportTrialSourceInfoRequest.trial_source_info:type_name -> determined.trial.v1.TrialSourceInfo + 61, // [61:61] is the sub-list for method output_type + 61, // [61:61] is the sub-list for method input_type + 61, // [61:61] is the sub-list for extension type_name + 61, // [61:61] is the sub-list for extension extendee + 0, // [0:61] is the sub-list for field type_name } func init() { file_determined_api_v1_trial_proto_init() } @@ -6947,54 +6704,6 @@ func file_determined_api_v1_trial_proto_init() { } } file_determined_api_v1_trial_proto_msgTypes[53].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetCurrentTrialSearcherOperationRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_trial_proto_msgTypes[54].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetCurrentTrialSearcherOperationResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_trial_proto_msgTypes[55].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CompleteTrialSearcherValidationRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_trial_proto_msgTypes[56].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CompleteTrialSearcherValidationResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_api_v1_trial_proto_msgTypes[57].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialSearcherEarlyExitRequest); i { case 0: return &v.state @@ -7006,7 +6715,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[58].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[54].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialSearcherEarlyExitResponse); i { case 0: return &v.state @@ -7018,7 +6727,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[59].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[55].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialProgressRequest); i { case 0: return &v.state @@ -7030,7 +6739,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[60].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[56].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialProgressResponse); i { case 0: return &v.state @@ -7042,7 +6751,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[61].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[57].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialMetricsRequest); i { case 0: return &v.state @@ -7054,7 +6763,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[62].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[58].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialMetricsResponse); i { case 0: return &v.state @@ -7066,7 +6775,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[63].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[59].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialTrainingMetricsRequest); i { case 0: return &v.state @@ -7078,7 +6787,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[64].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[60].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialTrainingMetricsResponse); i { case 0: return &v.state @@ -7090,7 +6799,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[65].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[61].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialValidationMetricsRequest); i { case 0: return &v.state @@ -7102,7 +6811,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[66].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[62].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialValidationMetricsResponse); i { case 0: return &v.state @@ -7114,7 +6823,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[67].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[63].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PostTrialRunnerMetadataRequest); i { case 0: return &v.state @@ -7126,7 +6835,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[68].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[64].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PostTrialRunnerMetadataResponse); i { case 0: return &v.state @@ -7138,7 +6847,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[69].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[65].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetMetricsRequest); i { case 0: return &v.state @@ -7150,7 +6859,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[70].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[66].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetMetricsResponse); i { case 0: return &v.state @@ -7162,7 +6871,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[71].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[67].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetTrainingMetricsRequest); i { case 0: return &v.state @@ -7174,7 +6883,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[72].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[68].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetTrainingMetricsResponse); i { case 0: return &v.state @@ -7186,7 +6895,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[73].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[69].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetValidationMetricsRequest); i { case 0: return &v.state @@ -7198,7 +6907,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[74].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[70].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetValidationMetricsResponse); i { case 0: return &v.state @@ -7210,7 +6919,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[75].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[71].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateTrialRequest); i { case 0: return &v.state @@ -7222,7 +6931,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[76].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[72].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateTrialResponse); i { case 0: return &v.state @@ -7234,7 +6943,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[77].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[73].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PutTrialRequest); i { case 0: return &v.state @@ -7246,7 +6955,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[78].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[74].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PutTrialResponse); i { case 0: return &v.state @@ -7258,7 +6967,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[79].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[75].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PatchTrialRequest); i { case 0: return &v.state @@ -7270,7 +6979,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[80].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[76].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PatchTrialResponse); i { case 0: return &v.state @@ -7282,7 +6991,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[81].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[77].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*StartTrialRequest); i { case 0: return &v.state @@ -7294,7 +7003,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[82].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[78].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*StartTrialResponse); i { case 0: return &v.state @@ -7306,7 +7015,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[83].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[79].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialSourceInfoRequest); i { case 0: return &v.state @@ -7318,7 +7027,7 @@ func file_determined_api_v1_trial_proto_init() { return nil } } - file_determined_api_v1_trial_proto_msgTypes[84].Exporter = func(v interface{}, i int) interface{} { + file_determined_api_v1_trial_proto_msgTypes[80].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ReportTrialSourceInfoResponse); i { case 0: return &v.state @@ -7342,15 +7051,15 @@ func file_determined_api_v1_trial_proto_init() { (*GetTrialCheckpointsRequest_SortByMetric)(nil), } file_determined_api_v1_trial_proto_msgTypes[16].OneofWrappers = []interface{}{} - file_determined_api_v1_trial_proto_msgTypes[79].OneofWrappers = []interface{}{} - file_determined_api_v1_trial_proto_msgTypes[82].OneofWrappers = []interface{}{} + file_determined_api_v1_trial_proto_msgTypes[75].OneofWrappers = []interface{}{} + file_determined_api_v1_trial_proto_msgTypes[78].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_determined_api_v1_trial_proto_rawDesc, NumEnums: 3, - NumMessages: 85, + NumMessages: 81, NumExtensions: 0, NumServices: 0, }, diff --git a/proto/pkg/experimentv1/searcher.pb.go b/proto/pkg/experimentv1/searcher.pb.go index df799911696..9d5bc7baa24 100644 --- a/proto/pkg/experimentv1/searcher.pb.go +++ b/proto/pkg/experimentv1/searcher.pb.go @@ -20,60 +20,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// RunnableType defines the type of operation that should be executed by trial -// runners. -type RunnableType int32 - -const ( - // Denotes an unknown runnable type. - RunnableType_RUNNABLE_TYPE_UNSPECIFIED RunnableType = 0 - // Signals to a trial runner that it should run a train. - RunnableType_RUNNABLE_TYPE_TRAIN RunnableType = 1 - // Signals to a trial runner it should compute validation metrics. - RunnableType_RUNNABLE_TYPE_VALIDATE RunnableType = 2 -) - -// Enum value maps for RunnableType. -var ( - RunnableType_name = map[int32]string{ - 0: "RUNNABLE_TYPE_UNSPECIFIED", - 1: "RUNNABLE_TYPE_TRAIN", - 2: "RUNNABLE_TYPE_VALIDATE", - } - RunnableType_value = map[string]int32{ - "RUNNABLE_TYPE_UNSPECIFIED": 0, - "RUNNABLE_TYPE_TRAIN": 1, - "RUNNABLE_TYPE_VALIDATE": 2, - } -) - -func (x RunnableType) Enum() *RunnableType { - p := new(RunnableType) - *p = x - return p -} - -func (x RunnableType) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (RunnableType) Descriptor() protoreflect.EnumDescriptor { - return file_determined_experiment_v1_searcher_proto_enumTypes[0].Descriptor() -} - -func (RunnableType) Type() protoreflect.EnumType { - return &file_determined_experiment_v1_searcher_proto_enumTypes[0] -} - -func (x RunnableType) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use RunnableType.Descriptor instead. -func (RunnableType) EnumDescriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{0} -} - // The reason for an early exit. type TrialExitedEarly_ExitedReason int32 @@ -116,11 +62,11 @@ func (x TrialExitedEarly_ExitedReason) String() string { } func (TrialExitedEarly_ExitedReason) Descriptor() protoreflect.EnumDescriptor { - return file_determined_experiment_v1_searcher_proto_enumTypes[1].Descriptor() + return file_determined_experiment_v1_searcher_proto_enumTypes[0].Descriptor() } func (TrialExitedEarly_ExitedReason) Type() protoreflect.EnumType { - return &file_determined_experiment_v1_searcher_proto_enumTypes[1] + return &file_determined_experiment_v1_searcher_proto_enumTypes[0] } func (x TrialExitedEarly_ExitedReason) Number() protoreflect.EnumNumber { @@ -518,177 +464,6 @@ func (x *ExperimentInactive) GetExperimentState() State { return State_STATE_UNSPECIFIED } -// SearcherEvent is a message from master to a client-driven custom searcher -// informing it of relevant changes in the state of an experiment. -type SearcherEvent struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Incremental ID of the event. - Id int32 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` - // The concrete event. - // - // Types that are assignable to Event: - // - // *SearcherEvent_InitialOperations - // *SearcherEvent_TrialCreated - // *SearcherEvent_ValidationCompleted - // *SearcherEvent_TrialClosed - // *SearcherEvent_TrialExitedEarly - // *SearcherEvent_TrialProgress - // *SearcherEvent_ExperimentInactive - Event isSearcherEvent_Event `protobuf_oneof:"event"` -} - -func (x *SearcherEvent) Reset() { - *x = SearcherEvent{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *SearcherEvent) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*SearcherEvent) ProtoMessage() {} - -func (x *SearcherEvent) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[7] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use SearcherEvent.ProtoReflect.Descriptor instead. -func (*SearcherEvent) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{7} -} - -func (x *SearcherEvent) GetId() int32 { - if x != nil { - return x.Id - } - return 0 -} - -func (m *SearcherEvent) GetEvent() isSearcherEvent_Event { - if m != nil { - return m.Event - } - return nil -} - -func (x *SearcherEvent) GetInitialOperations() *InitialOperations { - if x, ok := x.GetEvent().(*SearcherEvent_InitialOperations); ok { - return x.InitialOperations - } - return nil -} - -func (x *SearcherEvent) GetTrialCreated() *TrialCreated { - if x, ok := x.GetEvent().(*SearcherEvent_TrialCreated); ok { - return x.TrialCreated - } - return nil -} - -func (x *SearcherEvent) GetValidationCompleted() *ValidationCompleted { - if x, ok := x.GetEvent().(*SearcherEvent_ValidationCompleted); ok { - return x.ValidationCompleted - } - return nil -} - -func (x *SearcherEvent) GetTrialClosed() *TrialClosed { - if x, ok := x.GetEvent().(*SearcherEvent_TrialClosed); ok { - return x.TrialClosed - } - return nil -} - -func (x *SearcherEvent) GetTrialExitedEarly() *TrialExitedEarly { - if x, ok := x.GetEvent().(*SearcherEvent_TrialExitedEarly); ok { - return x.TrialExitedEarly - } - return nil -} - -func (x *SearcherEvent) GetTrialProgress() *TrialProgress { - if x, ok := x.GetEvent().(*SearcherEvent_TrialProgress); ok { - return x.TrialProgress - } - return nil -} - -func (x *SearcherEvent) GetExperimentInactive() *ExperimentInactive { - if x, ok := x.GetEvent().(*SearcherEvent_ExperimentInactive); ok { - return x.ExperimentInactive - } - return nil -} - -type isSearcherEvent_Event interface { - isSearcherEvent_Event() -} - -type SearcherEvent_InitialOperations struct { - // An experiment has just been created. - InitialOperations *InitialOperations `protobuf:"bytes,3,opt,name=initial_operations,json=initialOperations,proto3,oneof"` -} - -type SearcherEvent_TrialCreated struct { - // A trial has been created. - TrialCreated *TrialCreated `protobuf:"bytes,4,opt,name=trial_created,json=trialCreated,proto3,oneof"` -} - -type SearcherEvent_ValidationCompleted struct { - // Validation has completed. - ValidationCompleted *ValidationCompleted `protobuf:"bytes,5,opt,name=validation_completed,json=validationCompleted,proto3,oneof"` -} - -type SearcherEvent_TrialClosed struct { - // Trial has finished. - TrialClosed *TrialClosed `protobuf:"bytes,6,opt,name=trial_closed,json=trialClosed,proto3,oneof"` -} - -type SearcherEvent_TrialExitedEarly struct { - // Trial exited early. - TrialExitedEarly *TrialExitedEarly `protobuf:"bytes,7,opt,name=trial_exited_early,json=trialExitedEarly,proto3,oneof"` -} - -type SearcherEvent_TrialProgress struct { - // Trial progress. - TrialProgress *TrialProgress `protobuf:"bytes,8,opt,name=trial_progress,json=trialProgress,proto3,oneof"` -} - -type SearcherEvent_ExperimentInactive struct { - // Experiment is inactive. - ExperimentInactive *ExperimentInactive `protobuf:"bytes,9,opt,name=experiment_inactive,json=experimentInactive,proto3,oneof"` -} - -func (*SearcherEvent_InitialOperations) isSearcherEvent_Event() {} - -func (*SearcherEvent_TrialCreated) isSearcherEvent_Event() {} - -func (*SearcherEvent_ValidationCompleted) isSearcherEvent_Event() {} - -func (*SearcherEvent_TrialClosed) isSearcherEvent_Event() {} - -func (*SearcherEvent_TrialExitedEarly) isSearcherEvent_Event() {} - -func (*SearcherEvent_TrialProgress) isSearcherEvent_Event() {} - -func (*SearcherEvent_ExperimentInactive) isSearcherEvent_Event() {} - // ValidateAfterOperation means the trial should train and validate after // training the given length. type ValidateAfterOperation struct { @@ -705,7 +480,7 @@ type ValidateAfterOperation struct { func (x *ValidateAfterOperation) Reset() { *x = ValidateAfterOperation{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[8] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -718,7 +493,7 @@ func (x *ValidateAfterOperation) String() string { func (*ValidateAfterOperation) ProtoMessage() {} func (x *ValidateAfterOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[8] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[7] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -731,7 +506,7 @@ func (x *ValidateAfterOperation) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateAfterOperation.ProtoReflect.Descriptor instead. func (*ValidateAfterOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{8} + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{7} } func (x *ValidateAfterOperation) GetRequestId() string { @@ -748,115 +523,6 @@ func (x *ValidateAfterOperation) GetLength() uint64 { return 0 } -// SetSearcherProgressOperation informs the master of the progress of the custom -// searcher. -type SetSearcherProgressOperation struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Experiment progress as a float between 0.0 and 1.0. - Progress float64 `protobuf:"fixed64,1,opt,name=progress,proto3" json:"progress,omitempty"` -} - -func (x *SetSearcherProgressOperation) Reset() { - *x = SetSearcherProgressOperation{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *SetSearcherProgressOperation) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*SetSearcherProgressOperation) ProtoMessage() {} - -func (x *SetSearcherProgressOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[9] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use SetSearcherProgressOperation.ProtoReflect.Descriptor instead. -func (*SetSearcherProgressOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{9} -} - -func (x *SetSearcherProgressOperation) GetProgress() float64 { - if x != nil { - return x.Progress - } - return 0 -} - -// Used to complete a ValidateAfterOperation. -type CompleteValidateAfterOperation struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The ValidateAfterOperation being completed. - Op *ValidateAfterOperation `protobuf:"bytes,1,opt,name=op,proto3" json:"op,omitempty"` - // The value of searcher metric associated with this completed operation. - // The metric provided should be the metric used to guide HP search. - SearcherMetric *_struct.Value `protobuf:"bytes,2,opt,name=searcher_metric,json=searcherMetric,proto3" json:"searcher_metric,omitempty"` -} - -func (x *CompleteValidateAfterOperation) Reset() { - *x = CompleteValidateAfterOperation{} - if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[10] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *CompleteValidateAfterOperation) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*CompleteValidateAfterOperation) ProtoMessage() {} - -func (x *CompleteValidateAfterOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[10] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use CompleteValidateAfterOperation.ProtoReflect.Descriptor instead. -func (*CompleteValidateAfterOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{10} -} - -func (x *CompleteValidateAfterOperation) GetOp() *ValidateAfterOperation { - if x != nil { - return x.Op - } - return nil -} - -func (x *CompleteValidateAfterOperation) GetSearcherMetric() *_struct.Value { - if x != nil { - return x.SearcherMetric - } - return nil -} - // Create a trial with given hyperparameters. type CreateTrialOperation struct { state protoimpl.MessageState @@ -872,7 +538,7 @@ type CreateTrialOperation struct { func (x *CreateTrialOperation) Reset() { *x = CreateTrialOperation{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[11] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -885,7 +551,7 @@ func (x *CreateTrialOperation) String() string { func (*CreateTrialOperation) ProtoMessage() {} func (x *CreateTrialOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[11] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -898,7 +564,7 @@ func (x *CreateTrialOperation) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateTrialOperation.ProtoReflect.Descriptor instead. func (*CreateTrialOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{11} + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{8} } func (x *CreateTrialOperation) GetRequestId() string { @@ -928,7 +594,7 @@ type CloseTrialOperation struct { func (x *CloseTrialOperation) Reset() { *x = CloseTrialOperation{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[12] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -941,7 +607,7 @@ func (x *CloseTrialOperation) String() string { func (*CloseTrialOperation) ProtoMessage() {} func (x *CloseTrialOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[12] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -954,7 +620,7 @@ func (x *CloseTrialOperation) ProtoReflect() protoreflect.Message { // Deprecated: Use CloseTrialOperation.ProtoReflect.Descriptor instead. func (*CloseTrialOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{12} + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{9} } func (x *CloseTrialOperation) GetRequestId() string { @@ -979,7 +645,7 @@ type ShutDownOperation struct { func (x *ShutDownOperation) Reset() { *x = ShutDownOperation{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[13] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -992,7 +658,7 @@ func (x *ShutDownOperation) String() string { func (*ShutDownOperation) ProtoMessage() {} func (x *ShutDownOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[13] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1005,7 +671,7 @@ func (x *ShutDownOperation) ProtoReflect() protoreflect.Message { // Deprecated: Use ShutDownOperation.ProtoReflect.Descriptor instead. func (*ShutDownOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{13} + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{10} } func (x *ShutDownOperation) GetCancel() bool { @@ -1036,14 +702,13 @@ type SearcherOperation struct { // *SearcherOperation_CreateTrial // *SearcherOperation_CloseTrial // *SearcherOperation_ShutDown - // *SearcherOperation_SetSearcherProgress Union isSearcherOperation_Union `protobuf_oneof:"union"` } func (x *SearcherOperation) Reset() { *x = SearcherOperation{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[14] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1056,7 +721,7 @@ func (x *SearcherOperation) String() string { func (*SearcherOperation) ProtoMessage() {} func (x *SearcherOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[14] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1069,7 +734,7 @@ func (x *SearcherOperation) ProtoReflect() protoreflect.Message { // Deprecated: Use SearcherOperation.ProtoReflect.Descriptor instead. func (*SearcherOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{14} + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{11} } func (m *SearcherOperation) GetUnion() isSearcherOperation_Union { @@ -1107,13 +772,6 @@ func (x *SearcherOperation) GetShutDown() *ShutDownOperation { return nil } -func (x *SearcherOperation) GetSetSearcherProgress() *SetSearcherProgressOperation { - if x, ok := x.GetUnion().(*SearcherOperation_SetSearcherProgress); ok { - return x.SetSearcherProgress - } - return nil -} - type isSearcherOperation_Union interface { isSearcherOperation_Union() } @@ -1138,12 +796,6 @@ type SearcherOperation_ShutDown struct { ShutDown *ShutDownOperation `protobuf:"bytes,4,opt,name=shut_down,json=shutDown,proto3,oneof"` } -type SearcherOperation_SetSearcherProgress struct { - // SetSearcherProgressOperation is issued to set the progress of the custom - // search method. - SetSearcherProgress *SetSearcherProgressOperation `protobuf:"bytes,5,opt,name=set_searcher_progress,json=setSearcherProgress,proto3,oneof"` -} - func (*SearcherOperation_TrialOperation) isSearcherOperation_Union() {} func (*SearcherOperation_CreateTrial) isSearcherOperation_Union() {} @@ -1152,8 +804,6 @@ func (*SearcherOperation_CloseTrial) isSearcherOperation_Union() {} func (*SearcherOperation_ShutDown) isSearcherOperation_Union() {} -func (*SearcherOperation_SetSearcherProgress) isSearcherOperation_Union() {} - // TrialOperation is any operation that a trial can perform while it is active. type TrialOperation struct { state protoimpl.MessageState @@ -1171,7 +821,7 @@ type TrialOperation struct { func (x *TrialOperation) Reset() { *x = TrialOperation{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[15] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1184,7 +834,7 @@ func (x *TrialOperation) String() string { func (*TrialOperation) ProtoMessage() {} func (x *TrialOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[15] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[12] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1197,7 +847,7 @@ func (x *TrialOperation) ProtoReflect() protoreflect.Message { // Deprecated: Use TrialOperation.ProtoReflect.Descriptor instead. func (*TrialOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{15} + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{12} } func (m *TrialOperation) GetUnion() isTrialOperation_Union { @@ -1226,36 +876,37 @@ type TrialOperation_ValidateAfter struct { func (*TrialOperation_ValidateAfter) isTrialOperation_Union() {} -// RunnableOperation represents a single runnable operation emitted by a -// searcher. -type RunnableOperation struct { +// SearchUnit describes a length unit used by some searchers to manage training. +type SearchUnit struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // This is the type of the operation. - Type RunnableType `protobuf:"varint,1,opt,name=type,proto3,enum=determined.experiment.v1.RunnableType" json:"type,omitempty"` - // If the type == WORKLOAD_KIND_TRAIN, this is the number of units - Length uint64 `protobuf:"varint,2,opt,name=length,proto3" json:"length,omitempty"` + // Name of the length unit (if max_length is false). + Name *string `protobuf:"bytes,1,opt,name=name,proto3,oneof" json:"name,omitempty"` + // Value of the length unit (if max_length is false). + Value *int32 `protobuf:"varint,2,opt,name=value,proto3,oneof" json:"value,omitempty"` + // Bool indicating whether the training length is defined in code. + MaxLength bool `protobuf:"varint,3,opt,name=max_length,json=maxLength,proto3" json:"max_length,omitempty"` } -func (x *RunnableOperation) Reset() { - *x = RunnableOperation{} +func (x *SearchUnit) Reset() { + *x = SearchUnit{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[16] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } -func (x *RunnableOperation) String() string { +func (x *SearchUnit) String() string { return protoimpl.X.MessageStringOf(x) } -func (*RunnableOperation) ProtoMessage() {} +func (*SearchUnit) ProtoMessage() {} -func (x *RunnableOperation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[16] +func (x *SearchUnit) ProtoReflect() protoreflect.Message { + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[13] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1266,56 +917,62 @@ func (x *RunnableOperation) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use RunnableOperation.ProtoReflect.Descriptor instead. -func (*RunnableOperation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{16} +// Deprecated: Use SearchUnit.ProtoReflect.Descriptor instead. +func (*SearchUnit) Descriptor() ([]byte, []int) { + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{13} } -func (x *RunnableOperation) GetType() RunnableType { - if x != nil { - return x.Type +func (x *SearchUnit) GetName() string { + if x != nil && x.Name != nil { + return *x.Name } - return RunnableType_RUNNABLE_TYPE_UNSPECIFIED + return "" } -func (x *RunnableOperation) GetLength() uint64 { - if x != nil { - return x.Length +func (x *SearchUnit) GetValue() int32 { + if x != nil && x.Value != nil { + return *x.Value } return 0 } -// TrialSimulation is a specific sequence of workloads that were run before the -// trial was completed. -type TrialSimulation struct { +func (x *SearchUnit) GetMaxLength() bool { + if x != nil { + return x.MaxLength + } + return false +} + +// TrialSummary describes the runs that are estimated to train for a certain +// length. +type TrialSummary struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // The list of operations that were run before the trial was completed. - Operations []*RunnableOperation `protobuf:"bytes,1,rep,name=operations,proto3" json:"operations,omitempty"` - // The number of times that this trial configuration has occurred during the - // simulation. - Occurrences int32 `protobuf:"varint,2,opt,name=occurrences,proto3" json:"occurrences,omitempty"` + // Number of trials. + Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` + // Training length for the trials. + Unit *SearchUnit `protobuf:"bytes,2,opt,name=unit,proto3" json:"unit,omitempty"` } -func (x *TrialSimulation) Reset() { - *x = TrialSimulation{} +func (x *TrialSummary) Reset() { + *x = TrialSummary{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[17] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } -func (x *TrialSimulation) String() string { +func (x *TrialSummary) String() string { return protoimpl.X.MessageStringOf(x) } -func (*TrialSimulation) ProtoMessage() {} +func (*TrialSummary) ProtoMessage() {} -func (x *TrialSimulation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[17] +func (x *TrialSummary) ProtoReflect() protoreflect.Message { + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[14] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1326,57 +983,55 @@ func (x *TrialSimulation) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use TrialSimulation.ProtoReflect.Descriptor instead. -func (*TrialSimulation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{17} +// Deprecated: Use TrialSummary.ProtoReflect.Descriptor instead. +func (*TrialSummary) Descriptor() ([]byte, []int) { + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{14} } -func (x *TrialSimulation) GetOperations() []*RunnableOperation { +func (x *TrialSummary) GetCount() int32 { if x != nil { - return x.Operations + return x.Count } - return nil + return 0 } -func (x *TrialSimulation) GetOccurrences() int32 { +func (x *TrialSummary) GetUnit() *SearchUnit { if x != nil { - return x.Occurrences + return x.Unit } - return 0 + return nil } -// ExperimentSimulation holds the configuration and results of simulated run of -// a searcher. -type ExperimentSimulation struct { +// SearchSummary contains the estimated trials and training lengths that a +// search plans to execute. +type SearchSummary struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // The simulated experiment config. + // The searcher config from which the summary is generated. Config *_struct.Struct `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` - // The searcher simulation seed. - Seed uint32 `protobuf:"varint,2,opt,name=seed,proto3" json:"seed,omitempty"` - // The list of trials in the simulation. - Trials []*TrialSimulation `protobuf:"bytes,3,rep,name=trials,proto3" json:"trials,omitempty"` + // A list of planned number of trials to their training lengths. + Trials []*TrialSummary `protobuf:"bytes,2,rep,name=trials,proto3" json:"trials,omitempty"` } -func (x *ExperimentSimulation) Reset() { - *x = ExperimentSimulation{} +func (x *SearchSummary) Reset() { + *x = SearchSummary{} if protoimpl.UnsafeEnabled { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[18] + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } -func (x *ExperimentSimulation) String() string { +func (x *SearchSummary) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ExperimentSimulation) ProtoMessage() {} +func (*SearchSummary) ProtoMessage() {} -func (x *ExperimentSimulation) ProtoReflect() protoreflect.Message { - mi := &file_determined_experiment_v1_searcher_proto_msgTypes[18] +func (x *SearchSummary) ProtoReflect() protoreflect.Message { + mi := &file_determined_experiment_v1_searcher_proto_msgTypes[15] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1387,26 +1042,19 @@ func (x *ExperimentSimulation) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ExperimentSimulation.ProtoReflect.Descriptor instead. -func (*ExperimentSimulation) Descriptor() ([]byte, []int) { - return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{18} +// Deprecated: Use SearchSummary.ProtoReflect.Descriptor instead. +func (*SearchSummary) Descriptor() ([]byte, []int) { + return file_determined_experiment_v1_searcher_proto_rawDescGZIP(), []int{15} } -func (x *ExperimentSimulation) GetConfig() *_struct.Struct { +func (x *SearchSummary) GetConfig() *_struct.Struct { if x != nil { return x.Config } return nil } -func (x *ExperimentSimulation) GetSeed() uint32 { - if x != nil { - return x.Seed - } - return 0 -} - -func (x *ExperimentSimulation) GetTrials() []*TrialSimulation { +func (x *SearchSummary) GetTrials() []*TrialSummary { if x != nil { return x.Trials } @@ -1488,154 +1136,85 @@ var file_determined_experiment_v1_searcher_proto_rawDesc = []byte{ 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x3a, 0x18, 0x92, 0x41, 0x15, 0x0a, 0x13, 0xd2, 0x01, 0x10, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, - 0x74, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x22, 0xa0, 0x05, 0x0a, 0x0d, 0x53, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x65, 0x72, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x12, 0x5c, 0x0a, 0x12, 0x69, 0x6e, 0x69, - 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, - 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, - 0x2e, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x73, 0x48, 0x00, 0x52, 0x11, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, - 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x4d, 0x0a, 0x0d, 0x74, 0x72, 0x69, 0x61, 0x6c, - 0x5f, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x26, - 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, - 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x43, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x48, 0x00, 0x52, 0x0c, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x43, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x12, 0x62, 0x0a, 0x14, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, - 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, - 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x64, 0x48, 0x00, 0x52, 0x13, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x12, 0x4a, 0x0a, 0x0c, 0x74, 0x72, - 0x69, 0x61, 0x6c, 0x5f, 0x63, 0x6c, 0x6f, 0x73, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x25, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x64, 0x48, 0x00, 0x52, 0x0b, 0x74, 0x72, 0x69, 0x61, 0x6c, - 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x64, 0x12, 0x5a, 0x0a, 0x12, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, - 0x65, 0x78, 0x69, 0x74, 0x65, 0x64, 0x5f, 0x65, 0x61, 0x72, 0x6c, 0x79, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, - 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, - 0x69, 0x61, 0x6c, 0x45, 0x78, 0x69, 0x74, 0x65, 0x64, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x48, 0x00, - 0x52, 0x10, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x45, 0x78, 0x69, 0x74, 0x65, 0x64, 0x45, 0x61, 0x72, - 0x6c, 0x79, 0x12, 0x50, 0x0a, 0x0e, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x70, 0x72, 0x6f, 0x67, - 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x27, 0x2e, 0x64, 0x65, 0x74, - 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x67, 0x72, - 0x65, 0x73, 0x73, 0x48, 0x00, 0x52, 0x0d, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x67, - 0x72, 0x65, 0x73, 0x73, 0x12, 0x5f, 0x0a, 0x13, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, - 0x6e, 0x74, 0x5f, 0x69, 0x6e, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x2c, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x6e, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x48, - 0x00, 0x52, 0x12, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x6e, 0x61, - 0x63, 0x74, 0x69, 0x76, 0x65, 0x3a, 0x0a, 0x92, 0x41, 0x07, 0x0a, 0x05, 0xd2, 0x01, 0x02, 0x69, - 0x64, 0x42, 0x07, 0x0a, 0x05, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x4f, 0x0a, 0x16, 0x56, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x04, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x22, 0x3a, 0x0a, 0x1c, 0x53, - 0x65, 0x74, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, - 0x73, 0x73, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, - 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x01, 0x52, 0x08, 0x70, - 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x1e, 0x43, 0x6f, 0x6d, 0x70, - 0x6c, 0x65, 0x74, 0x65, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, 0x65, - 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x40, 0x0a, 0x02, 0x6f, 0x70, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, - 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, - 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, 0x65, 0x72, 0x4f, - 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x02, 0x6f, 0x70, 0x12, 0x3f, 0x0a, 0x0f, - 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x5f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x0e, 0x73, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x22, 0x57, 0x0a, - 0x14, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x68, 0x79, 0x70, 0x65, 0x72, 0x70, 0x61, 0x72, - 0x61, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x79, 0x70, 0x65, 0x72, - 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x22, 0x34, 0x0a, 0x13, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x54, - 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, - 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x22, 0x45, 0x0a, 0x11, - 0x53, 0x68, 0x75, 0x74, 0x44, 0x6f, 0x77, 0x6e, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x63, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x06, 0x63, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x66, 0x61, 0x69, - 0x6c, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x66, 0x61, 0x69, 0x6c, - 0x75, 0x72, 0x65, 0x22, 0xd2, 0x03, 0x0a, 0x11, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, - 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x53, 0x0a, 0x0f, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x5f, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, - 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, - 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0e, - 0x74, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x53, - 0x0a, 0x0c, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2e, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, - 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, - 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0b, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, - 0x69, 0x61, 0x6c, 0x12, 0x50, 0x0a, 0x0b, 0x63, 0x6c, 0x6f, 0x73, 0x65, 0x5f, 0x74, 0x72, 0x69, - 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, - 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, - 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0a, 0x63, 0x6c, 0x6f, 0x73, 0x65, - 0x54, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x4a, 0x0a, 0x09, 0x73, 0x68, 0x75, 0x74, 0x5f, 0x64, 0x6f, - 0x77, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, - 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x68, 0x75, 0x74, 0x44, 0x6f, 0x77, 0x6e, 0x4f, 0x70, 0x65, 0x72, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x08, 0x73, 0x68, 0x75, 0x74, 0x44, 0x6f, 0x77, - 0x6e, 0x12, 0x6c, 0x0a, 0x15, 0x73, 0x65, 0x74, 0x5f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, - 0x72, 0x5f, 0x70, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x36, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, - 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x74, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x4f, - 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x13, 0x73, 0x65, 0x74, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x42, - 0x07, 0x0a, 0x05, 0x75, 0x6e, 0x69, 0x6f, 0x6e, 0x22, 0x74, 0x0a, 0x0e, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x59, 0x0a, 0x0e, 0x76, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x66, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, - 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x56, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0d, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, - 0x41, 0x66, 0x74, 0x65, 0x72, 0x42, 0x07, 0x0a, 0x05, 0x75, 0x6e, 0x69, 0x6f, 0x6e, 0x22, 0x67, - 0x0a, 0x11, 0x52, 0x75, 0x6e, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x12, 0x3a, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x26, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x75, 0x6e, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, - 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, - 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x22, 0x80, 0x01, 0x0a, 0x0f, 0x54, 0x72, 0x69, 0x61, - 0x6c, 0x53, 0x69, 0x6d, 0x75, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x4b, 0x0a, 0x0a, 0x6f, - 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x2b, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, - 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x75, 0x6e, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0a, 0x6f, 0x70, - 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x6f, 0x63, 0x63, 0x75, - 0x72, 0x72, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x6f, - 0x63, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0x9e, 0x01, 0x0a, 0x14, 0x45, - 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x69, 0x6d, 0x75, 0x6c, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0d, 0x52, 0x04, 0x73, 0x65, 0x65, 0x64, 0x12, 0x41, 0x0a, 0x06, 0x74, 0x72, 0x69, 0x61, - 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, + 0x74, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x22, 0x4f, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, + 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, + 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x22, 0x57, 0x0a, 0x14, 0x43, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, + 0x20, 0x0a, 0x0b, 0x68, 0x79, 0x70, 0x65, 0x72, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x79, 0x70, 0x65, 0x72, 0x70, 0x61, 0x72, 0x61, 0x6d, + 0x73, 0x22, 0x34, 0x0a, 0x13, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, + 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x22, 0x45, 0x0a, 0x11, 0x53, 0x68, 0x75, 0x74, 0x44, + 0x6f, 0x77, 0x6e, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, + 0x63, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x63, 0x61, + 0x6e, 0x63, 0x65, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x66, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x66, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x22, 0xe4, + 0x02, 0x0a, 0x11, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x53, 0x0a, 0x0f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x5f, 0x6f, 0x70, + 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, + 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, + 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0e, 0x74, 0x72, 0x69, 0x61, 0x6c, + 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x53, 0x0a, 0x0c, 0x63, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x2e, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, + 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, + 0x00, 0x52, 0x0b, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x50, + 0x0a, 0x0b, 0x63, 0x6c, 0x6f, 0x73, 0x65, 0x5f, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, + 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, + 0x6c, 0x6f, 0x73, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0a, 0x63, 0x6c, 0x6f, 0x73, 0x65, 0x54, 0x72, 0x69, 0x61, 0x6c, + 0x12, 0x4a, 0x0a, 0x09, 0x73, 0x68, 0x75, 0x74, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, + 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x53, + 0x68, 0x75, 0x74, 0x44, 0x6f, 0x77, 0x6e, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x48, 0x00, 0x52, 0x08, 0x73, 0x68, 0x75, 0x74, 0x44, 0x6f, 0x77, 0x6e, 0x42, 0x07, 0x0a, 0x05, + 0x75, 0x6e, 0x69, 0x6f, 0x6e, 0x22, 0x74, 0x0a, 0x0e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x4f, 0x70, + 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x59, 0x0a, 0x0e, 0x76, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x5f, 0x61, 0x66, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x30, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, + 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, 0x65, 0x72, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x48, 0x00, 0x52, 0x0d, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x66, 0x74, + 0x65, 0x72, 0x42, 0x07, 0x0a, 0x05, 0x75, 0x6e, 0x69, 0x6f, 0x6e, 0x22, 0x86, 0x01, 0x0a, 0x0a, + 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x55, 0x6e, 0x69, 0x74, 0x12, 0x17, 0x0a, 0x04, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x88, 0x01, 0x01, 0x12, 0x19, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x05, 0x48, 0x01, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x88, 0x01, 0x01, 0x12, 0x1d, + 0x0a, 0x0a, 0x6d, 0x61, 0x78, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x09, 0x6d, 0x61, 0x78, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x3a, 0x12, 0x92, + 0x41, 0x0f, 0x0a, 0x0d, 0xd2, 0x01, 0x0a, 0x6d, 0x61, 0x78, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, + 0x68, 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x22, 0x74, 0x0a, 0x0c, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x75, 0x6d, + 0x6d, 0x61, 0x72, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x05, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x38, 0x0a, 0x04, 0x75, 0x6e, + 0x69, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x69, 0x6d, 0x75, 0x6c, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x2a, 0x62, 0x0a, 0x0c, 0x52, - 0x75, 0x6e, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1d, 0x0a, 0x19, 0x52, - 0x55, 0x4e, 0x4e, 0x41, 0x42, 0x4c, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, - 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x52, 0x55, - 0x4e, 0x4e, 0x41, 0x42, 0x4c, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x54, 0x52, 0x41, 0x49, - 0x4e, 0x10, 0x01, 0x12, 0x1a, 0x0a, 0x16, 0x52, 0x55, 0x4e, 0x4e, 0x41, 0x42, 0x4c, 0x45, 0x5f, - 0x54, 0x59, 0x50, 0x45, 0x5f, 0x56, 0x41, 0x4c, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x42, - 0x3c, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, - 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2d, 0x61, 0x69, 0x2f, 0x64, 0x65, 0x74, 0x65, - 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6b, 0x67, - 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x76, 0x31, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x55, 0x6e, 0x69, 0x74, 0x52, 0x04, + 0x75, 0x6e, 0x69, 0x74, 0x3a, 0x14, 0x92, 0x41, 0x11, 0x0a, 0x0f, 0xd2, 0x01, 0x05, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0xd2, 0x01, 0x04, 0x75, 0x6e, 0x69, 0x74, 0x22, 0x97, 0x01, 0x0a, 0x0d, 0x53, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x53, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x2f, 0x0a, 0x06, + 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, + 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, + 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x26, 0x2e, + 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2e, 0x65, 0x78, 0x70, 0x65, 0x72, + 0x69, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x75, + 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x52, 0x06, 0x74, 0x72, 0x69, 0x61, 0x6c, 0x73, 0x3a, 0x15, 0x92, + 0x41, 0x12, 0x0a, 0x10, 0xd2, 0x01, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0xd2, 0x01, 0x04, + 0x72, 0x75, 0x6e, 0x73, 0x42, 0x3c, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2d, 0x61, 0x69, + 0x2f, 0x64, 0x65, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6e, 0x65, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x6d, 0x65, 0x6e, 0x74, + 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -1650,62 +1229,47 @@ func file_determined_experiment_v1_searcher_proto_rawDescGZIP() []byte { return file_determined_experiment_v1_searcher_proto_rawDescData } -var file_determined_experiment_v1_searcher_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_determined_experiment_v1_searcher_proto_msgTypes = make([]protoimpl.MessageInfo, 19) +var file_determined_experiment_v1_searcher_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_determined_experiment_v1_searcher_proto_msgTypes = make([]protoimpl.MessageInfo, 16) var file_determined_experiment_v1_searcher_proto_goTypes = []interface{}{ - (RunnableType)(0), // 0: determined.experiment.v1.RunnableType - (TrialExitedEarly_ExitedReason)(0), // 1: determined.experiment.v1.TrialExitedEarly.ExitedReason - (*InitialOperations)(nil), // 2: determined.experiment.v1.InitialOperations - (*TrialCreated)(nil), // 3: determined.experiment.v1.TrialCreated - (*TrialProgress)(nil), // 4: determined.experiment.v1.TrialProgress - (*ValidationCompleted)(nil), // 5: determined.experiment.v1.ValidationCompleted - (*TrialClosed)(nil), // 6: determined.experiment.v1.TrialClosed - (*TrialExitedEarly)(nil), // 7: determined.experiment.v1.TrialExitedEarly - (*ExperimentInactive)(nil), // 8: determined.experiment.v1.ExperimentInactive - (*SearcherEvent)(nil), // 9: determined.experiment.v1.SearcherEvent - (*ValidateAfterOperation)(nil), // 10: determined.experiment.v1.ValidateAfterOperation - (*SetSearcherProgressOperation)(nil), // 11: determined.experiment.v1.SetSearcherProgressOperation - (*CompleteValidateAfterOperation)(nil), // 12: determined.experiment.v1.CompleteValidateAfterOperation - (*CreateTrialOperation)(nil), // 13: determined.experiment.v1.CreateTrialOperation - (*CloseTrialOperation)(nil), // 14: determined.experiment.v1.CloseTrialOperation - (*ShutDownOperation)(nil), // 15: determined.experiment.v1.ShutDownOperation - (*SearcherOperation)(nil), // 16: determined.experiment.v1.SearcherOperation - (*TrialOperation)(nil), // 17: determined.experiment.v1.TrialOperation - (*RunnableOperation)(nil), // 18: determined.experiment.v1.RunnableOperation - (*TrialSimulation)(nil), // 19: determined.experiment.v1.TrialSimulation - (*ExperimentSimulation)(nil), // 20: determined.experiment.v1.ExperimentSimulation - (*_struct.Value)(nil), // 21: google.protobuf.Value - (State)(0), // 22: determined.experiment.v1.State - (*_struct.Struct)(nil), // 23: google.protobuf.Struct + (TrialExitedEarly_ExitedReason)(0), // 0: determined.experiment.v1.TrialExitedEarly.ExitedReason + (*InitialOperations)(nil), // 1: determined.experiment.v1.InitialOperations + (*TrialCreated)(nil), // 2: determined.experiment.v1.TrialCreated + (*TrialProgress)(nil), // 3: determined.experiment.v1.TrialProgress + (*ValidationCompleted)(nil), // 4: determined.experiment.v1.ValidationCompleted + (*TrialClosed)(nil), // 5: determined.experiment.v1.TrialClosed + (*TrialExitedEarly)(nil), // 6: determined.experiment.v1.TrialExitedEarly + (*ExperimentInactive)(nil), // 7: determined.experiment.v1.ExperimentInactive + (*ValidateAfterOperation)(nil), // 8: determined.experiment.v1.ValidateAfterOperation + (*CreateTrialOperation)(nil), // 9: determined.experiment.v1.CreateTrialOperation + (*CloseTrialOperation)(nil), // 10: determined.experiment.v1.CloseTrialOperation + (*ShutDownOperation)(nil), // 11: determined.experiment.v1.ShutDownOperation + (*SearcherOperation)(nil), // 12: determined.experiment.v1.SearcherOperation + (*TrialOperation)(nil), // 13: determined.experiment.v1.TrialOperation + (*SearchUnit)(nil), // 14: determined.experiment.v1.SearchUnit + (*TrialSummary)(nil), // 15: determined.experiment.v1.TrialSummary + (*SearchSummary)(nil), // 16: determined.experiment.v1.SearchSummary + (*_struct.Value)(nil), // 17: google.protobuf.Value + (State)(0), // 18: determined.experiment.v1.State + (*_struct.Struct)(nil), // 19: google.protobuf.Struct } var file_determined_experiment_v1_searcher_proto_depIdxs = []int32{ - 21, // 0: determined.experiment.v1.ValidationCompleted.metric:type_name -> google.protobuf.Value - 1, // 1: determined.experiment.v1.TrialExitedEarly.exited_reason:type_name -> determined.experiment.v1.TrialExitedEarly.ExitedReason - 22, // 2: determined.experiment.v1.ExperimentInactive.experiment_state:type_name -> determined.experiment.v1.State - 2, // 3: determined.experiment.v1.SearcherEvent.initial_operations:type_name -> determined.experiment.v1.InitialOperations - 3, // 4: determined.experiment.v1.SearcherEvent.trial_created:type_name -> determined.experiment.v1.TrialCreated - 5, // 5: determined.experiment.v1.SearcherEvent.validation_completed:type_name -> determined.experiment.v1.ValidationCompleted - 6, // 6: determined.experiment.v1.SearcherEvent.trial_closed:type_name -> determined.experiment.v1.TrialClosed - 7, // 7: determined.experiment.v1.SearcherEvent.trial_exited_early:type_name -> determined.experiment.v1.TrialExitedEarly - 4, // 8: determined.experiment.v1.SearcherEvent.trial_progress:type_name -> determined.experiment.v1.TrialProgress - 8, // 9: determined.experiment.v1.SearcherEvent.experiment_inactive:type_name -> determined.experiment.v1.ExperimentInactive - 10, // 10: determined.experiment.v1.CompleteValidateAfterOperation.op:type_name -> determined.experiment.v1.ValidateAfterOperation - 21, // 11: determined.experiment.v1.CompleteValidateAfterOperation.searcher_metric:type_name -> google.protobuf.Value - 17, // 12: determined.experiment.v1.SearcherOperation.trial_operation:type_name -> determined.experiment.v1.TrialOperation - 13, // 13: determined.experiment.v1.SearcherOperation.create_trial:type_name -> determined.experiment.v1.CreateTrialOperation - 14, // 14: determined.experiment.v1.SearcherOperation.close_trial:type_name -> determined.experiment.v1.CloseTrialOperation - 15, // 15: determined.experiment.v1.SearcherOperation.shut_down:type_name -> determined.experiment.v1.ShutDownOperation - 11, // 16: determined.experiment.v1.SearcherOperation.set_searcher_progress:type_name -> determined.experiment.v1.SetSearcherProgressOperation - 10, // 17: determined.experiment.v1.TrialOperation.validate_after:type_name -> determined.experiment.v1.ValidateAfterOperation - 0, // 18: determined.experiment.v1.RunnableOperation.type:type_name -> determined.experiment.v1.RunnableType - 18, // 19: determined.experiment.v1.TrialSimulation.operations:type_name -> determined.experiment.v1.RunnableOperation - 23, // 20: determined.experiment.v1.ExperimentSimulation.config:type_name -> google.protobuf.Struct - 19, // 21: determined.experiment.v1.ExperimentSimulation.trials:type_name -> determined.experiment.v1.TrialSimulation - 22, // [22:22] is the sub-list for method output_type - 22, // [22:22] is the sub-list for method input_type - 22, // [22:22] is the sub-list for extension type_name - 22, // [22:22] is the sub-list for extension extendee - 0, // [0:22] is the sub-list for field type_name + 17, // 0: determined.experiment.v1.ValidationCompleted.metric:type_name -> google.protobuf.Value + 0, // 1: determined.experiment.v1.TrialExitedEarly.exited_reason:type_name -> determined.experiment.v1.TrialExitedEarly.ExitedReason + 18, // 2: determined.experiment.v1.ExperimentInactive.experiment_state:type_name -> determined.experiment.v1.State + 13, // 3: determined.experiment.v1.SearcherOperation.trial_operation:type_name -> determined.experiment.v1.TrialOperation + 9, // 4: determined.experiment.v1.SearcherOperation.create_trial:type_name -> determined.experiment.v1.CreateTrialOperation + 10, // 5: determined.experiment.v1.SearcherOperation.close_trial:type_name -> determined.experiment.v1.CloseTrialOperation + 11, // 6: determined.experiment.v1.SearcherOperation.shut_down:type_name -> determined.experiment.v1.ShutDownOperation + 8, // 7: determined.experiment.v1.TrialOperation.validate_after:type_name -> determined.experiment.v1.ValidateAfterOperation + 14, // 8: determined.experiment.v1.TrialSummary.unit:type_name -> determined.experiment.v1.SearchUnit + 19, // 9: determined.experiment.v1.SearchSummary.config:type_name -> google.protobuf.Struct + 15, // 10: determined.experiment.v1.SearchSummary.trials:type_name -> determined.experiment.v1.TrialSummary + 11, // [11:11] is the sub-list for method output_type + 11, // [11:11] is the sub-list for method input_type + 11, // [11:11] is the sub-list for extension type_name + 11, // [11:11] is the sub-list for extension extendee + 0, // [0:11] is the sub-list for field type_name } func init() { file_determined_experiment_v1_searcher_proto_init() } @@ -1800,7 +1364,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SearcherEvent); i { + switch v := v.(*ValidateAfterOperation); i { case 0: return &v.state case 1: @@ -1812,7 +1376,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateAfterOperation); i { + switch v := v.(*CreateTrialOperation); i { case 0: return &v.state case 1: @@ -1824,7 +1388,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SetSearcherProgressOperation); i { + switch v := v.(*CloseTrialOperation); i { case 0: return &v.state case 1: @@ -1836,7 +1400,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CompleteValidateAfterOperation); i { + switch v := v.(*ShutDownOperation); i { case 0: return &v.state case 1: @@ -1848,7 +1412,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateTrialOperation); i { + switch v := v.(*SearcherOperation); i { case 0: return &v.state case 1: @@ -1860,7 +1424,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CloseTrialOperation); i { + switch v := v.(*TrialOperation); i { case 0: return &v.state case 1: @@ -1872,7 +1436,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ShutDownOperation); i { + switch v := v.(*SearchUnit); i { case 0: return &v.state case 1: @@ -1884,7 +1448,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SearcherOperation); i { + switch v := v.(*TrialSummary); i { case 0: return &v.state case 1: @@ -1896,19 +1460,7 @@ func file_determined_experiment_v1_searcher_proto_init() { } } file_determined_experiment_v1_searcher_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*TrialOperation); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_experiment_v1_searcher_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RunnableOperation); i { + switch v := v.(*SearchSummary); i { case 0: return &v.state case 1: @@ -1919,57 +1471,24 @@ func file_determined_experiment_v1_searcher_proto_init() { return nil } } - file_determined_experiment_v1_searcher_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*TrialSimulation); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_determined_experiment_v1_searcher_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExperimentSimulation); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - file_determined_experiment_v1_searcher_proto_msgTypes[7].OneofWrappers = []interface{}{ - (*SearcherEvent_InitialOperations)(nil), - (*SearcherEvent_TrialCreated)(nil), - (*SearcherEvent_ValidationCompleted)(nil), - (*SearcherEvent_TrialClosed)(nil), - (*SearcherEvent_TrialExitedEarly)(nil), - (*SearcherEvent_TrialProgress)(nil), - (*SearcherEvent_ExperimentInactive)(nil), } - file_determined_experiment_v1_searcher_proto_msgTypes[14].OneofWrappers = []interface{}{ + file_determined_experiment_v1_searcher_proto_msgTypes[11].OneofWrappers = []interface{}{ (*SearcherOperation_TrialOperation)(nil), (*SearcherOperation_CreateTrial)(nil), (*SearcherOperation_CloseTrial)(nil), (*SearcherOperation_ShutDown)(nil), - (*SearcherOperation_SetSearcherProgress)(nil), } - file_determined_experiment_v1_searcher_proto_msgTypes[15].OneofWrappers = []interface{}{ + file_determined_experiment_v1_searcher_proto_msgTypes[12].OneofWrappers = []interface{}{ (*TrialOperation_ValidateAfter)(nil), } + file_determined_experiment_v1_searcher_proto_msgTypes[13].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_determined_experiment_v1_searcher_proto_rawDesc, - NumEnums: 2, - NumMessages: 19, + NumEnums: 1, + NumMessages: 16, NumExtensions: 0, NumServices: 0, }, diff --git a/proto/src/determined/api/v1/api.proto b/proto/src/determined/api/v1/api.proto index d5ec7f4134f..69e5008ae9c 100644 --- a/proto/src/determined/api/v1/api.proto +++ b/proto/src/determined/api/v1/api.proto @@ -1209,28 +1209,6 @@ service Determined { }; } - // Get the current searcher operation. - rpc GetCurrentTrialSearcherOperation(GetCurrentTrialSearcherOperationRequest) - returns (GetCurrentTrialSearcherOperationResponse) { - option (google.api.http) = { - get: "/api/v1/trials/{trial_id}/searcher/operation" - }; - option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = { - tags: "Internal" - }; - } - // Reports to the searcher that the trial has completed the given searcher - // operation. - rpc CompleteTrialSearcherValidation(CompleteTrialSearcherValidationRequest) - returns (CompleteTrialSearcherValidationResponse) { - option (google.api.http) = { - post: "/api/v1/trials/{trial_id}/searcher/completed_operation" - body: "completed_operation" - }; - option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = { - tags: "Internal" - }; - } // Reports to the searcher that the trial has completed the current // requested amount of training with the given searcher validation // metric. @@ -1912,29 +1890,6 @@ service Determined { }; } - // Get the list of custom searcher events with long polling. - rpc GetSearcherEvents(GetSearcherEventsRequest) - returns (GetSearcherEventsResponse) { - option (google.api.http) = { - get: "/api/v1/experiments/{experiment_id}/searcher_events" - }; - option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = { - tags: "Experiments" - }; - } - - // Submit operations to a custom searcher. - rpc PostSearcherOperations(PostSearcherOperationsRequest) - returns (PostSearcherOperationsResponse) { - option (google.api.http) = { - post: "/api/v1/experiments/{experiment_id}/searcher_operations" - body: "*" - }; - option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = { - tags: "Experiments" - }; - } - // Get the set of metric names recorded for a list of experiments. rpc ExpMetricNames(ExpMetricNamesRequest) returns (stream ExpMetricNamesResponse) { diff --git a/proto/src/determined/api/v1/experiment.proto b/proto/src/determined/api/v1/experiment.proto index 0b74b2ddfac..6c536ec4c4b 100644 --- a/proto/src/determined/api/v1/experiment.proto +++ b/proto/src/determined/api/v1/experiment.proto @@ -285,8 +285,8 @@ message PreviewHPSearchRequest { } // Response to PreviewSearchRequest. message PreviewHPSearchResponse { - // The resulting simulation. - determined.experiment.v1.ExperimentSimulation simulation = 1; + // The resulting summary. + determined.experiment.v1.SearchSummary summary = 1; } // Activate an experiment. @@ -913,31 +913,6 @@ message GetModelDefFileResponse { bytes file = 1; } -// Request to get the list of searcher events. -message GetSearcherEventsRequest { - // The ID of the experiment. - int32 experiment_id = 1; -} - -// Response to GetSearcherEventsRequest. -message GetSearcherEventsResponse { - // The list of events in the queue. - repeated determined.experiment.v1.SearcherEvent searcher_events = 1; -} - -// Request for sending operations from a custom search method. -message PostSearcherOperationsRequest { - // The experiment ID. - int32 experiment_id = 1; - // List of operations to submit. - repeated determined.experiment.v1.SearcherOperation searcher_operations = 2; - // The event that triggered the client to send these operations to the master. - determined.experiment.v1.SearcherEvent triggered_by_event = 3; -} - -// Response to PostSearcherOperationsResponse. -message PostSearcherOperationsResponse {} - // Request for searching experiments message SearchExperimentsRequest { // ID of the project to look at diff --git a/proto/src/determined/api/v1/trial.proto b/proto/src/determined/api/v1/trial.proto index 199fcd4679b..84b96588e73 100644 --- a/proto/src/determined/api/v1/trial.proto +++ b/proto/src/determined/api/v1/trial.proto @@ -651,37 +651,6 @@ message NotifyContainerRunningResponse { repeated google.protobuf.Struct data = 1; } -// Retrieves the current searcher operation. -message GetCurrentTrialSearcherOperationRequest { - option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { - json_schema: { required: [ "trial_id" ] } - }; - // The id of the trial. - int32 trial_id = 1; -} -// Response to GetCurrentTrialSearcherOperationRequest -message GetCurrentTrialSearcherOperationResponse { - // The current searcher operation. - determined.experiment.v1.TrialOperation op = 1; - // The status of the searcher operation. - bool completed = 2; -} - -// Reports to the searcher that the trial has completed the current requested -// amount of training. -message CompleteTrialSearcherValidationRequest { - option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { - json_schema: { required: [ "trial_id", "searcher_metric" ] } - }; - // The id of the trial. - int32 trial_id = 1; - // The completed operation. - determined.experiment.v1.CompleteValidateAfterOperation completed_operation = - 2; -} -// Response to CompleteTrialSearcherValidationRequest -message CompleteTrialSearcherValidationResponse {} - // Report a voluntary, permanent early exit to the searcher. message ReportTrialSearcherEarlyExitRequest { option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { diff --git a/proto/src/determined/experiment/v1/searcher.proto b/proto/src/determined/experiment/v1/searcher.proto index f75f17debc3..7fb0b39f03e 100644 --- a/proto/src/determined/experiment/v1/searcher.proto +++ b/proto/src/determined/experiment/v1/searcher.proto @@ -96,33 +96,6 @@ message ExperimentInactive { State experiment_state = 1; } -// SearcherEvent is a message from master to a client-driven custom searcher -// informing it of relevant changes in the state of an experiment. -message SearcherEvent { - option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { - json_schema: { required: [ "id" ] } - }; - // Incremental ID of the event. - int32 id = 1; - // The concrete event. - oneof event { - // An experiment has just been created. - InitialOperations initial_operations = 3; - // A trial has been created. - TrialCreated trial_created = 4; - // Validation has completed. - ValidationCompleted validation_completed = 5; - // Trial has finished. - TrialClosed trial_closed = 6; - // Trial exited early. - TrialExitedEarly trial_exited_early = 7; - // Trial progress. - TrialProgress trial_progress = 8; - // Experiment is inactive. - ExperimentInactive experiment_inactive = 9; - } -} - // ValidateAfterOperation means the trial should train and validate after // training the given length. message ValidateAfterOperation { @@ -132,22 +105,6 @@ message ValidateAfterOperation { uint64 length = 2; } -// SetSearcherProgressOperation informs the master of the progress of the custom -// searcher. -message SetSearcherProgressOperation { - // Experiment progress as a float between 0.0 and 1.0. - double progress = 1; -} - -// Used to complete a ValidateAfterOperation. -message CompleteValidateAfterOperation { - // The ValidateAfterOperation being completed. - ValidateAfterOperation op = 1; - // The value of searcher metric associated with this completed operation. - // The metric provided should be the metric used to guide HP search. - google.protobuf.Value searcher_metric = 2; -} - // Create a trial with given hyperparameters. message CreateTrialOperation { // The ID of the trial to create. @@ -182,9 +139,6 @@ message SearcherOperation { CloseTrialOperation close_trial = 3; // ShutDownOperation is issued to shut down the custom search method. ShutDownOperation shut_down = 4; - // SetSearcherProgressOperation is issued to set the progress of the custom - // search method. - SetSearcherProgressOperation set_searcher_progress = 5; } } @@ -198,43 +152,39 @@ message TrialOperation { } } -// RunnableType defines the type of operation that should be executed by trial -// runners. -enum RunnableType { - // Denotes an unknown runnable type. - RUNNABLE_TYPE_UNSPECIFIED = 0; - // Signals to a trial runner that it should run a train. - RUNNABLE_TYPE_TRAIN = 1; - // Signals to a trial runner it should compute validation metrics. - RUNNABLE_TYPE_VALIDATE = 2; -} - -// RunnableOperation represents a single runnable operation emitted by a -// searcher. -message RunnableOperation { - // This is the type of the operation. - RunnableType type = 1; - // If the type == WORKLOAD_KIND_TRAIN, this is the number of units - uint64 length = 2; +// SearchUnit describes a length unit used by some searchers to manage training. +message SearchUnit { + option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { + json_schema: { required: [ "max_length" ] } + }; + // Name of the length unit (if max_length is false). + optional string name = 1; + // Value of the length unit (if max_length is false). + optional int32 value = 2; + // Bool indicating whether the training length is defined in code. + bool max_length = 3; } -// TrialSimulation is a specific sequence of workloads that were run before the -// trial was completed. -message TrialSimulation { - // The list of operations that were run before the trial was completed. - repeated RunnableOperation operations = 1; - // The number of times that this trial configuration has occurred during the - // simulation. - int32 occurrences = 2; +// TrialSummary describes the runs that are estimated to train for a certain +// length. +message TrialSummary { + option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { + json_schema: { required: [ "count", "unit" ] } + }; + // Number of trials. + int32 count = 1; + // Training length for the trials. + SearchUnit unit = 2; } -// ExperimentSimulation holds the configuration and results of simulated run of -// a searcher. -message ExperimentSimulation { - // The simulated experiment config. +// SearchSummary contains the estimated trials and training lengths that a +// search plans to execute. +message SearchSummary { + option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = { + json_schema: { required: [ "config", "runs" ] } + }; + // The searcher config from which the summary is generated. google.protobuf.Struct config = 1; - // The searcher simulation seed. - uint32 seed = 2; - // The list of trials in the simulation. - repeated TrialSimulation trials = 3; + // A list of planned number of trials to their training lengths. + repeated TrialSummary trials = 2; } diff --git a/schemas/expconf/v0/searcher-adaptive-asha.json b/schemas/expconf/v0/searcher-adaptive-asha.json index ad3417bceca..fe7c020b4d3 100644 --- a/schemas/expconf/v0/searcher-adaptive-asha.json +++ b/schemas/expconf/v0/searcher-adaptive-asha.json @@ -8,7 +8,6 @@ "name" ], "eventuallyRequired": [ - "max_length", "max_trials", "metric" ], @@ -34,6 +33,20 @@ "default": null, "minimum": 1 }, + "time_metric": { + "type": [ + "string", + "null" + ], + "default": null + }, + "max_time": { + "type": [ + "integer", + "null" + ], + "default": null + }, "mode": { "enum": [ null, @@ -81,7 +94,7 @@ "boolean", "null" ], - "default": false + "default": null }, "metric": { "type": [ diff --git a/schemas/expconf/v0/searcher-adaptive-simple.json b/schemas/expconf/v0/searcher-adaptive-simple.json index 458c41ab656..9a5aab58fb5 100644 --- a/schemas/expconf/v0/searcher-adaptive-simple.json +++ b/schemas/expconf/v0/searcher-adaptive-simple.json @@ -10,7 +10,6 @@ ], "eventuallyRequired": [ "max_trials", - "max_length", "metric" ], "properties": { diff --git a/schemas/expconf/v0/searcher-adaptive.json b/schemas/expconf/v0/searcher-adaptive.json index 4365330892d..123d335b07c 100644 --- a/schemas/expconf/v0/searcher-adaptive.json +++ b/schemas/expconf/v0/searcher-adaptive.json @@ -10,7 +10,6 @@ ], "eventuallyRequired": [ "budget", - "max_length", "metric" ], "properties": { diff --git a/schemas/expconf/v0/searcher-async-halving.json b/schemas/expconf/v0/searcher-async-halving.json index 86dcfaa5843..43bb00ffd58 100644 --- a/schemas/expconf/v0/searcher-async-halving.json +++ b/schemas/expconf/v0/searcher-async-halving.json @@ -9,7 +9,6 @@ ], "eventuallyRequired": [ "num_rungs", - "max_length", "max_trials", "metric" ], @@ -62,7 +61,7 @@ "boolean", "null" ], - "default": false + "default": null }, "metric": { "type": [ @@ -71,6 +70,20 @@ ], "default": null }, + "time_metric": { + "type": [ + "string", + "null" + ], + "default": null + }, + "max_time": { + "type": [ + "integer", + "null" + ], + "default": null + }, "smaller_is_better": { "type": [ "boolean", diff --git a/schemas/expconf/v0/searcher-custom.json b/schemas/expconf/v0/searcher-custom.json index 47dcb1efd3a..1b724ab2d19 100644 --- a/schemas/expconf/v0/searcher-custom.json +++ b/schemas/expconf/v0/searcher-custom.json @@ -1,5 +1,6 @@ { "$schema": "http://json-schema.org/draft-07/schema#", + "$comment": "this is an EOL searcher, not to be used in new experiments", "$id": "http://determined.ai/schemas/expconf/v0/searcher-custom.json", "title": "CustomConfig", "type": "object", diff --git a/schemas/expconf/v0/searcher-grid.json b/schemas/expconf/v0/searcher-grid.json index 374f2830d53..b97363a55ee 100644 --- a/schemas/expconf/v0/searcher-grid.json +++ b/schemas/expconf/v0/searcher-grid.json @@ -8,7 +8,6 @@ "name" ], "eventuallyRequired": [ - "max_length", "metric" ], "properties": { diff --git a/schemas/expconf/v0/searcher-random.json b/schemas/expconf/v0/searcher-random.json index 063797e9160..f3498ac7692 100644 --- a/schemas/expconf/v0/searcher-random.json +++ b/schemas/expconf/v0/searcher-random.json @@ -9,7 +9,6 @@ ], "eventuallyRequired": [ "max_trials", - "max_length", "metric" ], "properties": { diff --git a/schemas/expconf/v0/searcher-single.json b/schemas/expconf/v0/searcher-single.json index e3630f94d17..60f328d251b 100644 --- a/schemas/expconf/v0/searcher-single.json +++ b/schemas/expconf/v0/searcher-single.json @@ -8,7 +8,6 @@ "name" ], "eventuallyRequired": [ - "max_length", "metric" ], "properties": { diff --git a/schemas/expconf/v0/searcher-sync-halving.json b/schemas/expconf/v0/searcher-sync-halving.json index 8c80766b84a..0cd7b4d871c 100644 --- a/schemas/expconf/v0/searcher-sync-halving.json +++ b/schemas/expconf/v0/searcher-sync-halving.json @@ -10,7 +10,6 @@ ], "eventuallyRequired": [ "num_rungs", - "max_length", "budget", "metric" ], diff --git a/schemas/expconf/v0/searcher.json b/schemas/expconf/v0/searcher.json index ee16e3a8ad3..470038f0c93 100644 --- a/schemas/expconf/v0/searcher.json +++ b/schemas/expconf/v0/searcher.json @@ -24,10 +24,6 @@ "unionKey": "const:name=grid", "$ref": "http://determined.ai/schemas/expconf/v0/searcher-grid.json" }, - { - "unionKey": "const:name=custom", - "$ref": "http://determined.ai/schemas/expconf/v0/searcher-custom.json" - }, { "unionKey": "const:name=adaptive_asha", "$ref": "http://determined.ai/schemas/expconf/v0/searcher-adaptive-asha.json" @@ -36,6 +32,11 @@ "unionKey": "const:name=async_halving", "$ref": "http://determined.ai/schemas/expconf/v0/searcher-async-halving.json" }, + { + "$comment": "this is an EOL searcher, not to be used in new experiments", + "unionKey": "const:name=custom", + "$ref": "http://determined.ai/schemas/expconf/v0/searcher-custom.json" + }, { "$comment": "this is an EOL searcher, not to be used in new experiments", "unionKey": "const:name=adaptive", @@ -65,6 +66,8 @@ "max_concurrent_trials": true, "max_length": true, "max_rungs": true, + "max_time": true, + "time_metric": true, "max_trials": true, "mode": true, "name": true, diff --git a/schemas/test_cases/v0/defaults.yaml b/schemas/test_cases/v0/defaults.yaml index 05c7c703a79..c3ed5e46afc 100644 --- a/schemas/test_cases/v0/defaults.yaml +++ b/schemas/test_cases/v0/defaults.yaml @@ -67,13 +67,9 @@ http://determined.ai/schemas/expconf/v0/searcher.json case: name: single - max_length: - batches: 1000 metric: loss defaulted: name: single - max_length: - batches: 1000 metric: loss smaller_is_better: true source_trial_id: null @@ -87,16 +83,12 @@ http://determined.ai/schemas/expconf/v0/searcher.json case: name: random - max_length: - batches: 1000 max_trials: 1000 metric: loss source_checkpoint_uuid: "asdf" defaulted: name: random max_concurrent_trials: 16 - max_length: - batches: 1000 max_trials: 1000 metric: loss smaller_is_better: true @@ -111,15 +103,11 @@ http://determined.ai/schemas/expconf/v0/searcher.json case: name: grid - max_length: - batches: 1000 metric: loss source_trial_id: 15 defaulted: name: grid max_concurrent_trials: 16 - max_length: - batches: 1000 metric: loss smaller_is_better: true source_trial_id: 15 @@ -134,15 +122,11 @@ case: name: async_halving num_rungs: 5 - max_length: - batches: 1000 max_trials: 100 metric: loss defaulted: name: async_halving num_rungs: 5 - max_length: - batches: 1000 max_trials: 100 divisor: 4 max_concurrent_trials: 16 @@ -150,7 +134,10 @@ smaller_is_better: true source_trial_id: null source_checkpoint_uuid: null - stop_once: false + # The master asserts these are not null in searcher.AssertCurrent(), + # but the json-schema layer doesn't know about that. + time_metric: null + max_time: null - name: adaptive_asha searcher defaults sane_as: @@ -160,14 +147,10 @@ http://determined.ai/schemas/expconf/v0/searcher.json case: name: adaptive_asha - max_length: - batches: 1000 max_trials: 100 metric: loss defaulted: name: adaptive_asha - max_length: - batches: 1000 max_trials: 100 bracket_rungs: [] divisor: 4 @@ -178,7 +161,10 @@ smaller_is_better: true source_trial_id: null source_checkpoint_uuid: null - stop_once: false + # The master asserts these are not null in searcher.AssertCurrent(), + # but the json-schema layer doesn't know about that. + time_metric: null + max_time: null - name: devices defaults, in string and map forms sane_as: diff --git a/schemas/test_cases/v0/experiment.yaml b/schemas/test_cases/v0/experiment.yaml index 48542e28d07..6f12c8e6631 100644 --- a/schemas/test_cases/v0/experiment.yaml +++ b/schemas/test_cases/v0/experiment.yaml @@ -22,7 +22,7 @@ data: any: thing integrations: - pachyderm: + pachyderm: pachd: host: localhost port: 80 @@ -166,8 +166,6 @@ searcher: name: single metric: loss - max_length: - batches: 1000 entrypoint: model_def:MyTrial ##### defaulted: @@ -241,8 +239,6 @@ is_single_node: null scheduling_unit: 100 searcher: - max_length: - batches: 1000 metric: loss name: single smaller_is_better: true @@ -286,8 +282,6 @@ searcher: name: grid metric: loss - max_length: - batches: 1000 entrypoint: model_def:MyTrial - name: check grid conditional (invalid) @@ -324,5 +318,3 @@ searcher: name: grid metric: loss - max_length: - batches: 1000 diff --git a/schemas/test_cases/v0/merging.yaml b/schemas/test_cases/v0/merging.yaml index c1aa5b88807..796d553f248 100644 --- a/schemas/test_cases/v0/merging.yaml +++ b/schemas/test_cases/v0/merging.yaml @@ -63,8 +63,6 @@ case: name: random max_trials: 10 - max_length: - epochs: 1 merge_src: metric: sae smaller_is_better: true @@ -73,8 +71,6 @@ merged: name: random max_trials: 10 - max_length: - epochs: 1 metric: sae smaller_is_better: true source_trial_id: 1 diff --git a/webui/react/src/components/CompareHyperparameters.test.mock.tsx b/webui/react/src/components/CompareHyperparameters.test.mock.tsx index 75847a6542d..3667d1d574e 100644 --- a/webui/react/src/components/CompareHyperparameters.test.mock.tsx +++ b/webui/react/src/components/CompareHyperparameters.test.mock.tsx @@ -300,9 +300,6 @@ export const SELECTED_EXPERIMENTS = [ }, scheduling_unit: 100, searcher: { - max_length: { - batches: 2, - }, metric: 'validation_loss', name: 'single', smaller_is_better: true, diff --git a/webui/react/src/components/ComparisonView.test.mock.tsx b/webui/react/src/components/ComparisonView.test.mock.tsx index 9d5f5a54fd8..800af38149d 100644 --- a/webui/react/src/components/ComparisonView.test.mock.tsx +++ b/webui/react/src/components/ComparisonView.test.mock.tsx @@ -174,7 +174,6 @@ export const SELECTED_EXPERIMENTS: ExperimentWithTrial[] = [ }, resources: {}, searcher: { - max_length: undefined, metric: 'validation_loss', name: 'single', smallerIsBetter: true, diff --git a/webui/react/src/components/ExperimentContinueModal.module.scss b/webui/react/src/components/ExperimentContinueModal.module.scss deleted file mode 100644 index cf7479bd86f..00000000000 --- a/webui/react/src/components/ExperimentContinueModal.module.scss +++ /dev/null @@ -1,3 +0,0 @@ -.fullWidth { - width: 100%; -} diff --git a/webui/react/src/components/ExperimentContinueModal.tsx b/webui/react/src/components/ExperimentContinueModal.tsx index b653a0f295c..d4e732f1f14 100644 --- a/webui/react/src/components/ExperimentContinueModal.tsx +++ b/webui/react/src/components/ExperimentContinueModal.tsx @@ -1,17 +1,14 @@ import Alert from 'hew/Alert'; import Button from 'hew/Button'; import Form, { hasErrors } from 'hew/Form'; -import Icon from 'hew/Icon'; import Input from 'hew/Input'; -import InputNumber from 'hew/InputNumber'; import { Modal } from 'hew/Modal'; -import Row from 'hew/Row'; import Spinner from 'hew/Spinner'; import { Body } from 'hew/Typography'; import { Loaded } from 'hew/utils/loadable'; import yaml from 'js-yaml'; import _ from 'lodash'; -import React, { useCallback, useEffect, useId, useMemo, useState } from 'react'; +import React, { useCallback, useEffect, useId, useState } from 'react'; import useFeature from 'hooks/useFeature'; import { paths } from 'routes/utils'; @@ -29,16 +26,13 @@ import handleError, { import { FULL_CONFIG_BUTTON_TEXT, getExperimentName, - getMaxLengthType, - getMaxLengthValue, SIMPLE_CONFIG_BUTTON_TEXT, trialContinueConfig, upgradeConfig, } from 'utils/experiment'; import { routeToReactUrl } from 'utils/routes'; -import { capitalize, capitalizeWord } from 'utils/string'; +import { capitalize } from 'utils/string'; -import css from './ExperimentContinueModal.module.scss'; const FORM_ID = 'continue-experiment-form'; export const ContinueExperimentType = { @@ -50,12 +44,12 @@ export type ContinueExperimentType = ValueOf; const ExperimentCopyMapping: Record = { [ContinueExperimentType.Continue]: 'Continue Trial in New Experiment', - [ContinueExperimentType.Reactivate]: 'Reactivate Current Trial', + [ContinueExperimentType.Reactivate]: 'Resume Current Trial', } satisfies Record; const SearchCopyMapping: Record = { [ContinueExperimentType.Continue]: 'Continue as New Run', - [ContinueExperimentType.Reactivate]: 'Reactivate Current Run', + [ContinueExperimentType.Reactivate]: 'Resume Current Run', }; type EntityCopyMap = { @@ -74,8 +68,6 @@ const flatRunsEntityCopyMap: EntityCopyMap = { }; const EXPERIMENT_NAME = 'name'; -const MAX_LENGTH = 'maxLength'; -const ADDITIONAL_LENGTH = 'additionalLength'; export interface Props { onClose?: () => void; @@ -116,7 +108,6 @@ const ExperimentContinueModalComponent = ({ const [registryCredentials, setRegistryCredentials] = useState(); const [modalState, setModalState] = useState(DEFAULT_MODAL_STATE); const [disabled, setDisabled] = useState(true); - const [originalConfig, setOriginalConfig] = useState(experiment.configRaw); const f_flat_runs = useFeature().isOn('flat_runs'); const isReactivate = type === ContinueExperimentType.Reactivate; @@ -125,10 +116,6 @@ const ExperimentContinueModalComponent = ({ : [ExperimentCopyMapping, experimentEntityCopyMap]; const actionCopy = actionCopyMap[modalState.type]; - useEffect(() => setOriginalConfig(experiment.configRaw), [experiment]); - - const requiredFields = useMemo(() => [EXPERIMENT_NAME, MAX_LENGTH], []); - const handleModalClose = () => { setModalState(DEFAULT_MODAL_STATE); onClose?.(); @@ -143,24 +130,6 @@ const ExperimentContinueModalComponent = ({ if (!prev.isAdvancedMode) { prev.config.name = values[EXPERIMENT_NAME]; } - if (!isReactivate && values[MAX_LENGTH]) { - const maxLengthType = getMaxLengthType(prev.config); - if (maxLengthType) { - prev.config.searcher.max_length[maxLengthType] = parseInt(values[MAX_LENGTH]); - } else { - prev.config.searcher.max_length = parseInt(values[MAX_LENGTH]); - } - } - if (isReactivate && values[ADDITIONAL_LENGTH] && parseInt(values[ADDITIONAL_LENGTH]) >= 0) { - const maxLengthType = getMaxLengthType(prev.config); - if (maxLengthType) { - prev.config.searcher.max_length[maxLengthType] = - originalConfig.searcher.max_length[maxLengthType] + parseInt(values[ADDITIONAL_LENGTH]); - } else { - prev.config.searcher.max_length = - originalConfig.searcher.max_length + parseInt(values[ADDITIONAL_LENGTH]); - } - } prev.configString = yaml.dump(prev.config); return prev; }); @@ -168,7 +137,7 @@ const ExperimentContinueModalComponent = ({ const hasError = hasErrors(form); const values = form.getFieldsValue(); const missingRequiredFields = Object.entries(values).some(([key, value]) => { - return requiredFields.includes(key) && !value; + return EXPERIMENT_NAME === key && !value; }); setDisabled(hasError || missingRequiredFields); }; @@ -208,38 +177,8 @@ const ExperimentContinueModalComponent = ({ if (modalState.isAdvancedMode && form) { try { const newConfig = (yaml.load(modalState.configString) || {}) as RawJson; - const maxLengthType = getMaxLengthType(newConfig); - const isReactivate = modalState.type === ContinueExperimentType.Reactivate; - const originalLength = maxLengthType - ? originalConfig.searcher.max_length[maxLengthType] - : originalConfig.searcher.max_length; - let additionalLength; - try { - const newLength = maxLengthType - ? newConfig.searcher.max_length[maxLengthType] - : newConfig.searcher.max_length; - const lengthDifference = newLength - originalLength; - if ( - originalLength && - lengthDifference && - Number.isInteger(originalLength) && - Number.isInteger(lengthDifference) && - lengthDifference > 0 - ) { - additionalLength = lengthDifference; - } - } catch { - additionalLength = undefined; - } - form.setFields([ - { name: EXPERIMENT_NAME, value: getExperimentName(newConfig) }, - { - name: MAX_LENGTH, - value: !isReactivate ? getMaxLengthValue(newConfig) : undefined, - }, - { name: ADDITIONAL_LENGTH, value: additionalLength }, - ]); + form.setFields([{ name: EXPERIMENT_NAME, value: getExperimentName(newConfig) }]); await form.validateFields(); } catch (e) { handleError(e, { publicMessage: 'failed to load previous yaml config' }); @@ -247,24 +186,12 @@ const ExperimentContinueModalComponent = ({ } else { setDisabled(false); } - }, [form, modalState, originalConfig.searcher.max_length]); + }, [form, modalState]); const getConfigFromForm = useCallback( (config: RawJson) => { if (!form) return yaml.dump(config); - - const formValues = form.getFieldsValue(); const newConfig = structuredClone(config); - - if (formValues[MAX_LENGTH]) { - const maxLengthType = getMaxLengthType(newConfig); - if (maxLengthType === undefined) { - // Unitless searcher config. - newConfig.searcher.max_length = parseInt(formValues[MAX_LENGTH]); - } else { - newConfig.searcher.max_length = { [maxLengthType]: parseInt(formValues[MAX_LENGTH]) }; - } - } return yaml.dump(newConfig); }, [form], @@ -401,14 +328,13 @@ const ExperimentContinueModalComponent = ({ }; return _.isEqual(prev, newModalState) ? prev : newModalState; }); - form.validateFields(requiredFields); // initial disabled state set here, gets updated later in handleFieldsChange - }, [entityCopyMap, experiment, trial, type, isReactivate, form, requiredFields]); + form.validateFields([EXPERIMENT_NAME]); // initial disabled state set here, gets updated later in handleFieldsChange + }, [entityCopyMap, experiment, trial, type, isReactivate, form]); if (!experiment || (!isReactivate && !trial)) return <>; const hideSimpleConfig = isReactivate && experiment.state !== RunState.Completed; - const maxLengthType = capitalizeWord(getMaxLengthType(modalState.config) || 'batches'); const modalIsInAdvancedMode = modalState.isAdvancedMode || hideSimpleConfig; return ( {isReactivate - ? `Reactivate and continue the current ${entityCopyMap.trial} from the latest checkpoint` + ? `Resume the current ${entityCopyMap.trial} from the latest checkpoint` : f_flat_runs ? "Start a new run from the current run's latest checkpoint" : "Start a new experiment from the current trial's latest checkpoint"} @@ -479,52 +405,6 @@ const ExperimentContinueModalComponent = ({ )} - {!isReactivate && ( - { - let errorMessage = ''; - if (!value) errorMessage = `Please provide a max ${maxLengthType}.`; - if (value < 1) errorMessage = `Max ${maxLengthType} must be at least 1.`; - return errorMessage ? Promise.reject(errorMessage) : Promise.resolve(); - }, - }, - ]}> - - - )} - {isReactivate && !hideSimpleConfig && ( - - {`Additional ${maxLengthType}`} - - - } - name={ADDITIONAL_LENGTH} - rules={[ - { - required: false, - validator: (_rule, value) => { - let errorMessage = ''; - if (value < 0) errorMessage = `Additional ${maxLengthType} must be at least 0.`; - if (value && !Number.isInteger(value)) - errorMessage = `Additional ${maxLengthType} must be an integer.`; - return errorMessage ? Promise.reject(errorMessage) : Promise.resolve(); - }, - }, - ]}> - - - )}
{!hideSimpleConfig && ( diff --git a/webui/react/src/components/ExperimentCreateModal.tsx b/webui/react/src/components/ExperimentCreateModal.tsx index 59ada006054..8f05d77e1f6 100644 --- a/webui/react/src/components/ExperimentCreateModal.tsx +++ b/webui/react/src/components/ExperimentCreateModal.tsx @@ -25,8 +25,6 @@ import handleError, { import { FULL_CONFIG_BUTTON_TEXT, getExperimentName, - getMaxLengthType, - getMaxLengthValue, SIMPLE_CONFIG_BUTTON_TEXT, trialContinueConfig, upgradeConfig, @@ -64,7 +62,6 @@ const RunEntityCopyMap = { export type CreateExperimentType = ValueOf; const EXPERIMENT_NAME = 'name'; -const MAX_LENGTH = 'maxLength'; interface Props { onClose?: () => void; @@ -120,8 +117,6 @@ const ExperimentCreateModalComponent = ({ ? `Fork ${capitalize(entityCopy.experiment)} ${experiment.id}` : `Continue ${capitalize(entityCopy.trial)} ${trial?.id}`; - const requiredFields = useMemo(() => [EXPERIMENT_NAME, MAX_LENGTH], []); - const handleModalClose = () => { setModalState(DEFAULT_MODAL_STATE); onClose?.(); @@ -136,14 +131,6 @@ const ExperimentCreateModalComponent = ({ if (!prev.isAdvancedMode) { prev.config.name = values[EXPERIMENT_NAME]; } - if (values[MAX_LENGTH]) { - const maxLengthType = getMaxLengthType(prev.config); - if (maxLengthType) { - prev.config.searcher.max_length[maxLengthType] = parseInt(values[MAX_LENGTH]); - } else { - prev.config.searcher.max_length = parseInt(values[MAX_LENGTH]); - } - } prev.configString = yaml.dump(prev.config); return prev; }); @@ -151,7 +138,7 @@ const ExperimentCreateModalComponent = ({ const hasError = hasErrors(form); const values = form.getFieldsValue(); const missingRequiredFields = Object.entries(values).some(([key, value]) => { - return requiredFields.includes(key) && !value; + return EXPERIMENT_NAME === key && !value; }); setDisabled(hasError || missingRequiredFields); }; @@ -191,15 +178,7 @@ const ExperimentCreateModalComponent = ({ if (modalState.isAdvancedMode && form) { try { const newConfig = (yaml.load(modalState.configString) || {}) as RawJson; - const isFork = modalState.type === CreateExperimentType.Fork; - - form.setFields([ - { name: 'name', value: getExperimentName(newConfig) }, - { - name: 'maxLength', - value: !isFork ? getMaxLengthValue(newConfig) : undefined, - }, - ]); + form.setFields([{ name: 'name', value: getExperimentName(newConfig) }]); } catch (e) { handleError(e, { publicMessage: 'failed to load previous yaml config' }); } @@ -212,19 +191,7 @@ const ExperimentCreateModalComponent = ({ const getConfigFromForm = useCallback( (config: RawJson) => { if (!form) return yaml.dump(config); - - const formValues = form.getFieldsValue(); const newConfig = structuredClone(config); - - if (formValues[MAX_LENGTH]) { - const maxLengthType = getMaxLengthType(newConfig); - if (maxLengthType === undefined) { - // Unitless searcher config. - newConfig.searcher.max_length = parseInt(formValues[MAX_LENGTH]); - } else { - newConfig.searcher.max_length = { [maxLengthType]: parseInt(formValues[MAX_LENGTH]) }; - } - } return yaml.dump(newConfig); }, [form], @@ -346,8 +313,8 @@ const ExperimentCreateModalComponent = ({ }; return _.isEqual(prev, newModalState) ? prev : newModalState; }); - form.validateFields(requiredFields); // initial disabled state set here, gets updated later in handleFieldsChange - }, [entityCopy, experiment, trial, type, isFork, form, requiredFields]); + form.validateFields([EXPERIMENT_NAME]); // initial disabled state set here, gets updated later in handleFieldsChange + }, [entityCopy, experiment, trial, type, isFork, form]); if (!experiment || (!isFork && !trial)) return <>; @@ -397,24 +364,6 @@ const ExperimentCreateModalComponent = ({ ]}> - {!isFork && ( - { - let errorMessage = ''; - if (!value) errorMessage = 'Please provide a max length.'; - if (value < 1) errorMessage = 'Max length must be at least 1.'; - return errorMessage ? Promise.reject(errorMessage) : Promise.resolve(); - }, - }, - ]}> - - - )}

Configure Trials

- - - - - - { scheduling_unit: 100, searcher: { max_concurrent_trials: 16, - max_length: { - epochs: 1, - }, max_trials: 2, metric: 'accuracy', name: 'random', diff --git a/webui/react/src/pages/ExperimentDetails/ExperimentDetails.test.mock.ts b/webui/react/src/pages/ExperimentDetails/ExperimentDetails.test.mock.ts index 07c71e1a55f..19f119b0821 100644 --- a/webui/react/src/pages/ExperimentDetails/ExperimentDetails.test.mock.ts +++ b/webui/react/src/pages/ExperimentDetails/ExperimentDetails.test.mock.ts @@ -26,7 +26,6 @@ const RESPONSES = { bracket_rungs: [], divisor: 4, max_concurrent_trials: 16, - max_length: { batches: 937, epochs: 1, records: 1 }, max_rungs: 5, max_trials: 16, metric: 'val_loss', @@ -118,7 +117,6 @@ const RESPONSES = { bracket_rungs: [], divisor: 4, max_concurrent_trials: 16, - max_length: { batches: 937, epochs: 1, records: 1 }, max_rungs: 5, max_trials: 16, metric: 'val_loss', @@ -723,7 +721,6 @@ const RESPONSES = { profiling: { enabled: false }, resources: {}, searcher: { - max_length: { batches: 937, epochs: 1, records: 1 }, metric: 'validation_loss', name: 'single' as const, smaller_is_better: true, @@ -823,7 +820,6 @@ const RESPONSES = { }, scheduling_unit: 100, searcher: { - max_length: { batches: 937, epochs: 1, records: 1 }, metric: 'validation_loss', name: 'single' as const, smaller_is_better: true, diff --git a/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx b/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx index 4e6ff2c2f0b..6ca0fd8550d 100644 --- a/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx +++ b/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx @@ -314,13 +314,13 @@ const ExperimentDetailsHeader: React.FC = ({ : 'Create New Experiment...', }, { - key: 'Reactivate Current Trial', - label: `Reactivate Current ${capitalize(copyMap.trial)}...`, + key: 'Resume Current Trial', + label: `Resume Current ${capitalize(copyMap.trial)}...`, }, ]} onClick={(key: string) => { if (key === 'Create New Experiment') ContinueExperimentModal.open(); - if (key === 'Reactivate Current Trial') ReactivateExperimentModal.open(); + if (key === 'Resume Current Trial') ReactivateExperimentModal.open(); }}> @@ -337,9 +337,9 @@ const ExperimentDetailsHeader: React.FC = ({ { key: 'reactivate-current-trial', label: experiment.unmanaged ? ( - Reactivate Current {capitalize(copyMap.trial)} + Resume Current {capitalize(copyMap.trial)} ) : ( - `Reactivate Current ${capitalize(copyMap.trial)}` + `Resume Current ${capitalize(copyMap.trial)}` ), onClick: ReactivateExperimentModal.open, }, diff --git a/webui/react/src/pages/TrialDetails/TrialInfoBox.test.tsx b/webui/react/src/pages/TrialDetails/TrialInfoBox.test.tsx index b069304bff6..bebc96891fb 100644 --- a/webui/react/src/pages/TrialDetails/TrialInfoBox.test.tsx +++ b/webui/react/src/pages/TrialDetails/TrialInfoBox.test.tsx @@ -45,7 +45,6 @@ const mockExperiment: ExperimentBase = { profiling: { enabled: false }, resources: {}, searcher: { - max_length: { batches: 937, epochs: 1, records: 1 }, max_trials: 16, metric: 'val_loss', name: 'adaptive_asha' as const, diff --git a/webui/react/src/services/api-ts-sdk/api.ts b/webui/react/src/services/api-ts-sdk/api.ts index ff48702c2d8..f6a2db5e811 100644 --- a/webui/react/src/services/api-ts-sdk/api.ts +++ b/webui/react/src/services/api-ts-sdk/api.ts @@ -2084,19 +2084,6 @@ export interface V1CleanupLogsResponse { */ removedCount: string; } -/** - * Close a trial with given ID. - * @export - * @interface V1CloseTrialOperation - */ -export interface V1CloseTrialOperation { - /** - * The ID of the trial to close. - * @type {string} - * @memberof V1CloseTrialOperation - */ - requestId?: string; -} /** * Active notice from the server admin. * @export @@ -2252,32 +2239,6 @@ export interface V1CompareTrialsResponse { */ trials: Array; } -/** - * - * @export - * @interface V1CompleteTrialSearcherValidationResponse - */ -export interface V1CompleteTrialSearcherValidationResponse { -} -/** - * Used to complete a ValidateAfterOperation. - * @export - * @interface V1CompleteValidateAfterOperation - */ -export interface V1CompleteValidateAfterOperation { - /** - * The ValidateAfterOperation being completed. - * @type {V1ValidateAfterOperation} - * @memberof V1CompleteValidateAfterOperation - */ - op?: V1ValidateAfterOperation; - /** - * The value of searcher metric associated with this completed operation. The metric provided should be the metric used to guide HP search. - * @type {any} - * @memberof V1CompleteValidateAfterOperation - */ - searcherMetric?: any; -} /** * The config to be patched into Master Config. * @export @@ -2546,25 +2507,6 @@ export interface V1CreateGroupResponse { */ group: V1GroupDetails; } -/** - * Create a trial with given hyperparameters. - * @export - * @interface V1CreateTrialOperation - */ -export interface V1CreateTrialOperation { - /** - * The ID of the trial to create. - * @type {string} - * @memberof V1CreateTrialOperation - */ - requestId?: string; - /** - * A JSON object representing the hyperparameters. - * @type {string} - * @memberof V1CreateTrialOperation - */ - hyperparams?: string; -} /** * Create a trial. * @export @@ -3406,44 +3348,6 @@ export interface V1ExperimentActionResult { */ id: number; } -/** - * ExperimentInactive is a searcher event triggered when an experiment is no longer active. - * @export - * @interface V1ExperimentInactive - */ -export interface V1ExperimentInactive { - /** - * Current state of the experiment. - * @type {Experimentv1State} - * @memberof V1ExperimentInactive - */ - experimentState: Experimentv1State; -} -/** - * ExperimentSimulation holds the configuration and results of simulated run of a searcher. - * @export - * @interface V1ExperimentSimulation - */ -export interface V1ExperimentSimulation { - /** - * The simulated experiment config. - * @type {any} - * @memberof V1ExperimentSimulation - */ - config?: any; - /** - * The searcher simulation seed. - * @type {number} - * @memberof V1ExperimentSimulation - */ - seed?: number; - /** - * The list of trials in the simulation. - * @type {Array} - * @memberof V1ExperimentSimulation - */ - trials?: Array; -} /** * Response to ExpMetricNamesRequest. * @export @@ -4058,25 +3962,6 @@ export interface V1GetCommandsResponse { */ pagination?: V1Pagination; } -/** - * - * @export - * @interface V1GetCurrentTrialSearcherOperationResponse - */ -export interface V1GetCurrentTrialSearcherOperationResponse { - /** - * The current searcher operation. - * @type {V1TrialOperation} - * @memberof V1GetCurrentTrialSearcherOperationResponse - */ - op?: V1TrialOperation; - /** - * The status of the searcher operation. - * @type {boolean} - * @memberof V1GetCurrentTrialSearcherOperationResponse - */ - completed?: boolean; -} /** * Response to GetExperimentCheckpointsRequest. * @export @@ -4967,19 +4852,6 @@ export interface V1GetRunMetadataResponse { */ metadata?: any; } -/** - * Response to GetSearcherEventsRequest. - * @export - * @interface V1GetSearcherEventsResponse - */ -export interface V1GetSearcherEventsResponse { - /** - * The list of events in the queue. - * @type {Array} - * @memberof V1GetSearcherEventsResponse - */ - searcherEvents?: Array; -} /** * Response to GetShellRequest. * @export @@ -5674,19 +5546,6 @@ export interface V1IdleNotebookRequest { */ export interface V1IdleNotebookResponse { } -/** - * InitialOperations is a searcher event signaling the creation of an experiment. - * @export - * @interface V1InitialOperations - */ -export interface V1InitialOperations { - /** - * Cannot have an empty message type. - * @type {number} - * @memberof V1InitialOperations - */ - placeholder?: number; -} /** * Int32 filters. * @export @@ -8729,38 +8588,6 @@ export interface V1PostRunMetadataResponse { */ metadata?: any; } -/** - * Request for sending operations from a custom search method. - * @export - * @interface V1PostSearcherOperationsRequest - */ -export interface V1PostSearcherOperationsRequest { - /** - * The experiment ID. - * @type {number} - * @memberof V1PostSearcherOperationsRequest - */ - experimentId?: number; - /** - * List of operations to submit. - * @type {Array} - * @memberof V1PostSearcherOperationsRequest - */ - searcherOperations?: Array; - /** - * The event that triggered the client to send these operations to the master. - * @type {V1SearcherEvent} - * @memberof V1PostSearcherOperationsRequest - */ - triggeredByEvent?: V1SearcherEvent; -} -/** - * Response to PostSearcherOperationsResponse. - * @export - * @interface V1PostSearcherOperationsResponse - */ -export interface V1PostSearcherOperationsResponse { -} /** * Request to PostTaskLogs. * @export @@ -9050,11 +8877,11 @@ export interface V1PreviewHPSearchRequest { */ export interface V1PreviewHPSearchResponse { /** - * The resulting simulation. - * @type {V1ExperimentSimulation} + * The resulting summary. + * @type {V1SearchSummary} * @memberof V1PreviewHPSearchResponse */ - simulation?: V1ExperimentSimulation; + summary?: V1SearchSummary; } /** * Project is a named collection of experiments. @@ -10742,36 +10569,6 @@ export interface V1RunActionResult { */ id: number; } -/** - * RunnableOperation represents a single runnable operation emitted by a searcher. - * @export - * @interface V1RunnableOperation - */ -export interface V1RunnableOperation { - /** - * This is the type of the operation. - * @type {V1RunnableType} - * @memberof V1RunnableOperation - */ - type?: V1RunnableType; - /** - * If the type == WORKLOAD_KIND_TRAIN, this is the number of units - * @type {string} - * @memberof V1RunnableOperation - */ - length?: string; -} -/** - * RunnableType defines the type of operation that should be executed by trial runners. - RUNNABLE_TYPE_UNSPECIFIED: Denotes an unknown runnable type. - RUNNABLE_TYPE_TRAIN: Signals to a trial runner that it should run a train. - RUNNABLE_TYPE_VALIDATE: Signals to a trial runner it should compute validation metrics. - * @export - * @enum {string} - */ -export const V1RunnableType = { - UNSPECIFIED: 'RUNNABLE_TYPE_UNSPECIFIED', - TRAIN: 'RUNNABLE_TYPE_TRAIN', - VALIDATE: 'RUNNABLE_TYPE_VALIDATE', -} as const -export type V1RunnableType = ValueOf /** * Request to prepare to start reporting to a run. * @export @@ -10857,98 +10654,6 @@ export interface V1SearchActionResult { */ id: number; } -/** - * SearcherEvent is a message from master to a client-driven custom searcher informing it of relevant changes in the state of an experiment. - * @export - * @interface V1SearcherEvent - */ -export interface V1SearcherEvent { - /** - * Incremental ID of the event. - * @type {number} - * @memberof V1SearcherEvent - */ - id: number; - /** - * An experiment has just been created. - * @type {V1InitialOperations} - * @memberof V1SearcherEvent - */ - initialOperations?: V1InitialOperations; - /** - * A trial has been created. - * @type {V1TrialCreated} - * @memberof V1SearcherEvent - */ - trialCreated?: V1TrialCreated; - /** - * Validation has completed. - * @type {V1ValidationCompleted} - * @memberof V1SearcherEvent - */ - validationCompleted?: V1ValidationCompleted; - /** - * Trial has finished. - * @type {V1TrialClosed} - * @memberof V1SearcherEvent - */ - trialClosed?: V1TrialClosed; - /** - * Trial exited early. - * @type {V1TrialExitedEarly} - * @memberof V1SearcherEvent - */ - trialExitedEarly?: V1TrialExitedEarly; - /** - * Trial progress. - * @type {V1TrialProgress} - * @memberof V1SearcherEvent - */ - trialProgress?: V1TrialProgress; - /** - * Experiment is inactive. - * @type {V1ExperimentInactive} - * @memberof V1SearcherEvent - */ - experimentInactive?: V1ExperimentInactive; -} -/** - * SearcherOperation is an operation issued by the custom searcher. - * @export - * @interface V1SearcherOperation - */ -export interface V1SearcherOperation { - /** - * TrialOperation is issued to tell an existing trial to do something. - * @type {V1TrialOperation} - * @memberof V1SearcherOperation - */ - trialOperation?: V1TrialOperation; - /** - * CreateTrialOperation is issued to create a trial. - * @type {V1CreateTrialOperation} - * @memberof V1SearcherOperation - */ - createTrial?: V1CreateTrialOperation; - /** - * CloseTrialOperation is issued to close a trial. - * @type {V1CloseTrialOperation} - * @memberof V1SearcherOperation - */ - closeTrial?: V1CloseTrialOperation; - /** - * ShutDownOperation is issued to shut down the custom search method. - * @type {V1ShutDownOperation} - * @memberof V1SearcherOperation - */ - shutDown?: V1ShutDownOperation; - /** - * SetSearcherProgressOperation is issued to set the progress of the custom search method. - * @type {V1SetSearcherProgressOperation} - * @memberof V1SearcherOperation - */ - setSearcherProgress?: V1SetSearcherProgressOperation; -} /** * * @export @@ -11124,6 +10829,50 @@ export interface V1SearchRunsResponse { */ pagination: V1Pagination; } +/** + * SearchSummary contains the estimated trials and training lengths that a search plans to execute. + * @export + * @interface V1SearchSummary + */ +export interface V1SearchSummary { + /** + * The searcher config from which the summary is generated. + * @type {any} + * @memberof V1SearchSummary + */ + config: any; + /** + * A list of planned number of trials to their training lengths. + * @type {Array} + * @memberof V1SearchSummary + */ + trials?: Array; +} +/** + * SearchUnit describes a length unit used by some searchers to manage training. + * @export + * @interface V1SearchUnit + */ +export interface V1SearchUnit { + /** + * Name of the length unit (if max_length is false). + * @type {string} + * @memberof V1SearchUnit + */ + name?: string; + /** + * Value of the length unit (if max_length is false). + * @type {number} + * @memberof V1SearchUnit + */ + value?: number; + /** + * Bool indicating whether the training length is defined in code. + * @type {boolean} + * @memberof V1SearchUnit + */ + maxLength: boolean; +} /** * Set the cluster-wide message. * @export @@ -11252,19 +11001,6 @@ export interface V1SetResourceQuotasRequest { */ export interface V1SetResourceQuotasResponse { } -/** - * SetSearcherProgressOperation informs the master of the progress of the custom searcher. - * @export - * @interface V1SetSearcherProgressOperation - */ -export interface V1SetSearcherProgressOperation { - /** - * Experiment progress as a float between 0.0 and 1.0. - * @type {number} - * @memberof V1SetSearcherProgressOperation - */ - progress?: number; -} /** * Set the priority of the requested shell. * @export @@ -11477,25 +11213,6 @@ export interface V1Shell { */ workspaceId: number; } -/** - * Shut down custom searcher method. - * @export - * @interface V1ShutDownOperation - */ -export interface V1ShutDownOperation { - /** - * Defines whether the Searcher was cancelled - * @type {boolean} - * @memberof V1ShutDownOperation - */ - cancel?: boolean; - /** - * Defines whether the Searcher failed - * @type {boolean} - * @memberof V1ShutDownOperation - */ - failure?: boolean; -} /** * Slot wraps a single device on the agent. * @export @@ -12142,32 +11859,6 @@ export const V1TokenType = { ACCESSTOKEN: 'TOKEN_TYPE_ACCESS_TOKEN', } as const export type V1TokenType = ValueOf -/** - * TrialClosed is a searcher event triggered when a trial has successfully finished. - * @export - * @interface V1TrialClosed - */ -export interface V1TrialClosed { - /** - * UUID identifying the trial to the searcher. - * @type {string} - * @memberof V1TrialClosed - */ - requestId: string; -} -/** - * TrialCreated is a searcher event signaling the creation of a trial. - * @export - * @interface V1TrialCreated - */ -export interface V1TrialCreated { - /** - * UUID identifying the trial to the searcher. - * @type {string} - * @memberof V1TrialCreated - */ - requestId: string; -} /** * Signals to the experiment the trial early exited. * @export @@ -12192,37 +11883,6 @@ export const V1TrialEarlyExitExitedReason = { INITINVALIDHP: 'EXITED_REASON_INIT_INVALID_HP', } as const export type V1TrialEarlyExitExitedReason = ValueOf -/** - * TrialExitedEarly is a searcher event triggered when a trial exited prematurely. - * @export - * @interface V1TrialExitedEarly - */ -export interface V1TrialExitedEarly { - /** - * UUID identifying the trial to the searcher. - * @type {string} - * @memberof V1TrialExitedEarly - */ - requestId: string; - /** - * The reason for the exit. - * @type {V1TrialExitedEarlyExitedReason} - * @memberof V1TrialExitedEarly - */ - exitedReason: V1TrialExitedEarlyExitedReason; -} -/** - * The reason for an early exit. - EXITED_REASON_UNSPECIFIED: Zero-value (not allowed). - EXITED_REASON_INVALID_HP: Indicates the trial exited due to an invalid hyperparameter. - EXITED_REASON_USER_REQUESTED_STOP: Indicates the trial exited due to a user requested stop, from code. - EXITED_REASON_USER_CANCELED: Indicates the trial exited due to a user requested stop, from the CLI or UI. - * @export - * @enum {string} - */ -export const V1TrialExitedEarlyExitedReason = { - UNSPECIFIED: 'EXITED_REASON_UNSPECIFIED', - INVALIDHP: 'EXITED_REASON_INVALID_HP', - USERREQUESTEDSTOP: 'EXITED_REASON_USER_REQUESTED_STOP', - USERCANCELED: 'EXITED_REASON_USER_CANCELED', -} as const -export type V1TrialExitedEarlyExitedReason = ValueOf /** * Response to TrialLogFieldsRequest. * @export @@ -12370,19 +12030,6 @@ export interface V1TrialMetrics { */ metrics: V1Metrics; } -/** - * TrialOperation is any operation that a trial can perform while it is active. - * @export - * @interface V1TrialOperation - */ -export interface V1TrialOperation { - /** - * ValidateAfter means a trial is currently training and will later validate. - * @type {V1ValidateAfterOperation} - * @memberof V1TrialOperation - */ - validateAfter?: V1ValidateAfterOperation; -} /** * * @export @@ -12451,25 +12098,6 @@ export interface V1TrialProfilerMetricsBatch { */ labels: V1TrialProfilerMetricLabels; } -/** - * TrialProgress is a searcher event that tells you the number of batches completed in the trial. - * @export - * @interface V1TrialProgress - */ -export interface V1TrialProgress { - /** - * UUID identifying the trial to the searcher. - * @type {string} - * @memberof V1TrialProgress - */ - requestId: string; - /** - * partial_units represent partial epochs, batches or records where the Unit is implied. - * @type {number} - * @memberof V1TrialProgress - */ - partialUnits: number; -} /** * The metadata pertaining to the current running task for a trial. * @export @@ -12483,25 +12111,6 @@ export interface V1TrialRunnerMetadata { */ state: string; } -/** - * TrialSimulation is a specific sequence of workloads that were run before the trial was completed. - * @export - * @interface V1TrialSimulation - */ -export interface V1TrialSimulation { - /** - * The list of operations that were run before the trial was completed. - * @type {Array} - * @memberof V1TrialSimulation - */ - operations?: Array; - /** - * The number of times that this trial configuration has occurred during the simulation. - * @type {number} - * @memberof V1TrialSimulation - */ - occurrences?: number; -} /** * * @export @@ -12644,6 +12253,25 @@ export interface V1TrialsSnapshotResponseTrial { */ batchesProcessed: number; } +/** + * TrialSummary describes the runs that are estimated to train for a certain length. + * @export + * @interface V1TrialSummary + */ +export interface V1TrialSummary { + /** + * Number of trials. + * @type {number} + * @memberof V1TrialSummary + */ + count: number; + /** + * Training length for the trials. + * @type {V1SearchUnit} + * @memberof V1TrialSummary + */ + unit: V1SearchUnit; +} /** * * @export @@ -13084,59 +12712,15 @@ export interface V1UserWebSetting { value?: string; } /** - * ValidateAfterOperation means the trial should train and validate after training the given length. + * ValidationHistoryEntry is a single entry for a validation history for an experiment. * @export - * @interface V1ValidateAfterOperation + * @interface V1ValidationHistoryEntry */ -export interface V1ValidateAfterOperation { +export interface V1ValidationHistoryEntry { /** - * The ID of the trial that should train. - * @type {string} - * @memberof V1ValidateAfterOperation - */ - requestId?: string; - /** - * The length to train before reporting a validation. - * @type {string} - * @memberof V1ValidateAfterOperation - */ - length?: string; -} -/** - * ValidationCompleted is a searcher event triggered when a validation has been completed. - * @export - * @interface V1ValidationCompleted - */ -export interface V1ValidationCompleted { - /** - * UUID identifying the trial to the searcher. - * @type {string} - * @memberof V1ValidationCompleted - */ - requestId: string; - /** - * Value of the validation metric used to direct the search. - * @type {any} - * @memberof V1ValidationCompleted - */ - metric: any; - /** - * Length from ValidateAfterOperation. - * @type {string} - * @memberof V1ValidationCompleted - */ - validateAfterLength: string; -} -/** - * ValidationHistoryEntry is a single entry for a validation history for an experiment. - * @export - * @interface V1ValidationHistoryEntry - */ -export interface V1ValidationHistoryEntry { - /** - * The id for the trial associated with this validation entry. - * @type {number} - * @memberof V1ValidationHistoryEntry + * The id for the trial associated with this validation entry. + * @type {number} + * @memberof V1ValidationHistoryEntry */ trialId: number; /** @@ -17609,42 +17193,6 @@ export const ExperimentsApiFetchParamCreator = function (configuration?: Configu options: localVarRequestOptions, }; }, - /** - * - * @summary Get the list of custom searcher events with long polling. - * @param {number} experimentId The ID of the experiment. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - getSearcherEvents(experimentId: number, options: any = {}): FetchArgs { - // verify required parameter 'experimentId' is not null or undefined - if (experimentId === null || experimentId === undefined) { - throw new RequiredError('experimentId','Required parameter experimentId was null or undefined when calling getSearcherEvents.'); - } - const localVarPath = `/api/v1/experiments/{experimentId}/searcher_events` - .replace(`{${"experimentId"}}`, encodeURIComponent(String(experimentId))); - const localVarUrlObj = new URL(localVarPath, BASE_PATH); - const localVarRequestOptions = { method: 'GET', ...options }; - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; - - // authentication BearerToken required - if (configuration && configuration.apiKey) { - const localVarApiKeyValue = typeof configuration.apiKey === 'function' - ? configuration.apiKey("Authorization") - : configuration.apiKey; - localVarHeaderParameter["Authorization"] = localVarApiKeyValue; - } - - objToSearchParams(localVarQueryParameter, localVarUrlObj.searchParams); - objToSearchParams(options.query || {}, localVarUrlObj.searchParams); - localVarRequestOptions.headers = { ...localVarHeaderParameter, ...options.headers }; - - return { - url: `${localVarUrlObj.pathname}${localVarUrlObj.search}`, - options: localVarRequestOptions, - }; - }, /** * * @summary Get a single trial. @@ -18075,50 +17623,6 @@ export const ExperimentsApiFetchParamCreator = function (configuration?: Configu options: localVarRequestOptions, }; }, - /** - * - * @summary Submit operations to a custom searcher. - * @param {number} experimentId The experiment ID. - * @param {V1PostSearcherOperationsRequest} body - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - postSearcherOperations(experimentId: number, body: V1PostSearcherOperationsRequest, options: any = {}): FetchArgs { - // verify required parameter 'experimentId' is not null or undefined - if (experimentId === null || experimentId === undefined) { - throw new RequiredError('experimentId','Required parameter experimentId was null or undefined when calling postSearcherOperations.'); - } - // verify required parameter 'body' is not null or undefined - if (body === null || body === undefined) { - throw new RequiredError('body','Required parameter body was null or undefined when calling postSearcherOperations.'); - } - const localVarPath = `/api/v1/experiments/{experimentId}/searcher_operations` - .replace(`{${"experimentId"}}`, encodeURIComponent(String(experimentId))); - const localVarUrlObj = new URL(localVarPath, BASE_PATH); - const localVarRequestOptions = { method: 'POST', ...options }; - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; - - // authentication BearerToken required - if (configuration && configuration.apiKey) { - const localVarApiKeyValue = typeof configuration.apiKey === 'function' - ? configuration.apiKey("Authorization") - : configuration.apiKey; - localVarHeaderParameter["Authorization"] = localVarApiKeyValue; - } - - localVarHeaderParameter['Content-Type'] = 'application/json'; - - objToSearchParams(localVarQueryParameter, localVarUrlObj.searchParams); - objToSearchParams(options.query || {}, localVarUrlObj.searchParams); - localVarRequestOptions.headers = { ...localVarHeaderParameter, ...options.headers }; - localVarRequestOptions.body = JSON.stringify(body) - - return { - url: `${localVarUrlObj.pathname}${localVarUrlObj.search}`, - options: localVarRequestOptions, - }; - }, /** * * @summary Preview hyperparameter search. @@ -18993,25 +18497,6 @@ export const ExperimentsApiFp = function (configuration?: Configuration) { }); }; }, - /** - * - * @summary Get the list of custom searcher events with long polling. - * @param {number} experimentId The ID of the experiment. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - getSearcherEvents(experimentId: number, options?: any): (fetch?: FetchAPI, basePath?: string) => Promise { - const localVarFetchArgs = ExperimentsApiFetchParamCreator(configuration).getSearcherEvents(experimentId, options); - return (fetch: FetchAPI = window.fetch, basePath: string = BASE_PATH) => { - return fetch(basePath + localVarFetchArgs.url, localVarFetchArgs.options).then((response) => { - if (response.status >= 200 && response.status < 300) { - return response.json(); - } else { - throw response; - } - }); - }; - }, /** * * @summary Get a single trial. @@ -19213,26 +18698,6 @@ export const ExperimentsApiFp = function (configuration?: Configuration) { }); }; }, - /** - * - * @summary Submit operations to a custom searcher. - * @param {number} experimentId The experiment ID. - * @param {V1PostSearcherOperationsRequest} body - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - postSearcherOperations(experimentId: number, body: V1PostSearcherOperationsRequest, options?: any): (fetch?: FetchAPI, basePath?: string) => Promise { - const localVarFetchArgs = ExperimentsApiFetchParamCreator(configuration).postSearcherOperations(experimentId, body, options); - return (fetch: FetchAPI = window.fetch, basePath: string = BASE_PATH) => { - return fetch(basePath + localVarFetchArgs.url, localVarFetchArgs.options).then((response) => { - if (response.status >= 200 && response.status < 300) { - return response.json(); - } else { - throw response; - } - }); - }; - }, /** * * @summary Preview hyperparameter search. @@ -19688,16 +19153,6 @@ export const ExperimentsApiFactory = function (configuration?: Configuration, fe getModelDefTree(experimentId: number, options?: any) { return ExperimentsApiFp(configuration).getModelDefTree(experimentId, options)(fetch, basePath); }, - /** - * - * @summary Get the list of custom searcher events with long polling. - * @param {number} experimentId The ID of the experiment. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - getSearcherEvents(experimentId: number, options?: any) { - return ExperimentsApiFp(configuration).getSearcherEvents(experimentId, options)(fetch, basePath); - }, /** * * @summary Get a single trial. @@ -19809,17 +19264,6 @@ export const ExperimentsApiFactory = function (configuration?: Configuration, fe pauseExperiments(projectId: number, body: V1PauseExperimentsRequest, options?: any) { return ExperimentsApiFp(configuration).pauseExperiments(projectId, body, options)(fetch, basePath); }, - /** - * - * @summary Submit operations to a custom searcher. - * @param {number} experimentId The experiment ID. - * @param {V1PostSearcherOperationsRequest} body - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - postSearcherOperations(experimentId: number, body: V1PostSearcherOperationsRequest, options?: any) { - return ExperimentsApiFp(configuration).postSearcherOperations(experimentId, body, options)(fetch, basePath); - }, /** * * @summary Preview hyperparameter search. @@ -20235,18 +19679,6 @@ export class ExperimentsApi extends BaseAPI { return ExperimentsApiFp(this.configuration).getModelDefTree(experimentId, options)(this.fetch, this.basePath) } - /** - * - * @summary Get the list of custom searcher events with long polling. - * @param {number} experimentId The ID of the experiment. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - * @memberof ExperimentsApi - */ - public getSearcherEvents(experimentId: number, options?: any) { - return ExperimentsApiFp(this.configuration).getSearcherEvents(experimentId, options)(this.fetch, this.basePath) - } - /** * * @summary Get a single trial. @@ -20378,19 +19810,6 @@ export class ExperimentsApi extends BaseAPI { return ExperimentsApiFp(this.configuration).pauseExperiments(projectId, body, options)(this.fetch, this.basePath) } - /** - * - * @summary Submit operations to a custom searcher. - * @param {number} experimentId The experiment ID. - * @param {V1PostSearcherOperationsRequest} body - * @param {*} [options] Override http request option. - * @throws {RequiredError} - * @memberof ExperimentsApi - */ - public postSearcherOperations(experimentId: number, body: V1PostSearcherOperationsRequest, options?: any) { - return ExperimentsApiFp(this.configuration).postSearcherOperations(experimentId, body, options)(this.fetch, this.basePath) - } - /** * * @summary Preview hyperparameter search. @@ -21054,50 +20473,6 @@ export const InternalApiFetchParamCreator = function (configuration?: Configurat options: localVarRequestOptions, }; }, - /** - * - * @summary Reports to the searcher that the trial has completed the given searcher operation. - * @param {number} trialId The id of the trial. - * @param {V1CompleteValidateAfterOperation} body The completed operation. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - completeTrialSearcherValidation(trialId: number, body: V1CompleteValidateAfterOperation, options: any = {}): FetchArgs { - // verify required parameter 'trialId' is not null or undefined - if (trialId === null || trialId === undefined) { - throw new RequiredError('trialId','Required parameter trialId was null or undefined when calling completeTrialSearcherValidation.'); - } - // verify required parameter 'body' is not null or undefined - if (body === null || body === undefined) { - throw new RequiredError('body','Required parameter body was null or undefined when calling completeTrialSearcherValidation.'); - } - const localVarPath = `/api/v1/trials/{trialId}/searcher/completed_operation` - .replace(`{${"trialId"}}`, encodeURIComponent(String(trialId))); - const localVarUrlObj = new URL(localVarPath, BASE_PATH); - const localVarRequestOptions = { method: 'POST', ...options }; - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; - - // authentication BearerToken required - if (configuration && configuration.apiKey) { - const localVarApiKeyValue = typeof configuration.apiKey === 'function' - ? configuration.apiKey("Authorization") - : configuration.apiKey; - localVarHeaderParameter["Authorization"] = localVarApiKeyValue; - } - - localVarHeaderParameter['Content-Type'] = 'application/json'; - - objToSearchParams(localVarQueryParameter, localVarUrlObj.searchParams); - objToSearchParams(options.query || {}, localVarUrlObj.searchParams); - localVarRequestOptions.headers = { ...localVarHeaderParameter, ...options.headers }; - localVarRequestOptions.body = JSON.stringify(body) - - return { - url: `${localVarUrlObj.pathname}${localVarUrlObj.search}`, - options: localVarRequestOptions, - }; - }, /** * * @summary Continues an experiment either to make the existing experiment train longer or to retry it. @@ -21516,42 +20891,6 @@ export const InternalApiFetchParamCreator = function (configuration?: Configurat options: localVarRequestOptions, }; }, - /** - * - * @summary Get the current searcher operation. - * @param {number} trialId The id of the trial. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - getCurrentTrialSearcherOperation(trialId: number, options: any = {}): FetchArgs { - // verify required parameter 'trialId' is not null or undefined - if (trialId === null || trialId === undefined) { - throw new RequiredError('trialId','Required parameter trialId was null or undefined when calling getCurrentTrialSearcherOperation.'); - } - const localVarPath = `/api/v1/trials/{trialId}/searcher/operation` - .replace(`{${"trialId"}}`, encodeURIComponent(String(trialId))); - const localVarUrlObj = new URL(localVarPath, BASE_PATH); - const localVarRequestOptions = { method: 'GET', ...options }; - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; - - // authentication BearerToken required - if (configuration && configuration.apiKey) { - const localVarApiKeyValue = typeof configuration.apiKey === 'function' - ? configuration.apiKey("Authorization") - : configuration.apiKey; - localVarHeaderParameter["Authorization"] = localVarApiKeyValue; - } - - objToSearchParams(localVarQueryParameter, localVarUrlObj.searchParams); - objToSearchParams(options.query || {}, localVarUrlObj.searchParams); - localVarRequestOptions.headers = { ...localVarHeaderParameter, ...options.headers }; - - return { - url: `${localVarUrlObj.pathname}${localVarUrlObj.search}`, - options: localVarRequestOptions, - }; - }, /** * * @summary Get task config @@ -24671,26 +24010,6 @@ export const InternalApiFp = function (configuration?: Configuration) { }); }; }, - /** - * - * @summary Reports to the searcher that the trial has completed the given searcher operation. - * @param {number} trialId The id of the trial. - * @param {V1CompleteValidateAfterOperation} body The completed operation. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - completeTrialSearcherValidation(trialId: number, body: V1CompleteValidateAfterOperation, options?: any): (fetch?: FetchAPI, basePath?: string) => Promise { - const localVarFetchArgs = InternalApiFetchParamCreator(configuration).completeTrialSearcherValidation(trialId, body, options); - return (fetch: FetchAPI = window.fetch, basePath: string = BASE_PATH) => { - return fetch(basePath + localVarFetchArgs.url, localVarFetchArgs.options).then((response) => { - if (response.status >= 200 && response.status < 300) { - return response.json(); - } else { - throw response; - } - }); - }; - }, /** * * @summary Continues an experiment either to make the existing experiment train longer or to retry it. @@ -24901,25 +24220,6 @@ export const InternalApiFp = function (configuration?: Configuration) { }); }; }, - /** - * - * @summary Get the current searcher operation. - * @param {number} trialId The id of the trial. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - getCurrentTrialSearcherOperation(trialId: number, options?: any): (fetch?: FetchAPI, basePath?: string) => Promise { - const localVarFetchArgs = InternalApiFetchParamCreator(configuration).getCurrentTrialSearcherOperation(trialId, options); - return (fetch: FetchAPI = window.fetch, basePath: string = BASE_PATH) => { - return fetch(basePath + localVarFetchArgs.url, localVarFetchArgs.options).then((response) => { - if (response.status >= 200 && response.status < 300) { - return response.json(); - } else { - throw response; - } - }); - }; - }, /** * * @summary Get task config @@ -26406,17 +25706,6 @@ export const InternalApiFactory = function (configuration?: Configuration, fetch cleanupLogs(options?: any) { return InternalApiFp(configuration).cleanupLogs(options)(fetch, basePath); }, - /** - * - * @summary Reports to the searcher that the trial has completed the given searcher operation. - * @param {number} trialId The id of the trial. - * @param {V1CompleteValidateAfterOperation} body The completed operation. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - completeTrialSearcherValidation(trialId: number, body: V1CompleteValidateAfterOperation, options?: any) { - return InternalApiFp(configuration).completeTrialSearcherValidation(trialId, body, options)(fetch, basePath); - }, /** * * @summary Continues an experiment either to make the existing experiment train longer or to retry it. @@ -26528,16 +25817,6 @@ export const InternalApiFactory = function (configuration?: Configuration, fetch getBestSearcherValidationMetric(experimentId: number, options?: any) { return InternalApiFp(configuration).getBestSearcherValidationMetric(experimentId, options)(fetch, basePath); }, - /** - * - * @summary Get the current searcher operation. - * @param {number} trialId The id of the trial. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - */ - getCurrentTrialSearcherOperation(trialId: number, options?: any) { - return InternalApiFp(configuration).getCurrentTrialSearcherOperation(trialId, options)(fetch, basePath); - }, /** * * @summary Get task config @@ -27448,19 +26727,6 @@ export class InternalApi extends BaseAPI { return InternalApiFp(this.configuration).cleanupLogs(options)(this.fetch, this.basePath) } - /** - * - * @summary Reports to the searcher that the trial has completed the given searcher operation. - * @param {number} trialId The id of the trial. - * @param {V1CompleteValidateAfterOperation} body The completed operation. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - * @memberof InternalApi - */ - public completeTrialSearcherValidation(trialId: number, body: V1CompleteValidateAfterOperation, options?: any) { - return InternalApiFp(this.configuration).completeTrialSearcherValidation(trialId, body, options)(this.fetch, this.basePath) - } - /** * * @summary Continues an experiment either to make the existing experiment train longer or to retry it. @@ -27594,18 +26860,6 @@ export class InternalApi extends BaseAPI { return InternalApiFp(this.configuration).getBestSearcherValidationMetric(experimentId, options)(this.fetch, this.basePath) } - /** - * - * @summary Get the current searcher operation. - * @param {number} trialId The id of the trial. - * @param {*} [options] Override http request option. - * @throws {RequiredError} - * @memberof InternalApi - */ - public getCurrentTrialSearcherOperation(trialId: number, options?: any) { - return InternalApiFp(this.configuration).getCurrentTrialSearcherOperation(trialId, options)(this.fetch, this.basePath) - } - /** * * @summary Get task config diff --git a/webui/react/src/types.ts b/webui/react/src/types.ts index 5ef5b089d0b..c39abd6159b 100644 --- a/webui/react/src/types.ts +++ b/webui/react/src/types.ts @@ -456,10 +456,6 @@ export const ContinuableNonSingleSearcherName = new Set( const Searcher = t.intersection([ t.partial({ - max_length: t.record( - t.union([t.literal('batches'), t.literal('records'), t.literal('epochs')]), - t.number, - ), max_trials: t.number, sourceTrialId: t.number, }), diff --git a/webui/react/src/utils/experiment.test.ts b/webui/react/src/utils/experiment.test.ts index aeeb9dc3127..e15d713e850 100644 --- a/webui/react/src/utils/experiment.test.ts +++ b/webui/react/src/utils/experiment.test.ts @@ -114,10 +114,6 @@ describe('Experiment Utilities', () => { input: { min_validation_period: 32 }, output: { min_validation_period: { batches: 3200 } }, }, - { - input: { searcher: { max_steps: 10 } }, - output: { searcher: { max_length: { batches: 1000 } } }, - }, { input: { searcher: { step_budget: 100 } }, output: { searcher: { budget: { batches: 10000 } } }, @@ -126,10 +122,6 @@ describe('Experiment Utilities', () => { input: { searcher: { steps_per_round: 2 } }, output: { searcher: { length_per_round: { batches: 200 } } }, }, - { - input: { searcher: { target_trial_steps: 10 } }, - output: { searcher: { max_length: { batches: 1000 } } }, - }, ]; tests.forEach((test) => { expect(utils.upgradeConfig(test.input)).toStrictEqual(test.output); diff --git a/webui/react/src/utils/experiment.ts b/webui/react/src/utils/experiment.ts index 032124abfa6..3827e64a013 100644 --- a/webui/react/src/utils/experiment.ts +++ b/webui/react/src/utils/experiment.ts @@ -86,8 +86,6 @@ const stepRemovalTranslations = [ { oldName: 'min_validation_period' }, { newName: 'searcher.budget', oldName: 'searcher.step_budget' }, { newName: 'searcher.length_per_round', oldName: 'searcher.steps_per_round' }, - { newName: 'searcher.max_length', oldName: 'searcher.max_steps' }, - { newName: 'searcher.max_length', oldName: 'searcher.target_trial_steps' }, ]; const getLengthFromStepCount = (config: RawJson, stepCount: number): [string, number] => { @@ -326,18 +324,6 @@ export const getExperimentName = (config: RawJson): string => { return config.name || ''; }; -// For unitless searchers, this will return undefined. -export const getMaxLengthType = (config: RawJson): string | undefined => { - return (Object.keys(config.searcher?.max_length || {}) || [])[0]; -}; - -export const getMaxLengthValue = (config: RawJson): number => { - const value = (Object.keys(config.searcher?.max_length || {}) || [])[0]; - return value - ? parseInt(config.searcher?.max_length[value]) - : parseInt(config.searcher?.max_length); -}; - export const trialContinueConfig = ( experimentConfig: RawJson, trialHparams: TrialHyperparameters, @@ -351,7 +337,6 @@ export const trialContinueConfig = ( hyperparameters: trialHParamsToExperimentHParams(trialHparams), project: projectName, searcher: { - max_length: experimentConfig.searcher.max_length, metric: experimentConfig.searcher.metric, name: 'single', smaller_is_better: experimentConfig.searcher.smaller_is_better, diff --git a/webui/react/src/utils/tests/generateTestData.ts b/webui/react/src/utils/tests/generateTestData.ts index d9b11d6452e..b6bcdb4adfc 100644 --- a/webui/react/src/utils/tests/generateTestData.ts +++ b/webui/react/src/utils/tests/generateTestData.ts @@ -177,7 +177,6 @@ export const generateTestExperimentData = (): { bracket_rungs: [], divisor: 4, max_concurrent_trials: 16, - max_length: { batches: 937 }, max_rungs: 5, max_trials: 16, metric: 'validation_loss', @@ -225,7 +224,6 @@ export const generateTestExperimentData = (): { name: mnist_pytorch_adaptive_search records_per_epoch: 10 searcher: - max_length: {batches: 937} max_trials: 16 metric: validation_loss name: adaptive_asha