Skip to content

Commit

Permalink
Add tests for creating srun command
Browse files Browse the repository at this point in the history
  • Loading branch information
djperrefort authored Aug 19, 2024
1 parent 2125d2e commit 3cb4ae4
Showing 1 changed file with 178 additions and 2 deletions.
180 changes: 178 additions & 2 deletions tests/test_crc_interactive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Tests for the ``crc-interactive`` application."""

import unittest
from argparse import ArgumentTypeError
from argparse import ArgumentTypeError, Namespace
from datetime import time
from unittest import TestCase

Expand Down Expand Up @@ -63,3 +62,180 @@ def test_empty_string(self) -> None:

with self.assertRaises(ArgumentTypeError):
CrcInteractive.parse_time('')


class TestCrcInteractive(TestCase):
"""Test the CrcInteractive class."""

def setUp(self) -> None:
"""Set up the test environment."""

self.parser = CrcInteractive()

def test_default_command(self) -> None:
"""Test the default srun command."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=False,
invest=False,
htc=False,
teach=False,
partition=None,
mem=1,
time=time(1, 0),
num_nodes=1,
num_cores=1,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=False
)

expected_command = 'srun -M smp --export=ALL --mem=1g --time=01:00:00 --nodes=1 --ntasks-per-node=1 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_gpu_command(self) -> None:
"""Test srun command for GPU."""

args = Namespace(
print_command=False,
smp=False,
gpu=True,
mpi=False,
invest=False,
htc=False,
teach=False,
partition=None,
mem=2,
time=time(2, 0),
num_nodes=2,
num_cores=4,
num_gpus=1,
account=None,
reservation=None,
license=None,
feature=None,
openmp=False
)

expected_command = 'srun -M gpu --export=ALL --mem=2g --time=02:00:00 --nodes=2 --ntasks-per-node=4 --gres=gpu:1 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_mpi_command(self) -> None:
"""Test srun command for MPI."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=True,
invest=False,
htc=False,
teach=False,
partition='mpi',
mem=4,
time=time(3, 0),
num_nodes=3,
num_cores=48,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=False
)

expected_command = 'srun -M mpi --export=ALL --mem=4g --time=03:00:00 --nodes=3 --ntasks-per-node=48 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_invest_command(self) -> None:
"""Test srun command for the invest cluster."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=False,
invest=True,
htc=False,
teach=False,
partition='invest-partition',
mem=2,
time=time(1, 0),
num_nodes=1,
num_cores=4,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=False
)

expected_command = 'srun -M invest --export=ALL --mem=2g --time=01:00:00 --nodes=1 --ntasks-per-node=4 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_openmp_command(self) -> None:
"""Test srun command for OpenMP."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=False,
invest=False,
htc=False,
teach=False,
partition=None,
mem=1,
time=time(1, 0),
num_nodes=1,
num_cores=4,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=True
)

expected_command = 'srun -M smp --export=ALL --mem=1g --time=01:00:00 --nodes=1 --cpus-per-task=4 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_partition_specific_cores(self) -> None:
"""Test srun command with partition-specific core requirements."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=True,
invest=False,
htc=False,
teach=False,
partition='opa-high-mem',
mem=8,
time=time(2, 0),
num_nodes=2,
num_cores=28,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=False
)

expected_command = 'srun -M mpi --export=ALL --mem=8g --time=02:00:00 --nodes=2 --ntasks-per-node=28 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

0 comments on commit 3cb4ae4

Please sign in to comment.