Skip to content

Commit

Permalink
Adds tests for default gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
djperrefort committed Aug 19, 2024
1 parent acb915f commit 2c1ca3c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
24 changes: 12 additions & 12 deletions apps/crc_interactive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A simple wrapper around the Slurm ``srun`` command
"""A simple wrapper around the Slurm `srun` command.
The application launches users into an interactive Slurm session on a
user-selected cluster and (if specified) partition. Dedicated command line
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self) -> None:

@staticmethod
def parse_time(time_str: str) -> time:
"""Parse a string representation of time in 'HH:MM:SS' format and return a time object
"""Parse a string representation of time in 'HH:MM:SS' format and return a time object.
Args:
time_str: A string representing time in 'HH:MM:SS' format.
Expand All @@ -110,10 +110,14 @@ def parse_time(time_str: str) -> time:
raise ArgumentTypeError(f'Could not parse time value {time_str}')

def parse_args(self, args=None, namespace=None) -> Namespace:
"""Parse command line arguments"""
"""Parse command line arguments."""

args = super().parse_args(args, namespace)

# Set defaults that need to be determined dynamically
if not args.num_gpus:
args.num_gpus = 1 if args.gpu else 0

# Check wall time is between limits, enable both %H:%M format and integer hours
check_time = args.time.hour + args.time.minute / 60 + args.time.second / 3600
if not self.min_time <= check_time <= self.max_time:
Expand All @@ -137,13 +141,13 @@ def parse_args(self, args=None, namespace=None) -> Namespace:
return args

def create_srun_command(self, args: Namespace) -> str:
"""Create an ``srun`` command based on parsed command line arguments
"""Create an `srun` command based on parsed command line arguments.
Args:
args: A dictionary of parsed command line parsed_args
args: A dictionary of parsed command line parsed_args.
Return:
The equivalent ``srun`` command as a string
The equivalent `srun` command as a string.
"""

# Map arguments from the parent application to equivalent srun arguments
Expand Down Expand Up @@ -179,20 +183,16 @@ def create_srun_command(self, args: Namespace) -> str:
return f'srun -M {cluster_to_run} {srun_args} --pty bash'

def app_logic(self, args: Namespace) -> None:
"""Logic to evaluate when executing the application
"""Logic to evaluate when executing the application.
Args:
args: Parsed command line arguments
args: Parsed command line arguments.
"""

if not any(getattr(args, cluster, False) for cluster in Slurm.get_cluster_names()):
self.print_help()
self.exit()

# Set defaults that need to be determined dynamically
if not args.num_gpus:
args.num_gpus = 1 if args.gpu else 0

# Create the slurm command
srun_command = self.create_srun_command(args)

Expand Down
11 changes: 10 additions & 1 deletion tests/test_crc_interactive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for the ``crc-interactive`` application."""
"""Tests for the `crc-interactive` application."""

from argparse import ArgumentTypeError, Namespace
from datetime import time
Expand All @@ -25,6 +25,15 @@ def test_args_match_class_settings(self) -> None:
self.assertEqual(CrcInteractive.default_mem, args.mem)
self.assertEqual(CrcInteractive.default_gpus, args.num_gpus)

def test_default_gpus(self) -> None:
"""Test the default number of GPUs is determined dynamically by cluster."""

smp_args = self.app.parse_args(['--smp'])
self.assertEqual(0, smp_args.num_gpus)

gpu_args = self.app.parse_args(['--gpu'])
self.assertEqual(1, gpu_args.num_gpus)

def test_time_argument_out_of_range(self) -> None:
"""Test invalid time arguments raise an error."""

Expand Down

0 comments on commit 2c1ca3c

Please sign in to comment.