(Advanced thanks to my former Ph.D. student Abhishek Kumar for help with this!)
The question I wanted to answer, which I felt must have a known answer though I'd never seen it, is the following. Suppose (e.g., at training time), I know that the correct label is i, but the model is (perhaps) not predicting i. I'd like to change A as little as possible so that it predicts i, perhaps with an added margin of 1. If I measure "as little as possible" by l2 norm, and assume wlog that i=1, then I get:
minB ||B-A||2 st (xB)1≥(xB)i+1 for all i ≥ 2
This problem arises most specifically in Crammer et al., 2006, "Online Passive-Aggressive Algorithms", though appears elsewhere too.
I'll show below (including python code using just numpy), a very efficient solution to this problem. (If this appears elsewhere and I simply was unable to find it, please let me know.)
First, I'll make the following unproven assertion, though I'm pretty sure it'll go through (famous last words). The assertion is that any difference between A and B will be in the direction of x. In other words, the first row of A will likely move in the direction of x and the other rows of A will move away. Hand-wavy reason: because otherwise you increase the norm ||B-A|| without helping satisfy the constraints.
In particular, I'll assume that bi=ai+dix, where the dis are scalars.
Given this, we can do a bit of algebra:
||B-A||2 = Σi (bi - ai)2 = Σi (ai + di x - ai)2= ||x||2 Σi di2
Since x is a constant, we really only care about minimizing the norm of the deltas.
We can similarly rewrite the constraints to just say:
xb1 ≥ xbi + 1 for all i ≥ 2
iff x(a1 + d1x) ≥ x(ai + dix) + 1 for all i ≥ 2
iff xa1 + d1 ||x||2 ≥ xai + di||x||2 + 1 for all i ≥ 2
iff d1 ≥ di + Ci for all i ≥ 2
Now, we have a plausibly simpler optimization problem just over the d vector:
mind Σi di2 st d1 ≥ di + Ci for all i ≥ 2
This was the place I got stuck. I felt like there would be some algorithm for solving this that involves sorting and projecting and whatever, but couldn't figure it out for a few days. I then asked current and former advisees, at which point Abhishek Kumar came to my rescue :). He pointed me to the paper "Factoring Non-negative Matrices with Linear Programs" by Bittorf et al., 2012. It's maybe not obvious that this is all connect from the title, but they solve a very similar problem in Algorithm 5. All of the following is due to Abhishek:
In particular, their Equation 11 has the form:
minx ||z-x||2 st 0 ≤ xi ≤ x1 for all i, x1 ≤ 1
My problem can be happed to this by a change of variables: z=d+D, where D=[0, C2, C3, ..., Ck]. We also need to remove the lower and upper bounds. This means that their Algorithm 5 can be used to solve my problem, but with all of the [0,1] projection steps removed. For completeness, here is their algorithm:
Putting this all together, we arrive at some python code in column_squishing.py for solving my multiclass problem. Here's an example of running it:
≫ A = np.random.randn(3,5) ≫ x = np.random.randn(5) ≫ A.dot(x) array([ 0.90085352, 2.25573249, 0.25974194])
So currently label "1" is winning by a big margin. Let's make each label win by a margin of one, one at a time:
≫ multiclass_update(A, x, 0).dot(x) array([ 2.078293 , 1.078293 , 0.25974194])
≫ multiclass_update(A, x, 1).dot(x)
array([ 0.90085352, 2.25573249, 0.25974194])
≫ multiclass_update(A, x, 2).dot(x)
array([ 0.80544265, 0.80544265, 1.80544265])
Hopefully you find this helpful. If you end up using it, please make some sort of acknowledgement to this blog post and definitely please credit Abhishek.
The fast algorithm is over 100x faster than a numerical solution (and produces the same result).
ReplyDeletehttps://gist.github.com/timvieira/4a4e7e700c34c04160b93aa03a14861c
Also check out these two papers (h/t Mathieu Blondel):
ReplyDeletehttp://epubs.siam.org/doi/abs/10.1137/1.9781611972801.27
http://mblondel.org/publications/mblondel-icpr2014.pdf
You can solve the Lagrange / Fenchel dual of your setting exactly. It was re-derived multiple times starting with Kesler. Crammer and myself gave an explicit algorithm for both the separable and the non-separable case. We provided a d time exact algorithm as well as a fixed point algorithm (figures 2 & 3). Your approximation is very nice thought I suspect not nearly as fast when the number of classes is in the thousands.
ReplyDeleteDo you mean Figure 2 in http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf ? This doesn't seem to be an exact solution to the problem I wrote above unless I'm missing something. It only updates two classes, which is going to be insufficient in general. Probably I'm looking at the wrong paper tho because Figure 3 in that paper is a graph :/. Can you point me in the right direction?
ReplyDeleteI'm not sure why you refer to this as an approximation. Which step is approximated?
This comment has been removed by a blog administrator.
ReplyDelete