## Tuesday, May 6, 2014

### Don't stop! I yield!

Riffing on the topic I touched on a few weeks ago, of questions on high-pressure exams, I wanted to share some thoughts about a fun programming problem that I'm not sure makes a great interview question -- but maybe that's an argument for why it does.

Consider the digits 123456789, in that order. Now consider the value of the expression resulting from inserting + or * between some of those digits: for example, you could insert + after 2 and 4 and * after 1 and 8, resulting in
>>> 1*2+34+5678*9
<<<   51148
or insert + after 3, 6, and 8, resulting in
>>> 123+456+78+9
<<<   666
Write a program which prints out all the expressions of this kind which evaluate to 2003. (Hint: there are exactly four.)

Now, the poster who referenced this problem said that he came up with it as a way to test whether candidates can represent a problem using appropriate library structures, and also whether they think to use recursion. I think the first goal is well-served here: there are many correct ways to implement the problem, and plenty of ways which are either truly wrong or result in programming headaches.

I'm less sure about the recursion aspect, but that's what prompted me to write the post, so more on that in due time.

Of course, the actual target value can vary; it is the responsibility of the interviewer/problem-poser to choose a value that can actually be expressed with these digits, and to know how many different such expressions there are. In particular, the phrasing of the problem in terms of "find all" instead of "find one" imposes significant programming constraints.

OK, so what do you do if this problem gets pitched to you in an interview? Well, the first thing to do is represent the problem, and then implement a quick brute-force solution:
digitlist = [str(x) for x in range(1,10)]
value = lambda L: eval(''.join(L))
Note how it's just easier to turn everything into strings, since I've chosen to implement evaluation using string concatenation.

Now we need some way to specify how we're inserting operations. I found that it's much easier to think of what we're doing not as inserting possibly one of two operators (+ and *), but definitely inserting one of three (+, *, or a blank). Now we need a quick function to zip together a list (or other iterable) of nine elements with another list of eight:
def longzip(long,short):
retlist = list()
longiter = iter(long)
for item in short:
retlist.append(next(longiter))
retlist.append(item)
retlist.append(next(longiter))
return retlist
This approach avoids explicitly converting long and short to lists. Of course, if the iterable passed as long is empty, this will raise an error, which is fine.
from itertools import product as xp
OPS = xp(['*','+',''],repeat = 8)
for ops in OPS:
t = longzip(digitlist,ops)
if value(t) == 2003:
print(''.join(t))
This is a perfectly good brute-force solution. However, it obviously doesn't scale, and you should say this to the interviewer: if you have $$n$$ digits in your expression, the running time is $$3^{n-1}$$. We definitely want to do better.

OK, if you're like me, you want to do two things: write a general function, and use recursion to make the problem simpler. Unfortunately, my first attempt didn't do so well:
def express_wrong(L,v):
for point in range(len(L)-1,-1,-1):
if value(L[point:]) <= v:
t1 = express_wrong(L[:point],v-value(L[point:]))
if t1 is not None:
return t1 + ['+'] + L[point:]

if value(L[point:])%v == 0:
t1 = express_wrong(L[:point],v//value(L[point:]))
if t1 is not None:
return t1 + ['*'] + L[point:]

There are a few problems with this function. The first is that it's ugly. That's not a fatal flaw, but it's not necessarily the easiest code to look at and see either what it's actually doing, or what it's trying to do. (Why the checks for None? Is there a way to cut out redundant calls to value()?)

A much deeper problem is that the function is incorrect: it doesn't return expressions with the desired value. (Exercise for the reader: figure out why not.)

And the third problem is that it doesn't even do what the problem asked for, which was to find all the expressions with the specified value. express_wrong() only finds one, and then exits.

#### Yield!

Enter the yield instruction.[1] It's almost like a replacement for the return instruction, except designed exactly for this situation of spitting back multiple values in sequence without kicking out of the function. In fact, after a function yields a value, it saves its entire state and goes into hibernation, waiting to be asked for another value to yield, at which point it picks right up where it left off.

Have you figured out why express_wrong() gave incorrect answers? Well, it's because not everything to the left of a multiplication sign gets multiplied by everything to the right. So we have to determine which operation binds least tightly, and use that operation as our splitting point.

def express(L,v):
if len(L) == 1 and value(L) == v:
yield L
elif len(L) > 1:
for point in range(len(L)-1,-1,-1):
if eval('*'.join(L[point:])) > v:
continue
OPS = xp(['*',''],repeat = len(L)-1-point)
for ops in OPS:
t2 = longzip(L[point:],ops)
v2 = value(t2)
if v2 > v:
continue
elif v2 == v and point == 0:
yield t2
else:
for t1 in express(L[:point],v-v2):
yield t1 + ['+'] + t2

Overall, the function is a triply-nested loop; the yield instruction allows us to process all the steps of all these loops without needing complicated code to remember state. The base case of the recursion is in the third line; the recursive call, of course, is in the very last.

Finally, to view the expressions, we do
for t in express(digitlist,2003):
print(''.join(t))

[1] Yes, I know that in the Python standard, and a lot of other places in the CS world, these lines of code are known as "statements". That's fucking stupid. A "statement" is an utterance which is true or false. "Statement" should be a synonym for "boolean-typed expression". These are instructions.