Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update main.py #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions nvas3d/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ def setup_distributed_training(local_rank):
world_size=world_size)
device = torch.device(f'cuda:{local_rank}')
return device

def validate_args(args): # New Function to Validate Input Parameters
"""
Validates the input arguments to ensure the configuration file exists and the GPU ID is valid.

This function performs two primary checks:
1. It verifies the existence of the specified configuration file.
2. It checks whether the provided GPU ID is within the valid range of available GPUs.

Parameters:
args (argparse.Namespace): Parsed command-line arguments containing the paths and GPU settings.

Raises:
FileNotFoundError: If the specified configuration file does not exist.
ValueError: If the specified GPU ID is not within the valid range.
"""
if not os.path.isfile(args.config):
raise FileNotFoundError(f"Configuration file not found: {args.config}")
if args.gpu is not None and (args.gpu < 0 or args.gpu >= torch.cuda.device_count()):
raise ValueError(f"Invalid GPU ID {args.gpu}. Must be between 0 and {torch.cuda.device_count() - 1}.")


def main(local_rank, args):
Expand All @@ -44,6 +64,10 @@ def main(local_rank, args):
save_dir = os.path.join(config['save_dir'], f'{args.exp}')
os.makedirs(save_dir, exist_ok=True)
shutil.copy(args.config, f'{save_dir}/config.yaml')
# shutil.copy(args.config, os.path.join(save_dir, 'config.yaml')) # Modify Path Concatenation Method to Improve Readability???

logging.info(f"Configuration file loaded from {args.config}") # Log the successful loading of the configuration file.
logging.info(f"Experiment directory created at {save_dir}") # Log the creation of the experiment directory.

# Initialize DataLoader
data_loader = SSAVDataLoader(config['use_visual'], config['use_deconv'], is_ddp, **config['data_loader'])
Expand All @@ -53,6 +77,7 @@ def main(local_rank, args):
model = model.to(device)
if is_ddp:
model = DistributedDataParallel(model, device_ids=[local_rank])
logging.info("Model initialized and moved to device") # Log the successful initialization and deployment of the model to the specified device.

# Train the model
trainer = Trainer(model, data_loader, device, save_dir, config['use_deconv'], config['training'])
Expand Down Expand Up @@ -84,6 +109,16 @@ def main(local_rank, args):
help='GPU ID to use')

args = parser.parse_args()
"""
Call the validate_args Function within if __name__ == '__main__':,
Capture Exceptions during Parameter Validation, Log Errors, and Exit the Program
"""

try:
validate_args(args)
except (FileNotFoundError, ValueError) as e:
logging.error(e)
exit(1)

if args.gpu is not None:
main(0, args) # Single GPU mode
Expand Down