Skip to content

Commit eebcefb

Browse files
committed
initial commit
1 parent af04605 commit eebcefb

7 files changed

+670
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"Ok I have a generative function. What can I do with it?"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [
15+
{
16+
"ename": "",
17+
"evalue": "",
18+
"output_type": "error",
19+
"traceback": [
20+
"\u001b[1;31mFailed to start the Kernel. \n",
21+
"\u001b[1;31mJupyter Server crashed. Unable to connect. \n",
22+
"\u001b[1;31mError code from Jupyter: 1\n",
23+
"\u001b[1;31mTraceback (most recent call last):\n",
24+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/bin/jupyter-notebook\", line 5, in <module>\n",
25+
"\u001b[1;31m from notebook.app import main\n",
26+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/notebook/app.py\", line 12, in <module>\n",
27+
"\u001b[1;31m from jupyter_server.base.handlers import JupyterHandler\n",
28+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/jupyter_server/base/handlers.py\", line 23, in <module>\n",
29+
"\u001b[1;31m from jupyter_events import EventLogger\n",
30+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/jupyter_events/__init__.py\", line 3, in <module>\n",
31+
"\u001b[1;31m from .logger import EVENTS_METADATA_VERSION, EventLogger\n",
32+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/jupyter_events/logger.py\", line 14, in <module>\n",
33+
"\u001b[1;31m from jsonschema import ValidationError\n",
34+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/jsonschema/__init__.py\", line 13, in <module>\n",
35+
"\u001b[1;31m from jsonschema._format import FormatChecker\n",
36+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/jsonschema/_format.py\", line 11, in <module>\n",
37+
"\u001b[1;31m from jsonschema.exceptions import FormatError\n",
38+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/jsonschema/exceptions.py\", line 15, in <module>\n",
39+
"\u001b[1;31m from referencing.exceptions import Unresolvable as _Unresolvable\n",
40+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/referencing/__init__.py\", line 5, in <module>\n",
41+
"\u001b[1;31m from referencing._core import Anchor, Registry, Resource, Specification\n",
42+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/referencing/_core.py\", line 9, in <module>\n",
43+
"\u001b[1;31m from rpds import HashTrieMap, HashTrieSet, List\n",
44+
"\u001b[1;31m File \"/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/rpds/__init__.py\", line 1, in <module>\n",
45+
"\u001b[1;31m from .rpds import *\n",
46+
"\u001b[1;31mImportError: dlopen(/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-darwin.so, 0x0002): tried: '/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-darwin.so' (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64')), '/System/Volumes/Preboot/Cryptexes/OS/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-darwin.so' (no such file), '/Users/matuot/repos/genjax-docs/.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-darwin.so' (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64')). \n",
47+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
48+
]
49+
}
50+
],
51+
"source": [
52+
"import jax\n",
53+
"from genjax import flip\n",
54+
"from genjax import beta\n",
55+
"from genjax import bernoulli\n",
56+
"from genjax import static_gen_fn\n",
57+
"\n",
58+
"# Define a generative function\n",
59+
"@static_gen_fn\n",
60+
"def beta_bernoulli_process(u):\n",
61+
" p = beta(0.0, u) @ \"p\"\n",
62+
" v = bernoulli(p) @ \"v\" # sweet\n",
63+
" return v\n",
64+
"\n",
65+
"# We can:\n",
66+
"# 1] Generate a traced sample\n",
67+
"key = jax.random.PRNGKey(0)\n",
68+
"trace = jax.jit(beta_bernoulli_process.simulate)(key, (0.5,))\n",
69+
"# 1.1] Print the return value\n",
70+
"print(trace.get_retval())\n",
71+
"print()\n",
72+
"# 1.2] Print the choice_map, i.e. the list of internal random choices made during the execution\n",
73+
"print(trace.get_choices())\n",
74+
"print()\n",
75+
"print(trace.get_choices().get_submap(\"p\"))\n",
76+
"# 2] Compute log probabilities\n",
77+
"# 2.1] Print the log probability of the trace\n",
78+
"print(trace.get_score())\n",
79+
"print()\n",
80+
"# 2.2] Print the log probability of an observation under the model\n",
81+
"print(TODO)"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"metadata": {},
88+
"outputs": [],
89+
"source": []
90+
},
91+
{
92+
"cell_type": "markdown",
93+
"metadata": {},
94+
"source": []
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"metadata": {},
100+
"outputs": [],
101+
"source": []
102+
}
103+
],
104+
"metadata": {
105+
"kernelspec": {
106+
"display_name": "genjax-trials",
107+
"language": "python",
108+
"name": "python3"
109+
},
110+
"language_info": {
111+
"codemirror_mode": {
112+
"name": "ipython",
113+
"version": 3
114+
},
115+
"file_extension": ".py",
116+
"mimetype": "text/x-python",
117+
"name": "python",
118+
"nbconvert_exporter": "python",
119+
"pygments_lexer": "ipython3",
120+
"version": "3.11.6"
121+
}
122+
},
123+
"nbformat": 4,
124+
"nbformat_minor": 2
125+
}

0 commit comments

Comments
 (0)