I want to share a neat way to perform automatic differentiation in Haskell. It generalizes dual numbers to handle arbitrarily high derivatives, via the power of laziness. I call them Taylor numbers* because they’re related to Taylor series.
Here’s the data type for Taylor numbers:
data Taylor t = Taylor t (Taylor t)
These numbers have two parts, the standard part and the differential part:
standardPart :: Taylor t -> t standardPart (Taylor x dx) = x differentialPart :: Taylor t -> Taylor t differentialPart (Taylor x dx) = dx
The important thing is that we can extend pre-existing functions on numbers to Taylor numbers, and these extensions encode the rules of differentiation. To facilitate this, we define the extend and extend2 combinators:
extend f df (Taylor x dx) = Taylor (f x) (df dx) extend2 f df (Taylor x dx) (Taylor y dy) = Taylor (f x y) (df dx dy)
Now we can extend Num, Fractional, and Floating operations:
instance Num t => Num (Taylor t) where fromInteger z = Taylor (fromInteger z) 0 negate = extend negate negate (+) = extend2 (+) (+) (-) = extend2 (-) (-) x * y = extend2 (*) (\ dx dy -> dx * y + x * dy) x y -- not differentiable at x=0 abs x = extend abs (* signum x) x signum = extend signum (const 0)
Notice how the rule for (*) follows the product rule.
instance Fractional t => Fractional (Taylor t) where fromRational q = Taylor (fromRational q) 0 x/y = z where z = extend2 (/) (\dx dy -> (dx - z*dy)/y) x y instance Floating t => Floating (Taylor t) where pi = Taylor pi 0 exp x = y where y = extend exp (* y) x log x = extend log (/ x) x sqrt x = extend sqrt (/ (2 * sqrt x)) x sin x = extend sin (* cos x) x cos x = extend cos (* (- sin x)) x tan x = extend tan (/ (cos x * cos x)) x ...
(You can look up the rest of the derivatives here.) Notice how (/) and exp refer to the result in their own derivative. I thought that was cute (but it’s not necessary).
We represent differentiable functions as functions Taylor t -> Taylor t. To calculate the derivatives of a differentiable function, we define variable and derivative:
variable :: Num t => t -> Taylor t variable x = Taylor x 1 derivative :: Num t => (Taylor t -> Taylor t) -> t -> t derivative f = standardPart . differentialPart . f . variable
For instance, we have:
derivative (\x -> x*x) x == x*2 derivative (\x -> x*x*x) x == x*x*3 derivative exp x == exp x
More generally, we can extract the f(x), f'(x), f”(x), etc coefficients of the Taylor series at a point and use that to calculate nth derivatives:
series :: Taylor t -> [t] series (Taylor x dx) = x : series dx nthDerivative :: Num t => Int -> (Taylor t -> Taylor t) -> t -> t nthDerivative n f = (!! n) . series . f . variable
Why does it work? For the same reason dual numbers work. Taylor numbers are dual numbers where the ε-coefficient is itself a dual number, where the ε-coefficient is itself a dual number, where the ε-coefficient is itself a dual number … to infinity. Due to laziness, we only calculate as many derivatives as we need, without needing to know ahead of time how many derivatives that is.
Next post I’ll write about calculating gradients (i.e multi-variate derivatives) and doing optimisation via gradient descent, based on Taylor numbers.
Exercise. When we’re working with Taylor numbers, the derivative is very often zero. To handle this efficiently, we can redefine Taylor numbers as follows:
data Taylor t = Taylor t (Maybe (Taylor t))
where a differential part of Nothing means the derivative is zero. Redefine the rest of the module to take this into account. What simplifying assumptions can you make for extend and extend2?
*- But note that “Taylor number” is already taken in a different context.