Skip to content

Commit

Permalink
tried to fix parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
vganapati committed Jun 14, 2023
1 parent f4cbebb commit 5510f0e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
24 changes: 19 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
This must be run in the shell/SLURM before running this script:
For NERSC:
export MASTER_ADDR=$(hostname)
For interactive session:
export SLURM_NTASKS=4
For other servers:
export MASTER_ADDR=localhost
Expand Down Expand Up @@ -90,8 +92,14 @@ def get_args():
return args


def setup(rank, world_size, fn, args, backend='gloo'):
def setup(rank, world_size, fn, args, backend='nccl'):
os.environ['MASTER_PORT'] = '29500'

# Get the SLURM_PROCID for the current process
proc_id = int(os.environ['SLURM_PROCID'])

print("Hello from " + str(proc_id))
print(get_rank(args.use_dist))
# initialize the process group
dist.init_process_group(backend, rank=rank, world_size=world_size)
fn(rank,world_size, args) # this will be the run function
Expand Down Expand Up @@ -180,6 +188,11 @@ def run(rank, world_size, args,
if args.use_dist:
print("Running on rank " + str(rank) + ".")

proc_id = int(os.environ['SLURM_PROCID'])

print("Hello from " + str(proc_id))
print(get_rank(args.use_dist))

train_set, train_set_2, test_set, bsz = partition_dataset(args, world_size)

# Force num_basis = 1 if not using pde-cl
Expand Down Expand Up @@ -407,12 +420,13 @@ def loss_fn(data, u_scatter, data_2):

args = get_args()

print(str(torch.cuda.device_count()) + " GPUs detected!")
if (torch.cuda.device_count() > 1) and args.use_dist:
world_size = torch.cuda.device_count()
# print(str(torch.cuda.device_count()) + " GPUs detected!")
if args.use_dist:
# world_size = torch.cuda.device_count()
world_size = int(os.environ['SLURM_NTASKS'])
else:
world_size = 1
args.use_dist = False
print('world_size is: ' + str(world_size))

start = time.time()
if args.use_dist:
Expand Down
10 changes: 9 additions & 1 deletion notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ Submitted batch job 10212615
Submitted batch job 10212616
(PINN) vidyagan@login31:/pscratch/sd/v/vidyagan/output_PINN> sbatch /pscratch/sd/v/vidyagan/PINN/slurm_train.sh 2048
Submitted batch job 10212617

Memory error for these:
(PINN) vidyagan@login31:/pscratch/sd/v/vidyagan/output_PINN> sbatch /pscratch/sd/v/vidyagan/PINN/slurm_train.sh 4096
Submitted batch job 10212619
(PINN) vidyagan@login31:/pscratch/sd/v/vidyagan/output_PINN> sbatch /pscratch/sd/v/vidyagan/PINN/slurm_train.sh 8192
Submitted batch job 10212621
Submitted batch job 10212621

All the above could only sense 1 GPU for some reason.

Modified code to distribute with NERSC:
> sbatch /pscratch/sd/v/vidyagan/PINN/slurm_train.sh 8192
Submitted batch job 10217467

0 comments on commit 5510f0e

Please sign in to comment.