generative models and optimal transport
play

Generative Models and Optimal Transport Marco Cuturi Joint work / - PowerPoint PPT Presentation

Generative Models and Optimal Transport Marco Cuturi Joint work / work in progress with G. Peyr, A. Genevay (ENS) , F. Bach (INRIA), G. Montavon, K-R Mller (TU Berlin) Statistics 0.1 : Density Fitting We collect data N data = 1 X x


  1. Wasserstein Distances Def. For p ≥ 1, the p -Wasserstein distance between µ , ν in P ( Ω ), defined by a metric D on Ω , ZZ def D ( x, y ) p P ( dx, dy ) . p ( µ , ν ) = inf W p P ∈ Π ( µ , ν ) PRIMAL 31

  2. Wasserstein Distances Def. For p ≥ 1, the p -Wasserstein distance between µ , ν in P ( Ω ), defined by a metric D on Ω , ZZ def D ( x, y ) p P ( dx, dy ) . p ( µ , ν ) = inf W p P ∈ Π ( µ , ν ) PRIMAL Z Z W p p ( µ , ν ) = sup ϕ d µ + ψ d ν . ϕ ∈ L 1 ( µ ) , ψ ∈ L 1 ( ν ) ϕ ( x )+ ψ ( y ) ≤ D p ( x,y ) DUAL 31

  3. W is versatile Discrete - Discrete Discrete - Continuous Continuous - Continuous 32

  4. W is versatile - Network flow solvers Discrete - Discrete - Entropic regularization Discrete - Continuous low dim. [M’11][KMB’16] [L’15] Stochastic Continuous - Continuous Optimization [GCPB’16] 32

  5. Minimum Kantorovich Estimators min θ ∈ Θ W ( ν data , f θ ] µ ) • [Bassetti’06] 1st reference discussing this approach. • [MMC’16] use regularization in a finite setting . • [ACB’17] (WGAN) [BJGR’17] (Wasserstein ABC). • Hot topics : approximate & differentiate W efficiently. • Today: ideas from our recent preprint [GPC’17] 33

  6. Wasserstein between 2 Diracs δ x ( Ω , D ) δ y p ( δ x , δ y ) = D ( x , y ) W p 34

  7. Wasserstein on Uniform Measures n 1 X n δ x i µ = i =1 ( Ω , D ) n 1 X ν = n δ y j j =1 35

  8. Wasserstein on Uniform Measures n 1 X n δ x i µ = i =1 ( Ω , D ) n 1 n X C ( σ ) = 1 ν = n δ y j X D ( x i , y σ i ) p n j =1 i =1 35

  9. Optimal Assignment ⊂ Wasserstein n 1 X n δ x i µ = i =1 ( Ω , D ) n 1 X ν = n δ y j W p p ( µ , ν ) = min σ ∈ S n C ( σ ) j =1 36

  10. OT on Two Empirical Measures n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 37

  11. OT on Two Empirical Measures n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 37

  12. Wasserstein on Empirical Measures n m X X a i δ x i and ν = b j δ y j . Consider µ = i =1 j =1 def = [ D ( x i , y j ) p ] ij M XY | P 1 m = a , P T 1 n = b } def = { P ∈ R n × m U ( a , b ) + y 1 ... y m b 1 ... b m     x 1 a 1 · · · · · · · · · · · ·     . .     . . D ( x i , y j ) p P 1 m = a . .     · · · · · · · ·         x n a n · · · · · · · · · · · · 38

  13. Wasserstein on Empirical Measures n m X X a i δ x i and ν = b j δ y j . Consider µ = i =1 j =1 def = [ D ( x i , y j ) p ] ij M XY | P 1 m = a , P T 1 n = b } def = { P ∈ R n × m U ( a , b ) + b 1 ... b m y 1 ... y m . . .     . . . . . . a 1 x 1 · · ·     . . .   .   . . P T 1 n = b . D ( x i , y j ) p .   . . .   · · .      . . .    . . . x n · · · . . . a n 38

  14. Wasserstein on Empirical Measures n m X X a i δ x i and ν = b j δ y j . Consider µ = i =1 j =1 def = [ D ( x i , y j ) p ] ij M XY | P 1 m = a , P T 1 n = b } def = { P ∈ R n × m U ( a , b ) + Def. Optimal Transport Problem W p p ( µ , ν ) = P ∈ U ( a , b ) h P , M XY i min 38

  15. Discrete OT Problem M XY U ( a , b ) 39

  16. Discrete OT Problem M XY U ( a , b ) P ? 40

  17. Discrete OT Problem M XY U ( a , b ) P ? Def. Dual OT problem α T a + β T b W p p ( µ , ν ) = max α ∈ R n , β ∈ R m α i + β j ≤ D ( x i , y j ) p 40

  18. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Note: flow/PDE formulations [Beckman’61]/[Benamou’98] can be used for p=1/p=2 for a sparse-graph metric/Euclidean metric. 40

  19. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? 41

  20. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? P ? Solution unstable and not always unique. 41

  21. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable { P ? } and not always unique. 41

  22. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable { P ? } and not always unique. 42

  23. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable and not always unique. P ? 42

  24. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable and not always unique. P ? p ( µ , ν ) not di ff erentiable. W p 42

  25. Discrete OT Problem M XY U ( a , b ) P ? 43

  26. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? 43

  27. Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? 43

  28. Solution: Modify OT Problem M XY U ( a , b ) P ? Wishlist : faster & scalable , more stable , differentiable 44

  29. Entropic Regularization [Wilson’62] Def. Regularized Wasserstein, γ ≥ 0 def W γ ( µ , ν ) = P ∈ U ( a , b ) h P , M XY i � γ E ( P ) min nm def X E ( P ) = − P ij (log P ij ) i,j =1 Note: Unique optimal solution because of strong concavity of Entropy 45

  30. Entropic Regularization [Wilson’62] Def. Regularized Wasserstein, γ ≥ 0 def W γ ( µ , ν ) = P ∈ U ( a , b ) h P , M XY i � γ E ( P ) min ν P γ µ γ Note: Unique optimal solution because of strong concavity of Entropy 45

  31. Fast & Scalable Algorithm def Prop. If P γ = argmin h P , M XY i� γ E ( P ) P ∈ U ( a , b ) then 9 ! u 2 R n + , v 2 R m + , such that def = e − M XY / γ P γ = diag ( u ) K diag ( v ) , K 46

  32. Fast & Scalable Algorithm def Prop. If P γ = argmin h P , M XY i� γ E ( P ) P ∈ U ( a , b ) then 9 ! u 2 R n + , v 2 R m + , such that def = e − M XY / γ P γ = diag ( u ) K diag ( v ) , K P ij M ij + γ P ij log P ij + α T ( P 1 − a ) + β T ( P T 1 − b ) X L ( P, α , β ) = ij ∂ L/ ∂ P ij = M ij + γ (log P ij + 1) + α i + β j β j − M ij α i γ + 1 γ + 1 2 e 2 = u i K ij v j ( ∂ L/ ∂ P ij = 0) ⇒ P ij = e e γ 46

  33. Fast & Scalable Algorithm def Prop. If P γ = argmin h P , M XY i� γ E ( P ) P ∈ U ( a , b ) then 9 ! u 2 R n + , v 2 R m + , such that def = e − M XY / γ P γ = diag ( u ) K diag ( v ) , K ( u , v ) • [Sinkhorn’64] fixed-point iterations for v ← b /K T u u ← a /K v , O ( nm ) • complexity, GPGPU parallel [C’13] . O ( n d +1 ) Ω = { 1 , . . . , n } d • if and separable. D p [S..C..’15] 46

  34. Very Fast EMD Approx. Solver 4 10 FastEMD Rubner’s emd CPU γ =0.02 Avg. Execution Time per Distance (in s.) 2 10 CPU γ =0.1 GPU γ =0.02 GPU γ =0.1 0 10 − 2 10 − 4 10 − 6 10 64 128 256 512 1024 2048 4096 Histogram Dimension ( Ω , D ) Note. is a random graph with shortest path metric, histograms sampled uniformly on simplex, Sinkhorn tolerance 10 -2 . 47

  35. Regularization ⤑ Differentiability W γ ( ( a, X ) , ( b, Y ) ) = P ∈ U ( a , b ) h P , M XY i� γ E ( P ) min n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 48

  36. Regularization ⤑ Differentiability W γ ( ( a + ∆ a, X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 48

  37. Regularization ⤑ Differentiability W γ ( ( a + ∆ a, X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j a ← a + ∆ a j =1 48

  38. Regularization ⤑ Differentiability W γ ( ( a, X + ∆ X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 49

  39. Regularization ⤑ Differentiability W γ ( ( a, X + ∆ X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j X ← X + ∆ X j =1 49

  40. 1. Differentiability of Regularized OT Def. Dual regularized OT Problem α , β α T a + β T b − 1 γ ( e α / γ ) T K Ke β / γ W γ ( µ , ν ) = max Prop. W � ( µ , ν ) is [CD’14] 1. convex w.r.t. a , r a W � = α ? = γ log( u ) . 2. decreased, when p = 2 , Ω = R d , using X Y P T � D ( a − 1 ) . 50

Recommend


More recommend