commit 861aa4d (2022-02-25 15:15:26 -0500) Torsten Scholak: configure doctests
Posted on Jan 20, 2022
33 min read
Have you ever wanted to write a recursive function and wondered what would happen if someone took away recursion from Haskell? Say goodbye to recursive function calls, say goodbye to recursive data types. How sad Haskell would be without them! I'm sure that thought must have occured to you -- if not, what are you even doing here?! Well, this article has you covered should that day ever come. After reading it, you will know how to write a recursive function that doesn't recurse.
For this Literate Haskell essay, a few language extensions are required. Nothing extraordinarily fancy, just the usually fancy Haskell flavour.
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE LambdaCase #-}
A bunch of these are part of GHC2021 that was introduced in GHC 9.2, but this blog is boring and still running on GHC 8.10.
Unsurprisingly, we will also work with code that is not in the
Prelude
:
module Unrecurse where
import Control.Lens (zoom, _1, _2)
import Control.Monad.Cont (ContT (runContT))
import Control.Monad.State (MonadState (get, put), StateT (runStateT), evalStateT, modify)
import Control.Monad.Trans (MonadTrans (lift))
import Control.Monad.Writer (Writer, execWriter, tell)
import Data.Monoid (Sum (..))
import GHC.Generics (Generic)
import Prelude hiding (even, odd)
Now, since this is settled, we can get on with the article.
The first part of this article is based on a 2019 post and talk by James Koppel about the equivalence of recursion and iteration.
James' talk is a great introduction to the topic, but it left out a few important details and did not explain how to generalize the example from the talk to more complicated cases. Also, the original talk used Java and only a little bit of Haskell, but I'm going to exclusively use Haskell here. This will make it clearer for people familiar with Haskell but not with Java, like me.
James' example is a recursive function that prints the content of a binary tree. The problem he chose to focus on is to convert that recursive function to an iterative one. We will reproduce the conversion process in Haskell code, and the first step is going to be to define an abstract data type for binary trees:
data Tree a
= Nil
| Node
left :: Tree a,
{ content :: a,
right :: Tree a
}deriving stock (Eq, Show, Functor, Foldable, Generic)
For illustration purposes, we will use the following balanced tree that carries consecutive integers at its seven leaves:
exampleTree :: Tree Int
=
exampleTree Node
Node
( Node Nil 1 Nil)
(2
Node Nil 3 Nil)
(
)4
Node
( Node Nil 5 Nil)
(6
Node Nil 7 Nil)
( )
Really, any tree will do, but we will use this one. We can print the
contents of our tree using the Show
instance of
integers.
-- | Print the content of a `Tree a`.
-- >>> printTree exampleTree
-- 1
-- 2
-- 3
-- 4
-- 5
-- 6
-- 7
printTree :: forall a. Show a => Tree a -> IO ()
Nil = pure ()
printTree Node {..} = do
printTree
printTree leftprint content
printTree right
This is the function we want to convert to an iterative one. In its current form, it contains two recursive calls to itself. To eliminate the recursive calls, we need to perform a number of transformations. Each transformation is potentially headache inducing, and I am not going to promise pretty code. In fact, the code is going to get worse and worse.
With your expectations properly lowered, let's start with the first
transformation. We need to rewrite the printTree
function
to use continuations, using something called the "continuation passing
style" or CPS. I'm not going to explain CPS in all its glory, only that
it is a style of programming in which functions do not return values,
but instead repeatedly pass control to a function called a continuation
that decides what to do next. The mind-bending implications of this
style of programming are introduced and discussed in this
article on wikibooks.org. Have a look at that article if you are
interested in the fine details. It's not necessary to understand the
details of CPS to understand the rest of this article, though.
Haskell is particularly good at handling the CPS style, rewriting to
it is easy and mechanical. The tool we will use is called the
ContT
monad transformer. It is found in the transformers
package. ContT
will wrap the IO
monad, and we
will use the >>=
operator to chain continuations.
IO
values need to be lifted to ContT
values.
This gives us:
printTree' ::
forall a r.
Show a =>
Tree a ->
ContT r IO ()
Nil = pure ()
printTree' Node {..} = do
printTree'
printTree' left$ print content
lift printTree' right
The code appears mostly the same, except for a few subtleties. We
have changed the return type and added a new type variable,
r
, that we cannot touch since it is quantified over. It is
the type of the result of the continuation. r
will only be
of interest when we run the ContT
monad transformer. This
is done using the
runContT :: ContT r m a -> (a -> m r) -> m r
function. It runs the CPS computation encoded in
ContT r m a
and gets the result, but only if we seed it
with one final continuation of the form a -> m r
:
-- | Run the `ContT` computation for `printTree'`.
-- >>> runPrintTree' exampleTree
-- 1
-- 2
-- 3
-- 4
-- 5
-- 6
-- 7
runPrintTree' :: Show a => Tree a -> IO ()
=
runPrintTree' tree pure runContT (printTree' tree)
The final continuation is pure
, and it decides that the
result of the continuation is r ~ ()
. Everything still
works as expected, phew.
The ContT
monad transformer is a great convenience and
allows us to deal with continuations in the familiar monadic style.
What's great for Haskell and its users is not so great for us in this
article, though. Remember, we want less of that good, idiomatic Haskell,
not more of it. The goodbye to recursion must hurt. ContT
is very effective at hiding what is actually happening behind the
scenes. We need to pull that curtain up and see what is going on. Below
is the same code as before, but with the ContT
monad
transformer inlined into the printTree'
function:
printTree'' ::
forall a r.
Show a =>
Tree a ->
-> IO r) ->
(() IO r
Nil = \c -> c ()
printTree'' Node {..} =
printTree'' let first = \c -> printTree'' left c
= \c -> print content >>= c
second = \c -> printTree'' right c
third = \c -> second (\x -> (\() -> third) x c)
inner = \c -> first (\x -> (\() -> inner) x c)
outer in outer
This is starting to look nasty. Don't look at me, I warned you.
printTree''
is now a higher-order function that takes a
continuation function, c
, as its second argument.
Notwithstanding, we can also clearly see now that c
takes a
value of type ()
and returns a value of type
IO r
. The ()
type of the argument is a
consequence of the fact that the original printTree
function returned a value of type IO ()
.
Let's take a look at the Node
case of
printTree''
. The do
notation is desugared into
nested continuations. inner
chains the second
and third
functions, and outer
chains the
first
and inner
functions. first
happens first, then second
, and then third
. We
can convince ourselves that this painfully obfuscated
printTree''
function is still computing the same result as
the printTree
function.
-- | Run the CPS computation for `printTree''`.
-- >>> runPrintTree'' exampleTree
-- 1
-- 2
-- 3
-- 4
-- 5
-- 6
-- 7
runPrintTree'' :: Show a => Tree a -> IO ()
=
runPrintTree'' tree -> pure ()) printTree'' tree (\()
What we have achieved so far is nice, but we still have not
eliminated any recursive calls. In order to make progress, we need to
convert the higher-order printTree''
function to a
first-order function. This process is called "defunctionalization". Once
again, there is a wiki
article on the subject if you are interested. It's one of those rare
and wonderous articles on Wikipedia that exclusively use Haskell to
explain things. Not that there should be more of them, but I
digress.
Concretely, we defunctionalize the printTree''
function
by replacing all the continuations, c :: () -> IO r
,
with a value of a new data type Kont (Next a)
.
Kont
and Next
look like this:
data Kont next
= Finished
| More next (Kont next)
data Next a
= First (Tree a)
| Second a
| Third (Tree a)
Kont
is a recursive data type with two constructors,
Finished
and More
. We use
Finished
to terminate the computation, and
More
to indicate that we need to continue. When that
happens, we use Next
to describe the details of the next
step in the computation. Each constructor of Next
corresponds to a different action that needs to be taken. The
First
constructor is named after the first
function in printTree''
. It takes a left subtree as
argument. The Second
constructor is named after the
second
function in printTree''
. Its argument
is the content of the current node that needs to be printed. The
Third
constructor is named after the third
function in printTree''
, and it takes a right subtree as
argument. We need a new function, apply
, that interprets a
Kont (Next a)
value and executes the corresponding
action:
apply ::
forall a.
Show a =>
Kont (Next a) ->
IO ()
More (First left) c) =
apply (
printTree''' left cMore (Second content) c) =
apply (print content >> apply c
More (Third right) c) =
apply (
printTree''' right cFinished =
apply pure ()
We can see here how different Next
values correspond to
different actions. When the computation is finished, apply
returns ()
. When there is more work to do,
apply
either calls the yet-to-be-defined
printTree'''
function with the next subtree and the
continuation value, c :: Kont (Next a)
, or it calls the
print
function on the content of the node and then itself
to continue the computation. The purpose of the
printTree'''
function is to build continuation values and
then call apply
to execute the corresponding actions:
printTree''' ::
forall a.
Show a =>
Tree a ->
Kont (Next a) ->
IO ()
Nil c = apply c
printTree''' Node {..} c =
printTree'''
applyMore
( First left)
(More
( Second content)
(More (Third right) c)
(
) )
How do we know that this is all working correctly? We can run it:
-- | Run the defunctionalized CPS computation
-- for `printTree'''`.
-- >>> runPrintTree''' exampleTree
-- 1
-- 2
-- 3
-- 4
-- 5
-- 6
-- 7
runPrintTree''' :: Show a => Tree a -> IO ()
=
runPrintTree''' tree Finished printTree''' tree
Great, we have successfully defunctionalized printTree''
and turned it into a first-order function. This has been a major step,
but we are not done yet. We still have to eliminate the recursive calls.
It only appears as if printTree'''
doesn't call itself
anymore. In fact, it still does, just indirectly through mutual
recursion. This becomes more apparent once we inline apply
into printTree'''
:
printTree'''' ::
forall a.
Show a =>
Tree a ->
Kont (Next a) ->
IO ()
Nil (More (First left) c) =
printTree''''
printTree'''' left cNil (More (Second content) c) =
printTree'''' print content >> printTree'''' Nil c
Nil (More (Third right) c) =
printTree''''
printTree'''' right cNil Finished = pure ()
printTree'''' Node {..} c =
printTree''''
printTree''''Nil
More
( First left)
(More
( Second content)
(More (Third right) c)
(
) )
At first glance, printTree''''
doesn't look better than
what we had before. We went from two recursive calls up to now four.
There is something interesting going on here, though. The recursive call
to printTree''''
always appears in the tail position. And
this means that we should be able to replace the calls with a loop.
Wikipedia also has an
article on that. However, we also notice that
printTree''''
is called with different arguments in all
four cases. We can't replace these calls with a loop without removing
those arguments first, and we can't remove the arguments until we have a
way to keep track of them. Or do we? Haskell has a special type, called
State
, that allows for just that. On hackage,
State
is available as Control.Monad.Trans.State.Lazy
.
Rather than passing arguments, we are going to update the
State
with the arguments' values. In preparation for this,
let us have a closer look at Kont
. You may have already
noticed, but our Kont
is isomorphic to []
:
Finished
is the empty list, and More
is simply
the list constructor. Let's acknowledge that fact by using the
[]
data type directly and by giving Kont
a
better name: Stack
.
type Stack a = [a]
The only operations on Stack
we will need in the
following are adding and removing single elements from its end. These
operations are commonly called push
and pop
.
They can be implemented as follows:
push ::
forall a m.
MonadState (Stack a) m =>
->
a
m ()= modify (x :)
push x
pop ::
forall a m.
MonadState (Stack a) m =>
Maybe a)
m (=
pop >>= \case
get -> pure Nothing
[] : xs) -> put xs >> pure (Just x) (x
Notice here how we use get
, put
, and
modify
to update the Stack
state. With
push
and pop
, printTree''''
becomes:
printTree''''' ::
forall a.
Show a =>
Tree a ->
StateT (Stack (Next a)) IO ()
Nil = do
printTree''''' <- pop
next case next of
Just (First left) ->
printTree''''' leftJust (Second content) ->
print content)
lift (>> printTree''''' Nil
Just (Third right) ->
printTree''''' rightNothing -> pure ()
Node {..} = do
printTree''''' Third right)
push (Second content)
push (First left)
push (Nil printTree'''''
One thing to note here is that we use StateT
, not
State
, to represent the state of our computation. This is
because StateT
is a monad transformer, and we already have
effects in IO
that we need to lift into it.
Rather than building a value of type Stack (Next a)
that
we pass to printTree''''
, in printTree'''''
,
we use push
and pop
to update the
Stack
state. As always, we need to confirm that this is
doing the right thing:
-- | Run the defunctionalized CPS computation
-- for `printTree'''''`.
-- >>> runPrintTree''''' exampleTree
-- 1
-- 2
-- 3
-- 4
-- 5
-- 6
-- 7
runPrintTree''''' :: Show a => Tree a -> IO ()
=
runPrintTree''''' tree evalStateT (printTree''''' tree) []
The result is as expected. This takes care of the
Kont (Next a)
argument, but we still need to eliminate the
Tree a
argument. We can use the same trick again and add
Tree a
to the state:
printTree'''''' ::
forall a.
Show a =>
StateT (Tree a, Stack (Next a)) IO ()
= do
printTree'''''' <- zoom _1 get
tree case tree of
Nil -> do
<- zoom _2 pop
c case c of
Just (First left) -> do
$ put left
zoom _1
printTree''''''Just (Second content) -> do
print content)
lift ($ put Nil
zoom _1
printTree''''''Just (Third right) -> do
$ put right
zoom _1
printTree''''''Nothing -> pure ()
Node {..} -> do
$ push (Third right)
zoom _2 $ push (Second content)
zoom _2 $ push (First left)
zoom _2 $ put Nil
zoom _1 printTree''''''
In this new version, printTree''''''
appears in the tail
position and does not have any arguments. We are now ready for the final
magic step, which is to eliminate all recursive calls.
Let us introduce a little helper, while
:
data Continue = Continue | Break
while ::
forall m.
Monad m =>
Continue ->
m
m ()= m >>= \case
while m Continue -> while m
Break -> pure ()
This function just runs the given monadic computation,
m :: m Continue
, again and again until it returns
Break
. This is as close to a "while" loop as we can get in
Haskell, we can't remove the recursion here. Let us pretend the
hypothetical recursion-free Haskell has while
as a
primitive.
With this, we finally have:
printTree''''''' ::
forall a.
Show a =>
StateT (Tree a, Stack (Next a)) IO ()
=
printTree''''''' $ do
while <- zoom _1 get
tree case tree of
Nil -> do
<- zoom _2 pop
c case c of
Just (First left) -> do
$ put left
zoom _1 pure Continue
Just (Second content) -> do
print content)
lift ($ put Nil
zoom _1 pure Continue
Just (Third right) -> do
$ put right
zoom _1 pure Continue
Nothing -> pure Break
Node {..} -> do
$ push (Third right)
zoom _2 $ push (Second content)
zoom _2 $ push (First left)
zoom _2 $ put Nil
zoom _1 pure Continue
There are no recursive calls left. Well, almost, because we are still
using the recursive Tree
type. We will worry about that in
a follow-up piece.
This marvel of a function still computes the same result,
-- | Run the unrolled `printTree'''''''` program.
-- >>> runPrintTree''''''' exampleTree
-- 1
-- 2
-- 3
-- 4
-- 5
-- 6
-- 7
runPrintTree''''''' :: Show a => Tree a -> IO ()
=
runPrintTree''''''' tree evalStateT printTree''''''' (tree, [])
Fantastic! What's next?
The printTree
example from James' 2019 talk is a bit too
simple, because we run printTree
only for its effects. We
don't care about the result since it is just ()
and will be
the same for all trees we pass to printTree
. What if this
was different? In many real applications, we might want to reduce the
tree to a value, say, accumulate a result, like a sum. This common
pattern is captured by the function below:
accumTree :: forall a. Monoid a => Tree a -> a
Nil = mempty
accumTree Node {..} =
accumTree let lacc = accumTree left
= accumTree right
racc in lacc <> content <> racc
Using accumTree
, summing up all the content values of
our example tree is a simple matter:
-- | Sum the content values of a tree.
-- >>> sumTree exampleTree
-- 28
sumTree :: Num a => Tree a -> a
= getSum $ accumTree (Sum <$> tree) sumTree tree
Sum
is a newtype wrapper whose only function is to
select the Monoid
instance for summation, where
mempty
is zero and mappend
is addition. This
is necessary since Haskell doesn't have named class instances.
By the way, since Tree
has a Foldable
instance, we could have achieved the same by using foldMap
from Data.Foldable
:
-- | Sum the content values of a tree.
-- >>> sumTree' exampleTree
-- 28
sumTree' :: (Num a, Foldable f) => f a -> a
= getSum $ foldMap Sum tree sumTree' tree
This certainly is convenient. Haskellers love foldMap
,
because they can use it on any data type that has a
Foldable
instance, not just Tree
. But this is
supposed to be a tutorial about inconvenient, recursion-free Haskell,
and for that we need to understand how we can remove recursion from
accumTree
. The tool we are going to use here is the
Writer
monad.
accumTree'' ::
forall a.
Monoid a =>
Tree a ->
Writer a ()
Nil = pure ()
accumTree'' Node {..} = do
accumTree''
accumTree'' left
tell content accumTree'' right
This version of accumTree
is a bit more complicated.
Accumulation has been moved to the Writer
monad, and we
need to use tell
to accumulate the result. We also are
working with a monadic effect, Writer
, whereas
accumTree
was pure. We can extract the result by using
execWriter
:
-- | Sum the content values of a tree.
-- >>> sumTree'' exampleTree
-- 28
sumTree'' :: Num a => Tree a -> a
=
sumTree'' tree $ execWriter $ accumTree'' (Sum <$> tree) getSum
Great, this works. Now, sharp eyes will notice that
accumTree''
and printTree
from before are
structurally equivalent. The only difference is that
printTree
uses IO
, and
accumTree''
uses Writer a
. That's all. Based
on what we have seen so far, we can therefore immediately see how
accumTree''
can be made iterative:
accumTree''''''' ::
forall a.
Monoid a =>
StateT (Tree a, Stack (Next a)) (Writer a) ()
=
accumTree''''''' $ do
while <- zoom _1 get
tree case tree of
Nil -> do
<- zoom _2 pop
c case c of
Just (First left) -> do
$ put left
zoom _1 pure Continue
Just (Second content) -> do
lift (tell content)$ put Nil
zoom _1 pure Continue
Just (Third right) -> do
$ put right
zoom _1 pure Continue
Nothing -> pure Break
Node {..} -> do
$ push (Third right)
zoom _2 $ push (Second content)
zoom _2 $ push (First left)
zoom _2 $ put Nil
zoom _1 pure Continue
Does this look familiar? This is the exact same as
printTree'''''''
from above, except
print content
is replaced by tell content
. We
run this function using its companion function
sumTree'''''''
that executes the Writer
and
State
effects:
-- | Run the unrolled `accumTree'''''''` program
-- and calculate the sum of the content values of the tree.
-- >>> sumTree''''''' exampleTree
-- 28
sumTree''''''' :: Num a => Tree a -> a
=
sumTree''''''' tree $ execWriter $ runStateT accumTree''''''' (Sum <$> tree, []) getSum
This settles the issue. Now you know how to write a recursive function that doesn't recurse.
Next time, we will take a closer look at the Tree
type,
and we will make sure it doesn't recurse either. This will be super
ugly, I can't wait for that.
Feb 2, 2022
The adventure continues in this "Unrecurse" sequel. Previously, we bravely faced turmoil and confusion in a cruel world in which Haskell suddenly stopped supporting recursive function calls. We barely escaped the wrath of the compiler. This time, we try to survive an even more extreme situation: Haskell without recursive data types! It is the ultimate test of our programming skills. Will we make it through the final challenge, or is all hope lost? Join us in this journey about tapes and tribulations.
Jan 1, 2022
Nil.