Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scale factors across plate dims in partial_sum_product #606

Merged
merged 16 commits into from
Aug 31, 2023
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: [3.9]
env:
CI: 1
FUNSOR_BACKEND: jax
Expand Down
4 changes: 4 additions & 0 deletions funsor/ops/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .op import (
BINARY_INVERSES,
DISTRIBUTIVE_OPS,
PRODUCT_TO_POWER,
SAFE_BINARY_INVERSES,
UNARY_INVERSES,
UNITS,
Expand Down Expand Up @@ -287,6 +288,9 @@ def sigmoid_log_abs_det_jacobian(x, y):
UNARY_INVERSES[mul] = reciprocal
UNARY_INVERSES[add] = neg

PRODUCT_TO_POWER[add] = mul
PRODUCT_TO_POWER[mul] = pow

__all__ = [
"AssociativeOp",
"ComparisonOp",
Expand Down
2 changes: 2 additions & 0 deletions funsor/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def log_abs_det_jacobian(x, y, fn):
BINARY_INVERSES = {} # binary op -> inverse binary op
SAFE_BINARY_INVERSES = {} # binary op -> numerically safe inverse binary op
UNARY_INVERSES = {} # binary op -> inverse unary op
PRODUCT_TO_POWER = {} # product op -> power op

__all__ = [
"BINARY_INVERSES",
Expand All @@ -430,6 +431,7 @@ def log_abs_det_jacobian(x, y, fn):
"LogAbsDetJacobianOp",
"NullaryOp",
"Op",
"PRODUCT_TO_POWER",
"SAFE_BINARY_INVERSES",
"TernaryOp",
"TransformOp",
Expand Down
49 changes: 44 additions & 5 deletions funsor/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from funsor.cnf import Contraction
from funsor.domains import Bint, Reals
from funsor.interpreter import gensym
from funsor.ops import UNITS, AssociativeOp
from funsor.ops import PRODUCT_TO_POWER, UNITS, AssociativeOp
from funsor.terms import (
Cat,
Funsor,
Expand Down Expand Up @@ -203,7 +203,14 @@ def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()):


def partial_sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
sum_op,
prod_op,
fritzo marked this conversation as resolved.
Show resolved Hide resolved
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
pow_op=None,
plate_to_scale=None, # dict
):
"""
Performs partial sum-product contraction of a collection of factors.
Expand All @@ -218,6 +225,10 @@ def partial_sum_product(
assert isinstance(eliminate, frozenset)
assert isinstance(plates, frozenset)

if plate_to_scale:
if pow_op is None:
pow_op = PRODUCT_TO_POWER[prod_op]

if pedantic:
var_to_errors = defaultdict(lambda: eliminate)
for f in factors:
Expand Down Expand Up @@ -256,7 +267,17 @@ def partial_sum_product(
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & eliminate))
f = f.reduce(prod_op, leaf & eliminate)
if plate_to_scale:
f_scales = [
plate_to_scale[plate]
for plate in leaf & eliminate
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
results.append(f)
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars)
Expand Down Expand Up @@ -306,6 +327,15 @@ def partial_sum_product(
reduced_plates = leaf - new_plates
assert reduced_plates.issubset(eliminate)
f = f.reduce(prod_op, reduced_plates)
if plate_to_scale:
f_scales = [
plate_to_scale[plate]
for plate in reduced_plates
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
ordinal_to_factors[new_plates].append(f)

return results
Expand Down Expand Up @@ -571,15 +601,24 @@ def modified_partial_sum_product(


def sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
sum_op,
prod_op,
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
pow_op=None,
plate_to_scale=None, # dict
):
"""
Performs sum-product contraction of a collection of factors.

:return: a single contracted Funsor.
:rtype: :class:`~funsor.terms.Funsor`
"""
factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates, pedantic)
factors = partial_sum_product(
sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, plate_to_scale
)
return reduce(prod_op, factors, Number(UNITS[prod_op]))


Expand Down
4 changes: 2 additions & 2 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def id_from_inputs(inputs):

@dispatch(object, object, Variadic[float])
def allclose(a, b, rtol=1e-05, atol=1e-08):
if type(a) != type(b):
if type(a) is not type(b):
return False
return ops.abs(a - b) < rtol + atol * ops.abs(b)

Expand Down Expand Up @@ -125,7 +125,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
elif isinstance(actual, Gaussian):
assert isinstance(expected, Gaussian)
else:
assert type(actual) == type(expected), msg
assert type(actual) is type(expected), msg

if isinstance(actual, Funsor):
assert isinstance(expected, Funsor), msg
Expand Down
96 changes: 95 additions & 1 deletion test/test_sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
sum_product,
)
from funsor.tensor import Tensor, get_default_prototype
from funsor.terms import Variable
from funsor.terms import Cat, Variable
from funsor.testing import assert_close, random_gaussian, random_tensor
from funsor.util import get_backend

Expand Down Expand Up @@ -2899,3 +2899,97 @@ def test_mixed_sequential_sum_product(duration, num_segments):
)

assert_close(actual, expected)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale", [1, 2])
def test_partial_sum_product_scale_1(sum_op, prod_op, scale):
f1 = random_tensor(OrderedDict(a=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3]))

eliminate = frozenset("ai")
plates = frozenset("i")

# Actual result based on applying scaling
factors = [f1, f2]
scales = {"i": scale}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f3 = Cat("i", (f2,) * scale)
factors = [f1, f3]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale_i", [1, 2])
@pytest.mark.parametrize("scale_j", [1, 3])
def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j):
f1 = random_tensor(OrderedDict(a=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3]))
f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4]))

eliminate = frozenset("aij")
plates = frozenset("ij")

# Actual result based on applying scaling
factors = [f1, f2, f3]
scales = {"i": scale_i, "j": scale_j}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f4 = Cat("i", (f2,) * scale_i)
f5 = Cat("j", (f3,) * scale_j)
factors = [f1, f4, f5]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale_i", [1, 2])
@pytest.mark.parametrize("scale_j", [1, 3])
@pytest.mark.parametrize("scale_k", [1, 4])
def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k):
f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3]))
f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3]))

eliminate = frozenset("aijk")
plates = frozenset("ijk")

# Actual result based on applying scaling
factors = [f1, f2, f3]
scales = {"i": scale_i, "j": scale_j, "k": scale_k}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f4 = Cat("i", (f1,) * scale_i)
# concatenate across multiple dims
f5 = Cat("i", (f2,) * scale_i)
f5 = Cat("j", (f5,) * scale_j)
# concatenate across multiple dims
f6 = Cat("i", (f3,) * scale_i)
f6 = Cat("j", (f6,) * scale_j)
f6 = Cat("k", (f6,) * scale_k)
factors = [f4, f5, f6]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)
4 changes: 2 additions & 2 deletions test/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_to_funsor_error(x):
def test_to_data():
actual = to_data(Number(0.0))
expected = 0.0
assert type(actual) == type(expected)
assert type(actual) is type(expected)
assert actual == expected


Expand Down Expand Up @@ -569,7 +569,7 @@ def test_stack_slice(start, stop, step):
xs = tuple(map(Number, range(10)))
actual = Stack("i", xs)(i=Slice("j", start, stop, step, dtype=10))
expected = Stack("j", xs[start:stop:step])
assert type(actual) == type(expected)
assert type(actual) is type(expected)
assert actual.name == expected.name
assert actual.parts == expected.parts

Expand Down