24 August 2016

Debugging machine learning

I've been thinking, mostly in the context of teaching, about how to specifically teach debugging of machine learning. Personally I find it very helpful to break things down in terms of the usual error terms: Bayes error (how much error is there in the best possible classifier), approximation error (how much do you pay for restricting to some hypothesis class), estimation error (how much do you pay because you only have finite samples), optimization error (how much do you pay because you didn't find a global optimum to your optimization problem). I've generally found that trying to isolate errors to one of these pieces, and then debugging that piece in particular (eg., pick a better optimizer versus pick a better hypothesis class) has been useful.

For instance, my general debugging strategy involves steps like the following:

  1. First, ensure that your optimizer isn't the problem. You can do this by adding "cheating" features -- a feature that correlates perfectly with the label. Make sure you can successfully overfit the training data. If not, this is probably either an optimizer problem or a too-small-sample problem.
  2. Remove all the features except the cheating feature and make sure you can overfit then. Assuming that works, add feature back in incrementally (usually at an exponential rate). If at some point, things stop working, then probably you have too many features or too little data.
  3. Remove the cheating features and make your hypothesis class much bigger; e.g., by adding lots of quadratic features. Make sure you can overfit. If you can't overfit, maybe you need a better hypothesis class.
  4. Cut the amount of training data in half. We usually see test accuracy asymptote as the training data size increases, so if cutting the training data in half has a huge effect, you're not yet asymptoted and you might do better to get some more data.
The problem is that this normal breakdown of error terms comes from theory land, and, well, sometimes theory misses out on some stuff because of a particular abstraction that has been taken. Typically this abstraction has to do with the fact that the overall goal has already been broken down into an iid/PAC style learning problem, and so you end up unable to see some types of error because the abstraction hides them.

In an effort to try to understand this better, I tried to make a flow chart of sorts that encompasses all the various types of error I could think of that can sneak into a machine learning system. This is shown below:
I've tried to give some reasonable names to the steps (the left part of the box) and then give a grounded example in the context of ad placement (because it's easy to think about). I'll walk through the steps (1-11) and try to say something about what sort of error can arise at that step.
  1. In the first step, we take our real world goal of increasing revenue for our company and decide to solve it by improving our ad displays. This immediately upper bounds how much increased revenue we can hope for because, well, maybe ads are the wrong thing to target. Maybe I would do better by building a better product. This is sort of a "business" decision, but it's perhaps the most important question you can ask: am I even going after the right things?
  2. Once you have a real world mechanism (better ad placement) you need to turn it into a learning problem (or not). In this case, we've decided that the way we're going to do this is by trying to predict clickthrough, and then use those predictions to place better ads. Is clickthrough a good thing to use to predict increased revenue? This itself is an active research area. But once you decide that you're going to predict clickthrough, you suffer some loss because of a mismatch between that prediction task and the goal of better ad placement.
  3. Now you have to collect some data. You might do this by logging interactions with a currently deployed system. This introduces all sorts of biases because the data you're collecting is not from the final system you want to deploy (the one you're building now), and you will pay for this in terms of distribution drift.
  4. You cannot possibly log everything that the current system is doing, so you have to only log a subset of things. Perhaps you log queries, ads, and clicks. This now hides any information that you didn't log, for instance time of day or day of week might be relevant, user information might be relevant, etc. Again, this upper bounds your best possible revenue.
  5. You then usually pick a data representation, for instance quadratic terms between a bag of words on the query side and a bag of words on the ad side, paired with a +/- on whether the user clicked or not. We're now getting into the position where we can start using theory words, but this is basically limited the best possible Bayes error. If you included more information, or represented it better, you might be able to get a lower Bayes error.
  6. You also have to choose a hypothesis class. I might choose decision trees. This is where my approximation error comes from.
  7. We have to pick some training data. The real world is basically never i.i.d., so any data we select is going to have some bias. It might not be identically distributed with the test data (because things change month to month, for instance). It might not be independent (because things don't change much second to second). You will pay for this.
  8. You now train your model on this data, probably tuning hyperparameters too. This is your usual estimation error.
  9. We now pick some test data on which to measure performance. Of course, this test data is only going to be representative of how well your system will do in the future if this data is so representative. In practice, it won't be, typically at least because of concept drift over time.
  10. After we make predictions on this test data, we have to choose some method for evaluating success. We might use accuracy, f measure, area under the ROC curve, etc. The degree to which these measures correlate with what we really care about (ad revenue) is going to affect how well we're able to capture the overall task. If the measure anti-correlates, for instance, we'll head downhill rather than uphill.
(Minor note: although I put these in a specific order, that's not a prescriptive order, and many can be swapped. Also, of course there are lots of cycles and dependencies here as one continues to improve systems.)

Some of these things are active research areas. Things like sample selection bias/domain adaptation/covariate shift have to do with mismatch of train/test data. For instance, if I can overfit train but generalization is horrible, I'll often randomly shuffle train/test into a new split and see if generalization is better. If it is, there's probably an adaptation problem.

When people develop new evaluation metrics (like Bleu for machine translation), they try to look at things like #10 (correlation with some goal, perhaps not exactly the end goal). And standard theory and debugging (per above) covers some of this too.

I'm very curious if y'all have topics/tricks that you like that aren't mentioned here.

Related reading:

1 comment:

UnknownPi said...

I've learned to always pitch my model against a random and an averaging predictor. If your regressor or classifier can't beat a simple average (or worse, total random guessing), well... no need to continue before finding more signal.

I also like using VW or Random Forest 500 to benchmark against. It can give estimates on the hardness of a problem and how well your optimization is doing vs. very standard modeling techniques.

I really dig the data-halving trick. Have to try that out soon.