learning neural causal models from unknown interventions
play

Learning Neural Causal Models From Unknown Interventions Summary - PowerPoint PPT Presentation

Learning Neural Causal Models From Unknown Interventions Summary The relationship between each variable and its parents is modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed


  1. Learning Neural Causal Models From Unknown Interventions

  2. Summary • The relationship between each variable and its parents is modeled by a neural network, modulated by structural meta-parameters which capture the overall topology of a directed graphical model. • Assume Random intervention on a single unknown variable of an unknown ground truth causal model. • To disentangle the slow-changing aspects of each conditional from the fast-changing adaptations to each intervention, the neural network is parameterized into fast parameters and slow meta-parameters.

  3. Summary • meta-learning objective that favors solutions robust to frequent but sparse interventional distribution change. • Challenging aspect of this setting is to not only learn the causal graph structure, but also predict the intervention accurately.

  4. Task Description • At most one intervention is concurrently performed. • Soft/imperfect Interventions: the conditional distribution of the variable on which the intervention is performed is changed. • Provided data: • data from the original ground-truth model. • data from a modified ground-truth model with a random intervention applied. • Learner aware of the intervention, but not aware of the node. (each run an episode) • The learner, over a large number of episodes, will experience all nodes being intervened upon, and should be able to infer the SCM from these interventions.

  5. Objectives • Avoid an exponential search over all possible DAGs • Handle unknown interventions • Model the e ff ect of interventions • Model the underlying causal structure.

  6. Model • Function approximation: 1 NN/var • Belief over : drop-out probability for i-th input of i → j network j • Represents all 2 M 2 possible graphs. • Learning drop-out probability • Prevents discrete search

  7. Model • SCM ( Categorical Random Variables and Categories): M N • Configuration C ∈ {0,1} M × M • C n counts the number of length-n walks from node to node i j of the graph in element . c ij • Tr(C n ) counts the number of length-n cycles in the graph • C = ?

  8. Model • Consider two node graph . A , B • : or = 1 (versus 0) M = 2 c AB c BA • and P ( c AB = 1) = σ ( γ AB ) P ( c BA = 1) = σ ( γ BA ) • , • Problem becomes simultaneously learning the structural meta-parameters γ and the functional meta-parameters . θ • : Easily learned by ML (Back Propagation) θ • : More Di ffi cult (Bengio et al. 2019) γ

  9. Problems • An M-variable SCM over random variables can induce X i a super-exponential number of adjacency matrices . C • The super-exponential explosion in the number of potential graph connectivity patterns • The super-exponentially growing storage requirements of their defining conditional probability tables make CPT- based parametrization of the structural assignments f i increasingly unwieldy as M scales.

  10. Solution • Neural networks with -masked inputs can provide a c ij more manageable parametrization.

  11. Proposed Method • Disentangle : θ • : The slow-changing meta-parameters, which θ slow reflect the stationary properties discovered by the learner. • : The fast-changing parameters, which adapt in θ fast response to interventional changes in distribution.

  12. Proposed Method • Two kinds of meta-parameters: the causal graph structure and the model’s slow weights . γ θ slow • Model’s fast weights . θ fast • , the sum of the slow, stationary meta- θ = θ slow + θ fast parameters and the fast, adaptational parameters.

  13. Optimization Problem • The strategy of considering all the possible structural graphs as separate hypotheses is not feasible because it would require 2 M 2 maintaining O( ) models of the data. • Sampling independently using Bernoulli Distribution. c ij • We only need to learn the M 2 coe ffi cients . γ ij • a slight dependency between the is induced if we require the c ij causal graph to be acyclic. • A regularizer acting on the . γ

  14. Optimization Problem • Each random variable . X i = f θ i ( c i 0 × X 0 , c i 1 × X 1 , …, c im × X m , ϵ i ) • is a neural network (MLP) with parameters f θ i θ i c ij ∼ Bin ( σ ( γ ij ) ) • And • optimized to maximize the likelihood of data under the model θ • optimized with respect to a meta-learning objective arising from changes in γ distribution because of interventions. • Analogous to an ensemble of neural nets di ff ering by their binary input dropout masks, which select what variables are used as predictors of another variable.

  15. Learning • To disentangle: • Environment’s stable, unchanging properties (the causal structure) • From unstable, changing properties (the e ff ects of an intervention) • MLP: , where . P i ( X i | X pa ( i ) ; θ i ) θ = θ slow + θ fast

  16. Learning • are reset after each episode of transfer distribution adaptation θ fast • Since an intervention is generally not persistent from one transfer distribution to another. • meta-parameters ( ) are preserved, then updated after each θ slow , γ episode. • The meta-objective for each meta-example over some intervention distribution is the following: D int

  17. Learning • The meta-objective for each meta-example over some intervention distribution is the following meta-transfer loss: D int • is an example sampled from the intervention distribution , is an X D int C adjacency matrix drawn from our belief distribution (parametrized by ) about γ graph structure configurations. • The likelihood of the -th variable of the sample when predicting it under i X i X the configuration from the set of its putative parents: C

  18. Learning • A discrete Bernoulli random sampling process is used to produce the configurations under which the log-likelihood of data samples is obtained. • A gradient estimator is required to propagate gradient through to the γ structural meta-parameters. • ( k ) superscript indicates the values obtained for the -th draw of . k C • This gradient is estimated solely with because estimates employing θ slow have much greater variance. θ

  19. Acyclic Constraint

  20. Acyclic Constraint

  21. Predicting Interventions • After an intervention on , the gradients into and the X i γ i slow weights for the -th conditional are false, because i they do not bear the blame for ’s outcome (which lies X i with the intervener). • The conditional likelihood of the intervened variable tends to have a poorer relative likelihood under . D int • Hence, the variable with the greatest deterioration in likelihood is picked as a good guess.

  22. Model Description • The MLPs are identical in shape but do not share any parameters, since they are modeling independent causal mechanisms. (M one-hot vectors (nodes) of length N each) X i = f θ i ( c i 0 × X 0 , c i 1 × X 1 , …, c im × X m , ϵ i ) c ij ∼ Bin ( σ ( γ ij ) )

  23. Stability of Training • Simultaneous training of both the structural and the functional meta-parameters. • These are not independent and do influence each other, which leads to instability in training. • Pre-train the model under observational data (from the distribution of the data before interventions) using dropout on the inputs. • functional meta-parameters are not too biased θ slow towards certain configurations of the meta-parameters . γ

  24. Regularizers • DAG Constraint • Sparsity: • sparse representation of edges in the causal graph. • L1 regularizer • Slightly faster convergence

  25. Temperature • A temperature hyperparameter to encourage the groundtruth model to generate some very rare events in the conditional probability tables (CPTs) more frequently. • The near-ground-truth MLP model’s logit outputs are divided by the temperature before being used for sampling.

  26. All in all: Algorithm Pre-train on Observational Data Predict Intervened Node

  27. Simulations: Synthetic Data The results are, however, sensitive to some hyperparameters, notably the DAG penalty and the sparsity penalty.

  28. Simulations: Real-World Data • BNLearn: Earthquake, Cancer, Asia • -variables respectively, maximum 2 parents per node. M = 5,5,8 • Learn a near-ground-truth MLP from the dataset’s CPT and use it as the ground-truth data generator. • In spite of same causal graphs, CPTs were di ff erent; hence di ff erent SCMs.

  29. Simulations: Comparison • Peters et al., (2016): ICP • Eaton & Murphy (2007a): uncertain interventions • Peters et al. (2016): unknown interventions • However, neither attempt to predict the intervention.

  30. Importance of Dropout • Used for initial training on observational data. • Fully connected o ff -diagonal (the most DOF). • Pre-training cannot be carried out this way.

  31. Importance of Intervention Prediction • After the intervention has been performed, the learner draws data samples from the intervention distribution and computes the per-variable average log-probability under sampled adjacency matrices. • The variable consistently producing the least-likely outputs is predicted to be the intervention node.

  32. Importance of Intervention Prediction

Recommend


More recommend