-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add supports for feature importances #166
Conversation
Codecov Report
@@ Coverage Diff @@
## dev #166 +/- ##
==========================================
- Coverage 89.51% 88.17% -1.34%
==========================================
Files 10 10
Lines 992 1201 +209
==========================================
+ Hits 888 1059 +171
- Misses 104 142 +38
Continue to review full report at Codecov.
|
@yufongpeng Thanks for this contribution. A cursory review has convinced me that this is worthy and well-executed enhancement to DecisionTrees.jl. I think I can make a good detailed review of this PR but it will a little time. In the meantime, can you please add document strings for the three new feature importance methods, and also a document string for your enhancement of Additionally, we will need to add tests to this PR. The current tests do not seem adequate, as far as I can tell. |
There's some changes for this branch:
|
@ablaom I think this PR is almost ready, but there's one thing I need to fix. In the new version of the |
src/classification/main.jl
Outdated
|
||
Prune tree based on prediction accuracy of each nodes. | ||
* `purity_thresh`: If prediction accuracy of a stump is larger than this value, the node will be pruned and become a leaf. | ||
* `loss`: The loss function for computing node impurity. It can be either `util.entropy`, `util.gini` or other measures of impurity depending on the loss function used for building the tree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm finding this one sentence a little confusing. Do you mean to say that the loss can be any of these, but would generally be chosen to match the loss used in building the tree (but doesn't have to be)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you can change the loss function freely because the tree did not store the loss function and it has to be provided externally. I've changed the algorithm a little bit: If a tree Root{S, T}
where T is Float64
, the loss function will be mean_squared_error
; otherwise, it will be util.entropy
.
@yufongpeng I agree that an accuracy-based |
src/classification/main.jl
Outdated
|
||
Prune tree based on prediction accuracy of each nodes. | ||
* `purity_thresh`: If prediction accuracy of a stump is larger than this value, the node will be pruned and become a leaf. | ||
* `loss`: The loss function for computing node impurity. It can be either `util.entropy`, `util.gini` or other measures of impurity depending on the loss function used for building the tree. | ||
* `loss`: The loss function for computing node impurity. For classification tree, default function is `util.entropy`; for regression tree, it's `mean_squared_error`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two points:
-
util.entropy
is not in theDecisionTree
namespace. So unless we decide to export this, we should writeDecisionTree.util.entropy
-
We should mention
DecisionTree.util.gini
as an available option, as user is unlikely to discover it otherwise.
Apologies if this is orthogonal to the PR itself, and also if this is a naive question, but I'm anxious to get some results that make use of it. Are changes required to MLJDecisionTreeInterface in order for me to use this via MLJ? I've added @yufongpeng 's fork, and trained a machine using the MLJ interface. I tried using Is there some other way to do this, or do I need to use the DecisionTree API directly instead? |
I'm currently working on MLJ API for feature importance. I'll be done with it this week. |
Got it, thanks for putting this together! I was able to use it on the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diff doesn't only show what this PR changes but also contains the changes of #174 and other PRs. This is probably caused by a wrong merge of dev
into this PR. I don't know how to fix it via Git, but you could copy the files that you've changed over to a new PR based on dev
and then it should all be looking good again
EDIT: Also, I would suggest next time to create a PR on a new branch in the fork. That way, upstream changes can be merged into the dev
branch on the fork and then merged into the new branch. Currently, the PR is at the dev
branch itself so it isn't so easy to merge changes into it from the original dev
base = score(trees, labels, features) | ||
scores = Matrix{Float64}(undef, size(features, 2), n_iter) | ||
rng = mk_rng(rng)::Random.AbstractRNG | ||
for (i, col) in enumerate(eachcol(features)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eachcol
methods requires at least julia 1.1. I suggest we bump the minimum julia compat for DecisionTrees to "1.1".
@ablaom @yufongpeng.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is already at 1.6 (#179) but this PR is in a bad state.
I just attempted a rebase, but there was a lot going on that I didn't totally understand. Someone is going to have to be careful to not undo stuff that's happened on |
@yufongpeng Can you please try @rikhuijzer 's suggestions to clean up the git issues? |
Closed in favour of #182 |
edit (@ablaom) Closes #170.
Needs:
I've added three methods for calculating feature importances.
The default
feature_importances
was calculating byMean Decrease in Impurity
which is calculated simultaneously with model building.permutation_importances
shuffles the columns multiple times and comparesR2
oraccuracy
with the original model.dropcol_importances
deletes each columns and uses cross validation instead.