Remove dependency of `interpax` in favour of `jax.numpy.interp`
I have recently learnt that there is jax.numpy.interp function that can perform 1-D (linear) interpolation. If linear interpolation is suitable, then it suffices to use this instead of the interpax to remove additional dependencies.
The real issue comes when actually passing this through array-like structures. But this could easily be circumvented via jax.vmap.