Skip to content

Commit 88e385a

Browse files
committed
- added toy example from "Weight Uncertainty in Neural Networks" by Blundell et al.
1 parent c68a2f9 commit 88e385a

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

Diff for: examples/bayesian_neural_net.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def logprob(weights, inputs, targets):
4242

4343

4444
def build_toy_dataset(n_data=100, noise_std=0.1, toy_example='blackbox'):
45+
print(toy_example)
4546
if toy_example == "blackbox":
4647
D = 1
4748
rs = npr.RandomState(0)
@@ -51,7 +52,7 @@ def build_toy_dataset(n_data=100, noise_std=0.1, toy_example='blackbox'):
5152
inputs = (inputs - 4.0) / 4.0
5253
inputs = inputs.reshape((len(inputs), D))
5354
targets = targets.reshape((len(targets), D))
54-
elif toy_example == "wierstra":
55+
if toy_example == "Wierstra":
5556
noise_std = 0.02
5657
inputs = np.linspace(0, 0.5, n_data).reshape(-1,1)
5758
n_traces = 1
@@ -71,7 +72,7 @@ def build_toy_dataset(n_data=100, noise_std=0.1, toy_example='blackbox'):
7172

7273
if __name__ == '__main__':
7374

74-
toy_example = "blackbox" # blackbox or Wierstra
75+
toy_example = "Wierstra" # blackbox or Wierstra
7576

7677
# Specify inference problem by its unnormalized log-posterior.
7778
rbf = lambda x: np.exp(-x**2)
@@ -106,7 +107,7 @@ def callback(params, t, g):
106107
if toy_example == "blackbox":
107108
plot_inputs = np.linspace(-8, 8, num=400)
108109
y_lim = [-2, 3]
109-
elif toy_example == "wierstra":
110+
elif toy_example == "Wierstra":
110111
plot_inputs = np.linspace(-0.2, 1.2, num=400)
111112
y_lim = [-0.6, 1.2]
112113

@@ -115,6 +116,7 @@ def callback(params, t, g):
115116

116117
# Plot data and functions.
117118
plt.cla()
119+
plt.title(toy_example)
118120
ax.plot(inputs.ravel(), targets.ravel(), 'bx')
119121
ax.plot(plot_inputs, outputs[:, :, 0].T)
120122
ax.set_ylim(y_lim)

0 commit comments

Comments
 (0)