-
Notifications
You must be signed in to change notification settings - Fork 17
[ModelZoo] unify with elastic dl model zoo #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ModelZoo] unify with elastic dl model zoo #25
Conversation
def loss(): | ||
"""Default loss function. Used in model.compile.""" | ||
return 'sparse_categorical_crossentropy' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if string representation would work in ElasticDL yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add unit tests on models repo when merging ElasticDL model_zoo/
dir.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean we'll test this later? I am quite sure this won't work in ElasticDL yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. Will fix this later in another PR.
def loss(): | ||
"""Default loss function. Used in model.compile.""" | ||
return 'sparse_categorical_crossentropy' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean we'll test this later? I am quite sure this won't work in ElasticDL yet.
@@ -10,6 +12,7 @@ def __init__(self, feature_columns, stack_units=[32], hidden_size=64, n_classes= | |||
:param n_classes: Target number of classes. | |||
:type n_classes: int. | |||
""" | |||
global _loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these global
keywords needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, check out the above comment:
_train_lr
may be configured by SQL statements when initializing the model instance. e.g.
SELECT * FROM mytrain_table TO TRAIN sqlflow_models.DNNClassifier WITH train_lr=0.0001 ...
In the above SQL statement, train_lr
will be passed to DeepEmbeddingClusterModel
to configure the learning rate.
The previous change #24 is not correct comparing to https://github.com/sql-machine-learning/elasticdl/blob/develop/model_zoo/cifar10_subclass/cifar10_subclass.py.
This PR changes how the customed model will be used in SQLFlow:
optimizer
,loss
function moved to module function, not a member of the model class to be the same as ElasticDL. And, if we putoptimizer
,loss
function inside the model class, it will cause keras model prediction error: https://travis-ci.com/sql-machine-learning/sqlflow/jobs/254846188