Calibration and model selection#

Probabilistic numerical finite differences use the formalism of Gaussian process regression to derive the schemes. This brings with it the advantage of uncertainty quantification, but also the burden of choosing a useful prior model.

In this notebook, we will discuss the very basics of model selection and uncertainty quantification.

[1]:
import functools

import jax
import jax.numpy as jnp
import jax.scipy.stats

import probfindiff
from probfindiff.utils import kernel, kernel_zoo

First, a baseline. With a bad scale-parameter-estimate,, the error and uncertainty quantification are off.

[2]:
def k(*, input_scale):
    """Fix the input scale of an exponentiated quadratic kernel."""
    return functools.partial(
        kernel_zoo.exponentiated_quadratic, input_scale=input_scale
    )


dx = 0.1

# an incorrect scale messes up the result
scale = 100.0
scheme, xs = probfindiff.central(dx=dx, kernel=k(input_scale=scale))

f = lambda x: jnp.cos((x - 1.0) ** 2)
fx = f(xs)
dfx, variance = probfindiff.differentiate(fx, scheme=scheme)

dfx_true = jax.grad(f)(0.0)
error, std = jnp.abs(dfx - dfx_true), jnp.sqrt(variance)
print("Scale:\n\t", scale)
print("Error:\n\t", error)
print("Standard deviation:\n\t", std)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Scale:
         100.0
Error:
         0.67733574
Standard deviation:
         3.8611143

We can tune the prior kernel to alleviate this issue. For example, we can compute the maximum-likelihood estimate of the input-scale \(\theta\). The goal is to find

\[\arg\max_{\theta} p(f_{\theta}(x_n) = f_n, ~ n=0, ..., N \mid \theta)\]

where \(f_\theta\) is the prior Gaussian process, \(f_n\) are the observations of the to-be-differentiated function, and \(x_n\) are the finite difference grid points.

The problem is small, so let us be lazy and compute the minimum with a grid-search over a logarithmic space.

[3]:
@functools.partial(jax.jit, static_argnames=("kernel_from_scale",))
def mle_input_scale(*, xs_data, fx_data, kernel_from_scale, input_scale_trials):
    """Compute the maximum-likelihood-estimate for the input scale."""

    # Fix all non-varying parameters, vectorise, and JIT.
    scale_to_logpdf = functools.partial(
        input_scale_to_logpdf,
        fx_data=fx_data,
        kernel_from_scale=kernel_from_scale,
        xs_data=xs_data,
    )
    scale_to_logpdf_optimised = jax.jit(jax.vmap(scale_to_logpdf))

    # Compute all logpdf values for some trial inputs.
    logpdf_values = scale_to_logpdf_optimised(input_scale=input_scale_trials)

    # Truly terrible input scales lead to NaN values.
    # They are obviously not good candidates for the optimum.
    logpdf_values_filtered = jnp.nan_to_num(logpdf_values, -jnp.inf)

    # Select the optimum
    index_max = jnp.argmax(logpdf_values_filtered)
    return input_scale_trials[index_max]


@functools.partial(jax.jit, static_argnames=("kernel_from_scale",))
def input_scale_to_logpdf(*, input_scale, xs_data, fx_data, kernel_from_scale):
    """Compute the logpdf of some data given an input-scale."""

    # Select a kernel with the correct input-scale
    k_scale = kernel_from_scale(input_scale=input_scale)
    k_batch = kernel.batch_gram(k_scale)[0]

    # Compute the Gram matrix and evaluate the logpdf
    K = k_batch(xs_data, xs_data.T)
    return jax.scipy.stats.multivariate_normal.logpdf(
        fx_data, mean=jnp.zeros_like(fx_data), cov=K
    )
[4]:
scale = mle_input_scale(
    xs_data=xs,
    fx_data=fx,
    kernel_from_scale=k,
    input_scale_trials=jnp.logspace(-3, 4, num=1_000, endpoint=True),
)
print("The optimised input scale is:\n\ts =", scale)
The optimised input scale is:
        s = 2.3462286

The resulting parameter estimate improves the calibration significantly.

[5]:
scheme, xs = probfindiff.central(dx=dx, kernel=k(input_scale=scale))
dfx, variance = probfindiff.differentiate(f(xs), scheme=scheme)

error, std = jnp.abs(dfx - dfx_true), jnp.sqrt(variance)
print("Scale:\n\t", scale)
print("Error:\n\t", error)
print("Standard deviation:\n\t", std)
Scale:
         2.3462286
Error:
         0.019150138
Standard deviation:
         0.0146484375