Lightning-fast text classification with scikit-learn integration.
LightningText is an adapter for using FastText's Python module with scikit-learn, enabling easy use of scikit-learn's features (cross validation, various metrics, multi-output, ...) with FastText.
Please note that while this project strives for maximum possible compatibility
with scikit-learn, it is not currently possible to pass all tests executed by
check_estimator
,
mostly due to FastText's behavior.
While this project builds upon both FastText and scikit-learn, it is an independent project not associated with either of the two.
pip install lightning-text
These are thin wrappers of the APIs exposed by FastText's Python module, providing just explicit declaration of parameters, their default values and documentation. They pass the arguments directly to their respective counterparts in FastText module:
tokenize
load_model
train_supervised
train_unsupervised
The FastTextClassifier
class wraps supervised learning of FastText as a
scikit-learn classifier.
To ensure compatibility with scikit-learn, the classifier requires the targets
(labels) to be integers instead of strings. One way to encode string labels to
integers is to use scikit-learn's
LabelEncoder
or
MultiLabelBinarizer
,
a more convenient way for handling a dataset in FastText format is to use the
Dataset
class (see below).
The classifier also adds support for pickling, including serialization of the underlying FastText model.
Scikit-learn uses integers to represent classification targets (classes), and by default, these are used as label names when fitting the underlying FastText model.
If, however, text representation (usually original names) of the classes are
desired to be known by the FastText model, (e.g. if deploying the final model in
a stand-alone way), a label encoder can be passed to FastTextClassifier
's
fit()
method using the label_encoder
parameter. FastTextClassifier
supports both
LabelEncoder
(for binary and multi-class classification) and
MultiLabelBinarizer
(for multi-label classification).
Note that if you use the label_encoder
parameter, the class names must not
contain whitespace, otherwise you'll encounter exceptions during inference.
The module provides the preprocess_text
utility for basic preprocessing of
raw text for FastText. The function also provides optional removal of HTTP(S)
URLs from text.
The Dataset
class provides a convenient way of handling existing FastText
datasets, while representing the dataset in a scikit-learn-compatible way.
Loading a FastText dataset using Dataset.load_from
or Dataset.from_file
will
automatically convert the string labels to integers, using either scikit-learn's
LabelEncoder
or
MultiLabelBinarizer
(if any sample has multiple labels assigned) to do the conversion. These class
methods return the used label encoder with the created dataset for converting a
fitted model's predictions back to text labels using the inverse_transform
method on the encoder.
LightningText provides a scikit-learn-style Optuna-powered
hyperparameter search with cross-validation, matching the APIs of other
scikit-learn's hyperparameter searches, e.g.
RandomizedSearchCV
.
See the OptunaSearchCV
class for details.
LightningText provides additional utilities for interpreting scores:
get_fold_scores_per_candidate
- Takes results of hyperparameter search with cross-validation (thecv_results_
field) and returns a dictionary of metric names to 2D numpy array of shape(n_candidates, n_folds)
.robust_cv_score
- Returns a harmonic mean of the provided scores. Harmonic mean puts more weight on lower scores, naturally penalizing high variance.penalized_cv_score
- Explicitly penalize the scores by their standard deviation and the providedpenalty_weight
.stability_score
- Measure how many scores are withinthreshold
of mean and return the ratio.
This example demonstrates use with a binary or multi-class (single-label) classification problem.
from lightning_text import FastTextClassifier
from sklearn.metrics import classification_report
classifier = FastTextClassifier()
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print(classification_report(y_test, y_pred))
There are two options for multi-label classification:
from lightning_text import FastTextClassifier
from sklearn.metrics import classification_report, hamming_loss
from sklearn.multioutput import MultiOutputClassifier
classifier = FastTextClassifier(
loss='ova,
)
classifier.fit(X_train, Y_train)
Y_pred = classifier.predict(X_test)
print(f'Hamming loss: {hamming_loss(Y_test, Y_pred)}')
print(classification_report(Y_test, Y_pred))
The classifier will be a faster to fit and a occupy smaller space when saved,
however, requires tuning the decision threshold for its predict()
method to be
useful (see
TunedThresholdClassifierCV
as one option for achieving this).
Training a binary classifier for each class using scikit-learn's MultiOutputClassifier
meta-estimator
from lightning_text import FastTextClassifier
from sklearn.metrics import classification_report, hamming_loss
from sklearn.multioutput import MultiOutputClassifier
# Binarizes the problem and trains a FastTextClassifier for predicting the label
# for each individual class.
classifier = MultiOutputClassifier(
FastTextClassifier(
verbose=0,
),
n_jobs=4,
)
classifier.fit(X_train, Y_train)
Y_pred = classifier.predict(X_test)
print(f'Hamming loss: {hamming_loss(Y_test, Y_pred)}')
print(classification_report(Y_test, Y_pred))
This will train a binary FastText classifier for detecting each class using the one-vs-all strategy and the resulting classifier will be usable right after fitting, however the classifier will be slower to fit, predict and will occupy larger space when saved, which could make it impractical if the number of classes is large.
Note that the use of
StratifiedKFold
is likely beneficial for correct evaluation a binary classification problem.
Stratification is hard to impractical or impossible for multi-label problems.
from lightning_text import FastTextClassifier
from sklearn.model_selection import cross_validate
classifier = FastTextClassifier()
cv = cross_validate(
classifier,
X,
y,
cv=5,
scoring='f1',
n_jobs=4,
)
print(cv)
from typing import Any
from lightning_text import FastTextClassifier
from lightning_text.optuna import (
OptunaSearchCV,
DEFAULT_SUPERVISED_TRAINING_HYPERPARAMETERS_SPACE,
)
import optuna
from sklearn.metrics import fbeta_score, make_scorer
def metrics_to_optuna_goals(metrics: dict[str, Any]) -> float:
last_mean_fbeta = metrics['mean_test_score'][-1]
return last_mean_fbeta
tries = 128
estimator = FastTextClassifier(
verbose=0,
)
study = optuna.create_study(direction='maximize')
search = OptunaSearchCV(
estimator=estimator,
study=study,
hyperparameters_space=DEFAULT_SUPERVISED_TRAINING_HYPERPARAMETERS_SPACE,
n_iter=tries,
scoring=make_scorer(fbeta_score, pos_label=1, beta=1),
optuna_metrics_exporter=metrics_to_optuna_goals,
n_jobs=4,
refit='fbeta',
cv=5,
show_progress_bar=True,
)
search.fit(X, y)
print(search.cv_results_)
print(search.best_params_)
print(search.best_score_)
best_estimator = search.best_estimator_
lightning_text
is distributed under the terms of the
MIT license.