-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathbouncingMnist_originalTest.py
41 lines (30 loc) · 1.01 KB
/
bouncingMnist_originalTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# adapted from https://github.com/emansim/unsupervised-videos
# DataHandler for different types of datasets
from __future__ import division
import h5py
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
class DataHandler(object):
"""Data Handler that creates Bouncing MNIST dataset on the fly."""
def __init__(self, dataset='/path/to/bouncing_mnist_test.npy'):
self.data_ = np.load(dataset)[..., np.newaxis].transpose(0,4,2,3,1)
self.data_ = self.data_.astype(np.float32) / 255
self.dataset_size_ = self.data_.shape[0]
self.num_channels_ = self.data_.shape[1]
self.image_size_ = self.data_.shape[2]
self.frame_size_ = self.image_size_ ** 2
def GetBatchSize(self):
return self.batch_size_
def GetDims(self):
return self.frame_size_
def GetDatasetSize(self):
return self.dataset_size_
def GetSeqLength(self):
return self.seq_length_
def Reset(self):
pass
def GetBatch(self, ind):
# minibatch data
data = self.data_[ind]
return data