maximum likelihood with bias corrected calibration is
play

Maximum Likelihood with Bias-Corrected Calibration is Hard-To-Beat - PowerPoint PPT Presentation

Maximum Likelihood with Bias-Corrected Calibration is Hard-To-Beat at Label Shift Adaptation Amr M. Alexandari*, Anshul Kundaje, Avanti Shrikumar * *co-first authors co-corresponding authors Amr Alexandari Anshul Kundaje PhD Student


  1. Maximum Likelihood with Bias-Corrected Calibration is Hard-To-Beat at Label Shift Adaptation Amr M. Alexandari*, Anshul Kundaje†, Avanti Shrikumar *† *co-first authors †co-corresponding authors Amr Alexandari Anshul Kundaje PhD Student Assistant Professor Dept. of Computer Science Depts. of CS & Genetics

  2. Label Shift Illustrated Train Model

  3. Label Shift Illustrated Original model under-predicts

  4. Label Shift Illustrated update

  5. Label Shift Illustrated We don’t have How do we ground-truth update our ? labels for the classifier? new patients!

  6. Main Contributions - An approach that achieves state-of-the-art on label shift adaptation - Scales to datasets with high-dimensional inputs - Does not require model retraining - Combines Max Likelihood with specific types of calibration. - Calibration with Temp. Scaling (TS) was insufficient (& sometimes harmful!) - Achieved state-of-the-art with extensions of TS (one of which we propose) that correct for systematic bias

  7. Formal Definition of Label Shift Let: - 𝑧 denote our labels (whether or not person has disease) - 𝒚 denote the observed symptoms - 𝑞(𝒚, 𝑧) denote joint distribution (𝒚, 𝑧) at beginning of outbreak (“source domain”) - 𝑟(𝒚, 𝑧) denote joint distribution at widespread stage (“target domain”), when we don’t know labels - Goal: adapt source-domain classifier that predicts 𝑞(𝑧|𝒚) to instead predict 𝑟(𝑧|𝒚) for target domain Core assumption: disease has same symptoms irrespective of outbreak stage, i.e. 𝑞 𝒚 𝑧 = 𝑟(𝒚|𝑧) . - Thus, difference between source & target domain is exclusively caused by shift in label proportions 𝑞(𝑧) and 𝑟(𝑧) . Formally, 𝑟 𝒚, 𝑧 = 𝑞 𝒚|𝑧 𝑟 𝑧 - Also called prior probability shift (Amos, 2008), corresponds to “anti-causal learning” i.e. predicting cause 𝑧 from effects 𝒚 (Schloelkopf, 2012). - Anti-causal learning is appropriate here because diseases status 𝑧 cause the symptoms 𝒚 .

  8. Estimating 𝑟 𝑧 𝒚 with Bayes’ Rule - Although 𝑞(𝒚|𝑧) is preserved, computing it is hard when 𝒚 is high-dimensional. - Much easier to estimate 𝑞(𝑧|𝒚) and 𝑞(𝑧) from the source domain, as 𝑧 is lower-dimensional. - If we know 𝑟(𝑧) , we can retrieve 𝑟 𝑧 𝑦 without ever estimating 𝑞 𝒚 𝑧 using Bayes’ Rule (first shown in Saerens et al., 2002): !(#,𝒚) !(𝒚|#)!(#) We first write 𝑟 𝑧 𝒚 = !(𝒚) = ∑ !∗ !(𝒚|# ∗ )!(# ∗ ) (terms in red are not explicitly known) )(𝒚|#)!(#) Substituting 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) (label shift assumption), we have 𝑟 𝑧 𝒚 = ∑ !∗ )(𝒚|# ∗ )!(# ∗ ) Through Bayes’ rule, observe that 𝑞 𝒚 𝑧 = )(#|𝒚))(𝒚) )(#) #(!|𝒚)#(𝒚) !(#) #(!) Substituting, we get 𝑟 𝑧 𝒚 = Reminders: #(!|𝒚)#(𝒚) ∑ ! !(#) - 𝒚 denotes features (e.g. symptoms) #(!) - 𝑧 denotes labels (e.g. disease status) #(!|𝒚) #(!) !(#) - 𝑞 indicates source-domain (labels known) 𝑞(𝑦) cancels out, giving 𝑟 𝑧 𝒚 = #(!|𝒚) - 𝑟 indicates target domain (labels unknown) ∑ ! #(!) !(#) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧)

  9. Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  10. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  11. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  12. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  13. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  14. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  15. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  16. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

  17. A Simple Iterative Approach to Label Shift… In practice, we are not told 𝑟(𝑧) – how can we estimate it? - Could use 𝑞(𝑧|𝒚) to predict on test set & average predictions to estimate 𝑟 𝑧 - Could then use 𝑟(𝑧) to update 𝑞(𝑧|𝒚) , and repeat the process until convergence! update Reminders: - 𝒚 denotes features (e.g. symptoms) - 𝑧 denotes labels (e.g. disease status) - 𝑞 indicates source-domain (labels known) - 𝑟 indicates target domain (labels unknown) - Label shift assumes 𝑟 𝒚 𝑧 = 𝑞(𝒚|𝑧) - If we estimate 𝑞(𝑧|𝒚) , 𝑞(𝑧) from source data & are told 𝑟(𝑧) , we can find 𝑟(𝑧|𝒚) using Bayes’ rule

Recommend


More recommend