From 8472696fc62385ccc6122e901bba71f82050d2c9 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Thu, 25 Jul 2019 14:44:01 +0100 Subject: [PATCH 1/6] add dataflow module --- tensorlayer/dataflow/__init__.py | 4 + tensorlayer/dataflow/base.py | 73 +++++ tensorlayer/dataflow/common.py | 351 +++++++++++++++++++++++ tensorlayer/dataflow/dataset/__init__.py | 5 + tensorlayer/dataflow/dataset/cifar10.py | 58 ++++ tensorlayer/dataflow/dataset/ilsvrc.py | 309 ++++++++++++++++++++ tensorlayer/dataflow/dataset/mnist.py | 62 ++++ tensorlayer/dataflow/parallel.py | 198 +++++++++++++ tensorlayer/dataflow/serialize.py | 27 ++ tensorlayer/dataflow/utils.py | 199 +++++++++++++ 10 files changed, 1286 insertions(+) create mode 100644 tensorlayer/dataflow/__init__.py create mode 100644 tensorlayer/dataflow/base.py create mode 100644 tensorlayer/dataflow/common.py create mode 100644 tensorlayer/dataflow/dataset/__init__.py create mode 100644 tensorlayer/dataflow/dataset/cifar10.py create mode 100644 tensorlayer/dataflow/dataset/ilsvrc.py create mode 100644 tensorlayer/dataflow/dataset/mnist.py create mode 100644 tensorlayer/dataflow/parallel.py create mode 100644 tensorlayer/dataflow/serialize.py create mode 100644 tensorlayer/dataflow/utils.py diff --git a/tensorlayer/dataflow/__init__.py b/tensorlayer/dataflow/__init__.py new file mode 100644 index 000000000..fec880c7f --- /dev/null +++ b/tensorlayer/dataflow/__init__.py @@ -0,0 +1,4 @@ +from .base import Dataset +from .base import Transform +from .common import Dataloader +from .common import TFDataloader diff --git a/tensorlayer/dataflow/base.py b/tensorlayer/dataflow/base.py new file mode 100644 index 000000000..3caacf6b7 --- /dev/null +++ b/tensorlayer/dataflow/base.py @@ -0,0 +1,73 @@ +class Dataset(object): + + def __getitem__(self, index): + raise NotImplementedError("A Dataset must implement __getitem__(index) method.") + + def __len__(self): + raise NotImplementedError("A Dataset must implement __len__() method.") + + def __iter__(self): + for i in range(self.__len__()): + yield self.__getitem__(i) + + def __call__(self, *args, **kwargs): + return self.__iter__() + + +class DatasetWrapper(object): + def __init__(self, ds): + self.ds = ds + self.ds_len = len(ds) + + def __len__(self): + return len(self.ds) + + def __iter__(self): + for dp in self.ds: + yield dp + + def __call__(self, *args, **kwargs): + return self.__iter__() + + +class IndexableDatasetWrapper(object): + def __init__(self, ds): + self.ds = ds + self.ds_len = len(ds) + + def __getitem__(self, index): + return self.ds.__getitem__(index) + + def __len__(self): + return len(self.ds) + + def __call__(self, *args, **kwargs): + return self + + +class Transform(object): + def __call__(self, *args, **kwargs): + raise NotImplementedError("Transform must implement __call__() method.") + + +class _Transforms_for_tf_dataset(object): + """ + This class aggregate Transforms into one object in order to use tf.data.Dataset.map API + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, *args): + # assert len(args) == len(self.transforms) + # data_list = [None] * len(args) + # for i in range(len(args)): + # data = args[i] + # for transform in self.transforms[i]: + # data = transform(data) + # data_list[i] = data + # return data_list + data_list = list(args) + for transform in self.transforms: + data_list = transform(*data_list) + return data_list diff --git a/tensorlayer/dataflow/common.py b/tensorlayer/dataflow/common.py new file mode 100644 index 000000000..a2e4c5adc --- /dev/null +++ b/tensorlayer/dataflow/common.py @@ -0,0 +1,351 @@ +import atexit +import math +import multiprocessing +import os + +import tensorflow as tf +import zmq + +import numpy as np + +from .base import IndexableDatasetWrapper, DatasetWrapper, _Transforms_for_tf_dataset +from .parallel import _get_pipe_name, ZMQMultiprocessDataset, MultiprocessDataset +from .utils import ensure_proc_terminate +from .serialize import convert_to_bytes, load_from_bytes + +__all__ = ['BatchedDataset', 'TransformedDataset', 'ShuffledDataset', + 'AugmentedDataset', 'Dataloader', 'TFDataloader'] + + +class BatchedDataset(DatasetWrapper): + def __init__(self, + ds, + batch_size, + drop_remainder=True, + return_numpy=True, + keep_dims=False, + output_types=None, + use_zmq=True): + super(BatchedDataset, self).__init__(ds) + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.return_numpy = return_numpy + self.keep_dims = keep_dims + self.output_types = output_types + self.use_zmq = use_zmq + + # self.q = multiprocessing.Queue(maxsize=1) + # self.worker = multiprocessing.Process(target=self._BatchedDataset_worker, + # args=(self.ds, self.q)) + # self.worker.start() + # ensure_proc_terminate(self.worker) + + if self.use_zmq: + self.data_pipename = _get_pipe_name('batch_prefetch') + context = zmq.Context() + self.fetch_data_socket = context.socket(zmq.PULL) + self.fetch_data_socket.bind(self.data_pipename) + self.worker = multiprocessing.Process(target=self._ZMQ_BatchedDataset_worker, + args=(self.ds,)) + self.worker.start() + else: + pipe_output, pipe_input = multiprocessing.Pipe() + self.worker = multiprocessing.Process(target=self._BatchedDataset_worker, + args=(self.ds, (pipe_output, pipe_input))) + self.worker.start() + # main process only reads (gets output) + pipe_input.close() + self.pipe_output = pipe_output + + ensure_proc_terminate(self.worker) + + def _ZMQ_BatchedDataset_worker(self, ds): + context = zmq.Context() + prepare_data_socket = context.socket(zmq.PUSH) + prepare_data_socket.connect(self.data_pipename) + while True: + dp_buffer = [] + for dp in ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + # q.put(self._batch_datapoints(dp_buffer)) + prepare_data_socket.send(convert_to_bytes(self._batch_datapoints(dp_buffer)), copy=False) + del dp_buffer[:] + if not self.drop_remainder: + # q.put(self._batch_datapoints(dp_buffer)) + prepare_data_socket.send(convert_to_bytes(self._batch_datapoints(dp_buffer)), copy=False) + + def _BatchedDataset_worker(self, ds, pipe): + pipe_output, pipe_input = pipe + # worker process only writes (puts input) + pipe_output.close() + while True: + dp_buffer = [] + for dp in ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + # q.put(self._batch_datapoints(dp_buffer)) + pipe_input.send(self._batch_datapoints(dp_buffer)) + del dp_buffer[:] + if not self.drop_remainder: + # q.put(self._batch_datapoints(dp_buffer)) + pipe_input.send(self._batch_datapoints(dp_buffer)) + + def __iter__(self): + for _ in range(self.__len__()): + # yield self.q.get() + if self.use_zmq: + yield load_from_bytes(self.fetch_data_socket.recv(copy=False)) + else: + yield self.pipe_output.recv() + + def __len__(self): + ds_len = len(self.ds) + if self.drop_remainder: + return ds_len // self.batch_size + else: + return math.ceil(ds_len / self.batch_size) + + def _batch_datapoints(self, dp_buffer): + """ + + :param dp_buffer: a list of datapoints + :return: + """ + first_dp = dp_buffer[0] + if isinstance(first_dp, (tuple, list)): + dp_batch = [None] * len(first_dp) + for i in range(len(first_dp)): + dp_element_batch = [] + for j in range(len(dp_buffer)): + dp_element_batch.append(dp_buffer[j][i]) + if self.return_numpy: + dp_batch[i] = self._batch_ndarray(dp_element_batch, dtype=self._get_element_dtype(i)) + else: + dp_batch[i] = dp_element_batch + return dp_batch + elif isinstance(first_dp, dict): + dp_batch = {} + for key in first_dp.keys(): + dp_element_batch = [] + for j in range(len(dp_buffer)): + dp_element_batch.append(dp_buffer[j][key]) + if self.return_numpy: + dp_batch[key] = self._batch_ndarray(dp_element_batch, dtype=None) + else: + dp_batch[key] = dp_element_batch + return dp_batch + elif isinstance(first_dp, np.ndarray): + return self._batch_ndarray(dp_buffer) + # single elements + else: + if self.return_numpy: + return self._batch_ndarray(dp_buffer, dtype=self._get_element_dtype(0)) + else: + return dp_buffer + + def _batch_ndarray(self, dp_element_batch, dtype): + """ + + :param dp_element_batch: a list of datapoint element, an element can be np.ndarray / list + :return: np.ndarray, type is the same as input + """ + try: + if dtype is not None: + ret = np.asarray(dp_element_batch, dtype=dtype) + else: + ret = np.asarray(dp_element_batch) + if self.keep_dims and len(ret.shape) == 1: + ret = np.expand_dims(ret, 1) + return ret + except: + raise ValueError("Unsupported type for batching.") + + def _get_element_dtype(self, i): + if self.output_types is None: + return None + if not isinstance(self.output_types, (tuple, list)): + return self.output_types + if len(self.output_types) == 1: + return self.output_types[0] + return self.output_types[i] + + +class ShuffledDataset(DatasetWrapper): + def __init__(self, ds): + super(ShuffledDataset, self).__init__(ds) + + def __iter__(self): + self.shuffled_idxs = np.random.permutation(len(self.ds)) + for index, data in enumerate(self.ds): + yield self.ds[self.shuffled_idxs[index]] + + +class TransformedDataset(IndexableDatasetWrapper): + """ + + """ + + def __init__(self, ds, transforms): + super(TransformedDataset, self).__init__(ds) + self.transforms = transforms + + def __getitem__(self, index): + dp = self.ds[index] + for transform in self.transforms: + assert callable(transform) + if isinstance(dp, (list, tuple)): + dp = transform(*dp) + else: + dp = transform(dp) + return dp + + +class AugmentedDataset(IndexableDatasetWrapper): + def __init__(self, ds, augmentations): + super(AugmentedDataset, self).__init__(ds) + self.augmentations = augmentations + self.num_augmentations = len(self.augmentations) + + def __getitem__(self, index): + if index >= self.__len__(): + raise IndexError + dp = self.ds[index % self.ds_len] + if index < self.ds_len: + return dp + augmentation = self.augmentations[(index // self.ds_len) - 1] + assert callable(augmentation) + if isinstance(dp, (list, tuple)): + return augmentation(*dp) + else: + return augmentation(dp) + + def __len__(self): + # every augmentation gives one more duplication of dataset + return self.ds_len * (1 + self.num_augmentations) + + +class Dataloader(DatasetWrapper): + def __init__(self, + ds, + augmentations=None, + shuffle=False, + batch_size=1, + drop_remainder=True, + batch_keep_dims=False, + output_types=None, + num_worker=os.cpu_count(), + use_zmq=True, + num_prefetch=None, + transforms=None): + + super(Dataloader, self).__init__(ds) + self.augmentations = augmentations + self.shuffle = shuffle + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.batch_keep_dims = batch_keep_dims + self.output_types = output_types + self.num_worker = num_worker + self.use_zmq = use_zmq + self.num_prefetch = num_worker if num_prefetch is None else num_prefetch + self.transforms = transforms + + if self.augmentations is not None: + self.ds = AugmentedDataset(self.ds, self.augmentations) + + if self.transforms is not None: + self.ds = TransformedDataset(self.ds, self.transforms) + # self.tfds = self.tfds.map(map_func=_Transforms(self.transforms), num_parallel_calls=num_map_worker) + + # TODO: auto adjust num_prefetch + if self.num_worker > 1: + if self.use_zmq: + self.ds = ZMQMultiprocessDataset(self.ds, num_worker=self.num_worker, hwm=self.num_prefetch, + shuffle=self.shuffle) + else: + self.ds = MultiprocessDataset(self.ds, num_worker=self.num_worker, num_prefetch=self.num_prefetch, + shuffle=self.shuffle) + elif self.shuffle: + self.ds = ShuffledDataset(self.ds) + + self.ds = BatchedDataset(self.ds, self.batch_size, drop_remainder=self.drop_remainder, + output_types=self.output_types, keep_dims=self.batch_keep_dims, + use_zmq=self.use_zmq) + + # self.tfds = tf.data.Dataset.from_generator(self.ds, output_types=output_types) + + # if self.num_prefetch > 1: + # self.tfds = self.tfds.prefetch(num_prefetch) + atexit.register(self._clean_up_socket_files) + + def __iter__(self): + for dp in self.ds: + yield dp + + def _clean_up_socket_files(self): + # remove all ipc socket files + # the environment variable starts with 'ipc://', so file name starts from 6 + try: + os.remove(os.environ['put_idx'][6:]) + except FileNotFoundError: + pass + try: + os.remove(os.environ['collect_data'][6:]) + except FileNotFoundError: + pass + try: + os.remove(os.environ['batch_prefetch'][6:]) + except FileNotFoundError: + pass + + +class TFDataloader(DatasetWrapper): + def __init__(self, + ds, + output_types, + augmentations=None, + shuffle=False, + shuffle_buffer_size=None, + batch_size=1, + drop_remainder=True, + # num_extract_worker=os.cpu_count(), + # num_map_worker=os.cpu_count(), + # num_prefetch=None, + transforms=None): + + super(TFDataloader, self).__init__(ds) + self.augmentations = augmentations + self.shuffle = shuffle + self.batch_size = batch_size + self.shuffle_buffer_size = 2 * batch_size if shuffle_buffer_size is None else shuffle_buffer_size + self.drop_remainder = drop_remainder + # self.num_map_worker = num_map_worker + # self.num_extract_worker = num_extract_worker + # self.num_prefetch = num_extract_worker if num_prefetch is None else num_prefetch + self.transforms = transforms + + self.ds = tf.data.Dataset.from_generator(self.ds, output_types=output_types) + + # if self.augmentations is not None: + # self.ds = AugmentedDataset(self.ds, self.augmentations) + + # if self.num_extract_worker > 1: + # self.ds = MultiProcessDataset(self.ds, num_worker=self.num_extract_worker, num_prefetch=self.num_prefetch) + + if self.shuffle: + self.ds = self.ds.shuffle(buffer_size=self.shuffle_buffer_size) + + if self.transforms is not None: + self.ds = self.ds.map(map_func=_Transforms_for_tf_dataset(self.transforms), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + if self.batch_size > 1: + self.ds = self.ds.batch(batch_size=self.batch_size, drop_remainder=self.drop_remainder) + + # if self.num_prefetch > 1: + self.ds = self.ds.prefetch(tf.data.experimental.AUTOTUNE) + + def __iter__(self): + for dp in self.ds: + yield dp diff --git a/tensorlayer/dataflow/dataset/__init__.py b/tensorlayer/dataflow/dataset/__init__.py new file mode 100644 index 000000000..d95f8afa6 --- /dev/null +++ b/tensorlayer/dataflow/dataset/__init__.py @@ -0,0 +1,5 @@ +from .mnist import MNIST +from .cifar10 import CIFAR10 +from .ilsvrc import ILSVRC12, ILSVRC12Files, ILSVRCMeta + +__all__ = ['MNIST', 'CIFAR10', 'ILSVRCMeta', 'ILSVRC12Files', 'ILSVRC12'] \ No newline at end of file diff --git a/tensorlayer/dataflow/dataset/cifar10.py b/tensorlayer/dataflow/dataset/cifar10.py new file mode 100644 index 000000000..f6ca6209f --- /dev/null +++ b/tensorlayer/dataflow/dataset/cifar10.py @@ -0,0 +1,58 @@ +import logging +import os +import pickle +import sys +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + + +class CIFAR10(Dataset): + def __init__(self, train_or_test, path='data', name='cifar10'): + self.path = os.path.join(path, name) + + # Helper function to unpickle the data + def unpickle(file): + fp = open(file, 'rb') + if sys.version_info.major == 2: + data = pickle.load(fp) + elif sys.version_info.major == 3: + data = pickle.load(fp, encoding='latin-1') + else: + raise RuntimeError("Sys Version Unsupported") + fp.close() + return data + + # Download and read the training and test set images and labels. + logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) + + filename = 'cifar-10-python.tar.gz' + maybe_download_and_extract(filename, path, CIFAR10_URL, extract=True) + + assert train_or_test in ['train', 'test'] + if train_or_test == 'train': + # Unpickle file and fill in data + self.images = None + self.labels = [] + for i in range(1, 6): + data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) + if i == 1: + self.images = data_dic['data'] + else: + self.images = np.vstack((self.images, data_dic['data'])) + self.labels += data_dic['labels'] + else: + test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) + self.images = test_data_dic['data'] + self.labels = np.array(test_data_dic['labels']) + + self.images = self.images.reshape((-1, 32, 32, 3)) + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, index): + return self.images[index], self.labels[index] diff --git a/tensorlayer/dataflow/dataset/ilsvrc.py b/tensorlayer/dataflow/dataset/ilsvrc.py new file mode 100644 index 000000000..e75b973f7 --- /dev/null +++ b/tensorlayer/dataflow/dataset/ilsvrc.py @@ -0,0 +1,309 @@ +import os +import logging +import cv2 + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files'] + +CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" + + +class ILSVRCMeta(object): + """ + Provide methods to access metadata for ILSVRC dataset. + Metadata is supposed to be found at/will be downloaded to 'path/name/' + + Parameters + ---------- + path : str + a folder path + name : str + name of the dataset + + Examples + -------- + >>> meta = ILSVRCMeta(path='data', name='ilsvrc') + >>> imglist = meta.get_image_list(train_or_val_or_test, dir_structure) + + """ + + def __init__(self, path='data', name='ilsvrc'): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) + self.filepath = maybe_download_and_extract('ilsvrc_meta', self.path, CAFFE_ILSVRC12_URL, extract=True) + self.caffepb = None + + def get_synset_words_1000(self): + """ + Returns: + dict: {cls_number: cls_name} + """ + fname = os.path.join(self.path, 'synset_words.txt') + assert os.path.isfile(fname), fname + lines = [x.strip() for x in open(fname).readlines()] + return dict(enumerate(lines)) + + def get_synset_1000(self): + """ + Returns: + dict: {cls_number: synset_id} + """ + fname = os.path.join(self.path, 'synsets.txt') + assert os.path.isfile(fname) + lines = [x.strip() for x in open(fname).readlines()] + return dict(enumerate(lines)) + + def get_image_list(self, name, dir_structure='original'): + """ + Args: + name (str): 'train' or 'val' or 'test' + dir_structure (str): same as in :meth:`ILSVRC12.__init__()`. + Returns: + list: list of (image filename, label) + """ + assert name in ['train', 'val', 'test'] + assert dir_structure in ['original', 'train'] + add_label_to_fname = (name != 'train' and dir_structure != 'original') + if add_label_to_fname: + synset = self.get_synset_1000() + + fname = os.path.join(self.path, name + '.txt') + assert os.path.isfile(fname), fname + with open(fname) as f: + ret = [] + for line in f.readlines(): + name, cls = line.strip().split() + cls = int(cls) + + if add_label_to_fname: + name = os.path.join(synset[cls], name) + + ret.append((name.strip(), cls)) + assert len(ret), fname + return ret + + # def get_per_pixel_mean(self, size=None): + # """ + # Args: + # size (tuple): image size in (h, w). Defaults to (256, 256). + # Returns: + # np.ndarray: per-pixel mean of shape (h, w, 3 (BGR)) in range [0, 255]. + # """ + # if self.caffepb is None: + # self.caffepb = get_caffe_pb() + # obj = self.caffepb.BlobProto() + # + # mean_file = os.path.join(self.dir, 'imagenet_mean.binaryproto') + # with open(mean_file, 'rb') as f: + # obj.ParseFromString(f.read()) + # arr = np.array(obj.data).reshape((3, 256, 256)).astype('float32') + # arr = np.transpose(arr, [1, 2, 0]) + # if size is not None: + # arr = cv2.resize(arr, size[::-1]) + # return arr + + @staticmethod + def guess_dir_structure(dir): + """ + Return the directory structure of "dir". + + Args: + dir(str): something like '/path/to/imagenet/val' + + Returns: + either 'train' or 'original' + """ + subdir = os.listdir(dir)[0] + # find a subdir starting with 'n' + if subdir.startswith('n') and \ + os.path.isdir(os.path.join(dir, subdir)): + dir_structure = 'train' + else: + dir_structure = 'original' + logging.info( + "[ILSVRC12] Assuming directory {} has '{}' structure.".format( + dir, dir_structure)) + return dir_structure + + +class ILSVRC12Files(Dataset): + """ + Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays. + This could be useful when ``cv2.imread`` is a bottleneck and you want to + decode it in smarter ways (e.g. in parallel). + """ + def __init__(self, path, train_or_val_or_test, meta_dir, + dir_structure=None): + """ + Same as in :class:`ILSVRC12`. + """ + assert train_or_val_or_test in ['train', 'test', 'val'] + path = os.path.expanduser(path) + assert os.path.isdir(path) + self.full_path = os.path.join(path, train_or_val_or_test) + self.path = train_or_val_or_test + # assert os.path.isdir(self.full_path) + # assert os.path.isdir(meta_dir) + + if train_or_val_or_test == 'train': + dir_structure = 'train' + elif dir_structure is None: + dir_structure = ILSVRCMeta.guess_dir_structure(self.full_path) + + meta = ILSVRCMeta(meta_dir) + self.imglist = meta.get_image_list(train_or_val_or_test, dir_structure) + + # for fname, _ in self.imglist[:10]: + # fname = os.path.join(self.full_path, fname) + # assert os.path.isfile(fname), fname + + def __len__(self): + return len(self.imglist) + + def __getitem__(self, index): + fname, label = self.imglist[index] + fname = os.path.join(self.full_path, fname) + return fname, label + + # def __iter__(self): + # idxs = np.arange(len(self.imglist)) + # if self.shuffle: + # self.rng.shuffle(idxs) + # for k in idxs: + # fname, label = self.imglist[k] + # fname = os.path.join(self.full_dir, fname) + # yield [fname, label] + + +class ILSVRC12(ILSVRC12Files): + """ + Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999]. + """ + def __init__(self, path, train_or_test, meta_dir, + dir_structure=None, shape=None): + """ + Args: + dir (str): A directory containing a subdir named ``name``, + containing the images in a structure described below. + name (str): One of 'train' or 'val' or 'test'. + shuffle (bool): shuffle the dataset. + Defaults to True if name=='train'. + dir_structure (str): One of 'original' or 'train'. + The directory structure for the 'val' directory. + 'original' means the original decompressed directory, which only has list of image files (as below). + If set to 'train', it expects the same two-level directory structure similar to 'dir/train/'. + By default, it tries to automatically detect the structure. + You probably do not need to care about this option because 'original' is what people usually have. + + Example: + + When `dir_structure=='original'`, `dir` should have the following structure: + + .. code-block:: none + + dir/ + train/ + n02134418/ + n02134418_198.JPEG + ... + ... + val/ + ILSVRC2012_val_00000001.JPEG + ... + test/ + ILSVRC2012_test_00000001.JPEG + ... + + With the downloaded ILSVRC12_img_*.tar, you can use the following + command to build the above structure: + + .. code-block:: none + + mkdir val && tar xvf ILSVRC12_img_val.tar -C val + mkdir test && tar xvf ILSVRC12_img_test.tar -C test + mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train + find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' + + When `dir_structure=='train'`, `dir` should have the following structure: + + .. code-block:: none + + dir/ + train/ + n02134418/ + n02134418_198.JPEG + ... + ... + val/ + n01440764/ + ILSVRC2012_val_00000293.JPEG + ... + ... + test/ + ILSVRC2012_test_00000001.JPEG + ... + """ + super(ILSVRC12, self).__init__( + path, train_or_test, meta_dir, dir_structure) + self.shape = shape + + """ + There are some CMYK / png images, but cv2 seems robust to them. + https://github.com/tensorflow/models/blob/c0cd713f59cfe44fa049b3120c417cc4079c17e3/research/inception/inception/data/build_imagenet_data.py#L264-L300 + """ + # def __iter__(self): + # for fname, label in super(ILSVRC12, self).__iter__(): + # im = cv2.imread(fname, cv2.IMREAD_COLOR) + # assert im is not None, fname + # yield [im, label] + + def __getitem__(self, index): + fname, label = super(ILSVRC12, self).__getitem__(index) + img = cv2.imread(fname, cv2.IMREAD_COLOR) + if self.shape is not None: + img = cv2.resize(img, self.shape) + return img, label + + # @staticmethod + # def get_training_bbox(bbox_dir, imglist): + # import xml.etree.ElementTree as ET + # ret = [] + # + # def parse_bbox(fname): + # root = ET.parse(fname).getroot() + # size = root.find('size').getchildren() + # size = map(int, [size[0].text, size[1].text]) + # + # box = root.find('object').find('bndbox').getchildren() + # box = map(lambda x: float(x.text), box) + # return np.asarray(box, dtype='float32') + # + # with timed_operation('Loading Bounding Boxes ...'): + # cnt = 0 + # for k in tqdm.trange(len(imglist)): + # fname = imglist[k][0] + # fname = fname[:-4] + 'xml' + # fname = os.path.join(bbox_dir, fname) + # try: + # ret.append(parse_bbox(fname)) + # cnt += 1 + # except Exception: + # ret.append(None) + # logger.info("{}/{} images have bounding box.".format(cnt, len(imglist))) + # return ret + + +# if __name__ == '__main__': +# meta = ILSVRCMeta() +# # print(meta.get_synset_words_1000()) +# +# ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False) +# ds.reset_state() +# +# for k in ds: +# from IPython import embed +# embed() +# break diff --git a/tensorlayer/dataflow/dataset/mnist.py b/tensorlayer/dataflow/dataset/mnist.py new file mode 100644 index 000000000..15db8c951 --- /dev/null +++ b/tensorlayer/dataflow/dataset/mnist.py @@ -0,0 +1,62 @@ +import gzip +import logging +import os +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +MNIST_TRAIN_IMAGE_URL = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' +MNIST_TRAIN_LABEL_URL = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz' +MNIST_TEST_IMAGE_URL = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz' +MNIST_TEST_LABEL_URL = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz' + + +class MNIST(Dataset): + def __init__(self, train_or_test, path='data', name='mnist'): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + + assert train_or_test in ['train', 'test'] + if train_or_test == 'train': + self.images = self.load_mnist_images(train_or_test=train_or_test) + self.labels = self.load_mnist_labels(train_or_test=train_or_test) + else: + self.images = self.load_mnist_images(train_or_test=train_or_test) + self.labels = self.load_mnist_labels(train_or_test=train_or_test) + + def load_mnist_images(self, train_or_test): + if train_or_test == 'train': + filepath = maybe_download_and_extract('train-images-idx3-ubyte.gz', self.path, MNIST_TRAIN_IMAGE_URL) + else: + filepath = maybe_download_and_extract('t10k-images-idx3-ubyte.gz', self.path, MNIST_TEST_IMAGE_URL) + + logging.info(filepath) + # Read the inputs in Yann LeCun's binary format. + with gzip.open(filepath, 'rb') as f: + data = np.frombuffer(f.read(), np.uint8, offset=16) + # The inputs are vectors now, we reshape them to monochrome 2D images, + # following the shape convention: (examples, channels, rows, columns) + data = data.reshape((-1, 28, 28, 1)) + # The inputs come as bytes, we convert them to float32 in range [0,1]. + # (Actually to range [0, 255/256], for compatibility to the version + # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) + return data / np.float32(256) + + def load_mnist_labels(self, train_or_test): + if train_or_test == 'train': + filepath = maybe_download_and_extract('train-labels-idx1-ubyte.gz', self.path, MNIST_TRAIN_LABEL_URL) + else: + filepath = maybe_download_and_extract('t10k-labels-idx1-ubyte.gz', self.path, MNIST_TEST_LABEL_URL) + + # Read the labels in Yann LeCun's binary format. + with gzip.open(filepath, 'rb') as f: + data = np.frombuffer(f.read(), np.uint8, offset=8) + # The labels are vectors of integers now, that's exactly what we want. + return data + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, index): + return self.images[index], self.labels[index] diff --git a/tensorlayer/dataflow/parallel.py b/tensorlayer/dataflow/parallel.py new file mode 100644 index 000000000..6eb250e21 --- /dev/null +++ b/tensorlayer/dataflow/parallel.py @@ -0,0 +1,198 @@ +import multiprocessing +import os +import sys +import uuid + +import zmq +import numpy as np + +from .base import DatasetWrapper +from .serialize import * + + +class MultiprocessDataset(DatasetWrapper): + def __init__(self, + ds, + num_worker, + num_prefetch, + shuffle=False): + + super(MultiprocessDataset, self).__init__(ds) + self.num_worker = num_worker + self.num_prefetch = num_prefetch + self.shuffle = shuffle + + self.index_queue = multiprocessing.Queue(self.num_worker) + self.data_queue = multiprocessing.Queue(self.num_prefetch) + self.put_idx_worker = None + for _ in range(num_worker): + worker = multiprocessing.Process(target=self._worker, + args=(self.ds, self.index_queue, self.data_queue)) + worker.daemon = True + worker.start() + + def _worker(self, ds, index_q, data_q): + while True: + idx = index_q.get() + data_q.put((idx, ds[idx])) + + def _put_idxs(self, idxs, index_q): + for idx in idxs: + index_q.put(idx) + + def __iter__(self): + # shutdown put_idx_worker and clear queues from previous epoch + _shutdown_proc(self.put_idx_worker) + while not self.index_queue.empty(): + self.index_queue.get() + while not self.data_queue.empty(): + self.data_queue.get() + + # shuffle at the start of every epoch + if self.shuffle: + self.idxs = np.random.permutation(self.ds_len) + else: + self.idxs = np.arange(self.ds_len) + + self.put_idx_worker = multiprocessing.Process(target=self._put_idxs, + args=(self.idxs, self.index_queue)) + self.put_idx_worker.daemon = True + self.put_idx_worker.start() + + data_buffer = {} + for return_idx in self.idxs: + if return_idx in data_buffer: + yield data_buffer.pop(return_idx) + else: + while True: + idx, dp = self.data_queue.get() + if idx == return_idx: + yield dp + break + else: + data_buffer[idx] = dp + _shutdown_proc(self.put_idx_worker) + + +def _shutdown_proc(proc): + if proc is None: + return + if proc.is_alive(): + proc.terminate() + proc.join() + + +class ZMQMultiprocessDataset(DatasetWrapper): + def __init__(self, + ds, + num_worker, + hwm=50, + shuffle=False): + + super(ZMQMultiprocessDataset, self).__init__(ds) + self.num_worker = num_worker + self.shuffle = shuffle + self._hwm = hwm + + self.idx_pipename = _get_pipe_name('put_idx') + self.data_pipename = _get_pipe_name('collect_data') + + self.put_idx_worker = None + for i in range(num_worker): + # first worker bind the socket, others connect to the socket + # however, zmq sockets using ipc do not care about the order of bind / connect + if i == 0: + worker = multiprocessing.Process(target=self._worker, + args=(True,)) + else: + worker = multiprocessing.Process(target=self._worker, + args=()) + worker.daemon = True + worker.start() + + def _worker(self, bind=False): + context = zmq.Context() + worker_receive_index_socket = context.socket(zmq.PULL) + worker_receive_index_socket.set_hwm(self._hwm) + if bind: + worker_receive_index_socket.bind(self.idx_pipename) + else: + worker_receive_index_socket.connect(self.idx_pipename) + + worker_send_data_socket = context.socket(zmq.PUSH) + worker_send_data_socket.set_hwm(self._hwm) + if bind: + worker_send_data_socket.bind(self.data_pipename) + else: + worker_send_data_socket.connect(self.data_pipename) + + while True: + recv_msg = worker_receive_index_socket.recv(copy=False) + idx = load_from_bytes(recv_msg) + send_msg = convert_to_bytes({'idx': idx, 'data': self.ds[idx]}) + worker_send_data_socket.send(send_msg, copy=False) + + def _put_idxs(self): + context = zmq.Context() + put_idx_socket = context.socket(zmq.PUSH) + put_idx_socket.set_hwm(self._hwm) + put_idx_socket.connect(self.idx_pipename) + for idx in self.idxs: + send_msg = convert_to_bytes(idx) + put_idx_socket.send(send_msg, copy=False) + + def __iter__(self): + context = zmq.Context() + collect_data_socket = context.socket(zmq.PULL) + collect_data_socket.set_hwm(self._hwm) + collect_data_socket.connect(self.data_pipename) + + # shutdown put_idx_worker and clear queues from previous epoch + _shutdown_proc(self.put_idx_worker) + try: + while True: + collect_data_socket.recv(flags=zmq.NOBLOCK) + except zmq.ZMQError: + pass + + # shuffle at the start of every epoch + if self.shuffle: + self.idxs = np.random.permutation(self.ds_len) + else: + self.idxs = np.arange(self.ds_len) + + self.put_idx_worker = multiprocessing.Process(target=self._put_idxs, + args=()) + self.put_idx_worker.daemon = True + self.put_idx_worker.start() + + data_buffer = {} + for return_idx in self.idxs: + if return_idx in data_buffer: + yield data_buffer.pop(return_idx) + else: + while True: + recv_msg = collect_data_socket.recv(copy=False) + recv_msg = load_from_bytes(recv_msg) + idx, dp = recv_msg['idx'], recv_msg['data'] + if idx == return_idx: + yield dp + break + else: + data_buffer[idx] = dp + _shutdown_proc(self.put_idx_worker) + + +def _get_pipe_name(name): + if sys.platform.startswith('linux'): + # linux supports abstract sockets: http://api.zeromq.org/4-1:zmq-ipc + pipename = "ipc://@{}-pipe-{}".format(name, str(uuid.uuid1())[:8]) + else: + pipedir = '.' + assert os.path.isdir(pipedir), pipedir + filename = '{}/{}-pipe-{}'.format(pipedir.rstrip('/'), name, str(uuid.uuid1())[:6]) + assert not os.path.exists(filename), "Pipe {} exists! You may be unlucky.".format(filename) + pipename = "ipc://{}".format(filename) + # register in environment variable, used for cleaning up ipc socket files + os.environ[name] = pipename + return pipename diff --git a/tensorlayer/dataflow/serialize.py b/tensorlayer/dataflow/serialize.py new file mode 100644 index 000000000..aa272f4f4 --- /dev/null +++ b/tensorlayer/dataflow/serialize.py @@ -0,0 +1,27 @@ +import msgpack_numpy + +MAX_MSGPACK_LEN = 1000000000 + + +def convert_to_bytes(obj): + """ + Serialize an object. + + Returns: + Implementation-dependent bytes-like object. + """ + return msgpack_numpy.dumps(obj, use_bin_type=True) + + +def load_from_bytes(buf): + """ + Args: + buf: the output of `dumps`. + """ + # Since 0.6, the default max size was set to 1MB. + # We change it to approximately 1G. + return msgpack_numpy.loads(buf, raw=False, + max_bin_len=MAX_MSGPACK_LEN, + max_array_len=MAX_MSGPACK_LEN, + max_map_len=MAX_MSGPACK_LEN, + max_str_len=MAX_MSGPACK_LEN) diff --git a/tensorlayer/dataflow/utils.py b/tensorlayer/dataflow/utils.py new file mode 100644 index 000000000..a90016aeb --- /dev/null +++ b/tensorlayer/dataflow/utils.py @@ -0,0 +1,199 @@ +import atexit +import logging +import math +import multiprocessing +import os +import weakref + +import psutil +import tarfile +import time +import zipfile +import progressbar +from urllib.request import urlretrieve + + +def exists_or_mkdir(path, verbose=True): + """ + Check a folder by given name, if not exist, create the folder and return False, + if directory exists, return True. + + Parameters + ---------- + path : str + A folder path. + verbose : boolean + If True (default), prints results. + + Returns + -------- + boolean + True if folder already exist, otherwise, returns False and create the folder. + + Examples + -------- + >>> exists_or_mkdir("checkpoints/train") + + """ + if not os.path.exists(path): + if verbose: + logging.info("[*] creates %s ..." % path) + os.makedirs(path) + return False + else: + if verbose: + logging.info("[!] %s exists ..." % path) + return True + + +def download(filename, working_directory, url_source): + """ + Download file from url_source to the working_directory with given filename. + + Parameters + ---------- + filename : str + The name of the downloaded file. + working_directory : str + A folder path download the file to + url_source : str + The URL to download the file from + + Examples + -------- + >>> download(filename='train.gz', + ... working_directory='data/', + ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') + + """ + working_directory = os.path.expanduser(working_directory) + + progress_bar = progressbar.ProgressBar() + + def _dlProgress(count, blockSize, totalSize, pbar=progress_bar): + if (totalSize != 0): + + if not pbar.max_value: + totalBlocks = math.ceil(float(totalSize) / float(blockSize)) + pbar.max_value = int(totalBlocks) + + pbar.update(count, force=True) + + filepath = os.path.join(working_directory, filename) + + logging.info('Downloading %s...\n' % filename) + + urlretrieve(url_source, filepath, reporthook=_dlProgress) + + +def maybe_download_and_extract(filename, working_directory, url_source, extract=False, expected_bytes=None): + """ + Checks if file exists in working_directory otherwise tries to dowload the file, + and optionally also tries to extract the file if format is ".zip" or ".tar" + + Parameters + ----------- + filename : str + The name of the (to be) dowloaded file. + working_directory : str + A folder path to search for the file in and dowload the file to + url_source : str + The URL to download the file from + extract : boolean + If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file, default is False. + expected_bytes : int or None + If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception, defaults is None which corresponds to no check being performed. + + Returns + ---------- + str + File path of the dowloaded (uncompressed) file. + + Examples + -------- + >>> down_file = maybe_download_and_extract(filename='train.gz', + ... working_directory='data/', + ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') + >>> maybe_download_and_extract(filename='ADEChallengeData2016.zip', + ... working_directory='data/', + ... url_source='http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip', + ... extract=True) + + """ + working_directory = os.path.expanduser(working_directory) + exists_or_mkdir(working_directory, verbose=False) + filepath = os.path.join(working_directory, filename) + + if not os.path.exists(filepath): + download(filename, working_directory, url_source) + statinfo = os.stat(filepath) + logging.info('Succesfully downloaded %s %s bytes.' % (filename, statinfo.st_size)) # , 'bytes.') + if not (expected_bytes is None) and (expected_bytes != statinfo.st_size): + raise Exception('Failed to verify ' + filename + '. Can you get to it with a browser?') + if extract: + if tarfile.is_tarfile(filepath): + logging.info('Trying to extract tar file') + tarfile.open(filepath, 'r').extractall(working_directory) + logging.info('... Success!') + elif zipfile.is_zipfile(filepath): + logging.info('Trying to extract zip file') + with zipfile.ZipFile(filepath) as zf: + zf.extractall(working_directory) + logging.info('... Success!') + else: + logging.info("Unknown compression_format only .tar.gz/.tar.bz2/.tar and .zip supported") + return filepath + + +def get_dataloader_speed(dl, num_steps): + cnt = 0 + start = time.time() + end = start + for _ in dl: + cnt += 1 + if cnt == num_steps: + end = time.time() + break + return (end - start) / num_steps + + +def format_bytes(bytes): + if abs(bytes) < 1000: + return str(bytes) + "B" + elif abs(bytes) < 1e6: + return str(round(bytes / 1e3, 2)) + "kB" + elif abs(bytes) < 1e9: + return str(round(bytes / 1e6, 2)) + "MB" + else: + return str(round(bytes / 1e9, 2)) + "GB" + + +def get_process_memory(): + process = psutil.Process(os.getpid()) + mi = process.memory_info() + return mi.rss, mi.vms, mi.vms + + +def ensure_proc_terminate(proc): + """ + Make sure processes terminate when main process exit. + + Args: + proc (multiprocessing.Process or list) + """ + if isinstance(proc, list): + for p in proc: + ensure_proc_terminate(p) + return + + def stop_proc_by_weak_ref(ref): + proc = ref() + if proc is None: + return + if not proc.is_alive(): + return + proc.terminate() + proc.join() + + assert isinstance(proc, multiprocessing.Process) + atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) From 0833a23703fa6aee1700d6a1c024792930c00f51 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Sat, 7 Sep 2019 12:48:57 +0100 Subject: [PATCH 2/6] add data flow module as tl.data --- .gitignore | 2 +- ..._transformer_network_dynamic_tlDataflow.py | 167 +++++++ tensorlayer/__init__.py | 2 + tensorlayer/data/__init__.py | 3 + tensorlayer/data/base.py | 65 +++ tensorlayer/data/common.py | 348 +++++++++++++++ tensorlayer/data/dataset/__init__.py | 15 + tensorlayer/data/dataset/celebA.py | 100 +++++ tensorlayer/data/dataset/cifar10.py | 190 ++++++++ tensorlayer/data/dataset/cyclegan.py | 133 ++++++ tensorlayer/data/dataset/flickr_1M.py | 128 ++++++ tensorlayer/data/dataset/flickr_25k.py | 81 ++++ tensorlayer/data/dataset/ilsvrc.py | 191 ++++++++ tensorlayer/data/dataset/imdb.py | 159 +++++++ tensorlayer/data/dataset/matt_mahoney.py | 76 ++++ tensorlayer/data/dataset/mnist.py | 115 +++++ tensorlayer/data/dataset/mnist_fashion.py | 108 +++++ tensorlayer/data/dataset/mpii.py | 297 +++++++++++++ tensorlayer/data/dataset/nietzsche.py | 70 +++ tensorlayer/data/dataset/ptb.py | 138 ++++++ tensorlayer/data/dataset/voc.py | 334 ++++++++++++++ tensorlayer/data/dataset/wmt_en_fr.py | 80 ++++ tensorlayer/data/parallel.py | 194 +++++++++ tensorlayer/data/serialize.py | 27 ++ tensorlayer/data/utils.py | 409 ++++++++++++++++++ 25 files changed, 3431 insertions(+), 1 deletion(-) create mode 100644 examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py create mode 100644 tensorlayer/data/__init__.py create mode 100644 tensorlayer/data/base.py create mode 100644 tensorlayer/data/common.py create mode 100644 tensorlayer/data/dataset/__init__.py create mode 100644 tensorlayer/data/dataset/celebA.py create mode 100644 tensorlayer/data/dataset/cifar10.py create mode 100644 tensorlayer/data/dataset/cyclegan.py create mode 100644 tensorlayer/data/dataset/flickr_1M.py create mode 100644 tensorlayer/data/dataset/flickr_25k.py create mode 100644 tensorlayer/data/dataset/ilsvrc.py create mode 100644 tensorlayer/data/dataset/imdb.py create mode 100644 tensorlayer/data/dataset/matt_mahoney.py create mode 100644 tensorlayer/data/dataset/mnist.py create mode 100644 tensorlayer/data/dataset/mnist_fashion.py create mode 100644 tensorlayer/data/dataset/mpii.py create mode 100644 tensorlayer/data/dataset/nietzsche.py create mode 100644 tensorlayer/data/dataset/ptb.py create mode 100644 tensorlayer/data/dataset/voc.py create mode 100644 tensorlayer/data/dataset/wmt_en_fr.py create mode 100644 tensorlayer/data/parallel.py create mode 100644 tensorlayer/data/serialize.py create mode 100644 tensorlayer/data/utils.py diff --git a/.gitignore b/.gitignore index 4a6f60ff3..22d68693c 100644 --- a/.gitignore +++ b/.gitignore @@ -119,7 +119,7 @@ venv_py2/ # TensorLayer Directories checkpoints -data/ +raw_data/ lib_win/ # Custom Scripts diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py new file mode 100644 index 000000000..bc0bae141 --- /dev/null +++ b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py @@ -0,0 +1,167 @@ +#! /usr/bin/python +# -*- coding: utf8 -*- +import time + +import numpy as np + +import tensorflow as tf +import tensorlayer as tl +from tensorlayer.layers import * +from tensorlayer.models import Model + +##================== PREPARE DATA ============================================## +X_train, y_train, X_val, y_val, X_test, y_test = \ + tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) + + +def pad_distort_im_fn(x): + """ Zero pads an image to 40x40, and distort it. + + Examples + --------- + x = pad_distort_im_fn(X_train[0]) + print(x, x.shape, x.max()) + tl.vis.save_image(x, '_xd.png') + tl.vis.save_image(X_train[0], '_x.png') + """ + b = np.zeros((40, 40, 1), dtype=np.float32) + o = int((40 - 28) / 2) + b[o:o + 28, o:o + 28] = x + x = b + x = tl.prepro.rotation(x, rg=30, is_random=True, fill_mode='constant') + x = tl.prepro.shear(x, 0.05, is_random=True, fill_mode='constant') + x = tl.prepro.shift(x, wrg=0.25, hrg=0.25, is_random=True, fill_mode='constant') + x = tl.prepro.zoom(x, zoom_range=(0.95, 1.05)) + return x + + +def pad_distort_ims_fn(X): + """ Zero pads images to 40x40, and distort them. """ + X_40 = [] + for X_a, _ in tl.iterate.minibatches(X, X, 50, shuffle=False): + X_40.extend(tl.prepro.threading_data(X_a, fn=pad_distort_im_fn)) + X_40 = np.asarray(X_40) + return X_40 + + +# create dataset with size of 40x40 with distortion +X_train_40 = pad_distort_ims_fn(X_train) +X_val_40 = pad_distort_ims_fn(X_val) +X_test_40 = pad_distort_ims_fn(X_test) + +tl.vis.save_images(X_test[0:32], [4, 8], '_imgs_original.png') +tl.vis.save_images(X_test_40[0:32], [4, 8], '_imgs_distorted.png') + + +##================== DEFINE MODEL ============================================## +class Net(Model): + + def __init__(self): + super(Net, self).__init__() + + ## 1. Localisation network + # use MLP as the localisation net + self.flatten1 = Flatten() + self.dense1 = Dense(n_units=20, in_channels=1600, act=tf.nn.tanh) + self.dropout1 = Dropout(keep=0.8) + # you can also use CNN instead for MLP as the localisation net + + ## 2. Spatial transformer module (sampler) + self.stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) + + ## 3. Classifier + self.conv1 = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', in_channels=1) + self.conv2 = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', in_channels=16) + self.flatten2 = Flatten() + self.dense2 = Dense(n_units=1024, in_channels=1600, act=tf.nn.relu) + self.dense3 = Dense(n_units=10, in_channels=1024, act=tf.identity) + + def forward(self, inputs): + theta_input = self.dropout1(self.dense1(self.flatten1(inputs))) + V = self.stn((theta_input, inputs)) + _logits = self.dense3(self.dense2(self.flatten2(self.conv2(self.conv1(V))))) + return _logits, V + + +net = Net() + +##================== DEFINE TRAIN OPS ========================================## +n_epoch = 100 +learning_rate = 0.0001 +print_freq = 10 +batch_size = 64 +train_weights = net.trainable_weights +optimizer = tf.optimizers.Adam(lr=learning_rate) + +##================== TRAINING ================================================## +print("Training ...") +for epoch in range(n_epoch): + start_time = time.time() + + net.train() # enable dropout + + for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=True): + # input_dim must be of length 4 + X_train_a = tf.expand_dims(X_train_a, 3) + + with tf.GradientTape() as tape: + ## compute outputs + _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=True) and remove MLP.train() + ## compute loss and update model + _loss = tl.cost.cross_entropy(_logits, y_train_a, name='train_loss') + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + + ## use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + + net.eval() # disable dropout + + print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) + + train_loss, train_acc, n_iter = 0, 0, 0 + for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=False): + # input_dim must be of length 4 + X_train_a = tf.expand_dims(X_train_a, 3) + + _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=False) and remove MLP.eval() + train_loss += tl.cost.cross_entropy(_logits, y_train_a, name='eval_train_loss') + train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_train_a)) + n_iter += 1 + print(" train loss: %f" % (train_loss / n_iter)) + print(" train acc: %f" % (train_acc / n_iter)) + + val_loss, val_acc, n_iter = 0, 0, 0 + for X_val_a, y_val_a in tl.iterate.minibatches(X_val_40, y_val, batch_size, shuffle=False): + # input_dim must be of length 4 + X_val_a = tf.expand_dims(X_val_a, 3) + + _logits, _ = net(X_val_a) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) + n_iter += 1 + print(" val loss: %f" % (val_loss / n_iter)) + print(" val acc: %f" % (val_acc / n_iter)) + + print('save images') + _, trans_imgs = net(tf.expand_dims(X_test_40[0:64], 3)) + trans_imgs = trans_imgs.numpy() + tl.vis.save_images(trans_imgs[0:32], [4, 8], '_imgs_distorted_after_stn_%s.png' % epoch) + +##================== EVALUATION ==============================================## +print('Evaluation') + +net.eval() + +test_loss, test_acc, n_iter = 0, 0, 0 +for X_test_a, y_test_a in tl.iterate.minibatches(X_test_40, y_test, batch_size, shuffle=False): + # input_dim must be of length 4 + X_test_a = tf.expand_dims(X_test_a, 3) + + _logits, _ = net(X_test_a) + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) + n_iter += 1 +print(" test loss: %f" % (test_loss / n_iter)) +print(" test acc: %f" % (test_acc / n_iter)) diff --git a/tensorlayer/__init__.py b/tensorlayer/__init__.py index f89eebfff..8c7078733 100644 --- a/tensorlayer/__init__.py +++ b/tensorlayer/__init__.py @@ -44,6 +44,7 @@ from tensorlayer import optimizers from tensorlayer import rein from tensorlayer import utils + from tensorlayer import data from tensorlayer.lazy_imports import LazyImport @@ -54,6 +55,7 @@ prepro = LazyImport("tensorlayer.prepro") utils = LazyImport("tensorlayer.utils") visualize = LazyImport("tensorlayer.visualize") + data = LazyImport("tensorlayer.data") # alias act = activation diff --git a/tensorlayer/data/__init__.py b/tensorlayer/data/__init__.py new file mode 100644 index 000000000..cd53ae92b --- /dev/null +++ b/tensorlayer/data/__init__.py @@ -0,0 +1,3 @@ +from .base import Dataset +from .common import Dataloader +from .dataset import * diff --git a/tensorlayer/data/base.py b/tensorlayer/data/base.py new file mode 100644 index 000000000..6912f2cd3 --- /dev/null +++ b/tensorlayer/data/base.py @@ -0,0 +1,65 @@ +class Dataset(object): + + def __getitem__(self, index): + raise NotImplementedError("A Dataset must implement __getitem__(index) method.") + + def __len__(self): + raise NotImplementedError("A Dataset must implement __len__() method.") + + def __iter__(self): + for i in range(self.__len__()): + yield self.__getitem__(i) + + def __call__(self, *args, **kwargs): + return self.__iter__() + + +class DatasetWrapper(object): + def __init__(self, ds): + self.ds = ds + self.ds_len = len(ds) + + def __len__(self): + return len(self.ds) + + def __iter__(self): + for dp in self.ds: + yield dp + + def __call__(self, *args, **kwargs): + return self.__iter__() + + +class IndexableDatasetWrapper(object): + def __init__(self, ds): + self.ds = ds + self.ds_len = len(ds) + + def __getitem__(self, index): + return self.ds.__getitem__(index) + + def __len__(self): + return len(self.ds) + + def __call__(self, *args, **kwargs): + return self + + +class Transform(object): + def __call__(self, *args, **kwargs): + raise NotImplementedError("Transform must implement __call__() method.") + + +class _Transforms_for_tf_dataset(object): + """ + This class aggregate Transforms into one object in order to use tf.data.Dataset.map API + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, *args): + data_list = list(args) + for transform in self.transforms: + data_list = transform(*data_list) + return data_list diff --git a/tensorlayer/data/common.py b/tensorlayer/data/common.py new file mode 100644 index 000000000..ed5a0f518 --- /dev/null +++ b/tensorlayer/data/common.py @@ -0,0 +1,348 @@ +import math +import multiprocessing +import os + +import tensorflow as tf +import zmq + +import numpy as np + +from .base import IndexableDatasetWrapper, DatasetWrapper, _Transforms_for_tf_dataset +from .parallel import _get_pipe_name, ZMQMultiprocessDataset, MultiprocessDataset +from .utils import clean_up_socket_files +from .serialize import convert_to_bytes, load_from_bytes + +__all__ = ['PrefetchBatchedDataset', 'TransformedDataset', 'ShuffledDataset', + 'AugmentedDataset'] + + +class BatchedDataset(DatasetWrapper): + def __init__(self, + ds, + batch_size, + drop_remainder=True, + return_numpy=True, + output_types=None): + super(BatchedDataset, self).__init__(ds) + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.return_numpy = return_numpy + self.output_types = output_types + + def __iter__(self): + dp_buffer = [] + for dp in self.ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + yield self._batch_datapoints(dp_buffer, self.return_numpy, self.output_types) + del dp_buffer[:] + if not self.drop_remainder: + self._batch_datapoints(dp_buffer, self.return_numpy, self.output_types) + + def __len__(self): + ds_len = len(self.ds) + if self.drop_remainder: + return ds_len // self.batch_size + else: + return math.ceil(ds_len / self.batch_size) + + @staticmethod + def _batch_datapoints(dp_buffer, return_numpy, output_types): + """ + + :param dp_buffer: a list of datapoints + :return: + """ + first_dp = dp_buffer[0] + if isinstance(first_dp, (tuple, list)): + dp_batch = [None] * len(first_dp) + for i in range(len(first_dp)): + dp_element_batch = [] + for j in range(len(dp_buffer)): + dp_element_batch.append(dp_buffer[j][i]) + if return_numpy: + dp_batch[i] = BatchedDataset._batch_ndarray(dp_element_batch, + dtype=BatchedDataset._get_element_dtype(output_types, + i)) + else: + dp_batch[i] = dp_element_batch + return dp_batch + elif isinstance(first_dp, dict): + dp_batch = {} + for key in first_dp.keys(): + dp_element_batch = [] + for j in range(len(dp_buffer)): + dp_element_batch.append(dp_buffer[j][key]) + if return_numpy: + dp_batch[key] = BatchedDataset._batch_ndarray(dp_element_batch, dtype=None) + else: + dp_batch[key] = dp_element_batch + return dp_batch + elif isinstance(first_dp, np.ndarray): + return BatchedDataset._batch_ndarray(dp_buffer) + # single elements + else: + if return_numpy: + return BatchedDataset._batch_ndarray(dp_buffer, + dtype=BatchedDataset._get_element_dtype(output_types, 0)) + else: + return dp_buffer + + @staticmethod + def _batch_ndarray(dp_element_batch, dtype): + """ + + :param dp_element_batch: a list of datapoint element, an element can be np.ndarray / list + :return: np.ndarray, type is the same as input + """ + try: + if dtype is not None: + ret = np.asarray(dp_element_batch, dtype=dtype) + else: + ret = np.asarray(dp_element_batch) + return ret + except: + raise ValueError("Unsupported type for batching.") + + @staticmethod + def _get_element_dtype(output_types, i): + if output_types is None: + return None + if not isinstance(output_types, (tuple, list)): + return output_types + if len(output_types) == 1: + return output_types[0] + return output_types[i] + + +class PrefetchBatchedDataset(DatasetWrapper): + def __init__(self, + ds, + batch_size, + drop_remainder=True, + return_numpy=True, + output_types=None, + use_zmq=True): + super(PrefetchBatchedDataset, self).__init__(ds) + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.return_numpy = return_numpy + self.output_types = output_types + self.use_zmq = use_zmq + + if self.use_zmq: + self.data_pipename = _get_pipe_name('batch_prefetch') + context = zmq.Context() + self.fetch_data_socket = context.socket(zmq.PULL) + self.fetch_data_socket.set_hwm(1) + self.fetch_data_socket.bind(self.data_pipename) + self.worker = multiprocessing.Process(target=self._ZMQ_BatchedDataset_worker, + args=(self.ds,)) + self.worker.daemon = True + self.worker.start() + clean_up_socket_files(self.data_pipename) + else: + pipe_output, pipe_input = multiprocessing.Pipe() + self.worker = multiprocessing.Process(target=self._BatchedDataset_worker, + args=(self.ds, (pipe_output, pipe_input))) + self.worker.daemon = True + self.worker.start() + # main process only reads (gets output) + pipe_input.close() + self.pipe_output = pipe_output + + def _ZMQ_BatchedDataset_worker(self, ds): + context = zmq.Context() + prepare_data_socket = context.socket(zmq.PUSH) + prepare_data_socket.set_hwm(1) + prepare_data_socket.connect(self.data_pipename) + while True: + dp_buffer = [] + for dp in ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + prepare_data_socket.send(convert_to_bytes( + BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)), copy=False) + del dp_buffer[:] + if not self.drop_remainder: + prepare_data_socket.send( + convert_to_bytes(BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)), + copy=False) + + def _BatchedDataset_worker(self, ds, pipe): + pipe_output, pipe_input = pipe + # worker process only writes (puts input) + pipe_output.close() + while True: + dp_buffer = [] + for dp in ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + pipe_input.send(BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)) + del dp_buffer[:] + if not self.drop_remainder: + pipe_input.send(BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)) + + def __iter__(self): + for _ in range(self.__len__()): + # yield self.q.get() + if self.use_zmq: + yield load_from_bytes(self.fetch_data_socket.recv(copy=False)) + else: + yield self.pipe_output.recv() + + def __len__(self): + ds_len = len(self.ds) + if self.drop_remainder: + return ds_len // self.batch_size + else: + return math.ceil(ds_len / self.batch_size) + + +class ShuffledDataset(DatasetWrapper): + def __init__(self, ds): + super(ShuffledDataset, self).__init__(ds) + + def __iter__(self): + self.shuffled_idxs = np.random.permutation(len(self.ds)) + for index, data in enumerate(self.ds): + yield self.ds[self.shuffled_idxs[index]] + + +class TransformedDataset(IndexableDatasetWrapper): + """ + + """ + + def __init__(self, ds, transforms): + super(TransformedDataset, self).__init__(ds) + self.transforms = transforms + + def __getitem__(self, index): + dp = self.ds[index] + for transform in self.transforms: + assert callable(transform) + if isinstance(dp, (list, tuple)): + dp = transform(*dp) + else: + dp = transform(dp) + return dp + + +class AugmentedDataset(IndexableDatasetWrapper): + def __init__(self, ds, augmentations): + super(AugmentedDataset, self).__init__(ds) + self.augmentations = augmentations + self.num_augmentations = len(self.augmentations) + + def __getitem__(self, index): + if index >= self.__len__(): + raise IndexError + dp = self.ds[index % self.ds_len] + if index < self.ds_len: + return dp + augmentation = self.augmentations[(index // self.ds_len) - 1] + assert callable(augmentation) + if isinstance(dp, (list, tuple)): + return augmentation(*dp) + else: + return augmentation(dp) + + def __len__(self): + # every augmentation gives one more duplication of dataset + return self.ds_len * (1 + self.num_augmentations) + + +class Dataloader(DatasetWrapper): + def __init__(self, + ds, + augmentations=None, + shuffle=False, + batch_size=1, + drop_remainder=True, + output_types=None, + num_worker=os.cpu_count(), + use_zmq=True, + prefetch_batch=True, + num_prefetch=None, + transforms=None): + + super(Dataloader, self).__init__(ds) + self.augmentations = augmentations + self.shuffle = shuffle + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.output_types = output_types + self.num_worker = num_worker + self.use_zmq = use_zmq + self.prefetch_batch = prefetch_batch + self.num_prefetch = num_worker if num_prefetch is None else num_prefetch + self.transforms = transforms + + if self.augmentations is not None: + self.ds = AugmentedDataset(self.ds, self.augmentations) + + if self.transforms is not None: + self.ds = TransformedDataset(self.ds, self.transforms) + # self.tfds = self.tfds.map(map_func=_Transforms(self.transforms), num_parallel_calls=num_map_worker) + + # TODO: auto adjust num_prefetch + if self.num_worker > 1: + if self.use_zmq: + self.ds = ZMQMultiprocessDataset(self.ds, num_worker=self.num_worker, hwm=self.num_prefetch, + shuffle=self.shuffle) + else: + self.ds = MultiprocessDataset(self.ds, num_worker=self.num_worker, num_prefetch=self.num_prefetch, + shuffle=self.shuffle) + elif self.shuffle: + self.ds = ShuffledDataset(self.ds) + + if self.prefetch_batch: + self.ds = PrefetchBatchedDataset(self.ds, self.batch_size, drop_remainder=self.drop_remainder, + output_types=self.output_types, use_zmq=self.use_zmq) + else: + self.ds = BatchedDataset(self.ds, self.batch_size, drop_remainder=self.drop_remainder, + output_types=self.output_types) + + def __iter__(self): + for dp in self.ds: + yield dp + + +class TFDataloader(DatasetWrapper): + def __init__(self, + ds, + output_types, + augmentations=None, + shuffle=False, + shuffle_buffer_size=None, + batch_size=1, + drop_remainder=True, + num_worker=tf.data.experimental.AUTOTUNE, + transforms=None): + + super(TFDataloader, self).__init__(ds) + self.augmentations = augmentations + self.shuffle = shuffle + self.batch_size = batch_size + self.shuffle_buffer_size = 2 * batch_size if shuffle_buffer_size is None else shuffle_buffer_size + self.drop_remainder = drop_remainder + self.transforms = transforms + + self.ds = tf.data.Dataset.from_generator(self.ds, output_types=output_types) + + if self.shuffle: + self.ds = self.ds.shuffle(buffer_size=self.shuffle_buffer_size) + + if self.transforms is not None: + self.ds = self.ds.map(map_func=_Transforms_for_tf_dataset(self.transforms), + num_parallel_calls=num_worker) + + if self.batch_size > 1: + self.ds = self.ds.batch(batch_size=self.batch_size, drop_remainder=self.drop_remainder) + + self.ds = self.ds.prefetch(tf.data.experimental.AUTOTUNE) + + def __iter__(self): + for dp in self.ds: + yield dp diff --git a/tensorlayer/data/dataset/__init__.py b/tensorlayer/data/dataset/__init__.py new file mode 100644 index 000000000..9b2c3166d --- /dev/null +++ b/tensorlayer/data/dataset/__init__.py @@ -0,0 +1,15 @@ +from .celebA import * +from .cifar10 import * +from .cyclegan import * +from .flickr_25k import * +from .flickr_1M import * +from .ilsvrc import * +from .imdb import * +from .matt_mahoney import * +from .mnist import * +from .mnist_fashion import * +from .mpii import * +from .nietzsche import * +from .ptb import * +from .voc import * +from .wmt_en_fr import * diff --git a/tensorlayer/data/dataset/celebA.py b/tensorlayer/data/dataset/celebA.py new file mode 100644 index 000000000..cb8c6a392 --- /dev/null +++ b/tensorlayer/data/dataset/celebA.py @@ -0,0 +1,100 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os +import zipfile + +import logging + +import cv2 +import numpy as np + +from ..base import Dataset +from ..utils import download_file_from_google_drive, exists_or_mkdir, load_file_list + +__all__ = ['load_celebA_dataset', 'CelebAFiles', 'CelebA'] + + +def load_celebA_dataset(name='celebA', path='raw_data'): + """ + Load CelebA dataset + + Return a list of image path. + + Parameters + ----------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/celebA/``. + """ + data_dir = name + filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" + save_path = os.path.join(path, filename) + image_path = os.path.join(path, data_dir) + if os.path.exists(image_path): + logging.info('[*] {} already exists'.format(save_path)) + else: + exists_or_mkdir(path) + download_file_from_google_drive(drive_id, save_path) + zip_dir = '' + with zipfile.ZipFile(save_path) as zf: + zip_dir = zf.namelist()[0] + zf.extractall(path) + os.remove(save_path) + os.rename(os.path.join(path, zip_dir), image_path) + + data_files = load_file_list(path=image_path, regx='\\.jpg', printable=False) + for i, _v in enumerate(data_files): + data_files[i] = os.path.join(image_path, data_files[i]) + return data_files + + +class CelebAFiles(Dataset): + """ + Load CelebA dataset. Produce filenames of images. + + Parameters + ----------- + name : str + The name of the dataset + path : str + The path that the data is downloaded to, defaults is ``raw_data/celebA/``. + """ + def __init__(self, name='celebA', path='raw_data'): + self.data_files = load_celebA_dataset(name=name, path=path) + + def __getitem__(self, index): + return self.data_files[index] + + def __len__(self): + return len(self.data_files) + + +class CelebA(CelebAFiles): + """ + Load CelebA dataset. Produce nparrays of images. + + Parameters + ----------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/celebA/``. + shape : tuple + The shape of digit images. + """ + def __init__(self, shape=None, name='celebA', path='raw_data'): + super(CelebA, self).__init__(name=name, path=path) + self.shape = shape + + def __getitem__(self, index): + file_path = self.data_files[index] + img = cv2.imread(file_path) + if self.shape: + img = cv2.resize(img, self.shape) + img = np.array(img, dtype=np.float32) + return img + + def __len__(self): + return len(self.data_files) diff --git a/tensorlayer/data/dataset/cifar10.py b/tensorlayer/data/dataset/cifar10.py new file mode 100644 index 000000000..dda876efb --- /dev/null +++ b/tensorlayer/data/dataset/cifar10.py @@ -0,0 +1,190 @@ +import logging +import os +import pickle +import sys +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_cifar10_dataset', 'CIFAR10'] + +CIFAR10_BASE_URL = 'https://www.cs.toronto.edu/~kriz/' +CIFAR10_FILENAME = 'cifar-10-python.tar.gz' + + +# Helper function to unpickle the data +def unpickle(file): + fp = open(file, 'rb') + if sys.version_info.major == 2: + data = pickle.load(fp) + elif sys.version_info.major == 3: + data = pickle.load(fp, encoding='latin-1') + else: + raise RuntimeError("Sys Version Unsupported") + fp.close() + return data + + +def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='raw_data', name='cifar10', plotable=False): + """ + Load CIFAR-10 dataset. + + It consists of 60000 32x32 colour images in 10 classes, with + 6000 images per class. There are 50000 training images and 10000 test images. + + The dataset is divided into five training batches and one test batch, each with + 10000 images. The test batch contains exactly 1000 randomly-selected images from + each class. The training batches contain the remaining images in random order, + but some training batches may contain more images from one class than another. + Between them, the training batches contain exactly 5000 images from each class. + + Parameters + ---------- + shape : tuple + The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3). + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/cifar10/``. + plotable : boolean + Whether to plot some image examples, False as default. + + Examples + -------- + >>> X_train, y_train, X_test, y_test = load_cifar10_dataset(shape=(-1, 32, 32, 3)) + + References + ---------- + - `CIFAR website `__ + - `Data download link `__ + - ``__ + + """ + path = os.path.join(path, name) + logging.info("Load or Download cifar10 > {}".format(path)) + + # Download and uncompress file + maybe_download_and_extract(CIFAR10_FILENAME, path, CIFAR10_BASE_URL, extract=True) + + # Unpickle file and fill in data + X_train = None + y_train = [] + for i in range(1, 6): + data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) + if i == 1: + X_train = data_dic['data'] + else: + X_train = np.vstack((X_train, data_dic['data'])) + y_train += data_dic['labels'] + + test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) + X_test = test_data_dic['data'] + y_test = np.array(test_data_dic['labels']) + + if shape == (-1, 3, 32, 32): + X_test = X_test.reshape(shape) + X_train = X_train.reshape(shape) + elif shape == (-1, 32, 32, 3): + X_test = X_test.reshape(shape, order='F') + X_train = X_train.reshape(shape, order='F') + X_test = np.transpose(X_test, (0, 2, 1, 3)) + X_train = np.transpose(X_train, (0, 2, 1, 3)) + else: + X_test = X_test.reshape(shape) + X_train = X_train.reshape(shape) + + y_train = np.array(y_train) + + if plotable: + logging.info('\nCIFAR-10') + import matplotlib.pyplot as plt + fig = plt.figure(1) + + logging.info('Shape of a training image: X_train[0] %s' % X_train[0].shape) + + plt.ion() # interactive mode + count = 1 + for _ in range(10): # each row + for _ in range(10): # each column + _ = fig.add_subplot(10, 10, count) + if shape == (-1, 3, 32, 32): + # plt.imshow(X_train[count-1], interpolation='nearest') + plt.imshow(np.transpose(X_train[count - 1], (1, 2, 0)), interpolation='nearest') + # plt.imshow(np.transpose(X_train[count-1], (2, 1, 0)), interpolation='nearest') + elif shape == (-1, 32, 32, 3): + plt.imshow(X_train[count - 1], interpolation='nearest') + # plt.imshow(np.transpose(X_train[count-1], (1, 0, 2)), interpolation='nearest') + else: + raise Exception("Do not support the given 'shape' to plot the image examples") + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + count = count + 1 + plt.draw() # interactive mode + plt.pause(3) # interactive mode + + logging.info("X_train: %s" % X_train.shape) + logging.info("y_train: %s" % y_train.shape) + logging.info("X_test: %s" % X_test.shape) + logging.info("y_test: %s" % y_test.shape) + + X_train = np.asarray(X_train, dtype=np.float32) + X_test = np.asarray(X_test, dtype=np.float32) + y_train = np.asarray(y_train, dtype=np.int32) + y_test = np.asarray(y_test, dtype=np.int32) + + return X_train, y_train, X_test, y_test + + +class CIFAR10(Dataset): + """ + Load CIFAR-10 dataset. + + It consists of 60000 32x32 colour images in 10 classes, with + 6000 images per class. There are 50000 training images and 10000 test images. + + Parameters + ---------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/cifar10/``. + shape : tuple + The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3). + """ + def __init__(self, train_or_test, path='raw_data', name='cifar10', shape=(-1, 32, 32, 3)): + self.path = os.path.join(path, name) + + # Download and read the training and test set images and labels. + logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) + + maybe_download_and_extract(CIFAR10_FILENAME, path, CIFAR10_BASE_URL, extract=True) + + assert train_or_test in ['train', 'test'] + if train_or_test == 'train': + # Unpickle file and fill in data + self.images = None + self.labels = [] + for i in range(1, 6): + data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) + if i == 1: + self.images = data_dic['data'] + else: + self.images = np.vstack((self.images, data_dic['data'])) + self.labels += data_dic['labels'] + else: + test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) + self.images = test_data_dic['data'] + self.labels = np.array(test_data_dic['labels']) + + self.images = np.reshape(self.images, shape) + + def __getitem__(self, index): + img = np.array(self.images[index], dtype=np.float32) + label = np.array(self.labels[index], dtype=np.int32) + return img, label + + def __len__(self): + return self.images.shape[0] diff --git a/tensorlayer/data/dataset/cyclegan.py b/tensorlayer/data/dataset/cyclegan.py new file mode 100644 index 000000000..f67da5f7b --- /dev/null +++ b/tensorlayer/data/dataset/cyclegan.py @@ -0,0 +1,133 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +import cv2 +import numpy as np + +from tensorlayer import logging, visualize + +from ..base import Dataset +from ..utils import maybe_download_and_extract, folder_exists, del_file, load_file_list + +__all__ = ['load_cyclegan_dataset', 'CycleGAN', 'CycleGANFiles'] + +CYCLEGAN_BASE_URL = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + + +def load_cyclegan_dataset(name='cyclegan', path='raw_data', filename='summer2winter_yosemite'): + """ + Load images from CycleGAN's database, see `this link `. + + Parameters + ------------ + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/cyclegan`. + filename : str + The dataset you want, see `this link `. + + Examples + --------- + >>> im_train_A, im_train_B, im_test_A, im_test_B = load_cyclegan_dataset(filename='summer2winter_yosemite') + + """ + path = os.path.join(path, name) + + if folder_exists(os.path.join(path, filename)) is False: + logging.info("[*] {} is nonexistent in {}".format(filename, path)) + filepath = maybe_download_and_extract(filename=filename + '.zip', working_directory=path, + url_source=CYCLEGAN_BASE_URL, extract=True) + del_file(filepath) + + def load_image_from_folder(path): + path_imgs = load_file_list(path=path, regx='\\.jpg', printable=False) + return visualize.read_images(path_imgs, path=path, n_threads=10, printable=False) + + im_train_A = load_image_from_folder(os.path.join(path, filename, "trainA")) + im_train_B = load_image_from_folder(os.path.join(path, filename, "trainB")) + im_test_A = load_image_from_folder(os.path.join(path, filename, "testA")) + im_test_B = load_image_from_folder(os.path.join(path, filename, "testB")) + + def if_2d_to_3d(images): # [h, w] --> [h, w, 3] + for i, _v in enumerate(images): + if len(images[i].shape) == 2: + images[i] = images[i][:, :, np.newaxis] + images[i] = np.tile(images[i], (1, 1, 3)) + return images + + im_train_A = if_2d_to_3d(im_train_A) + im_train_B = if_2d_to_3d(im_train_B) + im_test_A = if_2d_to_3d(im_test_A) + im_test_B = if_2d_to_3d(im_test_B) + + return im_train_A, im_train_B, im_test_A, im_test_B + + +class CycleGANFiles(Dataset): + """ + Load image file names from CycleGAN's database, see `this link `. + + Parameters + ------------ + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/cyclegan` + filename : str + The dataset you want, see `this link `. + """ + def __init__(self, train_or_test, name='cyclegan', path='raw_data', filename='summer2winter_yosemite'): + self.path = os.path.join(path, name) + self.train_or_test = train_or_test + + if folder_exists(os.path.join(path, filename)) is False: + logging.info("[*] {} is nonexistent in {}".format(filename, path)) + filepath = maybe_download_and_extract(filename=filename + '.zip', working_directory=path, + url_source=CYCLEGAN_BASE_URL, extract=True) + del_file(filepath) + + assert self.train_or_test in ['train', 'test'] + if self.train_or_test == 'train': + self.im_A_path = load_file_list(path=os.path.join(path, filename, "trainA"), regx='\\.jpg', printable=False) + self.im_B_path = load_file_list(path=os.path.join(path, filename, "trainB"), regx='\\.jpg', printable=False) + else: + self.im_A_path = load_file_list(path=os.path.join(path, filename, "testA"), regx='\\.jpg', printable=False) + self.im_B_path = load_file_list(path=os.path.join(path, filename, "testB"), regx='\\.jpg', printable=False) + + def __getitem__(self, index): + return self.im_A_path[index], self.im_B_path[index] + + def __len__(self): + assert len(self.im_A_path) == len(self.im_B_path) + return len(self.im_A_path) + + +class CycleGAN(CycleGANFiles): + """ + Load images from CycleGAN's database, see `this link `. + + Parameters + ------------ + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/cyclegan` + filename : str + The dataset you want, see `this link `. + """ + def __init__(self, train_or_test, name='cyclegan', path='raw_data', filename='summer2winter_yosemite'): + super(CycleGAN, self).__init__(train_or_test, name, path, filename) + + def __getitem__(self, index): + imA = cv2.imread(self.im_A_path) + imB = cv2.imread(self.im_B_path) + imA = np.array(imA, dtype=np.float32) + imB = np.array(imB, dtype=np.float32) + return imA, imB diff --git a/tensorlayer/data/dataset/flickr_1M.py b/tensorlayer/data/dataset/flickr_1M.py new file mode 100644 index 000000000..0b9387656 --- /dev/null +++ b/tensorlayer/data/dataset/flickr_1M.py @@ -0,0 +1,128 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging, visualize + +from ..utils import del_file, folder_exists, load_file_list, load_folder_list, maybe_download_and_extract, read_file + +__all__ = ['load_flickr1M_dataset'] + +IMAGES_ZIP = [ + 'images0.zip', 'images1.zip', 'images2.zip', 'images3.zip', 'images4.zip', 'images5.zip', 'images6.zip', + 'images7.zip', 'images8.zip', 'images9.zip' +] +TAG_ZIP = 'tags.zip' +FLICKR1M_BASE_URL = 'http://press.liacs.nl/mirflickr/mirflickr1m/' + + +def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printable=False): + """Load Flick1M dataset. + + Returns a list of images by a given tag from Flickr1M dataset, + it will download Flickr1M from `the official website `__ + at the first time you use it. + + Parameters + ------------ + tag : str or None + What images to return. + - If you want to get images with tag, use string like 'dog', 'red', see `Flickr Search `__. + - If you want to get all images, set to ``None``. + + size : int + integer between 1 to 10. 1 means 100k images ... 5 means 500k images, 10 means all 1 million images. Default is 10. + path : str + The path that the data is downloaded to, defaults is ``data/flickr25k/``. + n_threads : int + The number of thread to read image. + printable : boolean + Whether to print infomation when reading images, default is ``False``. + + Examples + ---------- + Use 200k images + + >>> images = tl.files.load_flickr1M_dataset(tag='zebra', size=2) + + Use 1 Million images + + >>> images = tl.files.load_flickr1M_dataset(tag='zebra') + + """ + import shutil + + path = os.path.join(path, 'flickr1M') + logging.info("[Flickr1M] using {}% of images = {}".format(size * 10, size * 100000)) + + # download dataset + for image_zip in IMAGES_ZIP[0:size]: + image_folder = image_zip.split(".")[0] + # logging.info(path+"/"+image_folder) + if folder_exists(os.path.join(path, image_folder)) is False: + # logging.info(image_zip) + logging.info("[Flickr1M] {} is missing in {}".format(image_folder, path)) + maybe_download_and_extract(image_zip, path, FLICKR1M_BASE_URL+image_zip, extract=True) + del_file(os.path.join(path, image_zip)) + # os.system("mv {} {}".format(os.path.join(path, 'images'), os.path.join(path, image_folder))) + shutil.move(os.path.join(path, 'images'), os.path.join(path, image_folder)) + else: + logging.info("[Flickr1M] {} exists in {}".format(image_folder, path)) + + # download tag + if folder_exists(os.path.join(path, "tags")) is False: + logging.info("[Flickr1M] tag files is nonexistent in {}".format(path)) + maybe_download_and_extract(TAG_ZIP, path, FLICKR1M_BASE_URL+TAG_ZIP, extract=True) + del_file(os.path.join(path, TAG_ZIP)) + else: + logging.info("[Flickr1M] tags exists in {}".format(path)) + + # 1. image path list + images_list = [] + images_folder_list = [] + for i in range(0, size): + images_folder_list += load_folder_list(path=os.path.join(path, 'images%d' % i)) + images_folder_list.sort(key=lambda s: int(s.split('/')[-1])) # folder/images/ddd + + for folder in images_folder_list[0:size * 10]: + tmp = load_file_list(path=folder, regx='\\.jpg', printable=False) + tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.jpg + images_list.extend([os.path.join(folder, x) for x in tmp]) + + # 2. tag path list + tag_list = [] + tag_folder_list = load_folder_list(os.path.join(path, "tags")) + + # tag_folder_list.sort(key=lambda s: int(s.split("/")[-1])) # folder/images/ddd + tag_folder_list.sort(key=lambda s: int(os.path.basename(s))) + + for folder in tag_folder_list[0:size * 10]: + tmp = load_file_list(path=folder, regx='\\.txt', printable=False) + tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.txt + tmp = [os.path.join(folder, s) for s in tmp] + tag_list += tmp + + # 3. select images + logging.info("[Flickr1M] searching tag: {}".format(tag)) + select_images_list = [] + for idx, _val in enumerate(tag_list): + tags = read_file(tag_list[idx]).split('\n') + if tag in tags: + select_images_list.append(images_list[idx]) + + logging.info("[Flickr1M] reading images with tag: {}".format(tag)) + images = visualize.read_images(select_images_list, '', n_threads=n_threads, printable=printable) + return images + + +# class Flickr1M(Dataset): +# +# def __init__(self): +# pass +# +# def __getitem__(self, index): +# pass +# +# def __len__(self): +# pass diff --git a/tensorlayer/data/dataset/flickr_25k.py b/tensorlayer/data/dataset/flickr_25k.py new file mode 100644 index 000000000..8049a0653 --- /dev/null +++ b/tensorlayer/data/dataset/flickr_25k.py @@ -0,0 +1,81 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging, visualize +from tensorlayer.files.utils import ( + del_file, folder_exists, load_file_list, maybe_download_and_extract, natural_keys, read_file +) + +__all__ = ['load_flickr25k_dataset'] + + +def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False): + """Load Flickr25K dataset. + + Returns a list of images by a given tag from Flick25k dataset, + it will download Flickr25k from `the official website `__ + at the first time you use it. + + Parameters + ------------ + tag : str or None + What images to return. + - If you want to get images with tag, use string like 'dog', 'red', see `Flickr Search `__. + - If you want to get all images, set to ``None``. + + path : str + The path that the data is downloaded to, defaults is ``data/flickr25k/``. + n_threads : int + The number of thread to read image. + printable : boolean + Whether to print infomation when reading images, default is ``False``. + + Examples + ----------- + Get images with tag of sky + + >>> images = tl.files.load_flickr25k_dataset(tag='sky') + + Get all images + + >>> images = tl.files.load_flickr25k_dataset(tag=None, n_threads=100, printable=True) + + """ + path = os.path.join(path, 'flickr25k') + + filename = 'mirflickr25k.zip' + url = 'http://press.liacs.nl/mirflickr/mirflickr25k/' + + # download dataset + if folder_exists(os.path.join(path, "mirflickr")) is False: + logging.info("[*] Flickr25k is nonexistent in {}".format(path)) + maybe_download_and_extract(filename, path, url, extract=True) + del_file(os.path.join(path, filename)) + + # return images by the given tag. + # 1. image path list + folder_imgs = os.path.join(path, "mirflickr") + path_imgs = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False) + path_imgs.sort(key=natural_keys) + + # 2. tag path list + folder_tags = os.path.join(path, "mirflickr", "meta", "tags") + path_tags = load_file_list(path=folder_tags, regx='\\.txt', printable=False) + path_tags.sort(key=natural_keys) + + # 3. select images + if tag is None: + logging.info("[Flickr25k] reading all images") + else: + logging.info("[Flickr25k] reading images with tag: {}".format(tag)) + images_list = [] + for idx, _v in enumerate(path_tags): + tags = read_file(os.path.join(folder_tags, path_tags[idx])).split('\n') + # logging.info(idx+1, tags) + if tag is None or tag in tags: + images_list.append(path_imgs[idx]) + + images = visualize.read_images(images_list, folder_imgs, n_threads=n_threads, printable=printable) + return images diff --git a/tensorlayer/data/dataset/ilsvrc.py b/tensorlayer/data/dataset/ilsvrc.py new file mode 100644 index 000000000..5cd1b0ec3 --- /dev/null +++ b/tensorlayer/data/dataset/ilsvrc.py @@ -0,0 +1,191 @@ +import numpy as np +import os +import logging +import cv2 + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files'] + +CAFFE_ILSVRC12_META_BASE_URL = 'http://dl.caffe.berkeleyvision.org/' +CAFFE_ILSVRC12_META_FILENAME = 'caffe_ilsvrc12.tar.gz' + + +class ILSVRCMeta(object): + """ + Provide methods to access metadata for ILSVRC dataset. + Metadata is supposed to be found at/will be downloaded to 'path/name/' + + Parameters + ---------- + name : str + The name of the dataset + path : str + The path that the data is downloaded to, defaults is `raw_data/ilsvrc` + + Examples + -------- + >>> meta = ILSVRCMeta(path='raw_data', name='ilsvrc') + >>> imglist = meta.get_image_list(train_or_val_or_test, dir_structure) + + """ + + def __init__(self, name='ilsvrc', path='raw_data'): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) + self.filepath = maybe_download_and_extract(CAFFE_ILSVRC12_META_FILENAME, self.path, CAFFE_ILSVRC12_META_BASE_URL, extract=True) + self.caffepb = None + + def get_synset_words_1000(self): + """ + Returns: + dict: {cls_number: cls_name} + """ + fname = os.path.join(self.path, 'synset_words.txt') + assert os.path.isfile(fname), fname + lines = [x.strip() for x in open(fname).readlines()] + return dict(enumerate(lines)) + + def get_synset_1000(self): + """ + Returns: + dict: {cls_number: synset_id} + """ + fname = os.path.join(self.path, 'synsets.txt') + assert os.path.isfile(fname) + lines = [x.strip() for x in open(fname).readlines()] + return dict(enumerate(lines)) + + def get_image_list(self, name): + """ + Args: + name (str): 'train' or 'val' or 'test' + Returns: + list: list of (image filename, label) + """ + assert name in ['train', 'val', 'test'] + + fname = os.path.join(self.path, name + '.txt') + assert os.path.isfile(fname), fname + with open(fname) as f: + ret = [] + for line in f.readlines(): + name, cls = line.strip().split() + cls = int(cls) + ret.append((name.strip(), cls)) + assert len(ret), fname + return ret + + +class ILSVRC12Files(Dataset): + """ + Load ILSVRC12 dataset. Produce filenames of images and their corresponding labels. + Labels are between [0, 999]. + + Parameters + ----------- + train_or_test_or_val : str + Must be either 'train' or 'test' or 'val'. Choose the training or test or validation dataset. + meta_dir : str + The path that the metadata is located. Will automatically download and extract if it is not found. + path : str + The path of the ILSVRC12 dataset. + + + The dataset should have the structure: + --------------------------------------- + path/ + train/ + n02134418/ + n02134418_198.JPEG + ... + ... + val/ + ILSVRC2012_val_00000001.JPEG + ... + test/ + ILSVRC2012_test_00000001.JPEG + ... + --------------------------------------- + With the downloaded ILSVRC12_img_*.tar, you can use the following + command to build the above structure: + + mkdir val && tar xvf ILSVRC12_img_val.tar -C val + mkdir test && tar xvf ILSVRC12_img_test.tar -C test + mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train + find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' + """ + def __init__(self, train_or_test_or_val, meta_dir, path): + """ + Same as in :class:`ILSVRC12`. + """ + assert train_or_test_or_val in ['train', 'test', 'val'] + path = os.path.expanduser(path) + assert os.path.isdir(path) + self.full_path = os.path.join(path, train_or_test_or_val) + self.path = train_or_test_or_val + + meta = ILSVRCMeta(path=meta_dir) + self.imglist = meta.get_image_list(train_or_test_or_val) + + def __len__(self): + return len(self.imglist) + + def __getitem__(self, index): + fname, label = self.imglist[index] + fname = os.path.join(self.full_path, fname) + return fname, label + + +class ILSVRC12(ILSVRC12Files): + """ + Load ILSVRC12 dataset. Produce images and a label between [0, 999]. + + Parameters + ----------- + train_or_test_or_val : str + Must be either 'train' or 'test' or 'val'. Choose the training or test or validation dataset. + meta_dir : str + The path that the metadata is located. Will automatically download and extract if it is not found. + path : str + The path of the ILSVRC12 dataset. + shape : tuple + When shape is None, return the original image. If set, return the resized image. + + + The dataset should have the structure: + --------------------------------------- + path/ + train/ + n02134418/ + n02134418_198.JPEG + ... + ... + val/ + ILSVRC2012_val_00000001.JPEG + ... + test/ + ILSVRC2012_test_00000001.JPEG + ... + --------------------------------------- + With the downloaded ILSVRC12_img_*.tar, you can use the following + command to build the above structure: + + mkdir val && tar xvf ILSVRC12_img_val.tar -C val + mkdir test && tar xvf ILSVRC12_img_test.tar -C test + mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train + find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' + """ + def __init__(self, train_or_test_or_val, meta_dir, path, shape=None): + super(ILSVRC12, self).__init__(train_or_test_or_val, meta_dir, path) + self.shape = shape + + def __getitem__(self, index): + fname, label = super(ILSVRC12, self).__getitem__(index) + img = cv2.imread(fname, cv2.IMREAD_COLOR) + if self.shape is not None: + img = cv2.resize(img, self.shape) + img = np.array(img, dtype=np.float32) + return img, label diff --git a/tensorlayer/data/dataset/imdb.py b/tensorlayer/data/dataset/imdb.py new file mode 100644 index 000000000..aa1fe31f1 --- /dev/null +++ b/tensorlayer/data/dataset/imdb.py @@ -0,0 +1,159 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +import numpy as np +import six.moves.cPickle as pickle + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_imdb_dataset', 'IMDB'] + +IMDB_BASE_URL = 'https://s3.amazonaws.com/text-datasets/' +IMDB_FILENAME = 'imdb.pkl' + + +def load_imdb_dataset( + name='imdb', path='raw_data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, + oov_char=2, index_from=3): + """ + Load IMDB dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/imdb/``. + nb_words : int + Number of words to get. + skip_top : int + Top most frequent words to ignore (they will appear as oov_char value in the sequence data). + maxlen : int + Maximum sequence length. Any longer sequence will be truncated. + test_split : float + Split of train / test dataset. + seed : int + Seed for reproducible data shuffling. + start_char : int + The start of a sequence will be marked with this character. Set to 1 because 0 is usually the padding character. + oov_char : int + Words that were cut out because of the num_words or skip_top limit will be replaced with this character. + index_from : int + Index actual words with this index and higher. + + Examples + -------- + >>> X_train, y_train, X_test, y_test = load_imdb_dataset(nb_words=20000, test_split=0.2) + >>> print('X_train.shape', X_train.shape) + (20000,) [[1, 62, 74, ... 1033, 507, 27],[1, 60, 33, ... 13, 1053, 7]..] + >>> print('y_train.shape', y_train.shape) + (20000,) [1 0 0 ..., 1 0 1] + + References + ----------- + - `Modified from keras. `__ + + """ + X, labels = _load_raw_imdb(path, name) + + np.random.seed(seed) + np.random.shuffle(X) + np.random.seed(seed) + np.random.shuffle(labels) + + X, labels = _preprocess_imdb(X, index_from, labels, maxlen, nb_words, oov_char, skip_top, start_char) + + X_train = np.array(X[:int(len(X) * (1 - test_split))]) + y_train = np.array(labels[:int(len(X) * (1 - test_split))]) + + X_test = np.array(X[int(len(X) * (1 - test_split)):]) + y_test = np.array(labels[int(len(X) * (1 - test_split)):]) + + return X_train, y_train, X_test, y_test + + +def _preprocess_imdb(X, index_from, labels, maxlen, nb_words, oov_char, skip_top, start_char): + if start_char is not None: + X = [[start_char] + [w + index_from for w in x] for x in X] + elif index_from: + X = [[w + index_from for w in x] for x in X] + if maxlen: + new_X = [] + new_labels = [] + for x, y in zip(X, labels): + if len(x) < maxlen: + new_X.append(x) + new_labels.append(y) + X = new_X + labels = new_labels + if not X: + raise Exception( + 'After filtering for sequences shorter than maxlen=' + str(maxlen) + ', no sequence was kept. ' + 'Increase maxlen.' + ) + if not nb_words: + nb_words = max([max(x) for x in X]) + # by convention, use 2 as OOV word + # reserve 'index_from' (=3 by default) characters: 0 (padding), 1 (start), 2 (OOV) + if oov_char is not None: + X = [[oov_char if (w >= nb_words or w < skip_top) else w for w in x] for x in X] + else: + nX = [] + for x in X: + nx = [] + for w in x: + if (w >= nb_words or w < skip_top): + nx.append(w) + nX.append(nx) + X = nX + return X, labels + + +def _load_raw_imdb(path, name): + path = os.path.join(path, name) + maybe_download_and_extract(IMDB_FILENAME, path, IMDB_BASE_URL) + f = open(os.path.join(path, IMDB_FILENAME), 'rb') + X, labels = pickle.load(f) + f.close() + return X, labels + + +class IMDB(Dataset): + """ + Load IMDB dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/imdb/``. + nb_words : int + Number of words to get. + skip_top : int + Top most frequent words to ignore (they will appear as oov_char value in the sequence data). + maxlen : int + Maximum sequence length. Any longer sequence will be truncated. + start_char : int + The start of a sequence will be marked with this character. Set to 1 because 0 is usually the padding character. + oov_char : int + Words that were cut out because of the num_words or skip_top limit will be replaced with this character. + index_from : int + Index actual words with this index and higher. + """ + + def __init__(self, name='imdb', path='raw_data', nb_words=None, skip_top=0, maxlen=None, start_char=1, oov_char=2, + index_from=3): + self.X, self.labels = _load_raw_imdb(path=path, name=name) + self.X, self.labels = _preprocess_imdb(self.X, index_from, self.labels, maxlen, nb_words, oov_char, skip_top, + start_char) + + def __getitem__(self, index): + return self.X[index], self.labels[index] + + def __len__(self): + assert len(self.X) == len(self.labels) + return len(self.labels) diff --git a/tensorlayer/data/dataset/matt_mahoney.py b/tensorlayer/data/dataset/matt_mahoney.py new file mode 100644 index 000000000..7f03b9ef3 --- /dev/null +++ b/tensorlayer/data/dataset/matt_mahoney.py @@ -0,0 +1,76 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os +import zipfile + +from tensorlayer import logging + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_matt_mahoney_text8_dataset', 'MattMahoney'] + +MATT_MAHONEY_BASE_URL = 'http://mattmahoney.net/dc/' +MATT_MAHONEY_FILENAME = 'text8.zip' + + +def load_matt_mahoney_text8_dataset(name='mm_test8', path='raw_data'): + """ + Load Matt Mahoney's dataset. + + Download a text file from Matt Mahoney's website + if not present, and make sure it's the right size. + Extract the first file enclosed in a zip file as a list of words. + This dataset can be used for Word Embedding. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mm_test8/``. + + Returns + -------- + list of str + The raw text data e.g. [.... 'their', 'families', 'who', 'were', 'expelled', 'from', 'jerusalem', ...] + + Examples + -------- + >>> words = load_matt_mahoney_text8_dataset() + >>> print('Data size', len(words)) + + """ + path = os.path.join(path, name) + logging.info("Load or Download matt_mahoney_text8 Dataset> {}".format(path)) + + maybe_download_and_extract(MATT_MAHONEY_FILENAME, path, MATT_MAHONEY_BASE_URL, expected_bytes=31344016) + + with zipfile.ZipFile(os.path.join(path, MATT_MAHONEY_FILENAME)) as f: + word_list = f.read(f.namelist()[0]).split() + for idx, _ in enumerate(word_list): + word_list[idx] = word_list[idx].decode() + return word_list + + +class MattMahoney(Dataset): + """ + Load Matt Mahoney's dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mm_test8/``. + """ + + def __init__(self, name='mm_test8', path='raw_data'): + self.word_list = load_matt_mahoney_text8_dataset(path=path, name=name) + + def __getitem__(self, index): + return self.word_list[index] + + def __len__(self): + return len(self.word_list) diff --git a/tensorlayer/data/dataset/mnist.py b/tensorlayer/data/dataset/mnist.py new file mode 100644 index 000000000..02df262ff --- /dev/null +++ b/tensorlayer/data/dataset/mnist.py @@ -0,0 +1,115 @@ +import gzip +import logging +import os +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['MNIST', 'load_mnist_dataset'] + +MNIST_BASE_URL = 'http://yann.lecun.com/exdb/mnist/' +MNIST_TRAIN_IMAGE_FILENAME = 'train-images-idx3-ubyte.gz' +MNIST_TRAIN_LABEL_FILENAME = 'train-labels-idx1-ubyte.gz' +MNIST_TEST_IMAGE_FILENAME = 't10k-images-idx3-ubyte.gz' +MNIST_TEST_LABEL_FILENAME = 't10k-labels-idx1-ubyte.gz' + + +def _load_mnist_images(name, url, path, shape): + filepath = maybe_download_and_extract(name, path, url) + + logging.info(filepath) + # Read the inputs in Yann LeCun's binary format. + with gzip.open(filepath, 'rb') as f: + data = np.frombuffer(f.read(), np.uint8, offset=16) + # The inputs are vectors now, we reshape them to monochrome 2D images, + # following the shape convention: (examples, channels, rows, columns) + data = data.reshape(shape) + # The inputs come as bytes, we convert them to float32 in range [0,1]. + # (Actually to range [0, 255/256], for compatibility to the version + # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) + return data / np.float32(256) + + +def _load_mnist_labels(name, url, path): + filepath = maybe_download_and_extract(name, path, url) + + # Read the labels in Yann LeCun's binary format. + with gzip.open(filepath, 'rb') as f: + data = np.frombuffer(f.read(), np.uint8, offset=8) + # The labels are vectors of integers now, that's exactly what we want. + return data + + +def load_mnist_dataset(shape=(-1, 784), name='mnist', path='raw_data'): + """ + A generic function to load mnist-like dataset. + + Parameters: + ---------- + shape : tuple + The shape of digit images. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mnist/``. + """ + path = os.path.join(path, name) + + # Download and read the training and test set images and labels. + logging.info("Load or Download {0} > {1}".format(name.upper(), path)) + X_train = _load_mnist_images(name=MNIST_TRAIN_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, shape=shape) + y_train = _load_mnist_labels(name=MNIST_TRAIN_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + X_test = _load_mnist_images(name=MNIST_TEST_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, shape=shape) + y_test = _load_mnist_labels(name=MNIST_TEST_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + + # We reserve the last 10000 training examples for validation. + X_train, X_val = X_train[:-10000], X_train[-10000:] + y_train, y_val = y_train[:-10000], y_train[-10000:] + + # We just return all the arrays in order, as expected in main(). + # (It doesn't matter how we do this as long as we can read them again.) + X_train = np.asarray(X_train, dtype=np.float32) + y_train = np.asarray(y_train, dtype=np.int32) + X_val = np.asarray(X_val, dtype=np.float32) + y_val = np.asarray(y_val, dtype=np.int32) + X_test = np.asarray(X_test, dtype=np.float32) + y_test = np.asarray(y_test, dtype=np.int32) + return X_train, y_train, X_val, y_val, X_test, y_test + + +class MNIST(Dataset): + """ + Load MNIST dataset. + + Parameters: + ---------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + shape : tuple + The shape of digit images. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mnist/``. + """ + def __init__(self, train_or_test, path='raw_data', name='mnist', shape=(-1, 784)): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + + assert train_or_test in ['train', 'test'] + self.train_or_test = train_or_test + if train_or_test == 'train': + self.images = _load_mnist_images(name=MNIST_TRAIN_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, + shape=shape) + self.labels = _load_mnist_labels(name=MNIST_TRAIN_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + else: + self.images = _load_mnist_images(name=MNIST_TEST_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, + shape=shape) + self.labels = _load_mnist_labels(name=MNIST_TEST_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, index): + return self.images[index], self.labels[index] diff --git a/tensorlayer/data/dataset/mnist_fashion.py b/tensorlayer/data/dataset/mnist_fashion.py new file mode 100644 index 000000000..d0025563e --- /dev/null +++ b/tensorlayer/data/dataset/mnist_fashion.py @@ -0,0 +1,108 @@ +import logging +import os +import numpy as np + +from ..base import Dataset +from .mnist import _load_mnist_images, _load_mnist_labels + +__all__ = ['load_fashion_mnist_dataset', 'FASHION_MNIST'] + +FASHION_MNIST_BASE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' +FASHION_MNIST_TRAIN_IMAGE_FILENAME = 'train-images-idx3-ubyte.gz' +FASHION_MNIST_TRAIN_LABEL_FILENAME = 'train-labels-idx1-ubyte.gz' +FASHION_MNIST_TEST_IMAGE_FILENAME = 't10k-images-idx3-ubyte.gz' +FASHION_MNIST_TEST_LABEL_FILENAME = 't10k-labels-idx1-ubyte.gz' + + +def load_fashion_mnist_dataset(shape=(-1, 784), name='fashion_mnist', path='raw_data'): + """ + Load the fashion mnist. + + Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, `examples `__. + + Parameters + ---------- + shape : tuple + The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)). + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/fashion_mnist/``. + + Returns + ------- + X_train, y_train, X_val, y_val, X_test, y_test: tuple + Return splitted training/validation/test set respectively. + + Examples + -------- + >>> X_train, y_train, X_val, y_val, X_test, y_test = load_fashion_mnist_dataset(shape=(-1,784), path='datasets') + >>> X_train, y_train, X_val, y_val, X_test, y_test = load_fashion_mnist_dataset(shape=(-1, 28, 28, 1)) + """ + path = os.path.join(path, name) + + # Download and read the training and test set images and labels. + logging.info("Load or Download {0} > {1}".format(name.upper(), path)) + X_train = _load_mnist_images(name=FASHION_MNIST_TRAIN_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, path=path, + shape=shape) + y_train = _load_mnist_labels(name=FASHION_MNIST_TRAIN_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, path=path) + X_test = _load_mnist_images(name=FASHION_MNIST_TEST_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, path=path, + shape=shape) + y_test = _load_mnist_labels(name=FASHION_MNIST_TEST_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, path=path) + + # We reserve the last 10000 training examples for validation. + X_train, X_val = X_train[:-10000], X_train[-10000:] + y_train, y_val = y_train[:-10000], y_train[-10000:] + + # We just return all the arrays in order, as expected in main(). + # (It doesn't matter how we do this as long as we can read them again.) + X_train = np.asarray(X_train, dtype=np.float32) + y_train = np.asarray(y_train, dtype=np.int32) + X_val = np.asarray(X_val, dtype=np.float32) + y_val = np.asarray(y_val, dtype=np.int32) + X_test = np.asarray(X_test, dtype=np.float32) + y_test = np.asarray(y_test, dtype=np.int32) + return X_train, y_train, X_val, y_val, X_test, y_test + + +class FASHION_MNIST(Dataset): + """ + Load the fashion mnist. + + Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, `examples `__. + + Parameters + ---------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + shape : tuple + The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)). + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/fashion_mnist/``. + """ + + def __init__(self, train_or_test, name='fashion_mnist', path='raw_data', shape=(-1, 784)): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + + assert train_or_test in ['train', 'test'] + if train_or_test == 'train': + self.images = _load_mnist_images(name=FASHION_MNIST_TRAIN_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path, + shape=shape) + self.labels = _load_mnist_labels(name=FASHION_MNIST_TRAIN_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path) + else: + self.images = _load_mnist_images(name=FASHION_MNIST_TEST_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path, + shape=shape) + self.labels = _load_mnist_labels(name=FASHION_MNIST_TEST_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path) + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, index): + return self.images[index], self.labels[index] diff --git a/tensorlayer/data/dataset/mpii.py b/tensorlayer/data/dataset/mpii.py new file mode 100644 index 000000000..85faec09d --- /dev/null +++ b/tensorlayer/data/dataset/mpii.py @@ -0,0 +1,297 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging + +from ..base import Dataset +from ..utils import del_file, folder_exists, load_file_list, maybe_download_and_extract + +__all__ = ['load_mpii_pose_dataset', 'MPII'] + +MPII_BASE_URL = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/" + + +def load_mpii_pose_dataset(name='mpii_human_pose', path='raw_data', is_16_pos_only=False): + """ + Load MPII Human Pose Dataset. + + Parameters + ----------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/mpii_human_pose`. + is_16_pos_only : boolean + If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation) + + Returns + ---------- + img_train_list : list of str + The image directories of training data. + ann_train_list : list of dict + The annotations of training data. + img_test_list : list of str + The image directories of testing data. + ann_test_list : list of dict + The annotations of testing data. + + Examples + -------- + >>> import pprint + >>> import tensorlayer as tl + >>> img_train_list, ann_train_list, img_test_list, ann_test_list = load_mpii_pose_dataset() + >>> image = tl.vis.read_image(img_train_list[0]) + >>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png') + >>> pprint.pprint(ann_train_list[0]) + + References + ----------- + - `MPII Human Pose Dataset. CVPR 14 `__ + - `MPII Human Pose Models. CVPR 16 `__ + - `MPII Human Shape, Poselet Conditioned Pictorial Structures and etc `__ + - `MPII Keyponts and ID `__ + """ + path = os.path.join(path, name) + logging.info("Load or Download MPII Human Pose > {}".format(path)) + + # annotation + tar_filename = "mpii_human_pose_v1_u12_2.zip" + extracted_filename = "mpii_human_pose_v1_u12_2" + if folder_exists(os.path.join(path, extracted_filename)) is False: + logging.info("[MPII] (annotation) {} is nonexistent in {}".format(extracted_filename, path)) + maybe_download_and_extract(tar_filename, path, MPII_BASE_URL+tar_filename, extract=True) + del_file(os.path.join(path, tar_filename)) + + # images + tar_filename = "mpii_human_pose_v1.tar.gz" + extracted_filename2 = "images" + if folder_exists(os.path.join(path, extracted_filename2)) is False: + logging.info("[MPII] (images) {} is nonexistent in {}".format(extracted_filename, path)) + maybe_download_and_extract(tar_filename, path, MPII_BASE_URL+tar_filename, extract=True) + del_file(os.path.join(path, tar_filename)) + + # parse annotation, format see http://human-pose.mpi-inf.mpg.de/#download + import scipy.io as sio + logging.info("reading annotations from mat file ...") + # mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat")) + + # def fix_wrong_joints(joint): # https://github.com/mitmul/deeppose/blob/master/datasets/mpii_dataset.py + # if '12' in joint and '13' in joint and '2' in joint and '3' in joint: + # if ((joint['12'][0] < joint['13'][0]) and + # (joint['3'][0] < joint['2'][0])): + # joint['2'], joint['3'] = joint['3'], joint['2'] + # if ((joint['12'][0] > joint['13'][0]) and + # (joint['3'][0] > joint['2'][0])): + # joint['2'], joint['3'] = joint['3'], joint['2'] + # return joint + + ann_train_list = [] + ann_test_list = [] + img_train_list = [] + img_test_list = [] + + def save_joints(): + # joint_data_fn = os.path.join(path, 'data.json') + # fp = open(joint_data_fn, 'w') + mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat")) + + for _, (anno, train_flag) in enumerate( # all images + zip(mat['RELEASE']['annolist'][0, 0][0], mat['RELEASE']['img_train'][0, 0][0])): + + img_fn = anno['image']['name'][0, 0][0] + train_flag = int(train_flag) + + # print(i, img_fn, train_flag) # DEBUG print all images + + if train_flag: + img_train_list.append(img_fn) + ann_train_list.append([]) + else: + img_test_list.append(img_fn) + ann_test_list.append([]) + + head_rect = [] + if 'x1' in str(anno['annorect'].dtype): + head_rect = zip( + [x1[0, 0] for x1 in anno['annorect']['x1'][0]], [y1[0, 0] for y1 in anno['annorect']['y1'][0]], + [x2[0, 0] for x2 in anno['annorect']['x2'][0]], [y2[0, 0] for y2 in anno['annorect']['y2'][0]] + ) + else: + head_rect = [] # TODO + + if 'annopoints' in str(anno['annorect'].dtype): + annopoints = anno['annorect']['annopoints'][0] + head_x1s = anno['annorect']['x1'][0] + head_y1s = anno['annorect']['y1'][0] + head_x2s = anno['annorect']['x2'][0] + head_y2s = anno['annorect']['y2'][0] + + for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(annopoints, head_x1s, head_y1s, head_x2s, + head_y2s): + # if annopoint != []: + # if len(annopoint) != 0: + if annopoint.size: + head_rect = [ + float(head_x1[0, 0]), + float(head_y1[0, 0]), + float(head_x2[0, 0]), + float(head_y2[0, 0]) + ] + + # joint coordinates + annopoint = annopoint['point'][0, 0] + j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]] + x = [x[0, 0] for x in annopoint['x'][0]] + y = [y[0, 0] for y in annopoint['y'][0]] + joint_pos = {} + for _j_id, (_x, _y) in zip(j_id, zip(x, y)): + joint_pos[int(_j_id)] = [float(_x), float(_y)] + # joint_pos = fix_wrong_joints(joint_pos) + + # visibility list + if 'is_visible' in str(annopoint.dtype): + vis = [v[0] if v.size > 0 else [0] for v in annopoint['is_visible'][0]] + vis = dict([(k, int(v[0])) if len(v) > 0 else v for k, v in zip(j_id, vis)]) + else: + vis = None + + # if len(joint_pos) == 16: + if ((is_16_pos_only ==True) and (len(joint_pos) == 16)) or (is_16_pos_only == False): + # only use image with 16 key points / or use all + data = { + 'filename': img_fn, + 'train': train_flag, + 'head_rect': head_rect, + 'is_visible': vis, + 'joint_pos': joint_pos + } + # print(json.dumps(data), file=fp) # py3 + if train_flag: + ann_train_list[-1].append(data) + else: + ann_test_list[-1].append(data) + + # def write_line(datum, fp): + # joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()]) + # joints = np.array([j for i, j in joints]).flatten() + # + # out = [datum['filename']] + # out.extend(joints) + # out = [str(o) for o in out] + # out = ','.join(out) + # + # print(out, file=fp) + + # def split_train_test(): + # # fp_test = open('data/mpii/test_joints.csv', 'w') + # fp_test = open(os.path.join(path, 'test_joints.csv'), 'w') + # # fp_train = open('data/mpii/train_joints.csv', 'w') + # fp_train = open(os.path.join(path, 'train_joints.csv'), 'w') + # # all_data = open('data/mpii/data.json').readlines() + # all_data = open(os.path.join(path, 'data.json')).readlines() + # N = len(all_data) + # N_test = int(N * 0.1) + # N_train = N - N_test + # + # print('N:{}'.format(N)) + # print('N_train:{}'.format(N_train)) + # print('N_test:{}'.format(N_test)) + # + # np.random.seed(1701) + # perm = np.random.permutation(N) + # test_indices = perm[:N_test] + # train_indices = perm[N_test:] + # + # print('train_indices:{}'.format(len(train_indices))) + # print('test_indices:{}'.format(len(test_indices))) + # + # for i in train_indices: + # datum = json.loads(all_data[i].strip()) + # write_line(datum, fp_train) + # + # for i in test_indices: + # datum = json.loads(all_data[i].strip()) + # write_line(datum, fp_test) + + save_joints() + # split_train_test() # + + ## read images dir + logging.info("reading images list ...") + img_dir = os.path.join(path, extracted_filename2) + _img_list = load_file_list(path=os.path.join(path, extracted_filename2), regx='\\.jpg', printable=False) + # ann_list = json.load(open(os.path.join(path, 'data.json'))) + for i, im in enumerate(img_train_list): + if im not in _img_list: + print('missing training image {} in {} (remove from img(ann)_train_list)'.format(im, img_dir)) + # img_train_list.remove(im) + del img_train_list[i] + del ann_train_list[i] + for i, im in enumerate(img_test_list): + if im not in _img_list: + print('missing testing image {} in {} (remove from img(ann)_test_list)'.format(im, img_dir)) + # img_test_list.remove(im) + del img_train_list[i] + del ann_train_list[i] + + ## check annotation and images + n_train_images = len(img_train_list) + n_test_images = len(img_test_list) + n_images = n_train_images + n_test_images + logging.info("n_images: {} n_train_images: {} n_test_images: {}".format(n_images, n_train_images, n_test_images)) + n_train_ann = len(ann_train_list) + n_test_ann = len(ann_test_list) + n_ann = n_train_ann + n_test_ann + logging.info("n_ann: {} n_train_ann: {} n_test_ann: {}".format(n_ann, n_train_ann, n_test_ann)) + n_train_people = len(sum(ann_train_list, [])) + n_test_people = len(sum(ann_test_list, [])) + n_people = n_train_people + n_test_people + logging.info("n_people: {} n_train_people: {} n_test_people: {}".format(n_people, n_train_people, n_test_people)) + # add path to all image file name + for i, value in enumerate(img_train_list): + img_train_list[i] = os.path.join(img_dir, value) + for i, value in enumerate(img_test_list): + img_test_list[i] = os.path.join(img_dir, value) + return img_train_list, ann_train_list, img_test_list, ann_test_list + + +class MPII(Dataset): + """ + Load MPII Human Pose Dataset. + + Parameters + ----------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/mpii_human_pose`. + is_16_pos_only : boolean + If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation) + """ + + def __init__(self, train_or_test, path='raw_data', name='mpii_human_pose', is_16_pos_only=False): + self.path = os.path.join(path, name) + img_train_list, ann_train_list, img_test_list, ann_test_list = load_mpii_pose_dataset(name=name, path=path, is_16_pos_only=is_16_pos_only) + assert train_or_test in ['train', 'test'] + self.train_or_test = train_or_test + if train_or_test == 'train': + self.img_list = img_train_list + self.ann_list = ann_train_list + del img_test_list + del ann_test_list + else: + self.img_list = img_test_list + self.ann_list = ann_test_list + del img_train_list + del ann_train_list + + def __getitem__(self, index): + return self.img_list[index], self.ann_list[index] + + def __len__(self): + assert len(self.img_list) == len(self.ann_list) + return len(self.img_list) diff --git a/tensorlayer/data/dataset/nietzsche.py b/tensorlayer/data/dataset/nietzsche.py new file mode 100644 index 000000000..5e00e32e4 --- /dev/null +++ b/tensorlayer/data/dataset/nietzsche.py @@ -0,0 +1,70 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_nietzsche_dataset', 'NIETZSCHE'] + +NIETZSCHE_BASE_URL = 'https://s3.amazonaws.com/text-datasets/' +NIETZSCHE_FILENAME = 'nietzsche.txt' + + +def load_nietzsche_dataset(name='nietzsche', path='raw_data'): + """ + Load Nietzsche dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/nietzsche/``. + + Returns + -------- + str + The content. + + Examples + -------- + >>> see tutorial_generate_text.py + >>> words = tl.files.load_nietzsche_dataset() + >>> words = basic_clean_str(words) + >>> words = words.split() + + """ + logging.info("Load or Download nietzsche dataset > {}".format(path)) + path = os.path.join(path, name) + + filepath = maybe_download_and_extract(NIETZSCHE_FILENAME, path, NIETZSCHE_BASE_URL) + + with open(filepath, "r") as f: + words = f.read() + return words + + +class NIETZSCHE(Dataset): + """ + Load Nietzsche dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/nietzsche/``. + """ + + def __init__(self, name='nietzsche', path='raw_data'): + self.words = load_nietzsche_dataset(name=name, path=path) + + def __getitem__(self, index): + return self.words[index] + + def __len__(self): + return len(self.words) diff --git a/tensorlayer/data/dataset/ptb.py b/tensorlayer/data/dataset/ptb.py new file mode 100644 index 000000000..96183c35d --- /dev/null +++ b/tensorlayer/data/dataset/ptb.py @@ -0,0 +1,138 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging, nlp +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_ptb_dataset', 'PTB'] + +PTB_URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/' +PTB_FILENAME = 'simple-examples.tgz' + + +def load_ptb_dataset(name='ptb', path='raw_data'): + """ + Load Penn TreeBank (PTB) dataset. + + It is used in many LANGUAGE MODELING papers, + including "Empirical Evaluation and Combination of Advanced Language + Modeling Techniques", "Recurrent Neural Network Regularization". + It consists of 929k training words, 73k validation words, and 82k test + words. It has 10k words in its vocabulary. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/ptb/``. + + Returns + -------- + train_data, valid_data, test_data : list of int + The training, validating and testing data in integer format. + vocab_size : int + The vocabulary size. + + Examples + -------- + >>> train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset() + + References + --------------- + - ``tensorflow.models.rnn.ptb import reader`` + - `Manual download `__ + + Notes + ------ + - If you want to get the raw data, see the source code. + + """ + path = os.path.join(path, name) + logging.info("Load or Download Penn TreeBank (PTB) dataset > {}".format(path)) + + # Maybe dowload and uncompress tar, or load exsisting files + maybe_download_and_extract(PTB_FILENAME, path, PTB_URL, extract=True) + + data_path = os.path.join(path, 'simple-examples', 'data') + train_path = os.path.join(data_path, "ptb.train.txt") + valid_path = os.path.join(data_path, "ptb.valid.txt") + test_path = os.path.join(data_path, "ptb.test.txt") + + word_to_id = nlp.build_vocab(nlp.read_words(train_path)) + + train_data = nlp.words_to_word_ids(nlp.read_words(train_path), word_to_id) + valid_data = nlp.words_to_word_ids(nlp.read_words(valid_path), word_to_id) + test_data = nlp.words_to_word_ids(nlp.read_words(test_path), word_to_id) + vocab_size = len(word_to_id) + + # logging.info(nlp.read_words(train_path)) # ... 'according', 'to', 'mr.', '', ''] + # logging.info(train_data) # ... 214, 5, 23, 1, 2] + # logging.info(word_to_id) # ... 'beyond': 1295, 'anti-nuclear': 9599, 'trouble': 1520, '': 2 ... } + # logging.info(vocabulary) # 10000 + # exit() + return train_data, valid_data, test_data, vocab_size + + +class PTB(Dataset): + """ + Load Penn TreeBank (PTB) dataset. + + It is used in many LANGUAGE MODELING papers, + including "Empirical Evaluation and Combination of Advanced Language + Modeling Techniques", "Recurrent Neural Network Regularization". + It consists of 929k training words, 73k validation words, and 82k test + words. It has 10k words in its vocabulary. + + Parameters + ---------- + train_or_test_or_valid : str + Must be either 'train' or 'test' or 'valid'. Choose the training or test or validation dataset. + num_steps : int + The number of unrolls. i.e. sequence_length + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/ptb/``. + """ + + def __init__(self, train_or_test_or_valid, num_steps, name='ptb', path='raw_data'): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + logging.info("Load or Download Penn TreeBank (PTB) dataset > {}".format(self.path)) + + maybe_download_and_extract(PTB_FILENAME, self.path, PTB_URL, extract=True) + + self.num_steps = num_steps + self.path = os.path.join(self.path, 'simple-examples', 'data') + assert train_or_test_or_valid in ['train', 'test', 'valid'] + self.train_or_test_or_valid = train_or_test_or_valid + train_path = os.path.join(self.path, "ptb.train.txt") + if train_or_test_or_valid == 'train': + data_path = train_path + elif train_or_test_or_valid == 'valid': + data_path = os.path.join(self.path, "ptb.valid.txt") + else: + data_path = os.path.join(self.path, "ptb.test.txt") + + # use training data to build vocab + self.word_to_id = nlp.build_vocab(nlp.read_words(train_path)) + self.vocav_size = len(self.word_to_id) + self.data = nlp.words_to_word_ids(nlp.read_words(data_path), self.word_to_id) + + self.data = np.array(self.data, dtype=np.int32) + self.data_len = (len(self.data) - 1) // self.num_steps + self.data = self.data[:self.data_len * self.num_steps + 1] + + def __getitem__(self, index): + x = self.data[index * self.num_steps:(index + 1) * self.num_steps] + y = self.data[index * self.num_steps + 1:(index + 1) * self.num_steps + 1] + return x, y + + def __len__(self): + return self.data_len diff --git a/tensorlayer/data/dataset/voc.py b/tensorlayer/data/dataset/voc.py new file mode 100644 index 000000000..8fd6dfe8c --- /dev/null +++ b/tensorlayer/data/dataset/voc.py @@ -0,0 +1,334 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +import tensorflow as tf +from tensorlayer import logging, utils +from ..utils import del_file, del_folder, folder_exists, load_file_list, maybe_download_and_extract + +__all__ = ['load_voc_dataset'] + + +def load_voc_dataset(path='data', dataset='2012', contain_classes_in_person=False): + """Pascal VOC 2007/2012 Dataset. + + It has 20 objects: + aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor + and additional 3 classes : head, hand, foot for person. + + Parameters + ----------- + path : str + The path that the data is downloaded to, defaults is ``data/VOC``. + dataset : str + The VOC dataset version, `2012`, `2007`, `2007test` or `2012test`. We usually train model on `2007+2012` and test it on `2007test`. + contain_classes_in_person : boolean + Whether include head, hand and foot annotation, default is False. + + Returns + --------- + imgs_file_list : list of str + Full paths of all images. + imgs_semseg_file_list : list of str + Full paths of all maps for semantic segmentation. Note that not all images have this map! + imgs_insseg_file_list : list of str + Full paths of all maps for instance segmentation. Note that not all images have this map! + imgs_ann_file_list : list of str + Full paths of all annotations for bounding box and object class, all images have this annotations. + classes : list of str + Classes in order. + classes_in_person : list of str + Classes in person. + classes_dict : dictionary + Class label to integer. + n_objs_list : list of int + Number of objects in all images in ``imgs_file_list`` in order. + objs_info_list : list of str + Darknet format for the annotation of all images in ``imgs_file_list`` in order. ``[class_id x_centre y_centre width height]`` in ratio format. + objs_info_dicts : dictionary + The annotation of all images in ``imgs_file_list``, ``{imgs_file_list : dictionary for annotation}``, + format from `TensorFlow/Models/object-detection `__. + + Examples + ---------- + >>> imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list, + >>> classes, classes_in_person, classes_dict, + >>> n_objs_list, objs_info_list, objs_info_dicts = tl.files.load_voc_dataset(dataset="2012", contain_classes_in_person=False) + >>> idx = 26 + >>> print(classes) + ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] + >>> print(classes_dict) + {'sheep': 16, 'horse': 12, 'bicycle': 1, 'bottle': 4, 'cow': 9, 'sofa': 17, 'car': 6, 'dog': 11, 'cat': 7, 'person': 14, 'train': 18, 'diningtable': 10, 'aeroplane': 0, 'bus': 5, 'pottedplant': 15, 'tvmonitor': 19, 'chair': 8, 'bird': 2, 'boat': 3, 'motorbike': 13} + >>> print(imgs_file_list[idx]) + data/VOC/VOC2012/JPEGImages/2007_000423.jpg + >>> print(n_objs_list[idx]) + 2 + >>> print(imgs_ann_file_list[idx]) + data/VOC/VOC2012/Annotations/2007_000423.xml + >>> print(objs_info_list[idx]) + 14 0.173 0.461333333333 0.142 0.496 + 14 0.828 0.542666666667 0.188 0.594666666667 + >>> ann = tl.prepro.parse_darknet_ann_str_to_list(objs_info_list[idx]) + >>> print(ann) + [[14, 0.173, 0.461333333333, 0.142, 0.496], [14, 0.828, 0.542666666667, 0.188, 0.594666666667]] + >>> c, b = tl.prepro.parse_darknet_ann_list_to_cls_box(ann) + >>> print(c, b) + [14, 14] [[0.173, 0.461333333333, 0.142, 0.496], [0.828, 0.542666666667, 0.188, 0.594666666667]] + + References + ------------- + - `Pascal VOC2012 Website `__. + - `Pascal VOC2007 Website `__. + + """ + try: + import lxml.etree as etree + except ImportError as e: + print(e) + raise ImportError("Module lxml not found. Please install lxml via pip or other package managers.") + + path = os.path.join(path, 'VOC') + + def _recursive_parse_xml_to_dict(xml): + """Recursively parses XML contents to python dict. + + We assume that `object` tags are the only ones that can appear + multiple times at the same level of a tree. + + Args: + xml: xml tree obtained by parsing XML file contents using lxml.etree + + Returns: + Python dictionary holding XML contents. + + """ + if xml is not None: + return {xml.tag: xml.text} + result = {} + for child in xml: + child_result = _recursive_parse_xml_to_dict(child) + if child.tag != 'object': + result[child.tag] = child_result[child.tag] + else: + if child.tag not in result: + result[child.tag] = [] + result[child.tag].append(child_result[child.tag]) + return {xml.tag: result} + + import xml.etree.ElementTree as ET + + if dataset == "2012": + url = "http://pjreddie.com/media/files/" + tar_filename = "VOCtrainval_11-May-2012.tar" + extracted_filename = "VOC2012" #"VOCdevkit/VOC2012" + logging.info(" [============= VOC 2012 =============]") + elif dataset == "2012test": + extracted_filename = "VOC2012test" #"VOCdevkit/VOC2012" + logging.info(" [============= VOC 2012 Test Set =============]") + logging.info( + " \nAuthor: 2012test only have person annotation, so 2007test is highly recommended for testing !\n" + ) + import time + time.sleep(3) + if os.path.isdir(os.path.join(path, extracted_filename)) is False: + logging.info("For VOC 2012 Test data - online registration required") + logging.info( + " Please download VOC2012test.tar from: \n register: http://host.robots.ox.ac.uk:8080 \n voc2012 : http://host.robots.ox.ac.uk:8080/eval/challenges/voc2012/ \ndownload: http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2012test.tar" + ) + logging.info(" unzip VOC2012test.tar,rename the folder to VOC2012test and put it into %s" % path) + exit() + # # http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2012test.tar + # url = "http://host.robots.ox.ac.uk:8080/eval/downloads/" + # tar_filename = "VOC2012test.tar" + elif dataset == "2007": + url = "http://pjreddie.com/media/files/" + tar_filename = "VOCtrainval_06-Nov-2007.tar" + extracted_filename = "VOC2007" + logging.info(" [============= VOC 2007 =============]") + elif dataset == "2007test": + # http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html#testdata + # http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar + url = "http://pjreddie.com/media/files/" + tar_filename = "VOCtest_06-Nov-2007.tar" + extracted_filename = "VOC2007test" + logging.info(" [============= VOC 2007 Test Set =============]") + else: + raise Exception("Please set the dataset aug to 2012, 2012test or 2007.") + + # download dataset + if dataset != "2012test": + from sys import platform as _platform + if folder_exists(os.path.join(path, extracted_filename)) is False: + logging.info("[VOC] {} is nonexistent in {}".format(extracted_filename, path)) + maybe_download_and_extract(tar_filename, path, url+tar_filename, extract=True) + del_file(os.path.join(path, tar_filename)) + if dataset == "2012": + if _platform == "win32": + os.system("move {}\VOCdevkit\VOC2012 {}\VOC2012".format(path, path)) + else: + os.system("mv {}/VOCdevkit/VOC2012 {}/VOC2012".format(path, path)) + elif dataset == "2007": + if _platform == "win32": + os.system("move {}\VOCdevkit\VOC2007 {}\VOC2007".format(path, path)) + else: + os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007".format(path, path)) + elif dataset == "2007test": + if _platform == "win32": + os.system("move {}\VOCdevkit\VOC2007 {}\VOC2007test".format(path, path)) + else: + os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007test".format(path, path)) + del_folder(os.path.join(path, 'VOCdevkit')) + # object classes(labels) NOTE: YOU CAN CUSTOMIZE THIS LIST + classes = [ + "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", + "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" + ] + if contain_classes_in_person: + classes_in_person = ["head", "hand", "foot"] + else: + classes_in_person = [] + + classes += classes_in_person # use extra 3 classes for person + + classes_dict = utils.list_string_to_dict(classes) + logging.info("[VOC] object classes {}".format(classes_dict)) + + # 1. image path list + # folder_imgs = path+"/"+extracted_filename+"/JPEGImages/" + folder_imgs = os.path.join(path, extracted_filename, "JPEGImages") + imgs_file_list = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False) + logging.info("[VOC] {} images found".format(len(imgs_file_list))) + + imgs_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000027.jpg --> 2007000027 + + imgs_file_list = [os.path.join(folder_imgs, s) for s in imgs_file_list] + # logging.info('IM',imgs_file_list[0::3333], imgs_file_list[-1]) + if dataset != "2012test": + ##======== 2. semantic segmentation maps path list + # folder_semseg = path+"/"+extracted_filename+"/SegmentationClass/" + folder_semseg = os.path.join(path, extracted_filename, "SegmentationClass") + imgs_semseg_file_list = load_file_list(path=folder_semseg, regx='\\.png', printable=False) + logging.info("[VOC] {} maps for semantic segmentation found".format(len(imgs_semseg_file_list))) + imgs_semseg_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000032.png --> 2007000032 + imgs_semseg_file_list = [os.path.join(folder_semseg, s) for s in imgs_semseg_file_list] + # logging.info('Semantic Seg IM',imgs_semseg_file_list[0::333], imgs_semseg_file_list[-1]) + ##======== 3. instance segmentation maps path list + # folder_insseg = path+"/"+extracted_filename+"/SegmentationObject/" + folder_insseg = os.path.join(path, extracted_filename, "SegmentationObject") + imgs_insseg_file_list = load_file_list(path=folder_insseg, regx='\\.png', printable=False) + logging.info("[VOC] {} maps for instance segmentation found".format(len(imgs_semseg_file_list))) + imgs_insseg_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000032.png --> 2007000032 + imgs_insseg_file_list = [os.path.join(folder_insseg, s) for s in imgs_insseg_file_list] + # logging.info('Instance Seg IM',imgs_insseg_file_list[0::333], imgs_insseg_file_list[-1]) + else: + imgs_semseg_file_list = [] + imgs_insseg_file_list = [] + # 4. annotations for bounding box and object class + # folder_ann = path+"/"+extracted_filename+"/Annotations/" + folder_ann = os.path.join(path, extracted_filename, "Annotations") + imgs_ann_file_list = load_file_list(path=folder_ann, regx='\\.xml', printable=False) + logging.info( + "[VOC] {} XML annotation files for bounding box and object class found".format(len(imgs_ann_file_list)) + ) + imgs_ann_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000027.xml --> 2007000027 + imgs_ann_file_list = [os.path.join(folder_ann, s) for s in imgs_ann_file_list] + # logging.info('ANN',imgs_ann_file_list[0::3333], imgs_ann_file_list[-1]) + + if dataset == "2012test": # remove unused images in JPEG folder + imgs_file_list_new = [] + for ann in imgs_ann_file_list: + ann = os.path.split(ann)[-1].split('.')[0] + for im in imgs_file_list: + if ann in im: + imgs_file_list_new.append(im) + break + imgs_file_list = imgs_file_list_new + logging.info("[VOC] keep %d images" % len(imgs_file_list_new)) + + # parse XML annotations + def convert(size, box): + dw = 1. / size[0] + dh = 1. / size[1] + x = (box[0] + box[1]) / 2.0 + y = (box[2] + box[3]) / 2.0 + w = box[1] - box[0] + h = box[3] - box[2] + x = x * dw + w = w * dw + y = y * dh + h = h * dh + return x, y, w, h + + def convert_annotation(file_name): + """Given VOC2012 XML Annotations, returns number of objects and info.""" + in_file = open(file_name) + out_file = "" + tree = ET.parse(in_file) + root = tree.getroot() + size = root.find('size') + w = int(size.find('width').text) + h = int(size.find('height').text) + n_objs = 0 + + for obj in root.iter('object'): + if dataset != "2012test": + difficult = obj.find('difficult').text + cls = obj.find('name').text + if cls not in classes or int(difficult) == 1: + continue + else: + cls = obj.find('name').text + if cls not in classes: + continue + cls_id = classes.index(cls) + xmlbox = obj.find('bndbox') + b = ( + float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), + float(xmlbox.find('ymax').text) + ) + bb = convert((w, h), b) + + out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n' + n_objs += 1 + if cls in "person": + for part in obj.iter('part'): + cls = part.find('name').text + if cls not in classes_in_person: + continue + cls_id = classes.index(cls) + xmlbox = part.find('bndbox') + b = ( + float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), + float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text) + ) + bb = convert((w, h), b) + # out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n') + out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n' + n_objs += 1 + in_file.close() + return n_objs, out_file + + logging.info("[VOC] Parsing xml annotations files") + n_objs_list = [] + objs_info_list = [] # Darknet Format list of string + objs_info_dicts = {} + for idx, ann_file in enumerate(imgs_ann_file_list): + n_objs, objs_info = convert_annotation(ann_file) + n_objs_list.append(n_objs) + objs_info_list.append(objs_info) + with tf.io.gfile.GFile(ann_file, 'r') as fid: + xml_str = fid.read() + xml = etree.fromstring(xml_str) + data = _recursive_parse_xml_to_dict(xml)['annotation'] + objs_info_dicts.update({imgs_file_list[idx]: data}) + + return imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list, classes, classes_in_person, classes_dict, n_objs_list, objs_info_list, objs_info_dicts diff --git a/tensorlayer/data/dataset/wmt_en_fr.py b/tensorlayer/data/dataset/wmt_en_fr.py new file mode 100644 index 000000000..5a28e279b --- /dev/null +++ b/tensorlayer/data/dataset/wmt_en_fr.py @@ -0,0 +1,80 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import gzip +import os +import tarfile + +from tensorflow.python.platform import gfile +from tensorlayer import logging +from ..utils import maybe_download_and_extract + +__all__ = ['load_wmt_en_fr_dataset'] + +WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar" +WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz" + + +def load_wmt_en_fr_dataset(path='data', name='wmt_en_fr'): + """Load WMT'15 English-to-French translation dataset. + + It will download the data from the WMT'15 Website (10^9-French-English corpus), and the 2013 news test from the same site as development set. + Returns the directories of training data and test data. + + Parameters + ---------- + path : str + The path that the data is downloaded to, defaults is ``data/wmt_en_fr/``. + + References + ---------- + - Code modified from /tensorflow/models/rnn/translation/data_utils.py + + Notes + ----- + Usually, it will take a long time to download this dataset. + + """ + path = os.path.join(path, name) + # URLs for WMT data. + + def gunzip_file(gz_path, new_path): + """Unzips from gz_path into new_path.""" + logging.info("Unpacking %s to %s" % (gz_path, new_path)) + with gzip.open(gz_path, "rb") as gz_file: + with open(new_path, "wb") as new_file: + for line in gz_file: + new_file.write(line) + + def get_wmt_enfr_train_set(path): + """Download the WMT en-fr training corpus to directory unless it's there.""" + filename = "training-giga-fren.tar" + maybe_download_and_extract(filename, path, WMT_ENFR_TRAIN_URL, extract=True) + train_path = os.path.join(path, "giga-fren.release2.fixed") + gunzip_file(train_path + ".fr.gz", train_path + ".fr") + gunzip_file(train_path + ".en.gz", train_path + ".en") + return train_path + + def get_wmt_enfr_dev_set(path): + """Download the WMT en-fr training corpus to directory unless it's there.""" + filename = "dev-v2.tgz" + dev_file = maybe_download_and_extract(filename, path, WMT_ENFR_DEV_URL, extract=False) + dev_name = "newstest2013" + dev_path = os.path.join(path, "newstest2013") + if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")): + logging.info("Extracting tgz file %s" % dev_file) + with tarfile.open(dev_file, "r:gz") as dev_tar: + fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr") + en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en") + fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix. + en_dev_file.name = dev_name + ".en" + dev_tar.extract(fr_dev_file, path) + dev_tar.extract(en_dev_file, path) + return dev_path + + logging.info("Load or Download WMT English-to-French translation > {}".format(path)) + + train_path = get_wmt_enfr_train_set(path) + dev_path = get_wmt_enfr_dev_set(path) + + return train_path, dev_path diff --git a/tensorlayer/data/parallel.py b/tensorlayer/data/parallel.py new file mode 100644 index 000000000..e0abc0f14 --- /dev/null +++ b/tensorlayer/data/parallel.py @@ -0,0 +1,194 @@ +import multiprocessing +import os +import sys +import uuid + +import zmq +import numpy as np + +from .base import DatasetWrapper +from .serialize import * +from .utils import clean_up_socket_files + + +class MultiprocessDataset(DatasetWrapper): + def __init__(self, + ds, + num_worker, + num_prefetch, + shuffle=False): + + super(MultiprocessDataset, self).__init__(ds) + self.num_worker = num_worker + self.num_prefetch = num_prefetch + self.shuffle = shuffle + + self.index_queue = multiprocessing.Queue(self.num_worker) + self.data_queue = multiprocessing.Queue(self.num_prefetch) + self.put_idx_worker = None + for _ in range(num_worker): + worker = multiprocessing.Process(target=self._data_worker, + args=(self.ds, self.index_queue, self.data_queue)) + worker.daemon = True + worker.start() + + def _data_worker(self, ds, index_q, data_q): + while True: + idx = index_q.get() + data_q.put((idx, ds[idx])) + + def _put_idx(self, index_q, idx): + index_q.put(idx) + + def __iter__(self): + # clear queues from previous epoch + while not self.index_queue.empty(): + self.index_queue.get() + while not self.data_queue.empty(): + self.data_queue.get() + + # shuffle at the start of every epoch + if self.shuffle: + self.idxs = np.random.permutation(self.ds_len) + else: + self.idxs = np.arange(self.ds_len) + + send_idx_cnt = 0 + for _ in range(self.num_worker * 2): + self._put_idx(self.index_queue, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + + data_buffer = {} + for return_idx in self.idxs: + if return_idx in data_buffer: + yield data_buffer.pop(return_idx) + else: + while True: + idx, dp = self.data_queue.get() + # put new idx after collecting data + if send_idx_cnt < len(self.idxs): + self._put_idx(self.index_queue, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + if idx == return_idx: + yield dp + break + else: + data_buffer[idx] = dp + + +class ZMQMultiprocessDataset(DatasetWrapper): + def __init__(self, + ds, + num_worker, + hwm=50, + shuffle=False): + + super(ZMQMultiprocessDataset, self).__init__(ds) + self.num_worker = num_worker + self.shuffle = shuffle + self._hwm = hwm + + self.idx_pipename = _get_pipe_name('put_idx') + self.data_pipename = _get_pipe_name('collect_data') + + for i in range(num_worker): + # first worker bind the socket, others connect to the socket + # however, zmq sockets using ipc do not care about the order of bind / connect + if i == 0: + worker = multiprocessing.Process(target=self._data_worker, + args=(True,)) + else: + worker = multiprocessing.Process(target=self._data_worker, + args=()) + worker.daemon = True + worker.start() + + clean_up_socket_files([self.idx_pipename, self.data_pipename]) + + def _data_worker(self, bind=False): + context = zmq.Context() + worker_receive_index_socket = context.socket(zmq.PULL) + worker_receive_index_socket.set_hwm(self._hwm) + if bind: + worker_receive_index_socket.bind(self.idx_pipename) + else: + worker_receive_index_socket.connect(self.idx_pipename) + + worker_send_data_socket = context.socket(zmq.PUSH) + worker_send_data_socket.set_hwm(self._hwm) + if bind: + worker_send_data_socket.bind(self.data_pipename) + else: + worker_send_data_socket.connect(self.data_pipename) + + while True: + recv_msg = worker_receive_index_socket.recv(copy=False) + idx = load_from_bytes(recv_msg) + send_msg = convert_to_bytes({'idx': idx, 'data': self.ds[idx]}) + worker_send_data_socket.send(send_msg, copy=False) + + def _put_idx(self, put_idx_socket, idx): + send_msg = convert_to_bytes(idx) + put_idx_socket.send(send_msg, copy=False) + + def __iter__(self): + context = zmq.Context() + collect_data_socket = context.socket(zmq.PULL) + collect_data_socket.set_hwm(self._hwm) + collect_data_socket.connect(self.data_pipename) + + put_idx_socket = context.socket(zmq.PUSH) + put_idx_socket.set_hwm(self._hwm) + put_idx_socket.connect(self.idx_pipename) + + # shutdown put_idx_worker and clear queues from previous epoch + try: + while True: + collect_data_socket.recv(flags=zmq.NOBLOCK) + except zmq.ZMQError: + pass + + # shuffle at the start of every epoch + if self.shuffle: + self.idxs = np.random.permutation(self.ds_len) + else: + self.idxs = np.arange(self.ds_len) + + send_idx_cnt = 0 + for _ in range(self.num_worker * 2): + self._put_idx(put_idx_socket, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + + data_buffer = {} + for return_idx in self.idxs: + if return_idx in data_buffer: + yield data_buffer.pop(return_idx) + else: + while True: + recv_msg = collect_data_socket.recv(copy=False) + recv_msg = load_from_bytes(recv_msg) + idx, dp = recv_msg['idx'], recv_msg['data'] + # put new idx after collecting data + if send_idx_cnt < len(self.idxs): + self._put_idx(put_idx_socket, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + if idx == return_idx: + yield dp + break + else: + data_buffer[idx] = dp + + +def _get_pipe_name(name): + if sys.platform.startswith('linux'): + # linux supports abstract sockets: http://api.zeromq.org/4-1:zmq-ipc + pipename = "ipc://@{}-pipe-{}".format(name, str(uuid.uuid1())[:8]) + else: + pipedir = '.' + assert os.path.isdir(pipedir), pipedir + filename = '{}/{}-pipe-{}'.format(pipedir.rstrip('/'), name, str(uuid.uuid1())[:6]) + assert not os.path.exists(filename), "Pipe {} exists! You may be unlucky.".format(filename) + pipename = "ipc://{}".format(filename) + # register in environment variable, used for cleaning up ipc socket files + # os.environ[name] = pipename + return pipename diff --git a/tensorlayer/data/serialize.py b/tensorlayer/data/serialize.py new file mode 100644 index 000000000..aa272f4f4 --- /dev/null +++ b/tensorlayer/data/serialize.py @@ -0,0 +1,27 @@ +import msgpack_numpy + +MAX_MSGPACK_LEN = 1000000000 + + +def convert_to_bytes(obj): + """ + Serialize an object. + + Returns: + Implementation-dependent bytes-like object. + """ + return msgpack_numpy.dumps(obj, use_bin_type=True) + + +def load_from_bytes(buf): + """ + Args: + buf: the output of `dumps`. + """ + # Since 0.6, the default max size was set to 1MB. + # We change it to approximately 1G. + return msgpack_numpy.loads(buf, raw=False, + max_bin_len=MAX_MSGPACK_LEN, + max_array_len=MAX_MSGPACK_LEN, + max_map_len=MAX_MSGPACK_LEN, + max_str_len=MAX_MSGPACK_LEN) diff --git a/tensorlayer/data/utils.py b/tensorlayer/data/utils.py new file mode 100644 index 000000000..f896b28e7 --- /dev/null +++ b/tensorlayer/data/utils.py @@ -0,0 +1,409 @@ +import atexit +import logging +import math +import multiprocessing +import os +import platform +import re +import resource +import shutil +import weakref + +import psutil +import tarfile +import time +import zipfile +import progressbar +from urllib.request import urlretrieve + + +def load_folder_list(path=""): + """Return a folder list in a folder by given a folder path. + + Parameters + ---------- + path : str + A folder path. + + """ + return [os.path.join(path, o) for o in os.listdir(path) if os.path.isdir(os.path.join(path, o))] + + +def exists_or_mkdir(path, verbose=True): + """ + Check a folder by given name, if not exist, create the folder and return False, + if directory exists, return True. + + Parameters + ---------- + path : str + A folder path. + verbose : boolean + If True (default), prints results. + + Returns + -------- + boolean + True if folder already exist, otherwise, returns False and create the folder. + + Examples + -------- + >>> exists_or_mkdir("checkpoints/train") + + """ + if not os.path.exists(path): + if verbose: + logging.info("[*] creates %s ..." % path) + os.makedirs(path) + return False + else: + if verbose: + logging.info("[!] %s exists ..." % path) + return True + + +def download(filename, working_directory, url_source): + """ + Download file from url_source to the working_directory with given filename. + + Parameters + ---------- + filename : str + The name of the downloaded file. + working_directory : str + A folder path download the file to + url_source : str + The URL to download the file from + + Examples + -------- + >>> download(filename='train.gz', + ... working_directory='data/', + ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') + + """ + working_directory = os.path.expanduser(working_directory) + + progress_bar = progressbar.ProgressBar() + + def _dlProgress(count, blockSize, totalSize, pbar=progress_bar): + if (totalSize != 0): + + if not pbar.max_value: + totalBlocks = math.ceil(float(totalSize) / float(blockSize)) + pbar.max_value = int(totalBlocks) + + pbar.update(count, force=True) + + filepath = os.path.join(working_directory, filename) + + logging.info('Downloading %s...\n' % filename) + + urlretrieve(url_source + filename, filepath, reporthook=_dlProgress) + + +def maybe_download_and_extract(filename, working_directory, url_source, extract=False, expected_bytes=None): + """ + Checks if file exists in working_directory otherwise tries to dowload the file, + and optionally also tries to extract the file if format is ".zip" or ".tar" + + Parameters + ----------- + filename : str + The name of the (to be) dowloaded file. + working_directory : str + A folder path to search for the file in and dowload the file to + url_source : str + The URL to download the file from + extract : boolean + If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file, default is False. + expected_bytes : int or None + If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception, defaults is None which corresponds to no check being performed. + + Returns + ---------- + str + File path of the dowloaded (uncompressed) file. + + Examples + -------- + >>> down_file = maybe_download_and_extract(filename='train.gz', + ... working_directory='data/', + ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') + >>> maybe_download_and_extract(filename='ADEChallengeData2016.zip', + ... working_directory='data/', + ... url_source='http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip', + ... extract=True) + + """ + working_directory = os.path.expanduser(working_directory) + exists_or_mkdir(working_directory, verbose=False) + filepath = os.path.join(working_directory, filename) + + if not os.path.exists(filepath): + download(filename, working_directory, url_source) + statinfo = os.stat(filepath) + logging.info('Succesfully downloaded %s %s bytes.' % (filename, statinfo.st_size)) # , 'bytes.') + if not (expected_bytes is None) and (expected_bytes != statinfo.st_size): + raise Exception('Failed to verify ' + filename + '. Can you get to it with a browser?') + if extract: + if tarfile.is_tarfile(filepath): + logging.info('Trying to extract tar file') + tarfile.open(filepath, 'r').extractall(working_directory) + logging.info('... Success!') + elif zipfile.is_zipfile(filepath): + logging.info('Trying to extract zip file') + with zipfile.ZipFile(filepath) as zf: + zf.extractall(working_directory) + logging.info('... Success!') + else: + logging.info("Unknown compression_format only .tar.gz/.tar.bz2/.tar and .zip supported") + return filepath + + +def natural_keys(text): + """Sort list of string with number in human order. + + Examples + ---------- + >>> l = ['im1.jpg', 'im31.jpg', 'im11.jpg', 'im21.jpg', 'im03.jpg', 'im05.jpg'] + >>> l.sort(key=tl.files.natural_keys) + ['im1.jpg', 'im03.jpg', 'im05', 'im11.jpg', 'im21.jpg', 'im31.jpg'] + >>> l.sort() # that is what we dont want + ['im03.jpg', 'im05', 'im1.jpg', 'im11.jpg', 'im21.jpg', 'im31.jpg'] + + References + ---------- + - `link `__ + + """ + + # - alist.sort(key=natural_keys) sorts in human order + # http://nedbatchelder.com/blog/200712/human_sorting.html + # (See Toothy's implementation in the comments) + def atoi(text): + return int(text) if text.isdigit() else text + + return [atoi(c) for c in re.split('(\d+)', text)] + + +def get_dataloader_speed(dl, num_steps): + cnt = 0 + start = time.time() + end = start + for _ in dl: + cnt += 1 + if cnt == num_steps: + end = time.time() + break + return (end - start) / num_steps + + +def format_bytes(bytes): + if abs(bytes) < 1000: + return str(bytes) + "B" + elif abs(bytes) < 1e6: + return str(round(bytes / 1e3, 2)) + "kB" + elif abs(bytes) < 1e9: + return str(round(bytes / 1e6, 2)) + "MB" + else: + return str(round(bytes / 1e9, 2)) + "GB" + + +def get_process_memory(): + process = psutil.Process(os.getpid()) + mi = process.memory_info() + return mi.rss, mi.vms, mi.shared + + +def get_peak_memory_usage(): + # peak memory usage (bytes on OS X, kilobytes on Linux) + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + platform_name = platform.system().lower() + + # If we are on linux + if platform_name== "linux" or platform_name == "linux2": + return format_bytes(rss * 1024) + + # If we are on Mac OS X + elif platform_name == "darwin": + return format_bytes(rss) + + # We don't support Windows + elif platform_name == "win32": + raise EnvironmentError("The Windows operating system is not supported") + + # Unrecognized platform + else: + raise EnvironmentError("Unrecognized platform") + + +def shutdown_proc(proc): + if proc is None: + return + if proc.is_alive(): + proc.terminate() + proc.join() + + +def shutdown_proc_by_weakref(ref): + proc = ref() + if proc is None: + return + if proc.is_alive(): + proc.terminate() + proc.join() + + +def ensure_subprocess_terminate(proc): + """ + Make sure subprocesses terminate when main process exit. + + Args: + proc (multiprocessing.Process or list) + """ + if isinstance(proc, list): + for p in proc: + ensure_subprocess_terminate(p) + return + + assert isinstance(proc, multiprocessing.Process) + atexit.register(shutdown_proc_by_weakref, weakref.ref(proc)) + + +def clean_up_socket_files(pipe_names): + if isinstance(pipe_names, list): + for pipe_name in pipe_names: + clean_up_socket_files(pipe_name) + return + + def remove_socket_files(pipe_name): + # remove all ipc socket files + # the environment variable starts with 'ipc://', so file name starts from 6 + try: + os.remove(pipe_name[6:]) + except (FileNotFoundError, KeyError): + pass + + atexit.register(remove_socket_files, pipe_names) + + +def download_file_from_google_drive(ID, destination): + """Download file from Google Drive. + + See ``tl.files.load_celebA_dataset`` for example. + + Parameters + -------------- + ID : str + The driver ID. + destination : str + The destination for save file. + + """ + try: + from tqdm import tqdm + except ImportError as e: + print(e) + raise ImportError("Module tqdm not found. Please install tqdm via pip or other package managers.") + + try: + import requests + except ImportError as e: + print(e) + raise ImportError("Module requests not found. Please install requests via pip or other package managers.") + + def save_response_content(response, destination, chunk_size=32 * 1024): + + total_size = int(response.headers.get('content-length', 0)) + with open(destination, "wb") as f: + for chunk in tqdm(response.iter_content(chunk_size), total=total_size, unit='B', unit_scale=True, + desc=destination): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + URL = "https://docs.google.com/uc?export=download" + session = requests.Session() + + response = session.get(URL, params={'id': ID}, stream=True) + token = get_confirm_token(response) + + if token: + params = {'id': ID, 'confirm': token} + response = session.get(URL, params=params, stream=True) + save_response_content(response, destination) + + +def load_file_list(path=None, regx='\.jpg', printable=True, keep_prefix=False): + r"""Return a file list in a folder by given a path and regular expression. + + Parameters + ---------- + path : str or None + A folder path, if `None`, use the current directory. + regx : str + The regx of file name. + printable : boolean + Whether to print the files infomation. + keep_prefix : boolean + Whether to keep path in the file name. + + Examples + ---------- + >>> file_list = load_file_list(path=None, regx='w1pre_[0-9]+\.(npz)') + + """ + if path is None: + path = os.getcwd() + file_list = os.listdir(path) + return_list = [] + for _, f in enumerate(file_list): + if re.search(regx, f): + return_list.append(f) + # return_list.sort() + if keep_prefix: + for i, f in enumerate(return_list): + return_list[i] = os.path.join(path, f) + + if printable: + logging.info('Match file list = %s' % return_list) + logging.info('Number of files = %d' % len(return_list)) + return return_list + + +def file_exists(filepath): + """Check whether a file exists by given file path.""" + return os.path.isfile(filepath) + + +def folder_exists(folderpath): + """Check whether a folder exists by given folder path.""" + return os.path.isdir(folderpath) + + +def del_folder(folderpath): + """Delete a folder by given folder path.""" + shutil.rmtree(folderpath) + + +def del_file(filepath): + """Delete a file by given file path.""" + os.remove(filepath) + + +def read_file(filepath): + """Read a file and return a string. + + Examples + --------- + >>> data = read_file('data.txt') + """ + with open(filepath, 'r') as afile: + return afile.read() From 7e4162206942926e3376b5830bacf500b3f315f0 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Sat, 7 Sep 2019 12:50:42 +0100 Subject: [PATCH 3/6] remove tl.dataflow --- tensorlayer/dataflow/__init__.py | 4 - tensorlayer/dataflow/base.py | 73 ----- tensorlayer/dataflow/common.py | 351 ----------------------- tensorlayer/dataflow/dataset/__init__.py | 5 - tensorlayer/dataflow/dataset/cifar10.py | 58 ---- tensorlayer/dataflow/dataset/ilsvrc.py | 309 -------------------- tensorlayer/dataflow/dataset/mnist.py | 62 ---- tensorlayer/dataflow/parallel.py | 198 ------------- tensorlayer/dataflow/serialize.py | 27 -- tensorlayer/dataflow/utils.py | 199 ------------- 10 files changed, 1286 deletions(-) delete mode 100644 tensorlayer/dataflow/__init__.py delete mode 100644 tensorlayer/dataflow/base.py delete mode 100644 tensorlayer/dataflow/common.py delete mode 100644 tensorlayer/dataflow/dataset/__init__.py delete mode 100644 tensorlayer/dataflow/dataset/cifar10.py delete mode 100644 tensorlayer/dataflow/dataset/ilsvrc.py delete mode 100644 tensorlayer/dataflow/dataset/mnist.py delete mode 100644 tensorlayer/dataflow/parallel.py delete mode 100644 tensorlayer/dataflow/serialize.py delete mode 100644 tensorlayer/dataflow/utils.py diff --git a/tensorlayer/dataflow/__init__.py b/tensorlayer/dataflow/__init__.py deleted file mode 100644 index fec880c7f..000000000 --- a/tensorlayer/dataflow/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import Dataset -from .base import Transform -from .common import Dataloader -from .common import TFDataloader diff --git a/tensorlayer/dataflow/base.py b/tensorlayer/dataflow/base.py deleted file mode 100644 index 3caacf6b7..000000000 --- a/tensorlayer/dataflow/base.py +++ /dev/null @@ -1,73 +0,0 @@ -class Dataset(object): - - def __getitem__(self, index): - raise NotImplementedError("A Dataset must implement __getitem__(index) method.") - - def __len__(self): - raise NotImplementedError("A Dataset must implement __len__() method.") - - def __iter__(self): - for i in range(self.__len__()): - yield self.__getitem__(i) - - def __call__(self, *args, **kwargs): - return self.__iter__() - - -class DatasetWrapper(object): - def __init__(self, ds): - self.ds = ds - self.ds_len = len(ds) - - def __len__(self): - return len(self.ds) - - def __iter__(self): - for dp in self.ds: - yield dp - - def __call__(self, *args, **kwargs): - return self.__iter__() - - -class IndexableDatasetWrapper(object): - def __init__(self, ds): - self.ds = ds - self.ds_len = len(ds) - - def __getitem__(self, index): - return self.ds.__getitem__(index) - - def __len__(self): - return len(self.ds) - - def __call__(self, *args, **kwargs): - return self - - -class Transform(object): - def __call__(self, *args, **kwargs): - raise NotImplementedError("Transform must implement __call__() method.") - - -class _Transforms_for_tf_dataset(object): - """ - This class aggregate Transforms into one object in order to use tf.data.Dataset.map API - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, *args): - # assert len(args) == len(self.transforms) - # data_list = [None] * len(args) - # for i in range(len(args)): - # data = args[i] - # for transform in self.transforms[i]: - # data = transform(data) - # data_list[i] = data - # return data_list - data_list = list(args) - for transform in self.transforms: - data_list = transform(*data_list) - return data_list diff --git a/tensorlayer/dataflow/common.py b/tensorlayer/dataflow/common.py deleted file mode 100644 index a2e4c5adc..000000000 --- a/tensorlayer/dataflow/common.py +++ /dev/null @@ -1,351 +0,0 @@ -import atexit -import math -import multiprocessing -import os - -import tensorflow as tf -import zmq - -import numpy as np - -from .base import IndexableDatasetWrapper, DatasetWrapper, _Transforms_for_tf_dataset -from .parallel import _get_pipe_name, ZMQMultiprocessDataset, MultiprocessDataset -from .utils import ensure_proc_terminate -from .serialize import convert_to_bytes, load_from_bytes - -__all__ = ['BatchedDataset', 'TransformedDataset', 'ShuffledDataset', - 'AugmentedDataset', 'Dataloader', 'TFDataloader'] - - -class BatchedDataset(DatasetWrapper): - def __init__(self, - ds, - batch_size, - drop_remainder=True, - return_numpy=True, - keep_dims=False, - output_types=None, - use_zmq=True): - super(BatchedDataset, self).__init__(ds) - self.batch_size = batch_size - self.drop_remainder = drop_remainder - self.return_numpy = return_numpy - self.keep_dims = keep_dims - self.output_types = output_types - self.use_zmq = use_zmq - - # self.q = multiprocessing.Queue(maxsize=1) - # self.worker = multiprocessing.Process(target=self._BatchedDataset_worker, - # args=(self.ds, self.q)) - # self.worker.start() - # ensure_proc_terminate(self.worker) - - if self.use_zmq: - self.data_pipename = _get_pipe_name('batch_prefetch') - context = zmq.Context() - self.fetch_data_socket = context.socket(zmq.PULL) - self.fetch_data_socket.bind(self.data_pipename) - self.worker = multiprocessing.Process(target=self._ZMQ_BatchedDataset_worker, - args=(self.ds,)) - self.worker.start() - else: - pipe_output, pipe_input = multiprocessing.Pipe() - self.worker = multiprocessing.Process(target=self._BatchedDataset_worker, - args=(self.ds, (pipe_output, pipe_input))) - self.worker.start() - # main process only reads (gets output) - pipe_input.close() - self.pipe_output = pipe_output - - ensure_proc_terminate(self.worker) - - def _ZMQ_BatchedDataset_worker(self, ds): - context = zmq.Context() - prepare_data_socket = context.socket(zmq.PUSH) - prepare_data_socket.connect(self.data_pipename) - while True: - dp_buffer = [] - for dp in ds: - dp_buffer.append(dp) - if len(dp_buffer) == self.batch_size: - # q.put(self._batch_datapoints(dp_buffer)) - prepare_data_socket.send(convert_to_bytes(self._batch_datapoints(dp_buffer)), copy=False) - del dp_buffer[:] - if not self.drop_remainder: - # q.put(self._batch_datapoints(dp_buffer)) - prepare_data_socket.send(convert_to_bytes(self._batch_datapoints(dp_buffer)), copy=False) - - def _BatchedDataset_worker(self, ds, pipe): - pipe_output, pipe_input = pipe - # worker process only writes (puts input) - pipe_output.close() - while True: - dp_buffer = [] - for dp in ds: - dp_buffer.append(dp) - if len(dp_buffer) == self.batch_size: - # q.put(self._batch_datapoints(dp_buffer)) - pipe_input.send(self._batch_datapoints(dp_buffer)) - del dp_buffer[:] - if not self.drop_remainder: - # q.put(self._batch_datapoints(dp_buffer)) - pipe_input.send(self._batch_datapoints(dp_buffer)) - - def __iter__(self): - for _ in range(self.__len__()): - # yield self.q.get() - if self.use_zmq: - yield load_from_bytes(self.fetch_data_socket.recv(copy=False)) - else: - yield self.pipe_output.recv() - - def __len__(self): - ds_len = len(self.ds) - if self.drop_remainder: - return ds_len // self.batch_size - else: - return math.ceil(ds_len / self.batch_size) - - def _batch_datapoints(self, dp_buffer): - """ - - :param dp_buffer: a list of datapoints - :return: - """ - first_dp = dp_buffer[0] - if isinstance(first_dp, (tuple, list)): - dp_batch = [None] * len(first_dp) - for i in range(len(first_dp)): - dp_element_batch = [] - for j in range(len(dp_buffer)): - dp_element_batch.append(dp_buffer[j][i]) - if self.return_numpy: - dp_batch[i] = self._batch_ndarray(dp_element_batch, dtype=self._get_element_dtype(i)) - else: - dp_batch[i] = dp_element_batch - return dp_batch - elif isinstance(first_dp, dict): - dp_batch = {} - for key in first_dp.keys(): - dp_element_batch = [] - for j in range(len(dp_buffer)): - dp_element_batch.append(dp_buffer[j][key]) - if self.return_numpy: - dp_batch[key] = self._batch_ndarray(dp_element_batch, dtype=None) - else: - dp_batch[key] = dp_element_batch - return dp_batch - elif isinstance(first_dp, np.ndarray): - return self._batch_ndarray(dp_buffer) - # single elements - else: - if self.return_numpy: - return self._batch_ndarray(dp_buffer, dtype=self._get_element_dtype(0)) - else: - return dp_buffer - - def _batch_ndarray(self, dp_element_batch, dtype): - """ - - :param dp_element_batch: a list of datapoint element, an element can be np.ndarray / list - :return: np.ndarray, type is the same as input - """ - try: - if dtype is not None: - ret = np.asarray(dp_element_batch, dtype=dtype) - else: - ret = np.asarray(dp_element_batch) - if self.keep_dims and len(ret.shape) == 1: - ret = np.expand_dims(ret, 1) - return ret - except: - raise ValueError("Unsupported type for batching.") - - def _get_element_dtype(self, i): - if self.output_types is None: - return None - if not isinstance(self.output_types, (tuple, list)): - return self.output_types - if len(self.output_types) == 1: - return self.output_types[0] - return self.output_types[i] - - -class ShuffledDataset(DatasetWrapper): - def __init__(self, ds): - super(ShuffledDataset, self).__init__(ds) - - def __iter__(self): - self.shuffled_idxs = np.random.permutation(len(self.ds)) - for index, data in enumerate(self.ds): - yield self.ds[self.shuffled_idxs[index]] - - -class TransformedDataset(IndexableDatasetWrapper): - """ - - """ - - def __init__(self, ds, transforms): - super(TransformedDataset, self).__init__(ds) - self.transforms = transforms - - def __getitem__(self, index): - dp = self.ds[index] - for transform in self.transforms: - assert callable(transform) - if isinstance(dp, (list, tuple)): - dp = transform(*dp) - else: - dp = transform(dp) - return dp - - -class AugmentedDataset(IndexableDatasetWrapper): - def __init__(self, ds, augmentations): - super(AugmentedDataset, self).__init__(ds) - self.augmentations = augmentations - self.num_augmentations = len(self.augmentations) - - def __getitem__(self, index): - if index >= self.__len__(): - raise IndexError - dp = self.ds[index % self.ds_len] - if index < self.ds_len: - return dp - augmentation = self.augmentations[(index // self.ds_len) - 1] - assert callable(augmentation) - if isinstance(dp, (list, tuple)): - return augmentation(*dp) - else: - return augmentation(dp) - - def __len__(self): - # every augmentation gives one more duplication of dataset - return self.ds_len * (1 + self.num_augmentations) - - -class Dataloader(DatasetWrapper): - def __init__(self, - ds, - augmentations=None, - shuffle=False, - batch_size=1, - drop_remainder=True, - batch_keep_dims=False, - output_types=None, - num_worker=os.cpu_count(), - use_zmq=True, - num_prefetch=None, - transforms=None): - - super(Dataloader, self).__init__(ds) - self.augmentations = augmentations - self.shuffle = shuffle - self.batch_size = batch_size - self.drop_remainder = drop_remainder - self.batch_keep_dims = batch_keep_dims - self.output_types = output_types - self.num_worker = num_worker - self.use_zmq = use_zmq - self.num_prefetch = num_worker if num_prefetch is None else num_prefetch - self.transforms = transforms - - if self.augmentations is not None: - self.ds = AugmentedDataset(self.ds, self.augmentations) - - if self.transforms is not None: - self.ds = TransformedDataset(self.ds, self.transforms) - # self.tfds = self.tfds.map(map_func=_Transforms(self.transforms), num_parallel_calls=num_map_worker) - - # TODO: auto adjust num_prefetch - if self.num_worker > 1: - if self.use_zmq: - self.ds = ZMQMultiprocessDataset(self.ds, num_worker=self.num_worker, hwm=self.num_prefetch, - shuffle=self.shuffle) - else: - self.ds = MultiprocessDataset(self.ds, num_worker=self.num_worker, num_prefetch=self.num_prefetch, - shuffle=self.shuffle) - elif self.shuffle: - self.ds = ShuffledDataset(self.ds) - - self.ds = BatchedDataset(self.ds, self.batch_size, drop_remainder=self.drop_remainder, - output_types=self.output_types, keep_dims=self.batch_keep_dims, - use_zmq=self.use_zmq) - - # self.tfds = tf.data.Dataset.from_generator(self.ds, output_types=output_types) - - # if self.num_prefetch > 1: - # self.tfds = self.tfds.prefetch(num_prefetch) - atexit.register(self._clean_up_socket_files) - - def __iter__(self): - for dp in self.ds: - yield dp - - def _clean_up_socket_files(self): - # remove all ipc socket files - # the environment variable starts with 'ipc://', so file name starts from 6 - try: - os.remove(os.environ['put_idx'][6:]) - except FileNotFoundError: - pass - try: - os.remove(os.environ['collect_data'][6:]) - except FileNotFoundError: - pass - try: - os.remove(os.environ['batch_prefetch'][6:]) - except FileNotFoundError: - pass - - -class TFDataloader(DatasetWrapper): - def __init__(self, - ds, - output_types, - augmentations=None, - shuffle=False, - shuffle_buffer_size=None, - batch_size=1, - drop_remainder=True, - # num_extract_worker=os.cpu_count(), - # num_map_worker=os.cpu_count(), - # num_prefetch=None, - transforms=None): - - super(TFDataloader, self).__init__(ds) - self.augmentations = augmentations - self.shuffle = shuffle - self.batch_size = batch_size - self.shuffle_buffer_size = 2 * batch_size if shuffle_buffer_size is None else shuffle_buffer_size - self.drop_remainder = drop_remainder - # self.num_map_worker = num_map_worker - # self.num_extract_worker = num_extract_worker - # self.num_prefetch = num_extract_worker if num_prefetch is None else num_prefetch - self.transforms = transforms - - self.ds = tf.data.Dataset.from_generator(self.ds, output_types=output_types) - - # if self.augmentations is not None: - # self.ds = AugmentedDataset(self.ds, self.augmentations) - - # if self.num_extract_worker > 1: - # self.ds = MultiProcessDataset(self.ds, num_worker=self.num_extract_worker, num_prefetch=self.num_prefetch) - - if self.shuffle: - self.ds = self.ds.shuffle(buffer_size=self.shuffle_buffer_size) - - if self.transforms is not None: - self.ds = self.ds.map(map_func=_Transforms_for_tf_dataset(self.transforms), - num_parallel_calls=tf.data.experimental.AUTOTUNE) - - if self.batch_size > 1: - self.ds = self.ds.batch(batch_size=self.batch_size, drop_remainder=self.drop_remainder) - - # if self.num_prefetch > 1: - self.ds = self.ds.prefetch(tf.data.experimental.AUTOTUNE) - - def __iter__(self): - for dp in self.ds: - yield dp diff --git a/tensorlayer/dataflow/dataset/__init__.py b/tensorlayer/dataflow/dataset/__init__.py deleted file mode 100644 index d95f8afa6..000000000 --- a/tensorlayer/dataflow/dataset/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .mnist import MNIST -from .cifar10 import CIFAR10 -from .ilsvrc import ILSVRC12, ILSVRC12Files, ILSVRCMeta - -__all__ = ['MNIST', 'CIFAR10', 'ILSVRCMeta', 'ILSVRC12Files', 'ILSVRC12'] \ No newline at end of file diff --git a/tensorlayer/dataflow/dataset/cifar10.py b/tensorlayer/dataflow/dataset/cifar10.py deleted file mode 100644 index f6ca6209f..000000000 --- a/tensorlayer/dataflow/dataset/cifar10.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging -import os -import pickle -import sys -import numpy as np - -from ..base import Dataset -from ..utils import maybe_download_and_extract - -CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' - - -class CIFAR10(Dataset): - def __init__(self, train_or_test, path='data', name='cifar10'): - self.path = os.path.join(path, name) - - # Helper function to unpickle the data - def unpickle(file): - fp = open(file, 'rb') - if sys.version_info.major == 2: - data = pickle.load(fp) - elif sys.version_info.major == 3: - data = pickle.load(fp, encoding='latin-1') - else: - raise RuntimeError("Sys Version Unsupported") - fp.close() - return data - - # Download and read the training and test set images and labels. - logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) - - filename = 'cifar-10-python.tar.gz' - maybe_download_and_extract(filename, path, CIFAR10_URL, extract=True) - - assert train_or_test in ['train', 'test'] - if train_or_test == 'train': - # Unpickle file and fill in data - self.images = None - self.labels = [] - for i in range(1, 6): - data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) - if i == 1: - self.images = data_dic['data'] - else: - self.images = np.vstack((self.images, data_dic['data'])) - self.labels += data_dic['labels'] - else: - test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) - self.images = test_data_dic['data'] - self.labels = np.array(test_data_dic['labels']) - - self.images = self.images.reshape((-1, 32, 32, 3)) - - def __len__(self): - return self.images.shape[0] - - def __getitem__(self, index): - return self.images[index], self.labels[index] diff --git a/tensorlayer/dataflow/dataset/ilsvrc.py b/tensorlayer/dataflow/dataset/ilsvrc.py deleted file mode 100644 index e75b973f7..000000000 --- a/tensorlayer/dataflow/dataset/ilsvrc.py +++ /dev/null @@ -1,309 +0,0 @@ -import os -import logging -import cv2 - -from ..base import Dataset -from ..utils import maybe_download_and_extract - -__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files'] - -CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" - - -class ILSVRCMeta(object): - """ - Provide methods to access metadata for ILSVRC dataset. - Metadata is supposed to be found at/will be downloaded to 'path/name/' - - Parameters - ---------- - path : str - a folder path - name : str - name of the dataset - - Examples - -------- - >>> meta = ILSVRCMeta(path='data', name='ilsvrc') - >>> imglist = meta.get_image_list(train_or_val_or_test, dir_structure) - - """ - - def __init__(self, path='data', name='ilsvrc'): - path = os.path.expanduser(path) - self.path = os.path.join(path, name) - logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) - self.filepath = maybe_download_and_extract('ilsvrc_meta', self.path, CAFFE_ILSVRC12_URL, extract=True) - self.caffepb = None - - def get_synset_words_1000(self): - """ - Returns: - dict: {cls_number: cls_name} - """ - fname = os.path.join(self.path, 'synset_words.txt') - assert os.path.isfile(fname), fname - lines = [x.strip() for x in open(fname).readlines()] - return dict(enumerate(lines)) - - def get_synset_1000(self): - """ - Returns: - dict: {cls_number: synset_id} - """ - fname = os.path.join(self.path, 'synsets.txt') - assert os.path.isfile(fname) - lines = [x.strip() for x in open(fname).readlines()] - return dict(enumerate(lines)) - - def get_image_list(self, name, dir_structure='original'): - """ - Args: - name (str): 'train' or 'val' or 'test' - dir_structure (str): same as in :meth:`ILSVRC12.__init__()`. - Returns: - list: list of (image filename, label) - """ - assert name in ['train', 'val', 'test'] - assert dir_structure in ['original', 'train'] - add_label_to_fname = (name != 'train' and dir_structure != 'original') - if add_label_to_fname: - synset = self.get_synset_1000() - - fname = os.path.join(self.path, name + '.txt') - assert os.path.isfile(fname), fname - with open(fname) as f: - ret = [] - for line in f.readlines(): - name, cls = line.strip().split() - cls = int(cls) - - if add_label_to_fname: - name = os.path.join(synset[cls], name) - - ret.append((name.strip(), cls)) - assert len(ret), fname - return ret - - # def get_per_pixel_mean(self, size=None): - # """ - # Args: - # size (tuple): image size in (h, w). Defaults to (256, 256). - # Returns: - # np.ndarray: per-pixel mean of shape (h, w, 3 (BGR)) in range [0, 255]. - # """ - # if self.caffepb is None: - # self.caffepb = get_caffe_pb() - # obj = self.caffepb.BlobProto() - # - # mean_file = os.path.join(self.dir, 'imagenet_mean.binaryproto') - # with open(mean_file, 'rb') as f: - # obj.ParseFromString(f.read()) - # arr = np.array(obj.data).reshape((3, 256, 256)).astype('float32') - # arr = np.transpose(arr, [1, 2, 0]) - # if size is not None: - # arr = cv2.resize(arr, size[::-1]) - # return arr - - @staticmethod - def guess_dir_structure(dir): - """ - Return the directory structure of "dir". - - Args: - dir(str): something like '/path/to/imagenet/val' - - Returns: - either 'train' or 'original' - """ - subdir = os.listdir(dir)[0] - # find a subdir starting with 'n' - if subdir.startswith('n') and \ - os.path.isdir(os.path.join(dir, subdir)): - dir_structure = 'train' - else: - dir_structure = 'original' - logging.info( - "[ILSVRC12] Assuming directory {} has '{}' structure.".format( - dir, dir_structure)) - return dir_structure - - -class ILSVRC12Files(Dataset): - """ - Same as :class:`ILSVRC12`, but produces filenames of the images instead of nparrays. - This could be useful when ``cv2.imread`` is a bottleneck and you want to - decode it in smarter ways (e.g. in parallel). - """ - def __init__(self, path, train_or_val_or_test, meta_dir, - dir_structure=None): - """ - Same as in :class:`ILSVRC12`. - """ - assert train_or_val_or_test in ['train', 'test', 'val'] - path = os.path.expanduser(path) - assert os.path.isdir(path) - self.full_path = os.path.join(path, train_or_val_or_test) - self.path = train_or_val_or_test - # assert os.path.isdir(self.full_path) - # assert os.path.isdir(meta_dir) - - if train_or_val_or_test == 'train': - dir_structure = 'train' - elif dir_structure is None: - dir_structure = ILSVRCMeta.guess_dir_structure(self.full_path) - - meta = ILSVRCMeta(meta_dir) - self.imglist = meta.get_image_list(train_or_val_or_test, dir_structure) - - # for fname, _ in self.imglist[:10]: - # fname = os.path.join(self.full_path, fname) - # assert os.path.isfile(fname), fname - - def __len__(self): - return len(self.imglist) - - def __getitem__(self, index): - fname, label = self.imglist[index] - fname = os.path.join(self.full_path, fname) - return fname, label - - # def __iter__(self): - # idxs = np.arange(len(self.imglist)) - # if self.shuffle: - # self.rng.shuffle(idxs) - # for k in idxs: - # fname, label = self.imglist[k] - # fname = os.path.join(self.full_dir, fname) - # yield [fname, label] - - -class ILSVRC12(ILSVRC12Files): - """ - Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999]. - """ - def __init__(self, path, train_or_test, meta_dir, - dir_structure=None, shape=None): - """ - Args: - dir (str): A directory containing a subdir named ``name``, - containing the images in a structure described below. - name (str): One of 'train' or 'val' or 'test'. - shuffle (bool): shuffle the dataset. - Defaults to True if name=='train'. - dir_structure (str): One of 'original' or 'train'. - The directory structure for the 'val' directory. - 'original' means the original decompressed directory, which only has list of image files (as below). - If set to 'train', it expects the same two-level directory structure similar to 'dir/train/'. - By default, it tries to automatically detect the structure. - You probably do not need to care about this option because 'original' is what people usually have. - - Example: - - When `dir_structure=='original'`, `dir` should have the following structure: - - .. code-block:: none - - dir/ - train/ - n02134418/ - n02134418_198.JPEG - ... - ... - val/ - ILSVRC2012_val_00000001.JPEG - ... - test/ - ILSVRC2012_test_00000001.JPEG - ... - - With the downloaded ILSVRC12_img_*.tar, you can use the following - command to build the above structure: - - .. code-block:: none - - mkdir val && tar xvf ILSVRC12_img_val.tar -C val - mkdir test && tar xvf ILSVRC12_img_test.tar -C test - mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train - find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' - - When `dir_structure=='train'`, `dir` should have the following structure: - - .. code-block:: none - - dir/ - train/ - n02134418/ - n02134418_198.JPEG - ... - ... - val/ - n01440764/ - ILSVRC2012_val_00000293.JPEG - ... - ... - test/ - ILSVRC2012_test_00000001.JPEG - ... - """ - super(ILSVRC12, self).__init__( - path, train_or_test, meta_dir, dir_structure) - self.shape = shape - - """ - There are some CMYK / png images, but cv2 seems robust to them. - https://github.com/tensorflow/models/blob/c0cd713f59cfe44fa049b3120c417cc4079c17e3/research/inception/inception/data/build_imagenet_data.py#L264-L300 - """ - # def __iter__(self): - # for fname, label in super(ILSVRC12, self).__iter__(): - # im = cv2.imread(fname, cv2.IMREAD_COLOR) - # assert im is not None, fname - # yield [im, label] - - def __getitem__(self, index): - fname, label = super(ILSVRC12, self).__getitem__(index) - img = cv2.imread(fname, cv2.IMREAD_COLOR) - if self.shape is not None: - img = cv2.resize(img, self.shape) - return img, label - - # @staticmethod - # def get_training_bbox(bbox_dir, imglist): - # import xml.etree.ElementTree as ET - # ret = [] - # - # def parse_bbox(fname): - # root = ET.parse(fname).getroot() - # size = root.find('size').getchildren() - # size = map(int, [size[0].text, size[1].text]) - # - # box = root.find('object').find('bndbox').getchildren() - # box = map(lambda x: float(x.text), box) - # return np.asarray(box, dtype='float32') - # - # with timed_operation('Loading Bounding Boxes ...'): - # cnt = 0 - # for k in tqdm.trange(len(imglist)): - # fname = imglist[k][0] - # fname = fname[:-4] + 'xml' - # fname = os.path.join(bbox_dir, fname) - # try: - # ret.append(parse_bbox(fname)) - # cnt += 1 - # except Exception: - # ret.append(None) - # logger.info("{}/{} images have bounding box.".format(cnt, len(imglist))) - # return ret - - -# if __name__ == '__main__': -# meta = ILSVRCMeta() -# # print(meta.get_synset_words_1000()) -# -# ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False) -# ds.reset_state() -# -# for k in ds: -# from IPython import embed -# embed() -# break diff --git a/tensorlayer/dataflow/dataset/mnist.py b/tensorlayer/dataflow/dataset/mnist.py deleted file mode 100644 index 15db8c951..000000000 --- a/tensorlayer/dataflow/dataset/mnist.py +++ /dev/null @@ -1,62 +0,0 @@ -import gzip -import logging -import os -import numpy as np - -from ..base import Dataset -from ..utils import maybe_download_and_extract - -MNIST_TRAIN_IMAGE_URL = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' -MNIST_TRAIN_LABEL_URL = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz' -MNIST_TEST_IMAGE_URL = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz' -MNIST_TEST_LABEL_URL = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz' - - -class MNIST(Dataset): - def __init__(self, train_or_test, path='data', name='mnist'): - path = os.path.expanduser(path) - self.path = os.path.join(path, name) - - assert train_or_test in ['train', 'test'] - if train_or_test == 'train': - self.images = self.load_mnist_images(train_or_test=train_or_test) - self.labels = self.load_mnist_labels(train_or_test=train_or_test) - else: - self.images = self.load_mnist_images(train_or_test=train_or_test) - self.labels = self.load_mnist_labels(train_or_test=train_or_test) - - def load_mnist_images(self, train_or_test): - if train_or_test == 'train': - filepath = maybe_download_and_extract('train-images-idx3-ubyte.gz', self.path, MNIST_TRAIN_IMAGE_URL) - else: - filepath = maybe_download_and_extract('t10k-images-idx3-ubyte.gz', self.path, MNIST_TEST_IMAGE_URL) - - logging.info(filepath) - # Read the inputs in Yann LeCun's binary format. - with gzip.open(filepath, 'rb') as f: - data = np.frombuffer(f.read(), np.uint8, offset=16) - # The inputs are vectors now, we reshape them to monochrome 2D images, - # following the shape convention: (examples, channels, rows, columns) - data = data.reshape((-1, 28, 28, 1)) - # The inputs come as bytes, we convert them to float32 in range [0,1]. - # (Actually to range [0, 255/256], for compatibility to the version - # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) - return data / np.float32(256) - - def load_mnist_labels(self, train_or_test): - if train_or_test == 'train': - filepath = maybe_download_and_extract('train-labels-idx1-ubyte.gz', self.path, MNIST_TRAIN_LABEL_URL) - else: - filepath = maybe_download_and_extract('t10k-labels-idx1-ubyte.gz', self.path, MNIST_TEST_LABEL_URL) - - # Read the labels in Yann LeCun's binary format. - with gzip.open(filepath, 'rb') as f: - data = np.frombuffer(f.read(), np.uint8, offset=8) - # The labels are vectors of integers now, that's exactly what we want. - return data - - def __len__(self): - return self.images.shape[0] - - def __getitem__(self, index): - return self.images[index], self.labels[index] diff --git a/tensorlayer/dataflow/parallel.py b/tensorlayer/dataflow/parallel.py deleted file mode 100644 index 6eb250e21..000000000 --- a/tensorlayer/dataflow/parallel.py +++ /dev/null @@ -1,198 +0,0 @@ -import multiprocessing -import os -import sys -import uuid - -import zmq -import numpy as np - -from .base import DatasetWrapper -from .serialize import * - - -class MultiprocessDataset(DatasetWrapper): - def __init__(self, - ds, - num_worker, - num_prefetch, - shuffle=False): - - super(MultiprocessDataset, self).__init__(ds) - self.num_worker = num_worker - self.num_prefetch = num_prefetch - self.shuffle = shuffle - - self.index_queue = multiprocessing.Queue(self.num_worker) - self.data_queue = multiprocessing.Queue(self.num_prefetch) - self.put_idx_worker = None - for _ in range(num_worker): - worker = multiprocessing.Process(target=self._worker, - args=(self.ds, self.index_queue, self.data_queue)) - worker.daemon = True - worker.start() - - def _worker(self, ds, index_q, data_q): - while True: - idx = index_q.get() - data_q.put((idx, ds[idx])) - - def _put_idxs(self, idxs, index_q): - for idx in idxs: - index_q.put(idx) - - def __iter__(self): - # shutdown put_idx_worker and clear queues from previous epoch - _shutdown_proc(self.put_idx_worker) - while not self.index_queue.empty(): - self.index_queue.get() - while not self.data_queue.empty(): - self.data_queue.get() - - # shuffle at the start of every epoch - if self.shuffle: - self.idxs = np.random.permutation(self.ds_len) - else: - self.idxs = np.arange(self.ds_len) - - self.put_idx_worker = multiprocessing.Process(target=self._put_idxs, - args=(self.idxs, self.index_queue)) - self.put_idx_worker.daemon = True - self.put_idx_worker.start() - - data_buffer = {} - for return_idx in self.idxs: - if return_idx in data_buffer: - yield data_buffer.pop(return_idx) - else: - while True: - idx, dp = self.data_queue.get() - if idx == return_idx: - yield dp - break - else: - data_buffer[idx] = dp - _shutdown_proc(self.put_idx_worker) - - -def _shutdown_proc(proc): - if proc is None: - return - if proc.is_alive(): - proc.terminate() - proc.join() - - -class ZMQMultiprocessDataset(DatasetWrapper): - def __init__(self, - ds, - num_worker, - hwm=50, - shuffle=False): - - super(ZMQMultiprocessDataset, self).__init__(ds) - self.num_worker = num_worker - self.shuffle = shuffle - self._hwm = hwm - - self.idx_pipename = _get_pipe_name('put_idx') - self.data_pipename = _get_pipe_name('collect_data') - - self.put_idx_worker = None - for i in range(num_worker): - # first worker bind the socket, others connect to the socket - # however, zmq sockets using ipc do not care about the order of bind / connect - if i == 0: - worker = multiprocessing.Process(target=self._worker, - args=(True,)) - else: - worker = multiprocessing.Process(target=self._worker, - args=()) - worker.daemon = True - worker.start() - - def _worker(self, bind=False): - context = zmq.Context() - worker_receive_index_socket = context.socket(zmq.PULL) - worker_receive_index_socket.set_hwm(self._hwm) - if bind: - worker_receive_index_socket.bind(self.idx_pipename) - else: - worker_receive_index_socket.connect(self.idx_pipename) - - worker_send_data_socket = context.socket(zmq.PUSH) - worker_send_data_socket.set_hwm(self._hwm) - if bind: - worker_send_data_socket.bind(self.data_pipename) - else: - worker_send_data_socket.connect(self.data_pipename) - - while True: - recv_msg = worker_receive_index_socket.recv(copy=False) - idx = load_from_bytes(recv_msg) - send_msg = convert_to_bytes({'idx': idx, 'data': self.ds[idx]}) - worker_send_data_socket.send(send_msg, copy=False) - - def _put_idxs(self): - context = zmq.Context() - put_idx_socket = context.socket(zmq.PUSH) - put_idx_socket.set_hwm(self._hwm) - put_idx_socket.connect(self.idx_pipename) - for idx in self.idxs: - send_msg = convert_to_bytes(idx) - put_idx_socket.send(send_msg, copy=False) - - def __iter__(self): - context = zmq.Context() - collect_data_socket = context.socket(zmq.PULL) - collect_data_socket.set_hwm(self._hwm) - collect_data_socket.connect(self.data_pipename) - - # shutdown put_idx_worker and clear queues from previous epoch - _shutdown_proc(self.put_idx_worker) - try: - while True: - collect_data_socket.recv(flags=zmq.NOBLOCK) - except zmq.ZMQError: - pass - - # shuffle at the start of every epoch - if self.shuffle: - self.idxs = np.random.permutation(self.ds_len) - else: - self.idxs = np.arange(self.ds_len) - - self.put_idx_worker = multiprocessing.Process(target=self._put_idxs, - args=()) - self.put_idx_worker.daemon = True - self.put_idx_worker.start() - - data_buffer = {} - for return_idx in self.idxs: - if return_idx in data_buffer: - yield data_buffer.pop(return_idx) - else: - while True: - recv_msg = collect_data_socket.recv(copy=False) - recv_msg = load_from_bytes(recv_msg) - idx, dp = recv_msg['idx'], recv_msg['data'] - if idx == return_idx: - yield dp - break - else: - data_buffer[idx] = dp - _shutdown_proc(self.put_idx_worker) - - -def _get_pipe_name(name): - if sys.platform.startswith('linux'): - # linux supports abstract sockets: http://api.zeromq.org/4-1:zmq-ipc - pipename = "ipc://@{}-pipe-{}".format(name, str(uuid.uuid1())[:8]) - else: - pipedir = '.' - assert os.path.isdir(pipedir), pipedir - filename = '{}/{}-pipe-{}'.format(pipedir.rstrip('/'), name, str(uuid.uuid1())[:6]) - assert not os.path.exists(filename), "Pipe {} exists! You may be unlucky.".format(filename) - pipename = "ipc://{}".format(filename) - # register in environment variable, used for cleaning up ipc socket files - os.environ[name] = pipename - return pipename diff --git a/tensorlayer/dataflow/serialize.py b/tensorlayer/dataflow/serialize.py deleted file mode 100644 index aa272f4f4..000000000 --- a/tensorlayer/dataflow/serialize.py +++ /dev/null @@ -1,27 +0,0 @@ -import msgpack_numpy - -MAX_MSGPACK_LEN = 1000000000 - - -def convert_to_bytes(obj): - """ - Serialize an object. - - Returns: - Implementation-dependent bytes-like object. - """ - return msgpack_numpy.dumps(obj, use_bin_type=True) - - -def load_from_bytes(buf): - """ - Args: - buf: the output of `dumps`. - """ - # Since 0.6, the default max size was set to 1MB. - # We change it to approximately 1G. - return msgpack_numpy.loads(buf, raw=False, - max_bin_len=MAX_MSGPACK_LEN, - max_array_len=MAX_MSGPACK_LEN, - max_map_len=MAX_MSGPACK_LEN, - max_str_len=MAX_MSGPACK_LEN) diff --git a/tensorlayer/dataflow/utils.py b/tensorlayer/dataflow/utils.py deleted file mode 100644 index a90016aeb..000000000 --- a/tensorlayer/dataflow/utils.py +++ /dev/null @@ -1,199 +0,0 @@ -import atexit -import logging -import math -import multiprocessing -import os -import weakref - -import psutil -import tarfile -import time -import zipfile -import progressbar -from urllib.request import urlretrieve - - -def exists_or_mkdir(path, verbose=True): - """ - Check a folder by given name, if not exist, create the folder and return False, - if directory exists, return True. - - Parameters - ---------- - path : str - A folder path. - verbose : boolean - If True (default), prints results. - - Returns - -------- - boolean - True if folder already exist, otherwise, returns False and create the folder. - - Examples - -------- - >>> exists_or_mkdir("checkpoints/train") - - """ - if not os.path.exists(path): - if verbose: - logging.info("[*] creates %s ..." % path) - os.makedirs(path) - return False - else: - if verbose: - logging.info("[!] %s exists ..." % path) - return True - - -def download(filename, working_directory, url_source): - """ - Download file from url_source to the working_directory with given filename. - - Parameters - ---------- - filename : str - The name of the downloaded file. - working_directory : str - A folder path download the file to - url_source : str - The URL to download the file from - - Examples - -------- - >>> download(filename='train.gz', - ... working_directory='data/', - ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') - - """ - working_directory = os.path.expanduser(working_directory) - - progress_bar = progressbar.ProgressBar() - - def _dlProgress(count, blockSize, totalSize, pbar=progress_bar): - if (totalSize != 0): - - if not pbar.max_value: - totalBlocks = math.ceil(float(totalSize) / float(blockSize)) - pbar.max_value = int(totalBlocks) - - pbar.update(count, force=True) - - filepath = os.path.join(working_directory, filename) - - logging.info('Downloading %s...\n' % filename) - - urlretrieve(url_source, filepath, reporthook=_dlProgress) - - -def maybe_download_and_extract(filename, working_directory, url_source, extract=False, expected_bytes=None): - """ - Checks if file exists in working_directory otherwise tries to dowload the file, - and optionally also tries to extract the file if format is ".zip" or ".tar" - - Parameters - ----------- - filename : str - The name of the (to be) dowloaded file. - working_directory : str - A folder path to search for the file in and dowload the file to - url_source : str - The URL to download the file from - extract : boolean - If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file, default is False. - expected_bytes : int or None - If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception, defaults is None which corresponds to no check being performed. - - Returns - ---------- - str - File path of the dowloaded (uncompressed) file. - - Examples - -------- - >>> down_file = maybe_download_and_extract(filename='train.gz', - ... working_directory='data/', - ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') - >>> maybe_download_and_extract(filename='ADEChallengeData2016.zip', - ... working_directory='data/', - ... url_source='http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip', - ... extract=True) - - """ - working_directory = os.path.expanduser(working_directory) - exists_or_mkdir(working_directory, verbose=False) - filepath = os.path.join(working_directory, filename) - - if not os.path.exists(filepath): - download(filename, working_directory, url_source) - statinfo = os.stat(filepath) - logging.info('Succesfully downloaded %s %s bytes.' % (filename, statinfo.st_size)) # , 'bytes.') - if not (expected_bytes is None) and (expected_bytes != statinfo.st_size): - raise Exception('Failed to verify ' + filename + '. Can you get to it with a browser?') - if extract: - if tarfile.is_tarfile(filepath): - logging.info('Trying to extract tar file') - tarfile.open(filepath, 'r').extractall(working_directory) - logging.info('... Success!') - elif zipfile.is_zipfile(filepath): - logging.info('Trying to extract zip file') - with zipfile.ZipFile(filepath) as zf: - zf.extractall(working_directory) - logging.info('... Success!') - else: - logging.info("Unknown compression_format only .tar.gz/.tar.bz2/.tar and .zip supported") - return filepath - - -def get_dataloader_speed(dl, num_steps): - cnt = 0 - start = time.time() - end = start - for _ in dl: - cnt += 1 - if cnt == num_steps: - end = time.time() - break - return (end - start) / num_steps - - -def format_bytes(bytes): - if abs(bytes) < 1000: - return str(bytes) + "B" - elif abs(bytes) < 1e6: - return str(round(bytes / 1e3, 2)) + "kB" - elif abs(bytes) < 1e9: - return str(round(bytes / 1e6, 2)) + "MB" - else: - return str(round(bytes / 1e9, 2)) + "GB" - - -def get_process_memory(): - process = psutil.Process(os.getpid()) - mi = process.memory_info() - return mi.rss, mi.vms, mi.vms - - -def ensure_proc_terminate(proc): - """ - Make sure processes terminate when main process exit. - - Args: - proc (multiprocessing.Process or list) - """ - if isinstance(proc, list): - for p in proc: - ensure_proc_terminate(p) - return - - def stop_proc_by_weak_ref(ref): - proc = ref() - if proc is None: - return - if not proc.is_alive(): - return - proc.terminate() - proc.join() - - assert isinstance(proc, multiprocessing.Process) - atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) From 13a43c16d1f3d4ffdcc96f841755c3c1a54394ce Mon Sep 17 00:00:00 2001 From: 1FengL Date: Sat, 7 Sep 2019 12:53:02 +0100 Subject: [PATCH 4/6] add cifar10 dataflow example --- .../tutorial_cifar10_cnn_static_dataloader.py | 176 ++++++++++++++++++ ..._transformer_network_dynamic_tlDataflow.py | 167 ----------------- 2 files changed, 176 insertions(+), 167 deletions(-) create mode 100644 examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py delete mode 100644 examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py diff --git a/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py b/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py new file mode 100644 index 000000000..226a4365d --- /dev/null +++ b/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import multiprocessing +import time + +import numpy as np + +import tensorflow as tf +import tensorlayer as tl +from tensorlayer.layers import (BatchNorm, Conv2d, Dense, Flatten, Input, LocalResponseNorm, MaxPool2d) +from tensorlayer.models import Model + +# enable debug logging +tl.logging.set_verbosity(tl.logging.DEBUG) +tl.logging.set_verbosity(tl.logging.DEBUG) + + +# define the network +def get_model(inputs_shape): + # self defined initialization + W_init = tl.initializers.truncated_normal(stddev=5e-2) + W_init2 = tl.initializers.truncated_normal(stddev=0.04) + b_init2 = tl.initializers.constant(value=0.1) + + # build network + ni = Input(inputs_shape) + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv1')(ni) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(nn) + nn = LocalResponseNorm(depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name="norm1")(nn) + + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv2')(nn) + nn = LocalResponseNorm(depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name="norm2")(nn) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(nn) + + nn = Flatten(name='flatten')(nn) + nn = Dense(384, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense1relu')(nn) + nn = Dense(192, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense2relu')(nn) + nn = Dense(10, act=None, W_init=W_init2, name='output')(nn) + + M = Model(inputs=ni, outputs=nn, name='cnn') + return M + + +def get_model_batchnorm(inputs_shape): + # self defined initialization + W_init = tl.initializers.truncated_normal(stddev=5e-2) + W_init2 = tl.initializers.truncated_normal(stddev=0.04) + b_init2 = tl.initializers.constant(value=0.1) + + # build network + ni = Input(inputs_shape) + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv1')(ni) + nn = BatchNorm(decay=0.99, act=tf.nn.relu, name='batch1')(nn) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(nn) + + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv2')(nn) + nn = BatchNorm(decay=0.99, act=tf.nn.relu, name='batch2')(nn) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(nn) + + nn = Flatten(name='flatten')(nn) + nn = Dense(384, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense1relu')(nn) + nn = Dense(192, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense2relu')(nn) + nn = Dense(10, act=None, W_init=W_init2, name='output')(nn) + + M = Model(inputs=ni, outputs=nn, name='cnn') + return M + + +# get the network +net = get_model([None, 24, 24, 3]) + +# training settings +batch_size = 128 +n_epoch = 50000 +learning_rate = 0.0001 +print_freq = 5 +# init_learning_rate = 0.1 +# learning_rate_decay_factor = 0.1 +# num_epoch_decay = 350 + +train_weights = net.trainable_weights +# learning_rate = tf.Variable(init_learning_rate) +optimizer = tf.optimizers.Adam(learning_rate) + + +def _fn_train(img, target): + # 1. Randomly crop a [height, width] section of the image. + img = tl.prepro.crop(img, 24, 24, False) + # # 2. Randomly flip the image horizontally. + img = tl.prepro.flip_axis(img, is_random=True) + # # 3. Randomly change brightness. + # # 4. Randomly change contrast. + # img = tl.prepro.brightness(img, is_random=True) + # # 5. Subtract off the mean and divide by the variance of the pixels. + img = tl.prepro.samplewise_norm(img) + # img = tl.prepro.featurewise_norm(img) + target = np.reshape(target, ()) + return img, target + + +def _fn_test(img, target): + # 1. Crop the central [height, width] of the image. + img = tl.prepro.crop(img, 24, 24) + # 2. Subtract off the mean and divide by the variance of the pixels. + img = tl.prepro.samplewise_norm(img) + img = np.reshape(img, (24, 24, 3)) + target = np.reshape(target, ()) + return img, target + + +# dataset API and augmentation +train_ds = tl.data.CIFAR10(train_or_test='train', shape=(-1, 32, 32, 3)) +train_dl = tl.data.Dataloader(train_ds, transforms=[_fn_train], shuffle=True, + batch_size=batch_size, output_types=(np.float32, np.int32)) +test_ds = tl.data.CIFAR10(train_or_test='test', shape=(-1, 32, 32, 3)) +test_dl = tl.data.Dataloader(test_ds, transforms=[_fn_test], batch_size=batch_size) + +for epoch in range(n_epoch): + start_time = time.time() + + train_loss, train_acc, n_iter = 0, 0, 0 + for X_batch, y_batch in train_dl: + net.train() + + with tf.GradientTape() as tape: + # compute outputs + _logits = net(X_batch) + # compute loss and update model + _loss_ce = tl.cost.cross_entropy(_logits, y_batch, name='train_loss') + _loss_L2 = 0 + # for p in tl.layers.get_variables_with_name('relu/W', True, True): + # _loss_L2 += tl.cost.lo_regularizer(1.0)(p) + _loss = _loss_ce + _loss_L2 + # print(_loss) + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + + train_loss += _loss + train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 + + # use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + net.eval() + + val_loss, val_acc, n_iter = 0, 0, 0 + for X_batch, y_batch in test_dl: + _logits = net(X_batch) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_batch, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 + print(" val loss: {}".format(val_loss / n_iter)) + print(" val acc: {}".format(val_acc / n_iter)) + + # FIXME : how to apply lr decay in eager mode? + # learning_rate.assign(tf.train.exponential_decay(init_learning_rate, epoch, num_epoch_decay, + # learning_rate_decay_factor)) + +# use testing data to evaluate the model +net.eval() +test_loss, test_acc, n_iter = 0, 0, 0 +for X_batch, y_batch in test_dl: + _logits = net(X_batch) + test_loss += tl.cost.cross_entropy(_logits, y_batch, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 +print(" test loss: {}".format(test_loss / n_iter)) +print(" test acc: {}".format(test_acc / n_iter)) diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py deleted file mode 100644 index bc0bae141..000000000 --- a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic_tlDataflow.py +++ /dev/null @@ -1,167 +0,0 @@ -#! /usr/bin/python -# -*- coding: utf8 -*- -import time - -import numpy as np - -import tensorflow as tf -import tensorlayer as tl -from tensorlayer.layers import * -from tensorlayer.models import Model - -##================== PREPARE DATA ============================================## -X_train, y_train, X_val, y_val, X_test, y_test = \ - tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) - - -def pad_distort_im_fn(x): - """ Zero pads an image to 40x40, and distort it. - - Examples - --------- - x = pad_distort_im_fn(X_train[0]) - print(x, x.shape, x.max()) - tl.vis.save_image(x, '_xd.png') - tl.vis.save_image(X_train[0], '_x.png') - """ - b = np.zeros((40, 40, 1), dtype=np.float32) - o = int((40 - 28) / 2) - b[o:o + 28, o:o + 28] = x - x = b - x = tl.prepro.rotation(x, rg=30, is_random=True, fill_mode='constant') - x = tl.prepro.shear(x, 0.05, is_random=True, fill_mode='constant') - x = tl.prepro.shift(x, wrg=0.25, hrg=0.25, is_random=True, fill_mode='constant') - x = tl.prepro.zoom(x, zoom_range=(0.95, 1.05)) - return x - - -def pad_distort_ims_fn(X): - """ Zero pads images to 40x40, and distort them. """ - X_40 = [] - for X_a, _ in tl.iterate.minibatches(X, X, 50, shuffle=False): - X_40.extend(tl.prepro.threading_data(X_a, fn=pad_distort_im_fn)) - X_40 = np.asarray(X_40) - return X_40 - - -# create dataset with size of 40x40 with distortion -X_train_40 = pad_distort_ims_fn(X_train) -X_val_40 = pad_distort_ims_fn(X_val) -X_test_40 = pad_distort_ims_fn(X_test) - -tl.vis.save_images(X_test[0:32], [4, 8], '_imgs_original.png') -tl.vis.save_images(X_test_40[0:32], [4, 8], '_imgs_distorted.png') - - -##================== DEFINE MODEL ============================================## -class Net(Model): - - def __init__(self): - super(Net, self).__init__() - - ## 1. Localisation network - # use MLP as the localisation net - self.flatten1 = Flatten() - self.dense1 = Dense(n_units=20, in_channels=1600, act=tf.nn.tanh) - self.dropout1 = Dropout(keep=0.8) - # you can also use CNN instead for MLP as the localisation net - - ## 2. Spatial transformer module (sampler) - self.stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) - - ## 3. Classifier - self.conv1 = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', in_channels=1) - self.conv2 = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', in_channels=16) - self.flatten2 = Flatten() - self.dense2 = Dense(n_units=1024, in_channels=1600, act=tf.nn.relu) - self.dense3 = Dense(n_units=10, in_channels=1024, act=tf.identity) - - def forward(self, inputs): - theta_input = self.dropout1(self.dense1(self.flatten1(inputs))) - V = self.stn((theta_input, inputs)) - _logits = self.dense3(self.dense2(self.flatten2(self.conv2(self.conv1(V))))) - return _logits, V - - -net = Net() - -##================== DEFINE TRAIN OPS ========================================## -n_epoch = 100 -learning_rate = 0.0001 -print_freq = 10 -batch_size = 64 -train_weights = net.trainable_weights -optimizer = tf.optimizers.Adam(lr=learning_rate) - -##================== TRAINING ================================================## -print("Training ...") -for epoch in range(n_epoch): - start_time = time.time() - - net.train() # enable dropout - - for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=True): - # input_dim must be of length 4 - X_train_a = tf.expand_dims(X_train_a, 3) - - with tf.GradientTape() as tape: - ## compute outputs - _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=True) and remove MLP.train() - ## compute loss and update model - _loss = tl.cost.cross_entropy(_logits, y_train_a, name='train_loss') - - grad = tape.gradient(_loss, train_weights) - optimizer.apply_gradients(zip(grad, train_weights)) - - ## use training and evaluation sets to evaluate the model every print_freq epoch - if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: - - net.eval() # disable dropout - - print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) - - train_loss, train_acc, n_iter = 0, 0, 0 - for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=False): - # input_dim must be of length 4 - X_train_a = tf.expand_dims(X_train_a, 3) - - _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=False) and remove MLP.eval() - train_loss += tl.cost.cross_entropy(_logits, y_train_a, name='eval_train_loss') - train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_train_a)) - n_iter += 1 - print(" train loss: %f" % (train_loss / n_iter)) - print(" train acc: %f" % (train_acc / n_iter)) - - val_loss, val_acc, n_iter = 0, 0, 0 - for X_val_a, y_val_a in tl.iterate.minibatches(X_val_40, y_val, batch_size, shuffle=False): - # input_dim must be of length 4 - X_val_a = tf.expand_dims(X_val_a, 3) - - _logits, _ = net(X_val_a) # is_train=False, disable dropout - val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') - val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) - n_iter += 1 - print(" val loss: %f" % (val_loss / n_iter)) - print(" val acc: %f" % (val_acc / n_iter)) - - print('save images') - _, trans_imgs = net(tf.expand_dims(X_test_40[0:64], 3)) - trans_imgs = trans_imgs.numpy() - tl.vis.save_images(trans_imgs[0:32], [4, 8], '_imgs_distorted_after_stn_%s.png' % epoch) - -##================== EVALUATION ==============================================## -print('Evaluation') - -net.eval() - -test_loss, test_acc, n_iter = 0, 0, 0 -for X_test_a, y_test_a in tl.iterate.minibatches(X_test_40, y_test, batch_size, shuffle=False): - # input_dim must be of length 4 - X_test_a = tf.expand_dims(X_test_a, 3) - - _logits, _ = net(X_test_a) - test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') - test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) - n_iter += 1 -print(" test loss: %f" % (test_loss / n_iter)) -print(" test acc: %f" % (test_acc / n_iter)) From 1761e854ec1ffdeaefa9bfc4aa137d4fc4c3359f Mon Sep 17 00:00:00 2001 From: 1FengL Date: Sat, 7 Sep 2019 12:56:57 +0100 Subject: [PATCH 5/6] fix typo --- .../tutorial_cifar10_cnn_static_dataloader.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py b/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py index 226a4365d..7269aae31 100644 --- a/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py +++ b/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py @@ -87,14 +87,10 @@ def get_model_batchnorm(inputs_shape): def _fn_train(img, target): # 1. Randomly crop a [height, width] section of the image. img = tl.prepro.crop(img, 24, 24, False) - # # 2. Randomly flip the image horizontally. + # 2. Randomly flip the image horizontally. img = tl.prepro.flip_axis(img, is_random=True) - # # 3. Randomly change brightness. - # # 4. Randomly change contrast. - # img = tl.prepro.brightness(img, is_random=True) - # # 5. Subtract off the mean and divide by the variance of the pixels. + # 3. Subtract off the mean and divide by the variance of the pixels. img = tl.prepro.samplewise_norm(img) - # img = tl.prepro.featurewise_norm(img) target = np.reshape(target, ()) return img, target From 39d8936d268ce9fc1ecfc912720c4ab70dbb27dd Mon Sep 17 00:00:00 2001 From: 1FengL Date: Sun, 8 Sep 2019 00:44:40 +0100 Subject: [PATCH 6/6] set default path for tl.files.dataset_loaders from 'data' to 'raw_data' --- tensorlayer/files/dataset_loaders/celebA_dataset.py | 2 +- tensorlayer/files/dataset_loaders/cifar10_dataset.py | 2 +- tensorlayer/files/dataset_loaders/cyclegan_dataset.py | 2 +- tensorlayer/files/dataset_loaders/flickr_1M_dataset.py | 2 +- tensorlayer/files/dataset_loaders/flickr_25k_dataset.py | 2 +- tensorlayer/files/dataset_loaders/imdb_dataset.py | 2 +- tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py | 2 +- tensorlayer/files/dataset_loaders/mnist_dataset.py | 2 +- tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py | 2 +- tensorlayer/files/dataset_loaders/mpii_dataset.py | 2 +- tensorlayer/files/dataset_loaders/nietzsche_dataset.py | 2 +- tensorlayer/files/dataset_loaders/ptb_dataset.py | 2 +- tensorlayer/files/dataset_loaders/voc_dataset.py | 2 +- tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tensorlayer/files/dataset_loaders/celebA_dataset.py b/tensorlayer/files/dataset_loaders/celebA_dataset.py index d5dc5755f..73ebcbced 100644 --- a/tensorlayer/files/dataset_loaders/celebA_dataset.py +++ b/tensorlayer/files/dataset_loaders/celebA_dataset.py @@ -10,7 +10,7 @@ __all__ = ['load_celebA_dataset'] -def load_celebA_dataset(path='data'): +def load_celebA_dataset(path='raw_data'): """Load CelebA dataset Return a list of image path. diff --git a/tensorlayer/files/dataset_loaders/cifar10_dataset.py b/tensorlayer/files/dataset_loaders/cifar10_dataset.py index 9af3f615d..bc36ff838 100644 --- a/tensorlayer/files/dataset_loaders/cifar10_dataset.py +++ b/tensorlayer/files/dataset_loaders/cifar10_dataset.py @@ -13,7 +13,7 @@ __all__ = ['load_cifar10_dataset'] -def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data', plotable=False): +def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='raw_data', plotable=False): """Load CIFAR-10 dataset. It consists of 60000 32x32 colour images in 10 classes, with diff --git a/tensorlayer/files/dataset_loaders/cyclegan_dataset.py b/tensorlayer/files/dataset_loaders/cyclegan_dataset.py index e327b3b4c..9f77b6710 100644 --- a/tensorlayer/files/dataset_loaders/cyclegan_dataset.py +++ b/tensorlayer/files/dataset_loaders/cyclegan_dataset.py @@ -11,7 +11,7 @@ __all__ = ['load_cyclegan_dataset'] -def load_cyclegan_dataset(filename='summer2winter_yosemite', path='data'): +def load_cyclegan_dataset(filename='summer2winter_yosemite', path='raw_data'): """Load images from CycleGAN's database, see `this link `__. Parameters diff --git a/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py b/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py index f2e582ae5..a81cfeac5 100644 --- a/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py +++ b/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py @@ -11,7 +11,7 @@ __all__ = ['load_flickr1M_dataset'] -def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printable=False): +def load_flickr1M_dataset(tag='sky', size=10, path="raw_data", n_threads=50, printable=False): """Load Flick1M dataset. Returns a list of images by a given tag from Flickr1M dataset, diff --git a/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py b/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py index 8049a0653..4c695b4e5 100644 --- a/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py +++ b/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py @@ -11,7 +11,7 @@ __all__ = ['load_flickr25k_dataset'] -def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False): +def load_flickr25k_dataset(tag='sky', path="raw_data", n_threads=50, printable=False): """Load Flickr25K dataset. Returns a list of images by a given tag from Flick25k dataset, diff --git a/tensorlayer/files/dataset_loaders/imdb_dataset.py b/tensorlayer/files/dataset_loaders/imdb_dataset.py index 34b4dffe0..686e6f906 100644 --- a/tensorlayer/files/dataset_loaders/imdb_dataset.py +++ b/tensorlayer/files/dataset_loaders/imdb_dataset.py @@ -13,7 +13,7 @@ def load_imdb_dataset( - path='data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, oov_char=2, + path='raw_data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, oov_char=2, index_from=3 ): """Load IMDB dataset. diff --git a/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py b/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py index 17a3e0833..f5a419f08 100644 --- a/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py +++ b/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py @@ -10,7 +10,7 @@ __all__ = ['load_matt_mahoney_text8_dataset'] -def load_matt_mahoney_text8_dataset(path='data'): +def load_matt_mahoney_text8_dataset(path='raw_data'): """Load Matt Mahoney's dataset. Download a text file from Matt Mahoney's website diff --git a/tensorlayer/files/dataset_loaders/mnist_dataset.py b/tensorlayer/files/dataset_loaders/mnist_dataset.py index 4e1346d5e..5e297dbc1 100644 --- a/tensorlayer/files/dataset_loaders/mnist_dataset.py +++ b/tensorlayer/files/dataset_loaders/mnist_dataset.py @@ -6,7 +6,7 @@ __all__ = ['load_mnist_dataset'] -def load_mnist_dataset(shape=(-1, 784), path='data'): +def load_mnist_dataset(shape=(-1, 784), path='raw_data'): """Load the original mnist. Automatically download MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 digit images respectively. diff --git a/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py b/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py index c7f1bb964..011d8688e 100644 --- a/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py +++ b/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py @@ -6,7 +6,7 @@ __all__ = ['load_fashion_mnist_dataset'] -def load_fashion_mnist_dataset(shape=(-1, 784), path='data'): +def load_fashion_mnist_dataset(shape=(-1, 784), path='raw_data'): """Load the fashion mnist. Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, `examples `__. diff --git a/tensorlayer/files/dataset_loaders/mpii_dataset.py b/tensorlayer/files/dataset_loaders/mpii_dataset.py index a6f88f609..bbbbe233f 100644 --- a/tensorlayer/files/dataset_loaders/mpii_dataset.py +++ b/tensorlayer/files/dataset_loaders/mpii_dataset.py @@ -9,7 +9,7 @@ __all__ = ['load_mpii_pose_dataset'] -def load_mpii_pose_dataset(path='data', is_16_pos_only=False): +def load_mpii_pose_dataset(path='raw_data', is_16_pos_only=False): """Load MPII Human Pose Dataset. Parameters diff --git a/tensorlayer/files/dataset_loaders/nietzsche_dataset.py b/tensorlayer/files/dataset_loaders/nietzsche_dataset.py index 3cd0e27f2..2bfb02113 100644 --- a/tensorlayer/files/dataset_loaders/nietzsche_dataset.py +++ b/tensorlayer/files/dataset_loaders/nietzsche_dataset.py @@ -9,7 +9,7 @@ __all__ = ['load_nietzsche_dataset'] -def load_nietzsche_dataset(path='data'): +def load_nietzsche_dataset(path='raw_data'): """Load Nietzsche dataset. Parameters diff --git a/tensorlayer/files/dataset_loaders/ptb_dataset.py b/tensorlayer/files/dataset_loaders/ptb_dataset.py index 30746fd87..023896828 100644 --- a/tensorlayer/files/dataset_loaders/ptb_dataset.py +++ b/tensorlayer/files/dataset_loaders/ptb_dataset.py @@ -9,7 +9,7 @@ __all__ = ['load_ptb_dataset'] -def load_ptb_dataset(path='data'): +def load_ptb_dataset(path='raw_data'): """Load Penn TreeBank (PTB) dataset. It is used in many LANGUAGE MODELING papers, diff --git a/tensorlayer/files/dataset_loaders/voc_dataset.py b/tensorlayer/files/dataset_loaders/voc_dataset.py index e5124b4df..b17650f14 100644 --- a/tensorlayer/files/dataset_loaders/voc_dataset.py +++ b/tensorlayer/files/dataset_loaders/voc_dataset.py @@ -10,7 +10,7 @@ __all__ = ['load_voc_dataset'] -def load_voc_dataset(path='data', dataset='2012', contain_classes_in_person=False): +def load_voc_dataset(path='raw_data', dataset='2012', contain_classes_in_person=False): """Pascal VOC 2007/2012 Dataset. It has 20 objects: diff --git a/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py b/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py index 77c1f93f9..572afdb46 100644 --- a/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py +++ b/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py @@ -12,7 +12,7 @@ __all__ = ['load_wmt_en_fr_dataset'] -def load_wmt_en_fr_dataset(path='data'): +def load_wmt_en_fr_dataset(path='raw_data'): """Load WMT'15 English-to-French translation dataset. It will download the data from the WMT'15 Website (10^9-French-English corpus), and the 2013 news test from the same site as development set.