Skip to content

Commit

Permalink
Trying to restrict samples per category
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 14, 2024
1 parent c091420 commit d241081
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import random
from collections import defaultdict

import numpy as np
import torch
Expand Down Expand Up @@ -422,51 +423,44 @@ def prepare_data(example, device_id):

return position, particle_type, material_property, n_particles_per_example, labels

import torch
import numpy as np
from collections import defaultdict

def group_friction_angles(train_dl, num_tasks, device):
def get_friction_groups(train_dl, device, max_samples_per_group=100):
"""
Group examples by their friction angles (material properties) using the data loader.
Group examples by their friction angles and return a summary of each group.
Args:
train_dl (DataLoader): The training data loader.
num_tasks (int): Number of tasks (groups) to create.
device (torch.device): The device to use for computations.
max_samples_per_group (int): Maximum number of sample indices to store per group.
Returns:
dict: A dictionary with friction values as keys and lists of batch indices and within-batch indices as values.
dict: A dictionary with friction values as keys and tuples (count, sample_indices) as values.
"""
friction_groups = defaultdict(list)
friction_groups = defaultdict(lambda: [0, []])

for batch_idx, batch in enumerate(train_dl):
# Move only the material_property to device
material_property = batch[0][2].to(device)

# Use unique to get friction values and their indices
frictions, indices = torch.unique(material_property, return_inverse=True)
# Get unique friction values and their counts
frictions, counts = torch.unique(material_property, return_counts=True)

# Process the whole batch at once
for friction, idx in zip(frictions.cpu().numpy(), indices.cpu().numpy()):
friction_groups[friction].extend(
[(batch_idx, i) for i in np.where(idx == indices.cpu().numpy())[0]]
)

# If there are more unique friction values than num_tasks, merge some groups
if len(friction_groups) > num_tasks:
sorted_frictions = sorted(friction_groups.keys())
merged_groups = {}
for i in range(num_tasks):
start_idx = i * len(sorted_frictions) // num_tasks
end_idx = (i + 1) * len(sorted_frictions) // num_tasks
merged_friction = sum(sorted_frictions[start_idx:end_idx]) / (end_idx - start_idx)
merged_groups[merged_friction] = sum(
[friction_groups[f] for f in sorted_frictions[start_idx:end_idx]], []
)
return merged_groups
else:
return dict(friction_groups)
for friction, count in zip(frictions.cpu().numpy(), counts.cpu().numpy()):
friction_groups[friction][0] += count

# If we haven't reached max_samples for this group, add some sample indices
if len(friction_groups[friction][1]) < max_samples_per_group:
# Find indices of this friction in the current batch
indices = torch.where(material_property == friction)[0].cpu().numpy()

# Randomly select indices if there are too many
if len(indices) > max_samples_per_group - len(friction_groups[friction][1]):
indices = np.random.choice(indices, max_samples_per_group - len(friction_groups[friction][1]), replace=False)

friction_groups[friction][1].extend((batch_idx, idx) for idx in indices)

# Convert defaultdict to regular dict
return dict(friction_groups)


def train(rank, cfg, world_size, device, verbose, use_dist):
"""Train the model using MAML for different friction angles.
Expand Down Expand Up @@ -561,7 +555,14 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
writer = setup_tensorboard(cfg, metadata) if verbose else None

# Group examples by friction angles
friction_groups = group_friction_angles(train_dl, num_tasks, device_id)
friction_groups = get_friction_groups(train_dl, num_tasks, device_id)

friction_groups = get_friction_groups(train_dl, device)
num_tasks = len(friction_groups)

print(f"Number of unique friction values (tasks): {num_tasks}")
for friction, (count, samples) in friction_groups.items():
print(f"Friction {friction:.4f}: {count} examples, {len(samples)} sample indices")

try:
num_epochs = max(1, (cfg.training.steps + len(friction_groups) - 1) // len(friction_groups))
Expand Down

0 comments on commit d241081

Please sign in to comment.