# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from funsor.util import get_backend, set_backend
[docs]class OpProgram:
"""
Backend program for evaluating a symbolic funsor expression.
Programs depend on the funsor library only via ``funsor.ops`` and op
registrations; program evaluation does not involve funsor interpretation or
rewriting. Programs can be pickled and unpickled.
:param iterable expr: A list of built-in constants (leaves).
:param iterable inputs: A list of string names of program inputs (leaves).
:param iterable operations: A list of program operations defining
non-leaf nodes in the program dag. Each operations is a tuple ``(op,
arg_ids)`` where op is a funsor op and ``arg_ids`` is a tuple of
positions of values, starting from zero and counting: constants,
inputs, and operation outputs.
"""
def __init__(self, constants, inputs, operations):
super().__init__()
self.constants = tuple(constants)
self.inputs = tuple(inputs)
self.operations = tuple(operations)
self.backend = get_backend()
def __call__(self, **kwargs):
set_backend(self.backend)
# Initialize environment with constants.
env = list(self.constants)
# Read inputs from kwargs.
for name in self.inputs:
value = kwargs.pop(name, None)
if value is None:
raise ValueError(f"Missing kwarg: {repr(name)}")
env.append(value)
if kwargs:
raise ValueError(f"Unrecognized kwargs: {set(kwargs)}")
# Sequentially compute ops.
for op, arg_ids in self.operations:
args = tuple(env[i] for i in arg_ids)
value = op(*args)
env.append(value)
result = env[-1]
return result
[docs] def as_code(self, name="program"):
"""
Returns Python code text defining a straight-line function equivalent
to this program.
:param str name: Optional name for the function, defaults to "program".
:returns: A string defining a python function equivalent to this program.
:rtype: str
"""
lines = [
"# Automatically generated by funsor.compiler.FunsorProgram.as_code().",
"def {}({}):".format(name, ", ".join(self.inputs)),
" from funsor import set_backend, ops",
f" set_backend({repr(self.backend)})",
]
start = len(lines)
def let(body):
i = len(lines) - start
lines.append(f" v{i} = {body}")
for c in self.constants:
let(c)
for name in self.inputs:
let(name)
for op, arg_ids in self.operations:
op = _print_op(op)
args = ", ".join(f"v{arg_id}" for arg_id in arg_ids)
let(f"{op}({args},)")
lines.append(f" return v{len(lines) - start - 1}")
return "\n".join(lines)
def make_tuple(*args):
return args
def _print_op(op):
if op is make_tuple:
return ""
if op.defaults and op.defaults != type(op)().defaults:
args = ", ".join(map(str, op.defaults.values()))
return f"ops.{type(op).__name__}({args})"
return repr(op)
__all__ = [
"OpProgram",
]