Skip to content

Commit 3cc4794

Browse files
committed
Propagating temperature parameter
1 parent a41ccc1 commit 3cc4794

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

examples/library-mm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
T = 1000
1616

17-
f = np.load('/Users/mattjj/Dropbox/Test Data/TMT_50p_mixtures_and_data.npz')
17+
f = np.load("/Users/Alex/Dropbox/Science/Datta lab/Posture Tracking/Test Data/TMT_50p_mixtures_and_data.npz")
1818
mus = f['mu']
1919
sigmas = f['sigma']
2020
data = f['data'][:T]
@@ -87,12 +87,12 @@
8787
# for the GMMs described above. roughly, gamma controls the total number
8888
# of states while alpha controls the diversity of the transition
8989
# distributions.
90-
alpha=10.,gamma=10.,
90+
# alpha=10.,gamma=10.,
9191
# NOTE: as with a_0 and b_0 for the GMMs described above, we can also
9292
# put gamma priors over alpha and gamma by commenting out the direct
9393
# alpha= and gamma= lines and using these instead
94-
# alpha_a_0=1.,alpha_b_0=1./10,
95-
# gamma_a_0=1.,gamma_b_0=1./10,
94+
alpha_a_0=1.,alpha_b_0=1./10,
95+
gamma_a_0=1.,gamma_b_0=1./10,
9696
obs_distns=hsmm_obs_distns,
9797
dur_distns=dur_distns)
9898
model.trans_distn.max_likelihood([rle(labels)[0]])

models.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ def _clear_caches(self):
104104

105105
### Gibbs sampling
106106

107-
def resample_model(self):
108-
self.resample_obs_distns()
107+
def resample_model(self, temp=None):
108+
self.resample_obs_distns(temp=temp)
109109
self.resample_trans_distn()
110110
self.resample_init_state_distn()
111111
self.resample_states()
112112

113-
def resample_obs_distns(self):
113+
def resample_obs_distns(self, temp=None):
114114
for state, distn in enumerate(self.obs_distns):
115-
distn.resample([s.data[s.stateseq == state] for s in self.states_list])
115+
distn.resample([s.data[s.stateseq == state] for s in self.states_list], temp=temp)
116116
self._clear_caches()
117117

118118
def resample_trans_distn(self):
@@ -381,9 +381,9 @@ def generate(self,T,keep=True,**kwargs):
381381

382382
### Gibbs sampling
383383

384-
def resample_model(self):
384+
def resample_model(self, temp=None):
385385
self.resample_dur_distns()
386-
super(HSMM,self).resample_model()
386+
super(HSMM,self).resample_model(temp=temp)
387387

388388
def resample_dur_distns(self):
389389
for state, distn in enumerate(self.dur_distns):

0 commit comments

Comments
 (0)