Demystifying Monads using Python

The subject of monads in functional programming is regarded by many as a scary and difficult topic. Often it is presented in a very abstract and mathematical way, which tends to have an eye-glazing effect on people who aren't hard-core mathematicians. This is an attempt to convey the basic ideas behind monads to someone who is a programmer rather than a mathematician. In particular, it's aimed at someone who has at least a basic familiarity with Python and the techniques of traditional, non-functional programming.

First of all, what is a monad and why would you want one? In the context of functional programming, a monad is a technique that allows a stateful process to be expressed by composing stateless functions. Functional programmers are interested in doing this because it opens up possibilities for optimisations, and also because it enables dealing with I/O – which involves an inherently stateful object, the outside world – in a functional way.

As an example of a computation carried out using a stateful process, I'm going to take the following task: Given a list of numbers, produce another list that contains (in no particular order) all the unique numbers in the input list.

A non-functional way to write this in Python would be:
def dedup(l):
s = set()
for x in l:
s.add(x)
return list(s)
However, you could never write that in any purely functional language, because it makes in-place changes to data, and functional languages don't provide any way to do that. We'd like another way to write it that doesn't involve mutating data, and also doesn't assign to a given variable more than once – another thing that functional languages don't allow.

To do this, I'm going to build the program up out of functions, all of which follow a particular pattern. The pattern is that they all take two arguments: a state object, and another function that follows the same pattern. (They may also take other arguments, but they must take at least those two arguments.)

The function argument is called a continuation, and it represents "the rest of the computation". The task of each of these functions is to take the state object given to it, construct a new state object from it in some way, pass that new state to the continuation, and return whatever the continuation returns.

We can define the term "monad" in another way now: A monad is the collection of all functions that follow this pattern, for some particular type of state object. (This is not quite the full definition of a monad, nor the most general one, but it will do for now.)

In our case, the state object is a set, so I'm going to call this kind of function a "set manipulating function", or SMF (which you can pronounce "smurf" if you feel so inclined).

When creating a monad, the first thing you need to do is write some functions for performing primitive operations on the state. For our purposes here, we only need one operation, that of adding an item to the set. So our first SMF will be:
def add(x, f, s):
return f(s | {x})
Here, f is the continuation and s is the input state. The output state passed to the continuation is the input state with the element x added. Note that we build a new set instead of modifying the one we were given, so we are not breaking any functional rules.

Another thing we will want to be able to do is turn the set into a list, because the problem specification said we were to produce the result as a list, not a set. So let's define another SMF to do that:
def listify(f, s):
return f(list(s), s)
Note that this function does not return the list – instead, it passes it as an extra argument to the continuation. This is a general characteristic of functions belonging to a monad – if the function produces a result, it does so by passing it to the continuation.

Also, because it doesn't need to change the state in any way, listify just passes the state it was given on to the continuation.

Now we can express the dedent computation itself in the form of an SMF:
def dedup(l, f, s):
if l:
return add(l[0], lambda s2: dedup(l[1:], f, s2), s)
else:
return listify(f, s)
This may look rather convoluted, but it breaks down like this. If the list is not empty, we add the first element of the list to the set, giving a new set s2. We then recursively dedup the rest of the list, using s2 as the starting state.

If the list is empty, then we are finished, so we listify what we have and pass that to our continuation.

This is well and good, and accomplishes the task, but it has a rather inconvenient API – instead of getting the result returned to us, we have to supply a function to which the result will be passed. Also, we need to supply a starting state. We can make it easier to use by writing a top-level wrapper:
def run_smf(f, a):
return f(a, lambda r, s: r, set())
This function takes any SMF expecting a single additional argument a, and runs it with an empty set as the initial state. The continuation takes the final result r and "turns it around" so that it passes back up through all the nested continuation calls and emerges as the return value. We don't need the final state any more, so we ignore it.

Let's try it out:
>>> run_smf(dedup, [5,4,3,6,8,2,3,6,8,4,5])
[2, 3, 4, 5, 6, 8]

