Skip to content

Kernels and Utilities API

This page documents public kernel and utility functions.

Kernel functions

linear_kernel(x1, x2, **kwargs)

Compute the linear kernel between two input vectors.

Parameters:

Name Type Description Default
x1 ndarray

First input vector.

required
x2 ndarray

Second input vector.

required
**kwargs

Ignored keyword arguments accepted for compatibility with other kernel functions.

{}

Returns:

Type Description
float or ndarray

The dot product np.dot(x1, x2).

Source code in torchkm/kernels.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def linear_kernel(x1, x2, **kwargs):
    """Compute the linear kernel between two input vectors.

    Parameters
    ----------
    x1 : numpy.ndarray
        First input vector.
    x2 : numpy.ndarray
        Second input vector.
    **kwargs
        Ignored keyword arguments accepted for compatibility with other kernel
        functions.

    Returns
    -------
    float or numpy.ndarray
        The dot product ``np.dot(x1, x2)``.
    """
    return np.dot(x1, x2)

polynomial_kernel(x1, x2, degree=3, coef0=1, gamma=1, **kwargs)

Compute the polynomial kernel between two input vectors.

The kernel is defined as (gamma * np.dot(x1, x2) + coef0) ** degree.

Parameters:

Name Type Description Default
x1 ndarray

First input vector.

required
x2 ndarray

Second input vector.

required
degree int

Degree of the polynomial kernel.

3
coef0 float

Additive constant in the polynomial kernel.

1
gamma float

Multiplicative scale applied to the dot product.

1
**kwargs

Ignored keyword arguments accepted for compatibility with other kernel functions.

{}

Returns:

Type Description
float or ndarray

Polynomial kernel value.

Source code in torchkm/kernels.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def polynomial_kernel(x1, x2, degree=3, coef0=1, gamma=1, **kwargs):
    """Compute the polynomial kernel between two input vectors.

    The kernel is defined as ``(gamma * np.dot(x1, x2) + coef0) ** degree``.

    Parameters
    ----------
    x1 : numpy.ndarray
        First input vector.
    x2 : numpy.ndarray
        Second input vector.
    degree : int, default=3
        Degree of the polynomial kernel.
    coef0 : float, default=1
        Additive constant in the polynomial kernel.
    gamma : float, default=1
        Multiplicative scale applied to the dot product.
    **kwargs
        Ignored keyword arguments accepted for compatibility with other kernel
        functions.

    Returns
    -------
    float or numpy.ndarray
        Polynomial kernel value.
    """
    return (gamma * np.dot(x1, x2) + coef0) ** degree

rbf_kernel(x1, x2, gamma=0.1, **kwargs)

Compute the radial basis function kernel between two input vectors.

The kernel is defined as exp(-gamma * ||x1 - x2||^2).

Parameters:

Name Type Description Default
x1 ndarray

First input vector.

required
x2 ndarray

Second input vector.

required
gamma float

Positive scale parameter controlling the width of the RBF kernel.

0.1
**kwargs

Ignored keyword arguments accepted for compatibility with other kernel functions.

{}

Returns:

Type Description
float

RBF kernel value.

Source code in torchkm/kernels.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def rbf_kernel(x1, x2, gamma=0.1, **kwargs):
    """Compute the radial basis function kernel between two input vectors.

    The kernel is defined as ``exp(-gamma * ||x1 - x2||^2)``.

    Parameters
    ----------
    x1 : numpy.ndarray
        First input vector.
    x2 : numpy.ndarray
        Second input vector.
    gamma : float, default=0.1
        Positive scale parameter controlling the width of the RBF kernel.
    **kwargs
        Ignored keyword arguments accepted for compatibility with other kernel
        functions.

    Returns
    -------
    float
        RBF kernel value.
    """
    return np.exp(-gamma * np.linalg.norm(x1 - x2) ** 2)

Utility functions

sigest(x, frac=0.5)

PyTorch equivalent of the R function sigest.

Parameters: - x (torch.Tensor): Input tensor of shape (m, n), where m is the number of samples and n is the number of features. - frac (float): Fraction of samples to use for computing the distance.

Returns: - sigma_estimate (float): Estimated sigma based on quantiles of squared distances.

