Skip to content

[ModelZoo] unify optimizer loss names #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ ARG WITH_SQLFLOW_MODELS="ON"
RUN if [ "${WITH_SQLFLOW_MODELS:-ON}" = "ON" ]; then \
git clone https://github.com/sql-machine-learning/models.git && \
cd models && \
git checkout 58f4c137129e2bc749320bafcc8fddb7c737fed9 && \
git checkout 91d63b581fad5686a2132635f0aa000b0699a1da && \
bash -c "source activate sqlflow-dev && python setup.py install" && \
cd .. && \
rm -rf models; \
3 changes: 2 additions & 1 deletion python/sqlflow_submitter/tensorflow/predict.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,7 @@ def pred(is_keras_model,
classifier = estimator(**feature_columns, **model_params, model_dir=save)
else:
classifier = estimator(**feature_columns, **model_params)
classifier_pkg = sys.modules[estimator.__module__]


if is_keras_model:
@@ -123,7 +124,7 @@ def eval_input_fn(batch_size):
except tf.errors.OutOfRangeError:
break
result = classifier.predict_on_batch(features[0])
result = classifier.prepare_prediction_column(result[0])
result = classifier_pkg.prepare_prediction_column(result[0])
row = []
for idx, name in enumerate(feature_column_names):
val = features[0][name].numpy()[0]
6 changes: 4 additions & 2 deletions python/sqlflow_submitter/tensorflow/train.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import sys, json
import tensorflow as tf
import functools
import sys
try:
import sqlflow_models
except:
@@ -68,6 +69,7 @@ def train(is_keras_model,
classifier = estimator(**feature_columns, **model_params, model_dir=save)
else:
classifier = estimator(**feature_columns, **model_params)
classifier_pkg = sys.modules[estimator.__module__]

def input_fn(datasetStr):
feature_types = []
@@ -96,8 +98,8 @@ def validate_input_fn(batch_size):
return dataset.batch(batch_size)

if is_keras_model:
classifier.compile(optimizer=classifier.default_optimizer(),
loss=classifier.default_loss(),
classifier.compile(optimizer=classifier_pkg.optimizer(),
loss=classifier_pkg.loss(),
metrics=["accuracy"])
if hasattr(classifier, 'sqlflow_train_loop'):
classifier.sqlflow_train_loop(train_input_fn(batch_size))