OSL 2015 The Wasserstein Barycenter Problem Marco Cuturi mcuturi@i.kyoto-u.ac.jp Joint work with G. Peyr´ e, G. Carlier, J.D. Benamou, L. Nenna, A. Gramfort, J. Solomon, ... 13.1.15 1
Motivation 1.2 1 0.8 0.6 0.4 0.2 0 −0.2 −0.2 0 0.2 0.4 0.6 0.8 1 1.2 4 points in R 2 x 1 , x 2 , x 3 , x 4 13.1.15 2
Mean 1.2 1 0.8 0.6 0.4 0.2 0 −0.2 −0.2 0 0.2 0.4 0.6 0.8 1 1.2 Their mean is ( x 1 + x 2 + x 3 + x 4 ) / 4 . 13.1.15 3
Computing Means Consider for each point the function �· − x i � 2 2 13.1.15 4
Computing Means � 4 The mean is the argmin 1 i =1 �· − x i � 2 2 . 4 13.1.15 5
Means in Metric Spaces 2 1.8 1.6 1.4 1.2 1 0.8 0.6 0.4 0.2 0 0 0.2 0.4 1 0.6 0.9 0.8 0.7 0.6 0.8 0.5 0.4 0.3 0.2 1 0.1 0 Means can be defined using any distance/divergence/discrepancy. 13.1.15 6
Means in Metric Spaces 2 1.8 1.6 1.4 1.2 1 0.8 0.6 0.4 0.2 0 0 0.2 0.4 1 0.6 0.9 0.8 0.7 0.6 0.8 0.5 0.4 0.3 0.2 1 0.1 0 Using e.g. geodesic distances. Here ∆( • , • ) = 0 . 994 13.1.15 7
Means in Metric Spaces Consider the distance functions ∆( · , x i ) , i = 1 , 2 , 3 , 4 . 13.1.15 8
Means in Metric Spaces � N � = argmin 1 i =1 ∆( · , x i ) . N 13.1.15 9
From points 1.2 1 0.8 0.6 0.4 0.2 0 −0.2 −0.2 0 0.2 0.4 0.6 0.8 1 1.2 13.1.15 10
to Probability Measures 1.2 1 0.8 0.6 0.4 0.2 0 −0.2 −0.2 0 0.2 0.4 0.6 0.8 1 1.2 Assume that each datum is now an empirical measure . What could be the mean of these 4 measures? 13.1.15 11
1. Naive Averaging 1.2 1 0.8 0.6 0.4 0.2 0 −0.2 −0.2 0 0.2 0.4 0.6 0.8 1 1.2 � = naive mean of all observations. Mean of 4 measures = a point? 13.1.15 12
Averaging Probability Measures 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1.2 1 0.8 1.2 0.6 1 0.4 0.8 0.6 0.2 0.4 0 0.2 0 −0.2 −0.2 Same measures, in a 3D perspective. 13.1.15 13
2. Naive Averaging 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1.2 1 0.8 1.2 0.6 1 0.4 0.8 0.6 0.2 0.4 0 0.2 0 −0.2 −0.2 Euclidean mean of measures is their sum / N . R 2 [ dµ − dν ] 2 . � Here, ∆( µ, ν ) = 13.1.15 14
Focus on uncertainty 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1.2 1 0.8 1.2 0.6 1 0.4 0.8 0.6 0.2 0.4 0 0.2 0 −0.2 −0.2 ...but geometric knowledge ignored. 13.1.15 15
Focus on geometry 1 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 1.2 1 0.8 1.2 0.6 1 0.4 0.8 0.6 0.2 0.4 0 0.2 0 −0.2 −0.2 ...but uncertainty is lost. 13.1.15 16
Problem of interest Given a discrepancy function ∆ between probabilities, compute their mean: argmin � i ∆( · , ν i ) • The idea is useful, sometimes tractable & appears in ◦ Bregman clustering for histograms [Banerjee’05].. ◦ Topic modeling [Blei & al.’03].. ◦ Clustering problems ( k -means). • Our goal in this talk: study the case ∆ = Wasserstein 13.1.15 17
Wasserstein Distances 13.1.15 18
Comparing Two Measures ν µ Ω Two measures µ, ν ∈ P (Ω) . 13.1.15 19
The Optimal Transport Approach ν µ y x D ( x , y ) (Ω , D ) Optimal Transport distances rely on 2 key concepts : • A metric D : Ω × Ω → R + ; • Π( µ , ν ) : joint probabilities with marginals µ , ν . 13.1.15 20
Joint Probabilities of ( µ, ν ) 1 µ ν 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 −2 −1 0 1 2 3 4 5 Consider µ, ν two measures on the real line. 13.1.15 21
Joint Probabilities of ( µ, ν ) µ ( x ) ν ( y ) 0 . 5 P 0 5 − 2 P ( x, y ) 4 − 1 0 . 6 0 3 1 2 0 . 4 2 1 0 . 2 3 0 x y 4 − 1 0 5 − 2 Π( µ , ν ) = probability measures on Ω 2 with marginals µ and ν . 13.1.15 22
Joint Probabilities of ( µ, ν ) µ ( x ) ν ( y ) 0 . 5 P 0 5 − 2 P ( x, y ) 4 − 1 0 . 6 0 3 1 2 0 . 4 2 1 0 . 2 3 0 x y 4 − 1 0 5 − 2 Π( µ , ν ) = probability measures on Ω 2 with marginals µ and ν . 13.1.15 23
Optimal Transport Distance ν µ y x D ( x , y ) (Ω , D ) p -Wasserstein (or OT) distance, assuming p ≥ 1 , is: � 1 /p � P ∈ Π( µ , ν ) E P [ D ( X, Y ) p ] W p ( µ , ν ) = inf . 13.1.15 24
(Historical Parenthesis) Monge-Kantorovich, Kantorovich-Rubinstein, Wasserstein, Earth Mover’s Distance, Mallows • Monge 1781 M´ emoire sur la th´ eorie des d´ eblais et des remblais • Optimization & Operations Research ◦ Kantorovich’42, Dantzig’47, Ford Fulkerson’55, etc. • Probability & Statistical Physics ◦ Rachev’92, Talagrand’96, Villani’09 • Computer Vision : Rubner et al’98 13.1.15 25
OT Distance for Empirical Measures µ = � n i =1 a i δ x i (Ω , D ) ν = � m j =1 b j δ y j � 1 /p � P ∈ Π( µ,ν ) E P [ D ( X, Y ) p ] W p ( µ, ν ) = inf . Algorithmically? 13.1.15 26
OT Distance for Empirical Measures µ = � n 1 n δ x i i =1 (Ω , D ) ν = � m 1 m δ y j j =1 Suppose n = m and all weights are uniform 13.1.15 27
OT Distance for Empirical Measures µ = � n 1 n δ x i i =1 (Ω , D ) ν = � m 1 m δ y j j =1 Then W p p = optimal matching cost (solved for instance with Hungarian algorithm) � 1 /p n � 1 � D ( x i , y σ i ) p min n σ ∈ S n i =1 13.1.15 28
OT Distance for Empirical Measures µ = � n 1 n δ x i i =1 (Ω , D ) ν = � m 1 m δ y j j =1 As soon as n � = m or weights are non uniform, optimal matching does not make sense. 13.1.15 29
Computing the OT Distance µ = � n i =1 a i δ x i (Ω , D ) ν = � m j =1 b j δ y j W p p ( µ , ν ) can be cast as a linear program in R n × m : def =[ D ( x i , y j ) p ] ij ∈ R n × m (metric information) 1. M XY 2. Transportation Polytope (joint probabilities) U ( a , b ) = { P ∈ R n × m | P 1 m = a , P T 1 n = b } + 13.1.15 30
Computing p -Wasserstein Distances def W p p ( µ , ν ) = primal ( a , b , M XY ) = T ∈ U ( a , b ) � T, M XY � min M XY T ⋆ U ( a, b ) W p p ( µ , ν ) = � T ⋆ , M XY � = T ∈ U ( a , b ) � T, M XY � min 13.1.15 31
[Kantorovich’42] Duality • This primal problem has an equivalent, dual LP: def primal ( a , b , M XY ) = T ∈ U ( a , b ) � T, M XY � min or W p p ( µ , ν ) = def α T a + β T b , dual ( a , b , M XY ) = max ( α,β ) ∈ C M XY where C M = { ( α, β ) ∈ R n + m | α i + β j ≤ M ij } . 13.1.15 32
[Kantorovich’42] Duality • This primal problem has an equivalent, dual LP: def primal ( a , b , M XY ) = T ∈ U ( a , b ) � T, M XY � min or W p p ( µ , ν ) = def α T a + β T b , dual ( a , b , M XY ) = max ( α,β ) ∈ C M XY where C M = { ( α, β ) ∈ R n + m | α i + β j ≤ M ij } . � Both problems require O ( n 3 log( n )) operations. Typically solved using the network simplex. 13.1.15 33
Wasserstein Barycenter Problem (WBP) • [Agueh’11] introduced the WBP: N def � W p argmin C ( µ ) = p ( µ , ν i ) , µ ∈ P (Ω) i =1 • Can be solved with a multi-marginal OT problem. • Intractable : LP of � i card(supp( ν i )) variables. 13.1.15 34
Differentiability w.r.t. X or a µ = � n i =1 a i δ x i (Ω , D ) ν = � m j =1 b j δ y j To solve it numerically , we must understand how def = W p f ν ( a , X ) p ( µ , ν ) varies when a & X varies. 13.1.15 35
Differentiability w.r.t. X or a µ = � n i =1 a ′ i δ x i Ω ν = � m j =1 b j δ y j 1. Infinitesimal Variation in Weights if a ′ ≈ a f ν ( a ′ , X )? , 13.1.15 36
Differentiability w.r.t. X or a µ = � n i =1 a i δ x ′ i Ω ν = � m j =1 b j δ y j 2. Infinitesimal Variation in Locations if X ′ ≈ X f ν ( a , X ′ )? , 13.1.15 37
Using the dual, ∂ | a α T a + β T b f ν ( a , X ) = max ( α,β ) ∈ C M XY 10 8 6 4 2 0 −2 −4 −6 −5 −4 −3 −2 −1 0 1 2 3 4 5 13.1.15 38
Using the dual, ∂ | a α T a + β T b f ν ( a , X ) = max ( α,β ) ∈ C M XY 10 8 6 4 2 0 −2 −4 −6 −5 −4 −3 −2 −1 0 1 2 3 4 5 a �→ f ν ( a , X ) is a convex non-smooth map. The dual optimum α ⋆ is a subgradient f ν ( a , X ) . 13.1.15 39
Using the primal ∂ | X f ν ( a, X ) = T ∈ U ( a , b ) � T, M XY � min • More involved computations. Tractable when D = Euclidean, p = 2 . • Convex quadratic + piecewise linear concave of X • ∂f ν | X = Y T ⋆T diag ( a − 1 ) : optimal transport T ⋆T yields a subgradient. 13.1.15 40
To sum up: (1) the WBP is challenging N N = 1 p ( µ , ν i ) = 1 def � � W p C ( a, X ) f ν i ( a, X ) N N i =1 i =1 • a → C ( a , X ) is convex , non-smooth , computing one subgradient requires solving N OT problems! • X → C ( a, X ) is not convex , non-smooth 13.1.15 41
Recommend
More recommend