26 March 2016

A dagger by any other name: scheduled sampling

Scheduled Sampling was at NIPS last year; the reviews are also online. (Note: I did not review this paper.) This is actually the third time I've tried to make my way through this paper, and to force myself to not give up again, I'm writing my impressions here. Given that this paper is about two things I know a fair amount about (imitation learning and neural networks), I kept getting frustrated at how hard it was for me to actually understand what was going on and how it related to things we've known for a long time. So this post is trying to ease entry for anyone else in that position.

What is the problem this paper is trying to solve?

Neural networks are now often used to generate text. As such, they are often trained to maximize a product of probabilities: the probability of the 5th word in the output given all previous words. Unfortunately, the distribution on which they are trained is typically the gold standard distribution of previous words. This means that they never learn to recover from their own mistakes: at test time, if they make an error (which they will), they could end up in some distribution they've never seen at training time and behave arbitrarily badly. Matti Kaariainen proved over ten years ago that you can construct worst case examples on which, even with vanishingly small probability of error, the sequence predictor makes T/2 mistakes on a length T binary sequence prediction problem. We've also known for about as long how to fix this (for models that are more general than neural networks), but more on that later.

How do they solve this problem?

On the ith minibatch, for each output word, they flip an (independent) coin. If it comes up heads, they use the true ith word when training (the training data one); if it comes up tails, they use the predicted ith word. If the coin probability is 1, then this amounts to the standard (inconsistent) training paradigm. If the coin probability is 0, then this amounts to always using the predicted words.

How does this relate to existing stuff?

To put this in terminology that I understand better, they're playing with the rollin distribution, adjusting it between "Ref" (when the probability is one) to "Learn" (when the probability is zero) and, as is natural, "Mix" when the probability is somewhere in the middle.

I think (but please correct me if I'm wrong!) that Searn was the first algorithm that considered mixture policies for roll-in. (Previous algorithms, like incremental perceptron, Lasso, etc., used either Ref or Learn.) Of course no one should be actually using Searn anymore (algorithms like DAgger, AggreVaTe and LOLS, complete dominate it as far as I can tell).

For what it's worth, I don't think the "related work" section is the paper is really particularly accurate. Searn was never a reinforcement learning algorithm (where do you get the oracle in RL?), and the paper completely dismisses DAgger, referring to it only as "a related idea" and highlights the fact that the "scheduled sampling" approach is online. (Side note, the sense of "online" used in the scheduled sampling paper is the "stochastic" sense.) Of course, DAgger can be trained online, and in fact that's how we do it in the vw implementation... and it works incredibly well! You can also do a straight-up online analysis.

The only real difference to DAgger is the exact form of the schedule. DAgger uses a schedule of something like P(use truth) = 0.99i, where i is the round/minibatch/whatever. Scheduled sampling considers two other rates, one linear and one inverse sigmoid. They're shown below, where the red curve is the DAgger schedule (Figure 2 from the scheduled sampling paper):

What are the results like?

There are basically two novelties in this paper: (1) applying DAgger to training neural networks, and (2) trying different schedules.

I would expect results that analyse the importance of these two things, but I only see a partial analysis of (1). In particular, the results show that training neural networks with DAgger is better than training them with supervised learning. This is nice to know, and is a nice replication of the results from Ross, Gordon and Bagnell. It would be also nice to see the other side: that going neural here is beneficial, but I suspect that's just taken for granted.

On (2), nothing is said! For two of the three experiments (image captioning and constituency parsing), the paper says that they use inverse sigmoid decay. For the third, the paper uses the linear decay schedule. Presumably the others performed worse, but I would really have appreciated knowing whether there's any benefit to doing something other than the DAgger schedule. And if there's benefit, how big?

In particular, the DAgger analysis works for any schedule that has the sum of probabilities over rounds go to zero has the number of rounds goes to infinity. This does not happen for the linear schedule (the sum goes to some lower bound, epsilon). I'm pretty sure it happens for the sigmoid schedule but didn't bother to verify. It would have been nice if the paper had verified this. At any rate, if I'm going to switch from something that I understand theoretically (like DAgger with exponential decay) to something that's heuristic, I would like to know how much empirical benefit I'm getting in exchange for losing my understanding of what's going on.

What's missing

One major difference to DAgger is that this paper makes a weird assumption that the correct next thing to do is independent of what was done in the past. For examples, suppose the gold standard output is "The little bird flew into the mailbox." Suppose that for the first word, the schedule uses the true output. We now have an output of "The". For the second word, maybe the schedule uses the predicted output. Maybe this is "bird". For any reasonable evaluation criteria, probably the best completion of "The bird" with respect to this gold standard is "The bird flew into the mailbox." But the approach proposed here would insist on producing "The bird bird flew into the mailbox."

This is the question of computing the oracle completion, or the "minimum cost action" at any given point. Yoav Goldberg refers to this as a "dynamic oracle." We know how to do this for many loss functions, and I'm surprised that (a) this isn't an issue in the experiments here, and (b) that it's not mentioned.

Reviewing the Reviews

The reviewers seem to mostly have liked the paper, which makes sense, it is interesting. I'm really really surprised that none of the reviewers took the authors to task on the connections to DAgger, and instead concentrated on Searn. I suspect that the reviewers were neural networks experts and not imitation learning experts, and given that the paper talked only about Searn (which at this point is an easy strawman), the reviews tended to focus on that.
The biggest surprise to me in the reviews was that comment that the paper had "comprehensive empirical testing," also echoed by reviewer 3. (Reviewer 2 was basically a no-op.) This paper had impressive experiments on big hard problems, but the experiments were anything but comprehensive from the perspective of understanding what is going on.

Some last thoughts

Let me say that although what's written above is critical, I of course don't dislike this work. It's something that needed to be done, and it's great someone did it. I wish that it were just done better and that the paper made some attempt at situating itself in the literature.

For comparison, take a gander at the recent arxiv paper on training an LSTM using imitation learning/learning-to-search technology by Ballesteros, Goldberg, Dyer and Smith. Not only does this paper get really impressive empirical results, they are also remarkably straightforward about connections to previous work.


Justin said...

Some discussion at reddit: https://www.reddit.com/r/MachineLearning/comments/4c35fd/a_dagger_by_any_other_name_scheduled_sampling/

Unknown said...

> "(where do you get the oracle in RL?)"

Oracle here just means labeled examples, right?

It's been my impression that learning from demonstrations in the RL setting is either: (1) behavioral cloning / imitation learning, where the goal is to learn some policy from a policy class, or (2) inverse reinforcement learning where the goal is to learn the reward function that best explains the choices.

So, then, an oracle in the RL sense is just one of these two things.

Which, also, I've always been curious why the comparison to IRL literature is not really present in your papers. For example, [1] uses compares various IRL algorithms for the task of parsing.

I also found this statement from the Scheduled Sampling paper a bit off:

> "Furthermore, SEARN has been proposed in the context of reinforcement learning"

I've always taken your approach as performing extra exploration when one has a supervised search signal. The phrase "context of reinforcement learning" would indicate to me that the base approach is to explore and the contribution is the incorporation of supervised signals. This seems off.

But, aside from the inaccuracy I see in that statement for describing your paper, I also disagree which how it partitions the research space. In the context of inverse reinforcement learning and imitation learning, using only the supervised signal is a standard approach. (for example, [2] is the reinforcement learning context that comes to mind).

[1] Neu, Gergely, and Csaba Szepesvári. "Training parsers by inverse reinforcement learning." Machine learning 77.2-3 (2009): 303-337.

[2] MacGlashan, James, and Michael L. Littman. "Between imitation and intention learning." Proceedings of the 24th International Conference on Artificial Intelligence. AAAI Press, 2015.

Unknown said...

I agree with your comments. We recently had a paper on the same topic that made similar considerations, see http://arxiv.org/pdf/1511.06732.pdf
In this work, we tackle the issue of text generation using incremental learning and REINFORCE to a) use the model predictions at training time and b) optimize for the metric of interest (e.g., BLEU).
Please, let us know if you have suggestions regarding how we relate to previous work.

Two notes:
1) Bengio et al NIPS 2015 is perhaps even more related to
Venkatraman, A., Hebert, M., and Bagnell, J.A. Improving multi-step prediction of learned time
series models. In AAAI, 2015.
because here they also predict the next ground truth action
2) We actually tried in our work to use a "dynamic oracle" but it did not work very well in our case (and therefore, it was not reported in the paper). When you optimize BLEU and try to find the optimal sequence completing your given prefix, you get sentences that often are not grammatical. Also, there are tokens that are rather frequent like "the" and the oracle keeps predicting them if you miss them (because that would increase your unigram counts, for instance). However, this skews the distribution of labels and makes convergence of sgd really hard.