{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================\n", "def int64_feature(values):\n", " if not isinstance(values, (tuple, list)):\n", " values = [values]\n", " return tf.train.Feature(int64_list=tf.train.Int64List(value=values))\n", "\n", "def bytes_feature(values):\n", " return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))\n", "\n", "def image_to_tfexample(image_data, image_format, height, width, class_id):\n", " \"\"\"\n", " Creates a TFRecord example image entry.\n", " Args:\n", " image_data: The image as numpy array\n", " image_format: The image format type e.g. jpg\n", " height and width: The image sizes\n", " class_id: The image label\n", " \"\"\"\n", " return tf.train.Example(features=tf.train.Features(feature={\n", " 'image/encoded': bytes_feature(image_data),\n", " 'image/format': bytes_feature(image_format),\n", " 'image/class/label': int64_feature(class_id),\n", " 'image/height': int64_feature(height),\n", " 'image/width': int64_feature(width),\n", " }))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def write_tfrecord(images_by_id, top_dir, output_filename, image_height, image_width):\n", " \"\"\"\n", " Creates a TFRecord file from a map of image files paths to class ids\n", " Args:\n", " images_by_id: Map of image file paths mapped to class id\n", " top_dir: where to put the file\n", " output_filename: how to name the file\n", " \"\"\"\n", " tf.reset_default_graph()\n", " \n", " image_path = tf.placeholder(dtype=tf.string)\n", " image_raw = tf.read_file(image_path)\n", " \n", " errors = []\n", " num_images = sum([len(v) for _, v in images_by_id.items()])\n", " with tf.Session() as sess:\n", " with tf.python_io.TFRecordWriter(\"/\".join([top_dir, output_filename])) as tfrecord_writer:\n", " processed_count = 0\n", " for class_id, images in images_by_id.items():\n", " for image in images:\n", " # Show progress\n", " processed_count += 1\n", " print('>> Converting image %d/%d' % (processed_count, num_images), end=\"\\r\")\n", " try:\n", " image_data = sess.run(image_raw, feed_dict={image_path: image})\n", " example = image_to_tfexample(\n", " image_data,\n", " b'jpg',\n", " image_height,\n", " image_width,\n", " class_id)\n", " tfrecord_writer.write(example.SerializeToString())\n", " except:\n", " err_msg = sys.exc_info()[0]\n", " errors.append((image, err_msg))\n", " \n", " print() \n", " print(\"Errors: {}\".format(len(errors)))\n", " if len(errors) > 0:\n", " for (error_file, info) in errors:\n", " print(\"Error one file: {} because: {}\".format(error_file, info))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def read_sample(example_raw):\n", " \"\"\"\n", " Read a single TFRecord example and converting into an image\n", " Args:\n", " The TFRecord example that represents an image\n", " \"\"\"\n", " example = tf.parse_single_example(\n", " example_raw,\n", " features={\n", " 'image/encoded': tf.FixedLenFeature([], tf.string),\n", " 'image/format': tf.FixedLenFeature([], tf.string),\n", " 'image/class/label': tf.FixedLenFeature([], tf.int64),\n", " 'image/height': tf.FixedLenFeature([], tf.int64),\n", " 'image/width': tf.FixedLenFeature([], tf.int64),\n", " })\n", " image = tf.image.decode_image(example['image/encoded'])\n", " height = tf.cast(example['image/height'], tf.int32)\n", " width = tf.cast(example['image/width'], tf.int32) \n", " label = example[\"image/class/label\"]\n", " image = tf.reshape(image, (height, width))\n", " return image, label\n", "\n", "def make_dataset(tfrecord_filepath):\n", " \"\"\"\n", " Returns a dataset ready to process a TFRecord file\n", " \"\"\"\n", " dataset = tf.data.TFRecordDataset(tfrecord_filepath)\n", " dataset = dataset.map(read_sample)\n", " return dataset " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================\n", "# Create mapping files for class ids to class names\n", "def write_label_file(labels_to_class_names, dataset_dir,\n", " filename=LABELS_FILENAME):\n", " \"\"\"\n", " Writes a file with the list of class names row-wise like \"1:class1\".\n", " Args:\n", " labels_to_class_names: A map of (integer) labels to class names.\n", " dataset_dir: The directory in which the labels file should be written.\n", " filename: The filename where the class names are written.\n", " \"\"\"\n", " labels_filename = os.path.join(dataset_dir, filename)\n", " with tf.gfile.Open(labels_filename, 'w') as f:\n", " for label in labels_to_class_names:\n", " class_name = labels_to_class_names[label]\n", " f.write('%d:%s\\n' % (label, class_name))\n", " \n", "def read_label_file(dataset_dir, filename=LABELS_FILENAME):\n", " \"\"\"Reads the labels file and returns a mapping from ID to class name.\n", " Args:\n", " dataset_dir: The directory in which the labels file is found.\n", " filename: The filename where the class names are written.\n", " Returns:\n", " A map from a label (integer) to class name.\n", " \"\"\"\n", " labels_filename = os.path.join(dataset_dir, filename)\n", " with tf.gfile.Open(labels_filename, 'rb') as f:\n", " lines = f.read().decode()\n", " lines = lines.split('\\n')\n", " lines = filter(None, lines)\n", "\n", " labels_to_class_names = {}\n", " for line in lines:\n", " index = line.index(':')\n", " labels_to_class_names[int(line[:index])] = line[index+1:]\n", " return labels_to_class_names" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:tf]", "language": "python", "name": "conda-env-tf-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }