Skip to content

Estimators API

This page documents the high-level scikit-learn-style estimators in TorchKM.

The classification estimators provide a familiar interface:

  • fit(X, y, *, low_rank=None, num_landmarks=None, nys_k=None)
  • predict(X)
  • decision_function(X)
  • predict_proba(X) when probability=True
  • platt_plot(X, y) when probability calibration is enabled

TorchKMKQR provides fit(X, y, *, low_rank=None, num_landmarks=None, nys_k=None) and predict(X) for continuous targets. Low-rank options are normally configured in the estimator constructor, but fit also accepts keyword-only convenience arguments low_rank, num_landmarks, and nys_k.

The classification wrappers accept NumPy arrays and torch tensors, map arbitrary binary labels to the low-level {-1, +1} convention internally, choose best_C_ through cross-validation, and return predictions in the original label space.

TorchKMSVC

TorchKMSVC

Bases: _TorchKMBaseBinaryClassifier

Kernel support vector classifier with integrated model selection.

TorchKMSVC is the scikit-learn-style wrapper around :class:torchkm.cvksvm.cvksvm. It builds a kernel matrix from feature input, fits a path of candidate regularization values, selects best_C_ by cross-validation, and exposes familiar prediction methods.

Parameters:

Name Type Description Default
kernel ('rbf', 'linear', 'poly', 'precomputed')

Kernel used by the estimator. "precomputed" expects a square training kernel matrix in fit and a test-by-train kernel matrix in decision_function or predict.

"rbf"
nC int

Number of candidate C values when Cs is not provided.

50
Cs array - like

Candidate regularization values under the scikit-learn/LIBSVM C convention. Internally these are converted to solver regularization values.

None
C_max float

Endpoints for the log-spaced C grid used when Cs is omitted.

1e3, 1e-3
C_min float

Endpoints for the log-spaced C grid used when Cs is omitted.

1e3, 1e-3
cv int

Number of cross-validation folds used to choose best_C_.

5
foldid array - like

Optional fold assignment of length n_samples. Fold labels follow the low-level solver convention and are typically in 1, ..., cv.

None
tol float

Solver convergence tolerance.

1e-5
max_iter int

Maximum number of iterations used by the low-level solver.

1000
solver_gamma float

Small numerical regularizer passed to the solver.

1e-8
is_exact int

Solver option used by the exact SVM backend.

0
device ('cpu', 'cuda')

Device used for computation. If None, CUDA is used when available; otherwise CPU is used. Requests for CUDA fall back to CPU when CUDA is unavailable.

"cpu"
rbf_sigma float

RBF kernel scale. If omitted, sigest estimates a scale from the training data.

None
sigest_frac float

Fraction passed to sigest when estimating the RBF scale.

0.5
poly_degree int or float

Polynomial-kernel parameters.

3
poly_coef0 int or float

Polynomial-kernel parameters.

3
poly_gamma int or float

Polynomial-kernel parameters.

3
probability bool

If True, fit a Platt scaler on the selected out-of-fold scores and enable predict_proba and platt_plot.

False
platt_device ('cpu', 'cuda')

Device used for Platt calibration. Defaults to the estimator device.

"cpu"
random_state int

Seed used for deterministic fold construction.

None
store_path bool

If True, keep the full coefficient and out-of-fold prediction path.

False
low_rank bool

If True, use the Nyström SVM backend. The low-rank path currently supports raw-feature RBF-kernel workflows, not kernel="precomputed".

False
num_landmarks int

Number of Nyström landmarks when low_rank=True.

2000
nys_k int

Rank used by the Nyström feature map when low_rank=True.

1000

Attributes:

Name Type Description
classes_ ndarray of shape (2,)

Original binary class labels, ordered as negative then positive.

best_C_ float

Regularization value selected by cross-validation.

best_ind_ int

Index of the selected value in the candidate path.

cv_mis_ ndarray of shape (nC,)

Cross-validation misclassification scores for the candidate path.

alpha_ ndarray

Coefficients for the selected model.

intercept_ float

Intercept for the selected model.

foldid_ ndarray

Fold assignment used during fitting.

n_features_in_ int

Number of input features seen during fitting.

n_samples_fit_ int

Number of training samples.

kernel_state_ dict

Kernel parameters needed for prediction, such as the fitted RBF scale.

low_rank_basis_dim_ int

Effective low-rank feature dimension when low_rank=True.

low_rank_landmark_indices_ ndarray

Landmark indices when exposed by the Nyström backend.

num_landmarks_ int

Number of landmarks used by the fitted Nyström backend, when available.

nys_k_ int

Effective Nyström rank, when available.

Notes

