Calibration and model selection#
Probabilistic numerical finite differences use the formalism of Gaussian process regression to derive the schemes. This brings with it the advantage of uncertainty quantification, but also the burden of choosing a useful prior model.
In this notebook, we will discuss the very basics of model selection and uncertainty quantification.
[1]:
import functools
import jax
import jax.numpy as jnp
import jax.scipy.stats
import probfindiff
from probfindiff.utils import kernel, kernel_zoo
First, a baseline. With a bad scale-parameter-estimate,, the error and uncertainty quantification are off.
[2]:
def k(*, input_scale):
"""Fix the input scale of an exponentiated quadratic kernel."""
return functools.partial(
kernel_zoo.exponentiated_quadratic, input_scale=input_scale
)
dx = 0.1
# an incorrect scale messes up the result
scale = 100.0
scheme, xs = probfindiff.central(dx=dx, kernel=k(input_scale=scale))
f = lambda x: jnp.cos((x - 1.0) ** 2)
fx = f(xs)
dfx, variance = probfindiff.differentiate(fx, scheme=scheme)
dfx_true = jax.grad(f)(0.0)
error, std = jnp.abs(dfx - dfx_true), jnp.sqrt(variance)
print("Scale:\n\t", scale)
print("Error:\n\t", error)
print("Standard deviation:\n\t", std)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Scale:
100.0
Error:
0.67733574
Standard deviation:
3.8611143
We can tune the prior kernel to alleviate this issue. For example, we can compute the maximum-likelihood estimate of the input-scale \(\theta\). The goal is to find
where \(f_\theta\) is the prior Gaussian process, \(f_n\) are the observations of the to-be-differentiated function, and \(x_n\) are the finite difference grid points.
The problem is small, so let us be lazy and compute the minimum with a grid-search over a logarithmic space.
[3]:
@functools.partial(jax.jit, static_argnames=("kernel_from_scale",))
def mle_input_scale(*, xs_data, fx_data, kernel_from_scale, input_scale_trials):
"""Compute the maximum-likelihood-estimate for the input scale."""
# Fix all non-varying parameters, vectorise, and JIT.
scale_to_logpdf = functools.partial(
input_scale_to_logpdf,
fx_data=fx_data,
kernel_from_scale=kernel_from_scale,
xs_data=xs_data,
)
scale_to_logpdf_optimised = jax.jit(jax.vmap(scale_to_logpdf))
# Compute all logpdf values for some trial inputs.
logpdf_values = scale_to_logpdf_optimised(input_scale=input_scale_trials)
# Truly terrible input scales lead to NaN values.
# They are obviously not good candidates for the optimum.
logpdf_values_filtered = jnp.nan_to_num(logpdf_values, -jnp.inf)
# Select the optimum
index_max = jnp.argmax(logpdf_values_filtered)
return input_scale_trials[index_max]
@functools.partial(jax.jit, static_argnames=("kernel_from_scale",))
def input_scale_to_logpdf(*, input_scale, xs_data, fx_data, kernel_from_scale):
"""Compute the logpdf of some data given an input-scale."""
# Select a kernel with the correct input-scale
k_scale = kernel_from_scale(input_scale=input_scale)
k_batch = kernel.batch_gram(k_scale)[0]
# Compute the Gram matrix and evaluate the logpdf
K = k_batch(xs_data, xs_data.T)
return jax.scipy.stats.multivariate_normal.logpdf(
fx_data, mean=jnp.zeros_like(fx_data), cov=K
)
[4]:
scale = mle_input_scale(
xs_data=xs,
fx_data=fx,
kernel_from_scale=k,
input_scale_trials=jnp.logspace(-3, 4, num=1_000, endpoint=True),
)
print("The optimised input scale is:\n\ts =", scale)
The optimised input scale is:
s = 2.3462286
The resulting parameter estimate improves the calibration significantly.
[5]:
scheme, xs = probfindiff.central(dx=dx, kernel=k(input_scale=scale))
dfx, variance = probfindiff.differentiate(f(xs), scheme=scheme)
error, std = jnp.abs(dfx - dfx_true), jnp.sqrt(variance)
print("Scale:\n\t", scale)
print("Error:\n\t", error)
print("Standard deviation:\n\t", std)
Scale:
2.3462286
Error:
0.019150138
Standard deviation:
0.0146484375