Demystifying Monads using Python
The subject of monads in functional programming is regarded by many as
a scary and difficult topic. Often it is presented in a very abstract and mathematical way, which
tends to have an eye-glazing effect on people who aren't hard-core
mathematicians. This is an attempt to convey the basic ideas behind
monads to someone who is a programmer rather than a mathematician. In
particular, it's aimed at someone who has at least a basic familiarity
with Python and the techniques of traditional, non-functional
programming.
First of all, what is a monad and why would you want one? In the
context of functional programming, a monad is a technique that allows a
stateful process to be expressed by composing stateless functions. Functional
programmers are interested in doing this because it opens up
possibilities for optimisations, and also because it enables dealing
with I/O – which involves an inherently stateful object, the outside
world – in a functional way.
As an example of a computation carried out using a stateful process,
I'm going to take the following task: Given a list of numbers, produce
another list that contains (in no particular order) all the unique
numbers in the input list.
A non-functional way to write this in Python would be:
def dedup(l):
s = set()
for x in l:
s.add(x)
return list(s)
However, you could never write that in any purely functional language,
because it makes in-place changes to data, and functional languages
don't provide any way to do that. We'd like another way to write it
that doesn't involve mutating data, and also doesn't assign to a given
variable more than once – another thing that functional languages don't
allow.
To do this, I'm going to build the program up out of functions, all of
which follow a particular pattern. The pattern is that they all take
two arguments: a state object, and another function that follows the
same pattern. (They may also take other arguments, but they must take
at least those two arguments.)
The function argument is called a continuation,
and it represents "the rest of the computation". The task of each of
these functions is to take the state object given to it, construct a new state object from it in some way, pass that new state to the continuation, and return whatever the continuation returns.
We can define the term "monad" in another way now: A monad is the
collection of all functions that follow this pattern, for some
particular type of state object. (This is not quite the full definition
of a monad, nor the most general one, but it will do for now.)
In our case, the state object is a
set, so I'm going to call this kind of function a "set manipulating
function", or SMF (which you can pronounce "smurf" if you feel so
inclined).
When creating a monad, the first thing you need to do is write some
functions for performing primitive operations on the state. For our
purposes here, we only need one operation, that of adding an item to
the set. So our first SMF will be:
def add(x, f, s):
return f(s | {x})
Here, f is the continuation and s is the input state. The output state passed to the continuation is the input state with the element x added. Note that we build a new set instead of modifying the one we were given, so we are not breaking any functional rules.
Another thing we will want to be able to do is turn the set into a
list, because the problem specification said we were to produce the
result as a list, not a set. So let's define another SMF to do that:
def listify(f, s):
return f(list(s), s)
Note that this function does not
return the list – instead, it passes it as an extra argument to the
continuation. This is a general characteristic of functions belonging
to a monad – if the function produces a result, it does so by passing
it to the continuation.
Also, because it doesn't need to change the state in any way, listify just passes the state it was given on to the continuation.
Now we can express the dedent computation itself in the form of an SMF:
def dedup(l, f, s):
if l:
return add(l[0], lambda s2: dedup(l[1:], f, s2), s)
else:
return listify(f, s)
This may look rather convoluted, but it breaks down like this. If the list is not empty, we add the first element of the list to the set, giving a new set s2. We then recursively dedup the rest of the list, using s2 as the starting state.
If the list is empty, then we are finished, so we listify what we have and pass that to our continuation.
This is well and good, and accomplishes the task, but it has a
rather
inconvenient API – instead of getting the result returned to us, we
have to supply a function to which the result will be passed. Also, we
need to supply a starting state. We can make it easier to use by
writing a top-level wrapper:
def run_smf(f, a):
return f(a, lambda r, s: r, set())
This function takes any SMF expecting a single additional argument a, and runs it with an empty set as the initial state. The continuation takes the final result r
and "turns it around" so that it passes back up through all the nested
continuation calls and emerges as the return value. We don't need the
final state any more, so we ignore it.
Let's try it out:
>>> run_smf(dedup, [5,4,3,6,8,2,3,6,8,4,5])
[2, 3, 4, 5, 6, 8]
Why bother?
At this point, you may be wondering why we went to all this trouble to
turn a simple, straightforward piece of code into something so
convoluted. We could have written something like this:
def func_dedup(l, s = set()):
if l:
return func_dedup(l[1:], s | {l[0]})
else:
return list(s)
and it would have been just as functionally pure.
Well, if you look carefully at our SMFs, you'll notice a few things. First, none of them ever returns
a state object; states are only passed deeper down the call chain;
what's more, they're only passed as the final argument to a
continuation. Second, the only functions that do anything directly with
a state object are the primitive operations add and listify. Third, the only one that creates a new state object is add, and the only thing it does with it is pass it to its continuation.
As a consequence, as soon as a new state object has been created, the old one is never seen again. So we could replace the implementation of add with this:
def add(x, f, s):
s.add(x)
return f(s)
and as long as all our SMFs manipulate the state solely through the
primitive operations, nobody would ever notice the difference.
In other words, the implementation is free to use in-place mutations of
the state object – which obviously allows considerable gains in
efficiency, both in time and memory usage – without giving up any
functional purity.
Also, it's worth pointing out that a major reason the SMF version
looks so horrible is that (a) we wrote it in Python instead of Haskell,
and (b) we didn't make use of currying. (All right, two major reasons...)
To illustrate what I mean, I'm going to make a series of
transformations to turn the code into something more like the way it
would be written in a Haskell-like language.
Lambdification
The first step is to rewrite all our SMFs using lambdas instead of
defs. This doesn't change anything much, but it makes the next couple
of steps easier.
add = lambda x, f, s: f(s | {x})
listify = lambda f, s: f(list(s), s)
dedup = lambda l, f, s: (
add(l[0], lambda s2: dedup(l[1:], f, s2), s)
if l else
listify(f, s)
)
First currying
Next, we curry the s argument to our SMFs, and update the wrapper to match.
add = lambda x, f: lambda s: f(s | {x})
listify = lambda f: lambda s: f(list(s))(s)
dedup = lambda l, f: lambda s: (
add(l[0], lambda s2: dedup(l[1:], f)(s2)) (s)
if l else
listify(f)(s)
)
def run_smf(f, a):
return f(a, lambda r: lambda s: r)(set())
At this point, we can make some simplifications. Note that dedup consists of two branches, each of which constructs a function and then immediately calls it with s as an argument. We can refactor that so that we conditionally construct one function or the other, then call it.
dedup = lambda l, f: lambda s: (
(add(l[0], lambda s2: dedup(l[1:], f)(s2))
if l else
listify(f))(s)
)
The next thing to note is that the expression that constructs the function to be called with s, shown in red above, only depends on the parameters l and f of the outer lambda. Switching back to def for a moment, we could write it as
def dedup(l, f):
g = (add(l[0], lambda s2: dedup(l[1:], f)(s2))
if l else
listify(f))
return lambda s: g(s)
It's not too hard to see that the final lambda in this isn't doing anything useful, and dedup might as well return g directly:
dedup = lambda l, f: (
add(l[0], lambda s2: dedup(l[1:], f)(s2))
if l else
listify(f)
)
Look what's happened: We've removed all explicit references to s from this function!
What happens if we carry on and do some more currying? Let's find out.
Second currying
Currying all the arguments to our SMFs leads to this:
add = lambda x: lambda f: lambda s: f(s | {x})
dedup = lambda l: lambda f: (
add(l[0])(lambda s2: dedup(l[1:])(f)(s2))
if l else
listify(f)
)
def run_smf(f, a):
return f(a)(lambda r: lambda s: r)(set())
To progress further, we need to introduce another SMF, one that
abstracts the notion of "sequencing". It will take two SMFs and perform
one followed by the other, with the output state from the first SMF
becoming the input state to the second.
sequ = lambda f, g: lambda h: f(lambda s1: g(h)(s1))
Here, f and g are the two SMFs to be run in sequence, and h is the continuation to be run after both of them.
Armed with sequ, we can now write dedup as
dedup = lambda l: lambda f: (
sequ(add(l[0]), dedup(l[1:]))(f)
if l else
listify(f)
)
Again, we can factor out construction of a function followed by calling it with f, and then drop the lambda f, leaving just
dedup = lambda l: (
sequ(add(l[0]), dedup(l[1:]))
if l else
listify
)
Now, not only have we removed explicit mention of the state, but the continuation
as well. What's more, it's starting to look a lot like a sequential
program. We can read it fairly clearly as "add the first element to the
set, then add the rest of the elements to the set, unless the list is
empty, in which case turn the set into a list and stop."
From here, it shouldn't take much imagination to see that, once we have a few primitives such as add, sequ and listify, we can write everything else in the style of dedup,
never explicitly referring to the states or continuations. Doing this
not only makes the code simpler and clearer than it would otherwise be,
but if the state is not explicit, you can't accidentally use it in a
"wrong" way, such as keeping a reference to an old state before
mutation. The implementation helps to keep you honest.
This is particularly important for implementations that want to use in-place mutation, because then it's vital that it be impossible to abuse the state object.
This can be accomplished by having the primitives all be built-ins that
never expose the state, so that you can't write code that abuses it
even if you want to.
Haskell
While we've simplified things quite a lot, it still suffers quite badly
from being in Python rather than Haskell. In this section I'm going to
show you the Haskell version.
Just in case you're not familiar with Haskell and it's ilk, here's a
crash course. All functions take exactly one argument. Calling function
f with argument x is written f x. Function application associates to the left, so f x y is equivalent to (f x) y.
Functions of more than one argument are normally handled by
currying. Haskell has a particularly concise syntax for defining and
using curried functions. The Haskell definition
f x y = x + y
is equivalent to the Python definition
f = lambda x: lambda y: x + y
and the Haskell expression f x y is equivalent to the Python expression f(x)(y).
So, our Python add function
add = lambda x: lambda f: lambda s: f(s | {x})
would be written in Haskell as
add x f s = f (insert x s)
where insert is Haskell's standard library function for adding an element to a set.
The rest of our SMFs could be
translated as:
listify f s = f (toList s) s
sequ f g h = f (\s1 -> g h s1)
(\x -> y is Haskell's way of writing lambda x: y.)
dedup (x:t) = sequ (add x) (dedup t)
dedup [] = listify
This definition of dedup
makes use of another Haskell feature: functions can be defined in terms
of a series of cases, with pattern matching used to deconstruct
arguments on the left hand side. (Lists in Haskell are linked lists,
and h:t represents a list whose first element is h and the rest of the list is t.)
In practice, sequ would be defined as an infix operator:
(f >> g) h = f (\s1 -> g h s1)
and then the first case of dedup could be written
dedup x:t = (add x) >> (dedup t)
Using an operator makes it easier to sequence more than two operations, e.g.
add3 x y z = (add x) >> (add y) >> (add z)
Why implicit currying is a good thing
The power of Haskell's notation starts to become apparent if you consider the version of dedup we had earlier before we eliminated the f and s arguments. In Haskell it would be
dedup x:t f s = add x (\s2 -> dedup t f s2) s
dedup [] f s = listify f s
In the Python version, it wasn't immediately obvious that the lambda: s and the final call could be eliminated, but here it's almost trivial: you just cancel the s off both sides of each case, leaving
dedup x:t f = add x (\s2 -> dedup t f s2)
dedup [] f = listify f
Then we notice that
add x (\s2 -> dedup t f s2)
= (add x) (\s2 -> dedup t f s2)
= ((add x) >> (dedup t)) f
so we can rewrite dedup further as
dedup x:t f = ((add x) >> (dedup t)) f
dedup [] f = listify f
and then cancel the f, giving the final version above.
This is why mathematicians like Haskell so much: it lets you perform algebra on your programs!
End of Part 1
This has been a longish journey, and I hope I haven't lost too many
people along the way. In Part 2 I'll show how monads can be used to
express I/O in a functional way.