"""Contains methods for supervised learning.
Non supervised classifiers (random forest, k-nearest neighbors, neural
networks) for predicting the methylation class.
"""
import ast
import hashlib
import logging
import pickle
import re
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
QuadraticDiscriminantAnalysis,
)
from sklearn.ensemble import (
AdaBoostClassifier,
BaggingClassifier,
ExtraTreesClassifier,
GradientBoostingClassifier,
HistGradientBoostingClassifier,
RandomForestClassifier,
)
from sklearn.feature_selection import (
SelectKBest,
)
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.linear_model import (
LogisticRegression,
Perceptron,
RidgeClassifier,
SGDClassifier,
)
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)
from sklearn.model_selection import (
cross_val_predict,
)
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (
MinMaxScaler,
PowerTransformer,
QuantileTransformer,
RobustScaler,
StandardScaler,
)
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
logger = logging.getLogger(__name__)
class TrainedClassifier(ABC):
"""Abstract base class for a trained classifier."""
@abstractmethod
def predict_proba(
self,
betas: np.ndarray | pd.DataFrame,
id_: pd.Series | np.ndarray | Sequence | None = None,
) -> np.ndarray:
"""Predicts the probability of the given input samples (betas).
If `id_` is provided, identifiers corresponding to samples seen during
cross-validation are matched to cached probabilities. Probabilities for
new samples are computed using the trained classifier.
Args:
betas (array-like): Input samples to predict probabilities.
id_ (sequence of str, optional): Sample identifiers used to reuse
cross-validation probabilities when available.
Returns:
np.ndarray: Predicted probabilities for each sample.
"""
@abstractmethod
def classes(self) -> np.ndarray:
"""Returns the list of classes for which predictions can be made.
Returns:
array-like: The classes the classifier can predict.
"""
def info(self, output_format: str = "txt") -> str:
"""Provides additional information of the classifier or its pipeline.
The info will be printed on the top of the classifiers report. Should
contain classifier name and classifier metrics.
Returns:
str: A description of the classifier or its components.
"""
return ""
def model(self) -> Any:
"""Returns the classifier model (usually sklearn object)."""
return
def metrics(self) -> dict[str, Any]:
"""Returns the metric statistics (usually sklearn object)."""
return {}
def __str__(self) -> str:
hash_str = hashlib.blake2b(
self.info().encode(), digest_size=16
).hexdigest()
return f"TrainedClassifier_{hash_str}"
def __repr__(self) -> str:
return str(self)
def _get_pipeline_description(clf: Any, output_format: str = "txt") -> str:
"""Generates a detailed summary string of the pipeline structure."""
def format_non_default_params(step: Any) -> list[str]:
"""Formats non-default parameters for a given step or classifier."""
default_params = step.__class__().get_params()
current_params = step.get_params()
non_default_params = {
k: v
for k, v in current_params.items()
if default_params.get(k) != v
}
if output_format == "html":
return [
(
f"<tr><td class='clf-label'>- {param}</td>"
f"<td class='clf-value'>{value}</td></tr>"
)
for param, value in non_default_params.items()
]
return [
f"- {param}: {value}"
for param, value in non_default_params.items()
]
if output_format == "html":
lines = ["<h2>Classifier Structure</h2>", "<table>"]
else:
lines = ["Classifier Structure:"]
if hasattr(clf, "steps"):
step_name_len = max(len(name) for name, _ in clf.steps)
for name, step in clf.steps:
step_name = name.capitalize()
step_type = (
step if isinstance(step, str) else step.__class__.__name__
)
if output_format == "html":
lines.append(
f"<tr><td class='clf-label'>{step_name}</td>"
f"<td class='clf-value'>{step_type}</td></tr>"
)
else:
lines.append(f"{step_name:<{step_name_len}} : {step_type}")
if not isinstance(step, str):
lines.extend(format_non_default_params(step))
else:
clf_name = clf.__class__.__name__
if output_format == "html":
lines.append(
f"<tr><td class='clf-label'>Classifier</td>"
f"<td class='clf-value'>{clf_name}</td></tr>"
)
else:
lines.append(f"Classifier: {clf_name}")
lines.extend(format_non_default_params(clf))
if output_format == "html":
lines.append("</table>")
return "\n".join(lines)
def _format_metrics(metrics: dict[str, Any]) -> dict[str, str]:
"""Transforms classifier metrics to readable text."""
def format_metric_scores(scores: Any) -> str:
if isinstance(scores, list | np.ndarray):
return f"{np.mean(scores):.4f} (SD {np.std(scores):.4f})"
return str(scores)
formatted_stats = {}
metric_keys = {
"Method": "cv",
"Samples": "n_samples",
"Features": "n_features",
"Accuracy": "accuracy_scores",
"AUC": "auc_scores",
"F1-Score": "f1_scores",
"Precision": "precision_scores",
"Recall": "recall_scores",
}
for key, metric in metric_keys.items():
if metric in metrics:
if metric == "cv":
formatted_stats[key] = (
f"{metrics[metric].n_splits}-fold cross validation"
)
else:
formatted_stats[key] = format_metric_scores(metrics[metric])
return formatted_stats
class TrainedSklearnClassifier(TrainedClassifier):
"""Trained classifier implementation using fitted scikit-learn objecs."""
def __init__(
self,
clf: Any,
X: pd.DataFrame,
metrics: dict[str, Any] | None = None,
) -> None:
self.clf = clf
self._classes = clf.classes_
self.ids = X.index
self.shape = X.shape
self._metrics = metrics or {
"n_features": X.shape[1],
"n_samples": X.shape[0],
}
def predict_proba(
self,
betas: np.ndarray | pd.DataFrame,
id_: pd.Series | np.ndarray | Sequence | None = None,
) -> np.ndarray:
return self.clf.predict_proba(betas)
def classes(self) -> np.ndarray:
return self._classes
def info(self, output_format: str = "txt") -> str:
description = _get_pipeline_description(self.clf, output_format)
formatted_stats = _format_metrics(self._metrics)
if output_format == "html":
result = [f"{description}"]
result.append("<h2>Metrics</h2>")
result.append("<table>")
for key, value in formatted_stats.items():
result.append(
f"<tr><td class='metrics-label'>{key}</td>"
f"<td class='metrics-value'>{value}</td></tr>"
)
result.append("</table>")
else:
max_key_length = max(len(key) for key in formatted_stats)
result = [description + "\n", "Metrics:"]
for key, value in formatted_stats.items():
result.append(f"{key:<{max_key_length}} : {value}")
return "\n".join(result)
def metrics(self) -> dict[str, Any]:
return self._metrics
def model(self) -> Any:
return self.clf
class TrainedSklearnCVClassifier(TrainedClassifier):
"""A trained sklearn classifier with support for cross-validation.
This class allows the use of precomputed cross-validation probabilities or
standard predictions from a trained classifier for a given sample. It also
provides additional statistics about the classifier.
"""
def __init__(
self,
clf: Any,
probabilities_cv: np.ndarray,
X: pd.DataFrame,
metrics: dict[str, Any] | None = None,
) -> None:
self.clf = clf
self.probabilities_cv = pd.DataFrame(probabilities_cv, index=X.index)
self.ids = X.index
self.shape = X.shape
self._metrics = metrics or {}
self._classes = clf.classes_
def predict_proba(
self,
betas: np.ndarray | pd.DataFrame,
id_: pd.Series | np.ndarray | Sequence | None = None,
) -> np.ndarray:
if id_ is not None:
id_ = np.ravel(id_)
probabilities = np.zeros((len(betas), len(self._classes)))
if len(betas) != len(id_):
msg = "Length of 'betas' and 'id_' must match."
raise ValueError(msg)
mask_known_ids = np.isin(id_, self.ids)
if any(mask_known_ids):
probabilities[mask_known_ids] = self.probabilities_cv.loc[
id_[mask_known_ids]
].values
unknown_betas = betas[~mask_known_ids]
if len(unknown_betas) > 0:
probabilities[~mask_known_ids] = self.clf.predict_proba(
unknown_betas
)
return probabilities
return self.clf.predict_proba(betas)
def classes(self) -> np.ndarray:
return self._classes
def info(self, output_format: str = "txt") -> str:
description = _get_pipeline_description(self.clf, output_format)
formatted_stats = _format_metrics(self._metrics)
if output_format == "html":
result = [f"{description}"]
result.append("<h2>Metrics</h2>")
result.append("<table>")
for key, value in formatted_stats.items():
result.append(
f"<tr><td class='metrics-label'>{key}</td>"
f"<td class='metrics-value'>{value}</td></tr>"
)
result.append("</table>")
else:
max_key_length = max(len(key) for key in formatted_stats)
result = [description + "\n", "Metrics:"]
for key, value in formatted_stats.items():
result.append(f"{key:<{max_key_length}} : {value}")
return "\n".join(result)
def metrics(self) -> dict[str, Any]:
return self._metrics
def model(self) -> Any:
return self.clf
class VarianceThresholdLite(BaseEstimator, TransformerMixin):
"""low memory version of sklearn.feature_selection.VarianceThreshold."""
def __init__(self, threshold: float = 1e-4) -> None:
self.threshold = threshold
def fit(
self,
X: np.ndarray | pd.DataFrame,
y: ArrayLike | None = None,
) -> "VarianceThresholdLite":
X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
if X_array.ndim != 2:
msg = "Input X must be 2D."
raise ValueError(msg)
self.variances_ = np.var(X_array, axis=0)
self.support_mask_ = self.variances_ > self.threshold
if not np.any(self.support_mask_):
msg = f"No features meet the variance threshold {self.threshold}"
raise ValueError(msg)
return self
def transform(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
return X_array[:, self.support_mask_]
def _get_support_mask(self) -> np.ndarray:
return self.variances_ > self.threshold
class TopVarianceSelector(BaseEstimator, TransformerMixin):
"""Low-memory selector for top-N features by variance."""
def __init__(self, n_top: int = 30000) -> None:
self.n_top = n_top
def fit(
self,
X: np.ndarray | pd.DataFrame,
y: ArrayLike | None = None,
) -> "TopVarianceSelector":
X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
if X_array.ndim != 2:
msg = "Input X must be 2D."
raise ValueError(msg)
# Compute variance per column
self.variances_ = np.var(X_array, axis=0)
# Get indices of top N variances
if self.n_top > X_array.shape[1]:
msg = (
f"n_top={self.n_top} exceeds number of features "
"{X_array.shape[1]}"
)
raise ValueError(msg)
self.top_indices_ = np.argsort(self.variances_)[-self.n_top :]
return self
def transform(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
return X_array[:, self.top_indices_]
def get_support(
self,
indices: bool = False,
) -> np.ndarray:
if indices:
return self.top_indices_
mask = np.zeros_like(self.variances_, dtype=bool)
mask[self.top_indices_] = True
return mask
def parse_component_key(key: str) -> tuple[str, list[Any], dict[str, Any]]:
"""Parse component name and kwargs like 'lr(n_iter=200)'."""
match = re.match(r"^(\w+)(?:\((.*)\))?$", key)
if not match:
raise ValueError(f"Invalid component key format: {key}")
name, kwargs_str = match.groups()
assert name is not None
args = []
kwargs = {}
if kwargs_str:
tree = ast.parse(f"f({kwargs_str})", mode="eval")
if not isinstance(tree.body, ast.Call):
raise ValueError(f"Malformed kwargs in: {key}")
args = [ast.literal_eval(arg) for arg in tree.body.args]
kwargs = {
kw.arg: ast.literal_eval(kw.value)
for kw in tree.body.keywords
if kw.arg is not None
}
return name, args, kwargs
def make_clf_pipeline(
step_keys: str | list[str],
X_shape: tuple[int, int] | None = None,
cv: int | Any | None = None,
) -> Pipeline:
"""Sklearn pipeline with scaling, feature selection, and classifier."""
scalers = {
"minmax": MinMaxScaler,
"power": PowerTransformer,
"quantile": QuantileTransformer,
"robust": RobustScaler,
"std": StandardScaler,
}
selectors = {
"kbest": lambda **kwargs: SelectKBest(**{"k": 10000, **kwargs}),
"pca": PCA,
"top": TopVarianceSelector,
"vtl": VarianceThresholdLite,
}
if X_shape is not None and cv:
n_splits = cv if isinstance(cv, int) else cv.n_splits
n_components_pca = min(((n_splits - 1) * X_shape[0] // n_splits), 50)
selectors["pca_auto"] = lambda **kwargs: PCA(
**{"n_components": n_components_pca, **kwargs}
)
models = {
"ada": AdaBoostClassifier,
"bag": BaggingClassifier,
"dt": DecisionTreeClassifier,
"et": ExtraTreesClassifier,
"gb": GradientBoostingClassifier,
"gp": GaussianProcessClassifier,
"hgb": HistGradientBoostingClassifier,
"knn": KNeighborsClassifier,
"lda": LinearDiscriminantAnalysis,
"lr": LogisticRegression,
"mlp": MLPClassifier,
"nb": GaussianNB,
"perceptron": Perceptron,
"qda": QuadraticDiscriminantAnalysis,
"rf": RandomForestClassifier,
"ridge": RidgeClassifier,
"sgd": SGDClassifier,
"svc": lambda **kwargs: SVC(**{"probability": True, **kwargs}),
}
components = {
**{key: ("scaler", val) for key, val in scalers.items()},
**{key: ("feature_selector", val) for key, val in selectors.items()},
**{key: ("classifier", val) for key, val in models.items()},
}
components_full_name = {
clf().__class__.__name__: (name, clf)
for name, clf in components.values()
}
components.update(components_full_name)
if isinstance(step_keys, str):
step_keys = step_keys.split("-")
pipeline_components = []
for i, full_key in enumerate(step_keys):
base_key, args, kwargs = parse_component_key(full_key)
if base_key not in components:
raise ValueError(f"Invalid step key: {base_key}")
type_name, constructor = components[base_key]
component = constructor(*args, **kwargs)
step_name = f"{i + 1}-{type_name}"
pipeline_components.append((step_name, component))
return Pipeline(pipeline_components)
def make_reports(
prediction: pd.DataFrame,
info: str,
output_format: str = "txt",
) -> list[str]:
"""Generates detailed reports from classifier predictions.
Args:
prediction (pd.DataFrame): DataFrame containing predicted probabilities
for each class, indexed by sample IDs.
info (str): A description of the classifier such as its name and
metrics. Will be printed before predictions.
output_format (str): The format of the report ('txt' or 'html').
Defaults to 'txt'.
Returns:
list[str]: A list of detailed string reports, one for each sample.
"""
reports = []
for sample_id, row in prediction.iterrows():
top_predictions = row[row > 0].nlargest(10) * 100
if output_format == "txt":
report_lines = [
str(sample_id),
"=" * len(str(sample_id)),
f"\n{info}\n",
"Classification Probability:",
]
label_len = max(len(str(label)) for label in top_predictions.index)
formatted_lines = [
f"{label:<{label_len}} : {probability:6.2f} %"
for label, probability in top_predictions.items()
]
n_dashes = (
max(len(line) for line in formatted_lines)
if formatted_lines
else 0
)
dashes = "-" * n_dashes
report_lines.extend([dashes, *formatted_lines, dashes])
if output_format == "html":
report_lines = [
f"<h1>{sample_id}</h1>",
f"{info}",
"<h2>Classification Probability</h2>",
"<div class='classification-result'>",
"<table>",
]
report_lines.extend(
(
f"<tr><td class='prob-label'>{label}</td>"
f"<td class='prob-value'>{probability:.2f}%</td></tr>"
)
for label, probability in top_predictions.items()
)
report_lines.extend(["</table>", "</div>"])
reports.append("\n".join(report_lines))
return reports
def make_classifier_report_page(
reports: Sequence[str],
title: str | None = None,
) -> str:
"""Generates an HTML page for multiple classifier reports.
Args:
reports (list of str): A list of HTML reports obtained from
'MethylAnalysis.classify'.
itle (str): Title of the report page. Defaults to 'None'.
Returns:
str: A formatted HTML string.
"""
title_str = f"<h1>{title}</h1><hr>" if title else ""
body_string = "\n<hr>\n".join(reports)
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Classifier Report</title>
<style>
body {{
font-family: Arial, sans-serif;
background: #f4f7f9;
color: #333;
margin: 0 auto;
max-width: 1000px;
line-height: 1.4;
}}
h1 {{
text-align: center;
margin-top: 10px;
font-size: 1.5em;
}}
h2 {{
font-size: 1.2em;
margin-bottom: 5px;
}}
table {{
width: 100%;
max-width: 500px;
margin: 10px auto;
border-collapse: collapse;
}}
hr {{
margin: 20px 0;
border: none;
height: 2px;
background-color: #ccc;
}}
th, td {{
padding: 3px 8px;
text-align: left;
border-bottom: 1px solid #ddd;
font-size: 0.9em;
}}
td.clf-label, td.metrics-label, td.prob-label {{
font-weight: bold;
}}
td.clf-value, td.metrics-value, td.prob-value {{
text-align: right;
}}
.classification-result td {{
background: #ffebe6;
}}
.classification-result-k {{
background: #ffe4b5;
font-weight: bold;
color: #d85c5c;
}}
@media (max-width: 768px) {{
body {{
max-width: 95%;
}}
table {{
max-width: 100%;
}}
}}
</style>
</head>
<body>
{title_str}
{body_string}
</body>
</html>"""
def _is_ovr_or_ovo(clf: Any) -> str:
"""Check if classifier is using One-vs-Rest (OvR) or One-vs-One (OvO)."""
if isinstance(clf, OneVsRestClassifier):
return "ovr"
if isinstance(clf, OneVsOneClassifier):
return "ovo"
if hasattr(clf, "decision_function_shape"):
if clf.decision_function_shape == "ovr":
return "ovr"
if clf.decision_function_shape == "ovo":
return "ovo"
# Default for most classifiers
return "ovr"
def cross_val_metrics(
clf: Any,
X: pd.DataFrame | np.ndarray,
y: pd.Series | np.ndarray | Sequence,
probabilities_cv: np.ndarray,
cv: Any,
) -> dict[str, Any]:
"""Calculates cross-validation statistics for a classifier."""
y_pred_cv = np.array(
[clf.classes_[i] for i in np.argmax(probabilities_cv, axis=1)]
)
accuracy_scores = []
auc_scores = []
f1_scores = []
multi_class_strategy = _is_ovr_or_ovo(clf)
precision_scores = []
recall_scores = []
for _, test_idx in cv.split(X, y):
y_true = np.array(y)[test_idx]
probabilities_cv_test = probabilities_cv[test_idx, :]
y_pred = y_pred_cv[test_idx]
accuracy_scores.append(accuracy_score(y_true, y_pred))
f1_scores.append(f1_score(y_true, y_pred, average="weighted"))
precision_scores.append(
precision_score(
y_true, y_pred, average="weighted", zero_division=0
)
)
recall_scores.append(
recall_score(y_true, y_pred, average="weighted", zero_division=0)
)
# Calculate ROC AUC score
if len(np.unique(y_true)) == 2:
# Binary classification: use probabilities for the positive class
roc_auc = roc_auc_score(y_true, probabilities_cv_test[:, 1])
else:
# Multiclass classification: use the full probabilities matrix
roc_auc = roc_auc_score(
y_true,
probabilities_cv_test,
multi_class=multi_class_strategy,
average="weighted",
)
auc_scores.append(roc_auc)
return {
"cv": cv,
"n_samples": X.shape[0],
"n_features": X.shape[1],
"accuracy_scores": accuracy_scores,
"auc_scores": auc_scores,
"f1_scores": f1_scores,
"precision_scores": precision_scores,
"recall_scores": recall_scores,
}
def is_trained(clf: Any) -> bool:
"""Checks if the given classifier has already been trained."""
clf_ = clf
if hasattr(clf, "steps"):
clf_ = clf.steps[-1][1]
trained_attributes = ["classes_"]
return all(hasattr(clf_, attr) for attr in trained_attributes)
def train_clf(
clf: Any,
X: pd.DataFrame,
y: ArrayLike,
cv: int | Any,
n_jobs: int = 1,
) -> TrainedSklearnClassifier | TrainedSklearnCVClassifier:
"""Trains a classifier and stores the trained model to disk.
If the classifier has already been trained (and saved), it loads the
trained model. If not, it trains the classifier using the provided data and
saves it to disk.
Args:
clf (classifier): The classifier to train or load.
X (array-like): The feature matrix.
y (array-like): The target labels.
cv (int or cross-validation generator): Determines the cross-validation
splitting strategy.
n_jobs (int): Number of parallel processes to run.
Returns:
TrainedSklearnClassifier or TrainedSklearnCVClassifier: The trained
classifier object.
"""
n_splits = cv if isinstance(cv, int) else cv.n_splits
if is_trained(clf):
return TrainedSklearnClassifier(clf, X=X)
logger.info("Start training...")
counts = np.unique(y, return_counts=True)
unique_classes, counts_per_class = counts
if min(counts_per_class) < n_splits:
# Find the class with the minimum count
insufficient_idx = np.where(counts_per_class < n_splits)[0]
insufficient_list = [
f" - {unique_classes[i]} (n={counts_per_class[i]})"
for i in insufficient_idx
]
insufficient_str = "\n".join(insufficient_list)
raise ValueError(
f"The following classes have fewer samples than cross-validation "
f"splits (n={n_splits}):\n{insufficient_str}\n\n Please check "
"both your annotation file and that the corresponding samples "
"exist on disk.\n You can reduce 'cv' (e.g., set cv=1 to train "
"without cross-validation), or provide more samples for these "
"classes."
)
clf.fit(X, y)
logger.info("Start cross-validation...")
probabilities_cv = cross_val_predict(
clf, X, y, cv=cv, method="predict_proba", n_jobs=n_jobs
)
metrics = cross_val_metrics(clf, X, y, probabilities_cv, cv)
trained_clf = TrainedSklearnCVClassifier(
clf=clf, probabilities_cv=probabilities_cv, X=X, metrics=metrics
)
return trained_clf
@dataclass
class ClassifierResult:
"""Data container for evaluation of classifier."""
prediction: Any
model: Any
metrics: dict
reports: dict
[docs]
def fit_and_evaluate_clf(
X: pd.DataFrame | None,
y: ArrayLike | None,
X_test: pd.DataFrame | np.ndarray,
id_test: Sequence[str] | None,
save_path: Path,
clf: Any,
cv: int | Any,
n_jobs: int = 1,
) -> ClassifierResult:
"""Predicts the methylation class by supervised learning classifier.
Uses supervised machine learning classifiers (Random Forest, K-Nearest
Neighbors, Neural Networks, SVM, ...) to predict the methylation class of
the sample. Output will be written to disk.
Args:
X (pd.DataFrame): Feature matrix (rows as samples, columns as
features).
y (array-like): Class labels.
X_test (array-like): Value of the sample to be evaluated.
id_test (str): Unique identifiers for the test samples to be evaluated.
save_path (str or Path): Path where the classifiers and results
will be saved/cached.
clf (list): Classifier to use. Can be:
- A scikit-learn classifier object or pipeline (trained or
untrained).
- A string in the format "scaler-selector-classifier". Possible
values are:
- A pipeline string composed of arbitrary components joined by
dashes ("-"). Each component can be specified using either an
abbreviation or the full class name (e.g., "std" or
"StandardScaler").:
scaler:
- "std": Standard scaling (StandardScaler).
- "minmax": Min-max scaling (MinMaxScaler).
- "robust": Robust scaling (RobustScaler).
- "power": Power transformation (PowerTransformer).
- "quantile": Quantile transformation
(QuantileTransformer).
selector:
- "kbest": Select the best features (SelectKBest).
- "top": To varying features (TopVarianceSelector).
- "pca": Principal component analysis (PCA).
- "pca_auto": Principal component analysis (PCA). Number of
components is determined automatically.
- "lda": Linear Discriminant Analysis (LDA).
clf:
- "rf": RandomForestClassifier.
- "lr": LogisticRegression.
- "et": ExtraTreesClassifier.
- "knn": KNeighborsClassifier.
- "mlp": MLPClassifier.
- "svc": Support Vector Classifier (SVC).
- "ada": AdaBoostClassifier.
- "bag": BaggingClassifier.
- "dt": DecisionTreeClassifier.
- "gp": GaussianProcessClassifier.
- "hgb": HistGradientBoostingClassifier.
- "nb": GaussianNB.
- "perceptron": Perceptron.
- "qda": Quadratic Discriminant Analysis (QDA).
- "ridge": RidgeClassifier.
- "sgd": SGDClassifier.
Example: Using a feature selector and a classifier (SelectKBest
selection and Logistic Regression):
- "kbest-lr"
- A custom class, that inherits from `TrainedClassifier`.
cv (int or cross-validation generator): Determines the cross-validation
splitting strategy.
n_jobs (int): Number of parallel processes to run.
Returns:
ClassifierResult:
- prediction (DataFrame): DataFrame containing the predicted
probabilities for each class.
- model (object): The trained classifier object.
- metrics (dict): Dict containing classifier metrics.
- reports (dict): Dict of evaluation report (both 'txt' and 'html')
for each sample.
"""
if save_path.exists():
logger.info("Reading trained and cached classifier.")
with save_path.open("rb") as file:
trained_clf = pickle.load(file)
elif isinstance(clf, Pipeline | ClassifierMixin):
if X is None or y is None:
raise ValueError("Both X and y must be different from 'None'.")
trained_clf = train_clf(clf=clf, X=X, y=y, cv=cv, n_jobs=n_jobs)
elif isinstance(clf, str):
if X is None or y is None:
raise ValueError("Both X and y must be different from 'None'.")
pipeline = make_clf_pipeline(clf, X.shape, cv)
trained_clf = train_clf(clf=pipeline, X=X, y=y, cv=cv, n_jobs=n_jobs)
else:
trained_clf = clf
classes = trained_clf.classes()
probabilities = trained_clf.predict_proba(X_test, id_test)
info_txt = trained_clf.info(output_format="txt")
info_html = trained_clf.info(output_format="html")
prediction = pd.DataFrame(probabilities, index=id_test, columns=classes)
reports_txt = make_reports(prediction, info_txt, output_format="txt")
reports_html = make_reports(prediction, info_html, output_format="html")
result = ClassifierResult(
prediction,
trained_clf.model(),
trained_clf.metrics(),
{"txt": reports_txt, "html": reports_html},
)
if not save_path.exists():
with save_path.open("wb") as file:
pickle.dump(trained_clf, file)
return result