basics of numerical optimization computing derivatives
play

Basics of Numerical Optimization: Computing Derivatives Ju Sun - PowerPoint PPT Presentation

Basics of Numerical Optimization: Computing Derivatives Ju Sun Computer Science & Engineering University of Minnesota, Twin Cities February 25, 2020 1 / 36 Derivatives for numerical optimization gradient descent Newtons


  1. Approximate the gradient For f ( x ) : R n → R , ∂x i ≈ f ( x + δ e i ) − f ( x ) ∂f ( forward ) δ ∂x i ≈ f ( x ) − f ( x − δ e i ) ∂f ( backward ) δ ∂x i ≈ f ( x + δ e i ) − f ( x − δ e i ) ∂f ( central ) 2 δ (Credit: numex-blog.com) f ′ ( x ) = lim δ → 0 f ( x + δ ) − f ( x ) δ Similarly, to approximate the Jacobian for f ( x ) : R n → R m : ∂f j ∂x i ≈ f j ( x + δ e i ) − f j ( x ) ( one element each time ) δ ∂x i ≈ f ( x + δ e i ) − f ( x ) ∂f ( one column each time ) δ Jp ≈ f ( x + δ p ) − f ( x ) ( directional ) δ central themes can also be derived 11 / 36

  2. Why central? Stronger form of Taylor’s theorems – 1st order : If f ( x ) : R n → R is twice continuously differentiable, � � � δ � 2 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + O 2 – 2nd order : If f ( x ) : R n → R is three-times continuously differentiable, � � � � δ , ∇ 2 f ( x ) δ � δ � 3 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + 1 + O 2 2 12 / 36

  3. Why central? Stronger form of Taylor’s theorems – 1st order : If f ( x ) : R n → R is twice continuously differentiable, � � � δ � 2 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + O 2 – 2nd order : If f ( x ) : R n → R is three-times continuously differentiable, � � � � δ , ∇ 2 f ( x ) δ � δ � 3 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + 1 + O 2 2 Why the central theme is better? – Forward: by 1st-order Taylor expansion � δ 2 �� � δ ( f ( x + δ e i ) − f ( x )) = 1 1 δ ∂f ∂f ∂x i + O = ∂x i + O ( δ ) δ 12 / 36

  4. Why central? Stronger form of Taylor’s theorems – 1st order : If f ( x ) : R n → R is twice continuously differentiable, � � � δ � 2 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + O 2 – 2nd order : If f ( x ) : R n → R is three-times continuously differentiable, � � � � δ , ∇ 2 f ( x ) δ � δ � 3 f ( x + δ ) = f ( x ) + �∇ f ( x ) , δ � + 1 + O 2 2 Why the central theme is better? – Forward: by 1st-order Taylor expansion � δ 2 �� � 1 δ ( f ( x + δ e i ) − f ( x )) = 1 δ ∂f ∂f ∂x i + O = ∂x i + O ( δ ) δ – Central: by 2nd-order Taylor expansion 1 δ ( f ( x + δ e i ) − f ( x − δ e i )) = � δ 3 �� � 2 δ 2 ∂ 2 f 2 δ 2 ∂ 2 f 1 δ ∂f ∂x i + 1 i + δ ∂f ∂x i − 1 ∂x i + O ( δ 2 ) ∂f i + O = ∂x 2 ∂x 2 2 δ 12 / 36

  5. Approximate the Hessian – Recall that for f ( x ) : R n → R that is 2nd-order differentiable, ∂x i ( x ) : R n → R . So ∂f � � � � � ∂f � ∂f ∂f ( x + δ e j ) − ( x ) ∂f 2 ∂ ∂x i ∂x i ∂x j ∂x i ( x ) = ( x ) ≈ ∂x j ∂x i δ 13 / 36

  6. Approximate the Hessian – Recall that for f ( x ) : R n → R that is 2nd-order differentiable, ∂x i ( x ) : R n → R . So ∂f � � � � � ∂f � ∂f ∂f ( x + δ e j ) − ( x ) ∂f 2 ∂ ∂x i ∂x i ∂x j ∂x i ( x ) = ( x ) ≈ ∂x j ∂x i δ – We can also compute one row of Hessian each time by � ∂f � � ∂f � � ∂f � ( x + δ e j ) − ( x ) ∂ ∂ x ∂ x ( x ) ≈ , ∂x j ∂ x δ � ⊺ � obtaining � H , which might not be symmetric. Return 1 � H + � H instead 2 – Most times (e.g., in TRM, Newton-CG), only ∇ 2 f ( x ) v for certain v ’s needed: (see, e.g., Manopt https://www.manopt.org/ ) ∇ 2 f ( x ) v ≈ ∇ f ( x + δ v ) − f ( x ) (1) δ 13 / 36

  7. A few words – Can be used for sanity check for correctness of analytic gradient 14 / 36

  8. A few words – Can be used for sanity check for correctness of analytic gradient – Finite-difference approximation of higher (i.e., ≥ 2 )-order derivatives combined with high-order iterative methods can be very efficient (e.g., Manopt https://www.manopt.org/tutorial.html#costdescription ) 14 / 36

  9. A few words – Can be used for sanity check for correctness of analytic gradient – Finite-difference approximation of higher (i.e., ≥ 2 )-order derivatives combined with high-order iterative methods can be very efficient (e.g., Manopt https://www.manopt.org/tutorial.html#costdescription ) – Numerical stability can be an issue: truncation and round off s (finite δ ; accurate evaluation of the nominators) 14 / 36

  10. Outline Analytic differentiation Finite-difference approximation Automatic differentiation Differentiable programming Suggested reading 15 / 36

  11. Four kinds of computing techniques Credit: [Baydin et al., 2017] 16 / 36

  12. Four kinds of computing techniques Credit: [Baydin et al., 2017] Misnomer: should be automatic numerical differentiation 16 / 36

  13. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: 17 / 36

  14. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 17 / 36

  15. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in one pass, from inner to outer most parenthesis: dx 17 / 36

  16. Forward mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in one pass, from inner to outer most parenthesis: dx dy 0 � Input: x 0 , initialization = 1 � dy 0 � x 0 for i = 1 , . . . , k do � � compute y i = f i y i − 1 � � � dyi − 1 � dyi − 1 dyi � dyi compute = � � = f ′ � � � · y i − 1 � � i � dy 0 dyi − 1 dy 0 dy 0 � x 0 � � � yi − 1 x 0 x 0 end for dyk � Output: � dy 0 � x 0 17 / 36

  17. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: 18 / 36

  18. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 18 / 36

  19. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in two passes, from inner to outer most parenthesis for the 2nd: dx 18 / 36

  20. Reverse mode in 1D Consider a univariate function f k ◦ f k − 1 ◦ · · · ◦ f 2 ◦ f 1 ( x ) : R → R . Write y 0 = x , y 1 = f 1 ( x ) , y 2 = f 2 ( y 1 ) , . . . , y k = f ( y k − 1 ) , or in computational graph form: ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f Chain rule: dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 � d f � Compute x 0 in two passes, from inner to outer most parenthesis for the 2nd: dx Input: x 0 , dyk dyk = 1 for i = 1 , . . . , k do compute y i = f i � y i − 1 � end for // forward pass for i = k − 1 , k − 2 , . . . , 0 do � � � dyk � dyk dyi +1 dyk � � = f ′ � compute = · i +1 ( y i ) � � � � dyi dyi +1 dyi dyi +1 � yi � � � yi +1 yi yi +1 end for // backward pass dyk � Output: � dy 0 � x 0 18 / 36

  21. Forward vs reverse modes 19 / 36

  22. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives 19 / 36

  23. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives Effectively, two different ways of grouping the multiplicative differential terms: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the root: dy 0 dy 0 �→ dy 1 dy 0 �→ dy 2 dy 0 �→ · · · �→ dy k dy 0 ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the leaf: dy k dy k dy k − 2 �→ · · · �→ dy k dy k dy k �→ dy k − 1 �→ dy 0 ...mixed forward and reverse modes are indeed possible!

  24. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives Effectively, two different ways of grouping the multiplicative differential terms: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the root: dy 0 dy 0 �→ dy 1 dy 0 �→ dy 2 dy 0 �→ · · · �→ dy k dy 0 ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the leaf: dy k dy k dy k − 2 �→ · · · �→ dy k dy k dy k �→ dy k − 1 �→ dy 0 19 / 36

  25. Forward vs reverse modes – forward mode AD : one forward pass, compute the intermediate variable and derivative values together – reverse mode AD : one forward pass to compute the intermediate variable values, one backward pass to compute the intermediate derivatives Effectively, two different ways of grouping the multiplicative differential terms: � dy k � dy k − 1 � � dy 2 � dy 1 ����� dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the root: dy 0 dy 0 �→ dy 1 dy 0 �→ dy 2 dy 0 �→ · · · �→ dy k dy 0 ����� dy k � dy k − 1 � � dy 2 � dy 1 � dx = d d f f dy 0 = . . . dy k − 1 dy k − 2 dy 1 dy 0 i.e., starting from the leaf: dy k dy k dy k − 2 �→ · · · �→ dy k dy k dy k �→ dy k − 1 �→ dy 0 ...mixed forward and reverse modes are indeed possible! 19 / 36

  26. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 20 / 36

  27. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 NB: this is a computational graph, not a NN 20 / 36

  28. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables NB: this is a computational graph, not a NN 20 / 36

  29. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables – If node B a descent of node A , ∂B ∂A is the rate of change in B wrt change in A NB: this is a computational graph, not a NN 20 / 36

  30. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables – If node B a descent of node A , ∂B ∂A is the rate of change in B wrt change in A – Traveling along a path, rates of changes should be multiplied NB: this is a computational graph, not a NN 20 / 36

  31. Chain rule in computational graphs Let f : R n → R m and h : R n → R k , and f is differentiable at x Chain rule and y = f ( x ) and h is differentiable at y . Then, h ◦ f : R n → R k is differentiable at x , and (write z = h ( y ) ) m � or ∂z j ∂z j ∂y ℓ J [ h ◦ f ] ( x ) = J h ( f ( x )) J f ( x ) , ∂x i = ∂x i ∀ i, j ∂y ℓ ℓ =1 – Each node is a variable, as a function of all incoming variables – If node B a descent of node A , ∂B ∂A is the rate of change in B wrt change in A – Traveling along a path, rates of changes should be multiplied – Chain rule: summing up rates over all connecting paths! (e.g., x 2 to z j as shown) NB: this is a computational graph, not a NN 20 / 36

  32. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = 21 / 36

  33. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 21 / 36

  34. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 21 / 36

  35. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 – complexity: 21 / 36

  36. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 – complexity: O (# edges + # nodes ) 21 / 36

  37. A multivariate example — forward mode � � � x 1 � sin x 1 x 2 + x 1 x 2 − e x 2 x 2 − e x 2 y = ∂ – interested in ∂x 1 ; for each variable v i . = ∂v i v i , write ˙ ∂x 1 – for each node, sum up partials over all incoming edges, e.g., v 4 = ∂v 4 v 1 + ∂v 4 ˙ ∂v 1 ˙ ∂v 3 ˙ v 3 – complexity: O (# edges + # nodes ) – for f : R n → R m , make n forward passes: O ( n (# edges + # nodes )) 21 / 36

  38. A multivariate example — reverse mode 22 / 36

  39. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) 22 / 36

  40. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 22 / 36

  41. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 – complexity: 22 / 36

  42. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 – complexity: O (# edges + # nodes ) 22 / 36

  43. A multivariate example — reverse mode – interested in ∂y ∂ ; for each variable v i , write v i . ∂y = ∂v i (called adjoint variable ) – for each node, sum up partials over all outgoing edges, e.g., v 4 = ∂v 5 ∂v 4 v 5 + ∂v 6 ∂v 4 v 6 – complexity: O (# edges + # nodes ) – for f : R n → R m , make n forward passes: O ( m (# edges + # nodes )) example from Ch 1 of [Griewank and Walther, 2008] 22 / 36

  44. Forward vs. reverse modes For general function f : R n → R m , suppose there is no loop in the computational graph, i.e., acyclic graph . Define E : set of edges ; V : set of nodes 23 / 36

  45. Forward vs. reverse modes For general function f : R n → R m , suppose there is no loop in the computational graph, i.e., acyclic graph . Define E : set of edges ; V : set of nodes forward mode reverse mode start from roots leaves end with leaves roots v i . v i . = ∂v i ∂y invariants ˙ ∂x ( x —root of interest) = ∂v i ( y —leaf of interest) rule sum over incoming edges sum over outgoing edges complexity O ( n | E | + n | V | ) O ( m | E | + m | V | ) better when m ≫ n n ≫ m 23 / 36

  46. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, 24 / 36

  47. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product 24 / 36

  48. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again 24 / 36

  49. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again – How? (1) initialize D p v n − 1 = p 1 , . . . , D p v 0 = p n . (2) apply chain rule: � � ∂v i ∂v i ∇ x v i = ∂v j ∇ x v j = ⇒ D p v i = ∂v j D p v j j : incoming j : incoming 24 / 36

  50. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again – How? (1) initialize D p v n − 1 = p 1 , . . . , D p v 0 = p n . (2) apply chain rule: � � ∂v i ∂v i ∇ x v i = ∂v j ∇ x v j = ⇒ D p v i = ∂v j D p v j j : incoming j : incoming reverse mode : compute J ⊺ f q = ∇ x ( f ⊺ q ) , i.e., Jacobian-trans-vector product 24 / 36

  51. Directional derivatives Consider f ( x ) : R n → R m . Let v s ’s be the variables in its computational graph. Particularly, v n − 1 = x 1 , v n − 2 = x 2 , . . . , v 0 = x n . D p ( · ) means directional derivative wrt p . In practical implementations, forward mode : compute J f p , i.e., Jacobian-vector product – Why? (1) Columns of J f can be obtained by setting p = e 1 , . . . , e n . (2) When J f has special structures (e.g., sparsity), save computation by judicious choices of p ’s (3) Problem may only need J f p for a specific p , not J f itself—save computation again – How? (1) initialize D p v n − 1 = p 1 , . . . , D p v 0 = p n . (2) apply chain rule: � � ∂v i ∂v i ∇ x v i = ∂v j ∇ x v j = ⇒ D p v i = ∂v j D p v j j : incoming j : incoming reverse mode : compute J ⊺ f q = ∇ x ( f ⊺ q ) , i.e., Jacobian-trans-vector product – Why? Similar to the above dv i ( f ⊺ q ) = � ∂v k d d d – How? Track dv i ( f ⊺ q ) : dv k ( f ⊺ q ) k : outgoing ∂v i 24 / 36

  52. Tensor abstraction Tensors : multi-dimensional arrays 25 / 36

  53. Tensor abstraction Tensors : multi-dimensional arrays Each node in the computational graph can be a tensor (scalar, vector, matrix, 3-D tensor, ...) 25 / 36

  54. Tensor abstraction Tensors : multi-dimensional arrays Each node in the computational graph can be a tensor (scalar, vector, matrix, 3-D tensor, ...) f ( W ) = � Y − σ ( W k σ ( W k − 1 σ . . . ( W 1 X ))) � 2 F 25 / 36

  55. Tensor abstraction Tensors : multi-dimensional arrays Each node in the computational graph can be a tensor (scalar, vector, matrix, 3-D tensor, ...) computational graph for DNN f ( W ) = � Y − σ ( W k σ ( W k − 1 σ . . . ( W 1 X ))) � 2 F 25 / 36

  56. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple 26 / 36

  57. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation 26 / 36

  58. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax 26 / 36

  59. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software 26 / 36

  60. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software – Higher-order derivatives can also be done, particularly Hessian-vector product ∇ 2 f ( x ) v (Check out Jax!) 26 / 36

  61. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software – Higher-order derivatives can also be done, particularly Hessian-vector product ∇ 2 f ( x ) v (Check out Jax!) – Auto-diff in Tensorflow and Pytorch are specialized to DNNs and focus on 1st order, Jax (in Python) is full fledged and also supports GPU 26 / 36

  62. Tensor abstraction – Abstract out low-level details; operations are often simple e.g., ∗ , σ so partials are simple – Tensor (i.e., vector) chain rules apply, often via tensor-free computation – Basis of implementation for: Tensorflow, Pytorch, Jax, etc Jax: https://github.com/google/jax Good to know: – In practice, graphs are built automatically by software – Higher-order derivatives can also be done, particularly Hessian-vector product ∇ 2 f ( x ) v (Check out Jax!) – Auto-diff in Tensorflow and Pytorch are specialized to DNNs and focus on 1st order, Jax (in Python) is full fledged and also supports GPU – General resources for autodiff: http://www.autodiff.org/ , [Griewank and Walther, 2008] 26 / 36

  63. Autodiff in Pytorch 2 with ∇ f ( x ) = − A ⊺ ( y − Ax ) Solve least squares f ( x ) = 1 2 � y − Ax � 2 27 / 36

  64. Autodiff in Pytorch 2 with ∇ f ( x ) = − A ⊺ ( y − Ax ) Solve least squares f ( x ) = 1 2 � y − Ax � 2 loss vs. iterate 27 / 36

  65. Autodiff in Pytorch Train a shallow neural network � � y i − W 2 σ ( W 1 x i ) � 2 f ( W ) = 2 i where σ ( z ) = max ( z, 0) , i.e., ReLU https://pytorch.org/tutorials/beginner/pytorch_with_ examples.html – torch.mm – torch.clamp – torch.no grad() Back propagation is reverse mode auto-differentiation! 28 / 36

  66. Outline Analytic differentiation Finite-difference approximation Automatic differentiation Differentiable programming Suggested reading 29 / 36

  67. Example: image enhancement 30 / 36

  68. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) 30 / 36

  69. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) – The parameterized functions may or may not be DNNs 30 / 36

  70. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) – The parameterized functions may or may not be DNNs – Each function may be analytic, or simply a chunk of codes dependent on the parameters 30 / 36

  71. Example: image enhancement – Each stage applies a parameterized function to the image, i.e., q w k ◦ · · · ◦ h w 3 ◦ g w 2 ◦ f w 1 ( X ) ( X is the camera raw) – The parameterized functions may or may not be DNNs – Each function may be analytic, or simply a chunk of codes dependent on the parameters – w i ’s are the trainable parameters Credit: https://people.csail.mit.edu/tzumao/gradient_halide/ 30 / 36

  72. Example: image enhancement – the trainable parameters are learned by gradient descent based on auto-differentiation – This is generalization of training DNNs with the classic feedforward structure to training general parameterized functions, using derivative-based methods Credit: https://people.csail.mit.edu/tzumao/gradient_halide/ 31 / 36

  73. Example: control a trebuchet https://fluxml.ai/2019/03/05/dp-vs-rl.html 32 / 36

Recommend


More recommend