deep transfer learning with joint adaptation networks
play

Deep Transfer Learning with Joint Adaptation Networks Mingsheng Long - PowerPoint PPT Presentation

Deep Transfer Learning with Joint Adaptation Networks Mingsheng Long 1 , Han Zhu 1 , Jianmin Wang 1 Michael I. Jordan 2 1 School of Software, Institute for Data Science Tsinghua University 2 Department of EECS, Department of Statistics University


  1. Deep Transfer Learning with Joint Adaptation Networks Mingsheng Long 1 , Han Zhu 1 , Jianmin Wang 1 Michael I. Jordan 2 1 School of Software, Institute for Data Science Tsinghua University 2 Department of EECS, Department of Statistics University of California, Berkeley https://github.com/thuml International Conference on Machine Learning, 2017 M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 1 / 25

  2. Motivation Outline Motivation 1 Deep Transfer Learning Related Work Main Idea Method 2 Kernel Embedding JMMD JAN Experiments 3 M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 2 / 25

  3. Motivation Deep Transfer Learning Deep Learning ( ) ~ P x , y ( ) f : x → y Learner: Distribution: x , y fish bird mammal tree flower …... complexity  test ≤ ˆ  train + Error Bound: n M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 3 / 25

  4. Motivation Deep Transfer Learning Deep Transfer Learning Deep learning across domains of different distributions P � = Q Source Domain Target Domain 2D Renderings Real Images P ( x , y ) ≠ Q ( x , y ) Model Representation Model f : x → y f : x → y http://ai.bu.edu/visda-2017/ M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 4 / 25

  5. Motivation Deep Transfer Learning Deep Transfer Learning: Why? Training Set Train-Dev Set Dev Set Test Set Optimal Bayes Rate Deeper Model Training Error high? Bias Longer Training Yes No Bigger Data Variance Train-Dev Error high? Regularization Yes No Transfer Learning Dev Error high? Dataset Shift Data Generation Yes No Overfit Dev Set Bigger Dev Data Test Error high? Yes No Andrew Ng. The Nuts and Bolts of Building Applications using Deep Done! Learning. NIPS 2016 Tutorial. M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 5 / 25

  6. Motivation Related Work Deep Transfer Learning: How? Learning predictive models on transferable features s.t. P ( x ) = Q ( x ) Distribution matching: MMD (ICML’15), GAN (ICML’15, JMLR’16) 98% 72% Transferring Features Adaptation P ( x ) Q ( x ) P ( x ) = Q ( x ) Learning Learning Features Features Supervised 99% Predictive 28% Models Source Domain Target Domain M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 6 / 25

  7. Motivation Related Work Deep Adaptation Network (DAN) learn learn learn learn fine- fine- source frozen frozen frozen tune tune output MK- MK- MK- MMD MMD MMD target output input conv1 conv2 conv3 conv4 conv5 fc6 fc7 fc8 Deep adaptation: match distributions in multiple domain-specific layers Optimal matching: maximize two-sample test power by multiple kernels � x t ��� � � � 2 d 2 � E P [ φ ( x s )] − E Q k ( P , Q ) � (1) φ H k � � n a l 2 � � 1 J ( θ ( x a i ) , y a d 2 D ℓ s , D ℓ min θ ∈ Θ max i ) + λ (2) t k n a k ∈K i =1 ℓ = l 1 M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 7 / 25

  8. Motivation Related Work Domain Adversarial Neural Network (DANN) Adversarial adaptation: learning features indistinguishable across domains � � E ( θ f , θ y , θ d ) = L y ( G y ( G f ( x i )) , y i ) − λ L d ( G d ( G f ( x i )) , d i ) (3) x i ∈D s x i ∈D s ∪D t (ˆ θ f , ˆ (ˆ θ y ) = arg min θ f ,θ y E ( θ f , θ y , θ d ) θ d ) = arg max θ d E ( θ f , θ y , θ d ) (4) M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 8 / 25

  9. Motivation Main Idea Behavior of Existing Work Adaptation of marginal distributions P ( x ) and Q ( x ) is not sufficient Before Adaptation After Adaptation P ( x ) ≠ Q ( x ) P ( x ) ≈ Q ( x ) M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 9 / 25

  10. Motivation Main Idea Main Idea of This Work Directly model and match joint distributions P ( x , y ) and Q ( x , y ) Match Marginal Distributions Match Joint Distributions P ( x ) ≈ Q ( x ) P ( x , y ) ≈ Q ( x , y ) M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 10 / 25

  11. Method Outline Motivation 1 Deep Transfer Learning Related Work Main Idea Method 2 Kernel Embedding JMMD JAN Experiments 3 M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 11 / 25

  12. Method Kernel Embedding Kernel Embedding of Distributions Le Song et al. Kernel Embeddings of Conditional Distributions. IEEE, 2013. M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 12 / 25

  13. Method Kernel Embedding Kernel Embedding of Joint Distributions � � n � C X 1: m = 1 ℓ =1 φ ℓ ( X ℓ ) ≈ � ℓ =1 φ ℓ ( x ℓ C X 1: m ( P ) � E X 1: m ⊗ m ⊗ m i ) (5) n i =1 Le Song et al. Kernel Embeddings of Conditional Distributions. IEEE, 2013. M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 13 / 25

  14. Method JMMD Joint Maximum Mean Discrepancy (JMMD) Distance between embeddings of P ( Z s 1 , . . . , Z s |L| ) and Q ( Z t 1 , . . . , Z t |L| ) D L ( P , Q ) � �C Z s , 1: |L| ( P ) − C Z t , 1: |L| ( Q ) � 2 ℓ =1 H ℓ . (6) ⊗ |L| k ℓ � � n s n s � � � D L ( P , Q ) = 1 � z s ℓ i , z s ℓ j n 2 s i =1 j =1 ℓ ∈L n t n t k ℓ � � � � � + 1 z t ℓ i , z t ℓ (7) j n 2 t i =1 j =1 ℓ ∈L k ℓ � � n s n t � � � 2 z s ℓ i , z t ℓ − . j n s n t ℓ ∈L i =1 j =1 Theorem (Two-Sample Test (Gretton et al. 2012)) P = Q if and only if � D L ( P , Q ) = 0 (In practice, � D L ( P , Q ) < ε ) M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 14 / 25

  15. Method JMMD How to Understand JMMD? Set last-layer features Z = Z L − 1 , classifier predictions Y = Z L ∈ R C We can understand JMMD( Z , Y ) by simplifying it to linear kernel This interpretation assumes classifier predictions Y be one-hot vector � � 2 � � n s n t � � � � 1 i − 1 � � z s i ⊗ y s z t j ⊗ y t � D L ( P , Q ) � � � j n s n t � � i =1 j =1 � � 2 � � � C � n s � n t � � 1 i − 1 (8) � � y s i , c z s y t j , c z t = � � j n s n t � � c =1 i =1 j =1 C � � � � ≈ D P Z | y = c , Q Z | y = c c =1 Equivalent to matching distributions P and Q conditioned on each class! M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 15 / 25

  16. Method JMMD How to Understand JMMD? JMMD can process continuous softmax activations (probability) In practice, Gaussian kernel is used for matching all orders of moments M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 16 / 25

  17. Method JAN Joint Adaptation Network (JAN) X s Z s1 Z s| L | Y s φ 1 φ L ✖ AlexNet VGGnet GoogLeNet tied JMMD tied ResNet …… φ 1 φ L ✖ X t Z t1 Z t| L | Y t Joint adaptation: match joint distributions of multiple task-specific layers n s � 1 i ) + λ � J ( f ( x s i ) , y s min D L ( P , Q ) (9) n s f i =1 D L ( P , Q ) � �C Z s , 1: |L| ( P ) − C Z t , 1: |L| ( Q ) � 2 (10) ⊗ |L| ℓ =1 H ℓ M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 17 / 25

  18. Method JAN Learning Algorithm Linear-Time O ( n ) Algorithm of JMMD (Streaming Algorithm) �� � n / 2 � � D L ( P , Q ) = 2 � k ℓ ( z s ℓ 2 i − 1 , z s ℓ k ℓ ( z t ℓ 2 i − 1 , z t ℓ 2 i ) + 2 i ) n ℓ ∈L ℓ ∈L i =1 �� � n / 2 � � − 2 k ℓ ( z s ℓ 2 i − 1 , z t ℓ k ℓ ( z t ℓ 2 i − 1 , z s ℓ (11) 2 i ) + 2 i ) n i =1 ℓ ∈L ℓ ∈L n / 2 � � � = 2 { z s ℓ 2 i − 1 , z s ℓ 2 i , z t ℓ 2 i − 1 , z t ℓ d ℓ ∈L ) 2 i n i =1 � � z s ℓ 2 i − 1 , z s ℓ 2 i , z t ℓ 2 i − 1 , z t ℓ SGD: for each layer ℓ and for each quad-tuple 2 i � � � � { z s ℓ 2 i − 1 , z s ℓ 2 i , z t ℓ 2 i − 1 , z t ℓ z s 2 i − 1 , z s 2 i , y s 2 i − 1 , y s ∂ d ℓ ∈L ) ∇ W ℓ = ∂ J 2 i 2 i + λ (12) ∂ W ℓ ∂ W ℓ M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 18 / 25

  19. Method JAN Adversarial Joint Adaptation Network (JAN-A) X s Z s1 Z s| L | Y s φ 1 φ L θ ✖ θ AlexNet VGGnet tied JMMD tied GoogLeNet ResNet …… θ ✖ θ φ 1 φ L X t Z t1 Z t| L | Y t Optimal matching: maximize JMMD as semi-parametric domain adversary n s � 1 i ) + λ � J ( f ( x s i ) , y s min max D L ( P , Q ; θ ) (13) n s f θ i =1 n / 2 � � � D L ( P , Q ; θ ) = 2 � { θ ℓ ( z s ℓ 2 i − 1 , z s ℓ 2 i , z t ℓ 2 i − 1 , z t ℓ d 2 i ) } ℓ ∈L (14) n i =1 M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 19 / 25

  20. Experiments Outline Motivation 1 Deep Transfer Learning Related Work Main Idea Method 2 Kernel Embedding JMMD JAN Experiments 3 M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 20 / 25

  21. Experiments Datasets O ffi ce-Caltech Fine-tune ImageCLEF Challenge 2014 Fine-tune Pre-train Fine-tune VisDA Challenge 2017 M. Long et al. (Tsinghua Univ.) JAN: Joint Adaptation Networks ICML 2017 21 / 25

Recommend


More recommend