Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9abb33d

Browse files
committedApr 11, 2014
Fix get/setRNGState for gaussian state.
Fixes torch#8
1 parent 14b20df commit 9abb33d

File tree

6 files changed

+46
-29
lines changed

6 files changed

+46
-29
lines changed
 

‎lib/TH/THRandom.c

+12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ THGenerator* THGenerator_new()
1414
return self;
1515
}
1616

17+
THGenerator* THGenerator_copy(THGenerator *self, THGenerator *from)
18+
{
19+
memcpy(self, from, sizeof(THGenerator));
20+
return self;
21+
}
22+
1723
void THGenerator_free(THGenerator *self)
1824
{
1925
THFree(self);
@@ -89,6 +95,12 @@ unsigned long THRandom_seed(THGenerator *_generator)
8995
void THRandom_manualSeed(THGenerator *_generator, unsigned long the_seed_)
9096
{
9197
int j;
98+
99+
/* This ensures reseeding resets all of the state (i.e. state for Gaussian numbers) */
100+
THGenerator *blank = THGenerator_new();
101+
THGenerator_copy(_generator, blank);
102+
THGenerator_free(blank);
103+
92104
_generator->the_initial_seed = the_seed_;
93105
_generator->state[0] = _generator->the_initial_seed & 0xffffffffUL;
94106
for(j = 1; j < n; j++)

‎lib/TH/THRandom.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#define _MERSENNE_STATE_N 624
77
#define _MERSENNE_STATE_M 397
8+
/* A THGenerator contains all the state required for a single random number stream */
89
typedef struct THGenerator {
910
/* The initial seed. */
1011
unsigned long the_initial_seed;
@@ -23,10 +24,9 @@ typedef struct THGenerator {
2324

2425
#define torch_Generator "torch.Generator"
2526

26-
/* Create a new random number generator stream */
27+
/* Manipulate THGenerator objects */
2728
TH_API THGenerator * THGenerator_new();
28-
29-
/* Free a random number generator stream */
29+
TH_API THGenerator * THGenerator_copy(THGenerator *self, THGenerator *from);
3030
TH_API void THGenerator_free(THGenerator *gen);
3131

3232
/* Initializes the random number generator with the current time (granularity: seconds) and returns the seed. */

‎lib/TH/generic/THTensorRandom.c

+12-21
Original file line numberDiff line numberDiff line change
@@ -210,33 +210,24 @@ TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator,
210210

211211
#endif
212212

213-
#if defined(TH_REAL_IS_LONG)
213+
#if defined(TH_REAL_IS_BYTE)
214214
TH_API void THTensor_(getRNGState)(THGenerator *_generator, THTensor *self)
215215
{
216-
unsigned long *data;
217-
long *offset;
218-
long *left;
219-
220-
THTensor_(resize1d)(self,626);
221-
data = (unsigned long *)THTensor_(data)(self);
222-
offset = (long *)data+624;
223-
left = (long *)data+625;
224-
225-
THRandom_getState(_generator, data, offset, left);
216+
static const size_t size = sizeof(THGenerator);
217+
THGenerator *state;
218+
THTensor_(resize1d)(self, size);
219+
state = (THGenerator *)THTensor_(data)(self);
220+
THGenerator_copy(state, _generator);
226221
}
227222

228223
TH_API void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self)
229224
{
230-
unsigned long *data;
231-
long *offset;
232-
long *left;
233-
234-
THArgCheck(THTensor_(nElement)(self) == 626, 1, "state should have 626 elements");
235-
data = (unsigned long *)THTensor_(data)(self);
236-
offset = (long *)(data+624);
237-
left = (long *)(data+625);
238-
239-
THRandom_setState(_generator, data, *offset, *left);
225+
static const size_t size = sizeof(THGenerator);
226+
THGenerator *state;
227+
THArgCheck(THTensor_(nElement)(self) == size, 1, "RNG state is wrong size");
228+
THArgCheck(THTensor_(isContiguous)(self), 1, "RNG state needs to be contiguous");
229+
state = (THGenerator *)THTensor_(data)(self);
230+
THGenerator_copy(_generator, state);
240231
}
241232
#endif
242233

‎lib/TH/generic/THTensorRandom.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double
1515
TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement);
1616
#endif
1717

18-
#if defined(TH_REAL_IS_LONG)
18+
#if defined(TH_REAL_IS_BYTE)
1919
TH_API void THTensor_(getRNGState)(THGenerator *_generator, THTensor *self);
2020
TH_API void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self);
2121
#endif

‎random.lua

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ interface:wrap('manualSeed',
2626
{name="long"}})
2727

2828
interface:wrap('getRNGState',
29-
'THLongTensor_getRNGState',
29+
'THByteTensor_getRNGState',
3030
{{name='Generator', default=true},
31-
{name='LongTensor',default=true,returned=true,method={default='nil'}}
31+
{name='ByteTensor',default=true,returned=true,method={default='nil'}}
3232
})
3333

3434
interface:wrap('setRNGState',
35-
'THLongTensor_setRNGState',
35+
'THByteTensor_setRNGState',
3636
{{name='Generator', default=true},
37-
{name='LongTensor',default=true,returned=true,method={default='nil'}}
37+
{name='ByteTensor',default=true,returned=true,method={default='nil'}}
3838
})
3939

4040
interface:register("random__")

‎test/test.lua

+14
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,20 @@ function torchtest.RNGState()
11051105
mytester:assertTensorEq(before, after, 1e-16, 'getRNGState/setRNGState not generating same sequence')
11061106
end
11071107

1108+
function torchtest.testBoxMullerState()
1109+
torch.manualSeed(123)
1110+
local odd_number = 101
1111+
local seeded = torch.randn(odd_number)
1112+
local state = torch.getRNGState()
1113+
local midstream = torch.randn(odd_number)
1114+
torch.setRNGState(state)
1115+
local repeat_midstream = torch.randn(odd_number)
1116+
torch.manualSeed(123)
1117+
local reseeded = torch.randn(odd_number)
1118+
mytester:assertTensorEq(midstream, repeat_midstream, 1e-16, 'getRNGState/setRNGState not generating same sequence of normally distributed numbers')
1119+
mytester:assertTensorEq(seeded, reseeded, 1e-16, 'repeated calls to manualSeed not generating same sequence of normally distributed numbers')
1120+
end
1121+
11081122
function torchtest.testCholesky()
11091123
local x = torch.rand(10,10)
11101124
local A = torch.mm(x, x:t())

0 commit comments

Comments
 (0)