26 March 2016

A dagger by any other name: scheduled sampling

Scheduled Sampling was at NIPS last year; the reviews are also online. (Note: I did not review this paper.) This is actually the third time I've tried to make my way through this paper, and to force myself to not give up again, I'm writing my impressions here. Given that this paper is about two things I know a fair amount about (imitation learning and neural networks), I kept getting frustrated at how hard it was for me to actually understand what was going on and how it related to things we've known for a long time. So this post is trying to ease entry for anyone else in that position.

What is the problem this paper is trying to solve?

Neural networks are now often used to generate text. As such, they are often trained to maximize a product of probabilities: the probability of the 5th word in the output given all previous words. Unfortunately, the distribution on which they are trained is typically the gold standard distribution of previous words. This means that they never learn to recover from their own mistakes: at test time, if they make an error (which they will), they could end up in some distribution they've never seen at training time and behave arbitrarily badly. Matti Kaariainen proved over ten years ago that you can construct worst case examples on which, even with vanishingly small probability of error, the sequence predictor makes T/2 mistakes on a length T binary sequence prediction problem. We've also known for about as long how to fix this (for models that are more general than neural networks), but more on that later.

How do they solve this problem?

On the ith minibatch, for each output word, they flip an (independent) coin. If it comes up heads, they use the true ith word when training (the training data one); if it comes up tails, they use the predicted ith word. If the coin probability is 1, then this amounts to the standard (inconsistent) training paradigm. If the coin probability is 0, then this amounts to always using the predicted words.

How does this relate to existing stuff?

To put this in terminology that I understand better, they're playing with the rollin distribution, adjusting it between "Ref" (when the probability is one) to "Learn" (when the probability is zero) and, as is natural, "Mix" when the probability is somewhere in the middle.

I think (but please correct me if I'm wrong!) that Searn was the first algorithm that considered mixture policies for roll-in. (Previous algorithms, like incremental perceptron, Lasso, etc., used either Ref or Learn.) Of course no one should be actually using Searn anymore (algorithms like DAgger, AggreVaTe and LOLS, complete dominate it as far as I can tell).

For what it's worth, I don't think the "related work" section is the paper is really particularly accurate. Searn was never a reinforcement learning algorithm (where do you get the oracle in RL?), and the paper completely dismisses DAgger, referring to it only as "a related idea" and highlights the fact that the "scheduled sampling" approach is online. (Side note, the sense of "online" used in the scheduled sampling paper is the "stochastic" sense.) Of course, DAgger can be trained online, and in fact that's how we do it in the vw implementation... and it works incredibly well! You can also do a straight-up online analysis.

The only real difference to DAgger is the exact form of the schedule. DAgger uses a schedule of something like P(use truth) = 0.99i, where i is the round/minibatch/whatever. Scheduled sampling considers two other rates, one linear and one inverse sigmoid. They're shown below, where the red curve is the DAgger schedule (Figure 2 from the scheduled sampling paper):

What are the results like?

There are basically two novelties in this paper: (1) applying DAgger to training neural networks, and (2) trying different schedules.

I would expect results that analyse the importance of these two things, but I only see a partial analysis of (1). In particular, the results show that training neural networks with DAgger is better than training them with supervised learning. This is nice to know, and is a nice replication of the results from Ross, Gordon and Bagnell. It would be also nice to see the other side: that going neural here is beneficial, but I suspect that's just taken for granted.

On (2), nothing is said! For two of the three experiments (image captioning and constituency parsing), the paper says that they use inverse sigmoid decay. For the third, the paper uses the linear decay schedule. Presumably the others performed worse, but I would really have appreciated knowing whether there's any benefit to doing something other than the DAgger schedule. And if there's benefit, how big?

In particular, the DAgger analysis works for any schedule that has the sum of probabilities over rounds go to zero has the number of rounds goes to infinity. This does not happen for the linear schedule (the sum goes to some lower bound, epsilon). I'm pretty sure it happens for the sigmoid schedule but didn't bother to verify. It would have been nice if the paper had verified this. At any rate, if I'm going to switch from something that I understand theoretically (like DAgger with exponential decay) to something that's heuristic, I would like to know how much empirical benefit I'm getting in exchange for losing my understanding of what's going on.

What's missing

One major difference to DAgger is that this paper makes a weird assumption that the correct next thing to do is independent of what was done in the past. For examples, suppose the gold standard output is "The little bird flew into the mailbox." Suppose that for the first word, the schedule uses the true output. We now have an output of "The". For the second word, maybe the schedule uses the predicted output. Maybe this is "bird". For any reasonable evaluation criteria, probably the best completion of "The bird" with respect to this gold standard is "The bird flew into the mailbox." But the approach proposed here would insist on producing "The bird bird flew into the mailbox."

This is the question of computing the oracle completion, or the "minimum cost action" at any given point. Yoav Goldberg refers to this as a "dynamic oracle." We know how to do this for many loss functions, and I'm surprised that (a) this isn't an issue in the experiments here, and (b) that it's not mentioned.

Reviewing the Reviews

The reviewers seem to mostly have liked the paper, which makes sense, it is interesting. I'm really really surprised that none of the reviewers took the authors to task on the connections to DAgger, and instead concentrated on Searn. I suspect that the reviewers were neural networks experts and not imitation learning experts, and given that the paper talked only about Searn (which at this point is an easy strawman), the reviews tended to focus on that.
The biggest surprise to me in the reviews was that comment that the paper had "comprehensive empirical testing," also echoed by reviewer 3. (Reviewer 2 was basically a no-op.) This paper had impressive experiments on big hard problems, but the experiments were anything but comprehensive from the perspective of understanding what is going on.

Some last thoughts

Let me say that although what's written above is critical, I of course don't dislike this work. It's something that needed to be done, and it's great someone did it. I wish that it were just done better and that the paper made some attempt at situating itself in the literature.

For comparison, take a gander at the recent arxiv paper on training an LSTM using imitation learning/learning-to-search technology by Ballesteros, Goldberg, Dyer and Smith. Not only does this paper get really impressive empirical results, they are also remarkably straightforward about connections to previous work.