Source code in torchkm/functions.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def sigest(x, frac=0.5):
    """
    PyTorch equivalent of the R function sigest.

    Parameters:
    - x (torch.Tensor): Input tensor of shape (m, n), where m is the number of samples and n is the number of features.
    - frac (float): Fraction of samples to use for computing the distance.

    Returns:
    - sigma_estimate (float): Estimated sigma based on quantiles of squared distances.
    """

    # Number of samples (m)
    m = x.shape[0]

    # Number of random samples to take for the distance calculation
    n = int(frac * m)

    # Randomly sample `n` indices (two sets)
    index1 = torch.randint(0, m, (n,), dtype=torch.long)
    index2 = torch.randint(0, m, (n,), dtype=torch.long)

    # Compute the squared differences between the randomly paired rows
    temp = x[index1] - x[index2]
    dist = torch.sum(temp**2, dim=1)

    # Exclude zero distances (self-pairs)
    non_zero_dist = dist[dist != 0]

    # Compute quantiles (0.9, 0.5, 0.1)
    q = torch.tensor(
        [0.9, 0.5, 0.1], dtype=non_zero_dist.dtype, device=non_zero_dist.device
    )
    srange = 1.0 / torch.quantile(non_zero_dist, q)

    # Return the mean of the 90th and 10th quantiles
    sigma_estimate = torch.mean(srange[[0, 2]]).item()

    return sigma_estimate

rbf_kernel(x, sigma)

Compute the RBF (Gaussian) kernel matrix in PyTorch.

Parameters: - x (torch.Tensor): Input tensor of shape (n_samples, n_features). - sigma (float): The standard deviation parameter for the RBF kernel (Gaussian width).

Returns: - K (torch.Tensor): RBF kernel matrix of shape (n_samples, n_samples).

Source code in torchkm/functions.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def rbf_kernel(x, sigma):
    """
    Compute the RBF (Gaussian) kernel matrix in PyTorch.

    Parameters:
    - x (torch.Tensor): Input tensor of shape (n_samples, n_features).
    - sigma (float): The standard deviation parameter for the RBF kernel (Gaussian width).

    Returns:
    - K (torch.Tensor): RBF kernel matrix of shape (n_samples, n_samples).
    """
    # Compute pairwise squared Euclidean distances
    x_norm = torch.sum(x * x, dim=1, keepdim=True)
    pairwise_dists = x_norm + x_norm.t()
    pairwise_dists.addmm_(x, x.t(), beta=1.0, alpha=-2.0)
    pairwise_dists.clamp_min_(0.0)

    # Compute the RBF kernel matrix
    K = torch.exp(pairwise_dists.mul_(-2.0 * sigma))

    return K

kernelMult(X, X_new, sigma)

Compute the RBF (Gaussian) kernel matrix between X and X_new in PyTorch.

Parameters: - X (torch.Tensor): Input tensor of shape (n_samples_X, n_features). - X_new (torch.Tensor): Input tensor of shape (n_samples_X_new, n_features). - sigma (float): The standard deviation parameter for the RBF kernel (Gaussian width).

Returns: - K (torch.Tensor): RBF kernel matrix of shape (n_samples_X, n_samples_X_new).

Source code in torchkm/functions.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def kernelMult(X, X_new, sigma):
    """
    Compute the RBF (Gaussian) kernel matrix between X and X_new in PyTorch.

    Parameters:
    - X (torch.Tensor): Input tensor of shape (n_samples_X, n_features).
    - X_new (torch.Tensor): Input tensor of shape (n_samples_X_new, n_features).
    - sigma (float): The standard deviation parameter for the RBF kernel (Gaussian width).

    Returns:
    - K (torch.Tensor): RBF kernel matrix of shape (n_samples_X, n_samples_X_new).
    """
    # Compute squared L2 norms
    X_norm = torch.sum(X * X, dim=1, keepdim=True)
    X_new_norm = torch.sum(X_new * X_new, dim=1).view(1, -1)

    # Compute pairwise squared Euclidean distances
    pairwise_dists = X_norm + X_new_norm
    pairwise_dists.addmm_(X, X_new.t(), beta=1.0, alpha=-2.0)
    pairwise_dists.clamp_min_(0.0)

    # Compute the RBF kernel matrix
    K = torch.exp(pairwise_dists.mul_(-2.0 * sigma))

    return K

Probability calibration

PlattScalerTorch

Platt scaling: P(y=1|f) = 1 / (1 + exp(A*f + B)) Fits A,B with regularized logistic regression on decision values f and labels y∈{-1,1}. Uses Newton updates with damping and target smoothing per Platt (1999).

