Batched differentiation#

Rarely does one have to compute only derivatives at single grid-points. More often than not, we want derivatives on a whole grid.

[1]:
import jax.numpy as jnp

from probfindiff import central, differentiate_along_axis, from_grid

Uniform grids#

Let’s say we need to compute the derivatives of a function \(f\) at a whole selection of grid-points. We can do this by exploiting the mechanisms of partial derivatives as follows. The only important thing to remember is that the schemes provided by probfindiff assume that the desired derivative is evaluated at zero, so we need to shift the finite difference grid appropriately.

[2]:
scheme, xs = central(dx=0.01)

grid = jnp.linspace(0.0, 1.0, num=12)

# Nonzero differentiation points
grid_fd = grid[:, None] + xs[None, :]

dfxs, _ = differentiate_along_axis(jnp.sin(grid_fd), axis=1, scheme=scheme)
print(dfxs, jnp.cos(grid))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[1.0003637  0.9962324  0.9838748  0.9633885  0.9349485  0.8987875
 0.8552009  0.80455595 0.7472604  0.6837938  0.6146878  0.54050064] [1.         0.9958706  0.9835166  0.9630399  0.93460965 0.8984607
 0.85489154 0.8042621  0.7469903  0.6835494  0.6144632  0.5403023 ]

Non-uniform grids#

It is not surprising that the above procedure works equally well with non-uniform grids.

[3]:
xs = jnp.array([0.0, 0.3, 0.4])
scheme = from_grid(xs=xs)

grid = jnp.linspace(0.0, 1.0, num=12)
grid_fd = grid[:, None] + xs[None, :]

dfxs, _ = differentiate_along_axis(jnp.sin(grid_fd), axis=1, scheme=scheme)
print(dfxs, jnp.cos(grid))
[1.0196     1.0150714  1.0021594  0.98097146 0.9516821  0.9145313
 0.86982924 0.81794256 0.75930154 0.69438916 0.6237417  0.5479433 ] [1.         0.9958706  0.9835166  0.9630399  0.93460965 0.8984607
 0.85489154 0.8042621  0.7469903  0.6835494  0.6144632  0.5403023 ]

Without redundant function evaluations#

The downside of the above approach is that the function \(f\) has to be evaluated at a few redundant points. Chances are that the point evaluates are already part of the vector.

Can we be more efficient in this case? Yes, we can! While there is still much room for improvement (in terms of API and efficiency), the basics are accessible through convolutions.

[4]:
dx = 0.1
xs = jnp.arange(0.0, 2.0, step=dx)
scheme, _ = central(dx=dx)

# jax.numpy.convolve flips the second coefficient set.
# mode="valid" discards the meaningless points on the boundary
dfx_approx = jnp.convolve(jnp.sin(xs), jnp.flip(scheme.weights), mode="valid")
dfx_true = jnp.cos(xs)

print(dfx_approx)
print(dfx_true[1:-1])  # eliminate values we cannot compute in the above way
[ 0.99335146  0.9784374   0.95374703  0.9195273   0.87611985  0.8239579
  0.7635641   0.6955409   0.62056756  0.5393946   0.45283175  0.36174345
  0.26704264  0.16967201  0.07060671 -0.02916431 -0.12864399 -0.22683764]
[ 0.9950042   0.9800666   0.9553365   0.921061    0.87758255  0.8253356
  0.7648422   0.6967067   0.6216099   0.5403023   0.4535961   0.3623577
  0.26749876  0.16996716  0.0707372  -0.02919955 -0.12884454 -0.22720216]

Since central coefficients are not well-defined on the boundary of the grid, we only obtain the derivatives on the interior. For those, we could use forward/backward differences, or apply boundary conditions. Which one the correct solution is, depends on the application and is left for a different tutorial.