Why bother?

At this point, you may be wondering why we went to all this trouble to turn a simple, straightforward piece of code into something so convoluted. We could have written something like this:
def func_dedup(l, s = set()):
if l:
return func_dedup(l[1:], s | {l[0]})
else:
return list(s)
and it would have been just as functionally pure.

Well, if you look carefully at our SMFs, you'll notice a few things. First, none of them ever returns a state object; states are only passed deeper down the call chain; what's more, they're only passed as the final argument to a continuation. Second, the only functions that do anything directly with a state object are the primitive operations add and listify. Third, the only one that creates a new state object is add, and the only thing it does with it is pass it to its continuation.

As a consequence, as soon as a new state object has been created, the old one is never seen again. So we could replace the implementation of add with this:
def add(x, f, s):
s.add(x)
return f(s)
and as long as all our SMFs manipulate the state solely through the primitive operations, nobody would ever notice the difference.

In other words, the implementation is free to use in-place mutations of the state object – which obviously allows considerable gains in efficiency, both in time and memory usage – without giving up any functional purity.

Also, it's worth pointing out that a major reason the SMF version looks so horrible is that (a) we wrote it in Python instead of Haskell, and (b) we didn't make use of currying. (All right, two major reasons...)

To illustrate what I mean, I'm going to make a series of transformations to turn the code into something more like the way it would be written in a Haskell-like language.

Lambdification

The first step is to rewrite all our SMFs using lambdas instead of defs. This doesn't change anything much, but it makes the next couple of steps easier.
add = lambda x, f, s: f(s | {x})

listify = lambda f, s: f(list(s), s)

dedup = lambda l, f, s: (
add(l[0], lambda s2: dedup(l[1:], f, s2), s)
if l else
listify(f, s)
)

First currying

Next, we curry the s argument to our SMFs, and update the wrapper to match.
add = lambda x, f: lambda s: f(s | {x})

listify = lambda f: lambda s: f(list(s))(s)

dedup = lambda l, f: lambda s: (
add(l[0], lambda s2: dedup(l[1:], f)(s2)) (s)
if l else
listify(f)(s)
)

def run_smf(f, a):
return f(a, lambda r: lambda s: r)(set())
At this point, we can make some simplifications. Note that dedup consists of two branches, each of which constructs a function and then immediately calls it with s as an argument. We can refactor that so that we conditionally construct one function or the other, then call it.
dedup = lambda l, f: lambda s: (
(add(l[0], lambda s2: dedup(l[1:], f)(s2))
if l else
listify(f))(s)
)
The next thing to note is that the expression that constructs the function to be called with s, shown in red above, only depends on the parameters l and f of the outer lambda. Switching back to def for a moment, we could write it as
def dedup(l, f):
g = (add(l[0], lambda s2: dedup(l[1:], f)(s2))
if l else
listify(f))
return lambda s: g(s)
It's not too hard to see that the final lambda in this isn't doing anything useful, and dedup might as well return g directly:
dedup = lambda l, f: (
add(l[0], lambda s2: dedup(l[1:], f)(s2))
if l else
listify(f)
)
Look what's happened: We've removed all explicit references to s from this function!

What happens if we carry on and do some more currying? Let's find out.

Second currying

Currying all the arguments to our SMFs leads to this:
add = lambda x: lambda f: lambda s: f(s | {x})

dedup = lambda l: lambda f: (
add(l[0])(lambda s2: dedup(l[1:])(f)(s2))
if l else
listify(f)
)

def run_smf(f, a):
return f(a)(lambda r: lambda s: r)(set())
To progress further, we need to introduce another SMF, one that abstracts the notion of "sequencing". It will take two SMFs and perform one followed by the other, with the output state from the first SMF becoming the input state to the second.
sequ = lambda f, g: lambda h: f(lambda s1: g(h)(s1))
Here, f and g are the two SMFs to be run in sequence, and h is the continuation to be run after both of them.

