# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import re
from collections import OrderedDict, defaultdict
from functools import reduce
from math import gcd
import funsor
import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.domains import Bint, Reals
from funsor.interpreter import gensym
from funsor.ops import PRODUCT_TO_POWER, UNITS, AssociativeOp
from funsor.terms import (
Cat,
Funsor,
FunsorMeta,
Number,
Slice,
Stack,
Subs,
Variable,
eager,
substitute,
to_funsor,
)
from funsor.util import quote
def _partition(terms, sum_vars):
# Construct a bipartite graph between terms and the vars
neighbors = OrderedDict([(t, []) for t in terms])
for term in terms:
for dim in term.inputs.keys():
if dim in sum_vars:
neighbors[term].append(dim)
neighbors.setdefault(dim, []).append(term)
# Partition the bipartite graph into connected components for contraction.
components = []
while neighbors:
v, pending = neighbors.popitem()
component = OrderedDict([(v, None)]) # used as an OrderedSet
for v in pending:
component[v] = None
while pending:
v = pending.pop()
for v in neighbors.pop(v):
if v not in component:
component[v] = None
pending.append(v)
# Split this connected component into tensors and dims.
component_terms = tuple(v for v in component if isinstance(v, Funsor))
if component_terms:
component_dims = frozenset(
v for v in component if not isinstance(v, Funsor)
)
components.append((component_terms, component_dims))
return components
def _unroll_plate(factors, var_to_ordinal, sum_vars, plate, step):
# size of the plate
size = next(iter(f.inputs[plate].size for f in factors if plate in f.inputs))
# history of the plate
history = 1 if step else 0
# replicated variables
plate_vars = set()
for var, ordinal in var_to_ordinal.items():
if plate in ordinal:
plate_vars.add(var)
# make sure that all vars in the plate are being unrolled
assert plate_vars.issubset(sum_vars)
# unroll variables
for var in plate_vars:
sum_vars -= frozenset({var})
if var in step.keys():
new_var = frozenset(
{"{}_{}".format(var.split("_")[0], i) for i in range(size)}
)
elif var in step.values():
new_var = frozenset(
{"{}_{}".format(var.split("_")[0], i + history) for i in range(size)}
)
else:
new_var = frozenset({"{}_{}".format(var, i + history) for i in range(size)})
sum_vars |= new_var
ordinal = var_to_ordinal.pop(var)
new_ordinal = ordinal.difference({plate})
var_to_ordinal.update({v: new_ordinal for v in new_var})
# unroll factors
unrolled_factors = []
for factor in factors:
if plate in factor.inputs:
f_vars = plate_vars.intersection(factor.inputs)
prev_to_var = {key: key.split("_")[0] for key in step.keys()}
curr_to_var = {value: value.split("_")[0] for value in step.values()}
nonmarkov_vars = f_vars - set(step.keys()) - set(step.values())
unrolled_factors.extend(
[
factor(
**{plate: i},
**{
var: "{}_{}".format(var, i + history)
for var in nonmarkov_vars
},
**{
curr: "{}_{}".format(var, i + history)
for curr, var in curr_to_var.items()
},
**{
prev: "{}_{}".format(var, i)
for prev, var in prev_to_var.items()
},
)
for i in range(size)
]
)
else:
unrolled_factors.append(factor)
return unrolled_factors, var_to_ordinal, sum_vars
[docs]def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()):
"""
Performs partial unrolling of plated factor graphs to standard factor graphs.
Only plates with history={0, 1} are supported.
For plates (history=0) unrolling operation appends ``_{i}`` suffix
to variable names for index ``i`` in the plate (e.g., "x"->"x_0" for i=0).
For markov dimensions (history=1) unrolling operation renames the suffixes
``var_prev`` to ``var_{i}`` and ``var_curr`` to ``var_{i+1}`` for index ``i``
(e.g., "x_prev"->"x_0" and "x_curr"->"x_1" for i=0).
Markov vars are assumed to have names that follow ``var_suffix`` formatting
and specifically ``var_0`` for the initial factor (e.g.,
``("x_0", "x_prev", "x_curr")`` for history=1).
:param factors: A collection of funsors.
:type factors: tuple or list
:param frozenset eliminate: A set of free variables to unroll,
including both sum variables and product variable.
:param dict plate_to_step: A dict mapping markov dimensions to
``step`` collections that contain ordered sequences of Markov variable names
(e.g., ``{"time": frozenset({("x_0", "x_prev", "x_curr")})}``).
Plates are passed with an empty ``step``.
:return: a list of partially unrolled Funsors,
a frozenset of partially unrolled variable names,
and a frozenset of remaining plates.
"""
assert isinstance(factors, (tuple, list))
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plate_to_step, dict)
assert all(
len(set(var.split("_")[0] for var in chain)) == 1 and chain[0].endswith("_0")
for step in plate_to_step.values()
if step
for chain in step
)
# process plate_to_step
plate_to_step = plate_to_step.copy()
for key, step in plate_to_step.items():
# make a dict step e.g. {"x_prev": "x_curr"}; specific to history = 1
plate_to_step[key] = {s[1]: s[2] for s in step}
plates = frozenset(plate_to_step.keys())
sum_vars = eliminate - plates
unrolled_plates = {k: v for (k, v) in plate_to_step.items() if k in eliminate}
remaining_plates = {k: v for (k, v) in plate_to_step.items() if k not in eliminate}
var_to_ordinal = {}
for f in factors:
ordinal = plates.intersection(f.inputs)
for var in set(f.inputs) - plates:
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal
# first unroll plates with history=1 and highest ordinal
# then unroll plates with history=0
plate_to_order = {}
for plate, step in unrolled_plates.items():
if step:
plate_to_order[plate] = max(len(var_to_ordinal[s]) for s in step)
else:
plate_to_order[plate] = 0
# unroll one plate at a time
for plate in sorted(
unrolled_plates.keys(), key=lambda p: plate_to_order[p], reverse=True
):
step = unrolled_plates[plate]
factors, var_to_ordinal, sum_vars = _unroll_plate(
factors, var_to_ordinal, sum_vars, plate, step
)
return factors, sum_vars, remaining_plates
[docs]def partial_sum_product(
sum_op,
prod_op,
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
pow_op=None,
plate_to_scale=None, # dict
):
"""
Performs partial sum-product contraction of a collection of factors.
:return: a list of partially contracted Funsors.
:rtype: list
"""
assert callable(sum_op)
assert callable(prod_op)
assert isinstance(factors, (tuple, list))
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plates, frozenset)
if plate_to_scale:
if pow_op is None:
pow_op = PRODUCT_TO_POWER[prod_op]
if pedantic:
var_to_errors = defaultdict(lambda: eliminate)
for f in factors:
ordinal = plates.intersection(f.inputs)
for var in set(f.inputs) - plates - eliminate:
var_to_errors[var] &= ordinal
for var, errors in var_to_errors.items():
for plate in errors:
raise ValueError(
f"Cannot eliminate plate {plate} containing preserved var {var}"
)
plates &= eliminate
sum_vars = eliminate - plates
var_to_ordinal = {}
ordinal_to_factors = defaultdict(list)
for f in factors:
ordinal = plates.intersection(f.inputs)
ordinal_to_factors[ordinal].append(f)
for var in sum_vars.intersection(f.inputs):
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal
ordinal_to_vars = defaultdict(set)
for var, ordinal in var_to_ordinal.items():
ordinal_to_vars[ordinal].add(var)
results = []
while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len) # CHOICE
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for group_factors, group_vars in _partition(
leaf_factors, leaf_reduce_vars
): # CHOICE
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
f = f.reduce(prod_op, leaf & eliminate)
if plate_to_scale:
f_scales = [
plate_to_scale[plate]
for plate in leaf & eliminate
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
results.append(f)
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars)
)
if new_plates == leaf:
# Choose the smallest plate to eliminate.
plate = min(
(f.inputs[plate].size, plate) for plate in leaf & eliminate
)[-1]
new_plates = leaf - {plate}
plate_shape = (f.inputs[plate].size,)
subs = {}
for v in remaining_sum_vars:
if plate in var_to_ordinal[v]:
if f.inputs[v].dtype != "real":
raise ValueError("intractable!")
v_ = Variable(
gensym(v), Reals[plate_shape + f.inputs[v].shape]
)
v_ordinal = var_to_ordinal[v] - {plate}
var_to_ordinal[v_.name] = v_ordinal
ordinal_to_vars[v_ordinal].add(v_.name)
sum_vars = sum_vars - {v} | {v_.name}
eliminate = eliminate - {v} | {v_.name}
subs[v] = v_[plate]
# This will only work for terms implementing substituting
# {var1: ops.getitem(var2, var3)}, e.g. Gaussian but not Tensor.
f = f(**subs)
for o, gs in list(ordinal_to_factors.items()):
if plate not in o:
assert all(set(g.inputs).isdisjoint(subs) for g in gs)
continue # nothing to do below
remaining = []
for g in gs:
if set(subs).intersection(g.inputs):
g = g(**subs)
assert all(
plate not in var_to_ordinal[u]
for u in g.inputs
if u in sum_vars
)
g = g.reduce(prod_op, plate)
ordinal_to_factors[o - {plate}].append(g)
else:
remaining.append(g)
ordinal_to_factors[o] = remaining
reduced_plates = leaf - new_plates
assert reduced_plates.issubset(eliminate)
f = f.reduce(prod_op, reduced_plates)
if plate_to_scale:
f_scales = [
plate_to_scale[plate]
for plate in reduced_plates
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
ordinal_to_factors[new_plates].append(f)
return results
[docs]def dynamic_partial_sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plate_to_step=dict()
):
"""
Generalization of the tensor variable elimination algorithm of
:func:`funsor.sum_product.partial_sum_product` to handle higer-order markov
dimensions in addition to plate dimensions. Markov dimensions in transition
factors are eliminated efficiently using the parallel-scan algorithm in
:func:`funsor.sum_product.sarkka_bilmes_product`. The resulting factors are then
combined with the initial factors and final states are eliminated. Therefore,
when Markov dimension is eliminated ``factors`` has to contain
initial factors and transition factors.
:param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
:param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
:param factors: A collection of funsors.
:type factors: tuple or list
:param frozenset eliminate: A set of free variables to eliminate,
including both sum variables and product variable.
:param dict plate_to_step: A dict mapping markov dimensions to
``step`` collections that contain ordered sequences of Markov variable names
(e.g., ``{"time": frozenset({("x_0", "x_prev", "x_curr")})}``).
Plates are passed with an empty ``step``.
:return: a list of partially contracted Funsors.
:rtype: list
"""
assert callable(sum_op)
assert callable(prod_op)
assert isinstance(factors, (tuple, list))
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plate_to_step, dict)
# process plate_to_step
plate_to_step = plate_to_step.copy()
prev_to_init = {}
markov_to_sarkka = {}
markov_sum_vars = set()
for key, step in plate_to_step.items():
for chain in step:
# map old markov step names to sarkka_bilmes format step names
# Case 1
# x_slice(0, 5, None) -> _PREV__PREV_x_slice(2, 7, None)
# x_slice(1, 6, None) -> _PREV_x_slice(2, 7, None)
# x_slice(2, 7, None) -> x_slice(2, 7, None)
# Case 2
# x_prev - > _PREV_x_curr
# x_curr -> x_curr
history = len(chain) // 2
base_name = chain[-1]
for t, name in enumerate(reversed(chain[history:-1])):
markov_to_sarkka[name] = _shift_name(base_name, t + 1)
markov_sum_vars.add(base_name)
markov_sum_vars.update(markov_to_sarkka)
# map prev to init; works for any history > 0
init, prev = chain[: len(chain) // 2], chain[len(chain) // 2 : -1]
prev = tuple(markov_to_sarkka[name] for name in prev)
prev_to_init.update(zip(prev, init))
markov_sum_vars = frozenset(markov_sum_vars)
plates = frozenset(plate_to_step.keys())
sum_vars = eliminate - plates
prod_vars = eliminate.intersection(plates)
markov_prod_vars = frozenset(
k for k, v in plate_to_step.items() if v and k in eliminate
)
markov_sum_to_prod = defaultdict(set)
for markov_prod in markov_prod_vars:
for chain in plate_to_step[markov_prod]:
for name in chain[len(chain) // 2 :]:
markov_sum_to_prod[name].add(markov_prod)
var_to_ordinal = {}
ordinal_to_factors = defaultdict(list)
for f in factors:
ordinal = plates.intersection(f.inputs)
ordinal_to_factors[ordinal].append(f)
for var in sum_vars.intersection(f.inputs):
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal
ordinal_to_vars = defaultdict(set)
for var, ordinal in var_to_ordinal.items():
ordinal_to_vars[ordinal].add(var)
results = []
while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for group_factors, group_vars in _partition(
leaf_factors, leaf_reduce_vars | markov_prod_vars
):
# eliminate non markov vars
nonmarkov_vars = group_vars - markov_sum_vars - markov_prod_vars
f = reduce(prod_op, group_factors).reduce(sum_op, nonmarkov_vars)
# eliminate markov vars
markov_vars = group_vars.intersection(markov_sum_vars)
if markov_vars:
markov_prod_var = [markov_sum_to_prod[var] for var in markov_vars]
assert all(p == markov_prod_var[0] for p in markov_prod_var)
if len(markov_prod_var[0]) != 1:
raise ValueError("intractable!")
time = next(iter(markov_prod_var[0]))
for v in sum_vars.intersection(f.inputs):
if time in var_to_ordinal[v] and var_to_ordinal[v] < leaf:
raise ValueError("intractable!")
time_var = Variable(time, f.inputs[time])
# markov_to_sarkka renames variables in MarkovProduct format
# to sarkka_bilmes_product format
base_names = markov_vars.intersection(
_shift_name(name, -_get_shift(name))
for name in markov_to_sarkka.values()
)
f = f(**markov_to_sarkka)
global_vars = frozenset(
set(f.inputs)
- {time_var.name}
- set(markov_to_sarkka.values())
- base_names
)
with funsor.terms.eager:
f = funsor.optimizer.apply_optimizer(f)
f = sarkka_bilmes_product(sum_op, prod_op, f, time_var, global_vars)
f = f.reduce(sum_op, base_names)
f = f(**prev_to_init)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & prod_vars - markov_prod_vars))
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars)
)
if new_plates == leaf:
raise ValueError("intractable!")
f = f.reduce(prod_op, leaf - new_plates - markov_prod_vars)
ordinal_to_factors[new_plates].append(f)
return results
[docs]def modified_partial_sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plate_to_step=dict()
):
"""
Generalization of the tensor variable elimination algorithm of
:func:`funsor.sum_product.partial_sum_product` to handle markov dimensions
in addition to plate dimensions. Markov dimensions in transition factors
are eliminated efficiently using the parallel-scan algorithm in
:func:`funsor.sum_product.sequential_sum_product`. The resulting factors are then
combined with the initial factors and final states are eliminated. Therefore,
when Markov dimension is eliminated ``factors`` has to contain a pairs of
initial factors and transition factors.
:param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
:param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
:param factors: A collection of funsors.
:type factors: tuple or list
:param frozenset eliminate: A set of free variables to eliminate,
including both sum variables and product variable.
:param dict plate_to_step: A dict mapping markov dimensions to
``step`` collections that contain ordered sequences of Markov variable names
(e.g., ``{"time": frozenset({("x_0", "x_prev", "x_curr")})}``).
Plates are passed with an empty ``step``.
:return: a list of partially contracted Funsors.
:rtype: list
"""
assert callable(sum_op)
assert callable(prod_op)
assert isinstance(factors, (tuple, list))
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plate_to_step, dict)
# process plate_to_step
plate_to_step = plate_to_step.copy()
prev_to_init = {}
for key, step in plate_to_step.items():
# map prev to init; works for any history > 0
for chain in step:
init, prev = chain[: len(chain) // 2], chain[len(chain) // 2 : -1]
prev_to_init.update(zip(prev, init))
# convert step to dict type required for MarkovProduct
plate_to_step[key] = {chain[1]: chain[2] for chain in step}
plates = frozenset(plate_to_step.keys())
sum_vars = eliminate - plates
prod_vars = eliminate.intersection(plates)
markov_sum_vars = frozenset()
for step in plate_to_step.values():
markov_sum_vars |= frozenset(step.keys()) | frozenset(step.values())
markov_sum_vars &= sum_vars
markov_prod_vars = frozenset(
k for k, v in plate_to_step.items() if v and k in eliminate
)
markov_sum_to_prod = defaultdict(set)
for markov_prod in markov_prod_vars:
for k, v in plate_to_step[markov_prod].items():
markov_sum_to_prod[k].add(markov_prod)
markov_sum_to_prod[v].add(markov_prod)
var_to_ordinal = {}
ordinal_to_factors = defaultdict(list)
for f in factors:
ordinal = plates.intersection(f.inputs)
ordinal_to_factors[ordinal].append(f)
for var in sum_vars.intersection(f.inputs):
var_to_ordinal[var] = var_to_ordinal.get(var, ordinal) & ordinal
ordinal_to_vars = defaultdict(set)
for var, ordinal in var_to_ordinal.items():
ordinal_to_vars[ordinal].add(var)
results = []
while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for group_factors, group_vars in _partition(
leaf_factors, leaf_reduce_vars | markov_prod_vars
):
# eliminate non markov vars
nonmarkov_vars = group_vars - markov_sum_vars - markov_prod_vars
f = reduce(prod_op, group_factors).reduce(sum_op, nonmarkov_vars)
# eliminate markov vars
markov_vars = group_vars.intersection(markov_sum_vars)
if markov_vars:
markov_prod_var = [markov_sum_to_prod[var] for var in markov_vars]
assert all(p == markov_prod_var[0] for p in markov_prod_var)
if len(markov_prod_var[0]) != 1:
raise ValueError("intractable!")
time = next(iter(markov_prod_var[0]))
for v in sum_vars.intersection(f.inputs):
if time in var_to_ordinal[v] and var_to_ordinal[v] < leaf:
raise ValueError("intractable!")
time_var = Variable(time, f.inputs[time])
group_step = {
k: v for (k, v) in plate_to_step[time].items() if v in markov_vars
}
f = MarkovProduct(sum_op, prod_op, f, time_var, group_step)
f = f.reduce(sum_op, frozenset(group_step.values()))
f = f(**prev_to_init)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & prod_vars - markov_prod_vars))
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars)
)
if new_plates == leaf:
raise ValueError("intractable!")
f = f.reduce(prod_op, leaf - new_plates - markov_prod_vars)
ordinal_to_factors[new_plates].append(f)
return results
[docs]def sum_product(
sum_op,
prod_op,
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
pow_op=None,
plate_to_scale=None, # dict
):
"""
Performs sum-product contraction of a collection of factors.
:return: a single contracted Funsor.
:rtype: :class:`~funsor.terms.Funsor`
"""
factors = partial_sum_product(
sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, plate_to_scale
)
return reduce(prod_op, factors, Number(UNITS[prod_op]))
[docs]def naive_sequential_sum_product(sum_op, prod_op, trans, time, step):
assert isinstance(sum_op, AssociativeOp)
assert isinstance(prod_op, AssociativeOp)
assert isinstance(trans, Funsor)
assert isinstance(time, Variable)
assert isinstance(step, dict)
assert all(isinstance(k, str) for k in step.keys())
assert all(isinstance(v, str) for v in step.values())
if time.name in trans.inputs:
assert time.output == trans.inputs[time.name]
step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
drop = frozenset(drop)
time, duration = time.name, time.output.size
factors = [trans(**{time: t}) for t in range(duration)]
while len(factors) > 1:
y = factors.pop()(**prev_to_drop)
x = factors.pop()(**curr_to_drop)
xy = prod_op(x, y).reduce(sum_op, drop)
factors.append(xy)
return factors[0]
[docs]def sequential_sum_product(sum_op, prod_op, trans, time, step):
"""
For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``,
computes a recursion equivalent to::
tail_time = 1 + arange("time", trans.inputs["time"].size - 1)
tail = sequential_sum_product(sum_op, prod_op,
trans(time=tail_time),
time, {"prev": "curr"})
return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \
.reduce(sum_op, "drop")
but does so efficiently in parallel in O(log(time)).
:param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
:param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
:param ~funsor.terms.Funsor trans: A transition funsor.
:param Variable time: The time input dimension.
:param dict step: A dict mapping previous variables to current variables.
This can contain multiple pairs of prev->curr variable names.
"""
assert isinstance(sum_op, AssociativeOp)
assert isinstance(prod_op, AssociativeOp)
assert isinstance(trans, Funsor)
assert isinstance(time, Variable)
assert isinstance(step, dict)
assert all(isinstance(k, str) for k in step.keys())
assert all(isinstance(v, str) for v in step.values())
if time.name in trans.inputs:
assert time.output == trans.inputs[time.name]
step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
drop = frozenset(Variable(v, trans.inputs[k]) for k, v in curr_to_drop.items())
time, duration = time.name, time.output.size
while duration > 1:
even_duration = duration // 2 * 2
x = trans(**{time: Slice(time, 0, even_duration, 2, duration)}, **curr_to_drop)
y = trans(**{time: Slice(time, 1, even_duration, 2, duration)}, **prev_to_drop)
contracted = Contraction(sum_op, prod_op, drop, x, y)
if duration > even_duration:
extra = trans(**{time: Slice(time, duration - 1, duration)})
contracted = Cat(time, (contracted, extra))
trans = contracted
duration = (duration + 1) // 2
return trans(**{time: 0})
[docs]def mixed_sequential_sum_product(sum_op, prod_op, trans, time, step, num_segments=None):
"""
For a funsor ``trans`` with dimensions ``time``, ``prev`` and ``curr``,
computes a recursion equivalent to::
tail_time = 1 + arange("time", trans.inputs["time"].size - 1)
tail = sequential_sum_product(sum_op, prod_op,
trans(time=tail_time),
time, {"prev": "curr"})
return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) \
.reduce(sum_op, "drop")
by mixing parallel and serial scan algorithms over ``num_segments`` segments.
:param ~funsor.ops.AssociativeOp sum_op: A semiring sum operation.
:param ~funsor.ops.AssociativeOp prod_op: A semiring product operation.
:param ~funsor.terms.Funsor trans: A transition funsor.
:param Variable time: The time input dimension.
:param dict step: A dict mapping previous variables to current variables.
This can contain multiple pairs of prev->curr variable names.
:param int num_segments: number of segments for the first stage
"""
time_var, time, duration = time, time.name, time.output.size
num_segments = duration if num_segments is None else num_segments
assert num_segments > 0 and duration > 0
# handle unevenly sized segments by chopping off the final segment and calling mixed_sequential_sum_product again
if duration % num_segments and duration - duration % num_segments > 0:
remainder = trans(
**{
time: Slice(
time, duration - duration % num_segments, duration, 1, duration
)
}
)
initial = trans(
**{time: Slice(time, 0, duration - duration % num_segments, 1, duration)}
)
initial_eliminated = mixed_sequential_sum_product(
sum_op,
prod_op,
initial,
Variable(time, Bint[duration - duration % num_segments]),
step,
num_segments=num_segments,
)
final = Cat(time, (Stack(time, (initial_eliminated,)), remainder))
final_eliminated = naive_sequential_sum_product(
sum_op,
prod_op,
final,
Variable(time, Bint[1 + duration % num_segments]),
step,
)
return final_eliminated
# handle degenerate cases that reduce to a single stage
if num_segments == 1:
return naive_sequential_sum_product(sum_op, prod_op, trans, time_var, step)
if num_segments >= duration:
return sequential_sum_product(sum_op, prod_op, trans, time_var, step)
# break trans into num_segments segments of equal length
segment_length = duration // num_segments
segments = [
trans(
**{
time: Slice(
time, i * segment_length, (i + 1) * segment_length, 1, duration
)
}
)
for i in range(num_segments)
]
first_stage_result = naive_sequential_sum_product(
sum_op,
prod_op,
Stack(time + "__SEGMENTED", tuple(segments)),
Variable(time, Bint[segment_length]),
step,
)
second_stage_result = sequential_sum_product(
sum_op,
prod_op,
first_stage_result,
Variable(time + "__SEGMENTED", Bint[num_segments]),
step,
)
return second_stage_result
def _get_shift(name):
"""helper function used internally in sarkka_bilmes_product"""
return len(re.search(r"^(_PREV_)*", name).group(0)) // 6
def _shift_name(name, t):
"""helper function used internally in sarkka_bilmes_product"""
if t >= 0:
return t * "_PREV_" + name
return name.replace("_PREV_" * -t, "", 1)
def _shift_funsor(f, t, global_vars):
"""helper function used internally in sarkka_bilmes_product"""
if t == 0:
return f
return f(
**{name: _shift_name(name, t) for name in f.inputs if name not in global_vars}
)
[docs]def naive_sarkka_bilmes_product(
sum_op, prod_op, trans, time_var, global_vars=frozenset()
):
assert isinstance(global_vars, frozenset)
time = time_var.name
global_vars |= {time}
lags = {_get_shift(name) for name in trans.inputs if name != time}
lags.discard(0)
if not lags:
return naive_sequential_sum_product(sum_op, prod_op, trans, time_var, {})
original_names = frozenset(
name
for name in trans.inputs
if name not in global_vars and not name.startswith("_PREV_")
)
duration = trans.inputs[time].size
result = trans(**{time: duration - 1})
for t in range(duration - 2, -1, -1):
result = prod_op(
_shift_funsor(trans(**{time: t}), duration - t - 1, global_vars), result
)
sum_vars = frozenset(
_shift_name(name, duration - t - 1) for name in original_names
)
result = result.reduce(sum_op, sum_vars)
result = result(
**{name: _shift_name(name, -duration + 1) for name in result.inputs}
)
return result
[docs]def sarkka_bilmes_product(
sum_op, prod_op, trans, time_var, global_vars=frozenset(), num_periods=1
):
assert isinstance(global_vars, frozenset)
time = time_var.name
global_vars |= {time}
lags = {_get_shift(name) for name in trans.inputs if name != time}
lags.discard(0)
if not lags:
return sequential_sum_product(sum_op, prod_op, trans, time_var, {})
period = int(reduce(lambda a, b: a * b // gcd(a, b), list(lags)))
original_names = frozenset(
name
for name in trans.inputs
if name not in global_vars and not name.startswith("_PREV_")
)
renamed_factors = []
duration = trans.inputs[time].size
if duration % period != 0:
remaining_duration = duration % period
truncated_duration = duration - remaining_duration
if truncated_duration == 0:
result = trans(**{time: remaining_duration - 1})
remaining_duration -= 1
else:
# chop off the rightmost set of complete chunks from trans,
# then recursively call sarkka_bilmes_product on truncated factor
result = sarkka_bilmes_product(
sum_op,
prod_op,
trans(**{time: Slice(time, remaining_duration, duration, 1, duration)}),
Variable(time, Bint[truncated_duration]),
global_vars - {time},
num_periods,
)
# sequentially combine remaining pieces with result
for t in reversed(range(remaining_duration)):
result = prod_op(
_shift_funsor(trans(**{time: t}), remaining_duration - t, global_vars),
result,
)
sum_vars = frozenset(
_shift_name(name, remaining_duration - t) for name in original_names
)
result = result.reduce(sum_op, sum_vars)
result = result(
**{name: _shift_name(name, -remaining_duration) for name in result.inputs}
)
return result
for t in range(period):
slice_t = Slice(time, t, duration - period + t + 1, period, duration)
factor = _shift_funsor(trans, period - t - 1, global_vars)
factor = factor(**{time: slice_t})
renamed_factors.append(factor)
block_trans = reduce(prod_op, renamed_factors)
block_step = {
_shift_name(name, period): name
for name in block_trans.inputs
if name not in global_vars and _get_shift(name) < period
}
block_time_var = Variable(time_var.name, Bint[duration // period])
final_chunk = mixed_sequential_sum_product(
sum_op,
prod_op,
block_trans,
block_time_var,
block_step,
num_segments=max(1, duration // (period * num_periods)),
)
final_sum_vars = frozenset(
_shift_name(name, t) for name in original_names for t in range(1, period)
)
result = final_chunk.reduce(sum_op, final_sum_vars)
result = result(**{name: _shift_name(name, -period + 1) for name in result.inputs})
return result
[docs]class MarkovProduct(Funsor, metaclass=MarkovProductMeta):
"""
Lazy representation of :func:`sequential_sum_product` .
:param AssociativeOp sum_op: A marginalization op.
:param AssociativeOp prod_op: A Bayesian fusion op.
:param Funsor trans: A sequence of transition factors,
usually varying along the ``time`` input.
:param time: A time dimension.
:type time: str or Variable
:param dict step: A str-to-str mapping of "previous" inputs of ``trans``
to "current" inputs of ``trans``.
:param dict step_names: Optional, for internal use by alpha conversion.
"""
def __init__(self, sum_op, prod_op, trans, time, step, step_names):
assert isinstance(sum_op, AssociativeOp)
assert isinstance(prod_op, AssociativeOp)
assert isinstance(trans, Funsor)
assert isinstance(time, Variable)
assert isinstance(step, frozenset)
assert isinstance(step_names, frozenset)
step = dict(step)
step_names = dict(step_names)
assert all(isinstance(k, str) for k in step_names.keys())
assert all(isinstance(v, str) for v in step_names.values())
assert set(step_names) == set(step).union(step.values())
inputs = OrderedDict(
(step_names.get(k, k), v) for k, v in trans.inputs.items() if k != time.name
)
output = trans.output
fresh = frozenset(step_names.values())
bound = {k: trans.inputs[k] for k in step_names}
bound[time.name] = time.output
super().__init__(inputs, output, fresh, bound)
self.sum_op = sum_op
self.prod_op = prod_op
self.trans = trans
self.time = time
self.step = step
self.step_names = step_names
def _alpha_convert(self, alpha_subs):
assert set(alpha_subs).issubset(self.bound)
time = Variable(
alpha_subs.get(self.time.name, self.time.name), self.time.output
)
step = frozenset(
(alpha_subs.get(k, k), alpha_subs.get(v, v)) for k, v in self.step.items()
)
step_names = frozenset(
(alpha_subs.get(k, k), v) for k, v in self.step_names.items()
)
alpha_subs = {
k: to_funsor(v, self.trans.inputs[k])
for k, v in alpha_subs.items()
if k in self.trans.inputs
}
trans = substitute(self.trans, alpha_subs)
return self.sum_op, self.prod_op, trans, time, step, step_names
[docs] def eager_subs(self, subs):
assert isinstance(subs, tuple)
# Eagerly rename variables.
rename = {k: v.name for k, v in subs if isinstance(v, Variable)}
if not rename:
return None
step_names = frozenset(
(k, rename.get(v, v)) for k, v in self.step_names.items()
)
result = MarkovProduct(
self.sum_op, self.prod_op, self.trans, self.time, self.step, step_names
)
lazy = tuple((k, v) for k, v in subs if not isinstance(v, Variable))
if lazy:
result = Subs(result, lazy)
return result
@quote.register(MarkovProduct)
def _(arg, indent, out):
line = "{}({}, {},".format(type(arg).__name__, repr(arg.sum_op), repr(arg.prod_op))
out.append((indent, line))
for value in arg._ast_values[2:]:
quote.inplace(value, indent + 1, out)
i, line = out[-1]
out[-1] = i, line + ","
i, line = out[-1]
out[-1] = i, line[:-1] + ")"
[docs]@eager.register(
MarkovProduct, AssociativeOp, AssociativeOp, Funsor, Variable, frozenset, frozenset
)
def eager_markov_product(sum_op, prod_op, trans, time, step, step_names):
if step:
result = sequential_sum_product(sum_op, prod_op, trans, time, dict(step))
elif time.name in trans.inputs:
result = trans.reduce(prod_op, time.name)
elif prod_op is ops.add:
result = trans * time.size
elif prod_op is ops.mul:
result = trans**time.size
else:
raise NotImplementedError("https://github.com/pyro-ppl/funsor/issues/233")
return Subs(result, step_names)