Skip to content

Commit

Permalink
[air/release] Fix golden_notebook_torch_tune_serve_test release + s…
Browse files Browse the repository at this point in the history
…moke tests (ray-project#38651)

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Aug 22, 2023
1 parent 992d99b commit 543e096
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 25 deletions.
1 change: 1 addition & 0 deletions release/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ py_test(
srcs = test_srcs,
env = {
"IS_SMOKE_TEST": "1",
"RAY_AIR_NEW_PERSISTENCE_MODE": "1",
},
main = "golden_notebook_tests/workloads/torch_tune_serve_test.py",
tags = [
Expand Down
39 changes: 15 additions & 24 deletions release/golden_notebook_tests/workloads/torch_tune_serve_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import ray
from ray.train import ScalingConfig, RunConfig
from ray.train._checkpoint import Checkpoint
from ray.air.util.node import _force_on_current_node
from ray.tune.tune_config import TuneConfig
import requests
import torch
Expand Down Expand Up @@ -127,33 +126,25 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False):
),
run_config=RunConfig(
verbose=1,
storage_path="/mnt/cluster_storage",
storage_path=(
"/mnt/cluster_storage"
if os.path.exists("/mnt/cluster_storage")
else None
),
),
)

return tuner.fit()


def get_remote_model(remote_model_checkpoint_path):
if ray.util.client.ray.is_connected():
remote_load = ray.remote(get_model)
remote_load = _force_on_current_node(remote_load)
return ray.get(remote_load.remote(remote_model_checkpoint_path))
else:
get_best_model_remote = ray.remote(get_model)
return ray.get(get_best_model_remote.remote(remote_model_checkpoint_path))


def get_model(model_checkpoint_path):
def get_model(checkpoint_dir: str):
model = resnet18()
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)

checkpoint = Checkpoint(path=model_checkpoint_path)
with checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(
os.path.join(checkpoint_dir, "model.pt"), map_location="cpu"
)
model.load_state_dict(model_state_dict)
model_state_dict = torch.load(
os.path.join(checkpoint_dir, "model.pt"), map_location="cpu"
)
model.load_state_dict(model_state_dict)

return model

Expand Down Expand Up @@ -275,13 +266,13 @@ def stop_ray():
use_gpu = not args.smoke_test

print("Training model.")
analysis = train_mnist(args.smoke_test, num_workers, use_gpu)._experiment_analysis
result_grid = train_mnist(args.smoke_test, num_workers, use_gpu)

print("Retrieving best model.")
best_checkpoint_path = analysis.get_best_checkpoint(
analysis.best_trial, return_path=True
)
model = get_remote_model(best_checkpoint_path)
best_result = result_grid.get_best_result()
best_checkpoint = best_result.get_best_checkpoint(metric="val_loss", mode="min")
with best_checkpoint.as_directory() as checkpoint_dir:
model = get_model(checkpoint_dir)

print("Setting up Serve.")
setup_serve(model, use_gpu)
Expand Down
2 changes: 1 addition & 1 deletion release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2173,7 +2173,7 @@

run:
timeout: 600
script: python workloads/torch_tune_serve_test.py
script: RAY_AIR_NEW_PERSISTENCE_MODE=1 python workloads/torch_tune_serve_test.py
wait_for_nodes:
num_nodes: 2

Expand Down

0 comments on commit 543e096

Please sign in to comment.