London Ohio 140ms Distributed Training Across the World 183ms 23Mbps California 35ms Tokyo 17Mbps 63Mbps Ligeng Zhu , Yao Lu, Hongzhou Lin, Yujun Lin, Song Han 277ms 13Mbps Neurips 19 MLSys 1
Why Distributed Training? • Model sizes • AlexNet (7 layers) -> VGG (16 layers) -> ResNet (152 layers) • Dataset sizes • CIFAR (50k) -> ImageNet (1.2M) -> Google JFG (300M) Even with modern GPU, says eight-V100 server, it still takes days and even weeks to train a model. 2
What is Distributed Training? Conventional SGD Distributed SGD 1. Sample (X, y) from dataset 1. Sample (X, y) from dataset 2. Forward to compute loss. 2. Forward to compute loss. 3. Backward to compute gradients. 3. Backward to compute gradients. 4. Apply gradients to update model. 4. Synchronize gradients 5. Apply gradients to update model. 3
Why Learning across Geographical Locations? It's not who has the best algorithm that wins. It’s who has the most data. —Andrew Ng However, it is always difficult to collect data (even illegal sometimes). 4
Why Learning across Geographical Locations? Collaborative / Federated Learning Data never leaves local device. 5
Communication Limits Scalability Latency Bandwidth • Infinity band: < 0.002 ms • Infinity band: up to 100 Gb/s • Mobile network: ~50ms (4G) / ~10ms (5G) • Mobile network: 100 Mb/s (4G), 1Gb/s (5G) Shanghai - Boston: What we need • 10 Mb/s with a high variance. • Bandwidth as high 900 Mb/s. • 78ms (ideal) / > 700ms (real world) • Latency as low as 1ms . 11,725 km × 2/(3 × 10 8 m/s ) = 78.16 ms Bandwidth is easy to increase. Latency is hard to improve : ( 6
Latency is critical 7
London Ohio 140ms Distributed Training Across the World 183ms 23Mbps California 35ms Tokyo 17Mbps 63Mbps Ligeng Zhu , Yao Lu, Hongzhou Lin, Yujun Lin, Song Han 277ms 13Mbps Neurips 19 MLSys 8
Delayed Update: sync stale gradients Conventional Distributed SGD at step i Delayed Update at step i 1. Sample (X, y) from dataset 1. Sample (X, y) from dataset 2. Forward to compute loss. 2. Forward to compute loss. 3. Backward to compute gradients. 3. Backward to compute gradients. 4. Synchronize step i ’s gradients 4. Synchronize step (i - t) ’s gradients 5. Apply gradients to update model. 5. Apply gradients to update model. 9
Delayed Update: put off the sync barrier Normal Distributed 1 2 3 4 1 2 3 4 …… I 2 3 4 Gradients from time stamp i are synced before (i+1)th update. Delayed Distributed 1 2 3 4 1 2 3 …… 1 2 3 4 Gradients from time stamp (i-t) are synced before (i+1)^th update. 10
Preserve Accuracy by Compensation Sync gradients (i-t) are synced at step i update. For example, if t = 2 Vanilla SGD Delayed Update n − 1 ∑ w n = w 0 − γ v i i =0 w 3 = w 0 − γ ( v 0 + v 1 + v 2 ) w 3 = w 0 − γ ( v 0 + v 1 + v 2 ) − γ ( v 0 − v 0 ) = w 0 − γ ( v 0 + v 1 + v 2 ) 11
<latexit sha1_base64="a2L7QmEFTvK2EMp86o0ULMi8C2g=">ACvHicdVFNb9QwEHXCR8vytcCRi8UKCYl2lSAkyqGogh4FsG2lTZL5DiTXcdJ7InhZXlPwkn/g2T7UqFkay9DQf7808F61WDpPkVxTfuHnr9tb2ncHde/cfPBw+enzsms5KmMhGN/a0EA60MjBhRpOWwuiLjScFMsPf3kHKxTjfmCqxZmtZgbVSkpkFL58Oe3Juds8D3OaEk8Jc860wJtrBCgs9cV+de7Sfhqze76S4GnjXE18v57BA0in5OhRBynyF8R/9Z1MCF46axtdC8pBusKjqEkqMVyigzD/+TMSwFiKZS/IdfnZJf6iqCiwYCVyKzhFrseK6kSTVtaVAoNbhKBkn6+DXQboBI7aJo3z4Iysb2dVgUGrh3DRNWpx5YVFJDWGQkU4r5FLMYUrQ0IVu5tfmB/58vUTVWHoG+Tr754QXtXOruqDOWuDCXa31yX/Vph1WezOvTEveGXkhVHWaY8P7nyRnLUjUKwJCWkW7crkQ5CfSfw/IhPTqydfB8atxmozT69HB+83dmyzp+wZe8FS9oYdsI/siE2YjN5GebSIVPwuLuNlXF+0xtFm5gn7K+Lz3y+12+Y=</latexit> <latexit sha1_base64="a2L7QmEFTvK2EMp86o0ULMi8C2g=">ACvHicdVFNb9QwEHXCR8vytcCRi8UKCYl2lSAkyqGogh4FsG2lTZL5DiTXcdJ7InhZXlPwkn/g2T7UqFkay9DQf7808F61WDpPkVxTfuHnr9tb2ncHde/cfPBw+enzsms5KmMhGN/a0EA60MjBhRpOWwuiLjScFMsPf3kHKxTjfmCqxZmtZgbVSkpkFL58Oe3Juds8D3OaEk8Jc860wJtrBCgs9cV+de7Sfhqze76S4GnjXE18v57BA0in5OhRBynyF8R/9Z1MCF46axtdC8pBusKjqEkqMVyigzD/+TMSwFiKZS/IdfnZJf6iqCiwYCVyKzhFrseK6kSTVtaVAoNbhKBkn6+DXQboBI7aJo3z4Iysb2dVgUGrh3DRNWpx5YVFJDWGQkU4r5FLMYUrQ0IVu5tfmB/58vUTVWHoG+Tr754QXtXOruqDOWuDCXa31yX/Vph1WezOvTEveGXkhVHWaY8P7nyRnLUjUKwJCWkW7crkQ5CfSfw/IhPTqydfB8atxmozT69HB+83dmyzp+wZe8FS9oYdsI/siE2YjN5GebSIVPwuLuNlXF+0xtFm5gn7K+Lz3y+12+Y=</latexit> <latexit sha1_base64="a2L7QmEFTvK2EMp86o0ULMi8C2g=">ACvHicdVFNb9QwEHXCR8vytcCRi8UKCYl2lSAkyqGogh4FsG2lTZL5DiTXcdJ7InhZXlPwkn/g2T7UqFkay9DQf7808F61WDpPkVxTfuHnr9tb2ncHde/cfPBw+enzsms5KmMhGN/a0EA60MjBhRpOWwuiLjScFMsPf3kHKxTjfmCqxZmtZgbVSkpkFL58Oe3Juds8D3OaEk8Jc860wJtrBCgs9cV+de7Sfhqze76S4GnjXE18v57BA0in5OhRBynyF8R/9Z1MCF46axtdC8pBusKjqEkqMVyigzD/+TMSwFiKZS/IdfnZJf6iqCiwYCVyKzhFrseK6kSTVtaVAoNbhKBkn6+DXQboBI7aJo3z4Iysb2dVgUGrh3DRNWpx5YVFJDWGQkU4r5FLMYUrQ0IVu5tfmB/58vUTVWHoG+Tr754QXtXOruqDOWuDCXa31yX/Vph1WezOvTEveGXkhVHWaY8P7nyRnLUjUKwJCWkW7crkQ5CfSfw/IhPTqydfB8atxmozT69HB+83dmyzp+wZe8FS9oYdsI/siE2YjN5GebSIVPwuLuNlXF+0xtFm5gn7K+Lz3y+12+Y=</latexit> <latexit sha1_base64="a2L7QmEFTvK2EMp86o0ULMi8C2g=">ACvHicdVFNb9QwEHXCR8vytcCRi8UKCYl2lSAkyqGogh4FsG2lTZL5DiTXcdJ7InhZXlPwkn/g2T7UqFkay9DQf7808F61WDpPkVxTfuHnr9tb2ncHde/cfPBw+enzsms5KmMhGN/a0EA60MjBhRpOWwuiLjScFMsPf3kHKxTjfmCqxZmtZgbVSkpkFL58Oe3Juds8D3OaEk8Jc860wJtrBCgs9cV+de7Sfhqze76S4GnjXE18v57BA0in5OhRBynyF8R/9Z1MCF46axtdC8pBusKjqEkqMVyigzD/+TMSwFiKZS/IdfnZJf6iqCiwYCVyKzhFrseK6kSTVtaVAoNbhKBkn6+DXQboBI7aJo3z4Iysb2dVgUGrh3DRNWpx5YVFJDWGQkU4r5FLMYUrQ0IVu5tfmB/58vUTVWHoG+Tr754QXtXOruqDOWuDCXa31yX/Vph1WezOvTEveGXkhVHWaY8P7nyRnLUjUKwJCWkW7crkQ5CfSfw/IhPTqydfB8atxmozT69HB+83dmyzp+wZe8FS9oYdsI/siE2YjN5GebSIVPwuLuNlXF+0xtFm5gn7K+Lz3y+12+Y=</latexit> Preserve Accuracy by Compensation n − 1 − t n − 1 X X ∆ w i ∆ w i,j w n,j = w 0 + + i =0 i = n − t | {z } | {z } Same as normal distributed training Di ff erence caused by local update Global information Local gradients ) + O t 2 J 1 O ( Theoretical Convergence: N NJ 1 O ( ) Convergence of SGD: NJ 12
Recommend
More recommend