@@ -104,15 +104,15 @@ def _clear_caches(self):
104
104
105
105
### Gibbs sampling
106
106
107
- def resample_model (self ):
108
- self .resample_obs_distns ()
107
+ def resample_model (self , temp = None ):
108
+ self .resample_obs_distns (temp = temp )
109
109
self .resample_trans_distn ()
110
110
self .resample_init_state_distn ()
111
111
self .resample_states ()
112
112
113
- def resample_obs_distns (self ):
113
+ def resample_obs_distns (self , temp = None ):
114
114
for state , distn in enumerate (self .obs_distns ):
115
- distn .resample ([s .data [s .stateseq == state ] for s in self .states_list ])
115
+ distn .resample ([s .data [s .stateseq == state ] for s in self .states_list ], temp = temp )
116
116
self ._clear_caches ()
117
117
118
118
def resample_trans_distn (self ):
@@ -381,9 +381,9 @@ def generate(self,T,keep=True,**kwargs):
381
381
382
382
### Gibbs sampling
383
383
384
- def resample_model (self ):
384
+ def resample_model (self , temp = None ):
385
385
self .resample_dur_distns ()
386
- super (HSMM ,self ).resample_model ()
386
+ super (HSMM ,self ).resample_model (temp = temp )
387
387
388
388
def resample_dur_distns (self ):
389
389
for state , distn in enumerate (self .dur_distns ):
0 commit comments