Haskell optimisations through continuation passing style and defunctionalisation
(This post is also a literate Haskell program.)
Hello everyone :) This will be a short walkthrough on an interesting technique I saw in the Principles of Programming Languages course from Oxford (in the teaching materials made by Prof. Mike Spivey). It’s not at all originally by me, but I like it a lot, so I wanted to write something about it.
Let’s start with an easy definition for reverse.
> import Prelude hiding ((++), reverse, sum)
>
> reverse1 :: [a] -> [a]
> reverse1 [] = []
> reverse1 (x:xs) = reverse1 xs ++ [x]
For completeness I’ll also define (++)
, which stands for concatenation.
> (++) :: [a] -> [a] -> [a]
> [] ++ ys = ys
> (x:xs) ++ ys = x:(xs ++ ys)
The definition from before is unfortunately quadratic — we’d like to make it linear. The first step in this approach to optimisation is to write it in continuation passing style. Briefly, this means making all recursive calls “the last thing we do in a function call”, by passing a “continuation” (a function that tells the recursive call “what to do next”). This gives us the following definition for reverse, which also takes a continuation.
> reversek :: ([a] -> b) -> [a] -> b
> reversek k [] = k []
> reversek k (x:xs) = reversek (\rs -> k (rs ++ [x])) xs
>
> reverse2 = reversek id
First convince yourself that reverse and reverse2 are equivalent. Now, the next step will be to defunctionalise. We look at all the continuations used (viz. id
, \rs -> rs ++ [x]
), and create a type that represents these continuations, which we call K.
> data K a = Id | Lam a (K a)
(Note that in a Lam a k
value, the a
is the free x
in \rs -> rs ++ [x]
.) Now we create a function that maps K
a to an actual function.
> apply :: K a -> [a] -> [a]
> apply Id xs = xs
> apply (Lam x k) rs = apply k (rs ++ [x])
And now we rewrite reverse
again, with this type.
> reversek2 :: K a -> [a] -> [a]
> reversek2 k [] = apply k []
> reversek2 k (x:xs) = reversek2 (Lam x k) xs
>
> reverse3 = reversek2 Id
Now the crucial observations. First, we note that K
a is either an Id
, or an a
and another K a
. Just like the list type! These two types are isomorphic, with a value like Lam 1 (Lam 2 (Lam 3 Id))
corresponding to [1, 2, 3]
. So we rewrite apply
and reverse
again, but using this observation.
> apply2 :: [a] -> [a] -> [a]
> apply2 [] xs = xs
> apply2 (x:k) rs = apply2 k (rs ++ [x])
>
> reversek3 :: [a] -> [a] -> [a]
> reversek3 k [] = apply2 k []
> reversek3 k (x:xs) = reversek3 (x:k) xs
>
> reverse4 = reversek3 []
Second, let’s investigate apply2
. What happens if we reverse the order of its arguments to get an rapply
function? (Also, write the recursive rapply
infix).
> rapply xs [] = xs
> rapply rs (x:k) = (rs ++ [x]) `rapply` k
What we notice is that rapply = (++)
! (This is also not difficult to prove inductively). This also implies that apply2 = flip (++)
, or equivalently apply2 xs ys = ys ++ xs
. So let’s use this in a fourth version of reversek.
> reversek4 k [] = [] ++ k
> reversek4 k (x:xs) = reversek4 (x:k) xs
>
> reverse5 = reversek4 []
Now all we need to do is note that [] ++ k = k
(since []
is an identity element w.r.t. (++)
). This lets us write the following definition.
> reversek5 k [] = k
> reversek5 k (x:xs) = reversek5 (x:k) xs
>
> reverse6 = reversek5 []
If we were to rewrite this as a single function, and rename reversek5
, we get precisely the linear time library definition of reverse.
> reverse = go [] where
> go xs [] = xs
> go xs (y:ys) = go (y:xs) ys
This whole process brings to the forefront the fact that the extra parameter given to go is in some sense a representation of a continuation. I also rather like that the only part of the process which was not entirely mechanical was the observation that apply2 xs [] = xs
.