Scalable Multi-Class Gaussian Process Classification using Expectation Propagation Carlos Villacampa-Calvo and Daniel Hern´ andez–Lobato Computer Science Department Universidad Aut´ onoma de Madrid http://dhnzl.org , daniel.hernandez@uam.es 1 / 22
Introduction to Multi-class Classification with GPs Given x i we want to make predictions about y i 2 { 1 , . . . , C } , C > 2. One can assume that (Kim & Ghahramani, 2006) : f k ( x i ) y i = arg max for k 2 { 1 , . . . , C } k 2 / 22
Introduction to Multi-class Classification with GPs Given x i we want to make predictions about y i 2 { 1 , . . . , C } , C > 2. One can assume that (Kim & Ghahramani, 2006) : f k ( x i ) y i = arg max for k 2 { 1 , . . . , C } k 3 1.50 0.75 Labels 0.00 f(x) 2 � 0.75 � 1.50 1 � 4 � 2 0 2 4 � 4 � 2 0 2 4 x x 2 / 22
Introduction to Multi-class Classification with GPs Given x i we want to make predictions about y i 2 { 1 , . . . , C } , C > 2. One can assume that (Kim & Ghahramani, 2006) : f k ( x i ) y i = arg max for k 2 { 1 , . . . , C } k 3 1.50 0.75 Labels 0.00 f(x) 2 � 0.75 � 1.50 1 � 4 � 2 0 2 4 � 4 � 2 0 2 4 x x Find p ( f | y ) = p ( y | f ) p ( f ) / p ( y ) under p ( f k ) ⇠ GP (0 , k ( · , · )). 2 / 22
Challenges in Multi-class Classification with GPs Binary classification has received more attention than multi-class! Challenges in the multi-class case : 1 Approximate inference is more di ffi cult. 3 / 22
Challenges in Multi-class Classification with GPs Binary classification has received more attention than multi-class! Challenges in the multi-class case : 1 Approximate inference is more di ffi cult. 2 C > 2 latent functions instead of just one. 3 / 22
Challenges in Multi-class Classification with GPs Binary classification has received more attention than multi-class! Challenges in the multi-class case : 1 Approximate inference is more di ffi cult. 2 C > 2 latent functions instead of just one. 3 Deal with more complicated likelihood factors. 3 / 22
Challenges in Multi-class Classification with GPs Binary classification has received more attention than multi-class! Challenges in the multi-class case : 1 Approximate inference is more di ffi cult. 2 C > 2 latent functions instead of just one. 3 Deal with more complicated likelihood factors. 4 More expensive algorithms, computationally. 3 / 22
Challenges in Multi-class Classification with GPs Binary classification has received more attention than multi-class! Challenges in the multi-class case : 1 Approximate inference is more di ffi cult. 2 C > 2 latent functions instead of just one. 3 Deal with more complicated likelihood factors. 4 More expensive algorithms, computationally. Most techniques do not scale to large datasets: (Williams & Barber, 1998; aki et al., 2013) . Kim & Ghahramani, 2006; Girolami & Rogers, 2006; Chai, 2012; Riihim¨ 3 / 22
Challenges in Multi-class Classification with GPs Binary classification has received more attention than multi-class! Challenges in the multi-class case : 1 Approximate inference is more di ffi cult. 2 C > 2 latent functions instead of just one. 3 Deal with more complicated likelihood factors. 4 More expensive algorithms, computationally. Most techniques do not scale to large datasets: (Williams & Barber, 1998; aki et al., 2013) . Kim & Ghahramani, 2006; Girolami & Rogers, 2006; Chai, 2012; Riihim¨ The best cost is O ( CNM 2 ), if sparse priors are used. 3 / 22
Stochastic Variational Inference for Multi-class GPs Hensman et al. , 2015, use a robust likelihood function: 8 f k ( x i ) 1 if y i = arg max ✏ < p ( y i | f i ) = (1 � ✏ ) p i + C � 1(1 � p i ) with p i = k 0 otherwise : 4 / 22
Stochastic Variational Inference for Multi-class GPs Hensman et al. , 2015, use a robust likelihood function: 8 f k ( x i ) 1 if y i = arg max ✏ < p ( y i | f i ) = (1 � ✏ ) p i + C � 1(1 � p i ) with p i = k 0 otherwise : R The posterior approximation is q ( f ) = p ( f | f ) q ( f ) d f k | µ k , Σ k ) q ( f ) = Q C k =1 N ( f k = ( f k ( x k k = ( x k M )) T M ) T 1 ) , . . . , f k ( x k 1 , . . . , x k f X 4 / 22
Stochastic Variational Inference for Multi-class GPs Hensman et al. , 2015, use a robust likelihood function: 8 f k ( x i ) 1 if y i = arg max ✏ < p ( y i | f i ) = (1 � ✏ ) p i + C � 1(1 � p i ) with p i = k 0 otherwise : R The posterior approximation is q ( f ) = p ( f | f ) q ( f ) d f k | µ k , Σ k ) q ( f ) = Q C k =1 N ( f k = ( f k ( x k k = ( x k M )) T M ) T 1 ) , . . . , f k ( x k 1 , . . . , x k f X The number of latent variables goes from CN to CM , with M ⌧ N . 4 / 22
Stochastic Variational Inference for Multi-class GPs Hensman et al. , 2015, use a robust likelihood function: 8 f k ( x i ) 1 if y i = arg max ✏ < p ( y i | f i ) = (1 � ✏ ) p i + C � 1(1 � p i ) with p i = k 0 otherwise : R The posterior approximation is q ( f ) = p ( f | f ) q ( f ) d f k | µ k , Σ k ) q ( f ) = Q C k =1 N ( f k = ( f k ( x k k = ( x k M )) T M ) T 1 ) , . . . , f k ( x k 1 , . . . , x k f X The number of latent variables goes from CN to CM , with M ⌧ N . N X L ( q ) = E q [log p ( y i | f i )] � KL( q | p ) i =1 4 / 22
Stochastic Variational Inference for Multi-class GPs Hensman et al. , 2015, use a robust likelihood function: 8 f k ( x i ) 1 if y i = arg max ✏ < p ( y i | f i ) = (1 � ✏ ) p i + C � 1(1 � p i ) with p i = k 0 otherwise : R The posterior approximation is q ( f ) = p ( f | f ) q ( f ) d f k | µ k , Σ k ) q ( f ) = Q C k =1 N ( f k = ( f k ( x k k = ( x k M )) T M ) T 1 ) , . . . , f k ( x k 1 , . . . , x k f X The number of latent variables goes from CN to CM , with M ⌧ N . N X L ( q ) = E q [log p ( y i | f i )] � KL( q | p ) i =1 The cost is O ( CM 3 ) (uses quadratures )! 4 / 22
Stochastic Variational Inference for Multi-class GPs Hensman et al. , 2015, use a robust likelihood function: 8 f k ( x i ) 1 if y i = arg max ✏ < p ( y i | f i ) = (1 � ✏ ) p i + C � 1(1 � p i ) with p i = k 0 otherwise : R The posterior approximation is q ( f ) = p ( f | f ) q ( f ) d f k | µ k , Σ k ) q ( f ) = Q C k =1 N ( f k = ( f k ( x k k = ( x k M )) T M ) T 1 ) , . . . , f k ( x k 1 , . . . , x k f X The number of latent variables goes from CN to CM , with M ⌧ N . N X L ( q ) = E q [log p ( y i | f i )] � KL( q | p ) i =1 The cost is O ( CM 3 ) (uses quadratures )! Can we do that with EP ? 4 / 22
Expectation Propagation (EP) Let θ summarize the latent variables of the model. Approximates p ( θ ) / p 0 ( θ ) Q N n =1 f n ( θ ) with q ( θ ) / p 0 ( θ ) Q N n =1 ˜ f n ( θ ) 5 / 22
Expectation Propagation (EP) Let θ summarize the latent variables of the model. Approximates p ( θ ) / p 0 ( θ ) Q N n =1 f n ( θ ) with q ( θ ) / p 0 ( θ ) Q N n =1 ˜ f n ( θ ) 5 / 22
Expectation Propagation (EP) Let θ summarize the latent variables of the model. Approximates p ( θ ) / p 0 ( θ ) Q N n =1 f n ( θ ) with q ( θ ) / p 0 ( θ ) Q N n =1 ˜ f n ( θ ) The ˜ f n are tuned by minimizing the KL divergence j 6 = n ˜ f n ( θ ) Q p n ( θ ) f j ( θ ) / D KL [ p n || q ] for n = 1 , . . . , N , where f j ( θ ) . ˜ j 6 = n ˜ f n ( θ ) Q q ( θ ) / 5 / 22
Model Specification f k ( x i ), which gives the likelihood : We consider that y i = arg max k p ( y | f ) = Q N i =1 p ( y i | f i ) = Q N Q k 6 = y i Θ ( f y i ( x i ) � f k ( x i )) i =1 6 / 22
Model Specification f k ( x i ), which gives the likelihood : We consider that y i = arg max k p ( y | f ) = Q N i =1 p ( y i | f i ) = Q N Q k 6 = y i Θ ( f y i ( x i ) � f k ( x i )) i =1 R The posterior approximation is also set to be q ( f ) = p ( f | f ) q ( f ) d f . 6 / 22
Model Specification f k ( x i ), which gives the likelihood : We consider that y i = arg max k p ( y | f ) = Q N i =1 p ( y i | f i ) = Q N Q k 6 = y i Θ ( f y i ( x i ) � f k ( x i )) i =1 R The posterior approximation is also set to be q ( f ) = p ( f | f ) q ( f ) d f . The posterior over f is: ⇡ [ Q N R R p ( y | f ) p ( f | f ) d f p ( f ) p ( y i | f i ) p ( f i | f ) d f i ] p ( f ) i =1 p ( f | y ) = p ( y ) p ( y ) where we have used the FITC approximation p ( f | f ) ⇡ Q N i =1 p ( f i | f ). 6 / 22
Model Specification f k ( x i ), which gives the likelihood : We consider that y i = arg max k p ( y | f ) = Q N i =1 p ( y i | f i ) = Q N Q k 6 = y i Θ ( f y i ( x i ) � f k ( x i )) i =1 R The posterior approximation is also set to be q ( f ) = p ( f | f ) q ( f ) d f . The posterior over f is: ⇡ [ Q N R R p ( y | f ) p ( f | f ) d f p ( f ) p ( y i | f i ) p ( f i | f ) d f i ] p ( f ) i =1 p ( f | y ) = p ( y ) p ( y ) where we have used the FITC approximation p ( f | f ) ⇡ Q N i =1 p ( f i | f ). The corresponding likelihood factors are: Z hQ �i Q C k ) d f i f y i � � f k k =1 p ( f k � i ( f ) = k 6 = y i Θ i | f i i 6 / 22
Recommend
More recommend