Calibration and model selection

Probabilistic numerical finite differences use the formalism of Gaussian process regression to derive the schemes. This brings with it the advantage of uncertainty quantification, but also the burden of choosing a useful prior model.

In this notebook, we will discuss the very basics of model selection and uncertainty quantification.

[1]:
import functools

import jax
import jax.numpy as jnp
import jax.scipy.stats

import probfindiff
from probfindiff.utils import kernel, kernel_zoo

First, a baseline. With a bad scale-parameter-estimate,, the error and uncertainty quantification are off.

[2]:
def k(*, input_scale):
    """Fix the input scale of an exponentiated quadratic kernel."""
    return functools.partial(
        kernel_zoo.exponentiated_quadratic, input_scale=input_scale
    )


dx = 0.1

# an incorrect scale messes up the result
scale = 100.0
scheme, xs = probfindiff.central(dx=dx, kernel=k(input_scale=scale))

f = lambda x: jnp.cos((x - 1.0) ** 2)
fx = f(xs)
dfx, variance = probfindiff.differentiate(fx, scheme=scheme)

dfx_true = jax.grad(f)(0.0)
error, std = jnp.abs(dfx - dfx_true), jnp.sqrt(variance)
print("Scale:\n\t", scale)
print("Error:\n\t", error)
print("Standard deviation:\n\t", std)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Scale:
         100.0
Error:
         0.67733574
Standard deviation:
         3.8611143

We can tune the prior kernel to alleviate this issue. For example, we can compute the maximum-likelihood estimate of the input-scale \(\theta\). The goal is to find

\[\arg\max_{\theta} p(f_{\theta}(x_n) = f_n, ~ n=0, ..., N \mid \theta)\]

where \(f_\theta\) is the prior Gaussian process, \(f_n\) are the observations of the to-be-differentiated function, and \(x_n\) are the finite difference grid points.

The problem is small, so let us be lazy and compute the minimum with a grid-search over a logarithmic space.

[3]:
@functools.partial(jax.jit, static_argnames=("kernel_from_scale",))
def mle_input_scale(*, xs_data, fx_data, kernel_from_scale, input_scale_trials):
    """Compute the maximum-likelihood-estimate for the input scale."""

    # Fix all non-varying parameters, vectorise, and JIT.
    scale_to_logpdf = functools.partial(
        input_scale_to_logpdf,
        fx_data=fx_data,
        kernel_from_scale=kernel_from_scale,
        xs_data=xs_data,
    )
    scale_to_logpdf_optimised = jax.jit(jax.vmap(scale_to_logpdf))

    # Compute all logpdf values for some trial inputs.
    logpdf_values = scale_to_logpdf_optimised(input_scale=input_scale_trials)

    # Truly terrible input scales lead to NaN values.
    # They are obviously not good candidates for the optimum.
    logpdf_values_filtered = jnp.nan_to_num(logpdf_values, -jnp.inf)

    # Select the optimum
    index_max = jnp.argmax(logpdf_values_filtered)
    return input_scale_trials[index_max]


@functools.partial(jax.jit, static_argnames=("kernel_from_scale",))
def input_scale_to_logpdf(*, input_scale, xs_data, fx_data, kernel_from_scale):
    """Compute the logpdf of some data given an input-scale."""

    # Select a kernel with the correct input-scale
    k_scale = kernel_from_scale(input_scale=input_scale)
    k_batch = kernel.batch_gram(k_scale)[0]

    # Compute the Gram matrix and evaluate the logpdf
    K = k_batch(xs_data, xs_data.T)
    return jax.scipy.stats.multivariate_normal.logpdf(
        fx_data, mean=jnp.zeros_like(fx_data), cov=K
    )
[4]:
scale = mle_input_scale(
    xs_data=xs,
    fx_data=fx,
    kernel_from_scale=k,
    input_scale_trials=jnp.logspace(-3, 4, num=1_000, endpoint=True),
)
print("The optimised input scale is:\n\ts =", scale)
The optimised input scale is:
        s = 2.3462286

The resulting parameter estimate improves the calibration significantly.

[5]:
scheme, xs = probfindiff.central(dx=dx, kernel=k(input_scale=scale))
dfx, variance = probfindiff.differentiate(f(xs), scheme=scheme)

error, std = jnp.abs(dfx - dfx_true), jnp.sqrt(variance)
print("Scale:\n\t", scale)
print("Error:\n\t", error)
print("Standard deviation:\n\t", std)
Scale:
         2.3462286
Error:
         0.019150138
Standard deviation:
         0.0146484375