The high-level wrapper accepts any two distinct class labels and maps them internally to the {-1, +1} convention used by the low-level solvers. Predictions are mapped back to the original labels.

The methods decision_function and predict are available after fitting. predict_proba and platt_plot require probability=True at construction time.

Examples:

>>> import numpy as np
>>> import torch
>>> from sklearn.datasets import make_circles
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.preprocessing import StandardScaler
>>> from torchkm.estimators import TorchKMSVC
>>> X, y = make_circles(n_samples=120, factor=0.4, noise=0.08,
...                     random_state=0)
>>> X = StandardScaler().fit_transform(X)
>>> Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.25,
...                                       random_state=0)
>>> Cs = np.logspace(2, -2, num=4)
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> clf = TorchKMSVC(kernel="rbf", Cs=Cs, nC=len(Cs), cv=5,
...                  device=device, max_iter=40)
>>> clf.fit(Xtr, ytr)
TorchKMSVC(...)
>>> clf.best_C_ > 0
True
>>> clf.predict(Xte[:3]).shape
(3,)
Source code in torchkm/estimators.py
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
class TorchKMSVC(_TorchKMBaseBinaryClassifier):
    """Kernel support vector classifier with integrated model selection.

    ``TorchKMSVC`` is the scikit-learn-style wrapper around
    :class:`torchkm.cvksvm.cvksvm`. It builds a kernel matrix from feature
    input, fits a path of candidate regularization values, selects ``best_C_``
    by cross-validation, and exposes familiar prediction methods.

    Parameters
    ----------
    kernel : {"rbf", "linear", "poly", "precomputed"}, default="rbf"
        Kernel used by the estimator. ``"precomputed"`` expects a square
        training kernel matrix in ``fit`` and a test-by-train kernel matrix in
        ``decision_function`` or ``predict``.
    nC : int, default=50
        Number of candidate ``C`` values when ``Cs`` is not provided.
    Cs : array-like, optional
        Candidate regularization values under the scikit-learn/LIBSVM
        ``C`` convention. Internally these are converted to solver
        regularization values.
    C_max, C_min : float, default=1e3, 1e-3
        Endpoints for the log-spaced ``C`` grid used when ``Cs`` is omitted.
    cv : int, default=5
        Number of cross-validation folds used to choose ``best_C_``.
    foldid : array-like, optional
        Optional fold assignment of length ``n_samples``. Fold labels follow
        the low-level solver convention and are typically in ``1, ..., cv``.
    tol : float, default=1e-5
        Solver convergence tolerance.
    max_iter : int, default=1000
        Maximum number of iterations used by the low-level solver.
    solver_gamma : float, default=1e-8
        Small numerical regularizer passed to the solver.
    is_exact : int, default=0
        Solver option used by the exact SVM backend.
    device : {"cpu", "cuda"} or torch.device, optional
        Device used for computation. If ``None``, CUDA is used when available;
        otherwise CPU is used. Requests for CUDA fall back to CPU when CUDA is
        unavailable.
    rbf_sigma : float, optional
        RBF kernel scale. If omitted, ``sigest`` estimates a scale from the
        training data.
    sigest_frac : float, default=0.5
        Fraction passed to ``sigest`` when estimating the RBF scale.
    poly_degree, poly_coef0, poly_gamma : int or float
        Polynomial-kernel parameters.
    probability : bool, default=False
        If ``True``, fit a Platt scaler on the selected out-of-fold scores and
        enable ``predict_proba`` and ``platt_plot``.
    platt_device : {"cpu", "cuda"} or torch.device, optional
        Device used for Platt calibration. Defaults to the estimator device.
    random_state : int, optional
        Seed used for deterministic fold construction.
    store_path : bool, default=False
        If ``True``, keep the full coefficient and out-of-fold prediction path.
    low_rank : bool, default=False
        If ``True``, use the Nyström SVM backend. The low-rank path currently
        supports raw-feature RBF-kernel workflows, not ``kernel="precomputed"``.
    num_landmarks : int, default=2000
        Number of Nyström landmarks when ``low_rank=True``.
    nys_k : int, default=1000
        Rank used by the Nyström feature map when ``low_rank=True``.

    Attributes
    ----------
    classes_ : ndarray of shape (2,)
        Original binary class labels, ordered as negative then positive.
    best_C_ : float
        Regularization value selected by cross-validation.
    best_ind_ : int
        Index of the selected value in the candidate path.
    cv_mis_ : ndarray of shape (nC,)
        Cross-validation misclassification scores for the candidate path.
    alpha_ : ndarray
        Coefficients for the selected model.
    intercept_ : float
        Intercept for the selected model.
    foldid_ : ndarray
        Fold assignment used during fitting.
    n_features_in_ : int
        Number of input features seen during fitting.
    n_samples_fit_ : int
        Number of training samples.
    kernel_state_ : dict
        Kernel parameters needed for prediction, such as the fitted RBF scale.
    low_rank_basis_dim_ : int
        Effective low-rank feature dimension when ``low_rank=True``.
    low_rank_landmark_indices_ : ndarray
        Landmark indices when exposed by the Nyström backend.
    num_landmarks_ : int
        Number of landmarks used by the fitted Nyström backend, when available.
    nys_k_ : int
        Effective Nyström rank, when available.

    Notes
    -----
    The high-level wrapper accepts any two distinct class labels and maps them
    internally to the ``{-1, +1}`` convention used by the low-level solvers.
    Predictions are mapped back to the original labels.

    The methods ``decision_function`` and ``predict`` are available after
    fitting. ``predict_proba`` and ``platt_plot`` require
    ``probability=True`` at construction time.

    Examples
    --------
    >>> import numpy as np
    >>> import torch
    >>> from sklearn.datasets import make_circles
    >>> from sklearn.model_selection import train_test_split
    >>> from sklearn.preprocessing import StandardScaler
    >>> from torchkm.estimators import TorchKMSVC
    >>> X, y = make_circles(n_samples=120, factor=0.4, noise=0.08,
    ...                     random_state=0)
    >>> X = StandardScaler().fit_transform(X)
    >>> Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.25,
    ...                                       random_state=0)
    >>> Cs = np.logspace(2, -2, num=4)
    >>> device = "cuda" if torch.cuda.is_available() else "cpu"
    >>> clf = TorchKMSVC(kernel="rbf", Cs=Cs, nC=len(Cs), cv=5,
    ...                  device=device, max_iter=40)
    >>> clf.fit(Xtr, ytr)
    TorchKMSVC(...)
    >>> clf.best_C_ > 0
    True
    >>> clf.predict(Xte[:3]).shape
    (3,)
    """

    _BACKEND: BackendName = "svm"

