-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the test accuracy #13
base: dev_jax
Are you sure you want to change the base?
Conversation
src/jax_mask_efficient.py
Outdated
|
||
accs = [] | ||
num_splits = int(len(test_images)/batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will break in some cases.
E.g., if len(test_images) = 100 and batch_size = 9 then 100/9=11.11 but int(11.11) = 11.
Then your actual batch_size will be larger than 9 in one of the parts (10) to be precise.
I think math.ceil(len(test_images)/batch_size)) is saver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now saw the lines above but I don't think they are necessary if you do the math.ceil()
src/jax_mask_efficient.py
Outdated
|
||
for pb, yb in zip(test_data, test_labels): | ||
for pb, yb in zip(splits_test, splits_test_labels): | ||
pb = jax.device_put(pb, jax.devices("gpu")[0]) | ||
yb = jax.device_put(yb, jax.devices("gpu")[0]) | ||
# TODO: This won't be correct when len(pb) not the same for all pb in test_data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove TODO if fixed please!
src/example.py
Outdated
@@ -219,7 +220,7 @@ def body_fun(t, args): | |||
|
|||
print(actual_batch_size / duration, flush=True) | |||
|
|||
acc_iter = model_evaluation(state, splits_test, splits_labels) | |||
acc_iter = model_evaluation(state, test_images, test_labels,test_bs_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use a formatter, there should be usually a space here!
src/example.py
Outdated
@@ -232,7 +233,7 @@ def body_fun(t, args): | |||
privacy_results = {"eps_rdp": epsilon, "delta_rdp": delta} | |||
print(privacy_results, flush=True) | |||
|
|||
acc_last = model_evaluation(state, splits_test, splits_labels) | |||
acc_last = model_evaluation(state, test_images, test_labels,test_bs_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here (formatter)
@@ -111,7 +113,7 @@ def setup_physical_batches( | |||
return masks, n_physical_batches | |||
|
|||
|
|||
@jax.jit | |||
@partial(jax.jit,static_argnums=(3,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would be something to check with @jjalko . I have no idea if this is how to do it.
src/jax_mask_efficient.py
Outdated
if diff != 0: | ||
batch_size = batch_size - diff | ||
warnings.warn(f'The batch size does not divide the size of the test set, fixed the new batch size to {batch_size}') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the intent but I think switching to the math.ceil() above should yield the same results without anything bad happening.
The ceil is not correct because we are looking for the number of splits, and the test size must be divisible by it. With ceil we have the same problem. If we have math.ceil(100/9) = 12, but 100 % 12 != 0. |
… is no need for the exact division of the test set
Worked on: