Skip to content
This repository was archived by the owner on Feb 28, 2025. It is now read-only.

Silhouette plot improvements #213

Merged
merged 4 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- [API Change] A new figure and axes is created (via `plt.subplots()`) when calling a plotting method with `ax=None`. Previously, the current axes was used (via `plt.gca()`) ([#211](https://github.com/ploomber/sklearn-evaluation/pull/211))
- [Fix] Validating input elbow curve model has "score" method [#146]
- [Fix] Adds class labels for multi class roc plot (#209)
- [API Change] `silhouette_analysis_from_results` function now accepts a list of cluster labels [#213](https://github.com/ploomber/sklearn-evaluation/pull/213)

## 0.9.0 (2023-01-13)

Expand Down
4 changes: 4 additions & 0 deletions docs/api/plot.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ elbow_curve
-----------
.. autofunction:: sklearn_evaluation.plot.elbow_curve

.. _elbow-curve-from-results-label:

elbow_curve_from_results
------------------------
.. autofunction:: sklearn_evaluation.plot.elbow_curve_from_results
Expand Down Expand Up @@ -115,6 +117,8 @@ silhouette_analysis
-------------------
.. autofunction:: sklearn_evaluation.plot.silhouette_analysis

.. _silhouette-analysis-from-results-label:

silhouette_analysis_from_results
--------------------------------
.. autofunction:: sklearn_evaluation.plot.silhouette_analysis_from_results
Expand Down
32 changes: 8 additions & 24 deletions docs/clustering/clustering_evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,11 @@ Elbow curve helps to identify the point at which the plot starts to become paral
plot.elbow_curve(X, kmeans, range_n_clusters=range(1, 30))
```

##### Elbow curve from results
```{eval-rst}
.. tip::

If you want to train the models yourself, you can use :ref:`elbow-curve-from-results-label` to plot.

```{code-cell} ipython3
import numpy as np
n_clusters = range(1, 10, 2)
sum_of_squares = np.array([4572.2, 470.7, 389.9, 335.1, 305.5])
plot.elbow_curve_from_results(n_clusters, sum_of_squares, times=None)
```

##### Silhouette plot
Expand All @@ -71,23 +69,9 @@ The below plot shows that n_clusters value of 3, 5 and 6 are a bad pick for the
silhouette = plot.silhouette_analysis(X, kmeans)
```

##### Silhouette plot from cluster labels

```{code-cell} ipython3
X, y = datasets.make_blobs(
n_samples=500,
n_features=2,
centers=4,
cluster_std=1,
center_box=(-10.0, 10.0),
shuffle=True,
random_state=1,
)
```{eval-rst}
.. tip::

kmeans = KMeans(n_clusters=4, random_state=1, n_init=5)
cluster_labels = kmeans.fit_predict(X)
```
If you want to train the models yourself, you can use :ref:`silhouette-analysis-from-results-label` to plot.

```{code-cell} ipython3
silhouette = plot.silhouette_analysis_from_results(X, cluster_labels)
```
```
21 changes: 21 additions & 0 deletions examples/elbow_curve_from_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import time
import numpy as np

from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs

from sklearn_evaluation import plot

X, _ = make_blobs(n_samples=100, centers=3, n_features=5, random_state=0)

n_clusters = range(1, 30)
sum_of_squares = []
cluster_times = []
for i in n_clusters:
start = time.time()
kmeans = KMeans(n_clusters=i, n_init=5)
sum_of_squares.append(kmeans.fit(X).score(X))
cluster_times.append(time.time() - start)

sum_of_squares = np.absolute(sum_of_squares)
plot.elbow_curve_from_results(n_clusters, sum_of_squares, cluster_times)
2 changes: 0 additions & 2 deletions examples/silhouette_plot_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

from sklearn_evaluation import plot

Expand All @@ -16,4 +15,3 @@

kmeans = KMeans(random_state=1, n_init=5)
plot.silhouette_analysis(X, kmeans, range_n_clusters=[3])
plt.show()
15 changes: 11 additions & 4 deletions examples/silhouette_plot_from_results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

from sklearn_evaluation import plot

Expand All @@ -14,7 +13,15 @@
random_state=1,
)

kmeans = KMeans(n_clusters=4, random_state=1, n_init=5)
cluster_labels = kmeans.fit_predict(X)
cluster_labels = []

# Cluster labels for four clusters
kmeans = KMeans(n_clusters=4, n_init=5)
cluster_labels.append(kmeans.fit_predict(X))

# Cluster labels for five clusters
kmeans = KMeans(n_clusters=5, n_init=5)
cluster_labels.append(kmeans.fit_predict(X))


plot.silhouette_analysis_from_results(X, cluster_labels)
plt.show()
19 changes: 13 additions & 6 deletions src/sklearn_evaluation/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,22 @@
silhouette_analysis,
silhouette_analysis_from_results,
)
from sklearn_evaluation.plot.regression \
import residuals, prediction_error, cooks_distance
from sklearn_evaluation.plot.regression import (
residuals,
prediction_error,
cooks_distance,
)
from sklearn_evaluation.plot.target_analysis import target_analysis
from sklearn_evaluation.plot.calibration import calibration_curve, scores_distribution
from sklearn_evaluation.plot.classification_report \
import classification_report, ClassificationReport
from sklearn_evaluation.plot.classification_report import (
classification_report,
ClassificationReport,
)
from sklearn_evaluation.plot.ks_statistics import ks_statistic
from sklearn_evaluation.plot.cumulative_gain_lift_curve \
import cumulative_gain, lift_curve
from sklearn_evaluation.plot.cumulative_gain_lift_curve import (
cumulative_gain,
lift_curve,
)
from sklearn_evaluation.plot.feature_ranking import Rank1D, Rank2D

__all__ = [
Expand Down
80 changes: 66 additions & 14 deletions src/sklearn_evaluation/plot/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,23 @@
from ploomber_core.exceptions import modify_exceptions
from ploomber_core import deprecated

# TODO: add unit test

def _generate_axes(cluster, figsize, ax):
if ax is not None:
if not isinstance(ax, list):
ax = [ax]
if len(cluster) != len(ax):
raise ValueError(
f"Received lengths, cluster : {len(cluster)},"
f"axes : {len(ax)}."
f"Number of axes passed should match number of clusters"
)
else:
ax = []
for i in range(len(cluster)):
_, axes = plt.subplots(1, 1, figsize=figsize)
ax.append(axes)
return ax


@SKLearnEvaluationLogger.log(feature="plot")
Expand Down Expand Up @@ -139,9 +155,13 @@ def elbow_curve_from_results(n_clusters, sum_of_squares, times, ax=None):
"""
Same as `elbow_curve`, but it takes the number of clusters and sum of
squares as inputs. Useful if you want to train the models yourself.

Examples
--------
.. plot:: ../examples/elbow_curve_from_results.py

"""
# TODO: unit test this
# TODO: also test with unsorted input

idx = np.argsort(n_clusters)
n_clusters = np.array(n_clusters)[idx]
sum_of_squares = np.array(sum_of_squares)[idx]
Expand Down Expand Up @@ -183,6 +203,7 @@ def _clone_and_score_clusterer(clf, X, n_clusters):


@SKLearnEvaluationLogger.log(feature="plot")
@modify_exceptions
def silhouette_analysis(
X,
clf,
Expand Down Expand Up @@ -262,21 +283,24 @@ def silhouette_analysis(
"Cannot plot silhouette analysis ."
)

for n_clusters in range_n_clusters:
_, ax = plt.subplots(1, 1, figsize=figsize)
# if no ax is passed by user generate new plot
# for each model

ax = _generate_axes(range_n_clusters, figsize, ax)

for ax, n_clusters in zip(ax, range_n_clusters):
clf = clone(clf)
setattr(clf, "n_clusters", n_clusters)
cluster_labels = clf.fit_predict(X)

ax = silhouette_analysis_from_results(
_silhouette_analysis_one_model(
X, cluster_labels, metric, figsize, cmap, text_fontsize, ax
)
return ax


@SKLearnEvaluationLogger.log(feature="plot")
@modify_exceptions
def silhouette_analysis_from_results(
def _silhouette_analysis_one_model(
X,
cluster_labels,
metric="euclidean",
Expand All @@ -286,12 +310,7 @@ def silhouette_analysis_from_results(
ax=None,
):
"""
Same as `silhouette_plot` but takes cluster_labels as input.
Useful if you want to train the model yourself

Notes
-----
.. versionadded:: 0.8.3
Generate silhouette plot for one value of n_cluster.
"""
cluster_labels = np.asarray(cluster_labels)

Expand Down Expand Up @@ -365,3 +384,36 @@ def silhouette_analysis_from_results(
ax.tick_params(labelsize=text_fontsize)
ax.legend(loc="best", fontsize=text_fontsize)
return ax


@SKLearnEvaluationLogger.log(feature="plot")
@modify_exceptions
def silhouette_analysis_from_results(
X,
cluster_labels,
metric="euclidean",
figsize=None,
cmap="nipy_spectral",
text_fontsize="medium",
ax=None,
):
"""
Same as `silhouette_plot` but takes list of cluster_labels as input.
Useful if you want to train the model yourself

Examples
--------
.. plot:: ../examples/silhouette_plot_from_results.py

"""

# if no ax is passed by user generate new plot
# for each model

ax = _generate_axes(cluster_labels, figsize, ax)

for ax, label in zip(ax, cluster_labels):
_silhouette_analysis_one_model(
X, label, metric, figsize, cmap, text_fontsize, ax
)
return ax
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading