Stochastic¶
This notebook demonstrates the use of stochastic wrappers in Clipppy
[8]:
from clipppy import loads
[1]:
import torch
import pyro, pyro.distributions as d
from clipppy.stochastic import Stochastic, Sampler
from clipppy.utils.distributions.extra_independent import ExtraIndependent
fb = Sampler(Stochastic(d.Normal, specs={'loc': Sampler(d.Normal(0, 1), name='a'), 'scale': 1}), name='b', to_event=0)
with pyro.poutine.trace() as tracer, pyro.plate('sumplate', 10):
a = pyro.sample('a', d.Uniform(0, 10))
b = pyro.sample('b', ExtraIndependent(d.Normal(a, 1), (1000,)))
c = pyro.sample('c', d.Normal(b, 1).to_event(1))
trace = tracer.trace
[(k, v['value'].shape) for k, v in trace.nodes.items()], b.std(), trace
[1]:
([('sumplate', torch.Size([10])),
('a', torch.Size([10])),
('b', torch.Size([10, 1000])),
('c', torch.Size([10, 1000]))],
tensor(2.8549),
<pyro.poutine.trace_struct.Trace at 0x7f5c6cb271f0>)
[2]:
def func(*args, **kwargs):
return args, kwargs
def get_stuff():
print('Getting stuff')
return {'the answer': 42, 'a number': 26}
loads('''
!Stochastic
- !py:func
- a: !AllEncapsulator
/: !py:get_stuff
/: &b
<: [&c, &d]
<<: {the answer: &e}
b: *b
c: *c
d: *d
e: *e
[a number]: *b
''')()
Getting stuff
[2]:
((),
{'a': {'the answer': 42, 'a number': 26},
'b': {'the answer': 42, 'a number': 26},
'c': 'the answer',
'd': 'a number',
'e': 42,
'a number': 26})