platt_plot(X=None, y=None, *, n_bins=15, strategy='uniform', annotate_counts=True, figsize=(5.2, 5.2), title='Calibration (Reliability) Curve', savepath=None, dpi=150, ax=None)

Plot a calibration / reliability curve for the fitted Platt scaler.

Parameters:

Name Type Description Default
X array - like or None

If provided, compute predict_proba(X) and plot reliability against y. If omitted, use the stored training calibration scores from fit().

None
y array - like or None

True labels corresponding to X. If X is None and y is None, stored training labels from fit() are used.

None
n_bins int

Number of bins used in the reliability curve.

15
strategy ('uniform', 'quantile')

How to bin probabilities.

"uniform"
annotate_counts bool

If True, annotate each point with the number of samples in that bin.

True
figsize tuple

Figure size when ax is None.

(5.2, 5.2)
title str

Plot title.

'Calibration (Reliability) Curve'
savepath str or None

If provided, save the plot.

None
dpi int

Save DPI.

150
ax matplotlib axis or None

Existing axis to draw on.

None

Returns:

Name Type Description
ax matplotlib axis
stats dict

Contains ECE, Brier score, bin counts, and plotted points.

Source code in torchkm/estimators.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
def platt_plot(
    self,
    X: Optional[Any] = None,
    y: Optional[Any] = None,
    *,
    n_bins: int = 15,
    strategy: str = "uniform",
    annotate_counts: bool = True,
    figsize: Tuple[float, float] = (5.2, 5.2),
    title: str = "Calibration (Reliability) Curve",
    savepath: Optional[str] = None,
    dpi: int = 150,
    ax=None,
):
    """
    Plot a calibration / reliability curve for the fitted Platt scaler.

    Parameters
    ----------
    X : array-like or None
        If provided, compute predict_proba(X) and plot reliability against y.
        If omitted, use the stored training calibration scores from fit().

    y : array-like or None
        True labels corresponding to X.
        If X is None and y is None, stored training labels from fit() are used.

    n_bins : int
        Number of bins used in the reliability curve.

    strategy : {"uniform", "quantile"}
        How to bin probabilities.

    annotate_counts : bool
        If True, annotate each point with the number of samples in that bin.

    figsize : tuple
        Figure size when ax is None.

    title : str
        Plot title.

    savepath : str or None
        If provided, save the plot.

    dpi : int
        Save DPI.

    ax : matplotlib axis or None
        Existing axis to draw on.

    Returns
    -------
    ax : matplotlib axis
    stats : dict
        Contains ECE, Brier score, bin counts, and plotted points.
    """
    check_is_fitted(self, ["classes_"])

    if self.platt_ is None:
        raise AttributeError(
            "Platt scaler is not fitted. Fit with probability=True before calling platt_plot()."
        )

    try:
        import matplotlib.pyplot as plt
    except Exception as e:
        raise ImportError(
            "platt_plot requires matplotlib. Install it with `pip install matplotlib` "
            "or add it to a visualization extra such as `torchkm[viz]`."
        ) from e

    # ------------------------------------------------------------
    # Get probabilities + labels
    # ------------------------------------------------------------
    if X is None:
        if self.platt_scores_ is None or self.platt_y_ is None:
            raise AttributeError(
                "Stored calibration data not found. Fit with probability=True first, "
                "or call platt_plot(X=..., y=...)."
            )

        scores_t = torch.as_tensor(
            self.platt_scores_,
            dtype=torch.double,
            device=getattr(self, "_platt_device_", "cpu"),
        )

        with torch.no_grad():
            proba_t = self.platt_.predict_proba(scores_t)

        proba = proba_t.detach().cpu().numpy()
        y_raw = np.asarray(self.platt_y_).reshape(-1)

    else:
        if y is None:
            raise ValueError("When X is provided, y must also be provided.")

        proba = self.predict_proba(X)
        y_raw = np.asarray(_as_numpy(y)).reshape(-1)

    if proba.ndim == 2:
        p_pos = proba[:, -1].astype(np.float64)
    else:
        p_pos = proba.reshape(-1).astype(np.float64)

    pos_label = self.classes_[1]
    y01 = (y_raw == pos_label).astype(np.float64)

    if p_pos.shape[0] != y01.shape[0]:
        raise ValueError(
            "Predicted probabilities and labels must have the same length."
        )

    # ------------------------------------------------------------
    # Metrics: ECE and Brier
    # ------------------------------------------------------------
    brier = float(np.mean((p_pos - y01) ** 2))

    # ------------------------------------------------------------
    # Binning
    # ------------------------------------------------------------
    if strategy not in {"uniform", "quantile"}:
        raise ValueError("strategy must be 'uniform' or 'quantile'.")

    if strategy == "uniform":
        edges = np.linspace(0.0, 1.0, int(n_bins) + 1)
    else:
        edges = np.quantile(p_pos, np.linspace(0.0, 1.0, int(n_bins) + 1))
        edges = np.unique(edges)
        if edges.size < 2:
            edges = np.array([0.0, 1.0], dtype=np.float64)

    bin_x = []
    bin_y = []
    bin_n = []

    n = p_pos.shape[0]
    ece = 0.0

    for i in range(len(edges) - 1):
        lo, hi = edges[i], edges[i + 1]

        if i == len(edges) - 2:
            mask = (p_pos >= lo) & (p_pos <= hi)
        else:
            mask = (p_pos >= lo) & (p_pos < hi)

        count = int(mask.sum())
        if count == 0:
            continue

        conf = float(p_pos[mask].mean())  # average predicted probability
        acc = float(y01[mask].mean())  # empirical positive frequency

        bin_x.append(conf)
        bin_y.append(acc)
        bin_n.append(count)

        ece += (count / n) * abs(acc - conf)

    bin_x = np.asarray(bin_x, dtype=np.float64)
    bin_y = np.asarray(bin_y, dtype=np.float64)
    bin_n = np.asarray(bin_n, dtype=np.int64)

    # ------------------------------------------------------------
    # Plot
    # ------------------------------------------------------------
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    # light grey background like your example
    fig.patch.set_facecolor("#EAEAF2")
    ax.set_facecolor("#EAEAF2")

    # perfect line
    ax.plot([0, 1], [0, 1], "--", linewidth=1.5, label="Perfect")

    # calibration curve
    label = f"Platt (ECE={ece:.3f}, Brier={brier:.3f})"
    ax.plot(bin_x, bin_y, marker="o", linewidth=1.8, label=label)

    # annotate counts
    if annotate_counts:
        for x_i, y_i, n_i in zip(bin_x, bin_y, bin_n):
            ax.text(
                x_i,
                y_i + 0.015,
                str(int(n_i)),
                ha="center",
                va="bottom",
                fontsize=9,
            )

    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xlabel("Predicted probability (bin average)")
    ax.set_ylabel("Observed frequency (empirical)")
    ax.set_title(title)
    ax.grid(True, alpha=0.35)
    ax.legend(loc="upper left")

    if savepath is not None:
        fig.savefig(savepath, dpi=dpi, bbox_inches="tight")

    stats = {
        "ece": float(ece),
        "brier": float(brier),
        "bin_avg_proba": bin_x,
        "bin_empirical_freq": bin_y,
        "bin_count": bin_n,
    }

    return ax, stats