Source code in torchkm/platt.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class PlattScalerTorch:
    """
    Platt scaling: P(y=1|f) = 1 / (1 + exp(A*f + B))
    Fits A,B with regularized logistic regression on decision values f and labels y∈{-1,1}.
    Uses Newton updates with damping and target smoothing per Platt (1999).
    """

    def __init__(
        self, max_iter=100, tol=1e-8, reg=1e-6, dtype=torch.double, device=None
    ):
        self.max_iter = max_iter
        self.tol = tol
        self.reg = reg
        self.dtype = dtype
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.A = None
        self.B = None

    @torch.no_grad()
    def fit(self, f, y):
        """
        f: (n,) decision values (raw scores), torch tensor
        y: (n,) labels in {-1,1}, torch tensor
        """
        f = f.reshape(-1).to(self.device, self.dtype)
        y = y.reshape(-1).to(self.device, self.dtype)

        # convert to targets t in {0,1} and apply target smoothing from Platt
        # t+ = (N+ + 1) / (N+ + 2), t- = 1 / (N- + 2)
        pos = y > 0
        npos = pos.sum().item()
        nneg = y.numel() - npos
        t_pos = (npos + 1.0) / (npos + 2.0)
        t_neg = 1.0 / (nneg + 2.0)
        t = torch.where(
            pos,
            torch.tensor(t_pos, dtype=self.dtype, device=self.device),
            torch.tensor(t_neg, dtype=self.dtype, device=self.device),
        )

        # initialize A,B
        A = torch.tensor(0.0, dtype=self.dtype, device=self.device)
        ratio = torch.tensor(
            (nneg + 1e-12) / (npos + 1e-12), dtype=self.dtype, device=self.device
        )
        B = torch.log(ratio)
        # B = torch.tensor(torch.log((nneg + 1e-12)/(npos + 1e-12)), dtype=self.dtype, device=self.device)

        # Newton updates
        for _ in range(self.max_iter):
            # p = sigmoid(A*f + B)
            z = A * f + B
            # numerically stable sigmoid
            p = torch.where(
                z >= 0, 1.0 / (1.0 + torch.exp(-z)), torch.exp(z) / (1.0 + torch.exp(z))
            )

            # logloss with small L2 on A,B (reg)
            # gradient wrt A,B
            w = p * (1.0 - p)  # Hessian weights
            # g = X^T (p - t) + reg*theta
            gA = torch.sum((p - t) * f) + self.reg * A
            gB = torch.sum(p - t) + self.reg * B

            # Hessian (2x2)
            HAA = torch.sum(w * f * f) + self.reg
            HAB = torch.sum(w * f)  # == HBA
            HBB = torch.sum(w) + self.reg

            # Solve for step: H * [dA, dB]^T = [gA, gB]^T
            det = HAA * HBB - HAB * HAB
            if det.abs() < 1e-24:
                # fallback small step if nearly singular
                stepA = -gA / (HAA + 1e-12)
                stepB = -gB / (HBB + 1e-12)
            else:
                stepA = -(HBB * gA - HAB * gB) / det
                stepB = -(-HAB * gA + HAA * gB) / det

            # damped update to ensure progress
            damping = 1.0
            for _inner in range(10):
                A_new = A + damping * stepA
                B_new = B + damping * stepB

                # check improvement via approximate line-search on NLL
                z_new = A_new * f + B_new
                p_new = torch.where(
                    z_new >= 0,
                    1.0 / (1.0 + torch.exp(-z_new)),
                    torch.exp(z_new) / (1.0 + torch.exp(z_new)),
                )
                # NLL with reg
                eps = 1e-12
                nll_new = -torch.sum(
                    t * torch.log(p_new + eps)
                    + (1.0 - t) * torch.log(1.0 - p_new + eps)
                )
                nll_new = nll_new + 0.5 * self.reg * (A_new * A_new + B_new * B_new)

                p_old = p
                nll_old = -torch.sum(
                    t * torch.log(p_old + eps)
                    + (1.0 - t) * torch.log(1.0 - p_old + eps)
                )
                nll_old = nll_old + 0.5 * self.reg * (A * A + B * B)

                if nll_new <= nll_old + 1e-12:
                    A, B = A_new, B_new
                    break
                damping *= 0.5

            # convergence on parameter step
            if (torch.abs(stepA) + torch.abs(stepB)) < self.tol:
                break

        self.A = A
        self.B = B
        return self

    @torch.no_grad()
    def predict_proba(self, f):
        if self.A is None or self.B is None:
            raise RuntimeError("Call fit() before predict_proba().")

        f = torch.as_tensor(f, dtype=self.dtype, device=self.device).reshape(-1)
        z = self.A * f + self.B
        p1 = torch.where(
            z >= 0, 1.0 / (1.0 + torch.exp(-z)), torch.exp(z) / (1.0 + torch.exp(z))
        )
        # Return [P(y=-1), P(y=1)] per row to match sklearn style
        return torch.stack([1.0 - p1, p1], dim=1)

    @torch.no_grad()
    def predict(self, f):
        proba = self.predict_proba(f)
        return torch.where(proba[:, 1] >= 0.5, 1.0, -1.0).to(self.dtype)

    @torch.no_grad()
    def reliability_curve(self, y_true, p_pred, n_bins=15):
        """
        y_true: tensor/array in {-1,1}
        p_pred: predicted prob P(y=1|x) in [0,1]
        returns: bin_centers, mean_pred, frac_pos, counts
        """
        y = torch.as_tensor(y_true).reshape(-1)
        if y.dtype != torch.float64 and y.dtype != torch.float32:
            y = y.double()
        y01 = (y > 0).double()

        p = torch.as_tensor(p_pred).reshape(-1).double()
        p = torch.clamp(p, 1e-8, 1 - 1e-8)

        edges = torch.linspace(0.0, 1.0, steps=n_bins + 1)
        idx = torch.bucketize(p, edges, right=True) - 1
        idx = torch.clamp(idx, 0, n_bins - 1)

        mean_pred = torch.zeros(n_bins, dtype=torch.double)
        frac_pos = torch.zeros(n_bins, dtype=torch.double)
        counts = torch.zeros(n_bins, dtype=torch.long)

        for b in range(n_bins):
            mask = idx == b
            cnt = mask.sum()
            counts[b] = cnt
            if cnt > 0:
                mean_pred[b] = p[mask].mean()
                frac_pos[b] = y01[mask].mean()

        bin_centers = 0.5 * (edges[:-1] + edges[1:])
        return bin_centers.numpy(), mean_pred.numpy(), frac_pos.numpy(), counts.numpy()

    def expected_calibration_error(self, mean_pred, frac_pos, counts):
        n = counts.sum()
        w = counts / max(n, 1)
        return float(np.sum(w * np.abs(frac_pos - mean_pred)))

    def brier_score(self, y_true, p_pred):
        y01 = (np.array(y_true).reshape(-1) > 0).astype(float)
        p = np.clip(np.array(p_pred).reshape(-1), 1e-8, 1 - 1e-8)
        return float(np.mean((p - y01) ** 2))

    def plot_calibration(
        self, bin_centers, mean_pred, frac_pos, counts, label="Platt", show_counts=True
    ):
        import matplotlib.pyplot as plt

        plt.figure(figsize=(5.2, 5.2), dpi=140)
        plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1.5, label="Perfect")
        plt.plot(mean_pred, frac_pos, marker="o", linewidth=2.0, label=label)
        plt.xlabel("Predicted probability (bin average)", fontsize=12)
        plt.ylabel("Observed frequency (empirical)", fontsize=12)
        plt.title("Calibration (Reliability) Curve", fontsize=13)
        plt.grid(True, alpha=0.3)
        plt.legend()
        if show_counts:
            for x, y, c in zip(mean_pred, frac_pos, counts):
                if c > 0:
                    plt.text(x, y, f"{int(c)}", fontsize=8, ha="center", va="bottom")
        plt.tight_layout()
        plt.show()