Armed with sequ, we can now write dedup as
dedup = lambda l: lambda f: (
sequ(add(l[0]), dedup(l[1:]))(f)
if l else
listify(f)
)
Again, we can factor out construction of a function followed by calling it with f, and then drop the lambda f, leaving just
dedup = lambda l: (
sequ(add(l[0]), dedup(l[1:]))
if l else
listify
)
Now, not only have we removed explicit mention of the state, but the continuation as well. What's more, it's starting to look a lot like a sequential program. We can read it fairly clearly as "add the first element to the set, then add the rest of the elements to the set, unless the list is empty, in which case turn the set into a list and stop."

From here, it shouldn't take much imagination to see that, once we have a few primitives such as add, sequ and listify, we can write everything else in the style of dedup, never explicitly referring to the states or continuations. Doing this not only makes the code simpler and clearer than it would otherwise be, but if the state is not explicit, you can't accidentally use it in a "wrong" way, such as keeping a reference to an old state before mutation. The implementation helps to keep you honest.

This is particularly important for implementations that want to use in-place mutation, because then it's vital that it be impossible to abuse the state object. This can be accomplished by having the primitives all be built-ins that never expose the state, so that you can't write code that abuses it even if you want to.

Haskell

While we've simplified things quite a lot, it still suffers quite badly from being in Python rather than Haskell. In this section I'm going to show you the Haskell version.

Just in case you're not familiar with Haskell and it's ilk, here's a crash course. All functions take exactly one argument. Calling function f with argument x is written f x. Function application associates to the left, so f x y is equivalent to (f x) y.

Functions of more than one argument are normally handled by currying. Haskell has a particularly concise syntax for defining and using curried functions. The Haskell definition

f x y = x + y

is equivalent to the Python definition

f = lambda x: lambda y: x + y

and the Haskell expression f x y is equivalent to the Python expression f(x)(y).

So, our Python add function
add = lambda x: lambda f: lambda s: f(s | {x})
would be written in Haskell as
add x f s = f (insert x s)
where insert is Haskell's standard library function for adding an element to a set.

The rest of our SMFs could be translated as:
listify f s = f (toList s) s

sequ f g h = f (\s1 -> g h s1)
(\x -> y is Haskell's way of writing lambda x: y.)
dedup (x:t) = sequ (add x) (dedup t)
dedup [] = listify
This definition of dedup makes use of another Haskell feature: functions can be defined in terms of a series of cases, with pattern matching used to deconstruct arguments on the left hand side. (Lists in Haskell are linked lists, and h:t represents a list whose first element is h and the rest of the list is t.)

In practice, sequ would be defined as an infix operator:
(f >> g) h = f (\s1 -> g h s1)
and then the first case of dedup could be written
dedup x:t = (add x) >> (dedup t)
Using an operator makes it easier to sequence more than two operations, e.g.
add3 x y z = (add x) >> (add y) >> (add z)

Why implicit currying is a good thing

The power of Haskell's notation starts to become apparent if you consider the version of dedup we had earlier before we eliminated the f and s arguments. In Haskell it would be
dedup x:t f s = add x (\s2 -> dedup t f s2) s
dedup [] f s = listify f s
In the Python version, it wasn't immediately obvious that the lambda: s and the final call could be eliminated, but here it's almost trivial: you just cancel the s off both sides of each case, leaving
dedup x:t f = add x (\s2 -> dedup t f s2)
dedup [] f = listify f
Then we notice that
add x (\s2 -> dedup t f s2) 
= (
add x) (\s2 -> dedup t f s2)
= ((add x) >> (dedup t)) f
so we can rewrite dedup further as
dedup x:t f = ((add x) >> (dedup t)) f
dedup [] f = listify f
and then cancel the f, giving the final version above.

This is why mathematicians like Haskell so much: it lets you perform algebra on your programs!

End of Part 1

This has been a longish journey, and I hope I haven't lost too many people along the way. In Part 2 I'll show how monads can be used to express I/O in a functional way.