 
              Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis Pranav Poduval, IIT Bombay Medical Imaging and Deep Learning (MIDL) 2020, Canada Co-authors: Hrushikesh Loya, Amit Sethi (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 1 / 12
Motivation for Bayesian Inference Deep learning is starting to show promise in multiple domains e.g. radiology, cancer detection etc. But we still have to solve a number of issues - Does it make sense to pass on obscure values like 0.1 positive chance to Doctors in order to take medical decisions? What is to be done, when the model see’s something it has never seen before? Bayesian Inference is the tool used to ”know what we don’t know” (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 2 / 12
Bayesian Inference in Neural Networks The object of interest when we are given an input data point ( x ∗ , y ∗ ) is q ( y ∗ | x ∗ ) which is obtained by marginalizing out the parameter θ i.e. The output distribution for test point ( x ∗ , y ∗ ) is give as - � q ( y ∗ | x ∗ , D ) = q ( y ∗ | x ∗ , θ ) p ( θ | D ) d θ (1) Now since exact integration computation is impossi- ble in case of Neural Networks we often use Monte- Carlo sampling to approximate the same - N q ( y ∗ | x ∗ , D ) ≈ 1 q ( y ∗ | x ∗ , θ i ) , θ i ∼ p ( θ | D ) (2) � N i =1 So as seen from the figure we can view them as ensembles of similar networks with N different parameters. (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 3 / 12
Variational Inference Exact Bayesian Inference involves computing the true posterior p ( θ | D ) according to Bayes rule after defining a prior p ( θ ) on the weight space p ( θ | D ) = p ( D | θ ) p ( θ ) � , where p ( D ) = p ( D | θ ) p ( θ ) d θ (3) p ( D ) Since intractable we use a trick by defining a surrogate posterior q φ ( θ ), (which could be Gaussian and in this case φ = ( µ, Σ)) and try to bring this surrogate posterior as close to the true posterior as possible. Therefore we define ELBO loss as - KL[ q φ ( θ ) || p ( θ | D )] = − E θ ∼ q ( θ ) [log( p ( D | θ ))] + KL[ q φ ( θ ) || p ( θ )] + C (4) The first term can be viewed as Expected Cross Entropy in case of classification or Expected MSE in case of regression, and the the second term as a Regularizer . (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 4 / 12
Priors and what meaning do they have? For classification among K classes, deep neural networks represent a function f θ : X → p ∈ [0 , 1] K , where X represents the input, and p represents a probability mass function such that � K i =1 p i = 1. For making predictions we assume the output distribution is - p ( Y | X , θ ) = Cat( Y | p ), where p = f θ ( X ) i.e. the softmax output of the Neural Network. (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 5 / 12
Priors and what meaning do they have? Clearly there exists a map θ → f ( . ), meaning a prior on θ implicitly defines a prior measure on the space of f , denoted as p ( f ). We therefore skip steps and directly define a uniform prior on the K -dimensional unit simplex for the functional space, such that p ( f ) = Dir( p |� 1 , . . . , 1 � ) (5) Figure: Ideal Prior for making OOD samples uncertain A completely uncertain prior. This indicates regardless of the input we are always uncertain of the output. (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 6 / 12
Functional Space Variational Inference For analytical tractability we assume the marginal posterior is also a Dirichlet distribution. In other words, unlike for a standard neural network where p = f θ ( x ) is the point estimate output, in our case Dir( p | α ) = q θ ( f ( x )) is the marginal functional distribution. This is similar to how a Gaussian process has a multivariate Gaussian as its marginal distribution. Figure: Fig (a) (left) A case where the Functional VI model is very confident of it belonging to all three classes whereas Fig (b) (right) Is the case where a regular Bayesian NN model (e.g.Dropout, Ensemble etc.) is confident of it belonging to a particular class (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 7 / 12
Functional Space Variational Inference So in our model given the training data D = ( X D , y D ) and the test points ( x ∗ , y ∗ ) we have: � p ( y ∗ | x ∗ , D ) = p ( y ∗ | p ) p ( p | x ∗ , D ) d p (6) As usual p ( y ∗ | p ) = Cat( y ∗ | p ), but the difference lies in the fact the neural network estimates a Dirichlet distribution p ( p | x ∗ , D ) = Dir( p | α ). For standard neural network where p = f θ ( x ) is the point estimate output, in our case Dir( p | α ) = q θ ( f ( x )) is the marginal functional distribution. The true functional posterior p ( f | D ) is intractable, but it can be approximated by minimizing the functional evidence lower bound (fELBO): L ( q ) = − E q ( f ) [log p ( y D | f ( X D ))] + KL[ q ( f ) || p ( f )] (7) (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 8 / 12
Functional Space Variational Inference The second term in Equation 7 is the functional KL divergence, which is hard to estimate. Therefore, we shift to a more familiar metric, the KL divergence between the marginal distributions of function values at finite sets of points x 1: n : L 2 = KL( q ( f ) || p ( f )) = sup KL [ q ( f ( x 1: n ) || p ( f ( x 1: n )] (8) x 1: n A more relaxed way of sampling these “measure points” x 1: n , is to assume x 1: k ∼ X D (training distribution) and x k +1: n ∼ c where c is a distribution having the same support as the training distribution, which could be OOD samples, that can be forced to be more uncertain. Note: the KL divergence between two Dirichlet distributions can be computed in closed form. (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 9 / 12
Functional Space Variational Inference We get a closed form solution for the first part in Equation 7 by assuming y to be a one-hot vector as follows: � � K � K 1 � � p α i − 1 L 1 = − log p ( y i | p ) (9) d p i B ( α ) i =1 i =1 By assuming p ( y | p ) = Cat( y | p ) we have- � � K   � K K K 1 � − log p y i � p α i − 1 � � L 1 = d p =  ̥ ( α j ) − ̥ ( α i ) y i  i i B ( α ) i =1 i =1 i =1 j =1 (10) Where B ( α ) is the Beta distribution and ̥ (.) is the digamma function. Combining L 1 + L 2 we will get the same loss function as Evidential Deep Learning (NeurIPS 2018) and has a simple closed form solution. (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 10 / 12
Expected Calibration Error If we have a well calibrated weather prediction model that predicts sunny event with 80% probability for 100 days then, any deviation from 80 sunny days and 20 non-sunny days will imply a poorly calibrated model. Important for model interpretability. Table: Comparison of classification accuracy and ECE on HAM10000 dataset Method Standard NN Dropout Ensembles Functional VI Test Accuracy 84.38% 86.32 % 85.21% 84.84% ECE (M = 15) 7.73% 6.39% 3.12% 1.17 % (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 11 / 12
Additional Experiment We observe our model is very confident on Nevi (NV) class, which is expected since it make majority of the dataset. We can also see our OOD samples can be distinctly separated from the in-class samples. The OOD sample used for training and testing are from different distributions. For simplicity we used Gaussian Distribution for training OOD samples and Uniform Distribution for testing OOD samples. (Indian Institute of Technology, Bombay) Functional Space Variational Inference for Uncertainty Estimation in Computer Aided Diagnosis 12 / 12
Recommend
More recommend