TorchKMDWD

TorchKMDWD

Bases: _TorchKMBaseBinaryClassifier

Kernel distance-weighted discrimination classifier.

TorchKMDWD uses the same scikit-learn-style interface and model selection machinery as TorchKMSVC, but delegates fitting to :class:torchkm.cvkdwd.cvkdwd. It accepts binary labels, maps them to the solver's {-1, +1} convention internally, and returns predictions in the original label space.

Parameters are inherited from the shared binary-classifier wrapper. The most common options are kernel, Cs/nC, cv, device, probability, low_rank, num_landmarks, and nys_k.

Attributes include best_C_, cv_mis_, alpha_, intercept_, classes_, and foldid_ after fitting. predict_proba and platt_plot are available only when probability=True.

Source code in torchkm/estimators.py
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
class TorchKMDWD(_TorchKMBaseBinaryClassifier):
    """Kernel distance-weighted discrimination classifier.

    ``TorchKMDWD`` uses the same scikit-learn-style interface and model
    selection machinery as ``TorchKMSVC``, but delegates fitting to
    :class:`torchkm.cvkdwd.cvkdwd`. It accepts binary labels, maps them to the
    solver's ``{-1, +1}`` convention internally, and returns predictions in the
    original label space.

    Parameters are inherited from the shared binary-classifier wrapper. The
    most common options are ``kernel``, ``Cs``/``nC``, ``cv``, ``device``,
    ``probability``, ``low_rank``, ``num_landmarks``, and ``nys_k``.

    Attributes include ``best_C_``, ``cv_mis_``, ``alpha_``, ``intercept_``,
    ``classes_``, and ``foldid_`` after fitting. ``predict_proba`` and
    ``platt_plot`` are available only when ``probability=True``.
    """

    _BACKEND: BackendName = "dwd"

