Partial Transfer Learning with Selective Adversarial Networks Zhangjie Cao 1 , Mingsheng Long 1 , Jianmin Wang 1 , and Michael I. Jordan 2 1 KLiss, MOE; School of Software, Tsinghua University, China 1 National Engineering Laboratory for Big Data Software 2 University of California, Berkeley, Berkeley, CA, USA IEEE Conference on Computer Vision and Pattern Recognition CVPR 2018 (Spotlight) Z. Cao et al. (Tsinghua University) SAN CVPR 2018 1 / 16
Motivation Deep Transfer Learning Deep learning across domains of different distributions P � = Q Source Domain Target Domain 2D Renderings Real Images P ( x , y ) ≠ Q ( x , y ) Model Representation Model f : x → y f : x → y http://ai.bu.edu/visda-2017/ Z. Cao et al. (Tsinghua University) SAN CVPR 2018 2 / 16
Motivation Deep Transfer Learning: Why? Training Set Train-Dev Set Dev Set Test Set Optimal Bayes Rate Deeper Model Training Error high? Bias Longer Training Yes No Bigger Data Variance Train-Dev Error high? Regularization Yes No Transfer Learning Dev Error high? Dataset Shift Data Generation Yes No Overfit Dev Set Bigger Dev Data Test Error high? Yes No Andrew Ng. The Nuts and Bolts of Building Applications using Deep Done! Learning. NIPS 2016 Tutorial. Z. Cao et al. (Tsinghua University) SAN CVPR 2018 3 / 16
Motivation Partial Transfer Learning Deep learning across domains with different label spaces C s ⊃ C t Positive transfer across domains in shared label space P C t � = Q C t Negative transfer across domains in outlier label space P C s \C t � = Q C t target domain chair source domain + + + + + mug + + + + + + + + + + ++ + + TV + + + + + + + chair mug Z. Cao et al. (Tsinghua University) SAN CVPR 2018 4 / 16
Method Partial Transfer Learning: How? Matching distributions across the source and target domains s.t. P ≈ Q Reduce marginal distribution mismatch: P ( X ) � = Q ( X ) Reduce conditional distribution mismatch: P ( Y | X ) � = Q ( Y | X ) P ( X ) E [ z (X)] n x n x ˆ X x i z ( x i ) Feature Space m 1 n x = E [ z (X)] c n x ˆ = / z ( x i ) m i = 1 [FIG3] Kernel embedding of a distribution and finite sample estimate. Kernel Embedding Adversarial Learning Song et al. Kernel Embeddings of Conditional Distributions. IEEE , 2013. Goodfellow et al. Generative Adversarial Networks. NIPS 2014. Z. Cao et al. (Tsinghua University) SAN CVPR 2018 5 / 16
Method Selective Adversarial Networks @L y @L y @θ f @θ y G f G y x f y ^ L y CNN ^ 1 G d 1 d 1 L d − @L d @θ f ^ 2 G d 2 2 d L d ^ K @L f G d K d K GRL L d @θ f @L d back-propagation L d @θ d f = G f ( x ): feature extractor G y , L y : label predictor and loss G k d , L k ˆ y : predicted data label d : domain discriminator ˆ d : predicted domain label GRL: gradient reversal layer Z. Cao et al. (Tsinghua University) SAN CVPR 2018 6 / 16
Method Selective Adversarial Networks @L y @L y @θ f @θ y G f G y x f y ^ L y CNN ^ 1 G d 1 d 1 L d − @L d @θ f ^ 2 2 G d 2 d L d ^ K G d K d K @L f GRL L d @θ f @L d back-propagation L d @θ d Instance Weighting (IW): probability-weighted loss for G k d , k = 1 , . . . , |C s | |C s | 1 � � � � L ′ y k i L k G k d = ˆ d ( G f ( x i )) , d i (1) d n s + n t k =1 x i ∈D s ∪D t Z. Cao et al. (Tsinghua University) SAN CVPR 2018 7 / 16
Method Selective Adversarial Networks @L y @L y @θ f @θ y G f G y x f y ^ L y CNN ^ 1 G d 1 1 d L d − @L d @θ f 2 ^ G d 2 d 2 L d ^ K G d K d K @L f GRL L d @θ f @L d back-propagation L d @θ d Class Weighting (CW): down-weigh G k d , k = 1 , . . . , |C s | for outlier classes |C s | 1 1 � � � � y k × � y k i L k G k (2) L d = ˆ ˆ d ( G f ( x i )) , d i i d n s + n t n t k =1 x i ∈D t x i ∈ ( D s ∪D t ) Z. Cao et al. (Tsinghua University) SAN CVPR 2018 8 / 16
Method Selective Adversarial Networks @L y @L y @θ f @θ y G f G y x f y ^ L y CNN ^ 1 G d 1 d 1 L d − @L d @θ f ^ 2 G d 2 2 d L d ^ K @L f G d K d K GRL L d @θ f @L d back-propagation L d @θ d Entropy (uncertainty) minimization: H ( G y ( G f ( x i ))) = − � |C s | y k y k k =1 ˆ i log ˆ i E = 1 � H ( G y ( G f ( x i ))) (3) n t x i ∈D t Z. Cao et al. (Tsinghua University) SAN CVPR 2018 9 / 16
Method Selective Adversarial Networks @L y @L y @θ f @θ y G f G y x f y ^ L y CNN ^ 1 G d 1 d 1 L d − @L d @θ f ^ 2 G d 2 d 2 L d ^ K @L f G d K d K L d GRL @θ f @L d back-propagation L d @θ d = 1 L y ( G y ( G f ( x i )) , y i ) + 1 � � d | |C s | θ f , θ y , θ k � � C H ( G y ( G f ( x i ))) k =1 n s n t x i ∈D s x i ∈D t |C s | 1 1 � � � � � y k × y k i L k G k − ˆ ˆ d ( G f ( x i )) , d i i d n s + n t n t k =1 x i ∈D t x i ∈ ( D s ∪D t ) (4) � d | |C s | � (ˆ θ f , ˆ θ f , θ y , θ k θ y ) = arg min C k =1 θ f ,θ y (5) � � θ |C s | d | |C s | (ˆ θ 1 d , ..., ˆ θ f , θ y , θ k ) = arg max C d k =1 d ,...,θ |C s | θ 1 d Z. Cao et al. (Tsinghua University) SAN CVPR 2018 10 / 16
Evaluation Setup O ffi ce-Caltech Fine-tune Art Clipart Pre-train Fine-tune Product Real World Spoon Sink Mug Pen Knife Bed Bike Kettle TV Keyboard Classes Alarm-Clock Desk-Lamp Hammer Chair Fan Fine-tune Office-Home VisDA Challenge 2017 Transfer Tasks: Office-31 (31 → 10), Caltech-Office (256 → 10) and ImageNet-Caltech ( I 1000 → C 84 and C 256 → I 84) Z. Cao et al. (Tsinghua University) SAN CVPR 2018 11 / 16
Evaluation Results Office-31 Method A 31 → W 10 D 31 → W 10 W 31 → D 10 A 31 → D 10 D 31 → A 10 W 31 → A 10 Avg AlexNet [2] 58.51 95.05 98.08 71.23 70.6 67.74 76.87 DAN [3] 56.52 71.86 86.78 51.86 50.42 52.29 61.62 RevGrad [1] 49.49 93.55 90.44 49.68 46.72 48.81 63.11 RTN [4] 66.78 86.77 99.36 70.06 73.52 76.41 78.82 ADDA [5] 70.68 96.44 98.65 72.90 74.26 75.56 81.42 SAN-selective 71.51 98.31 100.00 78.34 77.87 76.32 83.73 SAN-entropy 74.61 98.31 100.00 80.29 78.39 82.25 85.64 SAN 80.02 98.64 100.00 81.28 80.58 83.09 87.27 Caltech-Office ImageNet-Caltech Method C 256 → W 10 C 256 → A 10 C 256 → D 10 Avg I 1000 → C 84 C 256 → I 84 Avg AlexNet [2] 58.44 76.64 65.86 66.98 52.37 47.35 49.86 DAN [3] 42.37 70.75 47.04 53.39 54.21 52.03 53.12 RevGrad [1] 54.57 72.86 57.96 61.80 51.34 47.02 49.18 RTN [4] 71.02 81.32 62.35 71.56 63.69 50.45 57.07 ADDA [5] 73.66 78.35 74.80 75.60 64.20 51.55 57.88 SAN-selective 76.44 81.63 80.25 79.44 66.78 51.25 59.02 SAN-entropy 72.54 78.95 76.43 75.97 55.27 52.31 53.79 SAN 88.33 83.82 85.35 85.83 68.45 55.61 62.03 Z. Cao et al. (Tsinghua University) SAN CVPR 2018 12 / 16
Evaluation Analysis 100 0.7 RevGrad 0.65 SAN DAN 90 0.6 SAN RTN RevGrad 0.55 AlexNet 80 0.5 Accuracy Test Error 0.45 70 0.4 0.35 60 0.3 0.25 50 0.2 0.15 40 0.1 31 30 25 20 15 10 500 3000 6000 9000 12000 15000 Number of Target Classes Number of Iterations (a) Accuracy w.r.t #Target Classes (b) Test Error SAN outperforms RevGrad even more for larger class-space difference SAN converges more stably and fast to lower test error than RevGrad Z. Cao et al. (Tsinghua University) SAN CVPR 2018 13 / 16
Evaluation Visualization 40 30 40 40 source1 source2 source3 source4 30 30 30 20 source5 target1 target2 20 20 target3 20 10 target4 target5 10 10 0 10 0 0 -10 0 -10 -10 -10 -20 -20 -20 -20 -30 -30 -30 -40 -20 0 20 40 -40 -20 0 20 40 60 -40 -20 0 20 40 -40 -20 0 20 (c) DAN (d) RevGrad (e) RTN (f) SAN 40 40 30 40 source target 30 30 30 20 20 20 20 10 10 10 10 0 0 0 0 -10 -10 -10 -10 -20 -20 -20 -20 -30 -30 -30 -30 -40 -40 -40 -40 -50 -60 -40 -20 0 20 40 -40 -20 0 20 40 60 -40 -20 0 20 40 -50 0 50 (g) DAN (h) RevGrad (i) RTN (j) SAN Figure: t-SNE with class information (top) and domain information (bottom). Z. Cao et al. (Tsinghua University) SAN CVPR 2018 14 / 16
Evaluation References Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. S. Lempitsky. Domain-adversarial training of neural networks. Journal of Machine Learning Research , 17:59:1–59:35, 2016. A. Krizhevsky, I. Sutskever, and G. E. Hinton. Imagenet classification with deep convolutional neural networks. In NIPS , 2012. M. Long, Y. Cao, J. Wang, and M. I. Jordan. Learning transferable features with deep adaptation networks. In ICML , 2015. M. Long, H. Zhu, J. Wang, and M. I. Jordan. Unsupervised domain adaptation with residual transfer networks. In NIPS , pages 136–144, 2016. E. Tzeng, J. Hoffman, K. Saenko, and T. Darrell. Adversarial discriminative domain adaptation. In CVPR , 2017. Z. Cao et al. (Tsinghua University) SAN CVPR 2018 15 / 16
Evaluation Summary A novel selective adversarial network for partial transfer learning Circumvent negative transfer by selecting out outlier source classes Promote positive transfer by matching shared-class-space distributions Code will be available soon at: https://github.com/thuml/ A work at CVPR 2018 follows our arXiv version: how fast they are! Z. Cao et al. (Tsinghua University) SAN CVPR 2018 16 / 16
Recommend
More recommend