Most of the time, you use divide-n-conquer to solve one task, like sorting or selection. But in this section, you will learn to use divide-n-conquer in a more interesting way, to solve two tasks simultaneously. One task might be the main task while the other is a byproduct that tags along. In other words, your recursive function will return two solutions, one for each task. This would be hard to do for C/C++/Java, but super easy in Python:
def solve(problem):
# divide-conquer-combine
... return a, b
and you can use Python’s “pattern-matching” feature to receive two
things from a function: a, b = solve(sub_problem)
. Here is
a complete template:
def solve(problem): # two tasks
= ... # divide problem into two subproblems
subp1, subp2 = solve(subp1) # conquer two tasks for subp1
a1, b1 = solve(subp2) # conquer two tasks for subp2
a2, b2 = ... # combine (a1, b1) and (a2, b2) to (a, b)
a, b return a, b # solutions of the two tasks for problem
We will demonstrate two classical examples of this paradigm.
Caveat: unlike previous sections, we need to distinguish “problem”
and “task” in this section. For example, a problem might be an array
like [4,3,1,2]
and a task would be “sort the array” or
“find the median in this array”. If a problem is a binary search tree, a
task could be sorted(tree)
or
search(tree, query)
or depth(tree)
.
First, let’s consider how to efficiently count the number of
inversions in an unsorted array. We define an inversion (or “inverted
pair”) to be \((a_i, a_j)\) where \(i<j\) but \(a_i>a_j\). For example, in
[4, 1, 3, 2]
there are 4 inversions:
4, 1), (4, 3), (4, 2), (3, 2) (
(As a special case, a sorted array has 0 inversions).
Obviously you can do it in \(O(n^2)\) time by two nested loops which enumerates all pairs. But can you do it faster?
Well, the next faster complexity is \(O(n\log n)\). Can you do it that fast? Whenever you saw \(O(n\log n)\) (esp. in interviews), you should think of sorting, because this complexity comes naturally from many sorting algorithms (quicksort, mergesort, heapsort) and we’ll see in later sections that it is the fastest (internal) sorting can ever be. So \(O(n\log n)\) is inherently related to sorting. Now, can you use sorting to count the number of inversions?
In fact you can! And not just with one sorting algorithm. In this section we’ll see how to use mergesort to solve it, but you should think about (as an exercise) how to use quicksort for it as well. The basic idea is to tag along the counting of inversion onto mergesort, so that the former becomes a byproduct of sorting.
The basic idea is very simple:
a
into left
and right
as in mergesortleft
, which should return not only
sorted_left
, but also inv_left
, the number of
inversions within left
right
, which should return not only
sorted_right
, but also inv_right
, the number
of inversions within right
sorted_left
and sorted_right
into sorted_a
, but during this combination, also count the
number of crossing inversions between left
and
right
along the way (save it in
inv_cross
)sorted_a
and
inv_left + inv_right + inv_cross
(the latter is the total
number of inversions within a
)The only new thing is how to calculate the crossing inversions
between left
and right
. Why crossing
inversions only? Because by the principle of divide-n-conquer, the
internal inversions within left
and
right
should already be solved in those two subproblems,
and your job at the current level (a
) is just just counting
the remaining inversions (within a
) that are beyond the
scopes of left
or right
alone. Again, the
principle of divide-conquer-combine is that the vast majority of the job
is already done by the “conquer” steps, and you only need to do the
“combination” step that your children can’t do by themselves.
Here is an example: let’s say left = [5, 1, 7]
and
right = [6, 4, 2]
and after the two conquers, we get
= [1, 5, 7], inv_left = 1 (only 1 pair: (5,1))
sorted_left = [2, 4, 6], inv_right = 3 (3 pairs: (6,4) (6,2) (4,2)) sorted_right
Now we combine them like in mergesort:
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1 inv_cross = 0
Whenever you take a left number, there is obviously no inversions,
but when you take a right number you definitely have encountered (at
least one) inversions, e.g., this (5, 2)
pair.
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1, 2 inv_cross = 1: (5,2)
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1, 2, 4 inv_cross = 2: (5,2), (5,4)
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1, 2, 4, 5 inv_cross = 2: (5,2), (5,4)
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1, 2, 4, 5, 6 inv_cross = 3: (5,2), (5,4), (7,6)
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1, 2, 4, 5, 6, 7 inv_cross = 3: (5,2), (5,4), (7,6)
So we counted 3 crossing inversions
((5,2), (5,4), (7,6)
), but did we miss anything? Clearly,
there are two other crossing inversions that we didn’t count:
(7, 2), (7,4)
. What was the problem?
Here it is: in each step, when you take a number from the right,
there are more than just a single pair of inversion. In fact, all the
remaining numbers in left
are inverted with the current
number in right
:
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1, 2 inv_cross = 2: (5,2) implies (7,2)
1, 5, 7] [2, 4, 6]
[^ ^
*=>
1, 2, 4 inv_cross = 4: (5,2), (7,2), (5,4) implies (7,4)
Now we recovered the two missing (implied) inversions. In general,
when left[i] > right[j]
,
<<<< i >>>>>>>] [..... j ......]
[^ ^
*=>
the current pair (left[i], right[j])
is obviously
inverted, but that also implies that all remaining numbers
left[i+1]
, left[i+2]
, … (those
>
numbers above) are also inverted with
right[j]
, because they are even bigger than
left[i]
:
'] >= left[i] > right[j] for i' = i+1, i+2, ... left[i
So you should add |left|-i
to inv_cross
whenever you take a number from right
.
Now we have a complete method of tagging along number of inversions while doing mergesort. The complexity stays the same, since this tagging along only costs \(O(1)\) per step, or \(O(n)\) total, in the “combine” part, which doesn’t change anything.
Caveat: for illustration purposes, we listed the inversion
pairs explicitly when increasing inv_cross
, but in reality
we can’t do that (otherwise it would cost \(O(n)\) per step instead of \(O(1)\)). In other words, we can count
the number of inversions in \(O(n\log
n)\) time, but we can’t collect all inversion pairs in that time;
the latter task has to be \(O(n^2)\)
because in the worst case you have that many inversion pairs (if the
input is inversely sorted)!
Exercise: How would you do the same problem with quicksort
instead of mergesort? Note that unlike mergesort, quicksort is
divide-heavy and combine-light, meaning most of the work is done in the
partition step, so you should also count the crossing inversions in the
partition step. (think about it: there are no crossing inversions
after partitioning, or at the combination step, because
everything in left
is smaller than everything in
right
).
A more interesting example in this “recursion with byproduct” paradigm is to use it to find the longest path in a binary tree (doesn’t need to be a binary search tree). For example, for this tree:
4
/ \
2 6
/ \ / \
1 3 5 7
/ \
0 9
/
8
the longest path is 0-1-2-4-6-7-9-8
with a length of 7
edges.
The first observation is we need to go as deep as possible on the two ends; obviously if you stop somewhere in the middle, it’s not optimal. The second “observation” that many students have when first looking at this problem is that the longest path has to go through the root, which is the case in the above example. But is it always the case?
Actually no! What if one side of the tree is tiny and the other side is huge? Then the longest path would be completely embedded in that bigger side. For example:
2 2
/ \ \
1 6 6
/ \ / \
5 7 5 7
/ \ / \
3 9 3 9
/ \ /
\ 4 8 4 8
Clearly, for the left tree, the longest path would go be
4-3-5-6-7-9-8
which does not go through the root. If we
remove node 1
, it would be even more obvious (the right
tree).
So how do solve this problem? Well, we can let each subtree return the longest path that is completely embedded in that subtree. We have three cases:
So you can (naively) write two recursive functions for this algorithm:
def depth(t):
return 0 if t == [] else max(depth(t[0]), depth(t[2]))+1
def longest(t):
return 0 if t == [] else max(longest(t[0]), longest(t[2]), depth(t[0])+depth(t[2]))
However, this is a really bad solution, not just for aesthetic reasons. Its worst-case complexity is actually \(O(n^2)\) (see below)!
The better solution is to do these two tasks (longest and depth) together, just like sorting and counting inversions. Each node should return two numbers: not just the longest path, but also its depth. This way we guarantee the runtime is \(O(n)\) because it’s just a tree traversal (\(O(1)\) work for combining the subsolutions \((l_1, d_1)\) for left and \((l_2, d_2)\) for right to return \((l, d)\)).
Now let’s see how bad the separate-recursion solution would be.
Clearly, depth()
is \(O(n)\), not \(O(1)\). So in the balanced case:
\[ T(n) = 2T(n/2) + O(n) = O(n\log n)\]
But in the worst-case (single chain):
\[ T(n) = T(n-1) + O(n) = O(n^2) \]
This is intuitive, because you call depth()
at each
node, so \(n + (n-1) + ... +
1=O(n^2)\). Note that these depth()
calls have many
redundant calculations, because depth(root)
would call
depth(root.left)
but depth(root.left)
was
already called when doing longest(root)
. If we could
“memorize” (or “memoize”) the work already done, then we can avoid all
these repetitions and get back to the \(O(n)\) total time, but that’s the topic for
Chapter 2 (Dynamic Programming).
In any case, recursion with a byproduct is simple easy and fast!
The counting inversions problem was taken from Roughgarden’s textbook (3.2). The longest path problem was a commonly asked interview question, but its analysis (for the bad solution) is quite interesting. This paradigm of “recursion with byproduct” was summarized by me.