{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# ==============================================================================\n",
    "def int64_feature(values):\n",
    "    if not isinstance(values, (tuple, list)):\n",
    "        values = [values]\n",
    "    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))\n",
    "\n",
    "def bytes_feature(values):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))\n",
    "\n",
    "def image_to_tfexample(image_data, image_format, height, width, class_id):\n",
    "    \"\"\"\n",
    "        Creates a TFRecord example image entry.\n",
    "    Args:\n",
    "        image_data: The image as numpy array\n",
    "        image_format: The image format type e.g. jpg\n",
    "        height and width: The image sizes\n",
    "        class_id: The image label\n",
    "    \"\"\"\n",
    "    return tf.train.Example(features=tf.train.Features(feature={\n",
    "      'image/encoded': bytes_feature(image_data),\n",
    "      'image/format': bytes_feature(image_format),\n",
    "      'image/class/label': int64_feature(class_id),\n",
    "      'image/height': int64_feature(height),\n",
    "      'image/width': int64_feature(width),\n",
    "  }))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def write_tfrecord(images_by_id, top_dir, output_filename, image_height, image_width):\n",
    "    \"\"\"\n",
    "        Creates a TFRecord file from a map of image files paths to class ids\n",
    "    Args:\n",
    "        images_by_id: Map of image file paths mapped to class id\n",
    "        top_dir: where to put the file\n",
    "        output_filename: how to name the file\n",
    "    \"\"\"\n",
    "    tf.reset_default_graph()\n",
    "    \n",
    "    image_path = tf.placeholder(dtype=tf.string)\n",
    "    image_raw = tf.read_file(image_path)\n",
    "    \n",
    "    errors = []\n",
    "    num_images = sum([len(v) for _, v in images_by_id.items()])\n",
    "    with tf.Session() as sess:\n",
    "        with tf.python_io.TFRecordWriter(\"/\".join([top_dir, output_filename])) as tfrecord_writer:\n",
    "            processed_count = 0\n",
    "            for class_id, images in images_by_id.items():\n",
    "                for image in images:\n",
    "                    # Show progress\n",
    "                    processed_count += 1\n",
    "                    print('>> Converting image %d/%d' % (processed_count, num_images), end=\"\\r\")\n",
    "                    try:\n",
    "                        image_data = sess.run(image_raw, feed_dict={image_path: image})\n",
    "                        example = image_to_tfexample(\n",
    "                            image_data,\n",
    "                            b'jpg',\n",
    "                            image_height,\n",
    "                            image_width,\n",
    "                            class_id)\n",
    "                        tfrecord_writer.write(example.SerializeToString())\n",
    "                    except:\n",
    "                        err_msg = sys.exc_info()[0]\n",
    "                        errors.append((image, err_msg))\n",
    "                \n",
    "    print()            \n",
    "    print(\"Errors: {}\".format(len(errors)))\n",
    "    if len(errors) > 0:\n",
    "        for (error_file, info) in errors:\n",
    "            print(\"Error one file: {} because: {}\".format(error_file, info))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def read_sample(example_raw):\n",
    "    \"\"\"\n",
    "        Read a single TFRecord example and converting into an image\n",
    "    Args:\n",
    "        The TFRecord example that represents an image\n",
    "    \"\"\"\n",
    "    example = tf.parse_single_example(\n",
    "        example_raw,\n",
    "        features={\n",
    "            'image/encoded': tf.FixedLenFeature([], tf.string),\n",
    "            'image/format': tf.FixedLenFeature([], tf.string),\n",
    "            'image/class/label': tf.FixedLenFeature([], tf.int64),\n",
    "            'image/height': tf.FixedLenFeature([], tf.int64),\n",
    "            'image/width': tf.FixedLenFeature([], tf.int64),\n",
    "        })\n",
    "    image = tf.image.decode_image(example['image/encoded'])\n",
    "    height = tf.cast(example['image/height'], tf.int32)\n",
    "    width = tf.cast(example['image/width'], tf.int32)  \n",
    "    label = example[\"image/class/label\"]\n",
    "    image = tf.reshape(image, (height, width))\n",
    "    return image, label\n",
    "\n",
    "def make_dataset(tfrecord_filepath):\n",
    "    \"\"\"\n",
    "        Returns a dataset ready to process a TFRecord file\n",
    "    \"\"\"\n",
    "    dataset = tf.data.TFRecordDataset(tfrecord_filepath)\n",
    "    dataset = dataset.map(read_sample)\n",
    "    return dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# ==============================================================================\n",
    "# Create mapping files for class ids to class names\n",
    "def write_label_file(labels_to_class_names, dataset_dir,\n",
    "                     filename=LABELS_FILENAME):\n",
    "    \"\"\"\n",
    "        Writes a file with the list of class names row-wise like \"1:class1\".\n",
    "    Args:\n",
    "        labels_to_class_names: A map of (integer) labels to class names.\n",
    "        dataset_dir: The directory in which the labels file should be written.\n",
    "        filename: The filename where the class names are written.\n",
    "    \"\"\"\n",
    "    labels_filename = os.path.join(dataset_dir, filename)\n",
    "    with tf.gfile.Open(labels_filename, 'w') as f:\n",
    "        for label in labels_to_class_names:\n",
    "            class_name = labels_to_class_names[label]\n",
    "            f.write('%d:%s\\n' % (label, class_name))\n",
    "            \n",
    "def read_label_file(dataset_dir, filename=LABELS_FILENAME):\n",
    "    \"\"\"Reads the labels file and returns a mapping from ID to class name.\n",
    "    Args:\n",
    "        dataset_dir: The directory in which the labels file is found.\n",
    "        filename: The filename where the class names are written.\n",
    "    Returns:\n",
    "        A map from a label (integer) to class name.\n",
    "    \"\"\"\n",
    "    labels_filename = os.path.join(dataset_dir, filename)\n",
    "    with tf.gfile.Open(labels_filename, 'rb') as f:\n",
    "        lines = f.read().decode()\n",
    "        lines = lines.split('\\n')\n",
    "        lines = filter(None, lines)\n",
    "\n",
    "    labels_to_class_names = {}\n",
    "    for line in lines:\n",
    "        index = line.index(':')\n",
    "        labels_to_class_names[int(line[:index])] = line[index+1:]\n",
    "    return labels_to_class_names"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf]",
   "language": "python",
   "name": "conda-env-tf-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}