Processing math: 100%

GRU units

Dimitri Fichou

2023-04-21

Feed forward pass

rt=sigmoid(ht1Wr+xtUr)

zt=sigmoid(ht1Wz+xtUz)

gt=tanh(Wg(ht1rt)+xtUg)

ht=yt=ht1(1zt)+(ztgt)

Back propagation pass

To perform the BPTT with a GRU unit, we have the eror comming from the top layer (δ1), the future hidden states (δ2). Also, we have stored during the feed forward the states at each step of the feeding. In the case of the future layer, this error is just set to zero if not calculated yet. For convention, correspond to point wise multiplication, while correspond to matrix multiplication.

The rules on how to back prpagate come from this post.

δ3=δ1+δ2

δ4=(1zt)δ3

δ5=δ3ht1

δ6=1δ5

δ7=δ3gt

δ8=δ3zt

δ9=δ7+δ8

δ10=δ8tanh(gt)

δ11=δ9sigmoid(zt)

δ12=δ10WTg δ13=δ10UTg δ14=δ11WTz δ15=δ11UTz

δ16=δ13ht1 δ17=δ13rt

δ18=δ17sigmoid(rt)

δ19=δ17+δ4

δ20=δ18WTr δ21=δ18UTr

δ22=δ21+δ15

δ23=δ19+δ22

δ24=δ12+δ14+δ20

The error δ23 and δ24 are used for the next layers. Once all those errors are available, it is possible to calculate the weight update.

δWr=δWf+hTt1δ10 δUr=δUf+xTtδ10

δWz=δWi+hTt1δ11 δUz=δUi+xTtδ11

δWg=δWg+(hTt1rt)δ18 δUg=δUg+xTtδ18