Source code for funsor.ops.program

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from funsor.util import get_backend, set_backend


[docs]class OpProgram: """ Backend program for evaluating a symbolic funsor expression. Programs depend on the funsor library only via ``funsor.ops`` and op registrations; program evaluation does not involve funsor interpretation or rewriting. Programs can be pickled and unpickled. :param iterable expr: A list of built-in constants (leaves). :param iterable inputs: A list of string names of program inputs (leaves). :param iterable operations: A list of program operations defining non-leaf nodes in the program dag. Each operations is a tuple ``(op, arg_ids)`` where op is a funsor op and ``arg_ids`` is a tuple of positions of values, starting from zero and counting: constants, inputs, and operation outputs. """ def __init__(self, constants, inputs, operations): super().__init__() self.constants = tuple(constants) self.inputs = tuple(inputs) self.operations = tuple(operations) self.backend = get_backend() def __call__(self, **kwargs): set_backend(self.backend) # Initialize environment with constants. env = list(self.constants) # Read inputs from kwargs. for name in self.inputs: value = kwargs.pop(name, None) if value is None: raise ValueError(f"Missing kwarg: {repr(name)}") env.append(value) if kwargs: raise ValueError(f"Unrecognized kwargs: {set(kwargs)}") # Sequentially compute ops. for op, arg_ids in self.operations: args = tuple(env[i] for i in arg_ids) value = op(*args) env.append(value) result = env[-1] return result
[docs] def as_code(self, name="program"): """ Returns Python code text defining a straight-line function equivalent to this program. :param str name: Optional name for the function, defaults to "program". :returns: A string defining a python function equivalent to this program. :rtype: str """ lines = [ "# Automatically generated by funsor.compiler.FunsorProgram.as_code().", "def {}({}):".format(name, ", ".join(self.inputs)), " from funsor import set_backend, ops", f" set_backend({repr(self.backend)})", ] start = len(lines) def let(body): i = len(lines) - start lines.append(f" v{i} = {body}") for c in self.constants: let(c) for name in self.inputs: let(name) for op, arg_ids in self.operations: op = _print_op(op) args = ", ".join(f"v{arg_id}" for arg_id in arg_ids) let(f"{op}({args},)") lines.append(f" return v{len(lines) - start - 1}") return "\n".join(lines)
def make_tuple(*args): return args def _print_op(op): if op is make_tuple: return "" if op.defaults and op.defaults != type(op)().defaults: args = ", ".join(map(str, op.defaults.values())) return f"ops.{type(op).__name__}({args})" return repr(op) __all__ = [ "OpProgram", ]