Skip to content

Commit cdcd3ec

Browse files
author
Taylor Robie
authored
set strip_default_attrs=True for SavedModel exports (tensorflow#5439)
* set strip_default_attrs=True for SavedModel exports * specify dtype in resnet export * another dtype fix * fix another dtype issue, and set --image_bytes_as_serving_input to default to False
1 parent 637f08e commit cdcd3ec

File tree

5 files changed

+18
-11
lines changed

5 files changed

+18
-11
lines changed

official/boosted_trees/train_higgs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def train_boosted_trees(flags_obj):
249249
_make_csv_serving_input_receiver_fn(
250250
column_names=feature_names,
251251
# columns are all floats.
252-
column_defaults=[[0.0]] * len(feature_names)))
252+
column_defaults=[[0.0]] * len(feature_names)),
253+
strip_default_attrs=True)
253254

254255

255256
def main(_):

official/mnist/mnist.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def eval_input_fn():
222222
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
223223
'image': image,
224224
})
225-
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
225+
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
226+
strip_default_attrs=True)
226227

227228

228229
def main(_):

official/resnet/resnet_run_loop.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,13 @@ def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
156156
return input_fn
157157

158158

159-
def image_bytes_serving_input_fn(image_shape):
159+
def image_bytes_serving_input_fn(image_shape, dtype=tf.float32):
160160
"""Serving input fn for raw jpeg images."""
161161

162162
def _preprocess_image(image_bytes):
163163
"""Preprocess a single raw image."""
164164
# Bounding box around the whole image.
165-
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
165+
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=dtype, shape=[1, 1, 4])
166166
height, width, num_channels = image_shape
167167
image = imagenet_preprocessing.preprocess_image(
168168
image_bytes, bbox, height, width, num_channels, is_training=False)
@@ -171,7 +171,7 @@ def _preprocess_image(image_bytes):
171171
image_bytes_list = tf.placeholder(
172172
shape=[None], dtype=tf.string, name='input_tensor')
173173
images = tf.map_fn(
174-
_preprocess_image, image_bytes_list, back_prop=False, dtype=tf.float32)
174+
_preprocess_image, image_bytes_list, back_prop=False, dtype=dtype)
175175
return tf.estimator.export.TensorServingInputReceiver(
176176
images, {'image_bytes': image_bytes_list})
177177

@@ -530,12 +530,15 @@ def input_fn_eval():
530530

531531
if flags_obj.export_dir is not None:
532532
# Exports a saved model for the given classifier.
533+
export_dtype = flags_core.get_tf_dtype(flags_obj)
533534
if flags_obj.image_bytes_as_serving_input:
534-
input_receiver_fn = functools.partial(image_bytes_serving_input_fn, shape)
535+
input_receiver_fn = functools.partial(
536+
image_bytes_serving_input_fn, shape, dtype=export_dtype)
535537
else:
536538
input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
537-
shape, batch_size=flags_obj.batch_size)
538-
classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
539+
shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
540+
classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
541+
strip_default_attrs=True)
539542

540543

541544
def define_resnet_flags(resnet_size_choices=None):
@@ -565,7 +568,7 @@ def define_resnet_flags(resnet_size_choices=None):
565568
help=flags_core.help_wrap('Skip training and only perform evaluation on '
566569
'the latest checkpoint.'))
567570
flags.DEFINE_boolean(
568-
name="image_bytes_as_serving_input", default=True,
571+
name="image_bytes_as_serving_input", default=False,
569572
help=flags_core.help_wrap(
570573
'If True exports savedmodel with serving signature that accepts '
571574
'JPEG image bytes instead of a fixed size [HxWxC] tensor that '

official/transformer/transformer_main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,8 @@ def run_transformer(flags_obj):
621621
# an extra asset rather than a core asset.
622622
estimator.export_savedmodel(
623623
flags_obj.export_dir, serving_input_fn,
624-
assets_extra={"vocab.txt": flags_obj.vocab_file})
624+
assets_extra={"vocab.txt": flags_obj.vocab_file},
625+
strip_default_attrs=True)
625626

626627

627628
def main(_):

official/wide_deep/wide_deep_run_loop.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def export_model(model, model_type, export_dir, model_column_fn):
7373
feature_spec = tf.feature_column.make_parse_example_spec(columns)
7474
example_input_fn = (
7575
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
76-
model.export_savedmodel(export_dir, example_input_fn)
76+
model.export_savedmodel(export_dir, example_input_fn,
77+
strip_default_attrs=True)
7778

7879

7980
def run_loop(name, train_input_fn, eval_input_fn, model_column_fn,

0 commit comments

Comments
 (0)