Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scale factors across plate dims in partial_sum_product #606

Merged
merged 16 commits into from
Aug 31, 2023

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Feb 6, 2023

One of the features not supported by TraceEnum_ELBO is that you cannot subsample a local variable when it depends on a global variable that is enumerated in the model because it requires a common scale:

@config_enumerate
def model(data):
    # Global variables.
    locs = torch.tensor([1., 10.])
    assignment = pyro.sample('assignment', dist.Categorical(torch.ones(2)))
    with pyro.plate('data', len(data), subsample_size=2) as ind:  # cannot subsample here
        # Local variables.
        pyro.sample('obs', dist.Normal(locs[assignment], 1.), obs=data[ind])

def guide(data):
    pass

This has been asked on the forum as well: https://forum.pyro.ai/t/enumeration-and-subsampling-expected-all-enumerated-sample-sites-to-share-common-poutine-scale/4938

A solution I'm proposing here is to perform plate-wise scaling inside the partial_sum_product by passing the plate_to_scale dictionary. Then whenever a plate is reduced we can scale the factor:

# inside partial_sum_product
f = f.reduce(prod_op, reduced_plates)
f_scales = [plate_to_scale[plate] for plate in reduced_plates if plate in plate_to_scale]
if f_scales:
    scale = reduce(ops.mul, f_scales)
    f = pow_op(f, scale)
  • Added tests

@ordabayevy
Copy link
Member Author

@fritzo @eb8680 can you have a look at this? Do changes to partial_sum_product look reasonable to you?

@ordabayevy
Copy link
Member Author

@eb8680 I opened a PR in NumPyro (pyro-ppl/numpyro#1572) where I tried to expand the math and check whether expectations match for a full- and a mini-batch log-likelihoods. Can you verify that the math is correct?

elif sum_op is ops.add and prod_op is ops.mul:
pow_op = ops.pow
else:
raise ValueError("should not be here!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be NotImplementedError

Comment on lines 228 to 234
if plate_to_scale:
if sum_op is ops.logaddexp and prod_op is ops.add:
pow_op = ops.mul
elif sum_op is ops.add and prod_op is ops.mul:
pow_op = ops.pow
else:
raise ValueError("should not be here!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this out to a PROD_TO_POWER dict or similar, see ops

funsor/sum_product.py Show resolved Hide resolved
Comment on lines 274 to 282
f = f.reduce(prod_op, leaf & eliminate)
f_scales = [
plate_to_scale[plate]
for plate in leaf & eliminate
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we wrap this in an if plate_to_scale: guard to improve readability?

@@ -306,6 +330,14 @@ def partial_sum_product(
reduced_plates = leaf - new_plates
assert reduced_plates.issubset(eliminate)
f = f.reduce(prod_op, reduced_plates)
f_scales = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: if plate_to_scale: ...

eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
plate_to_scale={},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to None, maybe comment on datatype

@ordabayevy
Copy link
Member Author

@fritzo can you review the changes please? Also I believe the rule for jax tests needs to be changed in the Settings to use python 3.9 instead of 3.8. Jax doesn't support 3.8 anymore.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my comments.

@fritzo fritzo merged commit 349c038 into master Aug 31, 2023
@fritzo fritzo deleted the sum-product-scale branch August 31, 2023 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants