Stochastic

This notebook demonstrates the use of stochastic wrappers in Clipppy

[8]:
from clipppy import loads
[1]:
import torch
import pyro, pyro.distributions as d

from clipppy.stochastic import Stochastic, Sampler
from clipppy.utils.distributions.extra_independent import ExtraIndependent

fb = Sampler(Stochastic(d.Normal, specs={'loc': Sampler(d.Normal(0, 1), name='a'), 'scale': 1}), name='b', to_event=0)


with pyro.poutine.trace() as tracer, pyro.plate('sumplate', 10):
    a = pyro.sample('a', d.Uniform(0, 10))
    b = pyro.sample('b', ExtraIndependent(d.Normal(a, 1), (1000,)))
    c = pyro.sample('c', d.Normal(b, 1).to_event(1))
trace = tracer.trace

[(k, v['value'].shape) for k, v in trace.nodes.items()], b.std(), trace
[1]:
([('sumplate', torch.Size([10])),
  ('a', torch.Size([10])),
  ('b', torch.Size([10, 1000])),
  ('c', torch.Size([10, 1000]))],
 tensor(2.8549),
 <pyro.poutine.trace_struct.Trace at 0x7f5c6cb271f0>)
[2]:
def func(*args, **kwargs):
    return args, kwargs

def get_stuff():
    print('Getting stuff')
    return {'the answer': 42, 'a number': 26}

loads('''
!Stochastic
- !py:func
- a: !AllEncapsulator
    /: !py:get_stuff
    /: &b
    <: [&c, &d]
    <<: {the answer: &e}
  b: *b
  c: *c
  d: *d
  e: *e
  [a number]: *b
''')()
Getting stuff
[2]:
((),
 {'a': {'the answer': 42, 'a number': 26},
  'b': {'the answer': 42, 'a number': 26},
  'c': 'the answer',
  'd': 'a number',
  'e': 42,
  'a number': 26})