optimizing federated learning on non iid data with
play

Optimizing Federated Learning on Non-IID Data with Reinforcement - PowerPoint PPT Presentation

INFOCOM20 Optimizing Federated Learning on Non-IID Data with Reinforcement Learning Hao Wang *, Zakhary Kaplan*, Di Niu^, Baochun Li* *University of Toronto, ^University of Alberta < < > > Alexa Siri 2 Machine


  1. INFOCOM’20 Optimizing Federated Learning on Non-IID Data with Reinforcement Learning Hao Wang *, Zakhary Kaplan*, Di Niu^, Baochun Li* *University of Toronto, ^University of Alberta

  2. … < < …> …> Alexa Siri 2

  3. Machine Learning

  4. Federated Learning

  5. Federated Averaging Algorithm (FedAvg) … 5

  6. Random selection Local model … Local data 6

  7. Random selection Local model … Local data 6

  8. Thank you for the feedback Local model … Local data 7

  9. ML algorithms assume the training data is i ndependent and identically distributed (IID) 8

  10. Federated Learning reuses the existing ML algorithms but on non-IID data 9

  11. … … < > … < > 10

  12. … … < > 10

  13. Non-IID data introduces bias into the training and leads to a slow convergence and training failures 11

  14. MNIST http://yann.lecun.com/exdb/mnist/ 12

  15. FedAvg-IID FedAvg-non-IID 100 97 Accuracy (%) 95 93 91 1 10 19 28 37 46 55 64 73 82 91 100109 118 127 136145154 Communication Round (#) 13

  16. Build IID training data? No, w e don’t have any access to the data on your phone. 14

  17. Shared Data α × Shared α × Shared Data Data α × Shared Data α × Shared α × Shared α × Shared Data Data Data Private Private Private Data Data Data Figure 6: Illustration of the data- Zhao, Yue, et al. "Federated Learning with Non-IID Data." arXiv preprint arXiv:1806.00582 (2018). 15

  18. Optimizing Federated Learning on Non-IID Data with Reinforcement Learning [INFOCOM’20] 16

  19. Build IID training data? No Peeking into the data distribution on each device without violating data privacy Probing the bias of non-IID data 17

  20. … Carefully select devices to balance the bias introduced by non-IID data … < > 18

  21. Probing the data distribution

  22. 100 devices, each has 600 samples Non-IID data 80% data has the same label, e.g, “6” Initial model A two-layer CNN model with 431,080 parameters Local model 20

  23. We apply Principle Component Analysis (PCA) to reduce dimensionality 431,080-dimension model weight 2-dimension space 21

  24. − 0.05 0 0.05 0.4 0 0.3 0.2 − 0.05 C1 0.1 − 0.10 0 − 0.1 − 0.2 − 0.2 0 0.2 0.4 0.6 C0 22

  25. … An implicit connection between model weights and data distribution … 23

  26. Probing the data distribution Selecting devices for federated learning

  27. < > < > 25

  28. − 0.05 0 0.05 0.4 0 0.3 0.2 − 0.05 C1 0.1 − 0.10 0 − 0.1 − 0.2 − 0.2 0 0.2 0.4 0.6 C0 26

  29. K-Center Clustering … 27

  30. Random Selection from Groups … 28

  31. FedAvg-IID FedAvg-non-IID K-Center-non-IID 100 97 Accuracy (%) 95 93 91 1 31 61 91 121 151 Communication Round (#) 29

  32. Probing the data distribution Selecting devices for federated learning How to select devices to speed up training ?

  33. It is difficult to select the appropriate subset of devices - Model weights —> device selection choice - A dynamic and undeterministic problem Reinforcement Learning (RL) 31

  34. Reward Action … Environment Agent FL server State (…,state, action, reward, state’, action’, …,end) Episode 32

  35. (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) Learn to maximize sum(reward) (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) … (…,state, action, reward , state’, action’, …,end) (…,state, action, reward , state’, action’, …,end) 34

  36. States Global weights Local model weights … < > 100-dimension vector 35

  37. Actions Select K devices from a pool of N devices — a huge action space Selecting 10 devices from a pool of 100 devices leads to 1.7310309e+13 possible actions 36

  38. Modify the RL training algorithm

  39. Selecting the Top K Devices Only one device is selected during the RL training Now the action space is {1, 2, …, N} , instead of selecting K devices from N devices 38

  40. Evaluating Each Device Scores 0.3 0.5 Select the top K 0.1 … … … … 0.2 39

  41. Rewards Ξ Positive constant r t = Ξ ( ω t −Ω ) − 1 ω t Training Accuracy Ω 0 ⩽ ω t ⩽ Ω ⩽ 1 Target accuracy Communication t r t ∈ ( − 1,0] round # ! Accuracy increase: r t ω t ⬆ —> ⬆ " More communication rounds: ⬆ —> sum( ) ⬇ t r t 40

  42. Training the DRL Agent Look for a function that points out the actions leading to the maximum cumulative return under a particular state T T γ t − 1 ( Ξ ( ω t −Ω ) − 1) ∑ ∑ γ t − 1 r t = R = Max t =1 t =1 discount factor γ ∈ (0,1) 41

  43. r t Reward Agent DDQN Environment Features softmax a t … … … … … Action FL server s t − 1 State 42

  44. 0 Cumulative Discounted Reward -28 -55 -83 Training the DRL agent -110 1 11 21 31 41 51 61 71 81 91 101 111 121 131 141 151 161 171 Episode 43

  45. Check-in Selection Update … Probing Update weight DRL agent DRL agent 44

  46. Evaluating Our Solution Benchmark: MNIST, FashionMNIST, CIFAR-10 Non-IID level: 1, half-and-half, 80%, 50% Half-and-half 80% 45

  47. FedAvg K-Center Favor 2200 Communication Rounds 1650 Non-IID level 1100 1 550 0 MNIST FashionMNIST CIFAR-10 46

  48. FedAvg K-Center Favor 1600 Communication Rounds 1200 Non-IID level 800 half & half 400 0 MNIST FashionMNIST CIFAR-10 47

  49. FedAvg K-Center Favor 240 Communication Rounds 180 Non-IID level 120 80% 60 0 MNIST FashionMNIST CIFAR-10 48

  50. FedAvg K-Center Favor 70 Communication Rounds 53 Non-IID level 35 50% 18 0 MNIST FashionMNIST CIFAR-10 49

  51. w init 1.5 Local weights Global weights w 1 1.0 C2 0.5 w 2 FedAvg w 3 w 4 w 5 0 − 0.5 1.0 1.5 2.0 2.5 3.0 C1 w init 1.5 Local weights Global weights w 1 1.0 C2 Favor 0.5 w 2 w 3 w 4 0 − 0.5 1.0 1.5 2.0 2.5 3.0 C1 50

  52. Indirect data distribution probing DRL-based device selection Communication rounds can be reduced by up to • 49% on the MNIST • 23% on FashionMNIST • 42% on CIFAR-10 51

Recommend


More recommend