{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true, "pycharm": { "name": "#%% md\n" } }, "source": [ "## Stochastic\n", "\n", "This notebook demonstrates the use of stochastic wrappers in Clipppy" ] }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "from clipppy import loads" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 1, "outputs": [ { "data": { "text/plain": "([('sumplate', torch.Size([10])),\n ('a', torch.Size([10])),\n ('b', torch.Size([10, 1000])),\n ('c', torch.Size([10, 1000]))],\n tensor(2.8549),\n )" }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import pyro, pyro.distributions as d\n", "\n", "from clipppy.stochastic import Stochastic, Sampler\n", "from clipppy.utils.distributions.extra_independent import ExtraIndependent\n", "\n", "fb = Sampler(Stochastic(d.Normal, specs={'loc': Sampler(d.Normal(0, 1), name='a'), 'scale': 1}), name='b', to_event=0)\n", "\n", "\n", "with pyro.poutine.trace() as tracer, pyro.plate('sumplate', 10):\n", " a = pyro.sample('a', d.Uniform(0, 10))\n", " b = pyro.sample('b', ExtraIndependent(d.Normal(a, 1), (1000,)))\n", " c = pyro.sample('c', d.Normal(b, 1).to_event(1))\n", "trace = tracer.trace\n", "\n", "[(k, v['value'].shape) for k, v in trace.nodes.items()], b.std(), trace" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Getting stuff\n" ] }, { "data": { "text/plain": "((),\n {'a': {'the answer': 42, 'a number': 26},\n 'b': {'the answer': 42, 'a number': 26},\n 'c': 'the answer',\n 'd': 'a number',\n 'e': 42,\n 'a number': 26})" }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def func(*args, **kwargs):\n", " return args, kwargs\n", "\n", "def get_stuff():\n", " print('Getting stuff')\n", " return {'the answer': 42, 'a number': 26}\n", "\n", "loads('''\n", "!Stochastic\n", "- !py:func\n", "- a: !AllEncapsulator\n", " /: !py:get_stuff\n", " /: &b\n", " <: [&c, &d]\n", " <<: {the answer: &e}\n", " b: *b\n", " c: *c\n", " d: *d\n", " e: *e\n", " [a number]: *b\n", "''')()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }