The Non-IID Data Quagmire of Decentralized Machine Learning ICML 2020 Kevin Hsieh , Amar Phanishayee, Onur Mutlu, Phillip Gibbons
ML Training with Decentralized Data Geo-Distributed Learning Federated Learning Data Sovereignty and Privacy 2
Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 1: Communication Bottlenecks Solutions: Federated Averaging, Gaia, Deep Gradient Compression 3
Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 2: Data are often highly skewed (non-iid data) Solutions: Understudied! Is it a real problem? 4
Our Work in a Nutshell Real-World Experimental Proposed Dataset Study Solution 5
Geographical mammal images from Flickr 736K pictures in 42 mammal classes Highly skewed labels among Real-World Dataset geographic regions 6
Skewed data labels are a fundamental and pervasive problem The problem is even worse for DNNs with batch normalization Experimental The degree of skew determines the Study difficulty of the problem 7
Replace batch normalization with group normalization SkewScout : communication-efficient decentralized learning over Proposed arbitrarily skewed data Solution 8
Real-World Dataset 9
Flickr-Mammal Dataset 42 mammal classes from Open Images 40,000 and ImageNet Clean images Reverse images with PNAS geocoding to per class [Liu et al.,’18] country, subcontinent, and continent https://doi.org/10.5281/zenodo.3676081 736K Pictures with Labels and Geographic Information
Top-3 Mammals in Each Continent Each top-3 mammal takes 44-92% share of global images 11
100% 10% 20% 30% 40% 50% 60% 70% 80% 90% 0% Vast majority of mammals are dominated by 2-3 continents Label Distribution Across Continents alpaca The labels are even more skewed among subcontinents antelope armadillo brown bear bull camel cat cattle Africa cheetah deer dolphin elephant fox Americas goat hamster harbor seal hedgehog hippopotamus jaguar Asia kangaroo koala leopard lion Europe lynx monkey mule otter panda pig Oceania polar bear porcupine rabbit red panda sea lion sheep skunk squirrel teddy bear tiger 12 whale zebra
Experimental Study 13
Scope of Experimental Study Decentralized Learning Skewness of Data ML Application Algorithms Label Partitions × × Gaia [NSDI’17] 2-5 Partitions -- Image Classification • FederatedAveraging [AISTATS’17] more partitions are worse (with various DNNs DeepGradientCompression [ICLR’18] and datasets) Face recognition •
Results: GoogLeNet over CIFAR-10 BSP (Bulk Synchronous Parallel) Gaia (20X faster than BSP) FederatedAveraging (20X faster than BSP) DeepGradientCompression (30X faster than BSP) 80% Top-1 Validation -12% -15% 60% Accuarcy 40% 20% -69% 0% Shuffled Data Skewed Data All decentralized learning algorithms lose significant accuracy Tight synchronization (BSP) is accurate but too slow 15
Skewed data is a pervasive and fundamental problem Similar Results across the Board Even BSP loses accuracy for DNNs with Batch Normalization layers BSP Gaia FederatedAveraging DeepGradientCompression 90% Top-1 Validation 45% Accuracy 0% Shuffled Data Skewed Data Shuffled Data Skewed Data Shuffled Data Skewed Data AlexNet LeNet ResNet20 80% Image Classification (CIFAR-10) Top-1 Validation BSP Gaia FedAvg BSP Gaia FedAvg 40% 100% 100% Accuracy 80% 50% 0% Shuffled Skewed Shuffled Skewed 60% 0% Data Data Data Data Shuffled Skewed Shuffled Skewed GoogLeNet ResNet10 Data Data Data Data Image Classification Face Recognition Image Classification (Mammal-Flickr) (CASIA and test with LFW) (ImageNet)
Degree of Skew is a Key Factor 20% Skewed Data 40% Skewed Data 60% Skewed Data 80% Skewed Data -1.5% -3.0% -0.5% -1.3% -1.1% -3.5% -2.6% 80% -4.8% Top-1 Validation -5.1% -5.3% -6.5% 75% -8.5% Accuracy 70% 65% 60% BSP Gaia Federated Averaging Deep Gradient Compression CIFAR-10 with GN-LeNet Degree of skew can determine the difficulty of the problem 17
Batch Normalization ― Problem and Solution 18
Background: Batch Normalization [Ioffe & Szegedy, 2015] Prev Next W BN Layer Layer Normalize with Standard normal distribution estimated global μ and σ ( μ = 0, σ = 1) in each minibatch at test time at training time Batch normalization enables larger learning rates and avoid sharp local minimum (generalize better)
Batch Normalization with Skewed Data Shuffled Data Skewed Data Minibatch Mean Divergence: 70% ||Mean 1 – Mean 2 || / AVG(Mean 1 , Mean 2 ) Minibatch Mean Divergence 35% 0% 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 Channel CIFAR-10 with BN-LeNet (2 Partitions) Minibatch μ and σ vary significantly among partitions Global μ and σ do not work for all partitions 20
Solution: Use Group Normalization [Wu and He, ECCV’18] Group Normalization Batch Normalization H, W H, W C C N N Designed for small minibatches We apply as a solution for skewed data 21
Results with Group Normalization Shuffled Data Skewed Data 0% 80% -10% -9% Validation Accuracy -12% -15% -26% 60% -29% 40% 20% -70% 0% BSP Gaia Federated Deep BSP Gaia Federated Deep Averaging Gradient Averaging Gradient Compression Compression BatchNorm GroupNorm GroupNorm recovers the accuracy loss for BSP and reduces accuracy losses for decentralized algorithms 22
wScout : Decentralized learning Sk SkewScout over arbitrarily skewed data 23
Overview of Sk SkewScout wScout • Recall that degree of data skew determines difficulty : Adapts communication to the • Sk SkewScout wScout: skew-induced accuracy loss Model Travelling Accuracy Loss Estimation Communication Control Minimize commutation when accuracy loss is acceptable Work with different decentralized learning algorithms
Evaluation of Sk SkewScout wScout All data points achieves the same validation accuracy SkewScout Oracle SkewScout Oracle 60 50 Communication Saving 51.8 42.1 50 40 over BSP (times) 29.6 40 34.1 30 23.6 30 24.9 19.1 19.9 20 20 11.0 9.9 10.6 9.6 10 10 0 0 20% Skewed 60% Skewed 100% Skewed 20% Skewed 60% Skewed 100% Skewed CIFAR-10 with AlexNet CIFAR-10 with GoogLeNet Significant saving over BSP Only within 1.5X more than Oracle 25
Key Takeaways • Flickr-Mammal dataset: Highly skewed label distribution in the real world • Skewed data is a pervasive problem • Batch normalization is particularly problematic • SkewScout : adapts decentralized learning over arbitrarily skewed data • Group normalization is a good alternative to batch normalization 26
Recommend
More recommend