fit(f, y)

f: (n,) decision values (raw scores), torch tensor y: (n,) labels in {-1,1}, torch tensor

Source code in torchkm/platt.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@torch.no_grad()
def fit(self, f, y):
    """
    f: (n,) decision values (raw scores), torch tensor
    y: (n,) labels in {-1,1}, torch tensor
    """
    f = f.reshape(-1).to(self.device, self.dtype)
    y = y.reshape(-1).to(self.device, self.dtype)

    # convert to targets t in {0,1} and apply target smoothing from Platt
    # t+ = (N+ + 1) / (N+ + 2), t- = 1 / (N- + 2)
    pos = y > 0
    npos = pos.sum().item()
    nneg = y.numel() - npos
    t_pos = (npos + 1.0) / (npos + 2.0)
    t_neg = 1.0 / (nneg + 2.0)
    t = torch.where(
        pos,
        torch.tensor(t_pos, dtype=self.dtype, device=self.device),
        torch.tensor(t_neg, dtype=self.dtype, device=self.device),
    )

    # initialize A,B
    A = torch.tensor(0.0, dtype=self.dtype, device=self.device)
    ratio = torch.tensor(
        (nneg + 1e-12) / (npos + 1e-12), dtype=self.dtype, device=self.device
    )
    B = torch.log(ratio)
    # B = torch.tensor(torch.log((nneg + 1e-12)/(npos + 1e-12)), dtype=self.dtype, device=self.device)

    # Newton updates
    for _ in range(self.max_iter):
        # p = sigmoid(A*f + B)
        z = A * f + B
        # numerically stable sigmoid
        p = torch.where(
            z >= 0, 1.0 / (1.0 + torch.exp(-z)), torch.exp(z) / (1.0 + torch.exp(z))
        )

        # logloss with small L2 on A,B (reg)
        # gradient wrt A,B
        w = p * (1.0 - p)  # Hessian weights
        # g = X^T (p - t) + reg*theta
        gA = torch.sum((p - t) * f) + self.reg * A
        gB = torch.sum(p - t) + self.reg * B

        # Hessian (2x2)
        HAA = torch.sum(w * f * f) + self.reg
        HAB = torch.sum(w * f)  # == HBA
        HBB = torch.sum(w) + self.reg

        # Solve for step: H * [dA, dB]^T = [gA, gB]^T
        det = HAA * HBB - HAB * HAB
        if det.abs() < 1e-24:
            # fallback small step if nearly singular
            stepA = -gA / (HAA + 1e-12)
            stepB = -gB / (HBB + 1e-12)
        else:
            stepA = -(HBB * gA - HAB * gB) / det
            stepB = -(-HAB * gA + HAA * gB) / det

        # damped update to ensure progress
        damping = 1.0
        for _inner in range(10):
            A_new = A + damping * stepA
            B_new = B + damping * stepB

            # check improvement via approximate line-search on NLL
            z_new = A_new * f + B_new
            p_new = torch.where(
                z_new >= 0,
                1.0 / (1.0 + torch.exp(-z_new)),
                torch.exp(z_new) / (1.0 + torch.exp(z_new)),
            )
            # NLL with reg
            eps = 1e-12
            nll_new = -torch.sum(
                t * torch.log(p_new + eps)
                + (1.0 - t) * torch.log(1.0 - p_new + eps)
            )
            nll_new = nll_new + 0.5 * self.reg * (A_new * A_new + B_new * B_new)

            p_old = p
            nll_old = -torch.sum(
                t * torch.log(p_old + eps)
                + (1.0 - t) * torch.log(1.0 - p_old + eps)
            )
            nll_old = nll_old + 0.5 * self.reg * (A * A + B * B)

            if nll_new <= nll_old + 1e-12:
                A, B = A_new, B_new
                break
            damping *= 0.5

        # convergence on parameter step
        if (torch.abs(stepA) + torch.abs(stepB)) < self.tol:
            break

    self.A = A
    self.B = B
    return self