platt_plot(X=None, y=None, *, n_bins=15, strategy='uniform', annotate_counts=True, figsize=(5.2, 5.2), title='Calibration (Reliability) Curve', savepath=None, dpi=150, ax=None)

Plot a calibration / reliability curve for the fitted Platt scaler.

Parameters:

Name Type Description Default
X array - like or None

If provided, compute predict_proba(X) and plot reliability against y. If omitted, use the stored training calibration scores from fit().

None
y array - like or None

True labels corresponding to X. If X is None and y is None, stored training labels from fit() are used.

None
n_bins int

Number of bins used in the reliability curve.

15
strategy ('uniform', 'quantile')

How to bin probabilities.

"uniform"
annotate_counts bool

If True, annotate each point with the number of samples in that bin.

True
figsize tuple

Figure size when ax is None.

(5.2, 5.2)
title str

Plot title.

'Calibration (Reliability) Curve'
savepath str or None

If provided, save the plot.

None
dpi int

Save DPI.

150
ax matplotlib axis or None

Existing axis to draw on.

None

Returns:

Name Type Description
ax matplotlib axis
stats dict

Contains ECE, Brier score, bin counts, and plotted points.

Source code in torchkm/estimators.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
def platt_plot(
    self,
    X: Optional[Any] = None,
    y: Optional[Any] = None,
    *,
    n_bins: int = 15,
    strategy: str = "uniform",
    annotate_counts: bool = True,
    figsize: Tuple[float, float] = (5.2, 5.2),
    title: str = "Calibration (Reliability) Curve",
    savepath: Optional[str] = None,
    dpi: int = 150,
    ax=None,
):
    """
    Plot a calibration / reliability curve for the fitted Platt scaler.

    Parameters
    ----------
    X : array-like or None
        If provided, compute predict_proba(X) and plot reliability against y.
        If omitted, use the stored training calibration scores from fit().

    y : array-like or None
        True labels corresponding to X.
        If X is None and y is None, stored training labels from fit() are used.

    n_bins : int
        Number of bins used in the reliability curve.

    strategy : {"uniform", "quantile"}
        How to bin probabilities.

    annotate_counts : bool
        If True, annotate each point with the number of samples in that bin.

    figsize : tuple
        Figure size when ax is None.

    title : str
        Plot title.

    savepath : str or None
        If provided, save the plot.

    dpi : int
        Save DPI.

    ax : matplotlib axis or None
        Existing axis to draw on.

    Returns
    -------
    ax : matplotlib axis
    stats : dict
        Contains ECE, Brier score, bin counts, and plotted points.
    """
    check_is_fitted(self, ["classes_"])

    if self.platt_ is None:
        raise AttributeError(
            "Platt scaler is not fitted. Fit with probability=True before calling platt_plot()."
        )

    try:
        import matplotlib.pyplot as plt
    except Exception as e:
        raise ImportError(
            "platt_plot requires matplotlib. Install it with `pip install matplotlib` "
            "or add it to a visualization extra such as `torchkm[viz]`."
        ) from e

    # ------------------------------------------------------------
    # Get probabilities + labels
    # ------------------------------------------------------------
    if X is None:
        if self.platt_scores_ is None or self.platt_y_ is None:
            raise AttributeError(
                "Stored calibration data not found. Fit with probability=True first, "
                "or call platt_plot(X=..., y=...)."
            )

        scores_t = torch.as_tensor(
            self.platt_scores_,
            dtype=torch.double,
            device=getattr(self, "_platt_device_", "cpu"),
        )

        with torch.no_grad():
            proba_t = self.platt_.predict_proba(scores_t)

        proba = proba_t.detach().cpu().numpy()
        y_raw = np.asarray(self.platt_y_).reshape(-1)

    else:
        if y is None:
            raise ValueError("When X is provided, y must also be provided.")

        proba = self.predict_proba(X)
        y_raw = np.asarray(_as_numpy(y)).reshape(-1)

    if proba.ndim == 2:
        p_pos = proba[:, -1].astype(np.float64)
    else:
        p_pos = proba.reshape(-1).astype(np.float64)

    pos_label = self.classes_[1]
    y01 = (y_raw == pos_label).astype(np.float64)

    if p_pos.shape[0] != y01.shape[0]:
        raise ValueError(
            "Predicted probabilities and labels must have the same length."
        )

    # ------------------------------------------------------------
    # Metrics: ECE and Brier
    # ------------------------------------------------------------
    brier = float(np.mean((p_pos - y01) ** 2))

    # ------------------------------------------------------------
    # Binning
    # ------------------------------------------------------------
    if strategy not in {"uniform", "quantile"}:
        raise ValueError("strategy must be 'uniform' or 'quantile'.")

    if strategy == "uniform":
        edges = np.linspace(0.0, 1.0, int(n_bins) + 1)
    else:
        edges = np.quantile(p_pos, np.linspace(0.0, 1.0, int(n_bins) + 1))
        edges = np.unique(edges)
        if edges.size < 2:
            edges = np.array([0.0, 1.0], dtype=np.float64)

    bin_x = []
    bin_y = []
    bin_n = []

    n = p_pos.shape[0]
    ece = 0.0

    for i in range(len(edges) - 1):
        lo, hi = edges[i], edges[i + 1]

        if i == len(edges) - 2:
            mask = (p_pos >= lo) & (p_pos <= hi)
        else:
            mask = (p_pos >= lo) & (p_pos < hi)

        count = int(mask.sum())
        if count == 0:
            continue

        conf = float(p_pos[mask].mean())  # average predicted probability
        acc = float(y01[mask].mean())  # empirical positive frequency

        bin_x.append(conf)
        bin_y.append(acc)
        bin_n.append(count)

        ece += (count / n) * abs(acc - conf)

    bin_x = np.asarray(bin_x, dtype=np.float64)
    bin_y = np.asarray(bin_y, dtype=np.float64)
    bin_n = np.asarray(bin_n, dtype=np.int64)

    # ------------------------------------------------------------
    # Plot
    # ------------------------------------------------------------
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    # light grey background like your example
    fig.patch.set_facecolor("#EAEAF2")
    ax.set_facecolor("#EAEAF2")

    # perfect line
    ax.plot([0, 1], [0, 1], "--", linewidth=1.5, label="Perfect")

    # calibration curve
    label = f"Platt (ECE={ece:.3f}, Brier={brier:.3f})"
    ax.plot(bin_x, bin_y, marker="o", linewidth=1.8, label=label)

    # annotate counts
    if annotate_counts:
        for x_i, y_i, n_i in zip(bin_x, bin_y, bin_n):
            ax.text(
                x_i,
                y_i + 0.015,
                str(int(n_i)),
                ha="center",
                va="bottom",
                fontsize=9,
            )

    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xlabel("Predicted probability (bin average)")
    ax.set_ylabel("Observed frequency (empirical)")
    ax.set_title(title)
    ax.grid(True, alpha=0.35)
    ax.legend(loc="upper left")

    if savepath is not None:
        fig.savefig(savepath, dpi=dpi, bbox_inches="tight")

    stats = {
        "ece": float(ece),
        "brier": float(brier),
        "bin_avg_proba": bin_x,
        "bin_empirical_freq": bin_y,
        "bin_count": bin_n,
    }

    return ax, stats

TorchKMLogit

TorchKMLogit

Bases: _TorchKMBaseBinaryClassifier

Kernel logistic-regression classifier.

TorchKMLogit wraps :class:torchkm.cvklogit.cvklogit with the same estimator interface used by the other TorchKM binary classifiers. It fits a path over candidate C values, chooses best_C_ by cross-validation, and supports CPU or CUDA execution through the device parameter.

The estimator accepts any two distinct class labels and maps them internally to the low-level solver convention. Use decision_function for fitted scores and predict for class labels. Set probability=True to fit Platt calibration and enable predict_proba.

Source code in torchkm/estimators.py
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
class TorchKMLogit(_TorchKMBaseBinaryClassifier):
    """Kernel logistic-regression classifier.

    ``TorchKMLogit`` wraps :class:`torchkm.cvklogit.cvklogit` with the same
    estimator interface used by the other TorchKM binary classifiers. It fits
    a path over candidate ``C`` values, chooses ``best_C_`` by cross-validation,
    and supports CPU or CUDA execution through the ``device`` parameter.

    The estimator accepts any two distinct class labels and maps them
    internally to the low-level solver convention. Use ``decision_function`` for
    fitted scores and ``predict`` for class labels. Set ``probability=True`` to
    fit Platt calibration and enable ``predict_proba``.
    """

    _BACKEND: BackendName = "logit"

