Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the test accuracy #13

Open
wants to merge 8 commits into
base: dev_jax
Choose a base branch
from
Open

Fix the test accuracy #13

wants to merge 8 commits into from

Conversation

sebasrb09
Copy link
Collaborator

Worked on:

  1. Fixing the accuracy test calculation
  2. Remove cyclical call in models.py
  3. Add partial jit to compile the per example gradients, so it can have the num_classes parameter
  4. Minor bugs


accs = []
num_splits = int(len(test_images)/batch_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break in some cases.
E.g., if len(test_images) = 100 and batch_size = 9 then 100/9=11.11 but int(11.11) = 11.
Then your actual batch_size will be larger than 9 in one of the parts (10) to be precise.
I think math.ceil(len(test_images)/batch_size)) is saver.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now saw the lines above but I don't think they are necessary if you do the math.ceil()


for pb, yb in zip(test_data, test_labels):
for pb, yb in zip(splits_test, splits_test_labels):
pb = jax.device_put(pb, jax.devices("gpu")[0])
yb = jax.device_put(yb, jax.devices("gpu")[0])
# TODO: This won't be correct when len(pb) not the same for all pb in test_data.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove TODO if fixed please!

src/example.py Outdated
@@ -219,7 +220,7 @@ def body_fun(t, args):

print(actual_batch_size / duration, flush=True)

acc_iter = model_evaluation(state, splits_test, splits_labels)
acc_iter = model_evaluation(state, test_images, test_labels,test_bs_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use a formatter, there should be usually a space here!

src/example.py Outdated
@@ -232,7 +233,7 @@ def body_fun(t, args):
privacy_results = {"eps_rdp": epsilon, "delta_rdp": delta}
print(privacy_results, flush=True)

acc_last = model_evaluation(state, splits_test, splits_labels)
acc_last = model_evaluation(state, test_images, test_labels,test_bs_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here (formatter)

@@ -111,7 +113,7 @@ def setup_physical_batches(
return masks, n_physical_batches


@jax.jit
@partial(jax.jit,static_argnums=(3,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be something to check with @jjalko . I have no idea if this is how to do it.

Comment on lines 306 to 309
if diff != 0:
batch_size = batch_size - diff
warnings.warn(f'The batch size does not divide the size of the test set, fixed the new batch size to {batch_size}')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the intent but I think switching to the math.ceil() above should yield the same results without anything bad happening.

@sebasrb09
Copy link
Collaborator Author

The ceil is not correct because we are looking for the number of splits, and the test size must be divisible by it. With ceil we have the same problem. If we have math.ceil(100/9) = 12, but 100 % 12 != 0.

… is no need for the exact division of the test set
@sebasrb09 sebasrb09 requested a review from Solosneros November 5, 2024 13:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants