Multivariate derivatives#

Other notebooks have explained how to compute partial derivatives, but what if we want full gradients?

[1]:
import jax
import jax.numpy as jnp

from probfindiff import central, differentiate, from_grid, stencil

Let’s define a function \(f: R^d \rightarrow R\).

[2]:
f = lambda z: jnp.dot(z, z)
d = 4

x = jnp.arange(1.0, 1.0 + d)
df = jax.jacfwd(f)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We have to extend the 1d scheme to a multivariate scheme. A multivariate scheme comes with a new set of coefficients and a new grid, that is adapted to the shape of the function.

[3]:
scheme, xs_1d = central(dx=0.01)
xs = stencil.multivariate(xs_1d=xs_1d, shape_input=(d,))
print(xs.shape)
(4, 4, 3)

The shape of xs is deliberate. The “final” dimension of the grid must correspond to the finite-difference-weight-multiplication. The shapes in the very front must correspond to the input and output shape of the domain of the function, because we aim to match the shapes of Jax’ automatic differentiation behaviour. Therefore, the “evaluation domain” (which is the axis that will be “contracted” to shape=() by the scalar-valued function) must be axis=-2 (the only axis left).

Let us evaluate the gradient numerically now.

[4]:
# Firstly, batch over the FD coefficients.
# Secondly, over the input shape(s).
f_batched = jax.vmap(jax.vmap(f, in_axes=-1), in_axes=0)
fx = f_batched(x[None, :, None] + xs)
dfx, _ = differentiate(fx, scheme=scheme)
print(dfx, df(x))
[2.0007117 4.0015507 6.002268  8.003107 ] [2. 4. 6. 8.]

The same can be done for any one-dimensional scheme.

[5]:
xs_1d = jnp.array([-0.1, -0.01, 0.0, 0.01, 0.1])
scheme = from_grid(xs=xs_1d)
xs = stencil.multivariate(xs_1d=xs_1d, shape_input=(d,))

fx = f_batched(x[None, :, None] + xs)
dfx, _ = differentiate(fx, scheme=scheme)
print(xs.shape, dfx, df(x))
(4, 4, 5) [1.9995394 3.9993134 5.9990873 7.9988513] [2. 4. 6. 8.]

The parameter shape_input already suggests that this mechanism extends to more complex schemes, such as Jacobians of truly multivariate functions. But this is content for a different tutorial.