reliability_curve(y_true, p_pred, n_bins=15)

y_true: tensor/array in {-1,1} p_pred: predicted prob P(y=1|x) in [0,1] returns: bin_centers, mean_pred, frac_pos, counts

Source code in torchkm/platt.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@torch.no_grad()
def reliability_curve(self, y_true, p_pred, n_bins=15):
    """
    y_true: tensor/array in {-1,1}
    p_pred: predicted prob P(y=1|x) in [0,1]
    returns: bin_centers, mean_pred, frac_pos, counts
    """
    y = torch.as_tensor(y_true).reshape(-1)
    if y.dtype != torch.float64 and y.dtype != torch.float32:
        y = y.double()
    y01 = (y > 0).double()

    p = torch.as_tensor(p_pred).reshape(-1).double()
    p = torch.clamp(p, 1e-8, 1 - 1e-8)

    edges = torch.linspace(0.0, 1.0, steps=n_bins + 1)
    idx = torch.bucketize(p, edges, right=True) - 1
    idx = torch.clamp(idx, 0, n_bins - 1)

    mean_pred = torch.zeros(n_bins, dtype=torch.double)
    frac_pos = torch.zeros(n_bins, dtype=torch.double)
    counts = torch.zeros(n_bins, dtype=torch.long)

    for b in range(n_bins):
        mask = idx == b
        cnt = mask.sum()
        counts[b] = cnt
        if cnt > 0:
            mean_pred[b] = p[mask].mean()
            frac_pos[b] = y01[mask].mean()

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    return bin_centers.numpy(), mean_pred.numpy(), frac_pos.numpy(), counts.numpy()

Notes

The utility API is lower-level than the estimator API. Most users should begin with the high-level estimators and use these functions only when they need custom kernels, kernel matrices, or direct solver access.