Co-author here. Super excited to see this work posted on HN!
Happy to answer questions.
Majromax 18 days ago [-]
> Happy to answer questions.
Do you expect instability between successive macrobatch gradients? That is, why are you comparing microgradients within a single batch, adding a whole bunch of serialization headaches, rather than comparing with the macrogradient of the previous step?
Given your test setup of noisy labels, isn't the sequential accumulation of microgradients dangerous? Suppose we take this to the limit of a minibatch size of 1, and the first microgradient happens to come from an example with an incorrect label. If I understand this correctly, gradient filtering would seek to add all the gradients that are consistent with the bad example, rejecting the gradients that belong to good examples. The text only contemplates the perversity of rejecting all examples from a microbatch, not accidentally accepting only corrupted gradients.
The filtered gradients are used via SGD with momentum (although equation (6) looks like momentum-free SGD). Have you seen / do you expect different results when supplying the filtered gradients to Adam/AdamW or other, more sophisticated optimizers?
Your thresholding is binary, either accepting or rejecting an entire microgradient. Have you tested soft thresholding? Is there an information-theoretic way to explain this effect?
In figure 7, why does GAF with a large threshold result in a lower validation accuracy than the baseline? In GAF-terms, the baseline accepts every microgradient, so I'd expect GAF to converge to the baseline result as the threshold increases. What does the figure-7-but-0%-error curve look like?
isoprophlex 18 days ago [-]
Strong "reviewer #2" vibes in this comment...
Majromax 18 days ago [-]
> Strong "reviewer #2" vibes in this comment...
I've been a peer reviewer before, albeit not in the machine learning space. My comments above were general questions and not a review, since I'm now curious about applying this technique in my own work.
When I do review a paper, my guide star is "does the paper answer its own question using a methodology powerful enough to detect if the answer is 'no'?" A secondary question is "are any arbitrary choices of (hyper)parameters sufficiently justified?" Theoretical beauty would be nice, but that's secondary to a robust result.
If I were reviewing this paper, I'd be mostly satisfied with it. My questions above are honest matters of curiosity rather than strict demands for greater rigour, and "we haven't looked at this" is an acceptable answer to such a question.
ubj 18 days ago [-]
These are all fair questions. This is the whole point of peer review--to bring up idea or concerns the authors may not have considered.
fchaubard 18 days ago [-]
I’ll do my best to answer here.
> Do you expect instability between successive macrobatch gradients? That is, why are you comparing microgradients within a single batch, adding a whole bunch of serialization headaches, rather than comparing with the macrogradient of the previous step?
>> I do. If you take a sufficiently large step, the path of steepest descent will surely change sometimes. If it doesn’t then you should just double or triple your step size. So you just don’t know why the cosine distance is high, change in curvature of your loss curve or gradient variance. Most large runs are splitting up gradients across nodes, so if you are already doing so, instead of averaging, just do GAF instead.
> Given your test setup of noisy labels, isn't the sequential accumulation of microgradients dangerous? Suppose we take this to the limit of a minibatch size of 1, and the first microgradient happens to come from an example with an incorrect label. If I understand this correctly, gradient filtering would seek to add all the gradients that are consistent with the bad example, rejecting the gradients that belong to good examples.
>> Yes but “consistent with the bad example” is nearly impossible. The gradient directions in late stages of training without noisy labels are already orthogonal or worse.. if you flip the label of any 2 samples and do a MB of size 1 on it they will all be negatively correlated to each other so you will practically always skip with GAF. However, in standard SGD you will ALWAYS average them in until you’ve completely memorized the noisy samples.
> The filtered gradients are used via SGD with momentum (although equation (6) looks like momentum-free SGD). Have you seen / do you expect different results when supplying the filtered gradients to Adam/AdamW or other, more sophisticated optimizers?
>> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating.
> Your thresholding is binary, either accepting or rejecting an entire microgradient. Have you tested soft thresholding? Is there an information-theoretic way to explain this effect?
>> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea.
> In figure 7, why does GAF with a large threshold result in a lower validation accuracy than the baseline? In GAF-terms, the baseline accepts every microgradient, so I'd expect GAF to converge to the baseline result as the threshold increases. What does the figure-7-but-0%-error curve look like?
>> Good call out. Yes that wasn’t intuitive to me. You are correct that when Tau hits 2 it does converge to baseline as expected. But at 1.05 it actually does worse than baseline in the presence of 5% noise. So as you increase Tau above 1, which I never recommend doing, it starts to underperform baseline in the presence of noise then by 2 it matches. But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range.
Majromax 18 days ago [-]
> I do. If you take a sufficiently large step, the path of steepest descent will surely change sometimes. If it doesn’t then you should just double or triple your step size. So you just don’t know why the cosine distance is high, change in curvature of your loss curve or gradient variance. Most large runs are splitting up gradients across nodes, so if you are already doing so, instead of averaging, just do GAF instead.
I agree with normal SGD, but with-momentum optimizers depend on some consistency of gradients between optimizer steps. On the other hand, with-momentum optimizers try to maximize the effective learning rate subject to that momentum, so it could go the other way as well.
> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating.
> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea.
Maybe the answer lies in asking what's optimized by averaging.
For learning, we're interested in the intractable problem of the gradient of parameters with respect to the whole distribution of data. In practice, we only compute the gradient of parameters with respect to samples drawn from the data distribution, leading to stochastic gradient descent. SGD-with-momentum makes the additional assumption that the steepest descent path has relatively low curvature, so the mean gradient of previous batches is still informative.
Overall, this approach is still optimal if you imagine that the computed sample gradients are corrupted with mean-zero Gaussian noise: averaging over many samples is the best way to eliminate that noise.
Your work identifies and rejects outlier gradients. In a very hand-wavy way, this is kind of like a median filter, and a median filter is great at rejecting shot noise. I speculate that this is why your technique is particularly good for your examples with corrupted labels, since that corruption process replaces single samples with something completely uninformative.
This is why I also wonder about soft thresholding. A soft threshold could be interpreted as an intermediate rather than binary belief about whether a sample is informative, or it could be interpreted as belief that a sample has more than zero but less than the typical amount of information.
> But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range.
If it would be easy to add (that is, if you still have the data on hand), might I suggest adding a subpanel to figure 7 noting the fraction of minibatches that are accepted/rejected with each threshold? If you're hypothetically rejecting 80% of minibatches at the optimum threshold, it'd hint that your method is finding the golden kernel of most representative data to learn from; in contrast if you're hypothetically rejecting just a couple of percent then it'd hint that your method is more narrowly finding the corrupted samples. Either range (or anything in between) would be interesting.
szvsw 18 days ago [-]
Just skimmed the abstract and it immediately has me wondering if there are ways that this research intersects with explainability research - specifically, the notion that certain inputs only activate certain portions of a network. I wonder if at some point that kind of information could be leveraged to provide better sorting of datasets into batches. Obviously this conflicts a little bit with the notion that you want your batches to be just completely random permutations.
In some sense, if two inputs activate the same parts of the network at high intensity, then they are more likely to result in conflicting gradients if they are in separate microbatches from the same macrobatch.
I’m curious if you think long term, there could be some utility in trying to periodically extract information about how samples activate the network and use that (or some other representation) to better sort the micro/macro batches to maximize your parallelism by making conflicts less likely.
Obviously there would be some sort of time penalty dealing with calculating whatever that representation is and sorting so that overhead might far outweigh any gains, but I could see it as at least plausible that there might be some contexts/scales where making that periodic investment could pay off.
fchaubard 18 days ago [-]
Yes! I think this a great area of research. If you think of the gradient values as a blame score for why you got the answer wrong, then you can have a lot of fun with exploring which weights light up for different problems. A note, in Ring All Reduce they actually don’t ever share the FULL gradient but instead blocks. So to put this into practice you’d have to show that you can do the thresholding on the block of gradients vs the full gradient which you may never be able to fit in VRAM. Will results still hold? I don’t know. I believe it would but that’s for the next paper.
szvsw 18 days ago [-]
Very cool! Glad to hear my intuition is on the right track… I’m very much on the applied ML for engineering design side as opposed to the bleeding edge research side, so in terms of multi-node training I haven’t done much more than spin up a few GPUs and let PyTorch Lightning handle the parallelism, but cool to try to keep up with this stuff.
Thanks for the response and good luck with this!
timmg 18 days ago [-]
From a non-expert here (maybe a dumb question): does “batching” affect quality in a similar way as averaging parallel batches? (Like is there a difference between having a batch that is 10x the size versus averaging 10 batches that were calculated in parallel?)
deddy 18 days ago [-]
Great question! There are a few different aspects to this.
With gradient agreement filtering having a greater number of batches (generally) increases the likelihood of finding another microbatch that agrees with the gradient simply by virtue of having more gradient "samples" to compare. So having more batches increases the chance of success there. The algorithm as laid out in the paper is a simple approach to combining groups of batches where larger numbers batches doesn't necessarily improve you chances of success if the batch you're comparing against is an outlier itself. There are almost certainly better ways of combining greater numbers of batches to get a successful update. This is one of the exciting areas of future work.
Increasing the batch size generally can be though of as "averaging" out the noise in your samples to find a consistent update. This has an interesting affect though where that as you increase the batch size, when using gradient agreement filtering you want to lower the filter threshold as your batch size increases to become "stricter" in terms of the level of agreement to look for to accept a gradient. This is Figure 9 of the paper. This is also consistent with what other researchers have found that simply increasing batch size isn't always better. There is a trade here with diminishing returns of increasing batch size as well (https://arxiv.org/abs/1812.06162). One interesting finding from the work was that smaller batch sizes actually improved training accuracy for CIFAR-100N, very roughly speaking this can be explained by having more "signal" in each batch with smaller batch sizes, at the cost of potentially throwing out batches/gradients if they disagree.
magicalhippo 18 days ago [-]
This is probably a dumb idea, but I'll air it anyway.
I just stumbled upon the model soup[1] paper where, as I understand it, they average weights of fine-tuned models and get an model that performs better. They have a more involved algorithm but even the uniform soup (simple weight average) seems to perform well.
In your paper you mention that especially in late stages the gradients of microbatches are often not aligned, hence the agreement filtering.
What you're doing, from my brief glossing over, is effectively to do a k-means clustering with outlier rejection pass with k = 1. You then use the cluster mean to update the model.
What I'm curious of, assuming the above is correct, is what would happen if you combined the approaches.
That is, do a k-means clustering of the microbatch gradients with k > 1, still rejecting outliers but perhaps with lower threshold, generate k updated models using the k cluster means, and then average the k models afterwards.
I've used something similar to k-means clustering with outlier rejection for noise filtering and it was quite effective, so curious how it would work out here.
So other ways of combining greater numbers of microbatch gradients in an effective/consistent manner for performing an update is one area of potential future work. I think your idea is an interesting way to approach it. Though there are a bunch of potentially effective ways of doing it.
I think the idea of averaging the k-models afterwards though is at odds with the core concept of gradient agreement filtering though because you're back at combining two distinct directions of improvement without a guarantee that the combination is better (even though it does seem to be in practice). The the core idea is that you philosophically only want to learn the patterns that agree across multiple specific examples and build some some algorithmic protections to ensure that is happening. Just averaging, while it might work and even yield improvement, but it doesn't necessarily lead to proper generalized learning.
ithkuil 18 days ago [-]
Do you mean training each model independently and only averaging at a late stage, in order to reduce communication overhead in the distributed scenario?
magicalhippo 18 days ago [-]
I was considering this but thought perhaps that would be too resource intensive or not efficient enough.
So in this case I was more thinking if there were perhaps a few directions that stood out, and instead of potentially rejecting those, consider each a microfine-tune, averaging the result at each step ala the uniform soup.
Though perhaps a stupid idea, I'm not a practitioner.
fchaubard 18 days ago [-]
I think about designing your ideal solver. What do you want in a solver. I want my solver to squeeze all the juice out of the train that it possibly can and no more. If your problem is complete noise, I don’t want my solver getting 100% train accuracy as all SGD methods I am aware of do and as the soup method likely would as well as I am not averaging memorization thetas. I want my solver to score 0% train accuracy as GAF does. There may be other ways of getting there as well.
18 days ago [-]
lassepe 15 days ago [-]
Can you elaborate on if/when $\theta$ is synchronized across nodes?
Algorithm 1 suggests that each node starts gradient aggregation from their local micro-gradient $g$. Since the order of aggregation matters, \theta would likely diverge after apply the step with $g_{\mathrm{GAF}}$ --- even if models on different nodes are initialized with the same weights. Hence, I would expect there to be a weight-synchronization step after each macro-gradient step. Do you have such a step? If so, how do you implement consensus? Simply via averaging?
igorkraw 18 days ago [-]
A few technical questions (I had a somewhat related work with friends here https://openreview.net/forum?id=I3HCE7Ro78H although we focused on gradient multiplicity in adversarial training, not massively parallel training)
1. Do you think this is a form of variance reduction or more a form of curriculum (focus first on the bulk, then on remaining errors)?
2. Did you observe any overfitting/additional adversarial risk?
3. Did you try this on just single-node minibatches as well? How did that perform?
deddy 18 days ago [-]
> 1. Do you think this is a form of variance reduction or more a form of curriculum (focus first on the bulk, then on remaining errors)?
I'd say generally more of a cirriculum (using your terminology). Broadly speaking the idea is to restrict stepping to "high-quality" directions where there is agreement/consistency in the direction of update.
> 2. Did you observe any overfitting/additional adversarial risk?
No, actually one of the coolest findings of the work is that when training with GAF we have found that it prevents overfitting. It might slow/down stop training improvement, but it also prevents overfitting. Essentially what happens in late training when overfitting would occur the gradient directions become orthogonal, when that occurs GAF instead means you just don't take a step. Training ends up plateauing as it becomes harder to find two minibatches that have agreement so you end up having more no-op epochs, but you don't overfit. I think we still have one training run going (after months) on CIFAR-100N-Fine that has yet to overfit. It's still slowly improving, last time we check train and val were both around ~60%.
Adversarial risk is an interesting question, but this should help with that as well provided that the adversarial examples are a minority of the training data and that the adversarial attack comes from overfitting / memorizing the adversarial part of those examples.
> 3. Did you try this on just single-node minibatches as well? How did that perform?
The number of nodes is more of a performance implementation detail in terms of to what extent/scale you parallelsize. For the technique to work you just need to 2+ macrobatches that you can compare to determine to take your step. CIFAR-100-N is small enough that you can run multiple minibatches on a single GPU (node) and it all fits into VRAM. Even it it didn't fit into VRAM you could theoretically save off the gradient to disk before taking a step and the technique would still apply/help/work, it would just be slower.
Yes
Scene_Cast2 18 days ago [-]
Couple of questions. First, would you happen to have a code demo?
Second, and this is more of a hypothetical question for my own understanding rather than a practical one - in a single GPU scenario, could you take compute the loss per-sample without averaging (i.e. "reduce=None" in pytorch), and improve (on a sample efficiency basis) single GPU training with your algorithm? Sorry if this was covered in the paper already.
cheald 17 days ago [-]
As an experiment, I tried implementing this for Stable Diffusion lora training, where I'm training on a single GPU with a batch size of 8, and it does actually seem to have an appreciable impact. In my case, I'm keeping a per-parameter grad EMA, and then computing the cosine distance between the parameter's grad and its EMA, and then multiplying the grad by 0 if (1.0 - cos_sim) > 0.99.
My loss metrics stay roughly the same (they're slightly lower, but SD loss is fraught to interpret because variance by timestep renders it more or less meaningless), but tracking the means of `param.grad.norm / param.numel` (which shows how big the grad updates are) shows the grads stabilizing significantly quicker than baseline. I'm tracking suppressed params / total params via tensorboard, and I show that it drops (as expected) but then stabilizes at around 7%, suggesting that there are model parameters which consistently don't agree. I'm gonna try tracking the variance from the mean, as well, and perhaps down-weight or eliminate grads for parameters which show high cos similarity variance over time (suggesting a generalized lack of agreement in the direction to move, further suggesting that the parameter cannot contribute meaningfully to the task).
Did you explore microbatch sizes below 100? Curious about how far this can be pushed and what happens when approaching the limit microbatch size of 1.
ithkuil 18 days ago [-]
My intuition is that with very small microbatch sizes you're very likely to end up in one of the two modes: either the vast majority of the samples are aligned and thus pruned away, or they are not aligned. Thus effectively you're dropping a fraction of the samples but without the advantage of removing the variance between samples that belong in different microbatches.
fchaubard 18 days ago [-]
Yes. It’s more of a class spanning thing. I wanted batch composition across the two microbatches to be the same. So if you have class 1,2,3 in batch one and class 4,5,6 in class two I would fully expect the cosine distance to be orthogonal or worse, and it could be a good update. But if you have class 1,2,3 in batch one and class 1,2,3 in class two I would fully expect the cosine distance to be positively correlated and if not you should skip. So you could bring this to MB of size 5 for example but just make sure you have the same batch composition. This poses a big challenge in LLM training honestly bc technically classes is vocab size. So I need one “a”, one “b”, etc which is silly. This is why micro gradients in LLMs hit cosine distance of 2. So when you are sampling you kind of need to ensure the microbatches are of the same task at least.
fchaubard 18 days ago [-]
Here too!
eru 18 days ago [-]
> We show this technique consistently outperforms validation accuracy, in some cases by up to 18.2\% compared to traditional training approaches while reducing the computation required nearly an order of magnitude because we can now rely on smaller microbatch sizes without destabilizing training.
The accuracy improvement is great, but I'm really looking forward to the reduction in computation required!
fchaubard 18 days ago [-]
Yes it will allow stable training at much smaller batch sizes. Test it out and let us know if it works for your use case!
woadwarrior01 18 days ago [-]
Very interesting idea! Reminiscent of this robust learning paper[1] from 2020, where they do something even simpler: zero out elements in gradients with inconsistent signs, before averaging them.
Hey thanks! Ya we tried similar strategies to this and could not beat cosine distance < tau, average, else skip. It was too much to put in the paper and we may put it in the arxiv version but we tried Sign AND gating and zero’ing out if the signs don’t agree, we tried L2<tau, etc but nothing beat cosine distance.
gwern 18 days ago [-]
I'm also reminded of top-_k_ training/sampling, where you throw away the inconsistent samples: https://arxiv.org/abs/2002.06224 "Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples", Sinha et al 2020.
Which seems like it would be roughly equivalent to throwing away gradients if they are inconsistent...?
SubiculumCode 18 days ago [-]
Is this the paper that the cohost of 'Last Week in AI' said was the paper of the quarter to read back to front?
cgdl 18 days ago [-]
Very interesting. A related paper from a couple of years ago proposed a similar idea to understand generalization in deep learning:
As someone who only has a passing understanding of parallel training, I always found it wonderful that averaging gradients works at all. It seems non-intuitive to me.
Pretty cool, imho, that there are finding better ways to train in parallel.
d3m0t3p 18 days ago [-]
Well, it's just like stochastic gradient descent, if you think about it. The normal gradient descent is computed using the whole training set. The stochastic gradient is trained on a batch (a subset of the training set), and in the distributed case, we compute two batches at once by doing the gradient on each in parallel.
The intuition works IMO, but indeed, having the first batch update and then the second, is not equal to having the mean update.
This is indeed super cool !
eru 18 days ago [-]
Does anyone actually use the 'normal gradient descent' with the whole training set? I only ever see it as a sort of straw man to make explanation easier.
jey 18 days ago [-]
Generally yes, vanilla gradient descent gets plenty of use. But for LLMs: no, it’s not really used, and stochastic gradient descent provides a form of regularization, so it probably works better in addition to being more practical.
bravura 18 days ago [-]
Full batch with L-BFGS, when possible, is wildly underappreciated.
amarcheschi 18 days ago [-]
Not entirely related: I've just ended my internship at the Italian research council, where I worked on implementing federated learning algorithms on a simulator to make benchmarking easy. I couldn't finish my job because an algorithm I was trying to port developed by IBM had hardcoded values that prevented the algorithm from working with different models than the ones already coded by them in their code. There was also a comment like "this is hardcoded for now, will be changed later" yeah last commit was 4 years ago
Happy to answer questions.
Do you expect instability between successive macrobatch gradients? That is, why are you comparing microgradients within a single batch, adding a whole bunch of serialization headaches, rather than comparing with the macrogradient of the previous step?
Given your test setup of noisy labels, isn't the sequential accumulation of microgradients dangerous? Suppose we take this to the limit of a minibatch size of 1, and the first microgradient happens to come from an example with an incorrect label. If I understand this correctly, gradient filtering would seek to add all the gradients that are consistent with the bad example, rejecting the gradients that belong to good examples. The text only contemplates the perversity of rejecting all examples from a microbatch, not accidentally accepting only corrupted gradients.
The filtered gradients are used via SGD with momentum (although equation (6) looks like momentum-free SGD). Have you seen / do you expect different results when supplying the filtered gradients to Adam/AdamW or other, more sophisticated optimizers?
Your thresholding is binary, either accepting or rejecting an entire microgradient. Have you tested soft thresholding? Is there an information-theoretic way to explain this effect?
In figure 7, why does GAF with a large threshold result in a lower validation accuracy than the baseline? In GAF-terms, the baseline accepts every microgradient, so I'd expect GAF to converge to the baseline result as the threshold increases. What does the figure-7-but-0%-error curve look like?
I've been a peer reviewer before, albeit not in the machine learning space. My comments above were general questions and not a review, since I'm now curious about applying this technique in my own work.
When I do review a paper, my guide star is "does the paper answer its own question using a methodology powerful enough to detect if the answer is 'no'?" A secondary question is "are any arbitrary choices of (hyper)parameters sufficiently justified?" Theoretical beauty would be nice, but that's secondary to a robust result.
If I were reviewing this paper, I'd be mostly satisfied with it. My questions above are honest matters of curiosity rather than strict demands for greater rigour, and "we haven't looked at this" is an acceptable answer to such a question.
> Do you expect instability between successive macrobatch gradients? That is, why are you comparing microgradients within a single batch, adding a whole bunch of serialization headaches, rather than comparing with the macrogradient of the previous step?
>> I do. If you take a sufficiently large step, the path of steepest descent will surely change sometimes. If it doesn’t then you should just double or triple your step size. So you just don’t know why the cosine distance is high, change in curvature of your loss curve or gradient variance. Most large runs are splitting up gradients across nodes, so if you are already doing so, instead of averaging, just do GAF instead.
> Given your test setup of noisy labels, isn't the sequential accumulation of microgradients dangerous? Suppose we take this to the limit of a minibatch size of 1, and the first microgradient happens to come from an example with an incorrect label. If I understand this correctly, gradient filtering would seek to add all the gradients that are consistent with the bad example, rejecting the gradients that belong to good examples.
>> Yes but “consistent with the bad example” is nearly impossible. The gradient directions in late stages of training without noisy labels are already orthogonal or worse.. if you flip the label of any 2 samples and do a MB of size 1 on it they will all be negatively correlated to each other so you will practically always skip with GAF. However, in standard SGD you will ALWAYS average them in until you’ve completely memorized the noisy samples.
> The filtered gradients are used via SGD with momentum (although equation (6) looks like momentum-free SGD). Have you seen / do you expect different results when supplying the filtered gradients to Adam/AdamW or other, more sophisticated optimizers?
>> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating.
> Your thresholding is binary, either accepting or rejecting an entire microgradient. Have you tested soft thresholding? Is there an information-theoretic way to explain this effect?
>> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea.
> In figure 7, why does GAF with a large threshold result in a lower validation accuracy than the baseline? In GAF-terms, the baseline accepts every microgradient, so I'd expect GAF to converge to the baseline result as the threshold increases. What does the figure-7-but-0%-error curve look like?
>> Good call out. Yes that wasn’t intuitive to me. You are correct that when Tau hits 2 it does converge to baseline as expected. But at 1.05 it actually does worse than baseline in the presence of 5% noise. So as you increase Tau above 1, which I never recommend doing, it starts to underperform baseline in the presence of noise then by 2 it matches. But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range.
I agree with normal SGD, but with-momentum optimizers depend on some consistency of gradients between optimizer steps. On the other hand, with-momentum optimizers try to maximize the effective learning rate subject to that momentum, so it could go the other way as well.
> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating.
> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea.
Maybe the answer lies in asking what's optimized by averaging.
For learning, we're interested in the intractable problem of the gradient of parameters with respect to the whole distribution of data. In practice, we only compute the gradient of parameters with respect to samples drawn from the data distribution, leading to stochastic gradient descent. SGD-with-momentum makes the additional assumption that the steepest descent path has relatively low curvature, so the mean gradient of previous batches is still informative.
Overall, this approach is still optimal if you imagine that the computed sample gradients are corrupted with mean-zero Gaussian noise: averaging over many samples is the best way to eliminate that noise.
Your work identifies and rejects outlier gradients. In a very hand-wavy way, this is kind of like a median filter, and a median filter is great at rejecting shot noise. I speculate that this is why your technique is particularly good for your examples with corrupted labels, since that corruption process replaces single samples with something completely uninformative.
This is why I also wonder about soft thresholding. A soft threshold could be interpreted as an intermediate rather than binary belief about whether a sample is informative, or it could be interpreted as belief that a sample has more than zero but less than the typical amount of information.
> But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range.
If it would be easy to add (that is, if you still have the data on hand), might I suggest adding a subpanel to figure 7 noting the fraction of minibatches that are accepted/rejected with each threshold? If you're hypothetically rejecting 80% of minibatches at the optimum threshold, it'd hint that your method is finding the golden kernel of most representative data to learn from; in contrast if you're hypothetically rejecting just a couple of percent then it'd hint that your method is more narrowly finding the corrupted samples. Either range (or anything in between) would be interesting.
In some sense, if two inputs activate the same parts of the network at high intensity, then they are more likely to result in conflicting gradients if they are in separate microbatches from the same macrobatch.
I’m curious if you think long term, there could be some utility in trying to periodically extract information about how samples activate the network and use that (or some other representation) to better sort the micro/macro batches to maximize your parallelism by making conflicts less likely.
Obviously there would be some sort of time penalty dealing with calculating whatever that representation is and sorting so that overhead might far outweigh any gains, but I could see it as at least plausible that there might be some contexts/scales where making that periodic investment could pay off.
Thanks for the response and good luck with this!
With gradient agreement filtering having a greater number of batches (generally) increases the likelihood of finding another microbatch that agrees with the gradient simply by virtue of having more gradient "samples" to compare. So having more batches increases the chance of success there. The algorithm as laid out in the paper is a simple approach to combining groups of batches where larger numbers batches doesn't necessarily improve you chances of success if the batch you're comparing against is an outlier itself. There are almost certainly better ways of combining greater numbers of batches to get a successful update. This is one of the exciting areas of future work.
Increasing the batch size generally can be though of as "averaging" out the noise in your samples to find a consistent update. This has an interesting affect though where that as you increase the batch size, when using gradient agreement filtering you want to lower the filter threshold as your batch size increases to become "stricter" in terms of the level of agreement to look for to accept a gradient. This is Figure 9 of the paper. This is also consistent with what other researchers have found that simply increasing batch size isn't always better. There is a trade here with diminishing returns of increasing batch size as well (https://arxiv.org/abs/1812.06162). One interesting finding from the work was that smaller batch sizes actually improved training accuracy for CIFAR-100N, very roughly speaking this can be explained by having more "signal" in each batch with smaller batch sizes, at the cost of potentially throwing out batches/gradients if they disagree.
I just stumbled upon the model soup[1] paper where, as I understand it, they average weights of fine-tuned models and get an model that performs better. They have a more involved algorithm but even the uniform soup (simple weight average) seems to perform well.
In your paper you mention that especially in late stages the gradients of microbatches are often not aligned, hence the agreement filtering.
What you're doing, from my brief glossing over, is effectively to do a k-means clustering with outlier rejection pass with k = 1. You then use the cluster mean to update the model.
What I'm curious of, assuming the above is correct, is what would happen if you combined the approaches.
That is, do a k-means clustering of the microbatch gradients with k > 1, still rejecting outliers but perhaps with lower threshold, generate k updated models using the k cluster means, and then average the k models afterwards.
I've used something similar to k-means clustering with outlier rejection for noise filtering and it was quite effective, so curious how it would work out here.
[1]: https://arxiv.org/abs/2203.05482
I think the idea of averaging the k-models afterwards though is at odds with the core concept of gradient agreement filtering though because you're back at combining two distinct directions of improvement without a guarantee that the combination is better (even though it does seem to be in practice). The the core idea is that you philosophically only want to learn the patterns that agree across multiple specific examples and build some some algorithmic protections to ensure that is happening. Just averaging, while it might work and even yield improvement, but it doesn't necessarily lead to proper generalized learning.
So in this case I was more thinking if there were perhaps a few directions that stood out, and instead of potentially rejecting those, consider each a microfine-tune, averaging the result at each step ala the uniform soup.
Though perhaps a stupid idea, I'm not a practitioner.
Algorithm 1 suggests that each node starts gradient aggregation from their local micro-gradient $g$. Since the order of aggregation matters, \theta would likely diverge after apply the step with $g_{\mathrm{GAF}}$ --- even if models on different nodes are initialized with the same weights. Hence, I would expect there to be a weight-synchronization step after each macro-gradient step. Do you have such a step? If so, how do you implement consensus? Simply via averaging?
1. Do you think this is a form of variance reduction or more a form of curriculum (focus first on the bulk, then on remaining errors)? 2. Did you observe any overfitting/additional adversarial risk? 3. Did you try this on just single-node minibatches as well? How did that perform?
I'd say generally more of a cirriculum (using your terminology). Broadly speaking the idea is to restrict stepping to "high-quality" directions where there is agreement/consistency in the direction of update.
> 2. Did you observe any overfitting/additional adversarial risk?
No, actually one of the coolest findings of the work is that when training with GAF we have found that it prevents overfitting. It might slow/down stop training improvement, but it also prevents overfitting. Essentially what happens in late training when overfitting would occur the gradient directions become orthogonal, when that occurs GAF instead means you just don't take a step. Training ends up plateauing as it becomes harder to find two minibatches that have agreement so you end up having more no-op epochs, but you don't overfit. I think we still have one training run going (after months) on CIFAR-100N-Fine that has yet to overfit. It's still slowly improving, last time we check train and val were both around ~60%.
Adversarial risk is an interesting question, but this should help with that as well provided that the adversarial examples are a minority of the training data and that the adversarial attack comes from overfitting / memorizing the adversarial part of those examples.
> 3. Did you try this on just single-node minibatches as well? How did that perform?
The number of nodes is more of a performance implementation detail in terms of to what extent/scale you parallelsize. For the technique to work you just need to 2+ macrobatches that you can compare to determine to take your step. CIFAR-100-N is small enough that you can run multiple minibatches on a single GPU (node) and it all fits into VRAM. Even it it didn't fit into VRAM you could theoretically save off the gradient to disk before taking a step and the technique would still apply/help/work, it would just be slower.
Yes
Second, and this is more of a hypothetical question for my own understanding rather than a practical one - in a single GPU scenario, could you take compute the loss per-sample without averaging (i.e. "reduce=None" in pytorch), and improve (on a sample efficiency basis) single GPU training with your algorithm? Sorry if this was covered in the paper already.
My loss metrics stay roughly the same (they're slightly lower, but SD loss is fraught to interpret because variance by timestep renders it more or less meaningless), but tracking the means of `param.grad.norm / param.numel` (which shows how big the grad updates are) shows the grads stabilizing significantly quicker than baseline. I'm tracking suppressed params / total params via tensorboard, and I show that it drops (as expected) but then stabilizes at around 7%, suggesting that there are model parameters which consistently don't agree. I'm gonna try tracking the variance from the mean, as well, and perhaps down-weight or eliminate grads for parameters which show high cos similarity variance over time (suggesting a generalized lack of agreement in the direction to move, further suggesting that the parameter cannot contribute meaningfully to the task).
The accuracy improvement is great, but I'm really looking forward to the reduction in computation required!
[1]: https://arxiv.org/abs/2009.00329
Which seems like it would be roughly equivalent to throwing away gradients if they are inconsistent...?
https://arxiv.org/abs/2203.10036
Pretty cool, imho, that there are finding better ways to train in parallel.
This is indeed super cool !
https://github.com/IBM/FedMA