Skip to content

Commit 09d9200

Browse files
author
Alan Patterson
committed
Added support for Tensorflow Serving. Breaking change to remove
redundant output dimension.
1 parent 61d24cc commit 09d9200

File tree

6 files changed

+87
-8
lines changed

6 files changed

+87
-8
lines changed

README.md

+20
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ close to 2x the speed.
123123

124124
The training script has this option added: `train_classifier.py`.
125125

126+
# Tensorflow Serving
127+
128+
As well as using `predictor.py` to run a saved model to provide
129+
predictions, it is easy to serve a saved model using Tensorflow
130+
Serving and gRPC. There is a supplied simple rpc client that provides
131+
predictions by using a gRPC server.
132+
133+
First make sure you install the tensorflow serving binaries. Instructions are [here](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/setup.md#installing-the-modelserver).
134+
135+
You then serve the latest saved model by supplying the base export
136+
directory where you exported saved models to. This directory will
137+
contain the numbered model directories:
138+
139+
tensorflow_model_server --port=9000 --model_base_path=model
140+
141+
Now you can make requests to the server using gRPC calls. An example
142+
simple client is provided in `predictor_client.py`:
143+
144+
predictor_client.py --text="Some text to classify"
145+
126146
# Facebook Examples
127147

128148
<< NOT IMPLEMENTED YET >>

classifier.py

-3
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def model_fn(features, labels, mode, params):
106106
-0.1, 0.1))
107107
text_embedding = tf.reduce_mean(tf.nn.embedding_lookup(
108108
text_embedding_w, text_ids), axis=-2)
109-
text_embedding = tf.expand_dims(text_embedding, -2)
110109
input_layer = text_embedding
111110
if FLAGS.use_ngrams:
112111
ngram_hash = tf.string_to_hash_bucket(features["ngrams"],
@@ -131,8 +130,6 @@ def model_fn(features, labels, mode, params):
131130
labels = label_lookup_table.lookup(labels)
132131
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
133132
labels=labels, logits=logits))
134-
# Squeeze dimensions from labels and switch to 0-offset
135-
labels = tf.squeeze(labels, -1)
136133
opt = tf.train.AdamOptimizer(params["learning_rate"])
137134
if FLAGS.horovod:
138135
opt = hvd.DistributedOptimizer(opt)

inputs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ def BuildTextExample(text, ngrams=None, label=None):
1515
if ngrams is not None:
1616
ngrams = [tf.compat.as_bytes(x) for x in ngrams]
1717
record.features.feature["ngrams"].bytes_list.value.extend(ngrams)
18-
return record.SerializeToString()
18+
return record
1919

2020

2121
def ParseSpec(use_ngrams, include_target):
2222
parse_spec = {"text": tf.VarLenFeature(dtype=tf.string)}
2323
if use_ngrams:
2424
parse_spec["ngrams"] = tf.VarLenFeature(dtype=tf.string)
2525
if include_target:
26-
parse_spec["label"] = tf.FixedLenFeature(shape=(1,), dtype=tf.string,
26+
parse_spec["label"] = tf.FixedLenFeature(shape=(), dtype=tf.string,
2727
default_value=None)
2828
return parse_spec
2929

predictor.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def RunModel(saved_model_dir, signature_def_key, tag, text, ngrams_list=None):
4141
ngrams_list = text_utils.ParseNgramsOpts(ngrams_list)
4242
ngrams = text_utils.GenerateNgrams(text, ngrams_list)
4343
example = inputs.BuildTextExample(text, ngrams=ngrams)
44+
example = example.SerializeToString()
4445
inputs_feed_dict = {
4546
signature_def.inputs["inputs"].name: [example],
4647
}
@@ -64,8 +65,8 @@ def main(_):
6465
outputs = RunModel(FLAGS.saved_model, FLAGS.signature_def, FLAGS.tag,
6566
FLAGS.text, FLAGS.ngrams)
6667
if FLAGS.signature_def == "proba":
67-
print("Proba:", outputs[0])
68-
print("Class(1-N):", np.argmax(outputs[0]) + 1)
68+
print("Proba:", outputs)
69+
print("Class(1-N):", np.argmax(outputs) + 1)
6970
elif FLAGS.signature_def == "embedding":
7071
print(outputs[0])
7172

predictor_client.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Predict classification on provided text.
2+
3+
Send request to a tensorflow_model_server.
4+
5+
tensorflow_model_server --port=9000 --model_base_path=$export_dir_base
6+
7+
Usage:
8+
9+
predictor_client.py --text='some text' --ngrams=1,2,4
10+
11+
"""
12+
from __future__ import absolute_import
13+
from __future__ import division
14+
from __future__ import print_function
15+
16+
import tensorflow as tf
17+
import inputs
18+
import text_utils
19+
20+
from grpc.beta import implementations
21+
from tensorflow_serving.apis import classification_pb2
22+
from tensorflow_serving.apis import prediction_service_pb2
23+
24+
25+
tf.flags.DEFINE_string('server', 'localhost:9000',
26+
'TensorflowService host:port')
27+
tf.flags.DEFINE_string("text", None, "Text to predict label of")
28+
tf.flags.DEFINE_string("ngrams", None, "List of ngram lengths, E.g. --ngrams=2,3,4")
29+
tf.flags.DEFINE_string("signature_def", "proba",
30+
"Stored signature key of method to call (proba|embedding)")
31+
FLAGS = tf.flags.FLAGS
32+
33+
34+
def Request(text, ngrams):
35+
text = text_utils.TokenizeText(text)
36+
ngrams = None
37+
if ngrams is not None:
38+
ngrams_list = text_utils.ParseNgramsOpts(ngrams)
39+
ngrams = text_utils.GenerateNgrams(text, ngrams_list)
40+
example = inputs.BuildTextExample(text, ngrams=ngrams)
41+
request = classification_pb2.ClassificationRequest()
42+
request.model_spec.name = 'default'
43+
request.model_spec.signature_name = 'proba'
44+
request.input.example_list.examples.extend([example])
45+
return request
46+
47+
48+
def main(_):
49+
if not FLAGS.text:
50+
raise ValueError("No --text provided")
51+
host, port = FLAGS.server.split(':')
52+
channel = implementations.insecure_channel(host, int(port))
53+
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
54+
request = Request(FLAGS.text, FLAGS.ngrams)
55+
result = stub.Classify(request, 10.0) # 10 secs timeout
56+
print(result)
57+
58+
59+
if __name__ == '__main__':
60+
tf.app.run()
61+

process_input.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def WriteExamples(examples, outputfile, num_shards):
9494
(shard, num_shards))
9595
record = inputs.BuildTextExample(
9696
example["text"], example.get("ngrams", None), example["label"])
97-
writer.write(record)
97+
writer.write(record.SerializeToString())
9898

9999

100100
def WriteVocab(examples, vocabfile, labelfile):

0 commit comments

Comments
 (0)