Trees and continuation passing style

For no reason in particular I decided to revisit tree traversal as a kind of programming kata. There are two main kinds of tree traversal:

  • Depth first – This is where you go all the way down a tree’s branches first before bubbling up to do work. With a tree like below, you’d hit c before doing any work since it’s the deepest part of the tree (assuming you iterated left first then right)
         a
        / \
       b   e
     /  \
    c    d
    
  • Breadth first – This is where you hit all the nodes at the level you’re on before going further. So with the same tree, you’d hit a, then b, then e, then c and d.

Being as I actually hate tree traversal, and having to think about it, I decided that whatever I write better be extensible and clean.

Depth first

Here is a simple DFS traversal

private List<T> DepthFirstFlatten<T>(T root, Func<T, List<T>> edgeFunction) where T : class
{
    if (root == null)
    {
        return null;
    }

    var totalNodes = new List<T> { root };

    var edges = edgeFunction(root);

    if (edges != null && edges.Any())
    {
        foreach (var edge in edges)
        {
            if (edge != null)
            {
                totalNodes.AddRange(DepthFirstFlatten(edge, edgeFunction));
            }
        }
    }

    return totalNodes;
} 

In this case I’m just flattening the list and using a function to return all the edges. This way I can re-use the same depth algorithm for any kind of graph, not just a tree (assuming acyclic). To handle cycles you would need to pass the total processed nodes as an accumulator and test if the current node was already processed and if so skip it.

Breadth first

For the BFS, it’s very similar, except instead of using recursion it uses the standard iterative way of doing it with a queue:

private List<T> BreadthFlatten<T>(T root, Func<T, List<T>> edgeFunction) where T : class
{
    var queue = new Queue<T>();

    queue.Enqueue(root);

    var allNodes = new List<T>();

    while (queue.Any())
    {
        var head = queue.Dequeue();

        if (head == null)
        {
            continue;
        }

        allNodes.Add(head);

        edgeFunction(head).ForEach(queue.Enqueue );
    }

    return allNodes;
}

Same kind of deal here. This one is nice because it’s not limited by stack depth.

Also, for both traversals, if you wanted to you could pass in an action to do work each time a node was processed. Here is an example using the following tree

     1
    / \
   2   3
 /  \   \
4    5   6

Below is a small class representing a binary tree

class Node<T>
{
    public Node(T data, Node<T> left = null, Node<T> right = null)
    {
        Item = data;
        Left = left;
        Right = right;
    }

    public Node<T> Left { get; set; }
    public Node<T> Right { get; set; }

    public T Item { get; set; }
}

And our unit test to print out the different traversal types

[Test]
public void DepthFlatten()
{
    var tree = new Node<int>(1,
                            new Node<int>(2, new Node<int>(4), new Node<int>(5)),
                            new Node<int>(3, null, new Node<int>(6)));

    Func<Node<int>, List<Node<int>>> extractor = node => new List<Node<int>> {node.Left, node.Right};

    Console.WriteLine("Depth");
    DepthFirstFlatten(tree, extractor).ForEach(n => Console.WriteLine(n.Item));

    Console.WriteLine("Breadth");
    BreadthFlatten(tree, extractor).ForEach(n => Console.WriteLine(n.Item)); ;
}

Which prints out:

Depth
1
2
4
5
3
6

Breadth
1
2
3
4
5
6

DFS stack agnostic

We can even change the DFS to not use recursion in this case so that it’s agnostic of how deep the tree is. In this scenario, unlike the BFS, you’d use a stack instead of a queue. This way you are pushing on the deepest nodes and then immediately processing them. This contrasts with the queue where you enqueue the deepest nodes but process the queue FIFO (first in first out), meaning you process all the nodes at the current depth first before moving to the next depth.

private List<T> DepthFirstFlattenIterative<T>(T root, Func<T, List<T>> edgeFunction) where T : class
{
    var stack = new Stack<T>();

    stack.Push(root);

    var allNodes = new List<T>();

    while (stack.Any())
    {
        var head = stack.Pop();

        if (head == null)
        {
            continue;
        }

        allNodes.Add(head);

        var edges = edgeFunction(head);

        edges.Reverse();
        
        edges.ForEach(stack.Push);
    }

    return allNodes;
} 

The reverse is only there to be consistent with the left tree descent. Otherwise it goes down the right branch first. This spits out

Depth iterative
1
2
4
5
3
6

DFS with continuation passing

There is yet another way to do tree traversal that is common in functional languages. You can do what is called “continuation passing style”. Doing it this way you can actually get tail recursive code while iterating over multiple tree branches.

Below is some F# code to count the number of nodes in a tree. The tree I’m using as the sample looks like this

       1
     /   \ 
   2      3
 /  \      \
4    5      6

The total nodes here is 6, which is what you get with the code below.

open System

type Tree = 
    | Leaf of int
    | Node of int * Tree Option * Tree Option


let countNodes tree = 
    let rec countNodes' treeOpt cont = 
        match treeOpt with 
            | Some tree -> 
                match tree with 
                | Leaf item -> cont 1
                | Node (currentValue, left, right) ->
                    countNodes' left (fun leftCount ->
                                          countNodes' right (fun rightCount ->
                                                                 cont(1 + leftCount + rightCount)))
            | None -> cont 0
                    
    countNodes' tree id


let leftBranch = Node(2, Some(Leaf(4)), Some(Leaf(5)))

let rightBranch = Node(3, None, Some(Leaf(6)))

let tree = Node(1, Some(leftBranch), Some(rightBranch))

let treeNodeCount = countNodes (Some(tree))

But what the hell is going on here? It’s really not apparent when you first look at it what executes what and when.

The trick here is to pass around a function to each iteration that closes over what the next work should be. To be fair, its hard to wrap your mind around what is happening, so lets trace this out. I’ve highlighted each of the continuations and given them an alias so you can see how they are re-used elsewhere. Each time the continuation is called I also show the expanded form following the ->.

ContinuationPassing Trace

You can see how each iteration captures the work to do next. Eventually the very last work that needs to be done is the first function you passed in as the function seed. In this case, it’s the built in id function that returns whatever value is given to it (which turns out to be 6, which is how many nodes are in the tree). You can see the ordering of the traversal is the exact same as the other DFS traversals earlier, except this time everything is tail recursive.

Post a comment

You may use the following HTML:
<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>