Partial derivatives and other axis magic#

Can probfindiff also do partial derivatives? Yes, it can do this and more!

[1]:
import jax.numpy as jnp

from probfindiff import central, differentiate_along_axis

Partial derivatives#

Consider a function \(f=f(x,y)\). To compute its partial derivative \(\partial/\partial x f\), we can use finite differences. To this end, we build a meshgrid-style evaluation of \(f\) at the finite difference nodes (just as if we made a contour plot with matplotlib) and differentiate the resulting (n,n) array numerically.

[2]:
scheme, xs = central(dx=0.05)

fx = jnp.sin(xs)[:, None] * jnp.cos(jnp.zeros(1))[None, :]
dfdx_approx, _ = differentiate_along_axis(fx, axis=0, scheme=scheme)
print(dfdx_approx, jnp.cos(0.0) * jnp.cos(jnp.zeros(1)))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[0.9995756] [1.]

Mixed derivatives#

It is just as easy to compute mixed derivatives. For example, to compute

\[\frac{\partial^3}{\partial x \partial y^2} f(x,y)\]

we chain finite difference schemes together.

[3]:
scheme_x, xs = central(dx=0.1, order_derivative=1)
scheme_y, ys = central(dx=0.05, order_derivative=2)

fx = jnp.sin(xs)[:, None] * jnp.cos(ys)[None, :]
dfdx_approx, _ = differentiate_along_axis(fx, axis=0, scheme=scheme_x)
dfx_approx, _ = differentiate_along_axis(dfdx_approx, axis=0, scheme=scheme_y)
print(dfx_approx, -jnp.cos(0.0) ** 2)
-1.0064235 -1.0

If you’ve read the modelling tutorial, you will notice how this chain of applications implies a specific model. More specifically, the above is a good idea if the function \(f\) splits into the product

\[f(x,y) = f_1(x) f_2(y).\]

If not, there are better approaches. This will be left for a different tutorial.

Batched derivatives#

Once we have the scheme, we can also use the differentiate_along_axis() function to compute batched finite difference evaluations. Since we set up the schemes independently of applying them, we can pick a scheme and apply it to a batch of function evaluations easily.

[4]:
scheme, xs = central(dx=0.01)
fx_batch = jnp.sin(xs)[:, None] * jnp.linspace(0.0, 1.0, 100)[None, :]
dfx_batch, _ = differentiate_along_axis(fx_batch, axis=0, scheme=scheme)

difference = dfx_batch - jnp.cos(0.0) * jnp.linspace(0.0, 1.0, 100)
print(difference.shape, jnp.linalg.norm(difference) / jnp.sqrt(difference.size))
(100,) 0.00021049284