Simultaneous Inference for Massive Data: Distributed Bootstrap Yang Yu 1 , Shih-Kang Chao 2 , Guang Cheng 1 1 Purdue University 2 University of Missouri ICML 2020
We have N i.i.d. data points: Z 1 , . . . , Z N Estimation: Fit a model that has an unknown parameter θ ∈ R d by minimizing the empirical risk � N 1 � θ : = arg min L ( θ ; Z i ) N θ ∈ R d i =1
We have N i.i.d. data points: Z 1 , . . . , Z N Estimation: Fit a model that has an unknown parameter θ ∈ R d by minimizing the empirical risk � N 1 � θ : = arg min L ( θ ; Z i ) N θ ∈ R d i =1 Ideally, we want � θ to be close to the expected risk minimizer θ ∗ : = arg min θ ∈ R d E Z [ L ( θ ; Z )]
We have N i.i.d. data points: Z 1 , . . . , Z N Estimation: Fit a model that has an unknown parameter θ ∈ R d by minimizing the empirical risk � N 1 � θ : = arg min L ( θ ; Z i ) N θ ∈ R d i =1 Ideally, we want � θ to be close to the expected risk minimizer θ ∗ : = arg min θ ∈ R d E Z [ L ( θ ; Z )] Examples: ◮ Linear regression: Z = ( X, Y ) , L ( θ ; Z ) = ( Y − X ⊤ θ ) 2 / 2 ◮ Logistic regression: Z = ( X, Y ) , L ( θ ; Z ) = − Y X ⊤ θ + log(1 + exp[ X ⊤ θ ])
Inference: Find an interval [ L, U ] s.t. P ( θ ∗ 1 ∈ [ L, U ]) ≈ 95%
Inference: Find an interval [ L, U ] s.t. P ( θ ∗ 1 ∈ [ L, U ]) ≈ 95% Simultaneous Inference: Find intervals [ L 1 , U 1 ] , . . . , [ L d , U d ] s.t. P ( θ ∗ 1 ∈ [ L 1 , U 1 ] , . . . , θ ∗ d ∈ [ L d , U d ]) ≈ 95%
Inference: Find an interval [ L, U ] s.t. P ( θ ∗ 1 ∈ [ L, U ]) ≈ 95% Simultaneous Inference: Find intervals [ L 1 , U 1 ] , . . . , [ L d , U d ] s.t. P ( θ ∗ 1 ∈ [ L 1 , U 1 ] , . . . , θ ∗ d ∈ [ L d , U d ]) ≈ 95% How to perform Simultaneous Inference:
Inference: Find an interval [ L, U ] s.t. P ( θ ∗ 1 ∈ [ L, U ]) ≈ 95% Simultaneous Inference: Find intervals [ L 1 , U 1 ] , . . . , [ L d , U d ] s.t. P ( θ ∗ 1 ∈ [ L 1 , U 1 ] , . . . , θ ∗ d ∈ [ L d , U d ]) ≈ 95% How to perform Simultaneous Inference: Step 1: Compute point estimator � θ
Inference: Find an interval [ L, U ] s.t. P ( θ ∗ 1 ∈ [ L, U ]) ≈ 95% Simultaneous Inference: Find intervals [ L 1 , U 1 ] , . . . , [ L d , U d ] s.t. P ( θ ∗ 1 ∈ [ L 1 , U 1 ] , . . . , θ ∗ d ∈ [ L d , U d ]) ≈ 95% How to perform Simultaneous Inference: Step 1: Compute point estimator � θ � � √ θ − θ ∗ �� �� � Step 2: Estimate the 0 . 95 -quantile c (0 . 95) of N ∞ (by bootstrap)
Inference: Find an interval [ L, U ] s.t. P ( θ ∗ 1 ∈ [ L, U ]) ≈ 95% Simultaneous Inference: Find intervals [ L 1 , U 1 ] , . . . , [ L d , U d ] s.t. P ( θ ∗ 1 ∈ [ L 1 , U 1 ] , . . . , θ ∗ d ∈ [ L d , U d ]) ≈ 95% How to perform Simultaneous Inference: Step 1: Compute point estimator � θ � √ � θ − θ ∗ �� �� � Step 2: Estimate the 0 . 95 -quantile c (0 . 95) of N ∞ (by bootstrap) Step 3: For l = 1 , . . . , d , θ l − � c (0 . 95) θ l + � c (0 . 95) L l = � U l = � √ √ , N N
Distributed framework: Distribute N data points evenly across k machines s.t. each machine stores n = N/k data points ◮ 1 master node M 1 ◮ k − 1 worker nodes M 2 , M 3 , . . . , M k ◮ Z ij : the i -th data point at machine M j
Distributed Simultaneous Inference 1 Kleiner, et al. "A scalable bootstrap for massive data." JRSS-B (2014) 2 Sengupta, et al. "A subsampled double bootstrap for massive data." JASA (2016)
Distributed Simultaneous Inference Step 1: Compute � θ ◮ Can be approximated by existing efficient distributed estimation methods 1 Kleiner, et al. "A scalable bootstrap for massive data." JRSS-B (2014) 2 Sengupta, et al. "A subsampled double bootstrap for massive data." JASA (2016)
Distributed Simultaneous Inference Step 1: Compute � θ ◮ Can be approximated by existing efficient distributed estimation methods Step 2: Bootstrap c (0 . 95) ◮ Traditional bootstrap cannot be efficiently applied in the distributed framework ◮ BLB 1 and SDB 2 are computationally expensive due to repeated resampling and not suitable for large k 1 Kleiner, et al. "A scalable bootstrap for massive data." JRSS-B (2014) 2 Sengupta, et al. "A subsampled double bootstrap for massive data." JASA (2016)
Question: How can we efficiently do Step 2 in a distributed manner?
Question: How can we efficiently do Step 2 in a distributed manner? Our contributions: ◮ We propose communication-efficient and computation-efficient distributed bootstrap methods: k-grad and n+k-1-grad ◮ We prove a sufficient number of communication rounds that guarantees statistical accuracy and efficiency
Approximate by sample average: � � � n � k √ � � � E [ ∇ 2 L ( θ ∗ ; Z )] − 1 1 N ( � � � θ − θ ∗ ) � ∞ ≈ ∇L ( θ ∗ ; Z ij ) � √ � N ∞ i =1 j =1
Approximate by sample average: � � � n � k √ � � � E [ ∇ 2 L ( θ ∗ ; Z )] − 1 1 N ( � � � θ − θ ∗ ) � ∞ ≈ ∇L ( θ ∗ ; Z ij ) � √ � N ∞ i =1 j =1 iid Multiplier bootstrap: ǫ ij ∼ N (0 , 1) for i = 1 , . . . , n and j = 1 , . . . , k � � � √ � n � k � � � E [ ∇ 2 L ( θ ∗ ; Z )] − 1 1 � D N ( � � � θ − θ ∗ ) � ∞ ǫ ij ∇L ( θ ∗ ; Z ij ) � ≈ √ � { Z ij } i,j � N ∞ i =1 j =1
Approximate by sample average: � � � n � k √ � � � E [ ∇ 2 L ( θ ∗ ; Z )] − 1 1 N ( � � � θ − θ ∗ ) � ∞ ≈ ∇L ( θ ∗ ; Z ij ) � √ � N ∞ i =1 j =1 iid Multiplier bootstrap: ǫ ij ∼ N (0 , 1) for i = 1 , . . . , n and j = 1 , . . . , k � � � √ � n � k � � � E [ ∇ 2 L ( θ ∗ ; Z )] − 1 1 � D N ( � � � θ − θ ∗ ) � ∞ ǫ ij ∇L ( θ ∗ ; Z ij ) � ≈ √ � { Z ij } i,j � N ∞ i =1 j =1 iid k-grad (computed at M 1 ): ǫ j ∼ N (0 , 1) for j = 1 , . . . , k � � � k √ � � √ n ( g j − ¯ � Θ 1 � D N ( � �� � � θ − θ ∗ ) � ∞ � ≈ W : = √ ǫ j g ) � { Z ij } i,j � k ∞ j =1 n � g j = 1 ∇L (¯ where θ ; Z ij ) computed at M j , transmitted to M 1 n i =1 � 1 � − 1 � k � n g = 1 ∇ 2 L (¯ � ¯ g j averaged at M 1 , Θ = θ ; Z i 1 ) computed at M 1 k n j =1 i =1
k-grad fails for small k !
k-grad fails for small k ! Solution: n+k-1-grad (computed at M 1 ): iid ǫ i 1 , ǫ j ∼ N (0 , 1) for i = 1 , . . . , n and j = 2 , . . . , k � � �� � � n � k � � √ n ( g j − ¯ 1 � � � �� � √ ǫ i 1 ( g i 1 − ¯ � { Z ij } i,j W : = Θ g ) + ǫ j g ) � n + k − 1 ∞ i =1 j =2 g i 1 = ∇L (¯ where θ ; Z i 1 ) computed at M 1
An example algorithm: apply k-grad / n+k-1-grad with CSL estimator 3 1 Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)
An example algorithm: apply k-grad / n+k-1-grad with CSL estimator 3 Step 1: compute point estimator � θ ( τ rounds of communication) θ (0) ← arg min θ L 1 ( θ ) at M 1 1: � 2: for t = 1 , . . . , τ do θ ( t − 1) to {M j } k Transmit � 3: j =2 θ ( t − 1) ) − 1 at M 1 Compute ∇L 1 ( � θ ( t − 1) ) and ∇ 2 L 1 ( � 4: for j = 2 , . . . , k do 5: Compute ∇L j ( � θ ( t − 1) ) at M j 6: Transmit ∇L j ( � θ ( t − 1) ) to M 1 7: θ ( t − 1) ) ← k − 1 � k ∇L N ( � j =1 ∇L j ( � θ ( t − 1) ) at M 1 8: θ ( t ) ← � θ ( t − 1) − ∇ 2 L 1 ( � � θ ( t − 1) ) − 1 ∇L N ( � θ ( t − 1) ) at M 1 9: 1 Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)
An example algorithm: apply k-grad / n+k-1-grad with CSL estimator 3 Step 1: compute point estimator � θ ( τ rounds of communication) θ (0) ← arg min θ L 1 ( θ ) at M 1 1: � 2: for t = 1 , . . . , τ do θ ( t − 1) to {M j } k Transmit � 3: j =2 θ ( t − 1) ) − 1 at M 1 Compute ∇L 1 ( � θ ( t − 1) ) and ∇ 2 L 1 ( � 4: for j = 2 , . . . , k do 5: Compute ∇L j ( � θ ( t − 1) ) at M j 6: Transmit ∇L j ( � θ ( t − 1) ) to M 1 7: θ ( t − 1) ) ← k − 1 � k ∇L N ( � j =1 ∇L j ( � θ ( t − 1) ) at M 1 8: θ ( t ) ← � θ ( t − 1) − ∇ 2 L 1 ( � � θ ( t − 1) ) − 1 ∇L N ( � θ ( t − 1) ) at M 1 9: Step 2: bootstrap quantile c (0 . 95) ( 0 round of communication) θ ( t − 1) at M 1 10: Run k-grad / n+k-1-grad with ¯ θ = � 1 Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)
Recommend
More recommend