diff --git a/apps/crc_interactive.py b/apps/crc_interactive.py index 84b382a..35fcb8e 100755 --- a/apps/crc_interactive.py +++ b/apps/crc_interactive.py @@ -116,7 +116,7 @@ def parse_args(self, args=None, namespace=None) -> Namespace: # Set defaults that need to be determined dynamically if not args.num_gpus: - args.num_gpus = 1 if args.gpu else 0 + args.num_gpus = 1 if (args.gpu or (args.teach and (args.partition == 'gpu'))) else 0 # Check wall time is between limits, enable both %H:%M format and integer hours check_time = args.time.hour + args.time.minute / 60 + args.time.second / 3600 @@ -171,7 +171,7 @@ def create_srun_command(self, args: Namespace) -> str: srun_args += ' ' + srun_arg_name.format(arg_value) # The --gres argument in srun needs some special handling so is missing from the above dict - if (args.gpu or args.invest) and args.num_gpus: + if (args.gpu or args.invest or (args.teach and (args.partition == 'gpu'))) and args.num_gpus: srun_args += ' ' + f'--gres=gpu:{args.num_gpus}' try: