Source code for mepylome.analysis.methyl_clf

"""Contains methods for supervised learning.

Non supervised classifiers (random forest, k-nearest neighbors, neural
networks) for predicting the methylation class.
"""

import ast
import hashlib
import logging
import pickle
import re
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import (
    LinearDiscriminantAnalysis,
    QuadraticDiscriminantAnalysis,
)
from sklearn.ensemble import (
    AdaBoostClassifier,
    BaggingClassifier,
    ExtraTreesClassifier,
    GradientBoostingClassifier,
    HistGradientBoostingClassifier,
    RandomForestClassifier,
)
from sklearn.feature_selection import (
    SelectKBest,
)
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.linear_model import (
    LogisticRegression,
    Perceptron,
    RidgeClassifier,
    SGDClassifier,
)
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import (
    cross_val_predict,
)
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (
    MinMaxScaler,
    PowerTransformer,
    QuantileTransformer,
    RobustScaler,
    StandardScaler,
)
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

logger = logging.getLogger(__name__)


class TrainedClassifier(ABC):
    """Abstract base class for a trained classifier."""

    @abstractmethod
    def predict_proba(
        self,
        betas: np.ndarray | pd.DataFrame,
        id_: pd.Series | np.ndarray | Sequence | None = None,
    ) -> np.ndarray:
        """Predicts the probability of the given input samples (betas).

        If `id_` is provided, identifiers corresponding to samples seen during
        cross-validation are matched to cached probabilities. Probabilities for
        new samples are computed using the trained classifier.

        Args:
            betas (array-like): Input samples to predict probabilities.
            id_ (sequence of str, optional): Sample identifiers used to reuse
                cross-validation probabilities when available.

        Returns:
            np.ndarray: Predicted probabilities for each sample.
        """

    @abstractmethod
    def classes(self) -> np.ndarray:
        """Returns the list of classes for which predictions can be made.

        Returns:
            array-like: The classes the classifier can predict.
        """

    def info(self, output_format: str = "txt") -> str:
        """Provides additional information of the classifier or its pipeline.

        The info will be printed on the top of the classifiers report. Should
        contain classifier name and classifier metrics.

        Returns:
            str: A description of the classifier or its components.
        """
        return ""

    def model(self) -> Any:
        """Returns the classifier model (usually sklearn object)."""
        return

    def metrics(self) -> dict[str, Any]:
        """Returns the metric statistics (usually sklearn object)."""
        return {}

    def __str__(self) -> str:
        hash_str = hashlib.blake2b(
            self.info().encode(), digest_size=16
        ).hexdigest()
        return f"TrainedClassifier_{hash_str}"

    def __repr__(self) -> str:
        return str(self)


def _get_pipeline_description(clf: Any, output_format: str = "txt") -> str:
    """Generates a detailed summary string of the pipeline structure."""

    def format_non_default_params(step: Any) -> list[str]:
        """Formats non-default parameters for a given step or classifier."""
        default_params = step.__class__().get_params()
        current_params = step.get_params()
        non_default_params = {
            k: v
            for k, v in current_params.items()
            if default_params.get(k) != v
        }
        if output_format == "html":
            return [
                (
                    f"<tr><td class='clf-label'>- {param}</td>"
                    f"<td class='clf-value'>{value}</td></tr>"
                )
                for param, value in non_default_params.items()
            ]
        return [
            f"- {param}: {value}"
            for param, value in non_default_params.items()
        ]

    if output_format == "html":
        lines = ["<h2>Classifier Structure</h2>", "<table>"]
    else:
        lines = ["Classifier Structure:"]
    if hasattr(clf, "steps"):
        step_name_len = max(len(name) for name, _ in clf.steps)
        for name, step in clf.steps:
            step_name = name.capitalize()
            step_type = (
                step if isinstance(step, str) else step.__class__.__name__
            )
            if output_format == "html":
                lines.append(
                    f"<tr><td class='clf-label'>{step_name}</td>"
                    f"<td class='clf-value'>{step_type}</td></tr>"
                )
            else:
                lines.append(f"{step_name:<{step_name_len}} : {step_type}")
            if not isinstance(step, str):
                lines.extend(format_non_default_params(step))
    else:
        clf_name = clf.__class__.__name__
        if output_format == "html":
            lines.append(
                f"<tr><td class='clf-label'>Classifier</td>"
                f"<td class='clf-value'>{clf_name}</td></tr>"
            )
        else:
            lines.append(f"Classifier: {clf_name}")
        lines.extend(format_non_default_params(clf))

    if output_format == "html":
        lines.append("</table>")

    return "\n".join(lines)


def _format_metrics(metrics: dict[str, Any]) -> dict[str, str]:
    """Transforms classifier metrics to readable text."""

    def format_metric_scores(scores: Any) -> str:
        if isinstance(scores, list | np.ndarray):
            return f"{np.mean(scores):.4f} (SD {np.std(scores):.4f})"
        return str(scores)

    formatted_stats = {}

    metric_keys = {
        "Method": "cv",
        "Samples": "n_samples",
        "Features": "n_features",
        "Accuracy": "accuracy_scores",
        "AUC": "auc_scores",
        "F1-Score": "f1_scores",
        "Precision": "precision_scores",
        "Recall": "recall_scores",
    }

    for key, metric in metric_keys.items():
        if metric in metrics:
            if metric == "cv":
                formatted_stats[key] = (
                    f"{metrics[metric].n_splits}-fold cross validation"
                )
            else:
                formatted_stats[key] = format_metric_scores(metrics[metric])

    return formatted_stats


class TrainedSklearnClassifier(TrainedClassifier):
    """Trained classifier implementation using fitted scikit-learn objecs."""

    def __init__(
        self,
        clf: Any,
        X: pd.DataFrame,
        metrics: dict[str, Any] | None = None,
    ) -> None:
        self.clf = clf
        self._classes = clf.classes_
        self.ids = X.index
        self.shape = X.shape
        self._metrics = metrics or {
            "n_features": X.shape[1],
            "n_samples": X.shape[0],
        }

    def predict_proba(
        self,
        betas: np.ndarray | pd.DataFrame,
        id_: pd.Series | np.ndarray | Sequence | None = None,
    ) -> np.ndarray:
        return self.clf.predict_proba(betas)

    def classes(self) -> np.ndarray:
        return self._classes

    def info(self, output_format: str = "txt") -> str:
        description = _get_pipeline_description(self.clf, output_format)
        formatted_stats = _format_metrics(self._metrics)

        if output_format == "html":
            result = [f"{description}"]
            result.append("<h2>Metrics</h2>")
            result.append("<table>")
            for key, value in formatted_stats.items():
                result.append(
                    f"<tr><td class='metrics-label'>{key}</td>"
                    f"<td class='metrics-value'>{value}</td></tr>"
                )
            result.append("</table>")
        else:
            max_key_length = max(len(key) for key in formatted_stats)
            result = [description + "\n", "Metrics:"]
            for key, value in formatted_stats.items():
                result.append(f"{key:<{max_key_length}} : {value}")

        return "\n".join(result)

    def metrics(self) -> dict[str, Any]:
        return self._metrics

    def model(self) -> Any:
        return self.clf


class TrainedSklearnCVClassifier(TrainedClassifier):
    """A trained sklearn classifier with support for cross-validation.

    This class allows the use of precomputed cross-validation probabilities or
    standard predictions from a trained classifier for a given sample. It also
    provides additional statistics about the classifier.
    """

    def __init__(
        self,
        clf: Any,
        probabilities_cv: np.ndarray,
        X: pd.DataFrame,
        metrics: dict[str, Any] | None = None,
    ) -> None:
        self.clf = clf
        self.probabilities_cv = pd.DataFrame(probabilities_cv, index=X.index)
        self.ids = X.index
        self.shape = X.shape
        self._metrics = metrics or {}
        self._classes = clf.classes_

    def predict_proba(
        self,
        betas: np.ndarray | pd.DataFrame,
        id_: pd.Series | np.ndarray | Sequence | None = None,
    ) -> np.ndarray:
        if id_ is not None:
            id_ = np.ravel(id_)
            probabilities = np.zeros((len(betas), len(self._classes)))

            if len(betas) != len(id_):
                msg = "Length of 'betas' and 'id_' must match."
                raise ValueError(msg)

            mask_known_ids = np.isin(id_, self.ids)
            if any(mask_known_ids):
                probabilities[mask_known_ids] = self.probabilities_cv.loc[
                    id_[mask_known_ids]
                ].values
            unknown_betas = betas[~mask_known_ids]
            if len(unknown_betas) > 0:
                probabilities[~mask_known_ids] = self.clf.predict_proba(
                    unknown_betas
                )
            return probabilities

        return self.clf.predict_proba(betas)

    def classes(self) -> np.ndarray:
        return self._classes

    def info(self, output_format: str = "txt") -> str:
        description = _get_pipeline_description(self.clf, output_format)
        formatted_stats = _format_metrics(self._metrics)

        if output_format == "html":
            result = [f"{description}"]
            result.append("<h2>Metrics</h2>")
            result.append("<table>")
            for key, value in formatted_stats.items():
                result.append(
                    f"<tr><td class='metrics-label'>{key}</td>"
                    f"<td class='metrics-value'>{value}</td></tr>"
                )
            result.append("</table>")
        else:
            max_key_length = max(len(key) for key in formatted_stats)
            result = [description + "\n", "Metrics:"]
            for key, value in formatted_stats.items():
                result.append(f"{key:<{max_key_length}} : {value}")

        return "\n".join(result)

    def metrics(self) -> dict[str, Any]:
        return self._metrics

    def model(self) -> Any:
        return self.clf


