Source code for funsor.montecarlo

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict

from funsor.integrate import Integrate
from funsor.interpreter import StatefulInterpretation
from funsor.terms import Funsor
from funsor.util import get_backend

[docs]class MonteCarlo(StatefulInterpretation): """ A Monte Carlo interpretation of :class:`~funsor.integrate.Integrate` expressions. This falls back to the previous interpreter in other cases. :param rng_key: """ def __init__(self, *, rng_key=None, **sample_inputs): self.rng_key = rng_key self.sample_inputs = OrderedDict(sample_inputs)
@MonteCarlo.register(Integrate, Funsor, Funsor, frozenset) def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): sample_options = {} if state.rng_key is not None and get_backend() == "jax": import jax sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options) if sample is log_measure: return None # cannot progress reduced_vars |= frozenset(state.sample_inputs).intersection(sample.inputs) return Integrate(sample, integrand, reduced_vars) __all__ = [ 'MonteCarlo', ]