diff --git a/.gitignore b/.gitignore index 4a6f60ff3..22d68693c 100644 --- a/.gitignore +++ b/.gitignore @@ -119,7 +119,7 @@ venv_py2/ # TensorLayer Directories checkpoints -data/ +raw_data/ lib_win/ # Custom Scripts diff --git a/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py b/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py new file mode 100644 index 000000000..7269aae31 --- /dev/null +++ b/examples/basic_tutorials/tutorial_cifar10_cnn_static_dataloader.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import multiprocessing +import time + +import numpy as np + +import tensorflow as tf +import tensorlayer as tl +from tensorlayer.layers import (BatchNorm, Conv2d, Dense, Flatten, Input, LocalResponseNorm, MaxPool2d) +from tensorlayer.models import Model + +# enable debug logging +tl.logging.set_verbosity(tl.logging.DEBUG) +tl.logging.set_verbosity(tl.logging.DEBUG) + + +# define the network +def get_model(inputs_shape): + # self defined initialization + W_init = tl.initializers.truncated_normal(stddev=5e-2) + W_init2 = tl.initializers.truncated_normal(stddev=0.04) + b_init2 = tl.initializers.constant(value=0.1) + + # build network + ni = Input(inputs_shape) + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv1')(ni) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(nn) + nn = LocalResponseNorm(depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name="norm1")(nn) + + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv2')(nn) + nn = LocalResponseNorm(depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name="norm2")(nn) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(nn) + + nn = Flatten(name='flatten')(nn) + nn = Dense(384, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense1relu')(nn) + nn = Dense(192, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense2relu')(nn) + nn = Dense(10, act=None, W_init=W_init2, name='output')(nn) + + M = Model(inputs=ni, outputs=nn, name='cnn') + return M + + +def get_model_batchnorm(inputs_shape): + # self defined initialization + W_init = tl.initializers.truncated_normal(stddev=5e-2) + W_init2 = tl.initializers.truncated_normal(stddev=0.04) + b_init2 = tl.initializers.constant(value=0.1) + + # build network + ni = Input(inputs_shape) + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv1')(ni) + nn = BatchNorm(decay=0.99, act=tf.nn.relu, name='batch1')(nn) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(nn) + + nn = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv2')(nn) + nn = BatchNorm(decay=0.99, act=tf.nn.relu, name='batch2')(nn) + nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(nn) + + nn = Flatten(name='flatten')(nn) + nn = Dense(384, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense1relu')(nn) + nn = Dense(192, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='dense2relu')(nn) + nn = Dense(10, act=None, W_init=W_init2, name='output')(nn) + + M = Model(inputs=ni, outputs=nn, name='cnn') + return M + + +# get the network +net = get_model([None, 24, 24, 3]) + +# training settings +batch_size = 128 +n_epoch = 50000 +learning_rate = 0.0001 +print_freq = 5 +# init_learning_rate = 0.1 +# learning_rate_decay_factor = 0.1 +# num_epoch_decay = 350 + +train_weights = net.trainable_weights +# learning_rate = tf.Variable(init_learning_rate) +optimizer = tf.optimizers.Adam(learning_rate) + + +def _fn_train(img, target): + # 1. Randomly crop a [height, width] section of the image. + img = tl.prepro.crop(img, 24, 24, False) + # 2. Randomly flip the image horizontally. + img = tl.prepro.flip_axis(img, is_random=True) + # 3. Subtract off the mean and divide by the variance of the pixels. + img = tl.prepro.samplewise_norm(img) + target = np.reshape(target, ()) + return img, target + + +def _fn_test(img, target): + # 1. Crop the central [height, width] of the image. + img = tl.prepro.crop(img, 24, 24) + # 2. Subtract off the mean and divide by the variance of the pixels. + img = tl.prepro.samplewise_norm(img) + img = np.reshape(img, (24, 24, 3)) + target = np.reshape(target, ()) + return img, target + + +# dataset API and augmentation +train_ds = tl.data.CIFAR10(train_or_test='train', shape=(-1, 32, 32, 3)) +train_dl = tl.data.Dataloader(train_ds, transforms=[_fn_train], shuffle=True, + batch_size=batch_size, output_types=(np.float32, np.int32)) +test_ds = tl.data.CIFAR10(train_or_test='test', shape=(-1, 32, 32, 3)) +test_dl = tl.data.Dataloader(test_ds, transforms=[_fn_test], batch_size=batch_size) + +for epoch in range(n_epoch): + start_time = time.time() + + train_loss, train_acc, n_iter = 0, 0, 0 + for X_batch, y_batch in train_dl: + net.train() + + with tf.GradientTape() as tape: + # compute outputs + _logits = net(X_batch) + # compute loss and update model + _loss_ce = tl.cost.cross_entropy(_logits, y_batch, name='train_loss') + _loss_L2 = 0 + # for p in tl.layers.get_variables_with_name('relu/W', True, True): + # _loss_L2 += tl.cost.lo_regularizer(1.0)(p) + _loss = _loss_ce + _loss_L2 + # print(_loss) + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + + train_loss += _loss + train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 + + # use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + net.eval() + + val_loss, val_acc, n_iter = 0, 0, 0 + for X_batch, y_batch in test_dl: + _logits = net(X_batch) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_batch, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 + print(" val loss: {}".format(val_loss / n_iter)) + print(" val acc: {}".format(val_acc / n_iter)) + + # FIXME : how to apply lr decay in eager mode? + # learning_rate.assign(tf.train.exponential_decay(init_learning_rate, epoch, num_epoch_decay, + # learning_rate_decay_factor)) + +# use testing data to evaluate the model +net.eval() +test_loss, test_acc, n_iter = 0, 0, 0 +for X_batch, y_batch in test_dl: + _logits = net(X_batch) + test_loss += tl.cost.cross_entropy(_logits, y_batch, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 +print(" test loss: {}".format(test_loss / n_iter)) +print(" test acc: {}".format(test_acc / n_iter)) diff --git a/tensorlayer/__init__.py b/tensorlayer/__init__.py index f89eebfff..8c7078733 100644 --- a/tensorlayer/__init__.py +++ b/tensorlayer/__init__.py @@ -44,6 +44,7 @@ from tensorlayer import optimizers from tensorlayer import rein from tensorlayer import utils + from tensorlayer import data from tensorlayer.lazy_imports import LazyImport @@ -54,6 +55,7 @@ prepro = LazyImport("tensorlayer.prepro") utils = LazyImport("tensorlayer.utils") visualize = LazyImport("tensorlayer.visualize") + data = LazyImport("tensorlayer.data") # alias act = activation diff --git a/tensorlayer/data/__init__.py b/tensorlayer/data/__init__.py new file mode 100644 index 000000000..cd53ae92b --- /dev/null +++ b/tensorlayer/data/__init__.py @@ -0,0 +1,3 @@ +from .base import Dataset +from .common import Dataloader +from .dataset import * diff --git a/tensorlayer/data/base.py b/tensorlayer/data/base.py new file mode 100644 index 000000000..6912f2cd3 --- /dev/null +++ b/tensorlayer/data/base.py @@ -0,0 +1,65 @@ +class Dataset(object): + + def __getitem__(self, index): + raise NotImplementedError("A Dataset must implement __getitem__(index) method.") + + def __len__(self): + raise NotImplementedError("A Dataset must implement __len__() method.") + + def __iter__(self): + for i in range(self.__len__()): + yield self.__getitem__(i) + + def __call__(self, *args, **kwargs): + return self.__iter__() + + +class DatasetWrapper(object): + def __init__(self, ds): + self.ds = ds + self.ds_len = len(ds) + + def __len__(self): + return len(self.ds) + + def __iter__(self): + for dp in self.ds: + yield dp + + def __call__(self, *args, **kwargs): + return self.__iter__() + + +class IndexableDatasetWrapper(object): + def __init__(self, ds): + self.ds = ds + self.ds_len = len(ds) + + def __getitem__(self, index): + return self.ds.__getitem__(index) + + def __len__(self): + return len(self.ds) + + def __call__(self, *args, **kwargs): + return self + + +class Transform(object): + def __call__(self, *args, **kwargs): + raise NotImplementedError("Transform must implement __call__() method.") + + +class _Transforms_for_tf_dataset(object): + """ + This class aggregate Transforms into one object in order to use tf.data.Dataset.map API + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, *args): + data_list = list(args) + for transform in self.transforms: + data_list = transform(*data_list) + return data_list diff --git a/tensorlayer/data/common.py b/tensorlayer/data/common.py new file mode 100644 index 000000000..ed5a0f518 --- /dev/null +++ b/tensorlayer/data/common.py @@ -0,0 +1,348 @@ +import math +import multiprocessing +import os + +import tensorflow as tf +import zmq + +import numpy as np + +from .base import IndexableDatasetWrapper, DatasetWrapper, _Transforms_for_tf_dataset +from .parallel import _get_pipe_name, ZMQMultiprocessDataset, MultiprocessDataset +from .utils import clean_up_socket_files +from .serialize import convert_to_bytes, load_from_bytes + +__all__ = ['PrefetchBatchedDataset', 'TransformedDataset', 'ShuffledDataset', + 'AugmentedDataset'] + + +class BatchedDataset(DatasetWrapper): + def __init__(self, + ds, + batch_size, + drop_remainder=True, + return_numpy=True, + output_types=None): + super(BatchedDataset, self).__init__(ds) + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.return_numpy = return_numpy + self.output_types = output_types + + def __iter__(self): + dp_buffer = [] + for dp in self.ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + yield self._batch_datapoints(dp_buffer, self.return_numpy, self.output_types) + del dp_buffer[:] + if not self.drop_remainder: + self._batch_datapoints(dp_buffer, self.return_numpy, self.output_types) + + def __len__(self): + ds_len = len(self.ds) + if self.drop_remainder: + return ds_len // self.batch_size + else: + return math.ceil(ds_len / self.batch_size) + + @staticmethod + def _batch_datapoints(dp_buffer, return_numpy, output_types): + """ + + :param dp_buffer: a list of datapoints + :return: + """ + first_dp = dp_buffer[0] + if isinstance(first_dp, (tuple, list)): + dp_batch = [None] * len(first_dp) + for i in range(len(first_dp)): + dp_element_batch = [] + for j in range(len(dp_buffer)): + dp_element_batch.append(dp_buffer[j][i]) + if return_numpy: + dp_batch[i] = BatchedDataset._batch_ndarray(dp_element_batch, + dtype=BatchedDataset._get_element_dtype(output_types, + i)) + else: + dp_batch[i] = dp_element_batch + return dp_batch + elif isinstance(first_dp, dict): + dp_batch = {} + for key in first_dp.keys(): + dp_element_batch = [] + for j in range(len(dp_buffer)): + dp_element_batch.append(dp_buffer[j][key]) + if return_numpy: + dp_batch[key] = BatchedDataset._batch_ndarray(dp_element_batch, dtype=None) + else: + dp_batch[key] = dp_element_batch + return dp_batch + elif isinstance(first_dp, np.ndarray): + return BatchedDataset._batch_ndarray(dp_buffer) + # single elements + else: + if return_numpy: + return BatchedDataset._batch_ndarray(dp_buffer, + dtype=BatchedDataset._get_element_dtype(output_types, 0)) + else: + return dp_buffer + + @staticmethod + def _batch_ndarray(dp_element_batch, dtype): + """ + + :param dp_element_batch: a list of datapoint element, an element can be np.ndarray / list + :return: np.ndarray, type is the same as input + """ + try: + if dtype is not None: + ret = np.asarray(dp_element_batch, dtype=dtype) + else: + ret = np.asarray(dp_element_batch) + return ret + except: + raise ValueError("Unsupported type for batching.") + + @staticmethod + def _get_element_dtype(output_types, i): + if output_types is None: + return None + if not isinstance(output_types, (tuple, list)): + return output_types + if len(output_types) == 1: + return output_types[0] + return output_types[i] + + +class PrefetchBatchedDataset(DatasetWrapper): + def __init__(self, + ds, + batch_size, + drop_remainder=True, + return_numpy=True, + output_types=None, + use_zmq=True): + super(PrefetchBatchedDataset, self).__init__(ds) + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.return_numpy = return_numpy + self.output_types = output_types + self.use_zmq = use_zmq + + if self.use_zmq: + self.data_pipename = _get_pipe_name('batch_prefetch') + context = zmq.Context() + self.fetch_data_socket = context.socket(zmq.PULL) + self.fetch_data_socket.set_hwm(1) + self.fetch_data_socket.bind(self.data_pipename) + self.worker = multiprocessing.Process(target=self._ZMQ_BatchedDataset_worker, + args=(self.ds,)) + self.worker.daemon = True + self.worker.start() + clean_up_socket_files(self.data_pipename) + else: + pipe_output, pipe_input = multiprocessing.Pipe() + self.worker = multiprocessing.Process(target=self._BatchedDataset_worker, + args=(self.ds, (pipe_output, pipe_input))) + self.worker.daemon = True + self.worker.start() + # main process only reads (gets output) + pipe_input.close() + self.pipe_output = pipe_output + + def _ZMQ_BatchedDataset_worker(self, ds): + context = zmq.Context() + prepare_data_socket = context.socket(zmq.PUSH) + prepare_data_socket.set_hwm(1) + prepare_data_socket.connect(self.data_pipename) + while True: + dp_buffer = [] + for dp in ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + prepare_data_socket.send(convert_to_bytes( + BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)), copy=False) + del dp_buffer[:] + if not self.drop_remainder: + prepare_data_socket.send( + convert_to_bytes(BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)), + copy=False) + + def _BatchedDataset_worker(self, ds, pipe): + pipe_output, pipe_input = pipe + # worker process only writes (puts input) + pipe_output.close() + while True: + dp_buffer = [] + for dp in ds: + dp_buffer.append(dp) + if len(dp_buffer) == self.batch_size: + pipe_input.send(BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)) + del dp_buffer[:] + if not self.drop_remainder: + pipe_input.send(BatchedDataset._batch_datapoints(dp_buffer, self.return_numpy, self.output_types)) + + def __iter__(self): + for _ in range(self.__len__()): + # yield self.q.get() + if self.use_zmq: + yield load_from_bytes(self.fetch_data_socket.recv(copy=False)) + else: + yield self.pipe_output.recv() + + def __len__(self): + ds_len = len(self.ds) + if self.drop_remainder: + return ds_len // self.batch_size + else: + return math.ceil(ds_len / self.batch_size) + + +class ShuffledDataset(DatasetWrapper): + def __init__(self, ds): + super(ShuffledDataset, self).__init__(ds) + + def __iter__(self): + self.shuffled_idxs = np.random.permutation(len(self.ds)) + for index, data in enumerate(self.ds): + yield self.ds[self.shuffled_idxs[index]] + + +class TransformedDataset(IndexableDatasetWrapper): + """ + + """ + + def __init__(self, ds, transforms): + super(TransformedDataset, self).__init__(ds) + self.transforms = transforms + + def __getitem__(self, index): + dp = self.ds[index] + for transform in self.transforms: + assert callable(transform) + if isinstance(dp, (list, tuple)): + dp = transform(*dp) + else: + dp = transform(dp) + return dp + + +class AugmentedDataset(IndexableDatasetWrapper): + def __init__(self, ds, augmentations): + super(AugmentedDataset, self).__init__(ds) + self.augmentations = augmentations + self.num_augmentations = len(self.augmentations) + + def __getitem__(self, index): + if index >= self.__len__(): + raise IndexError + dp = self.ds[index % self.ds_len] + if index < self.ds_len: + return dp + augmentation = self.augmentations[(index // self.ds_len) - 1] + assert callable(augmentation) + if isinstance(dp, (list, tuple)): + return augmentation(*dp) + else: + return augmentation(dp) + + def __len__(self): + # every augmentation gives one more duplication of dataset + return self.ds_len * (1 + self.num_augmentations) + + +class Dataloader(DatasetWrapper): + def __init__(self, + ds, + augmentations=None, + shuffle=False, + batch_size=1, + drop_remainder=True, + output_types=None, + num_worker=os.cpu_count(), + use_zmq=True, + prefetch_batch=True, + num_prefetch=None, + transforms=None): + + super(Dataloader, self).__init__(ds) + self.augmentations = augmentations + self.shuffle = shuffle + self.batch_size = batch_size + self.drop_remainder = drop_remainder + self.output_types = output_types + self.num_worker = num_worker + self.use_zmq = use_zmq + self.prefetch_batch = prefetch_batch + self.num_prefetch = num_worker if num_prefetch is None else num_prefetch + self.transforms = transforms + + if self.augmentations is not None: + self.ds = AugmentedDataset(self.ds, self.augmentations) + + if self.transforms is not None: + self.ds = TransformedDataset(self.ds, self.transforms) + # self.tfds = self.tfds.map(map_func=_Transforms(self.transforms), num_parallel_calls=num_map_worker) + + # TODO: auto adjust num_prefetch + if self.num_worker > 1: + if self.use_zmq: + self.ds = ZMQMultiprocessDataset(self.ds, num_worker=self.num_worker, hwm=self.num_prefetch, + shuffle=self.shuffle) + else: + self.ds = MultiprocessDataset(self.ds, num_worker=self.num_worker, num_prefetch=self.num_prefetch, + shuffle=self.shuffle) + elif self.shuffle: + self.ds = ShuffledDataset(self.ds) + + if self.prefetch_batch: + self.ds = PrefetchBatchedDataset(self.ds, self.batch_size, drop_remainder=self.drop_remainder, + output_types=self.output_types, use_zmq=self.use_zmq) + else: + self.ds = BatchedDataset(self.ds, self.batch_size, drop_remainder=self.drop_remainder, + output_types=self.output_types) + + def __iter__(self): + for dp in self.ds: + yield dp + + +class TFDataloader(DatasetWrapper): + def __init__(self, + ds, + output_types, + augmentations=None, + shuffle=False, + shuffle_buffer_size=None, + batch_size=1, + drop_remainder=True, + num_worker=tf.data.experimental.AUTOTUNE, + transforms=None): + + super(TFDataloader, self).__init__(ds) + self.augmentations = augmentations + self.shuffle = shuffle + self.batch_size = batch_size + self.shuffle_buffer_size = 2 * batch_size if shuffle_buffer_size is None else shuffle_buffer_size + self.drop_remainder = drop_remainder + self.transforms = transforms + + self.ds = tf.data.Dataset.from_generator(self.ds, output_types=output_types) + + if self.shuffle: + self.ds = self.ds.shuffle(buffer_size=self.shuffle_buffer_size) + + if self.transforms is not None: + self.ds = self.ds.map(map_func=_Transforms_for_tf_dataset(self.transforms), + num_parallel_calls=num_worker) + + if self.batch_size > 1: + self.ds = self.ds.batch(batch_size=self.batch_size, drop_remainder=self.drop_remainder) + + self.ds = self.ds.prefetch(tf.data.experimental.AUTOTUNE) + + def __iter__(self): + for dp in self.ds: + yield dp diff --git a/tensorlayer/data/dataset/__init__.py b/tensorlayer/data/dataset/__init__.py new file mode 100644 index 000000000..9b2c3166d --- /dev/null +++ b/tensorlayer/data/dataset/__init__.py @@ -0,0 +1,15 @@ +from .celebA import * +from .cifar10 import * +from .cyclegan import * +from .flickr_25k import * +from .flickr_1M import * +from .ilsvrc import * +from .imdb import * +from .matt_mahoney import * +from .mnist import * +from .mnist_fashion import * +from .mpii import * +from .nietzsche import * +from .ptb import * +from .voc import * +from .wmt_en_fr import * diff --git a/tensorlayer/data/dataset/celebA.py b/tensorlayer/data/dataset/celebA.py new file mode 100644 index 000000000..cb8c6a392 --- /dev/null +++ b/tensorlayer/data/dataset/celebA.py @@ -0,0 +1,100 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os +import zipfile + +import logging + +import cv2 +import numpy as np + +from ..base import Dataset +from ..utils import download_file_from_google_drive, exists_or_mkdir, load_file_list + +__all__ = ['load_celebA_dataset', 'CelebAFiles', 'CelebA'] + + +def load_celebA_dataset(name='celebA', path='raw_data'): + """ + Load CelebA dataset + + Return a list of image path. + + Parameters + ----------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/celebA/``. + """ + data_dir = name + filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" + save_path = os.path.join(path, filename) + image_path = os.path.join(path, data_dir) + if os.path.exists(image_path): + logging.info('[*] {} already exists'.format(save_path)) + else: + exists_or_mkdir(path) + download_file_from_google_drive(drive_id, save_path) + zip_dir = '' + with zipfile.ZipFile(save_path) as zf: + zip_dir = zf.namelist()[0] + zf.extractall(path) + os.remove(save_path) + os.rename(os.path.join(path, zip_dir), image_path) + + data_files = load_file_list(path=image_path, regx='\\.jpg', printable=False) + for i, _v in enumerate(data_files): + data_files[i] = os.path.join(image_path, data_files[i]) + return data_files + + +class CelebAFiles(Dataset): + """ + Load CelebA dataset. Produce filenames of images. + + Parameters + ----------- + name : str + The name of the dataset + path : str + The path that the data is downloaded to, defaults is ``raw_data/celebA/``. + """ + def __init__(self, name='celebA', path='raw_data'): + self.data_files = load_celebA_dataset(name=name, path=path) + + def __getitem__(self, index): + return self.data_files[index] + + def __len__(self): + return len(self.data_files) + + +class CelebA(CelebAFiles): + """ + Load CelebA dataset. Produce nparrays of images. + + Parameters + ----------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/celebA/``. + shape : tuple + The shape of digit images. + """ + def __init__(self, shape=None, name='celebA', path='raw_data'): + super(CelebA, self).__init__(name=name, path=path) + self.shape = shape + + def __getitem__(self, index): + file_path = self.data_files[index] + img = cv2.imread(file_path) + if self.shape: + img = cv2.resize(img, self.shape) + img = np.array(img, dtype=np.float32) + return img + + def __len__(self): + return len(self.data_files) diff --git a/tensorlayer/data/dataset/cifar10.py b/tensorlayer/data/dataset/cifar10.py new file mode 100644 index 000000000..dda876efb --- /dev/null +++ b/tensorlayer/data/dataset/cifar10.py @@ -0,0 +1,190 @@ +import logging +import os +import pickle +import sys +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_cifar10_dataset', 'CIFAR10'] + +CIFAR10_BASE_URL = 'https://www.cs.toronto.edu/~kriz/' +CIFAR10_FILENAME = 'cifar-10-python.tar.gz' + + +# Helper function to unpickle the data +def unpickle(file): + fp = open(file, 'rb') + if sys.version_info.major == 2: + data = pickle.load(fp) + elif sys.version_info.major == 3: + data = pickle.load(fp, encoding='latin-1') + else: + raise RuntimeError("Sys Version Unsupported") + fp.close() + return data + + +def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='raw_data', name='cifar10', plotable=False): + """ + Load CIFAR-10 dataset. + + It consists of 60000 32x32 colour images in 10 classes, with + 6000 images per class. There are 50000 training images and 10000 test images. + + The dataset is divided into five training batches and one test batch, each with + 10000 images. The test batch contains exactly 1000 randomly-selected images from + each class. The training batches contain the remaining images in random order, + but some training batches may contain more images from one class than another. + Between them, the training batches contain exactly 5000 images from each class. + + Parameters + ---------- + shape : tuple + The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3). + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/cifar10/``. + plotable : boolean + Whether to plot some image examples, False as default. + + Examples + -------- + >>> X_train, y_train, X_test, y_test = load_cifar10_dataset(shape=(-1, 32, 32, 3)) + + References + ---------- + - `CIFAR website `__ + - `Data download link `__ + - ``__ + + """ + path = os.path.join(path, name) + logging.info("Load or Download cifar10 > {}".format(path)) + + # Download and uncompress file + maybe_download_and_extract(CIFAR10_FILENAME, path, CIFAR10_BASE_URL, extract=True) + + # Unpickle file and fill in data + X_train = None + y_train = [] + for i in range(1, 6): + data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) + if i == 1: + X_train = data_dic['data'] + else: + X_train = np.vstack((X_train, data_dic['data'])) + y_train += data_dic['labels'] + + test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) + X_test = test_data_dic['data'] + y_test = np.array(test_data_dic['labels']) + + if shape == (-1, 3, 32, 32): + X_test = X_test.reshape(shape) + X_train = X_train.reshape(shape) + elif shape == (-1, 32, 32, 3): + X_test = X_test.reshape(shape, order='F') + X_train = X_train.reshape(shape, order='F') + X_test = np.transpose(X_test, (0, 2, 1, 3)) + X_train = np.transpose(X_train, (0, 2, 1, 3)) + else: + X_test = X_test.reshape(shape) + X_train = X_train.reshape(shape) + + y_train = np.array(y_train) + + if plotable: + logging.info('\nCIFAR-10') + import matplotlib.pyplot as plt + fig = plt.figure(1) + + logging.info('Shape of a training image: X_train[0] %s' % X_train[0].shape) + + plt.ion() # interactive mode + count = 1 + for _ in range(10): # each row + for _ in range(10): # each column + _ = fig.add_subplot(10, 10, count) + if shape == (-1, 3, 32, 32): + # plt.imshow(X_train[count-1], interpolation='nearest') + plt.imshow(np.transpose(X_train[count - 1], (1, 2, 0)), interpolation='nearest') + # plt.imshow(np.transpose(X_train[count-1], (2, 1, 0)), interpolation='nearest') + elif shape == (-1, 32, 32, 3): + plt.imshow(X_train[count - 1], interpolation='nearest') + # plt.imshow(np.transpose(X_train[count-1], (1, 0, 2)), interpolation='nearest') + else: + raise Exception("Do not support the given 'shape' to plot the image examples") + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + count = count + 1 + plt.draw() # interactive mode + plt.pause(3) # interactive mode + + logging.info("X_train: %s" % X_train.shape) + logging.info("y_train: %s" % y_train.shape) + logging.info("X_test: %s" % X_test.shape) + logging.info("y_test: %s" % y_test.shape) + + X_train = np.asarray(X_train, dtype=np.float32) + X_test = np.asarray(X_test, dtype=np.float32) + y_train = np.asarray(y_train, dtype=np.int32) + y_test = np.asarray(y_test, dtype=np.int32) + + return X_train, y_train, X_test, y_test + + +class CIFAR10(Dataset): + """ + Load CIFAR-10 dataset. + + It consists of 60000 32x32 colour images in 10 classes, with + 6000 images per class. There are 50000 training images and 10000 test images. + + Parameters + ---------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/cifar10/``. + shape : tuple + The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3). + """ + def __init__(self, train_or_test, path='raw_data', name='cifar10', shape=(-1, 32, 32, 3)): + self.path = os.path.join(path, name) + + # Download and read the training and test set images and labels. + logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) + + maybe_download_and_extract(CIFAR10_FILENAME, path, CIFAR10_BASE_URL, extract=True) + + assert train_or_test in ['train', 'test'] + if train_or_test == 'train': + # Unpickle file and fill in data + self.images = None + self.labels = [] + for i in range(1, 6): + data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) + if i == 1: + self.images = data_dic['data'] + else: + self.images = np.vstack((self.images, data_dic['data'])) + self.labels += data_dic['labels'] + else: + test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) + self.images = test_data_dic['data'] + self.labels = np.array(test_data_dic['labels']) + + self.images = np.reshape(self.images, shape) + + def __getitem__(self, index): + img = np.array(self.images[index], dtype=np.float32) + label = np.array(self.labels[index], dtype=np.int32) + return img, label + + def __len__(self): + return self.images.shape[0] diff --git a/tensorlayer/data/dataset/cyclegan.py b/tensorlayer/data/dataset/cyclegan.py new file mode 100644 index 000000000..f67da5f7b --- /dev/null +++ b/tensorlayer/data/dataset/cyclegan.py @@ -0,0 +1,133 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +import cv2 +import numpy as np + +from tensorlayer import logging, visualize + +from ..base import Dataset +from ..utils import maybe_download_and_extract, folder_exists, del_file, load_file_list + +__all__ = ['load_cyclegan_dataset', 'CycleGAN', 'CycleGANFiles'] + +CYCLEGAN_BASE_URL = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + + +def load_cyclegan_dataset(name='cyclegan', path='raw_data', filename='summer2winter_yosemite'): + """ + Load images from CycleGAN's database, see `this link `. + + Parameters + ------------ + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/cyclegan`. + filename : str + The dataset you want, see `this link `. + + Examples + --------- + >>> im_train_A, im_train_B, im_test_A, im_test_B = load_cyclegan_dataset(filename='summer2winter_yosemite') + + """ + path = os.path.join(path, name) + + if folder_exists(os.path.join(path, filename)) is False: + logging.info("[*] {} is nonexistent in {}".format(filename, path)) + filepath = maybe_download_and_extract(filename=filename + '.zip', working_directory=path, + url_source=CYCLEGAN_BASE_URL, extract=True) + del_file(filepath) + + def load_image_from_folder(path): + path_imgs = load_file_list(path=path, regx='\\.jpg', printable=False) + return visualize.read_images(path_imgs, path=path, n_threads=10, printable=False) + + im_train_A = load_image_from_folder(os.path.join(path, filename, "trainA")) + im_train_B = load_image_from_folder(os.path.join(path, filename, "trainB")) + im_test_A = load_image_from_folder(os.path.join(path, filename, "testA")) + im_test_B = load_image_from_folder(os.path.join(path, filename, "testB")) + + def if_2d_to_3d(images): # [h, w] --> [h, w, 3] + for i, _v in enumerate(images): + if len(images[i].shape) == 2: + images[i] = images[i][:, :, np.newaxis] + images[i] = np.tile(images[i], (1, 1, 3)) + return images + + im_train_A = if_2d_to_3d(im_train_A) + im_train_B = if_2d_to_3d(im_train_B) + im_test_A = if_2d_to_3d(im_test_A) + im_test_B = if_2d_to_3d(im_test_B) + + return im_train_A, im_train_B, im_test_A, im_test_B + + +class CycleGANFiles(Dataset): + """ + Load image file names from CycleGAN's database, see `this link `. + + Parameters + ------------ + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/cyclegan` + filename : str + The dataset you want, see `this link `. + """ + def __init__(self, train_or_test, name='cyclegan', path='raw_data', filename='summer2winter_yosemite'): + self.path = os.path.join(path, name) + self.train_or_test = train_or_test + + if folder_exists(os.path.join(path, filename)) is False: + logging.info("[*] {} is nonexistent in {}".format(filename, path)) + filepath = maybe_download_and_extract(filename=filename + '.zip', working_directory=path, + url_source=CYCLEGAN_BASE_URL, extract=True) + del_file(filepath) + + assert self.train_or_test in ['train', 'test'] + if self.train_or_test == 'train': + self.im_A_path = load_file_list(path=os.path.join(path, filename, "trainA"), regx='\\.jpg', printable=False) + self.im_B_path = load_file_list(path=os.path.join(path, filename, "trainB"), regx='\\.jpg', printable=False) + else: + self.im_A_path = load_file_list(path=os.path.join(path, filename, "testA"), regx='\\.jpg', printable=False) + self.im_B_path = load_file_list(path=os.path.join(path, filename, "testB"), regx='\\.jpg', printable=False) + + def __getitem__(self, index): + return self.im_A_path[index], self.im_B_path[index] + + def __len__(self): + assert len(self.im_A_path) == len(self.im_B_path) + return len(self.im_A_path) + + +class CycleGAN(CycleGANFiles): + """ + Load images from CycleGAN's database, see `this link `. + + Parameters + ------------ + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/cyclegan` + filename : str + The dataset you want, see `this link `. + """ + def __init__(self, train_or_test, name='cyclegan', path='raw_data', filename='summer2winter_yosemite'): + super(CycleGAN, self).__init__(train_or_test, name, path, filename) + + def __getitem__(self, index): + imA = cv2.imread(self.im_A_path) + imB = cv2.imread(self.im_B_path) + imA = np.array(imA, dtype=np.float32) + imB = np.array(imB, dtype=np.float32) + return imA, imB diff --git a/tensorlayer/data/dataset/flickr_1M.py b/tensorlayer/data/dataset/flickr_1M.py new file mode 100644 index 000000000..0b9387656 --- /dev/null +++ b/tensorlayer/data/dataset/flickr_1M.py @@ -0,0 +1,128 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging, visualize + +from ..utils import del_file, folder_exists, load_file_list, load_folder_list, maybe_download_and_extract, read_file + +__all__ = ['load_flickr1M_dataset'] + +IMAGES_ZIP = [ + 'images0.zip', 'images1.zip', 'images2.zip', 'images3.zip', 'images4.zip', 'images5.zip', 'images6.zip', + 'images7.zip', 'images8.zip', 'images9.zip' +] +TAG_ZIP = 'tags.zip' +FLICKR1M_BASE_URL = 'http://press.liacs.nl/mirflickr/mirflickr1m/' + + +def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printable=False): + """Load Flick1M dataset. + + Returns a list of images by a given tag from Flickr1M dataset, + it will download Flickr1M from `the official website `__ + at the first time you use it. + + Parameters + ------------ + tag : str or None + What images to return. + - If you want to get images with tag, use string like 'dog', 'red', see `Flickr Search `__. + - If you want to get all images, set to ``None``. + + size : int + integer between 1 to 10. 1 means 100k images ... 5 means 500k images, 10 means all 1 million images. Default is 10. + path : str + The path that the data is downloaded to, defaults is ``data/flickr25k/``. + n_threads : int + The number of thread to read image. + printable : boolean + Whether to print infomation when reading images, default is ``False``. + + Examples + ---------- + Use 200k images + + >>> images = tl.files.load_flickr1M_dataset(tag='zebra', size=2) + + Use 1 Million images + + >>> images = tl.files.load_flickr1M_dataset(tag='zebra') + + """ + import shutil + + path = os.path.join(path, 'flickr1M') + logging.info("[Flickr1M] using {}% of images = {}".format(size * 10, size * 100000)) + + # download dataset + for image_zip in IMAGES_ZIP[0:size]: + image_folder = image_zip.split(".")[0] + # logging.info(path+"/"+image_folder) + if folder_exists(os.path.join(path, image_folder)) is False: + # logging.info(image_zip) + logging.info("[Flickr1M] {} is missing in {}".format(image_folder, path)) + maybe_download_and_extract(image_zip, path, FLICKR1M_BASE_URL+image_zip, extract=True) + del_file(os.path.join(path, image_zip)) + # os.system("mv {} {}".format(os.path.join(path, 'images'), os.path.join(path, image_folder))) + shutil.move(os.path.join(path, 'images'), os.path.join(path, image_folder)) + else: + logging.info("[Flickr1M] {} exists in {}".format(image_folder, path)) + + # download tag + if folder_exists(os.path.join(path, "tags")) is False: + logging.info("[Flickr1M] tag files is nonexistent in {}".format(path)) + maybe_download_and_extract(TAG_ZIP, path, FLICKR1M_BASE_URL+TAG_ZIP, extract=True) + del_file(os.path.join(path, TAG_ZIP)) + else: + logging.info("[Flickr1M] tags exists in {}".format(path)) + + # 1. image path list + images_list = [] + images_folder_list = [] + for i in range(0, size): + images_folder_list += load_folder_list(path=os.path.join(path, 'images%d' % i)) + images_folder_list.sort(key=lambda s: int(s.split('/')[-1])) # folder/images/ddd + + for folder in images_folder_list[0:size * 10]: + tmp = load_file_list(path=folder, regx='\\.jpg', printable=False) + tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.jpg + images_list.extend([os.path.join(folder, x) for x in tmp]) + + # 2. tag path list + tag_list = [] + tag_folder_list = load_folder_list(os.path.join(path, "tags")) + + # tag_folder_list.sort(key=lambda s: int(s.split("/")[-1])) # folder/images/ddd + tag_folder_list.sort(key=lambda s: int(os.path.basename(s))) + + for folder in tag_folder_list[0:size * 10]: + tmp = load_file_list(path=folder, regx='\\.txt', printable=False) + tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.txt + tmp = [os.path.join(folder, s) for s in tmp] + tag_list += tmp + + # 3. select images + logging.info("[Flickr1M] searching tag: {}".format(tag)) + select_images_list = [] + for idx, _val in enumerate(tag_list): + tags = read_file(tag_list[idx]).split('\n') + if tag in tags: + select_images_list.append(images_list[idx]) + + logging.info("[Flickr1M] reading images with tag: {}".format(tag)) + images = visualize.read_images(select_images_list, '', n_threads=n_threads, printable=printable) + return images + + +# class Flickr1M(Dataset): +# +# def __init__(self): +# pass +# +# def __getitem__(self, index): +# pass +# +# def __len__(self): +# pass diff --git a/tensorlayer/data/dataset/flickr_25k.py b/tensorlayer/data/dataset/flickr_25k.py new file mode 100644 index 000000000..8049a0653 --- /dev/null +++ b/tensorlayer/data/dataset/flickr_25k.py @@ -0,0 +1,81 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging, visualize +from tensorlayer.files.utils import ( + del_file, folder_exists, load_file_list, maybe_download_and_extract, natural_keys, read_file +) + +__all__ = ['load_flickr25k_dataset'] + + +def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False): + """Load Flickr25K dataset. + + Returns a list of images by a given tag from Flick25k dataset, + it will download Flickr25k from `the official website `__ + at the first time you use it. + + Parameters + ------------ + tag : str or None + What images to return. + - If you want to get images with tag, use string like 'dog', 'red', see `Flickr Search `__. + - If you want to get all images, set to ``None``. + + path : str + The path that the data is downloaded to, defaults is ``data/flickr25k/``. + n_threads : int + The number of thread to read image. + printable : boolean + Whether to print infomation when reading images, default is ``False``. + + Examples + ----------- + Get images with tag of sky + + >>> images = tl.files.load_flickr25k_dataset(tag='sky') + + Get all images + + >>> images = tl.files.load_flickr25k_dataset(tag=None, n_threads=100, printable=True) + + """ + path = os.path.join(path, 'flickr25k') + + filename = 'mirflickr25k.zip' + url = 'http://press.liacs.nl/mirflickr/mirflickr25k/' + + # download dataset + if folder_exists(os.path.join(path, "mirflickr")) is False: + logging.info("[*] Flickr25k is nonexistent in {}".format(path)) + maybe_download_and_extract(filename, path, url, extract=True) + del_file(os.path.join(path, filename)) + + # return images by the given tag. + # 1. image path list + folder_imgs = os.path.join(path, "mirflickr") + path_imgs = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False) + path_imgs.sort(key=natural_keys) + + # 2. tag path list + folder_tags = os.path.join(path, "mirflickr", "meta", "tags") + path_tags = load_file_list(path=folder_tags, regx='\\.txt', printable=False) + path_tags.sort(key=natural_keys) + + # 3. select images + if tag is None: + logging.info("[Flickr25k] reading all images") + else: + logging.info("[Flickr25k] reading images with tag: {}".format(tag)) + images_list = [] + for idx, _v in enumerate(path_tags): + tags = read_file(os.path.join(folder_tags, path_tags[idx])).split('\n') + # logging.info(idx+1, tags) + if tag is None or tag in tags: + images_list.append(path_imgs[idx]) + + images = visualize.read_images(images_list, folder_imgs, n_threads=n_threads, printable=printable) + return images diff --git a/tensorlayer/data/dataset/ilsvrc.py b/tensorlayer/data/dataset/ilsvrc.py new file mode 100644 index 000000000..5cd1b0ec3 --- /dev/null +++ b/tensorlayer/data/dataset/ilsvrc.py @@ -0,0 +1,191 @@ +import numpy as np +import os +import logging +import cv2 + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files'] + +CAFFE_ILSVRC12_META_BASE_URL = 'http://dl.caffe.berkeleyvision.org/' +CAFFE_ILSVRC12_META_FILENAME = 'caffe_ilsvrc12.tar.gz' + + +class ILSVRCMeta(object): + """ + Provide methods to access metadata for ILSVRC dataset. + Metadata is supposed to be found at/will be downloaded to 'path/name/' + + Parameters + ---------- + name : str + The name of the dataset + path : str + The path that the data is downloaded to, defaults is `raw_data/ilsvrc` + + Examples + -------- + >>> meta = ILSVRCMeta(path='raw_data', name='ilsvrc') + >>> imglist = meta.get_image_list(train_or_val_or_test, dir_structure) + + """ + + def __init__(self, name='ilsvrc', path='raw_data'): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + logging.info("Load or Download {0} > {1}".format(name.upper(), self.path)) + self.filepath = maybe_download_and_extract(CAFFE_ILSVRC12_META_FILENAME, self.path, CAFFE_ILSVRC12_META_BASE_URL, extract=True) + self.caffepb = None + + def get_synset_words_1000(self): + """ + Returns: + dict: {cls_number: cls_name} + """ + fname = os.path.join(self.path, 'synset_words.txt') + assert os.path.isfile(fname), fname + lines = [x.strip() for x in open(fname).readlines()] + return dict(enumerate(lines)) + + def get_synset_1000(self): + """ + Returns: + dict: {cls_number: synset_id} + """ + fname = os.path.join(self.path, 'synsets.txt') + assert os.path.isfile(fname) + lines = [x.strip() for x in open(fname).readlines()] + return dict(enumerate(lines)) + + def get_image_list(self, name): + """ + Args: + name (str): 'train' or 'val' or 'test' + Returns: + list: list of (image filename, label) + """ + assert name in ['train', 'val', 'test'] + + fname = os.path.join(self.path, name + '.txt') + assert os.path.isfile(fname), fname + with open(fname) as f: + ret = [] + for line in f.readlines(): + name, cls = line.strip().split() + cls = int(cls) + ret.append((name.strip(), cls)) + assert len(ret), fname + return ret + + +class ILSVRC12Files(Dataset): + """ + Load ILSVRC12 dataset. Produce filenames of images and their corresponding labels. + Labels are between [0, 999]. + + Parameters + ----------- + train_or_test_or_val : str + Must be either 'train' or 'test' or 'val'. Choose the training or test or validation dataset. + meta_dir : str + The path that the metadata is located. Will automatically download and extract if it is not found. + path : str + The path of the ILSVRC12 dataset. + + + The dataset should have the structure: + --------------------------------------- + path/ + train/ + n02134418/ + n02134418_198.JPEG + ... + ... + val/ + ILSVRC2012_val_00000001.JPEG + ... + test/ + ILSVRC2012_test_00000001.JPEG + ... + --------------------------------------- + With the downloaded ILSVRC12_img_*.tar, you can use the following + command to build the above structure: + + mkdir val && tar xvf ILSVRC12_img_val.tar -C val + mkdir test && tar xvf ILSVRC12_img_test.tar -C test + mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train + find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' + """ + def __init__(self, train_or_test_or_val, meta_dir, path): + """ + Same as in :class:`ILSVRC12`. + """ + assert train_or_test_or_val in ['train', 'test', 'val'] + path = os.path.expanduser(path) + assert os.path.isdir(path) + self.full_path = os.path.join(path, train_or_test_or_val) + self.path = train_or_test_or_val + + meta = ILSVRCMeta(path=meta_dir) + self.imglist = meta.get_image_list(train_or_test_or_val) + + def __len__(self): + return len(self.imglist) + + def __getitem__(self, index): + fname, label = self.imglist[index] + fname = os.path.join(self.full_path, fname) + return fname, label + + +class ILSVRC12(ILSVRC12Files): + """ + Load ILSVRC12 dataset. Produce images and a label between [0, 999]. + + Parameters + ----------- + train_or_test_or_val : str + Must be either 'train' or 'test' or 'val'. Choose the training or test or validation dataset. + meta_dir : str + The path that the metadata is located. Will automatically download and extract if it is not found. + path : str + The path of the ILSVRC12 dataset. + shape : tuple + When shape is None, return the original image. If set, return the resized image. + + + The dataset should have the structure: + --------------------------------------- + path/ + train/ + n02134418/ + n02134418_198.JPEG + ... + ... + val/ + ILSVRC2012_val_00000001.JPEG + ... + test/ + ILSVRC2012_test_00000001.JPEG + ... + --------------------------------------- + With the downloaded ILSVRC12_img_*.tar, you can use the following + command to build the above structure: + + mkdir val && tar xvf ILSVRC12_img_val.tar -C val + mkdir test && tar xvf ILSVRC12_img_test.tar -C test + mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train + find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' + """ + def __init__(self, train_or_test_or_val, meta_dir, path, shape=None): + super(ILSVRC12, self).__init__(train_or_test_or_val, meta_dir, path) + self.shape = shape + + def __getitem__(self, index): + fname, label = super(ILSVRC12, self).__getitem__(index) + img = cv2.imread(fname, cv2.IMREAD_COLOR) + if self.shape is not None: + img = cv2.resize(img, self.shape) + img = np.array(img, dtype=np.float32) + return img, label diff --git a/tensorlayer/data/dataset/imdb.py b/tensorlayer/data/dataset/imdb.py new file mode 100644 index 000000000..aa1fe31f1 --- /dev/null +++ b/tensorlayer/data/dataset/imdb.py @@ -0,0 +1,159 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +import numpy as np +import six.moves.cPickle as pickle + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_imdb_dataset', 'IMDB'] + +IMDB_BASE_URL = 'https://s3.amazonaws.com/text-datasets/' +IMDB_FILENAME = 'imdb.pkl' + + +def load_imdb_dataset( + name='imdb', path='raw_data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, + oov_char=2, index_from=3): + """ + Load IMDB dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/imdb/``. + nb_words : int + Number of words to get. + skip_top : int + Top most frequent words to ignore (they will appear as oov_char value in the sequence data). + maxlen : int + Maximum sequence length. Any longer sequence will be truncated. + test_split : float + Split of train / test dataset. + seed : int + Seed for reproducible data shuffling. + start_char : int + The start of a sequence will be marked with this character. Set to 1 because 0 is usually the padding character. + oov_char : int + Words that were cut out because of the num_words or skip_top limit will be replaced with this character. + index_from : int + Index actual words with this index and higher. + + Examples + -------- + >>> X_train, y_train, X_test, y_test = load_imdb_dataset(nb_words=20000, test_split=0.2) + >>> print('X_train.shape', X_train.shape) + (20000,) [[1, 62, 74, ... 1033, 507, 27],[1, 60, 33, ... 13, 1053, 7]..] + >>> print('y_train.shape', y_train.shape) + (20000,) [1 0 0 ..., 1 0 1] + + References + ----------- + - `Modified from keras. `__ + + """ + X, labels = _load_raw_imdb(path, name) + + np.random.seed(seed) + np.random.shuffle(X) + np.random.seed(seed) + np.random.shuffle(labels) + + X, labels = _preprocess_imdb(X, index_from, labels, maxlen, nb_words, oov_char, skip_top, start_char) + + X_train = np.array(X[:int(len(X) * (1 - test_split))]) + y_train = np.array(labels[:int(len(X) * (1 - test_split))]) + + X_test = np.array(X[int(len(X) * (1 - test_split)):]) + y_test = np.array(labels[int(len(X) * (1 - test_split)):]) + + return X_train, y_train, X_test, y_test + + +def _preprocess_imdb(X, index_from, labels, maxlen, nb_words, oov_char, skip_top, start_char): + if start_char is not None: + X = [[start_char] + [w + index_from for w in x] for x in X] + elif index_from: + X = [[w + index_from for w in x] for x in X] + if maxlen: + new_X = [] + new_labels = [] + for x, y in zip(X, labels): + if len(x) < maxlen: + new_X.append(x) + new_labels.append(y) + X = new_X + labels = new_labels + if not X: + raise Exception( + 'After filtering for sequences shorter than maxlen=' + str(maxlen) + ', no sequence was kept. ' + 'Increase maxlen.' + ) + if not nb_words: + nb_words = max([max(x) for x in X]) + # by convention, use 2 as OOV word + # reserve 'index_from' (=3 by default) characters: 0 (padding), 1 (start), 2 (OOV) + if oov_char is not None: + X = [[oov_char if (w >= nb_words or w < skip_top) else w for w in x] for x in X] + else: + nX = [] + for x in X: + nx = [] + for w in x: + if (w >= nb_words or w < skip_top): + nx.append(w) + nX.append(nx) + X = nX + return X, labels + + +def _load_raw_imdb(path, name): + path = os.path.join(path, name) + maybe_download_and_extract(IMDB_FILENAME, path, IMDB_BASE_URL) + f = open(os.path.join(path, IMDB_FILENAME), 'rb') + X, labels = pickle.load(f) + f.close() + return X, labels + + +class IMDB(Dataset): + """ + Load IMDB dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/imdb/``. + nb_words : int + Number of words to get. + skip_top : int + Top most frequent words to ignore (they will appear as oov_char value in the sequence data). + maxlen : int + Maximum sequence length. Any longer sequence will be truncated. + start_char : int + The start of a sequence will be marked with this character. Set to 1 because 0 is usually the padding character. + oov_char : int + Words that were cut out because of the num_words or skip_top limit will be replaced with this character. + index_from : int + Index actual words with this index and higher. + """ + + def __init__(self, name='imdb', path='raw_data', nb_words=None, skip_top=0, maxlen=None, start_char=1, oov_char=2, + index_from=3): + self.X, self.labels = _load_raw_imdb(path=path, name=name) + self.X, self.labels = _preprocess_imdb(self.X, index_from, self.labels, maxlen, nb_words, oov_char, skip_top, + start_char) + + def __getitem__(self, index): + return self.X[index], self.labels[index] + + def __len__(self): + assert len(self.X) == len(self.labels) + return len(self.labels) diff --git a/tensorlayer/data/dataset/matt_mahoney.py b/tensorlayer/data/dataset/matt_mahoney.py new file mode 100644 index 000000000..7f03b9ef3 --- /dev/null +++ b/tensorlayer/data/dataset/matt_mahoney.py @@ -0,0 +1,76 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os +import zipfile + +from tensorlayer import logging + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_matt_mahoney_text8_dataset', 'MattMahoney'] + +MATT_MAHONEY_BASE_URL = 'http://mattmahoney.net/dc/' +MATT_MAHONEY_FILENAME = 'text8.zip' + + +def load_matt_mahoney_text8_dataset(name='mm_test8', path='raw_data'): + """ + Load Matt Mahoney's dataset. + + Download a text file from Matt Mahoney's website + if not present, and make sure it's the right size. + Extract the first file enclosed in a zip file as a list of words. + This dataset can be used for Word Embedding. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mm_test8/``. + + Returns + -------- + list of str + The raw text data e.g. [.... 'their', 'families', 'who', 'were', 'expelled', 'from', 'jerusalem', ...] + + Examples + -------- + >>> words = load_matt_mahoney_text8_dataset() + >>> print('Data size', len(words)) + + """ + path = os.path.join(path, name) + logging.info("Load or Download matt_mahoney_text8 Dataset> {}".format(path)) + + maybe_download_and_extract(MATT_MAHONEY_FILENAME, path, MATT_MAHONEY_BASE_URL, expected_bytes=31344016) + + with zipfile.ZipFile(os.path.join(path, MATT_MAHONEY_FILENAME)) as f: + word_list = f.read(f.namelist()[0]).split() + for idx, _ in enumerate(word_list): + word_list[idx] = word_list[idx].decode() + return word_list + + +class MattMahoney(Dataset): + """ + Load Matt Mahoney's dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mm_test8/``. + """ + + def __init__(self, name='mm_test8', path='raw_data'): + self.word_list = load_matt_mahoney_text8_dataset(path=path, name=name) + + def __getitem__(self, index): + return self.word_list[index] + + def __len__(self): + return len(self.word_list) diff --git a/tensorlayer/data/dataset/mnist.py b/tensorlayer/data/dataset/mnist.py new file mode 100644 index 000000000..02df262ff --- /dev/null +++ b/tensorlayer/data/dataset/mnist.py @@ -0,0 +1,115 @@ +import gzip +import logging +import os +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['MNIST', 'load_mnist_dataset'] + +MNIST_BASE_URL = 'http://yann.lecun.com/exdb/mnist/' +MNIST_TRAIN_IMAGE_FILENAME = 'train-images-idx3-ubyte.gz' +MNIST_TRAIN_LABEL_FILENAME = 'train-labels-idx1-ubyte.gz' +MNIST_TEST_IMAGE_FILENAME = 't10k-images-idx3-ubyte.gz' +MNIST_TEST_LABEL_FILENAME = 't10k-labels-idx1-ubyte.gz' + + +def _load_mnist_images(name, url, path, shape): + filepath = maybe_download_and_extract(name, path, url) + + logging.info(filepath) + # Read the inputs in Yann LeCun's binary format. + with gzip.open(filepath, 'rb') as f: + data = np.frombuffer(f.read(), np.uint8, offset=16) + # The inputs are vectors now, we reshape them to monochrome 2D images, + # following the shape convention: (examples, channels, rows, columns) + data = data.reshape(shape) + # The inputs come as bytes, we convert them to float32 in range [0,1]. + # (Actually to range [0, 255/256], for compatibility to the version + # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) + return data / np.float32(256) + + +def _load_mnist_labels(name, url, path): + filepath = maybe_download_and_extract(name, path, url) + + # Read the labels in Yann LeCun's binary format. + with gzip.open(filepath, 'rb') as f: + data = np.frombuffer(f.read(), np.uint8, offset=8) + # The labels are vectors of integers now, that's exactly what we want. + return data + + +def load_mnist_dataset(shape=(-1, 784), name='mnist', path='raw_data'): + """ + A generic function to load mnist-like dataset. + + Parameters: + ---------- + shape : tuple + The shape of digit images. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mnist/``. + """ + path = os.path.join(path, name) + + # Download and read the training and test set images and labels. + logging.info("Load or Download {0} > {1}".format(name.upper(), path)) + X_train = _load_mnist_images(name=MNIST_TRAIN_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, shape=shape) + y_train = _load_mnist_labels(name=MNIST_TRAIN_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + X_test = _load_mnist_images(name=MNIST_TEST_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, shape=shape) + y_test = _load_mnist_labels(name=MNIST_TEST_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + + # We reserve the last 10000 training examples for validation. + X_train, X_val = X_train[:-10000], X_train[-10000:] + y_train, y_val = y_train[:-10000], y_train[-10000:] + + # We just return all the arrays in order, as expected in main(). + # (It doesn't matter how we do this as long as we can read them again.) + X_train = np.asarray(X_train, dtype=np.float32) + y_train = np.asarray(y_train, dtype=np.int32) + X_val = np.asarray(X_val, dtype=np.float32) + y_val = np.asarray(y_val, dtype=np.int32) + X_test = np.asarray(X_test, dtype=np.float32) + y_test = np.asarray(y_test, dtype=np.int32) + return X_train, y_train, X_val, y_val, X_test, y_test + + +class MNIST(Dataset): + """ + Load MNIST dataset. + + Parameters: + ---------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + shape : tuple + The shape of digit images. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/mnist/``. + """ + def __init__(self, train_or_test, path='raw_data', name='mnist', shape=(-1, 784)): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + + assert train_or_test in ['train', 'test'] + self.train_or_test = train_or_test + if train_or_test == 'train': + self.images = _load_mnist_images(name=MNIST_TRAIN_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, + shape=shape) + self.labels = _load_mnist_labels(name=MNIST_TRAIN_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + else: + self.images = _load_mnist_images(name=MNIST_TEST_IMAGE_FILENAME, url=MNIST_BASE_URL, path=path, + shape=shape) + self.labels = _load_mnist_labels(name=MNIST_TEST_LABEL_FILENAME, url=MNIST_BASE_URL, path=path) + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, index): + return self.images[index], self.labels[index] diff --git a/tensorlayer/data/dataset/mnist_fashion.py b/tensorlayer/data/dataset/mnist_fashion.py new file mode 100644 index 000000000..d0025563e --- /dev/null +++ b/tensorlayer/data/dataset/mnist_fashion.py @@ -0,0 +1,108 @@ +import logging +import os +import numpy as np + +from ..base import Dataset +from .mnist import _load_mnist_images, _load_mnist_labels + +__all__ = ['load_fashion_mnist_dataset', 'FASHION_MNIST'] + +FASHION_MNIST_BASE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' +FASHION_MNIST_TRAIN_IMAGE_FILENAME = 'train-images-idx3-ubyte.gz' +FASHION_MNIST_TRAIN_LABEL_FILENAME = 'train-labels-idx1-ubyte.gz' +FASHION_MNIST_TEST_IMAGE_FILENAME = 't10k-images-idx3-ubyte.gz' +FASHION_MNIST_TEST_LABEL_FILENAME = 't10k-labels-idx1-ubyte.gz' + + +def load_fashion_mnist_dataset(shape=(-1, 784), name='fashion_mnist', path='raw_data'): + """ + Load the fashion mnist. + + Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, `examples `__. + + Parameters + ---------- + shape : tuple + The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)). + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/fashion_mnist/``. + + Returns + ------- + X_train, y_train, X_val, y_val, X_test, y_test: tuple + Return splitted training/validation/test set respectively. + + Examples + -------- + >>> X_train, y_train, X_val, y_val, X_test, y_test = load_fashion_mnist_dataset(shape=(-1,784), path='datasets') + >>> X_train, y_train, X_val, y_val, X_test, y_test = load_fashion_mnist_dataset(shape=(-1, 28, 28, 1)) + """ + path = os.path.join(path, name) + + # Download and read the training and test set images and labels. + logging.info("Load or Download {0} > {1}".format(name.upper(), path)) + X_train = _load_mnist_images(name=FASHION_MNIST_TRAIN_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, path=path, + shape=shape) + y_train = _load_mnist_labels(name=FASHION_MNIST_TRAIN_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, path=path) + X_test = _load_mnist_images(name=FASHION_MNIST_TEST_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, path=path, + shape=shape) + y_test = _load_mnist_labels(name=FASHION_MNIST_TEST_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, path=path) + + # We reserve the last 10000 training examples for validation. + X_train, X_val = X_train[:-10000], X_train[-10000:] + y_train, y_val = y_train[:-10000], y_train[-10000:] + + # We just return all the arrays in order, as expected in main(). + # (It doesn't matter how we do this as long as we can read them again.) + X_train = np.asarray(X_train, dtype=np.float32) + y_train = np.asarray(y_train, dtype=np.int32) + X_val = np.asarray(X_val, dtype=np.float32) + y_val = np.asarray(y_val, dtype=np.int32) + X_test = np.asarray(X_test, dtype=np.float32) + y_test = np.asarray(y_test, dtype=np.int32) + return X_train, y_train, X_val, y_val, X_test, y_test + + +class FASHION_MNIST(Dataset): + """ + Load the fashion mnist. + + Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, `examples `__. + + Parameters + ---------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + shape : tuple + The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)). + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/fashion_mnist/``. + """ + + def __init__(self, train_or_test, name='fashion_mnist', path='raw_data', shape=(-1, 784)): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + + assert train_or_test in ['train', 'test'] + if train_or_test == 'train': + self.images = _load_mnist_images(name=FASHION_MNIST_TRAIN_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path, + shape=shape) + self.labels = _load_mnist_labels(name=FASHION_MNIST_TRAIN_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path) + else: + self.images = _load_mnist_images(name=FASHION_MNIST_TEST_IMAGE_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path, + shape=shape) + self.labels = _load_mnist_labels(name=FASHION_MNIST_TEST_LABEL_FILENAME, url=FASHION_MNIST_BASE_URL, + path=path) + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, index): + return self.images[index], self.labels[index] diff --git a/tensorlayer/data/dataset/mpii.py b/tensorlayer/data/dataset/mpii.py new file mode 100644 index 000000000..85faec09d --- /dev/null +++ b/tensorlayer/data/dataset/mpii.py @@ -0,0 +1,297 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging + +from ..base import Dataset +from ..utils import del_file, folder_exists, load_file_list, maybe_download_and_extract + +__all__ = ['load_mpii_pose_dataset', 'MPII'] + +MPII_BASE_URL = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/" + + +def load_mpii_pose_dataset(name='mpii_human_pose', path='raw_data', is_16_pos_only=False): + """ + Load MPII Human Pose Dataset. + + Parameters + ----------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/mpii_human_pose`. + is_16_pos_only : boolean + If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation) + + Returns + ---------- + img_train_list : list of str + The image directories of training data. + ann_train_list : list of dict + The annotations of training data. + img_test_list : list of str + The image directories of testing data. + ann_test_list : list of dict + The annotations of testing data. + + Examples + -------- + >>> import pprint + >>> import tensorlayer as tl + >>> img_train_list, ann_train_list, img_test_list, ann_test_list = load_mpii_pose_dataset() + >>> image = tl.vis.read_image(img_train_list[0]) + >>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png') + >>> pprint.pprint(ann_train_list[0]) + + References + ----------- + - `MPII Human Pose Dataset. CVPR 14 `__ + - `MPII Human Pose Models. CVPR 16 `__ + - `MPII Human Shape, Poselet Conditioned Pictorial Structures and etc `__ + - `MPII Keyponts and ID `__ + """ + path = os.path.join(path, name) + logging.info("Load or Download MPII Human Pose > {}".format(path)) + + # annotation + tar_filename = "mpii_human_pose_v1_u12_2.zip" + extracted_filename = "mpii_human_pose_v1_u12_2" + if folder_exists(os.path.join(path, extracted_filename)) is False: + logging.info("[MPII] (annotation) {} is nonexistent in {}".format(extracted_filename, path)) + maybe_download_and_extract(tar_filename, path, MPII_BASE_URL+tar_filename, extract=True) + del_file(os.path.join(path, tar_filename)) + + # images + tar_filename = "mpii_human_pose_v1.tar.gz" + extracted_filename2 = "images" + if folder_exists(os.path.join(path, extracted_filename2)) is False: + logging.info("[MPII] (images) {} is nonexistent in {}".format(extracted_filename, path)) + maybe_download_and_extract(tar_filename, path, MPII_BASE_URL+tar_filename, extract=True) + del_file(os.path.join(path, tar_filename)) + + # parse annotation, format see http://human-pose.mpi-inf.mpg.de/#download + import scipy.io as sio + logging.info("reading annotations from mat file ...") + # mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat")) + + # def fix_wrong_joints(joint): # https://github.com/mitmul/deeppose/blob/master/datasets/mpii_dataset.py + # if '12' in joint and '13' in joint and '2' in joint and '3' in joint: + # if ((joint['12'][0] < joint['13'][0]) and + # (joint['3'][0] < joint['2'][0])): + # joint['2'], joint['3'] = joint['3'], joint['2'] + # if ((joint['12'][0] > joint['13'][0]) and + # (joint['3'][0] > joint['2'][0])): + # joint['2'], joint['3'] = joint['3'], joint['2'] + # return joint + + ann_train_list = [] + ann_test_list = [] + img_train_list = [] + img_test_list = [] + + def save_joints(): + # joint_data_fn = os.path.join(path, 'data.json') + # fp = open(joint_data_fn, 'w') + mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat")) + + for _, (anno, train_flag) in enumerate( # all images + zip(mat['RELEASE']['annolist'][0, 0][0], mat['RELEASE']['img_train'][0, 0][0])): + + img_fn = anno['image']['name'][0, 0][0] + train_flag = int(train_flag) + + # print(i, img_fn, train_flag) # DEBUG print all images + + if train_flag: + img_train_list.append(img_fn) + ann_train_list.append([]) + else: + img_test_list.append(img_fn) + ann_test_list.append([]) + + head_rect = [] + if 'x1' in str(anno['annorect'].dtype): + head_rect = zip( + [x1[0, 0] for x1 in anno['annorect']['x1'][0]], [y1[0, 0] for y1 in anno['annorect']['y1'][0]], + [x2[0, 0] for x2 in anno['annorect']['x2'][0]], [y2[0, 0] for y2 in anno['annorect']['y2'][0]] + ) + else: + head_rect = [] # TODO + + if 'annopoints' in str(anno['annorect'].dtype): + annopoints = anno['annorect']['annopoints'][0] + head_x1s = anno['annorect']['x1'][0] + head_y1s = anno['annorect']['y1'][0] + head_x2s = anno['annorect']['x2'][0] + head_y2s = anno['annorect']['y2'][0] + + for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(annopoints, head_x1s, head_y1s, head_x2s, + head_y2s): + # if annopoint != []: + # if len(annopoint) != 0: + if annopoint.size: + head_rect = [ + float(head_x1[0, 0]), + float(head_y1[0, 0]), + float(head_x2[0, 0]), + float(head_y2[0, 0]) + ] + + # joint coordinates + annopoint = annopoint['point'][0, 0] + j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]] + x = [x[0, 0] for x in annopoint['x'][0]] + y = [y[0, 0] for y in annopoint['y'][0]] + joint_pos = {} + for _j_id, (_x, _y) in zip(j_id, zip(x, y)): + joint_pos[int(_j_id)] = [float(_x), float(_y)] + # joint_pos = fix_wrong_joints(joint_pos) + + # visibility list + if 'is_visible' in str(annopoint.dtype): + vis = [v[0] if v.size > 0 else [0] for v in annopoint['is_visible'][0]] + vis = dict([(k, int(v[0])) if len(v) > 0 else v for k, v in zip(j_id, vis)]) + else: + vis = None + + # if len(joint_pos) == 16: + if ((is_16_pos_only ==True) and (len(joint_pos) == 16)) or (is_16_pos_only == False): + # only use image with 16 key points / or use all + data = { + 'filename': img_fn, + 'train': train_flag, + 'head_rect': head_rect, + 'is_visible': vis, + 'joint_pos': joint_pos + } + # print(json.dumps(data), file=fp) # py3 + if train_flag: + ann_train_list[-1].append(data) + else: + ann_test_list[-1].append(data) + + # def write_line(datum, fp): + # joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()]) + # joints = np.array([j for i, j in joints]).flatten() + # + # out = [datum['filename']] + # out.extend(joints) + # out = [str(o) for o in out] + # out = ','.join(out) + # + # print(out, file=fp) + + # def split_train_test(): + # # fp_test = open('data/mpii/test_joints.csv', 'w') + # fp_test = open(os.path.join(path, 'test_joints.csv'), 'w') + # # fp_train = open('data/mpii/train_joints.csv', 'w') + # fp_train = open(os.path.join(path, 'train_joints.csv'), 'w') + # # all_data = open('data/mpii/data.json').readlines() + # all_data = open(os.path.join(path, 'data.json')).readlines() + # N = len(all_data) + # N_test = int(N * 0.1) + # N_train = N - N_test + # + # print('N:{}'.format(N)) + # print('N_train:{}'.format(N_train)) + # print('N_test:{}'.format(N_test)) + # + # np.random.seed(1701) + # perm = np.random.permutation(N) + # test_indices = perm[:N_test] + # train_indices = perm[N_test:] + # + # print('train_indices:{}'.format(len(train_indices))) + # print('test_indices:{}'.format(len(test_indices))) + # + # for i in train_indices: + # datum = json.loads(all_data[i].strip()) + # write_line(datum, fp_train) + # + # for i in test_indices: + # datum = json.loads(all_data[i].strip()) + # write_line(datum, fp_test) + + save_joints() + # split_train_test() # + + ## read images dir + logging.info("reading images list ...") + img_dir = os.path.join(path, extracted_filename2) + _img_list = load_file_list(path=os.path.join(path, extracted_filename2), regx='\\.jpg', printable=False) + # ann_list = json.load(open(os.path.join(path, 'data.json'))) + for i, im in enumerate(img_train_list): + if im not in _img_list: + print('missing training image {} in {} (remove from img(ann)_train_list)'.format(im, img_dir)) + # img_train_list.remove(im) + del img_train_list[i] + del ann_train_list[i] + for i, im in enumerate(img_test_list): + if im not in _img_list: + print('missing testing image {} in {} (remove from img(ann)_test_list)'.format(im, img_dir)) + # img_test_list.remove(im) + del img_train_list[i] + del ann_train_list[i] + + ## check annotation and images + n_train_images = len(img_train_list) + n_test_images = len(img_test_list) + n_images = n_train_images + n_test_images + logging.info("n_images: {} n_train_images: {} n_test_images: {}".format(n_images, n_train_images, n_test_images)) + n_train_ann = len(ann_train_list) + n_test_ann = len(ann_test_list) + n_ann = n_train_ann + n_test_ann + logging.info("n_ann: {} n_train_ann: {} n_test_ann: {}".format(n_ann, n_train_ann, n_test_ann)) + n_train_people = len(sum(ann_train_list, [])) + n_test_people = len(sum(ann_test_list, [])) + n_people = n_train_people + n_test_people + logging.info("n_people: {} n_train_people: {} n_test_people: {}".format(n_people, n_train_people, n_test_people)) + # add path to all image file name + for i, value in enumerate(img_train_list): + img_train_list[i] = os.path.join(img_dir, value) + for i, value in enumerate(img_test_list): + img_test_list[i] = os.path.join(img_dir, value) + return img_train_list, ann_train_list, img_test_list, ann_test_list + + +class MPII(Dataset): + """ + Load MPII Human Pose Dataset. + + Parameters + ----------- + train_or_test : str + Must be either 'train' or 'test'. Choose the training or test dataset. + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is `raw_data/mpii_human_pose`. + is_16_pos_only : boolean + If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation) + """ + + def __init__(self, train_or_test, path='raw_data', name='mpii_human_pose', is_16_pos_only=False): + self.path = os.path.join(path, name) + img_train_list, ann_train_list, img_test_list, ann_test_list = load_mpii_pose_dataset(name=name, path=path, is_16_pos_only=is_16_pos_only) + assert train_or_test in ['train', 'test'] + self.train_or_test = train_or_test + if train_or_test == 'train': + self.img_list = img_train_list + self.ann_list = ann_train_list + del img_test_list + del ann_test_list + else: + self.img_list = img_test_list + self.ann_list = ann_test_list + del img_train_list + del ann_train_list + + def __getitem__(self, index): + return self.img_list[index], self.ann_list[index] + + def __len__(self): + assert len(self.img_list) == len(self.ann_list) + return len(self.img_list) diff --git a/tensorlayer/data/dataset/nietzsche.py b/tensorlayer/data/dataset/nietzsche.py new file mode 100644 index 000000000..5e00e32e4 --- /dev/null +++ b/tensorlayer/data/dataset/nietzsche.py @@ -0,0 +1,70 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_nietzsche_dataset', 'NIETZSCHE'] + +NIETZSCHE_BASE_URL = 'https://s3.amazonaws.com/text-datasets/' +NIETZSCHE_FILENAME = 'nietzsche.txt' + + +def load_nietzsche_dataset(name='nietzsche', path='raw_data'): + """ + Load Nietzsche dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/nietzsche/``. + + Returns + -------- + str + The content. + + Examples + -------- + >>> see tutorial_generate_text.py + >>> words = tl.files.load_nietzsche_dataset() + >>> words = basic_clean_str(words) + >>> words = words.split() + + """ + logging.info("Load or Download nietzsche dataset > {}".format(path)) + path = os.path.join(path, name) + + filepath = maybe_download_and_extract(NIETZSCHE_FILENAME, path, NIETZSCHE_BASE_URL) + + with open(filepath, "r") as f: + words = f.read() + return words + + +class NIETZSCHE(Dataset): + """ + Load Nietzsche dataset. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/nietzsche/``. + """ + + def __init__(self, name='nietzsche', path='raw_data'): + self.words = load_nietzsche_dataset(name=name, path=path) + + def __getitem__(self, index): + return self.words[index] + + def __len__(self): + return len(self.words) diff --git a/tensorlayer/data/dataset/ptb.py b/tensorlayer/data/dataset/ptb.py new file mode 100644 index 000000000..96183c35d --- /dev/null +++ b/tensorlayer/data/dataset/ptb.py @@ -0,0 +1,138 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +from tensorlayer import logging, nlp +import numpy as np + +from ..base import Dataset +from ..utils import maybe_download_and_extract + +__all__ = ['load_ptb_dataset', 'PTB'] + +PTB_URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/' +PTB_FILENAME = 'simple-examples.tgz' + + +def load_ptb_dataset(name='ptb', path='raw_data'): + """ + Load Penn TreeBank (PTB) dataset. + + It is used in many LANGUAGE MODELING papers, + including "Empirical Evaluation and Combination of Advanced Language + Modeling Techniques", "Recurrent Neural Network Regularization". + It consists of 929k training words, 73k validation words, and 82k test + words. It has 10k words in its vocabulary. + + Parameters + ---------- + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/ptb/``. + + Returns + -------- + train_data, valid_data, test_data : list of int + The training, validating and testing data in integer format. + vocab_size : int + The vocabulary size. + + Examples + -------- + >>> train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset() + + References + --------------- + - ``tensorflow.models.rnn.ptb import reader`` + - `Manual download `__ + + Notes + ------ + - If you want to get the raw data, see the source code. + + """ + path = os.path.join(path, name) + logging.info("Load or Download Penn TreeBank (PTB) dataset > {}".format(path)) + + # Maybe dowload and uncompress tar, or load exsisting files + maybe_download_and_extract(PTB_FILENAME, path, PTB_URL, extract=True) + + data_path = os.path.join(path, 'simple-examples', 'data') + train_path = os.path.join(data_path, "ptb.train.txt") + valid_path = os.path.join(data_path, "ptb.valid.txt") + test_path = os.path.join(data_path, "ptb.test.txt") + + word_to_id = nlp.build_vocab(nlp.read_words(train_path)) + + train_data = nlp.words_to_word_ids(nlp.read_words(train_path), word_to_id) + valid_data = nlp.words_to_word_ids(nlp.read_words(valid_path), word_to_id) + test_data = nlp.words_to_word_ids(nlp.read_words(test_path), word_to_id) + vocab_size = len(word_to_id) + + # logging.info(nlp.read_words(train_path)) # ... 'according', 'to', 'mr.', '', ''] + # logging.info(train_data) # ... 214, 5, 23, 1, 2] + # logging.info(word_to_id) # ... 'beyond': 1295, 'anti-nuclear': 9599, 'trouble': 1520, '': 2 ... } + # logging.info(vocabulary) # 10000 + # exit() + return train_data, valid_data, test_data, vocab_size + + +class PTB(Dataset): + """ + Load Penn TreeBank (PTB) dataset. + + It is used in many LANGUAGE MODELING papers, + including "Empirical Evaluation and Combination of Advanced Language + Modeling Techniques", "Recurrent Neural Network Regularization". + It consists of 929k training words, 73k validation words, and 82k test + words. It has 10k words in its vocabulary. + + Parameters + ---------- + train_or_test_or_valid : str + Must be either 'train' or 'test' or 'valid'. Choose the training or test or validation dataset. + num_steps : int + The number of unrolls. i.e. sequence_length + name : str + The name of the dataset. + path : str + The path that the data is downloaded to, defaults is ``raw_data/ptb/``. + """ + + def __init__(self, train_or_test_or_valid, num_steps, name='ptb', path='raw_data'): + path = os.path.expanduser(path) + self.path = os.path.join(path, name) + logging.info("Load or Download Penn TreeBank (PTB) dataset > {}".format(self.path)) + + maybe_download_and_extract(PTB_FILENAME, self.path, PTB_URL, extract=True) + + self.num_steps = num_steps + self.path = os.path.join(self.path, 'simple-examples', 'data') + assert train_or_test_or_valid in ['train', 'test', 'valid'] + self.train_or_test_or_valid = train_or_test_or_valid + train_path = os.path.join(self.path, "ptb.train.txt") + if train_or_test_or_valid == 'train': + data_path = train_path + elif train_or_test_or_valid == 'valid': + data_path = os.path.join(self.path, "ptb.valid.txt") + else: + data_path = os.path.join(self.path, "ptb.test.txt") + + # use training data to build vocab + self.word_to_id = nlp.build_vocab(nlp.read_words(train_path)) + self.vocav_size = len(self.word_to_id) + self.data = nlp.words_to_word_ids(nlp.read_words(data_path), self.word_to_id) + + self.data = np.array(self.data, dtype=np.int32) + self.data_len = (len(self.data) - 1) // self.num_steps + self.data = self.data[:self.data_len * self.num_steps + 1] + + def __getitem__(self, index): + x = self.data[index * self.num_steps:(index + 1) * self.num_steps] + y = self.data[index * self.num_steps + 1:(index + 1) * self.num_steps + 1] + return x, y + + def __len__(self): + return self.data_len diff --git a/tensorlayer/data/dataset/voc.py b/tensorlayer/data/dataset/voc.py new file mode 100644 index 000000000..8fd6dfe8c --- /dev/null +++ b/tensorlayer/data/dataset/voc.py @@ -0,0 +1,334 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import os + +import tensorflow as tf +from tensorlayer import logging, utils +from ..utils import del_file, del_folder, folder_exists, load_file_list, maybe_download_and_extract + +__all__ = ['load_voc_dataset'] + + +def load_voc_dataset(path='data', dataset='2012', contain_classes_in_person=False): + """Pascal VOC 2007/2012 Dataset. + + It has 20 objects: + aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor + and additional 3 classes : head, hand, foot for person. + + Parameters + ----------- + path : str + The path that the data is downloaded to, defaults is ``data/VOC``. + dataset : str + The VOC dataset version, `2012`, `2007`, `2007test` or `2012test`. We usually train model on `2007+2012` and test it on `2007test`. + contain_classes_in_person : boolean + Whether include head, hand and foot annotation, default is False. + + Returns + --------- + imgs_file_list : list of str + Full paths of all images. + imgs_semseg_file_list : list of str + Full paths of all maps for semantic segmentation. Note that not all images have this map! + imgs_insseg_file_list : list of str + Full paths of all maps for instance segmentation. Note that not all images have this map! + imgs_ann_file_list : list of str + Full paths of all annotations for bounding box and object class, all images have this annotations. + classes : list of str + Classes in order. + classes_in_person : list of str + Classes in person. + classes_dict : dictionary + Class label to integer. + n_objs_list : list of int + Number of objects in all images in ``imgs_file_list`` in order. + objs_info_list : list of str + Darknet format for the annotation of all images in ``imgs_file_list`` in order. ``[class_id x_centre y_centre width height]`` in ratio format. + objs_info_dicts : dictionary + The annotation of all images in ``imgs_file_list``, ``{imgs_file_list : dictionary for annotation}``, + format from `TensorFlow/Models/object-detection `__. + + Examples + ---------- + >>> imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list, + >>> classes, classes_in_person, classes_dict, + >>> n_objs_list, objs_info_list, objs_info_dicts = tl.files.load_voc_dataset(dataset="2012", contain_classes_in_person=False) + >>> idx = 26 + >>> print(classes) + ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] + >>> print(classes_dict) + {'sheep': 16, 'horse': 12, 'bicycle': 1, 'bottle': 4, 'cow': 9, 'sofa': 17, 'car': 6, 'dog': 11, 'cat': 7, 'person': 14, 'train': 18, 'diningtable': 10, 'aeroplane': 0, 'bus': 5, 'pottedplant': 15, 'tvmonitor': 19, 'chair': 8, 'bird': 2, 'boat': 3, 'motorbike': 13} + >>> print(imgs_file_list[idx]) + data/VOC/VOC2012/JPEGImages/2007_000423.jpg + >>> print(n_objs_list[idx]) + 2 + >>> print(imgs_ann_file_list[idx]) + data/VOC/VOC2012/Annotations/2007_000423.xml + >>> print(objs_info_list[idx]) + 14 0.173 0.461333333333 0.142 0.496 + 14 0.828 0.542666666667 0.188 0.594666666667 + >>> ann = tl.prepro.parse_darknet_ann_str_to_list(objs_info_list[idx]) + >>> print(ann) + [[14, 0.173, 0.461333333333, 0.142, 0.496], [14, 0.828, 0.542666666667, 0.188, 0.594666666667]] + >>> c, b = tl.prepro.parse_darknet_ann_list_to_cls_box(ann) + >>> print(c, b) + [14, 14] [[0.173, 0.461333333333, 0.142, 0.496], [0.828, 0.542666666667, 0.188, 0.594666666667]] + + References + ------------- + - `Pascal VOC2012 Website `__. + - `Pascal VOC2007 Website `__. + + """ + try: + import lxml.etree as etree + except ImportError as e: + print(e) + raise ImportError("Module lxml not found. Please install lxml via pip or other package managers.") + + path = os.path.join(path, 'VOC') + + def _recursive_parse_xml_to_dict(xml): + """Recursively parses XML contents to python dict. + + We assume that `object` tags are the only ones that can appear + multiple times at the same level of a tree. + + Args: + xml: xml tree obtained by parsing XML file contents using lxml.etree + + Returns: + Python dictionary holding XML contents. + + """ + if xml is not None: + return {xml.tag: xml.text} + result = {} + for child in xml: + child_result = _recursive_parse_xml_to_dict(child) + if child.tag != 'object': + result[child.tag] = child_result[child.tag] + else: + if child.tag not in result: + result[child.tag] = [] + result[child.tag].append(child_result[child.tag]) + return {xml.tag: result} + + import xml.etree.ElementTree as ET + + if dataset == "2012": + url = "http://pjreddie.com/media/files/" + tar_filename = "VOCtrainval_11-May-2012.tar" + extracted_filename = "VOC2012" #"VOCdevkit/VOC2012" + logging.info(" [============= VOC 2012 =============]") + elif dataset == "2012test": + extracted_filename = "VOC2012test" #"VOCdevkit/VOC2012" + logging.info(" [============= VOC 2012 Test Set =============]") + logging.info( + " \nAuthor: 2012test only have person annotation, so 2007test is highly recommended for testing !\n" + ) + import time + time.sleep(3) + if os.path.isdir(os.path.join(path, extracted_filename)) is False: + logging.info("For VOC 2012 Test data - online registration required") + logging.info( + " Please download VOC2012test.tar from: \n register: http://host.robots.ox.ac.uk:8080 \n voc2012 : http://host.robots.ox.ac.uk:8080/eval/challenges/voc2012/ \ndownload: http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2012test.tar" + ) + logging.info(" unzip VOC2012test.tar,rename the folder to VOC2012test and put it into %s" % path) + exit() + # # http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2012test.tar + # url = "http://host.robots.ox.ac.uk:8080/eval/downloads/" + # tar_filename = "VOC2012test.tar" + elif dataset == "2007": + url = "http://pjreddie.com/media/files/" + tar_filename = "VOCtrainval_06-Nov-2007.tar" + extracted_filename = "VOC2007" + logging.info(" [============= VOC 2007 =============]") + elif dataset == "2007test": + # http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html#testdata + # http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar + url = "http://pjreddie.com/media/files/" + tar_filename = "VOCtest_06-Nov-2007.tar" + extracted_filename = "VOC2007test" + logging.info(" [============= VOC 2007 Test Set =============]") + else: + raise Exception("Please set the dataset aug to 2012, 2012test or 2007.") + + # download dataset + if dataset != "2012test": + from sys import platform as _platform + if folder_exists(os.path.join(path, extracted_filename)) is False: + logging.info("[VOC] {} is nonexistent in {}".format(extracted_filename, path)) + maybe_download_and_extract(tar_filename, path, url+tar_filename, extract=True) + del_file(os.path.join(path, tar_filename)) + if dataset == "2012": + if _platform == "win32": + os.system("move {}\VOCdevkit\VOC2012 {}\VOC2012".format(path, path)) + else: + os.system("mv {}/VOCdevkit/VOC2012 {}/VOC2012".format(path, path)) + elif dataset == "2007": + if _platform == "win32": + os.system("move {}\VOCdevkit\VOC2007 {}\VOC2007".format(path, path)) + else: + os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007".format(path, path)) + elif dataset == "2007test": + if _platform == "win32": + os.system("move {}\VOCdevkit\VOC2007 {}\VOC2007test".format(path, path)) + else: + os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007test".format(path, path)) + del_folder(os.path.join(path, 'VOCdevkit')) + # object classes(labels) NOTE: YOU CAN CUSTOMIZE THIS LIST + classes = [ + "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", + "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" + ] + if contain_classes_in_person: + classes_in_person = ["head", "hand", "foot"] + else: + classes_in_person = [] + + classes += classes_in_person # use extra 3 classes for person + + classes_dict = utils.list_string_to_dict(classes) + logging.info("[VOC] object classes {}".format(classes_dict)) + + # 1. image path list + # folder_imgs = path+"/"+extracted_filename+"/JPEGImages/" + folder_imgs = os.path.join(path, extracted_filename, "JPEGImages") + imgs_file_list = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False) + logging.info("[VOC] {} images found".format(len(imgs_file_list))) + + imgs_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000027.jpg --> 2007000027 + + imgs_file_list = [os.path.join(folder_imgs, s) for s in imgs_file_list] + # logging.info('IM',imgs_file_list[0::3333], imgs_file_list[-1]) + if dataset != "2012test": + ##======== 2. semantic segmentation maps path list + # folder_semseg = path+"/"+extracted_filename+"/SegmentationClass/" + folder_semseg = os.path.join(path, extracted_filename, "SegmentationClass") + imgs_semseg_file_list = load_file_list(path=folder_semseg, regx='\\.png', printable=False) + logging.info("[VOC] {} maps for semantic segmentation found".format(len(imgs_semseg_file_list))) + imgs_semseg_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000032.png --> 2007000032 + imgs_semseg_file_list = [os.path.join(folder_semseg, s) for s in imgs_semseg_file_list] + # logging.info('Semantic Seg IM',imgs_semseg_file_list[0::333], imgs_semseg_file_list[-1]) + ##======== 3. instance segmentation maps path list + # folder_insseg = path+"/"+extracted_filename+"/SegmentationObject/" + folder_insseg = os.path.join(path, extracted_filename, "SegmentationObject") + imgs_insseg_file_list = load_file_list(path=folder_insseg, regx='\\.png', printable=False) + logging.info("[VOC] {} maps for instance segmentation found".format(len(imgs_semseg_file_list))) + imgs_insseg_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000032.png --> 2007000032 + imgs_insseg_file_list = [os.path.join(folder_insseg, s) for s in imgs_insseg_file_list] + # logging.info('Instance Seg IM',imgs_insseg_file_list[0::333], imgs_insseg_file_list[-1]) + else: + imgs_semseg_file_list = [] + imgs_insseg_file_list = [] + # 4. annotations for bounding box and object class + # folder_ann = path+"/"+extracted_filename+"/Annotations/" + folder_ann = os.path.join(path, extracted_filename, "Annotations") + imgs_ann_file_list = load_file_list(path=folder_ann, regx='\\.xml', printable=False) + logging.info( + "[VOC] {} XML annotation files for bounding box and object class found".format(len(imgs_ann_file_list)) + ) + imgs_ann_file_list.sort( + key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2]) + ) # 2007_000027.xml --> 2007000027 + imgs_ann_file_list = [os.path.join(folder_ann, s) for s in imgs_ann_file_list] + # logging.info('ANN',imgs_ann_file_list[0::3333], imgs_ann_file_list[-1]) + + if dataset == "2012test": # remove unused images in JPEG folder + imgs_file_list_new = [] + for ann in imgs_ann_file_list: + ann = os.path.split(ann)[-1].split('.')[0] + for im in imgs_file_list: + if ann in im: + imgs_file_list_new.append(im) + break + imgs_file_list = imgs_file_list_new + logging.info("[VOC] keep %d images" % len(imgs_file_list_new)) + + # parse XML annotations + def convert(size, box): + dw = 1. / size[0] + dh = 1. / size[1] + x = (box[0] + box[1]) / 2.0 + y = (box[2] + box[3]) / 2.0 + w = box[1] - box[0] + h = box[3] - box[2] + x = x * dw + w = w * dw + y = y * dh + h = h * dh + return x, y, w, h + + def convert_annotation(file_name): + """Given VOC2012 XML Annotations, returns number of objects and info.""" + in_file = open(file_name) + out_file = "" + tree = ET.parse(in_file) + root = tree.getroot() + size = root.find('size') + w = int(size.find('width').text) + h = int(size.find('height').text) + n_objs = 0 + + for obj in root.iter('object'): + if dataset != "2012test": + difficult = obj.find('difficult').text + cls = obj.find('name').text + if cls not in classes or int(difficult) == 1: + continue + else: + cls = obj.find('name').text + if cls not in classes: + continue + cls_id = classes.index(cls) + xmlbox = obj.find('bndbox') + b = ( + float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), + float(xmlbox.find('ymax').text) + ) + bb = convert((w, h), b) + + out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n' + n_objs += 1 + if cls in "person": + for part in obj.iter('part'): + cls = part.find('name').text + if cls not in classes_in_person: + continue + cls_id = classes.index(cls) + xmlbox = part.find('bndbox') + b = ( + float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), + float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text) + ) + bb = convert((w, h), b) + # out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n') + out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n' + n_objs += 1 + in_file.close() + return n_objs, out_file + + logging.info("[VOC] Parsing xml annotations files") + n_objs_list = [] + objs_info_list = [] # Darknet Format list of string + objs_info_dicts = {} + for idx, ann_file in enumerate(imgs_ann_file_list): + n_objs, objs_info = convert_annotation(ann_file) + n_objs_list.append(n_objs) + objs_info_list.append(objs_info) + with tf.io.gfile.GFile(ann_file, 'r') as fid: + xml_str = fid.read() + xml = etree.fromstring(xml_str) + data = _recursive_parse_xml_to_dict(xml)['annotation'] + objs_info_dicts.update({imgs_file_list[idx]: data}) + + return imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list, classes, classes_in_person, classes_dict, n_objs_list, objs_info_list, objs_info_dicts diff --git a/tensorlayer/data/dataset/wmt_en_fr.py b/tensorlayer/data/dataset/wmt_en_fr.py new file mode 100644 index 000000000..5a28e279b --- /dev/null +++ b/tensorlayer/data/dataset/wmt_en_fr.py @@ -0,0 +1,80 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import gzip +import os +import tarfile + +from tensorflow.python.platform import gfile +from tensorlayer import logging +from ..utils import maybe_download_and_extract + +__all__ = ['load_wmt_en_fr_dataset'] + +WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar" +WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz" + + +def load_wmt_en_fr_dataset(path='data', name='wmt_en_fr'): + """Load WMT'15 English-to-French translation dataset. + + It will download the data from the WMT'15 Website (10^9-French-English corpus), and the 2013 news test from the same site as development set. + Returns the directories of training data and test data. + + Parameters + ---------- + path : str + The path that the data is downloaded to, defaults is ``data/wmt_en_fr/``. + + References + ---------- + - Code modified from /tensorflow/models/rnn/translation/data_utils.py + + Notes + ----- + Usually, it will take a long time to download this dataset. + + """ + path = os.path.join(path, name) + # URLs for WMT data. + + def gunzip_file(gz_path, new_path): + """Unzips from gz_path into new_path.""" + logging.info("Unpacking %s to %s" % (gz_path, new_path)) + with gzip.open(gz_path, "rb") as gz_file: + with open(new_path, "wb") as new_file: + for line in gz_file: + new_file.write(line) + + def get_wmt_enfr_train_set(path): + """Download the WMT en-fr training corpus to directory unless it's there.""" + filename = "training-giga-fren.tar" + maybe_download_and_extract(filename, path, WMT_ENFR_TRAIN_URL, extract=True) + train_path = os.path.join(path, "giga-fren.release2.fixed") + gunzip_file(train_path + ".fr.gz", train_path + ".fr") + gunzip_file(train_path + ".en.gz", train_path + ".en") + return train_path + + def get_wmt_enfr_dev_set(path): + """Download the WMT en-fr training corpus to directory unless it's there.""" + filename = "dev-v2.tgz" + dev_file = maybe_download_and_extract(filename, path, WMT_ENFR_DEV_URL, extract=False) + dev_name = "newstest2013" + dev_path = os.path.join(path, "newstest2013") + if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")): + logging.info("Extracting tgz file %s" % dev_file) + with tarfile.open(dev_file, "r:gz") as dev_tar: + fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr") + en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en") + fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix. + en_dev_file.name = dev_name + ".en" + dev_tar.extract(fr_dev_file, path) + dev_tar.extract(en_dev_file, path) + return dev_path + + logging.info("Load or Download WMT English-to-French translation > {}".format(path)) + + train_path = get_wmt_enfr_train_set(path) + dev_path = get_wmt_enfr_dev_set(path) + + return train_path, dev_path diff --git a/tensorlayer/data/parallel.py b/tensorlayer/data/parallel.py new file mode 100644 index 000000000..e0abc0f14 --- /dev/null +++ b/tensorlayer/data/parallel.py @@ -0,0 +1,194 @@ +import multiprocessing +import os +import sys +import uuid + +import zmq +import numpy as np + +from .base import DatasetWrapper +from .serialize import * +from .utils import clean_up_socket_files + + +class MultiprocessDataset(DatasetWrapper): + def __init__(self, + ds, + num_worker, + num_prefetch, + shuffle=False): + + super(MultiprocessDataset, self).__init__(ds) + self.num_worker = num_worker + self.num_prefetch = num_prefetch + self.shuffle = shuffle + + self.index_queue = multiprocessing.Queue(self.num_worker) + self.data_queue = multiprocessing.Queue(self.num_prefetch) + self.put_idx_worker = None + for _ in range(num_worker): + worker = multiprocessing.Process(target=self._data_worker, + args=(self.ds, self.index_queue, self.data_queue)) + worker.daemon = True + worker.start() + + def _data_worker(self, ds, index_q, data_q): + while True: + idx = index_q.get() + data_q.put((idx, ds[idx])) + + def _put_idx(self, index_q, idx): + index_q.put(idx) + + def __iter__(self): + # clear queues from previous epoch + while not self.index_queue.empty(): + self.index_queue.get() + while not self.data_queue.empty(): + self.data_queue.get() + + # shuffle at the start of every epoch + if self.shuffle: + self.idxs = np.random.permutation(self.ds_len) + else: + self.idxs = np.arange(self.ds_len) + + send_idx_cnt = 0 + for _ in range(self.num_worker * 2): + self._put_idx(self.index_queue, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + + data_buffer = {} + for return_idx in self.idxs: + if return_idx in data_buffer: + yield data_buffer.pop(return_idx) + else: + while True: + idx, dp = self.data_queue.get() + # put new idx after collecting data + if send_idx_cnt < len(self.idxs): + self._put_idx(self.index_queue, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + if idx == return_idx: + yield dp + break + else: + data_buffer[idx] = dp + + +class ZMQMultiprocessDataset(DatasetWrapper): + def __init__(self, + ds, + num_worker, + hwm=50, + shuffle=False): + + super(ZMQMultiprocessDataset, self).__init__(ds) + self.num_worker = num_worker + self.shuffle = shuffle + self._hwm = hwm + + self.idx_pipename = _get_pipe_name('put_idx') + self.data_pipename = _get_pipe_name('collect_data') + + for i in range(num_worker): + # first worker bind the socket, others connect to the socket + # however, zmq sockets using ipc do not care about the order of bind / connect + if i == 0: + worker = multiprocessing.Process(target=self._data_worker, + args=(True,)) + else: + worker = multiprocessing.Process(target=self._data_worker, + args=()) + worker.daemon = True + worker.start() + + clean_up_socket_files([self.idx_pipename, self.data_pipename]) + + def _data_worker(self, bind=False): + context = zmq.Context() + worker_receive_index_socket = context.socket(zmq.PULL) + worker_receive_index_socket.set_hwm(self._hwm) + if bind: + worker_receive_index_socket.bind(self.idx_pipename) + else: + worker_receive_index_socket.connect(self.idx_pipename) + + worker_send_data_socket = context.socket(zmq.PUSH) + worker_send_data_socket.set_hwm(self._hwm) + if bind: + worker_send_data_socket.bind(self.data_pipename) + else: + worker_send_data_socket.connect(self.data_pipename) + + while True: + recv_msg = worker_receive_index_socket.recv(copy=False) + idx = load_from_bytes(recv_msg) + send_msg = convert_to_bytes({'idx': idx, 'data': self.ds[idx]}) + worker_send_data_socket.send(send_msg, copy=False) + + def _put_idx(self, put_idx_socket, idx): + send_msg = convert_to_bytes(idx) + put_idx_socket.send(send_msg, copy=False) + + def __iter__(self): + context = zmq.Context() + collect_data_socket = context.socket(zmq.PULL) + collect_data_socket.set_hwm(self._hwm) + collect_data_socket.connect(self.data_pipename) + + put_idx_socket = context.socket(zmq.PUSH) + put_idx_socket.set_hwm(self._hwm) + put_idx_socket.connect(self.idx_pipename) + + # shutdown put_idx_worker and clear queues from previous epoch + try: + while True: + collect_data_socket.recv(flags=zmq.NOBLOCK) + except zmq.ZMQError: + pass + + # shuffle at the start of every epoch + if self.shuffle: + self.idxs = np.random.permutation(self.ds_len) + else: + self.idxs = np.arange(self.ds_len) + + send_idx_cnt = 0 + for _ in range(self.num_worker * 2): + self._put_idx(put_idx_socket, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + + data_buffer = {} + for return_idx in self.idxs: + if return_idx in data_buffer: + yield data_buffer.pop(return_idx) + else: + while True: + recv_msg = collect_data_socket.recv(copy=False) + recv_msg = load_from_bytes(recv_msg) + idx, dp = recv_msg['idx'], recv_msg['data'] + # put new idx after collecting data + if send_idx_cnt < len(self.idxs): + self._put_idx(put_idx_socket, self.idxs[send_idx_cnt]) + send_idx_cnt += 1 + if idx == return_idx: + yield dp + break + else: + data_buffer[idx] = dp + + +def _get_pipe_name(name): + if sys.platform.startswith('linux'): + # linux supports abstract sockets: http://api.zeromq.org/4-1:zmq-ipc + pipename = "ipc://@{}-pipe-{}".format(name, str(uuid.uuid1())[:8]) + else: + pipedir = '.' + assert os.path.isdir(pipedir), pipedir + filename = '{}/{}-pipe-{}'.format(pipedir.rstrip('/'), name, str(uuid.uuid1())[:6]) + assert not os.path.exists(filename), "Pipe {} exists! You may be unlucky.".format(filename) + pipename = "ipc://{}".format(filename) + # register in environment variable, used for cleaning up ipc socket files + # os.environ[name] = pipename + return pipename diff --git a/tensorlayer/data/serialize.py b/tensorlayer/data/serialize.py new file mode 100644 index 000000000..aa272f4f4 --- /dev/null +++ b/tensorlayer/data/serialize.py @@ -0,0 +1,27 @@ +import msgpack_numpy + +MAX_MSGPACK_LEN = 1000000000 + + +def convert_to_bytes(obj): + """ + Serialize an object. + + Returns: + Implementation-dependent bytes-like object. + """ + return msgpack_numpy.dumps(obj, use_bin_type=True) + + +def load_from_bytes(buf): + """ + Args: + buf: the output of `dumps`. + """ + # Since 0.6, the default max size was set to 1MB. + # We change it to approximately 1G. + return msgpack_numpy.loads(buf, raw=False, + max_bin_len=MAX_MSGPACK_LEN, + max_array_len=MAX_MSGPACK_LEN, + max_map_len=MAX_MSGPACK_LEN, + max_str_len=MAX_MSGPACK_LEN) diff --git a/tensorlayer/data/utils.py b/tensorlayer/data/utils.py new file mode 100644 index 000000000..f896b28e7 --- /dev/null +++ b/tensorlayer/data/utils.py @@ -0,0 +1,409 @@ +import atexit +import logging +import math +import multiprocessing +import os +import platform +import re +import resource +import shutil +import weakref + +import psutil +import tarfile +import time +import zipfile +import progressbar +from urllib.request import urlretrieve + + +def load_folder_list(path=""): + """Return a folder list in a folder by given a folder path. + + Parameters + ---------- + path : str + A folder path. + + """ + return [os.path.join(path, o) for o in os.listdir(path) if os.path.isdir(os.path.join(path, o))] + + +def exists_or_mkdir(path, verbose=True): + """ + Check a folder by given name, if not exist, create the folder and return False, + if directory exists, return True. + + Parameters + ---------- + path : str + A folder path. + verbose : boolean + If True (default), prints results. + + Returns + -------- + boolean + True if folder already exist, otherwise, returns False and create the folder. + + Examples + -------- + >>> exists_or_mkdir("checkpoints/train") + + """ + if not os.path.exists(path): + if verbose: + logging.info("[*] creates %s ..." % path) + os.makedirs(path) + return False + else: + if verbose: + logging.info("[!] %s exists ..." % path) + return True + + +def download(filename, working_directory, url_source): + """ + Download file from url_source to the working_directory with given filename. + + Parameters + ---------- + filename : str + The name of the downloaded file. + working_directory : str + A folder path download the file to + url_source : str + The URL to download the file from + + Examples + -------- + >>> download(filename='train.gz', + ... working_directory='data/', + ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') + + """ + working_directory = os.path.expanduser(working_directory) + + progress_bar = progressbar.ProgressBar() + + def _dlProgress(count, blockSize, totalSize, pbar=progress_bar): + if (totalSize != 0): + + if not pbar.max_value: + totalBlocks = math.ceil(float(totalSize) / float(blockSize)) + pbar.max_value = int(totalBlocks) + + pbar.update(count, force=True) + + filepath = os.path.join(working_directory, filename) + + logging.info('Downloading %s...\n' % filename) + + urlretrieve(url_source + filename, filepath, reporthook=_dlProgress) + + +def maybe_download_and_extract(filename, working_directory, url_source, extract=False, expected_bytes=None): + """ + Checks if file exists in working_directory otherwise tries to dowload the file, + and optionally also tries to extract the file if format is ".zip" or ".tar" + + Parameters + ----------- + filename : str + The name of the (to be) dowloaded file. + working_directory : str + A folder path to search for the file in and dowload the file to + url_source : str + The URL to download the file from + extract : boolean + If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file, default is False. + expected_bytes : int or None + If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception, defaults is None which corresponds to no check being performed. + + Returns + ---------- + str + File path of the dowloaded (uncompressed) file. + + Examples + -------- + >>> down_file = maybe_download_and_extract(filename='train.gz', + ... working_directory='data/', + ... url_source='http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz') + >>> maybe_download_and_extract(filename='ADEChallengeData2016.zip', + ... working_directory='data/', + ... url_source='http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip', + ... extract=True) + + """ + working_directory = os.path.expanduser(working_directory) + exists_or_mkdir(working_directory, verbose=False) + filepath = os.path.join(working_directory, filename) + + if not os.path.exists(filepath): + download(filename, working_directory, url_source) + statinfo = os.stat(filepath) + logging.info('Succesfully downloaded %s %s bytes.' % (filename, statinfo.st_size)) # , 'bytes.') + if not (expected_bytes is None) and (expected_bytes != statinfo.st_size): + raise Exception('Failed to verify ' + filename + '. Can you get to it with a browser?') + if extract: + if tarfile.is_tarfile(filepath): + logging.info('Trying to extract tar file') + tarfile.open(filepath, 'r').extractall(working_directory) + logging.info('... Success!') + elif zipfile.is_zipfile(filepath): + logging.info('Trying to extract zip file') + with zipfile.ZipFile(filepath) as zf: + zf.extractall(working_directory) + logging.info('... Success!') + else: + logging.info("Unknown compression_format only .tar.gz/.tar.bz2/.tar and .zip supported") + return filepath + + +def natural_keys(text): + """Sort list of string with number in human order. + + Examples + ---------- + >>> l = ['im1.jpg', 'im31.jpg', 'im11.jpg', 'im21.jpg', 'im03.jpg', 'im05.jpg'] + >>> l.sort(key=tl.files.natural_keys) + ['im1.jpg', 'im03.jpg', 'im05', 'im11.jpg', 'im21.jpg', 'im31.jpg'] + >>> l.sort() # that is what we dont want + ['im03.jpg', 'im05', 'im1.jpg', 'im11.jpg', 'im21.jpg', 'im31.jpg'] + + References + ---------- + - `link `__ + + """ + + # - alist.sort(key=natural_keys) sorts in human order + # http://nedbatchelder.com/blog/200712/human_sorting.html + # (See Toothy's implementation in the comments) + def atoi(text): + return int(text) if text.isdigit() else text + + return [atoi(c) for c in re.split('(\d+)', text)] + + +def get_dataloader_speed(dl, num_steps): + cnt = 0 + start = time.time() + end = start + for _ in dl: + cnt += 1 + if cnt == num_steps: + end = time.time() + break + return (end - start) / num_steps + + +def format_bytes(bytes): + if abs(bytes) < 1000: + return str(bytes) + "B" + elif abs(bytes) < 1e6: + return str(round(bytes / 1e3, 2)) + "kB" + elif abs(bytes) < 1e9: + return str(round(bytes / 1e6, 2)) + "MB" + else: + return str(round(bytes / 1e9, 2)) + "GB" + + +def get_process_memory(): + process = psutil.Process(os.getpid()) + mi = process.memory_info() + return mi.rss, mi.vms, mi.shared + + +def get_peak_memory_usage(): + # peak memory usage (bytes on OS X, kilobytes on Linux) + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + platform_name = platform.system().lower() + + # If we are on linux + if platform_name== "linux" or platform_name == "linux2": + return format_bytes(rss * 1024) + + # If we are on Mac OS X + elif platform_name == "darwin": + return format_bytes(rss) + + # We don't support Windows + elif platform_name == "win32": + raise EnvironmentError("The Windows operating system is not supported") + + # Unrecognized platform + else: + raise EnvironmentError("Unrecognized platform") + + +def shutdown_proc(proc): + if proc is None: + return + if proc.is_alive(): + proc.terminate() + proc.join() + + +def shutdown_proc_by_weakref(ref): + proc = ref() + if proc is None: + return + if proc.is_alive(): + proc.terminate() + proc.join() + + +def ensure_subprocess_terminate(proc): + """ + Make sure subprocesses terminate when main process exit. + + Args: + proc (multiprocessing.Process or list) + """ + if isinstance(proc, list): + for p in proc: + ensure_subprocess_terminate(p) + return + + assert isinstance(proc, multiprocessing.Process) + atexit.register(shutdown_proc_by_weakref, weakref.ref(proc)) + + +def clean_up_socket_files(pipe_names): + if isinstance(pipe_names, list): + for pipe_name in pipe_names: + clean_up_socket_files(pipe_name) + return + + def remove_socket_files(pipe_name): + # remove all ipc socket files + # the environment variable starts with 'ipc://', so file name starts from 6 + try: + os.remove(pipe_name[6:]) + except (FileNotFoundError, KeyError): + pass + + atexit.register(remove_socket_files, pipe_names) + + +def download_file_from_google_drive(ID, destination): + """Download file from Google Drive. + + See ``tl.files.load_celebA_dataset`` for example. + + Parameters + -------------- + ID : str + The driver ID. + destination : str + The destination for save file. + + """ + try: + from tqdm import tqdm + except ImportError as e: + print(e) + raise ImportError("Module tqdm not found. Please install tqdm via pip or other package managers.") + + try: + import requests + except ImportError as e: + print(e) + raise ImportError("Module requests not found. Please install requests via pip or other package managers.") + + def save_response_content(response, destination, chunk_size=32 * 1024): + + total_size = int(response.headers.get('content-length', 0)) + with open(destination, "wb") as f: + for chunk in tqdm(response.iter_content(chunk_size), total=total_size, unit='B', unit_scale=True, + desc=destination): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + URL = "https://docs.google.com/uc?export=download" + session = requests.Session() + + response = session.get(URL, params={'id': ID}, stream=True) + token = get_confirm_token(response) + + if token: + params = {'id': ID, 'confirm': token} + response = session.get(URL, params=params, stream=True) + save_response_content(response, destination) + + +def load_file_list(path=None, regx='\.jpg', printable=True, keep_prefix=False): + r"""Return a file list in a folder by given a path and regular expression. + + Parameters + ---------- + path : str or None + A folder path, if `None`, use the current directory. + regx : str + The regx of file name. + printable : boolean + Whether to print the files infomation. + keep_prefix : boolean + Whether to keep path in the file name. + + Examples + ---------- + >>> file_list = load_file_list(path=None, regx='w1pre_[0-9]+\.(npz)') + + """ + if path is None: + path = os.getcwd() + file_list = os.listdir(path) + return_list = [] + for _, f in enumerate(file_list): + if re.search(regx, f): + return_list.append(f) + # return_list.sort() + if keep_prefix: + for i, f in enumerate(return_list): + return_list[i] = os.path.join(path, f) + + if printable: + logging.info('Match file list = %s' % return_list) + logging.info('Number of files = %d' % len(return_list)) + return return_list + + +def file_exists(filepath): + """Check whether a file exists by given file path.""" + return os.path.isfile(filepath) + + +def folder_exists(folderpath): + """Check whether a folder exists by given folder path.""" + return os.path.isdir(folderpath) + + +def del_folder(folderpath): + """Delete a folder by given folder path.""" + shutil.rmtree(folderpath) + + +def del_file(filepath): + """Delete a file by given file path.""" + os.remove(filepath) + + +def read_file(filepath): + """Read a file and return a string. + + Examples + --------- + >>> data = read_file('data.txt') + """ + with open(filepath, 'r') as afile: + return afile.read() diff --git a/tensorlayer/files/dataset_loaders/celebA_dataset.py b/tensorlayer/files/dataset_loaders/celebA_dataset.py index d5dc5755f..73ebcbced 100644 --- a/tensorlayer/files/dataset_loaders/celebA_dataset.py +++ b/tensorlayer/files/dataset_loaders/celebA_dataset.py @@ -10,7 +10,7 @@ __all__ = ['load_celebA_dataset'] -def load_celebA_dataset(path='data'): +def load_celebA_dataset(path='raw_data'): """Load CelebA dataset Return a list of image path. diff --git a/tensorlayer/files/dataset_loaders/cifar10_dataset.py b/tensorlayer/files/dataset_loaders/cifar10_dataset.py index 9af3f615d..bc36ff838 100644 --- a/tensorlayer/files/dataset_loaders/cifar10_dataset.py +++ b/tensorlayer/files/dataset_loaders/cifar10_dataset.py @@ -13,7 +13,7 @@ __all__ = ['load_cifar10_dataset'] -def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data', plotable=False): +def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='raw_data', plotable=False): """Load CIFAR-10 dataset. It consists of 60000 32x32 colour images in 10 classes, with diff --git a/tensorlayer/files/dataset_loaders/cyclegan_dataset.py b/tensorlayer/files/dataset_loaders/cyclegan_dataset.py index e327b3b4c..9f77b6710 100644 --- a/tensorlayer/files/dataset_loaders/cyclegan_dataset.py +++ b/tensorlayer/files/dataset_loaders/cyclegan_dataset.py @@ -11,7 +11,7 @@ __all__ = ['load_cyclegan_dataset'] -def load_cyclegan_dataset(filename='summer2winter_yosemite', path='data'): +def load_cyclegan_dataset(filename='summer2winter_yosemite', path='raw_data'): """Load images from CycleGAN's database, see `this link `__. Parameters diff --git a/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py b/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py index f2e582ae5..a81cfeac5 100644 --- a/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py +++ b/tensorlayer/files/dataset_loaders/flickr_1M_dataset.py @@ -11,7 +11,7 @@ __all__ = ['load_flickr1M_dataset'] -def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printable=False): +def load_flickr1M_dataset(tag='sky', size=10, path="raw_data", n_threads=50, printable=False): """Load Flick1M dataset. Returns a list of images by a given tag from Flickr1M dataset, diff --git a/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py b/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py index 8049a0653..4c695b4e5 100644 --- a/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py +++ b/tensorlayer/files/dataset_loaders/flickr_25k_dataset.py @@ -11,7 +11,7 @@ __all__ = ['load_flickr25k_dataset'] -def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False): +def load_flickr25k_dataset(tag='sky', path="raw_data", n_threads=50, printable=False): """Load Flickr25K dataset. Returns a list of images by a given tag from Flick25k dataset, diff --git a/tensorlayer/files/dataset_loaders/imdb_dataset.py b/tensorlayer/files/dataset_loaders/imdb_dataset.py index 34b4dffe0..686e6f906 100644 --- a/tensorlayer/files/dataset_loaders/imdb_dataset.py +++ b/tensorlayer/files/dataset_loaders/imdb_dataset.py @@ -13,7 +13,7 @@ def load_imdb_dataset( - path='data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, oov_char=2, + path='raw_data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, oov_char=2, index_from=3 ): """Load IMDB dataset. diff --git a/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py b/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py index 17a3e0833..f5a419f08 100644 --- a/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py +++ b/tensorlayer/files/dataset_loaders/matt_mahoney_dataset.py @@ -10,7 +10,7 @@ __all__ = ['load_matt_mahoney_text8_dataset'] -def load_matt_mahoney_text8_dataset(path='data'): +def load_matt_mahoney_text8_dataset(path='raw_data'): """Load Matt Mahoney's dataset. Download a text file from Matt Mahoney's website diff --git a/tensorlayer/files/dataset_loaders/mnist_dataset.py b/tensorlayer/files/dataset_loaders/mnist_dataset.py index 4e1346d5e..5e297dbc1 100644 --- a/tensorlayer/files/dataset_loaders/mnist_dataset.py +++ b/tensorlayer/files/dataset_loaders/mnist_dataset.py @@ -6,7 +6,7 @@ __all__ = ['load_mnist_dataset'] -def load_mnist_dataset(shape=(-1, 784), path='data'): +def load_mnist_dataset(shape=(-1, 784), path='raw_data'): """Load the original mnist. Automatically download MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 digit images respectively. diff --git a/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py b/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py index c7f1bb964..011d8688e 100644 --- a/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py +++ b/tensorlayer/files/dataset_loaders/mnist_fashion_dataset.py @@ -6,7 +6,7 @@ __all__ = ['load_fashion_mnist_dataset'] -def load_fashion_mnist_dataset(shape=(-1, 784), path='data'): +def load_fashion_mnist_dataset(shape=(-1, 784), path='raw_data'): """Load the fashion mnist. Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, `examples `__. diff --git a/tensorlayer/files/dataset_loaders/mpii_dataset.py b/tensorlayer/files/dataset_loaders/mpii_dataset.py index a6f88f609..bbbbe233f 100644 --- a/tensorlayer/files/dataset_loaders/mpii_dataset.py +++ b/tensorlayer/files/dataset_loaders/mpii_dataset.py @@ -9,7 +9,7 @@ __all__ = ['load_mpii_pose_dataset'] -def load_mpii_pose_dataset(path='data', is_16_pos_only=False): +def load_mpii_pose_dataset(path='raw_data', is_16_pos_only=False): """Load MPII Human Pose Dataset. Parameters diff --git a/tensorlayer/files/dataset_loaders/nietzsche_dataset.py b/tensorlayer/files/dataset_loaders/nietzsche_dataset.py index 3cd0e27f2..2bfb02113 100644 --- a/tensorlayer/files/dataset_loaders/nietzsche_dataset.py +++ b/tensorlayer/files/dataset_loaders/nietzsche_dataset.py @@ -9,7 +9,7 @@ __all__ = ['load_nietzsche_dataset'] -def load_nietzsche_dataset(path='data'): +def load_nietzsche_dataset(path='raw_data'): """Load Nietzsche dataset. Parameters diff --git a/tensorlayer/files/dataset_loaders/ptb_dataset.py b/tensorlayer/files/dataset_loaders/ptb_dataset.py index 30746fd87..023896828 100644 --- a/tensorlayer/files/dataset_loaders/ptb_dataset.py +++ b/tensorlayer/files/dataset_loaders/ptb_dataset.py @@ -9,7 +9,7 @@ __all__ = ['load_ptb_dataset'] -def load_ptb_dataset(path='data'): +def load_ptb_dataset(path='raw_data'): """Load Penn TreeBank (PTB) dataset. It is used in many LANGUAGE MODELING papers, diff --git a/tensorlayer/files/dataset_loaders/voc_dataset.py b/tensorlayer/files/dataset_loaders/voc_dataset.py index e5124b4df..b17650f14 100644 --- a/tensorlayer/files/dataset_loaders/voc_dataset.py +++ b/tensorlayer/files/dataset_loaders/voc_dataset.py @@ -10,7 +10,7 @@ __all__ = ['load_voc_dataset'] -def load_voc_dataset(path='data', dataset='2012', contain_classes_in_person=False): +def load_voc_dataset(path='raw_data', dataset='2012', contain_classes_in_person=False): """Pascal VOC 2007/2012 Dataset. It has 20 objects: diff --git a/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py b/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py index 77c1f93f9..572afdb46 100644 --- a/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py +++ b/tensorlayer/files/dataset_loaders/wmt_en_fr_dataset.py @@ -12,7 +12,7 @@ __all__ = ['load_wmt_en_fr_dataset'] -def load_wmt_en_fr_dataset(path='data'): +def load_wmt_en_fr_dataset(path='raw_data'): """Load WMT'15 English-to-French translation dataset. It will download the data from the WMT'15 Website (10^9-French-English corpus), and the 2013 news test from the same site as development set.