I want to share a neat way to perform automatic differentiation in Haskell. It generalizes dual numbers to handle arbitrarily high derivatives, via the power of laziness. I call them Taylor numbers* because they’re related to Taylor series.

Here’s the data type for Taylor numbers:

data Taylor t = Taylor t (Taylor t)

These numbers have two parts, the standard part and the differential part:

standardPart :: Taylor t -> t standardPart (Taylor x dx) = x differentialPart :: Taylor t -> Taylor t differentialPart (Taylor x dx) = dx

The important thing is that we can extend pre-existing functions on numbers to Taylor numbers, and these extensions encode the rules of differentiation. To facilitate this, we define the *extend* and *extend2* combinators:

extend f df (Taylor x dx) = Taylor (f x) (df dx) extend2 f df (Taylor x dx) (Taylor y dy) = Taylor (f x y) (df dx dy)

Now we can extend *Num*, *Fractional*, and *Floating* operations:

instance Num t => Num (Taylor t) where fromInteger z = Taylor (fromInteger z) 0 negate = extend negate negate (+) = extend2 (+) (+) (-) = extend2 (-) (-) x * y = extend2 (*) (\ dx dy -> dx * y + x * dy) x y -- not differentiable at x=0 abs x = extend abs (* signum x) x signum = extend signum (const 0)

Notice how the rule for (*) follows the product rule.

instance Fractional t => Fractional (Taylor t) where fromRational q = Taylor (fromRational q) 0 x/y = z where z = extend2 (/) (\dx dy -> (dx - z*dy)/y) x y instance Floating t => Floating (Taylor t) where pi = Taylor pi 0 exp x = y where y = extend exp (* y) x log x = extend log (/ x) x sqrt x = extend sqrt (/ (2 * sqrt x)) x sin x = extend sin (* cos x) x cos x = extend cos (* (- sin x)) x tan x = extend tan (/ (cos x * cos x)) x ...

(You can look up the rest of the derivatives here.) Notice how (/) and exp refer to the result in their own derivative. I thought that was cute (but it’s not necessary).

We represent differentiable functions as functions *Taylor t -> Taylor t*. To calculate the derivatives of a differentiable function, we define *variable* and *derivative*:

variable :: Num t => t -> Taylor t variable x = Taylor x 1 derivative :: Num t => (Taylor t -> Taylor t) -> t -> t derivative f = standardPart . differentialPart . f . variable

For instance, we have:

derivative (\x -> x*x) x == x*2 derivative (\x -> x*x*x) x == x*x*3 derivative exp x == exp x

More generally, we can extract the *f(x), f'(x), f”(x), etc* coefficients of the Taylor series at a point and use that to calculate nth derivatives:

series :: Taylor t -> [t] series (Taylor x dx) = x : series dx nthDerivative :: Num t => Int -> (Taylor t -> Taylor t) -> t -> t nthDerivative n f = (!! n) . series . f . variable

Why does it work? For the same reason dual numbers work. Taylor numbers are dual numbers where the ε-coefficient is itself a dual number, where the ε-coefficient is itself a dual number, where the ε-coefficient is itself a dual number … to infinity. Due to laziness, we only calculate as many derivatives as we need, without needing to know ahead of time how many derivatives that is.

Next post I’ll write about calculating gradients (i.e multi-variate derivatives) and doing optimisation via gradient descent, based on Taylor numbers.

**Exercise.** When we’re working with Taylor numbers, the derivative is very often zero. To handle this efficiently, we can redefine Taylor numbers as follows:

data Taylor t = Taylor t (Maybe (Taylor t))

where a differential part of *Nothing* means the derivative is zero. Redefine the rest of the module to take this into account. What simplifying assumptions can you make for *extend* and *extend2*?

*- But note that “Taylor number” is already taken in a different context.