diff --git a/release/BUILD b/release/BUILD index 9f1d9f62d580..9811553de567 100644 --- a/release/BUILD +++ b/release/BUILD @@ -236,6 +236,7 @@ py_test( srcs = test_srcs, env = { "IS_SMOKE_TEST": "1", + "RAY_AIR_NEW_PERSISTENCE_MODE": "1", }, main = "golden_notebook_tests/workloads/torch_tune_serve_test.py", tags = [ diff --git a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py index 8c931141f0c7..8d9c54ec7ed9 100644 --- a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py +++ b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py @@ -9,7 +9,6 @@ import ray from ray.train import ScalingConfig, RunConfig from ray.train._checkpoint import Checkpoint -from ray.air.util.node import _force_on_current_node from ray.tune.tune_config import TuneConfig import requests import torch @@ -127,33 +126,25 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False): ), run_config=RunConfig( verbose=1, - storage_path="/mnt/cluster_storage", + storage_path=( + "/mnt/cluster_storage" + if os.path.exists("/mnt/cluster_storage") + else None + ), ), ) return tuner.fit() -def get_remote_model(remote_model_checkpoint_path): - if ray.util.client.ray.is_connected(): - remote_load = ray.remote(get_model) - remote_load = _force_on_current_node(remote_load) - return ray.get(remote_load.remote(remote_model_checkpoint_path)) - else: - get_best_model_remote = ray.remote(get_model) - return ray.get(get_best_model_remote.remote(remote_model_checkpoint_path)) - - -def get_model(model_checkpoint_path): +def get_model(checkpoint_dir: str): model = resnet18() model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False) - checkpoint = Checkpoint(path=model_checkpoint_path) - with checkpoint.as_directory() as checkpoint_dir: - model_state_dict = torch.load( - os.path.join(checkpoint_dir, "model.pt"), map_location="cpu" - ) - model.load_state_dict(model_state_dict) + model_state_dict = torch.load( + os.path.join(checkpoint_dir, "model.pt"), map_location="cpu" + ) + model.load_state_dict(model_state_dict) return model @@ -275,13 +266,13 @@ def stop_ray(): use_gpu = not args.smoke_test print("Training model.") - analysis = train_mnist(args.smoke_test, num_workers, use_gpu)._experiment_analysis + result_grid = train_mnist(args.smoke_test, num_workers, use_gpu) print("Retrieving best model.") - best_checkpoint_path = analysis.get_best_checkpoint( - analysis.best_trial, return_path=True - ) - model = get_remote_model(best_checkpoint_path) + best_result = result_grid.get_best_result() + best_checkpoint = best_result.get_best_checkpoint(metric="val_loss", mode="min") + with best_checkpoint.as_directory() as checkpoint_dir: + model = get_model(checkpoint_dir) print("Setting up Serve.") setup_serve(model, use_gpu) diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 9609a844b9ae..87acaa5a0a66 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -2173,7 +2173,7 @@ run: timeout: 600 - script: python workloads/torch_tune_serve_test.py + script: RAY_AIR_NEW_PERSISTENCE_MODE=1 python workloads/torch_tune_serve_test.py wait_for_nodes: num_nodes: 2