scalable deep generative modeling for sparse graphs
play

Scalable Deep Generative Modeling for Sparse Graphs Hanjun Dai 1 , - PowerPoint PPT Presentation

Scalable Deep Generative Modeling for Sparse Graphs Hanjun Dai 1 , Azade Nazi 1 , Yujia Li 2 , Bo Dai 1 , Dale Schuurmans 1 1 Google Brain, 2 DeepMind Graph generative models Given a set of graphs {G 1 , G 2 , , G N }, fit a probabilistic model


  1. Scalable Deep Generative Modeling for Sparse Graphs Hanjun Dai 1 , Azade Nazi 1 , Yujia Li 2 , Bo Dai 1 , Dale Schuurmans 1 1 Google Brain, 2 DeepMind

  2. Graph generative models Given a set of graphs {G 1 , G 2 , …, G N }, fit a probabilistic model p(G) over graphs. So that we can: sample from it to get new graphs: G ~ p(G) ● complete a graph given parts: G rest ~ p(G rest | G part ) ● obtain graph representations ● Can also be used for structured prediction p(G|z).

  3. Types of deep graph generative models Leverage the sparse Modeling adjacency structure of graphs matrix directly ⇒ like an image 1 2 4 3 Junction-Tree VAE (Jin et al., 18) VAE (Kipf et al 16, NetGAN (Bojchevski et al., 18) Simonovsky et al. 18) Autoregressive (You et al., 18; Deep GraphGen (Li et al., 18) Liao et al., 19)

  4. Autoregressive graph generative models Time complexity per graph during inference: Model Complexity (n nodes, m edges) Scalability Deep GraphGen (Li et al., 18) O((m + n) 2 ) ~100 nodes GraphRNN (You et al., 18) O(n 2 ) ~2,000 nodes GRAN (Liao et al., 19) O(n 2 ) ~5,000 nodes BiGG (Dai et al., 20) O((m + n) log n) ~100,000 nodes This work Or O(n 2 ) for fully connected graph

  5. Autoregressive graph generative models Time/memory complexity per graph during training: Model # syncs during training memory cost Deep GraphGen (Li et al., 18) O(m) O(m(m+n)) GraphRNN (You et al., 18) O(n 2 ) or O(n) O(n 2 ) GRAN (Liao et al., 19) O(n) O(n(m+n)) BiGG (Dai et al., 20) O(log n) O( ! log % ) This work

  6. Saving computation for sparse graphs GraphRNN ... ... O(n 2 ) GRAN, GraphRNN-S O(n 2 ) BiGG (this work) ... ... O((m + n) log n)

  7. Autoregressive Generation of adjacency matrix 01 Generating one cell 02 Generating one row 03 Generating rows

  8. Autoregressive Generation of adjacency matrix 01 Generating one cell 02 Generating one row 03 Generating rows

  9. O(log n) procedure for generating one edge Naive approach: Efficient approach: Given node u , Recursively divide the range [1, n] choose a neighbor v . into two halves, choose one. Choose 1 out of n using a softmax O(log n) decisions maximum O(n)

  10. Binary tree generation Following a path from root Generating neighbors Generating via DFS separately? O(N u log n) O(|T|) O(log n) N u is the number of |T| is the tree size. neighbors of node u |T| < min{N u log n, 2n}

  11. Autoregressive Generation of adjacency matrix 01 Generating one cell 02 Generating one row 03 Generating rows

  12. Autoregressive row-binary tree generation For node t , we first decide whether to def generate_tree(t): generate left child. should generate left child? t

  13. Autoregressive row-binary tree generation For node t , we first decide whether to def generate_tree(t): h top (t) generate left child. should generate left child? t ⇒ Generate left child: Conditioning on h top (t) , which summarizes existing tree (from top-down) Has-left ~ Bernoulli( ∘ | h top (t))

  14. Autoregressive row-binary tree generation For node t , we first decide whether to def generate_tree(t): h top (t) generate left child. should generate left child? t if yes: ⇒ Generate left child: create left child generate_tree(lch(t)) Conditioning on h top (t) , which summarizes existing tree (from top-down) Has-left ~ Bernoulli( ∘ | h top (t)) Yes? ⇒ Recursively generate left subtree

  15. Autoregressive row-binary tree generation For node t , we first decide whether to def generate_tree(t): h top (t) generate left child. should generate left child? t if yes: ⇒ Generate left child: create left child generate_tree(lch(t)) Conditioning on h top (t) , which summarizes should generate right child? existing tree (from top-down) if yes: Has-left ~ Bernoulli( ∘ | h top (t)) create right child generate_tree(rch(t)) h bot (lch(t)) Yes? ⇒ Recursively generate left subtree ⇒ Generate right child: Conditioning on h top (t), and h bot (lch(t)) , which summarizes the left subtree of t (from bottom-up) Has-right ~ Bernoulli( ∘ | h top (t), h bot (lch(t))) Yes? ⇒ Recursively generate right subtree

  16. Realize top-down and bottom-up recursion h bot (t) h bot (t) = TreeLSTMCell( , ) h bot (lch(t)) h bot (rch(t)) h top (t) h top (lch(t)) = LSTMCell( , ! "#$% ) ĥ top (rch(t)) = TreeLSTMCell( , ) h top (lch(t)) ĥ top (rch(t)) h top (rch(t)) = LSTMCell( , ! &'()% ) h bot (lch(t)) h top (rch(t))

  17. Autoregressive Generation of adjacency matrix 01 Generating one cell 02 Generating one row 03 Generating rows

  18. Autoregressive conditioning between rows To generate neighbors of node u , (i.e., u -th row) How to summarize row 0 to row u-1 ? Use LSTM? – not efficient h row (0) h row (1) h row (0) …… …… h row (1) h row (2) h row (u) h row (u-1) O(n) dependency length

  19. Fenwick tree for prefix summarization Fenwick tree: data structure that supports prefix sum and single modification h row (0) h row (1) h row (2) h row (3) h row (4) h row (5)

  20. Fenwick tree for prefix summarization Fenwick tree: data structure that supports prefix sum and single modification Obtaining “prefix sum” using low-bit query h row (0) Current row u Required Context h row (1) h row (2) u = 3 h row (2) u = 5 h row (4) h row (3) u = 6 h row (4) h row (5) At most O(log n) dependencies per row

  21. Optimizing BiGG 01 Training with O(log n) synchronizations 02 Model parallelism & sublinear memory cost

  22. Optimizing BiGG 01 Training with O(log n) synchronizations 02 Model parallelism & sublinear memory cost

  23. Training with O(log n) synchronizations Sync 1 …… Sync 2 Stage 1: …… Sync 3 Compute all bottom-up Sync 4 summarizations for all rows 1 1 1 1 O(log n) steps 1 1 1

  24. Training with O(log n) synchronizations Sync 1 Sync 2 Stage 2: Construct Sync 3 entire Fenwick Tree 1 1 1 1 O(log n) steps 1 1 1

  25. Training with O(log n) synchronizations Sync 1 #$% ℎ " Sync 2 Stage 3: #$% ℎ & Retrieve all the prefix #$% ℎ ' context O(log n) steps #$% ℎ ( #$% ℎ )

  26. Training with O(log n) synchronizations …… Sync 1 Sync 2 Stage 4: Sync 3 …… Compute Sync 4 Cross-Entropy O(log n) steps 1 1 1 1 1 1 1

  27. Optimizing BiGG 01 Training with O(log n) synchronizations 02 Model parallelism & sublinear memory cost

  28. Model parallelism

  29. Model parallelism GPU 2 GPU 1

  30. Model parallelism GPU 1 -> 2 message GPU 2 GPU 1

  31. Sublinear memory cost Run 2x forward + 1x backward Memory cost during training: O( ! log % ) pass-1 to pass-2 Pass-2 Pass-1

  32. Experiments

  33. Inference speed

  34. Training memory

  35. Training time Main reason: # GPU cores is limited

  36. Sample quality on benchmark datasets

  37. Sample quality as graph size grows

  38. Summary Advantages: Improve inference speed to O(min{ (m + n) log n, n 2 } ) ● Enables parallelized training with sublinear memory cost ● Did not sacrifice the sample quality ● Limitations: Limited by the parallelism of existing hardware ● Good capacity, but limited extrapolation ability ●

  39. Thank You Hanjun Dai Research Scientist hadai@google.com

Recommend


More recommend