Processing math: 100%

LSTM units

Dimitri Fichou

2023-04-21

Feed forward pass

ft=sigmoid(ht1Wf+xtUf)

it=sigmoid(ht1Wi+xtUi)

gt=tanh(ht1Wg+xtUg)

ot=sigmoid(ht1Wo+xtUo)

Ct=Ct1ft+itgt

ht=yt=tanh(Ct)ot

Back propagation pass

To perform the BPTT with a LSTM unit, we have the eror comming from the top layer (δ1), the future cell (δ4), the future hidden state (δ2). Also, we have stored during the feed forward the states at each step of the feeding. In the case of the future layer, this error is just set to zero if not calculated yet. For convention, correspond to point wise multiplication, while correspond to matrix multiplication.

The rules on how to back prpagate come from this post.

δ3=δ1+δ2

δ5=δ36=δ3ot

δ6=δ35=δ3tanh(ct)

δ7=δ5f(5)=δ5tanh(tanh(ct))

δ8=δ7δ4

δ9=δ810=δ8it

δ10=δ89=δ8gt

δ11=δ812=δ8ft

δ12=δ811=δ8ct1

δ13=δ6f(6)=δ6sigmoid(ot) δ14=δ9f(9)=δ9tanh(gt) δ15=δ10f(10)=δ10sigmoid(it) δ16=δ12f(12)=δ12sigmoid(ft)

δ17=δ13UTo δ19=δ14UTg δ21=δ15UTi δ23=δ16WTf δ18=δ13WTo δ20=δ14WTg δ22=δ16WTi δ24=δ16WTf

δ25=δ18+δ20+δ22+δ24 δ26=δ17+δ19+δ21+δ23

The error δ11, δ25 and δ26 are used for the next layers. Once all those errors are available, it is possible to calculate the weight update.

δWf=δWf+hTt1δ16 δUf=δUf+xTtδ16

δWi=δWi+hTt1δ15 δUi=δUi+xTtδ15

δWg=δWg+hTt1δ14 δUg=δUg+xTtδ14

δWo=δWo+hTt1δ13 δUo=δUo+xTtδ13