Wasserstein Distances Def. For p ≥ 1, the p -Wasserstein distance between µ , ν in P ( Ω ), defined by a metric D on Ω , ZZ def D ( x, y ) p P ( dx, dy ) . p ( µ , ν ) = inf W p P ∈ Π ( µ , ν ) PRIMAL 31
Wasserstein Distances Def. For p ≥ 1, the p -Wasserstein distance between µ , ν in P ( Ω ), defined by a metric D on Ω , ZZ def D ( x, y ) p P ( dx, dy ) . p ( µ , ν ) = inf W p P ∈ Π ( µ , ν ) PRIMAL Z Z W p p ( µ , ν ) = sup ϕ d µ + ψ d ν . ϕ ∈ L 1 ( µ ) , ψ ∈ L 1 ( ν ) ϕ ( x )+ ψ ( y ) ≤ D p ( x,y ) DUAL 31
W is versatile Discrete - Discrete Discrete - Continuous Continuous - Continuous 32
W is versatile - Network flow solvers Discrete - Discrete - Entropic regularization Discrete - Continuous low dim. [M’11][KMB’16] [L’15] Stochastic Continuous - Continuous Optimization [GCPB’16] 32
Minimum Kantorovich Estimators min θ ∈ Θ W ( ν data , f θ ] µ ) • [Bassetti’06] 1st reference discussing this approach. • [MMC’16] use regularization in a finite setting . • [ACB’17] (WGAN) [BJGR’17] (Wasserstein ABC). • Hot topics : approximate & differentiate W efficiently. • Today: ideas from our recent preprint [GPC’17] 33
Wasserstein between 2 Diracs δ x ( Ω , D ) δ y p ( δ x , δ y ) = D ( x , y ) W p 34
Wasserstein on Uniform Measures n 1 X n δ x i µ = i =1 ( Ω , D ) n 1 X ν = n δ y j j =1 35
Wasserstein on Uniform Measures n 1 X n δ x i µ = i =1 ( Ω , D ) n 1 n X C ( σ ) = 1 ν = n δ y j X D ( x i , y σ i ) p n j =1 i =1 35
Optimal Assignment ⊂ Wasserstein n 1 X n δ x i µ = i =1 ( Ω , D ) n 1 X ν = n δ y j W p p ( µ , ν ) = min σ ∈ S n C ( σ ) j =1 36
OT on Two Empirical Measures n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 37
OT on Two Empirical Measures n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 37
Wasserstein on Empirical Measures n m X X a i δ x i and ν = b j δ y j . Consider µ = i =1 j =1 def = [ D ( x i , y j ) p ] ij M XY | P 1 m = a , P T 1 n = b } def = { P ∈ R n × m U ( a , b ) + y 1 ... y m b 1 ... b m x 1 a 1 · · · · · · · · · · · · . . . . D ( x i , y j ) p P 1 m = a . . · · · · · · · · x n a n · · · · · · · · · · · · 38
Wasserstein on Empirical Measures n m X X a i δ x i and ν = b j δ y j . Consider µ = i =1 j =1 def = [ D ( x i , y j ) p ] ij M XY | P 1 m = a , P T 1 n = b } def = { P ∈ R n × m U ( a , b ) + b 1 ... b m y 1 ... y m . . . . . . . . . a 1 x 1 · · · . . . . . . P T 1 n = b . D ( x i , y j ) p . . . . · · . . . . . . . x n · · · . . . a n 38
Wasserstein on Empirical Measures n m X X a i δ x i and ν = b j δ y j . Consider µ = i =1 j =1 def = [ D ( x i , y j ) p ] ij M XY | P 1 m = a , P T 1 n = b } def = { P ∈ R n × m U ( a , b ) + Def. Optimal Transport Problem W p p ( µ , ν ) = P ∈ U ( a , b ) h P , M XY i min 38
Discrete OT Problem M XY U ( a , b ) 39
Discrete OT Problem M XY U ( a , b ) P ? 40
Discrete OT Problem M XY U ( a , b ) P ? Def. Dual OT problem α T a + β T b W p p ( µ , ν ) = max α ∈ R n , β ∈ R m α i + β j ≤ D ( x i , y j ) p 40
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Note: flow/PDE formulations [Beckman’61]/[Benamou’98] can be used for p=1/p=2 for a sparse-graph metric/Euclidean metric. 40
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? 41
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? P ? Solution unstable and not always unique. 41
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable { P ? } and not always unique. 41
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable { P ? } and not always unique. 42
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable and not always unique. P ? 42
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? Solution unstable and not always unique. P ? p ( µ , ν ) not di ff erentiable. W p 42
Discrete OT Problem M XY U ( a , b ) P ? 43
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? 43
Discrete OT Problem network flow solver M XY used in practice. O ( n 3 log( n )) U ( a , b ) P ? 43
Solution: Modify OT Problem M XY U ( a , b ) P ? Wishlist : faster & scalable , more stable , differentiable 44
Entropic Regularization [Wilson’62] Def. Regularized Wasserstein, γ ≥ 0 def W γ ( µ , ν ) = P ∈ U ( a , b ) h P , M XY i � γ E ( P ) min nm def X E ( P ) = − P ij (log P ij ) i,j =1 Note: Unique optimal solution because of strong concavity of Entropy 45
Entropic Regularization [Wilson’62] Def. Regularized Wasserstein, γ ≥ 0 def W γ ( µ , ν ) = P ∈ U ( a , b ) h P , M XY i � γ E ( P ) min ν P γ µ γ Note: Unique optimal solution because of strong concavity of Entropy 45
Fast & Scalable Algorithm def Prop. If P γ = argmin h P , M XY i� γ E ( P ) P ∈ U ( a , b ) then 9 ! u 2 R n + , v 2 R m + , such that def = e − M XY / γ P γ = diag ( u ) K diag ( v ) , K 46
Fast & Scalable Algorithm def Prop. If P γ = argmin h P , M XY i� γ E ( P ) P ∈ U ( a , b ) then 9 ! u 2 R n + , v 2 R m + , such that def = e − M XY / γ P γ = diag ( u ) K diag ( v ) , K P ij M ij + γ P ij log P ij + α T ( P 1 − a ) + β T ( P T 1 − b ) X L ( P, α , β ) = ij ∂ L/ ∂ P ij = M ij + γ (log P ij + 1) + α i + β j β j − M ij α i γ + 1 γ + 1 2 e 2 = u i K ij v j ( ∂ L/ ∂ P ij = 0) ⇒ P ij = e e γ 46
Fast & Scalable Algorithm def Prop. If P γ = argmin h P , M XY i� γ E ( P ) P ∈ U ( a , b ) then 9 ! u 2 R n + , v 2 R m + , such that def = e − M XY / γ P γ = diag ( u ) K diag ( v ) , K ( u , v ) • [Sinkhorn’64] fixed-point iterations for v ← b /K T u u ← a /K v , O ( nm ) • complexity, GPGPU parallel [C’13] . O ( n d +1 ) Ω = { 1 , . . . , n } d • if and separable. D p [S..C..’15] 46
Very Fast EMD Approx. Solver 4 10 FastEMD Rubner’s emd CPU γ =0.02 Avg. Execution Time per Distance (in s.) 2 10 CPU γ =0.1 GPU γ =0.02 GPU γ =0.1 0 10 − 2 10 − 4 10 − 6 10 64 128 256 512 1024 2048 4096 Histogram Dimension ( Ω , D ) Note. is a random graph with shortest path metric, histograms sampled uniformly on simplex, Sinkhorn tolerance 10 -2 . 47
Regularization ⤑ Differentiability W γ ( ( a, X ) , ( b, Y ) ) = P ∈ U ( a , b ) h P , M XY i� γ E ( P ) min n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 48
Regularization ⤑ Differentiability W γ ( ( a + ∆ a, X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 48
Regularization ⤑ Differentiability W γ ( ( a + ∆ a, X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j a ← a + ∆ a j =1 48
Regularization ⤑ Differentiability W γ ( ( a, X + ∆ X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j j =1 49
Regularization ⤑ Differentiability W γ ( ( a, X + ∆ X ) , ( b, Y ) ) = W γ ( ( a, X ) , ( b, Y ) )+?? n X a i δ x i µ = i =1 ( Ω , D ) m X ν = b j δ y j X ← X + ∆ X j =1 49
1. Differentiability of Regularized OT Def. Dual regularized OT Problem α , β α T a + β T b − 1 γ ( e α / γ ) T K Ke β / γ W γ ( µ , ν ) = max Prop. W � ( µ , ν ) is [CD’14] 1. convex w.r.t. a , r a W � = α ? = γ log( u ) . 2. decreased, when p = 2 , Ω = R d , using X Y P T � D ( a − 1 ) . 50
Recommend
More recommend