class VarianceThresholdLite(BaseEstimator, TransformerMixin):
    """low memory version of sklearn.feature_selection.VarianceThreshold."""

    def __init__(self, threshold: float = 1e-4) -> None:
        self.threshold = threshold

    def fit(
        self,
        X: np.ndarray | pd.DataFrame,
        y: ArrayLike | None = None,
    ) -> "VarianceThresholdLite":
        X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
        if X_array.ndim != 2:
            msg = "Input X must be 2D."
            raise ValueError(msg)
        self.variances_ = np.var(X_array, axis=0)
        self.support_mask_ = self.variances_ > self.threshold
        if not np.any(self.support_mask_):
            msg = f"No features meet the variance threshold {self.threshold}"
            raise ValueError(msg)
        return self

    def transform(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
        X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
        return X_array[:, self.support_mask_]

    def _get_support_mask(self) -> np.ndarray:
        return self.variances_ > self.threshold


class TopVarianceSelector(BaseEstimator, TransformerMixin):
    """Low-memory selector for top-N features by variance."""

    def __init__(self, n_top: int = 30000) -> None:
        self.n_top = n_top

    def fit(
        self,
        X: np.ndarray | pd.DataFrame,
        y: ArrayLike | None = None,
    ) -> "TopVarianceSelector":
        X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
        if X_array.ndim != 2:
            msg = "Input X must be 2D."
            raise ValueError(msg)

        # Compute variance per column
        self.variances_ = np.var(X_array, axis=0)

        # Get indices of top N variances
        if self.n_top > X_array.shape[1]:
            msg = (
                f"n_top={self.n_top} exceeds number of features "
                "{X_array.shape[1]}"
            )
            raise ValueError(msg)
        self.top_indices_ = np.argsort(self.variances_)[-self.n_top :]

        return self

    def transform(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
        X_array = X.values if isinstance(X, pd.DataFrame) else np.asarray(X)
        return X_array[:, self.top_indices_]

    def get_support(
        self,
        indices: bool = False,
    ) -> np.ndarray:
        if indices:
            return self.top_indices_
        mask = np.zeros_like(self.variances_, dtype=bool)
        mask[self.top_indices_] = True
        return mask


def parse_component_key(key: str) -> tuple[str, list[Any], dict[str, Any]]:
    """Parse component name and kwargs like 'lr(n_iter=200)'."""
    match = re.match(r"^(\w+)(?:\((.*)\))?$", key)
    if not match:
        raise ValueError(f"Invalid component key format: {key}")
    name, kwargs_str = match.groups()
    assert name is not None
    args = []
    kwargs = {}
    if kwargs_str:
        tree = ast.parse(f"f({kwargs_str})", mode="eval")
        if not isinstance(tree.body, ast.Call):
            raise ValueError(f"Malformed kwargs in: {key}")
        args = [ast.literal_eval(arg) for arg in tree.body.args]
        kwargs = {
            kw.arg: ast.literal_eval(kw.value)
            for kw in tree.body.keywords
            if kw.arg is not None
        }
    return name, args, kwargs


def make_clf_pipeline(
    step_keys: str | list[str],
    X_shape: tuple[int, int] | None = None,
    cv: int | Any | None = None,
) -> Pipeline:
    """Sklearn pipeline with scaling, feature selection, and classifier."""
    scalers = {
        "minmax": MinMaxScaler,
        "power": PowerTransformer,
        "quantile": QuantileTransformer,
        "robust": RobustScaler,
        "std": StandardScaler,
    }
    selectors = {
        "kbest": lambda **kwargs: SelectKBest(**{"k": 10000, **kwargs}),
        "pca": PCA,
        "top": TopVarianceSelector,
        "vtl": VarianceThresholdLite,
    }
    if X_shape is not None and cv:
        n_splits = cv if isinstance(cv, int) else cv.n_splits
        n_components_pca = min(((n_splits - 1) * X_shape[0] // n_splits), 50)
        selectors["pca_auto"] = lambda **kwargs: PCA(
            **{"n_components": n_components_pca, **kwargs}
        )
    models = {
        "ada": AdaBoostClassifier,
        "bag": BaggingClassifier,
        "dt": DecisionTreeClassifier,
        "et": ExtraTreesClassifier,
        "gb": GradientBoostingClassifier,
        "gp": GaussianProcessClassifier,
        "hgb": HistGradientBoostingClassifier,
        "knn": KNeighborsClassifier,
        "lda": LinearDiscriminantAnalysis,
        "lr": LogisticRegression,
        "mlp": MLPClassifier,
        "nb": GaussianNB,
        "perceptron": Perceptron,
        "qda": QuadraticDiscriminantAnalysis,
        "rf": RandomForestClassifier,
        "ridge": RidgeClassifier,
        "sgd": SGDClassifier,
        "svc": lambda **kwargs: SVC(**{"probability": True, **kwargs}),
    }
    components = {
        **{key: ("scaler", val) for key, val in scalers.items()},
        **{key: ("feature_selector", val) for key, val in selectors.items()},
        **{key: ("classifier", val) for key, val in models.items()},
    }
    components_full_name = {
        clf().__class__.__name__: (name, clf)
        for name, clf in components.values()
    }
    components.update(components_full_name)
    if isinstance(step_keys, str):
        step_keys = step_keys.split("-")
    pipeline_components = []
    for i, full_key in enumerate(step_keys):
        base_key, args, kwargs = parse_component_key(full_key)
        if base_key not in components:
            raise ValueError(f"Invalid step key: {base_key}")
        type_name, constructor = components[base_key]
        component = constructor(*args, **kwargs)
        step_name = f"{i + 1}-{type_name}"
        pipeline_components.append((step_name, component))

    return Pipeline(pipeline_components)


def make_reports(
    prediction: pd.DataFrame,
    info: str,
    output_format: str = "txt",
) -> list[str]:
    """Generates detailed reports from classifier predictions.

    Args:
        prediction (pd.DataFrame): DataFrame containing predicted probabilities
            for each class, indexed by sample IDs.
        info (str): A description of the classifier such as its name and
            metrics. Will be printed before predictions.
        output_format (str): The format of the report ('txt' or 'html').
            Defaults to 'txt'.

    Returns:
        list[str]: A list of detailed string reports, one for each sample.
    """
    reports = []
    for sample_id, row in prediction.iterrows():
        top_predictions = row[row > 0].nlargest(10) * 100
        if output_format == "txt":
            report_lines = [
                str(sample_id),
                "=" * len(str(sample_id)),
                f"\n{info}\n",
                "Classification Probability:",
            ]
            label_len = max(len(str(label)) for label in top_predictions.index)
            formatted_lines = [
                f"{label:<{label_len}} : {probability:6.2f} %"
                for label, probability in top_predictions.items()
            ]
            n_dashes = (
                max(len(line) for line in formatted_lines)
                if formatted_lines
                else 0
            )
            dashes = "-" * n_dashes
            report_lines.extend([dashes, *formatted_lines, dashes])
        if output_format == "html":
            report_lines = [
                f"<h1>{sample_id}</h1>",
                f"{info}",
                "<h2>Classification Probability</h2>",
                "<div class='classification-result'>",
                "<table>",
            ]
            report_lines.extend(
                (
                    f"<tr><td class='prob-label'>{label}</td>"
                    f"<td class='prob-value'>{probability:.2f}%</td></tr>"
                )
                for label, probability in top_predictions.items()
            )
            report_lines.extend(["</table>", "</div>"])
        reports.append("\n".join(report_lines))
    return reports


def make_classifier_report_page(
    reports: Sequence[str],
    title: str | None = None,
) -> str:
    """Generates an HTML page for multiple classifier reports.

    Args:
        reports (list of str): A list of HTML reports obtained from
            'MethylAnalysis.classify'.
        itle (str): Title of the report page. Defaults to 'None'.

    Returns:
        str: A formatted HTML string.
    """
    title_str = f"<h1>{title}</h1><hr>" if title else ""
    body_string = "\n<hr>\n".join(reports)
    return f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Classifier Report</title>
    <style>
        body {{
            font-family: Arial, sans-serif;
            background: #f4f7f9;
            color: #333;
            margin: 0 auto;
            max-width: 1000px;
            line-height: 1.4;
        }}
        h1 {{
            text-align: center;
            margin-top: 10px;
            font-size: 1.5em;
        }}
        h2 {{
            font-size: 1.2em;
            margin-bottom: 5px;
        }}
        table {{
            width: 100%;
            max-width: 500px;
            margin: 10px auto;
            border-collapse: collapse;
        }}
        hr {{
            margin: 20px 0;
            border: none;
            height: 2px;
            background-color: #ccc;
        }}
        th, td {{
            padding: 3px 8px;
            text-align: left;
            border-bottom: 1px solid #ddd;
            font-size: 0.9em;
        }}
        td.clf-label, td.metrics-label, td.prob-label {{
            font-weight: bold;
        }}
        td.clf-value, td.metrics-value, td.prob-value {{
            text-align: right;
        }}
        .classification-result td {{
            background: #ffebe6;
        }}
        .classification-result-k {{
            background: #ffe4b5;
            font-weight: bold;
            color: #d85c5c;
        }}
        @media (max-width: 768px) {{
            body {{
                max-width: 95%;
            }}
            table {{
                max-width: 100%;
            }}
        }}
    </style>
</head>
<body>
{title_str}
{body_string}
</body>
</html>"""


def _is_ovr_or_ovo(clf: Any) -> str:
    """Check if classifier is using One-vs-Rest (OvR) or One-vs-One (OvO)."""
    if isinstance(clf, OneVsRestClassifier):
        return "ovr"
    if isinstance(clf, OneVsOneClassifier):
        return "ovo"
    if hasattr(clf, "decision_function_shape"):
        if clf.decision_function_shape == "ovr":
            return "ovr"
        if clf.decision_function_shape == "ovo":
            return "ovo"
    # Default for most classifiers
    return "ovr"


def cross_val_metrics(
    clf: Any,
    X: pd.DataFrame | np.ndarray,
    y: pd.Series | np.ndarray | Sequence,
    probabilities_cv: np.ndarray,
    cv: Any,
) -> dict[str, Any]:
    """Calculates cross-validation statistics for a classifier."""
    y_pred_cv = np.array(
        [clf.classes_[i] for i in np.argmax(probabilities_cv, axis=1)]
    )
    accuracy_scores = []
    auc_scores = []
    f1_scores = []
    multi_class_strategy = _is_ovr_or_ovo(clf)
    precision_scores = []
    recall_scores = []

    for _, test_idx in cv.split(X, y):
        y_true = np.array(y)[test_idx]
        probabilities_cv_test = probabilities_cv[test_idx, :]
        y_pred = y_pred_cv[test_idx]

        accuracy_scores.append(accuracy_score(y_true, y_pred))
        f1_scores.append(f1_score(y_true, y_pred, average="weighted"))
        precision_scores.append(
            precision_score(
                y_true, y_pred, average="weighted", zero_division=0
            )
        )
        recall_scores.append(
            recall_score(y_true, y_pred, average="weighted", zero_division=0)
        )

        # Calculate ROC AUC score
        if len(np.unique(y_true)) == 2:
            # Binary classification: use probabilities for the positive class
            roc_auc = roc_auc_score(y_true, probabilities_cv_test[:, 1])
        else:
            # Multiclass classification: use the full probabilities matrix
            roc_auc = roc_auc_score(
                y_true,
                probabilities_cv_test,
                multi_class=multi_class_strategy,
                average="weighted",
            )
        auc_scores.append(roc_auc)

    return {
        "cv": cv,
        "n_samples": X.shape[0],
        "n_features": X.shape[1],
        "accuracy_scores": accuracy_scores,
        "auc_scores": auc_scores,
        "f1_scores": f1_scores,
        "precision_scores": precision_scores,
        "recall_scores": recall_scores,
    }


def is_trained(clf: Any) -> bool:
    """Checks if the given classifier has already been trained."""
    clf_ = clf
    if hasattr(clf, "steps"):
        clf_ = clf.steps[-1][1]
    trained_attributes = ["classes_"]
    return all(hasattr(clf_, attr) for attr in trained_attributes)


def train_clf(
    clf: Any,
    X: pd.DataFrame,
    y: ArrayLike,
    cv: int | Any,
    n_jobs: int = 1,
) -> TrainedSklearnClassifier | TrainedSklearnCVClassifier:
    """Trains a classifier and stores the trained model to disk.

    If the classifier has already been trained (and saved), it loads the
    trained model. If not, it trains the classifier using the provided data and
    saves it to disk.

    Args:
        clf (classifier): The classifier to train or load.
        X (array-like): The feature matrix.
        y (array-like): The target labels.
        cv (int or cross-validation generator): Determines the cross-validation
            splitting strategy.
        n_jobs (int): Number of parallel processes to run.

    Returns:
        TrainedSklearnClassifier or TrainedSklearnCVClassifier: The trained
            classifier object.
    """
    n_splits = cv if isinstance(cv, int) else cv.n_splits

    if is_trained(clf):
        return TrainedSklearnClassifier(clf, X=X)

    logger.info("Start training...")

    counts = np.unique(y, return_counts=True)
    unique_classes, counts_per_class = counts

    if min(counts_per_class) < n_splits:
        # Find the class with the minimum count
        insufficient_idx = np.where(counts_per_class < n_splits)[0]
        insufficient_list = [
            f"  - {unique_classes[i]} (n={counts_per_class[i]})"
            for i in insufficient_idx
        ]
        insufficient_str = "\n".join(insufficient_list)

        raise ValueError(
            f"The following classes have fewer samples than cross-validation "
            f"splits (n={n_splits}):\n{insufficient_str}\n\n Please check "
            "both your annotation file and that the corresponding samples "
            "exist on disk.\n You can reduce 'cv' (e.g., set cv=1 to train "
            "without cross-validation), or provide more samples for these "
            "classes."
        )

    clf.fit(X, y)

    logger.info("Start cross-validation...")
    probabilities_cv = cross_val_predict(
        clf, X, y, cv=cv, method="predict_proba", n_jobs=n_jobs
    )
    metrics = cross_val_metrics(clf, X, y, probabilities_cv, cv)
    trained_clf = TrainedSklearnCVClassifier(
        clf=clf, probabilities_cv=probabilities_cv, X=X, metrics=metrics
    )

    return trained_clf


@dataclass
class ClassifierResult:
    """Data container for evaluation of classifier."""

    prediction: Any
    model: Any
    metrics: dict
    reports: dict


[docs] def fit_and_evaluate_clf( X: pd.DataFrame | None, y: ArrayLike | None, X_test: pd.DataFrame | np.ndarray, id_test: Sequence[str] | None, save_path: Path, clf: Any, cv: int | Any, n_jobs: int = 1, ) -> ClassifierResult: """Predicts the methylation class by supervised learning classifier. Uses supervised machine learning classifiers (Random Forest, K-Nearest Neighbors, Neural Networks, SVM, ...) to predict the methylation class of the sample. Output will be written to disk. Args: X (pd.DataFrame): Feature matrix (rows as samples, columns as features). y (array-like): Class labels. X_test (array-like): Value of the sample to be evaluated. id_test (str): Unique identifiers for the test samples to be evaluated. save_path (str or Path): Path where the classifiers and results will be saved/cached. clf (list): Classifier to use. Can be: - A scikit-learn classifier object or pipeline (trained or untrained). - A string in the format "scaler-selector-classifier". Possible values are: - A pipeline string composed of arbitrary components joined by dashes ("-"). Each component can be specified using either an abbreviation or the full class name (e.g., "std" or "StandardScaler").: scaler: - "std": Standard scaling (StandardScaler). - "minmax": Min-max scaling (MinMaxScaler). - "robust": Robust scaling (RobustScaler). - "power": Power transformation (PowerTransformer). - "quantile": Quantile transformation (QuantileTransformer). selector: - "kbest": Select the best features (SelectKBest). - "top": To varying features (TopVarianceSelector). - "pca": Principal component analysis (PCA). - "pca_auto": Principal component analysis (PCA). Number of components is determined automatically. - "lda": Linear Discriminant Analysis (LDA). clf: - "rf": RandomForestClassifier. - "lr": LogisticRegression. - "et": ExtraTreesClassifier. - "knn": KNeighborsClassifier. - "mlp": MLPClassifier. - "svc": Support Vector Classifier (SVC). - "ada": AdaBoostClassifier. - "bag": BaggingClassifier. - "dt": DecisionTreeClassifier. - "gp": GaussianProcessClassifier. - "hgb": HistGradientBoostingClassifier. - "nb": GaussianNB. - "perceptron": Perceptron. - "qda": Quadratic Discriminant Analysis (QDA). - "ridge": RidgeClassifier. - "sgd": SGDClassifier. Example: Using a feature selector and a classifier (SelectKBest selection and Logistic Regression): - "kbest-lr" - A custom class, that inherits from `TrainedClassifier`. cv (int or cross-validation generator): Determines the cross-validation splitting strategy. n_jobs (int): Number of parallel processes to run. Returns: ClassifierResult: - prediction (DataFrame): DataFrame containing the predicted probabilities for each class. - model (object): The trained classifier object. - metrics (dict): Dict containing classifier metrics. - reports (dict): Dict of evaluation report (both 'txt' and 'html') for each sample. """ if save_path.exists(): logger.info("Reading trained and cached classifier.") with save_path.open("rb") as file: trained_clf = pickle.load(file) elif isinstance(clf, Pipeline | ClassifierMixin): if X is None or y is None: raise ValueError("Both X and y must be different from 'None'.") trained_clf = train_clf(clf=clf, X=X, y=y, cv=cv, n_jobs=n_jobs) elif isinstance(clf, str): if X is None or y is None: raise ValueError("Both X and y must be different from 'None'.") pipeline = make_clf_pipeline(clf, X.shape, cv) trained_clf = train_clf(clf=pipeline, X=X, y=y, cv=cv, n_jobs=n_jobs) else: trained_clf = clf classes = trained_clf.classes() probabilities = trained_clf.predict_proba(X_test, id_test) info_txt = trained_clf.info(output_format="txt") info_html = trained_clf.info(output_format="html") prediction = pd.DataFrame(probabilities, index=id_test, columns=classes) reports_txt = make_reports(prediction, info_txt, output_format="txt") reports_html = make_reports(prediction, info_html, output_format="html") result = ClassifierResult( prediction, trained_clf.model(), trained_clf.metrics(), {"txt": reports_txt, "html": reports_html}, ) if not save_path.exists(): with save_path.open("wb") as file: pickle.dump(trained_clf, file) return result