Scalable Exact Inference in Multi-Output Gaussian Processes Wessel P. Bruinsma 1 , 2 , Eric Perim 2 , Will Tebbutt 1 , J. Scott Hosking 3 , 4 , Arno Solin 5 , Richard E. Turner 1 , 6 1 University of Cambridge, 2 Invenia Labs, 3 British Antarctic Survey, 4 Alan Turing Institute, 5 Aalto University, 6 Microsoft Research International Conference on Machine Learning 2020
Collaborators Wessel P. Eric Perim Will Tebbutt Bruinsma J. Scott Arno Solin Richard E. Hoskings Turner
Introduction and Motivation
Introduction 1/17 • Gaussian processes are a powerful and popular probabilistic modelling framework for nonlinear functions. f 2 Central modelling choice: K ( t, t ′ ) � cov( f 1 ( t ) , f 1 ( t ′ )) cov( f 1 ( t ) , f 2 ( t ′ )) � f 1 = cov( f 2 ( t ) , f 1 ( t ′ )) cov( f 2 ( t ) , f 2 ( t ′ )) t ′ t ′ t • Inference and learning: O ( n 3 p 3 ) time and O ( n 2 p 2 ) memory. • Often alleviated by exploiting structure in K . number of outputs
Instantaneous Linear Mixing Model (ILMM) 2/17 K ( t, t ) = I m x ∼ GP ( 0 , K ( t, t ′ )) , h 1 x 1 ( t ) h 2 x 2 ( t ) f ( t ) = h 1 x 1 ( t ) + h 2 x 2 ( t ) f ( t ) = Hx ( t ) , y ( t ) ∼ N ( f ( t ) , Σ ) , x : “latent processes” , H : “basis” or “mixing matrix” . 0 • Use m ≪ p basis vectors: data lives in “pancake” around col ( H ) . • Generalisation of FA to time series setting. • Captures many existing MOGPs from literature. • Inference and learning: O ( m 3 n 3 ) instead of O ( p 3 n 3 ) .
Inside the ILMM
Key Result 3/17 “projected observation” high-dim. observation for x ∼ GP ( 0 , K ( t, t ′ )) • • • . �→ . . . m ( ≪ p ) p . . • • • y y proj = Ty ✗ inference in p ( y ) � inference in p ( x ) noise: Σ projected noise: Σ T Proposition: This is exact!
Key Result (2) 4/17 inference Y p ( f | Y ) ✗ O ( n 3 p 3 ) reconstruction projection O ( nmp ) O ( nmp ) inference p ( x | TY ) TY � O ( n 3 m 3 )
Key Result (3) 5/17 likelihood of projected observations under projected noise n � � log p ( Y ) = log p ( x ) N ( Ty i | x i , Σ T ) d x i =1 n − 1 − 1 2 n log | Σ | � � y i − HTy i � 2 + const. Σ 2 | Σ T | i =1 data “lost” by projection noise “lost” by projection (reconstruction error) • Learning H ⇔ learning T ⇔ learning a transform of the data! • “Regularisation terms” prevent underfitting.
Key Insight 6/17 • Inference in ILMM: condition x on Y proj under noise Σ T . • Hence, if x are independent under the prior and the projected noise Σ T is diagonal, then x remain independent upon observing data. Treat latent processes independently: condition x i on ( Y proj ) i : under noise ( Σ T ) ii ! • Decouples inference into independent single-output problems.
“Decoupling” the ILMM
Orthogonal ILMM (OILMM) 7/17 x ∼ GP ( 0 , K ( t, t ′ )) , h 1 x 1 ( t ) h 2 x 2 ( t ) f ( t ) = Hx ( t ) f ( t ) 1 2 x ( t ) , = US h 2 x 2 ( t ) orthogonal diagonal scaling y ( t ) ∼ N ( f ( t ) , Σ ) . h 1 x 1 ( t ) 0 Key property: Σ T is diagonal! orthogonality constraint ILMM OILMM varying varying time time orthogonality constraint FA PPCA
Benefits of Orthogonality 8/17 inference p ( f | Y ) Y ✗ O ( n 3 p 3 ) reconstruction O ( nmp ) projection O ( nmp ) p ( x 1 | ( TY ) 1: ) inference � O ( n 3 ) . . TY . inference � O ( n 3 ) p ( x m | ( TY ) m : ) • Linear scaling in m ! • Trivially compatible with single-output scaling techniques!
Benefits of Orthogonality (2) 9/17 • • 1 Project data and compute proj. noise: • . Σ T = σ − 2 S − 1 + D . Y proj = S − 1 . 2 U T Y , • . • 2 For i = 1 , . . . , m , compute the log-probability LML i of ( Y proj ) : i under latent process x i and observation noise ( Σ T ) ii . 3 Compute the “regularisation term”: reg. = − n 2 log | S |− n ( p − m ) log 2 πσ 2 − 1 2 σ 2 � ( I p − UU T ) Y � 2 F 2 4 Construct the log-probability of the data Y under the OILMM: m � log p ( Y ) = LML i + reg. i =1
Complexities of MOGPs 10/17 Class Complexity more restrictive Use single-output scaling techniques O ( p 3 n 3 ) MOGP to also bring down complexity in n . O ( m 3 n 3 ) ILMM O ( mn 3 ) OILMM O ( mnr 2 ) ( r inducing points) O ( mnd 3 ) ( d -dim. state-space approximation) Orthogonality gives excellent computational benefits. But how restrictive is it?
Generality of the OILMM 11/17 Definition An (O)ILMM is separable if K ( t, t ′ ) = k ( t, t ′ ) I m . Example: ICM. ILMM versus OILMM: • Separable case: without loss of generality. • Non-separable case: only affects correlations through time. • ILMM can be approximated by an OILMM (in KL) if the right singular vectors of H are close to unit vectors (in � • � F ). • Separable spatio–temporal GP is an OILMM. • OILMM gives non-separable relaxation of separable models whilst retaining efficient inference.
Missing Data 12/17 • Missing data is troublesome: it breaks orthogonality of H . • In the paper, we derive a simple and effective approximation.
The OILMM in Practice
Demonstration of Scalability 13/17 ILMM ILMM 400 40 OILMM OILMM Memory (GB) 300 30 Time (s) 200 20 100 10 0 0 1 5 10 15 20 25 1 5 10 15 20 25 Number of latent processes m Number of latent processes m
Demonstration of Generality 14/17 EEG FX PPLP SMSE PPLP SMSE ILMM − 2 . 11 0 . 49 3 . 39 0 . 19 OILMM − 2 . 11 0 . 49 3 . 39 0 . 19 • Near identical performance on two real-world data sets. • Demonstrates that missing data approximation works well.
Case Study: Climate Simulators 15/17 Temp. (K) ACCESS1-0 ACCESS1-3 300 BNU-ESM Simulator CCSM4 280 CMCC-CM CNRM-CM5 260 CSIRO-Mk3-6-0 CanAM4 EC-EARTH 240 FGOALS-g2 1979-01-01 1984-01-01 1989-01-01 1994-01-01 1999-01-01 2004-01-01 • Jointly model p s = 28 climate simulators at p r = 247 spatial locations and n = 10 000 points in time. • Equals p = p s p r ≈ 7 k outputs and pn ≈ 70 M observations. • Goal: Learn covariance between simulators with H = H s ⊗ H r . • Use m = 50 and inducing points to scale decoupled problems.
Case Study: Climate Simulators (2) 16/17 Empirical correlations Learned by OILMM ACCESS1.0 ACCESS1.3 BCC_CSM1.1 BCC_CSM1.1(m) BNU-ESM CCSM4 CMCC-CM CNRM-CM5 CSIRO-Mk3.6.0 CanAM4 EC-EARTH FGOALS-g2 FGOALS-s2 GFDL-CM3 GFDL-HIRAM-C180 GFDL-HIRAM-C360 HadGEM2-A INMCM4 IPSL-CM5A-LR IPSL-CM5A-MR IPSL-CM5B-LR MIROC5 MPI-ESM-LR MPI-ESM-MR MRI-AGCM3.2H MRI-AGCM3.2S MRI-CGCM3 NorESM1-M
Conclusion 17/17 Use projection of the data to accelerate inference in MOGPs with orthogonal bases: � Linear scaling in m . � Simple to implement. � Trivially compatible with single-output scaling techniques. � Does not sacrifice significant expressivity.
Recommend
More recommend