platt_plot(X=None, y=None, *, n_bins=15, strategy='uniform', annotate_counts=True, figsize=(5.2, 5.2), title='Calibration (Reliability) Curve', savepath=None, dpi=150, ax=None)

Plot a calibration / reliability curve for the fitted Platt scaler.

Parameters:

Name Type Description Default
X array - like or None

If provided, compute predict_proba(X) and plot reliability against y. If omitted, use the stored training calibration scores from fit().

None
y array - like or None

True labels corresponding to X. If X is None and y is None, stored training labels from fit() are used.

None
n_bins int

Number of bins used in the reliability curve.

15
strategy ('uniform', 'quantile')

How to bin probabilities.

"uniform"
annotate_counts bool

If True, annotate each point with the number of samples in that bin.

True
figsize tuple

Figure size when ax is None.

(5.2, 5.2)
title str

Plot title.

'Calibration (Reliability) Curve'
savepath str or None

If provided, save the plot.

None
dpi int

Save DPI.

150
ax matplotlib axis or None

Existing axis to draw on.

None

Returns:

Name Type Description
ax matplotlib axis
stats dict

Contains ECE, Brier score, bin counts, and plotted points.

Source code in torchkm/estimators.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
def platt_plot(
    self,
    X: Optional[Any] = None,
    y: Optional[Any] = None,
    *,
    n_bins: int = 15,
    strategy: str = "uniform",
    annotate_counts: bool = True,
    figsize: Tuple[float, float] = (5.2, 5.2),
    title: str = "Calibration (Reliability) Curve",
    savepath: Optional[str] = None,
    dpi: int = 150,
    ax=None,
):
    """
    Plot a calibration / reliability curve for the fitted Platt scaler.

    Parameters
    ----------
    X : array-like or None
        If provided, compute predict_proba(X) and plot reliability against y.
        If omitted, use the stored training calibration scores from fit().

    y : array-like or None
        True labels corresponding to X.
        If X is None and y is None, stored training labels from fit() are used.

    n_bins : int
        Number of bins used in the reliability curve.

    strategy : {"uniform", "quantile"}
        How to bin probabilities.

    annotate_counts : bool
        If True, annotate each point with the number of samples in that bin.

    figsize : tuple
        Figure size when ax is None.

    title : str
        Plot title.

    savepath : str or None
        If provided, save the plot.

    dpi : int
        Save DPI.

    ax : matplotlib axis or None
        Existing axis to draw on.

    Returns
    -------
    ax : matplotlib axis
    stats : dict
        Contains ECE, Brier score, bin counts, and plotted points.
    """
    check_is_fitted(self, ["classes_"])

    if self.platt_ is None:
        raise AttributeError(
            "Platt scaler is not fitted. Fit with probability=True before calling platt_plot()."
        )

    try:
        import matplotlib.pyplot as plt
    except Exception as e:
        raise ImportError(
            "platt_plot requires matplotlib. Install it with `pip install matplotlib` "
            "or add it to a visualization extra such as `torchkm[viz]`."
        ) from e

    # ------------------------------------------------------------
    # Get probabilities + labels
    # ------------------------------------------------------------
    if X is None:
        if self.platt_scores_ is None or self.platt_y_ is None:
            raise AttributeError(
                "Stored calibration data not found. Fit with probability=True first, "
                "or call platt_plot(X=..., y=...)."
            )

        scores_t = torch.as_tensor(
            self.platt_scores_,
            dtype=torch.double,
            device=getattr(self, "_platt_device_", "cpu"),
        )

        with torch.no_grad():
            proba_t = self.platt_.predict_proba(scores_t)

        proba = proba_t.detach().cpu().numpy()
        y_raw = np.asarray(self.platt_y_).reshape(-1)

    else:
        if y is None:
            raise ValueError("When X is provided, y must also be provided.")

        proba = self.predict_proba(X)
        y_raw = np.asarray(_as_numpy(y)).reshape(-1)

    if proba.ndim == 2:
        p_pos = proba[:, -1].astype(np.float64)
    else:
        p_pos = proba.reshape(-1).astype(np.float64)

    pos_label = self.classes_[1]
    y01 = (y_raw == pos_label).astype(np.float64)

    if p_pos.shape[0] != y01.shape[0]:
        raise ValueError(
            "Predicted probabilities and labels must have the same length."
        )

    # ------------------------------------------------------------
    # Metrics: ECE and Brier
    # ------------------------------------------------------------
    brier = float(np.mean((p_pos - y01) ** 2))

    # ------------------------------------------------------------
    # Binning
    # ------------------------------------------------------------
    if strategy not in {"uniform", "quantile"}:
        raise ValueError("strategy must be 'uniform' or 'quantile'.")

    if strategy == "uniform":
        edges = np.linspace(0.0, 1.0, int(n_bins) + 1)
    else:
        edges = np.quantile(p_pos, np.linspace(0.0, 1.0, int(n_bins) + 1))
        edges = np.unique(edges)
        if edges.size < 2:
            edges = np.array([0.0, 1.0], dtype=np.float64)

    bin_x = []
    bin_y = []
    bin_n = []

    n = p_pos.shape[0]
    ece = 0.0

    for i in range(len(edges) - 1):
        lo, hi = edges[i], edges[i + 1]

        if i == len(edges) - 2:
            mask = (p_pos >= lo) & (p_pos <= hi)
        else:
            mask = (p_pos >= lo) & (p_pos < hi)

        count = int(mask.sum())
        if count == 0:
            continue

        conf = float(p_pos[mask].mean())  # average predicted probability
        acc = float(y01[mask].mean())  # empirical positive frequency

        bin_x.append(conf)
        bin_y.append(acc)
        bin_n.append(count)

        ece += (count / n) * abs(acc - conf)

    bin_x = np.asarray(bin_x, dtype=np.float64)
    bin_y = np.asarray(bin_y, dtype=np.float64)
    bin_n = np.asarray(bin_n, dtype=np.int64)

    # ------------------------------------------------------------
    # Plot
    # ------------------------------------------------------------
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    # light grey background like your example
    fig.patch.set_facecolor("#EAEAF2")
    ax.set_facecolor("#EAEAF2")

    # perfect line
    ax.plot([0, 1], [0, 1], "--", linewidth=1.5, label="Perfect")

    # calibration curve
    label = f"Platt (ECE={ece:.3f}, Brier={brier:.3f})"
    ax.plot(bin_x, bin_y, marker="o", linewidth=1.8, label=label)

    # annotate counts
    if annotate_counts:
        for x_i, y_i, n_i in zip(bin_x, bin_y, bin_n):
            ax.text(
                x_i,
                y_i + 0.015,
                str(int(n_i)),
                ha="center",
                va="bottom",
                fontsize=9,
            )

    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xlabel("Predicted probability (bin average)")
    ax.set_ylabel("Observed frequency (empirical)")
    ax.set_title(title)
    ax.grid(True, alpha=0.35)
    ax.legend(loc="upper left")

    if savepath is not None:
        fig.savefig(savepath, dpi=dpi, bbox_inches="tight")

    stats = {
        "ece": float(ece),
        "brier": float(brier),
        "bin_avg_proba": bin_x,
        "bin_empirical_freq": bin_y,
        "bin_count": bin_n,
    }

    return ax, stats

TorchKMKQR

TorchKMKQR

Bases: _TorchKMBaseKernelQuantileRegressor

Kernel quantile regressor with integrated model selection.

TorchKMKQR uses :class:torchkm.cvkqr.cvkqr when low_rank=False and :class:torchkm.cvknyqr.cvknyqr when low_rank=True.

Source code in torchkm/estimators.py
1315
1316
1317
1318
1319
1320
class TorchKMKQR(_TorchKMBaseKernelQuantileRegressor):
    """Kernel quantile regressor with integrated model selection.

    ``TorchKMKQR`` uses :class:`torchkm.cvkqr.cvkqr` when ``low_rank=False``
    and :class:`torchkm.cvknyqr.cvknyqr` when ``low_rank=True``.
    """