CS 533: Natural Language Processing Backpropagation, Self-Attention, Text Representations Through Language Modeling Karl Stratos Rutgers University Karl Stratos CS 533: Natural Language Processing 1/52
Dropout (Slide credit Danqi Chen & Karthik Narasimhan) Karl Stratos CS 533: Natural Language Processing 2/52
Unidirectional vs Bidirectional RNN h h x x Karl Stratos CS 533: Natural Language Processing 3/52
Agenda 1. Backpropagation 2. Self-attention in NLP 3. Representation learning through language modeling Karl Stratos CS 533: Natural Language Processing 4/52
Backpropagation: Input and Output ◮ A technique to automatically calculate ∇ J ( θ ) for any definition of scalar-valued loss function J ( θ ) ∈ R . Input : loss function J ( θ ) ∈ R , parameter value ˆ θ Output : ∇ J (ˆ θ ) , the gradient of J ( θ ) at θ = ˆ θ ◮ Calculates the gradient of an arbitrary differentiable function of parameter θ Including neural networks Karl Stratos CS 533: Natural Language Processing 5/52
Notation ◮ For the most part, we will consider (differentiable) function f : R → R with a single 1 -dimensional parameter x ∈ R . ◮ The gradient of f with respect to x is a function of x ∂f ( x ) : R → R ∂x ◮ The gradient of f with respect to x evaluated at x = a is written as � � ∂f ( x ) � ∈ R � ∂x x = a Karl Stratos CS 533: Natural Language Processing 6/52
Chain Rule ◮ Given any differentiable functions f, g from R to R , ∂g ( f ( x )) ∂x = ∂g ( f ( x )) ∂f ( x ) × ∂f ( x ) ∂x � �� � easy to calculate ◮ “Proof”: Linearization of linearization of g ( z ) around f ( x ) around a g ( f ( x )) ≈ g ( f ( a )) + g ′ ( f ( a )) f ′ ( a ) ( x − a ) � � ∂g ( f ( x )) ∂x x = a Karl Stratos CS 533: Natural Language Processing 7/52
Exercises At x = 42 , ◮ What is the value of the gradient of f ( x ) := 7 ? ◮ What is the value of the gradient of f ( x ) := 2 x ? ◮ What is the value of the gradient of f ( x ) := 2 x + 99999 ? ◮ What is the value of the gradient of f ( x ) := x 3 ? ◮ What is the value of the gradient of f ( x ) := exp( x ) ? ◮ What is the value of the gradient of f ( x ) := exp(2 x 3 + 10) ? ◮ What is the value of the gradient of f ( x ) := log(exp(2 x 3 + 10)) Karl Stratos CS 533: Natural Language Processing 8/52
Chain Rule for a Function of Multiple Input Variables ◮ Let f 1 . . . f m denote any differentiable functions from R to R . ◮ If g : R m → R is a differentiable function from R m to R , ∂g ( f 1 ( x ) , . . . , f m ( x )) ∂x m � ∂g ( f 1 ( x ) , . . . , f m ( x )) ∂f i ( x ) = × ∂f i ( x ) ∂x � �� � i =1 easy to calculate ◮ Calculate the gradient of x + x 2 + yx with respect to x using the chain rule. Karl Stratos CS 533: Natural Language Processing 9/52
DAG A directed acylic graph (DAG) is a directed graph G = ( V, A ) with a topological ordering : a sequence π of V such that for every arc ( i, j ) ∈ A , i comes before j in π . 1 2 3 4 5 6 For backpropagation: usually assume have many roots and 1 leaf Karl Stratos CS 533: Natural Language Processing 10/52
Notation 1 2 3 4 5 6 V = { 1 , 2 , 3 , 4 , 5 , 6 } V I = { 1 , 2 } V N = { 3 , 4 , 5 , 6 } A = { (1 , 3) , (1 , 5) , (2 , 4) , (3 , 4) , (4 , 6) , (5 , 6) } pa (4) = { 2 , 3 } ch (1) = { 3 , 5 } Π G = { (1 , 2 , 3 , 4 , 5 , 6) , (2 , 1 , 3 , 4 , 5 , 6) } Karl Stratos CS 533: Natural Language Processing 11/52
Computation Graph ◮ DAG G = ( V, E ) with a single output node ω ∈ V . ◮ Every node i ∈ V is equipped with a value x i ∈ R : 1. For input node i ∈ V I , we assume x i = a i is given. 2. For non-input node i ∈ V N , we assume a differentiable function f i : R | pa ( i ) | → R and compute x i = f i (( x j ) j ∈ pa ( i ) ) ◮ Thus G represents a function : it receives multiple values x i = a i for i ∈ V I and calculates a scalar x ω ∈ R . ◮ We can calculate x ω by a forward pass . Karl Stratos CS 533: Natural Language Processing 12/52
Forward Pass Input : computation graph G = ( V, A ) with output node ω ∈ V Result : populates x i = a i for every i ∈ V 1. Pick some topological ordering π of V . 2. For i in order of π , if i ∈ V N is a non-input node, set x i ← a i := f i (( a j ) j ∈ pa ( i ) ) Why do we need a topological ordering? Karl Stratos CS 533: Natural Language Processing 13/52
Exercise Construct the computation graph associated with the function f ( x, y ) := ( x + y ) xy 2 Compute its output value at x = 1 and y = 2 by performing a forward pass. Karl Stratos CS 533: Natural Language Processing 14/52
For Notational Convenience . . . ◮ Collectively refer to all input slots by x I = ( x i ) i ∈ V I . ◮ Collectively refer to all input values by a I = ( a i ) i ∈ V I . ◮ At i ∈ V : Refer to its parental slots by x i I = ( x j ) j ∈ pa ( i ) . Refer to its parental values by a i I = ( a j ) j ∈ pa ( i ) . Two equally valid ways of viewing any a i ∈ R as a function: ◮ A “global” function of x I evaluated at a I . ◮ A “local” function of x i I evaluated at a i I . Karl Stratos CS 533: Natural Language Processing 15/52
Computation Graph: Gradients ◮ Now for every node i ∈ V , we introduce an additional slot z i ∈ R defined as � � z i := ∂x ω � � ∂x i x I = a I ◮ The goal of backpropagation is to calculate z i for every i ∈ V . ◮ Why are we done if we achieve this goal? Karl Stratos CS 533: Natural Language Processing 16/52
Key Ideas of Backpropagation ◮ Chain rule on the DAG structure � z i := ∂x ω � � � ∂x i x I = a I Karl Stratos CS 533: Natural Language Processing 17/52
Key Ideas of Backpropagation ◮ Chain rule on the DAG structure � � � � z i := ∂x ω � ∂x ω � × ∂x j � � � � = � � � ∂x i ∂x j ∂x i x j I = a j x I = a I x I = a I j ∈ ch ( i ) I Karl Stratos CS 533: Natural Language Processing 17/52
Key Ideas of Backpropagation ◮ Chain rule on the DAG structure � � � � z i := ∂x ω � ∂x ω � × ∂x j � � � � = � � � ∂x i ∂x j ∂x i x j I = a j x I = a I x I = a I j ∈ ch ( i ) I � j ∈ ch ( i ) z j × ∂f j ( x j � � I ) � = � ∂x i x j I = a j I � �� � easy to calculate Karl Stratos CS 533: Natural Language Processing 17/52
Key Ideas of Backpropagation ◮ Chain rule on the DAG structure � � � � z i := ∂x ω � ∂x ω � × ∂x j � � � � = � � � ∂x i ∂x j ∂x i x j I = a j x I = a I x I = a I j ∈ ch ( i ) I � j ∈ ch ( i ) z j × ∂f j ( x j � � I ) � = � ∂x i x j I = a j I � �� � easy to calculate ◮ If we compute z i in a reverse topological ordering , then we will have already computed z j for all j ∈ ch ( i ) . Karl Stratos CS 533: Natural Language Processing 17/52
Key Ideas of Backpropagation ◮ Chain rule on the DAG structure � � � � z i := ∂x ω � ∂x ω � × ∂x j � � � � = � � � ∂x i ∂x j ∂x i x j I = a j x I = a I x I = a I j ∈ ch ( i ) I � j ∈ ch ( i ) z j × ∂f j ( x j � � I ) � = � ∂x i x j I = a j I � �� � easy to calculate ◮ If we compute z i in a reverse topological ordering , then we will have already computed z j for all j ∈ ch ( i ) . ◮ What’s the base case z ω ? Karl Stratos CS 533: Natural Language Processing 17/52
Backpropagation Input : computation graph G = ( V, A ) with output node ω ∈ V whose value slots x i = a i are already populated for every i ∈ V Result : populates z i for every i ∈ V 1. Set z ω ← 1 . 2. Pick some topological ordering π of V . 3. For i in reverse order of π , set � � z j × ∂f j ( x j � I ) z i ← � � ∂x i x j I = a j j ∈ ch ( i ) I Karl Stratos CS 533: Natural Language Processing 18/52
Exercise Calculate the gradient of f ( x, y ) := ( x + y ) xy 2 with respect to x at x = 1 and y = 2 by performing backpropagation. That is, calculate the scalar � � ∂f ( x, y ) � � ∂x ( x,y )=(1 , 2) Karl Stratos CS 533: Natural Language Processing 19/52
Answer * 12 1 4 * 3 * + 4 3 4 3 x y 1 2 16 16 Karl Stratos CS 533: Natural Language Processing 20/52
Implementation ◮ Each type of function f creates a child node from parent nodes and initializes its gradient to zero. ◮ “Add” function creates a child node c with two parents ( a, b ) and sets c.z ← 0 . ◮ Each node has an associated forward function. ◮ Calling forward at c populates c.x = a.x + b.x (assumes parents have their values). ◮ Each node also has an associated backward function. ◮ Calling backward at c “broadcasts” its gradient c.z (assumes it’s already calculated) to its parents a.z ← a.z + c.z b.z ← b.z + c.z Karl Stratos CS 533: Natural Language Processing 21/52
Implementation (Cont.) ◮ Express your loss J B ( θ ) on minibatch B at θ = ˆ θ as a computation graph. ◮ Forward pass. For each node a in a topological ordering, a. forward () ◮ Backward pass. For each node a in a reverse topological ordering, a. backward () ◮ The gradient of J B ( θ ) at θ = ˆ θ is stored in the input nodes of the computation graph. Karl Stratos CS 533: Natural Language Processing 22/52
Recommend
More recommend