Finite differences on custom grids#

This tutorial explains how to compute finite difference approximations on custom grids.

[1]:
import jax.numpy as jnp

from probfindiff import backward, differentiate, from_grid

Recall how the usual output of finite difference schemes, for instance those resulting from backward, are a scheme and a grid. Subsequently, when applying the scheme, probfindiff assumes that the function has been evalauted at the grid.

[2]:
scheme, xs = backward(dx=0.1)
print(scheme)
print(xs)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
FiniteDifferenceScheme(weights=DeviceArray([ 15.01597  , -20.031385 ,   5.0154157], dtype=float32), covs_marginal=DeviceArray(-5.543232e-05, dtype=float32), order_derivative=DeviceArray(1, dtype=int32, weak_type=True))
[ 0.  -0.1 -0.2]

Custom schemes#

Sometimes, we have a grid and want to compute a corresponding finite difference scheme. For example, when dealing with irregular geometries (circles, curves), or when specific function evaluations are readily available, and more evaluations are costly.

Luckily, there are ways to compute finite difference schemes from a grid.

[3]:
xs = jnp.array([-0.01, 0.0, 2.0])
scheme = from_grid(xs=xs)
print(scheme)
FiniteDifferenceScheme(weights=DeviceArray([-9.9635269e+01,  9.9632912e+01,  2.3571502e-03], dtype=float32), covs_marginal=DeviceArray(-0.00106692, dtype=float32), order_derivative=DeviceArray(1, dtype=int32, weak_type=True))

Where is x?#

For the from_grid(), as well as its cousins forward(), backward(), etc., it is always assumed that the function’s derivative shall be computed at the origin x=0. For instance, the grid (-0.1, 0., 2.) computes something like an unevenly-spaced central difference quotient, because the resulting differentiation scheme will approximate f'(0).

[4]:
dfx, _ = differentiate(jnp.cos(xs), scheme=scheme)
print(dfx, -jnp.sin(0.0))
0.0016435911 -0.0

If you require the finite difference quotient at x=x_0 instead, you can shift the evaluation points accordingly.

[5]:
dfx, _ = differentiate(jnp.cos(xs + 0.75), scheme=scheme)
print(dfx, -jnp.sin(0.75))
-0.6793945 -0.6816388
[ ]:

[ ]: