El Problema del Gradiente Evanescente

Mariano Rivera

Nov 2018

[1] R. Pascanu et al., On the difficulty of training recurrent neural networks. In Proc. ICML’13, vol 28, pp III–1310–III–1318, (2013)


En problema conocido del entrenamiento de las redes recurrentes en el llamado Gradiente Evanescete (vanish gradient). [1]

Para analizar dicho problema, primero usaremos un ejemplo con la finalidad de motivar el problema.

Suponga que hemos entrenado una red para completar oraciones y la entrenamos con oraciones como la siguiente:

Al gato la gusta jugar con los niños, …, descansar en su cojín y comer pescado

Donde ‘…’ significan que hay varias palabras mas.

Luego le pedimos a la red predecir la siguiente palabra en la siguiente oración:

*Al perro la gusta jugar con los niños, …, descansar en su cojín y comer ________ *

Y la red predice pescado, cuando esperabamos huesos.

¿Que fué lo que pasó? ¿En que falló el entrenamiento? Bueno, pues la palabra que define de quién estamos hablando (gato o perro) está al inicio de la oración, y la red tiene memoria corta. Esto es, nuestra red no es muy buena para recordar datos importantes por mucho tiempo.

La razón es que en las redes recurrentes el efecto en la salida de los palaras se desvanece rápidamente. Es decir, modificar las palabras iniciales de la oración tienen muy poco efecto (si no es que nulo) en la última salida de la RNN.

Esto implica que el el cambio en el costo de la última predicción se desvanece conforme los términos son más tempranos: el gradiente se desvanece rápidamente en el tiempo. Matemáticamente, esto se escribe como

(1)
E(ot,Ot)xktk0 \frac{\partial E(o_t, O_t)}{\partial {x_k}}\bigg\uparrow_{t-k} \approx 0

Donde oto_t es la salida de la RNN para el tiempo tt y OtO_t es la salida esperada (usada en el entrenamiento supervisado) y xkx_k es una entrada muy distante en el tiempo.

Para analizar esto, asumamos una RNN muy simple, como la de la siguiente figura.

rrn2

Donde la celda esta definida por la simple función de transformación:

(2)
ot=ϕ(Wx xt+Wo ot1+b) o_t = \phi(W_x\, x_t + W_o \, o_{t-1} + b)

Donde ϕ\phi es una función de activación y definimos W=[Wx,Wo,b]W = [W_x, W_o, b]. Luego agregamos, gráficamente, la etapa de cálculo del error. Entonces la RNN se ve como

rrn2_error

Donde, EtE_t es el error entre la predicción oto_t y la salida esperada OtO_t considerando la tt-ésima entrada xtx_t y la retroalimentación ot1o_{t-1}.

Luego, el error total cometido esta dado por

(3)
E=t=0TEt E = \sum_{t=0}^T E_t

y el gradiente de la función de error respecto a los parámetros WW es

(4)
WE=defEW=t=0TEtW \nabla_W E \overset{def}{=}\frac{\partial E}{\partial W} = \sum_{t=0}^T \frac{\partial E_t}{\partial W}
que es la suma de las contribuciones de cada salida al gradiente.

Por lo pronto, consideremos sólo la salida final,

1 Notamos que un cambio en WW afectará directamente a o5o_5. Gráficamente: se ilustra en la figura siguiente.

grad1

2 Sin ambargo, la entrada o4o_4 a la última celda también se ve afectada por un cambio en WW.

grad2

3 Que a su vez se ve afectada por un cambio en o3o_3 inducido por el cambio en WW.

grad3

De hecho, por la naturaleza recursiva de nuestra red, un cambio en WW afecta directamente todas las salidas de la red, y este cambio se propaga a través del canal de memoria.

gradn

Entonces el gradiente en la salida tt-ésima esta dado por

(5)
EtW=EtototWCambio directamente en la celda final t+Etototot1ot1WCambio inducido por la celda previa t1+Etototot1ot1ot2ot2WCambio inducido por la celda ante-previa t2+Etototot1ot1ot2ot2ot3ot3WCambio inducido por la celda t3+Etototot1ot1ot2ot2ot3ot3ot4ot4WCambio inducido por la celda t4+Cambio inducido por el resto de las celdas \begin{matrix} \frac{\partial E_t}{\partial W} & = \frac{\partial E_t}{\partial o_t} \frac{\partial o_t}{\partial W} & \text{Cambio directamente en la celda final $t$} \\ & + \frac{\partial E_t}{\partial o_t} \frac{\partial o_t}{\partial o_{t-1}} \frac{\partial o_{t-1}}{\partial W} & \text{Cambio inducido por la celda previa $t-1$}\\ & + \frac{\partial E_t}{\partial o_t} \frac{\partial o_t}{\partial o_{t-1}} \frac{\partial o_{t-1}}{\partial o_{t-2}} \frac{\partial o_{t-2}}{\partial W} & \text{Cambio inducido por la celda ante-previa $t-2$} \\ & + \frac{\partial E_t}{\partial o_t} \frac{\partial o_t}{\partial o_{t-1}} \frac{\partial o_{t-1}}{\partial o_{t-2}} \frac{\partial o_{t-2}}{\partial o_{t-3}} \frac{\partial o_{t-3}}{\partial W} & \text{Cambio inducido por la celda $t-3$} \\ & + \frac{\partial E_t}{\partial o_t} \frac{\partial o_t}{\partial o_{t-1}} \frac{\partial o_{t-1}}{\partial o_{t-2}} \frac{\partial o_{t-2}}{\partial o_{t-3}} \frac{\partial o_{t-3}}{\partial o_{t-4}} \frac{\partial o_{t-4}}{\partial W} & \text{Cambio inducido por la celda $t-4$}\\ & + \ldots & \text{Cambio inducido por el resto de las celdas} \end{matrix}

que podemos reescribirlo como

(6)
EtW=k=1tEtototokokW \frac{\partial E_t}{\partial W} = \sum_{k=1}^t \frac{\partial E_t}{\partial o_t} \frac{\partial o_{t}}{\partial o_k} \frac{\partial o_{k}}{\partial W}

Donde hemos definido
(7)
otok=defi=kt1oioi1 \frac{\partial o_{t}}{\partial o_k} \overset{def}{=}\prod_{i=k}^{t-1} \frac{\partial o_{i}}{\partial o_{i-1}}

Ahora, veamos cada término

(8)
oioi1=oi1ϕ(Wx xi+Wo oi1+b)                        Wo ϕ(Wx xi+Wo oi1+b) \frac{\partial o_{i}}{\partial o_{i-1}} = \frac{\partial}{\partial o_{i-1}} \phi(W_x\, x_i + W_o \, o_{i-1} + b) \\ \;\;\;\;\;\;\;\;\;\;\;\; W_o \, \phi^\prime(W_x\, x_i + W_o \, o_{i-1} + b)

Sustituimos (8) en (7):

(9)
otok=i=kt1Wo ϕ(Wx xi+Wo oi1+b)                        =(Wo)tk1i=kt1ϕ(Wx xi+Wo oi1+b) \frac{\partial o_{t}}{\partial o_k} = \prod_{i=k}^{t-1} W_o \, \phi^\prime(W_x\, x_i + W_o \, o_{i-1} + b) \\ \;\;\;\;\;\;\;\;\;\;\;\; = (W_o)^{t-k-1} \prod_{i=k}^{t-1} \phi^\prime(W_x\, x_i + W_o \, o_{i-1} + b)

Ahora

(10)
okW=Wϕ(W [xk,ok1,1])                                    =[xk,ok1,1] ϕ(W [xk,ok1,1]) \frac{\partial o_{k}}{\partial W} = \frac{\partial}{\partial W} \phi( W \, [x_k, o_{k-1}, \mathbf{1}]^\top) \\ \;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\; = [x_k, o_{k-1}, \mathbf{1}]^\top \, \phi^\prime(W \, [x_k, o_{k-1}, \mathbf{1}]^\top)
donde usamos W=[Wx,Wo,b]W = [W_x, W_o , b]

Y el primer término Etot\frac{\partial E_t}{\partial o_t} depende únicamente de la función de error DD que usemos.

Poniendo todo junto en (6):

(13)

EtW=k=1t(Etot  [xk,ok1,1] ϕ(W [xk,ok1,1])(Wo)tk1!i=kt1ϕ(Wx xi+Wo oi1+b)!) \boxed{ \frac{\partial E_t}{\partial W} = \sum_{k=1}^t \left( \frac{\partial E_t}{\partial o_t} \; [x_k, o_{k-1}, \mathbf{1}]^\top \, \phi^\prime(W \, [x_k, o_{k-1}, \mathbf{1}]^\top ) \underbrace{(W_o)^{t-k-1}}_{\mathbf !} \underbrace{ \prod_{i=k}^{t-1} \phi^\prime(W_x\, x_i + W_o \, o_{i-1} + b)}_{\mathbf !} \right) }

Lo importante es la aparición de los términos


Podemos factorizar W=U D VW = U \, D \, V^\top usando descomposición en valores singulares (Singular Value Decompossition, SVD). Donde DD es una matriz diagonal con los valores singulares y UU, VV som matrices unitarias (UU=UU=IU^\top U = U U^\top = I y VV=VVV=IV^\top V = V V^\top V = I). Luego Wn=Wn2WW=Wn2V D2 VW^n = W^{n-2} W^\top W = W^{n-2} V \, D^2 \, V ^\top. Note que este producto pueded cambiar si elegimos multiplicar por la izquierda y si nn es par o impar; pero lo escencial es que el producto es de la forma Wn=ADnBW^n = A D^n B (con A,B{U,U,V,V}A,B \in \{U,U^\top, V, V^\top\}. Donde DnD^n significa que cada valor singular es elevado a la potencia nn. Si n>>0n>>0 y Dii<0D_{ii}<0, entonces Diin0D_{ii}^n \approx 0. y Wn0W^n \approx 0.


Por ello, las RNN no pueden usar términos tempranos kk para construir respuestas a tiempos muy distantes tt (con k<<tk<<t).

El Problema de la Explosión del Gradiente

En redes profundas no recurrentes se ha visto otro problema asociado con el producto de gradientes.

Si la matriz WoW_o contienen únicamente términos [Wo]ij>1| [W_o]_{ij} | > 1 y el producto (potencia)de matrices crece más rapidamente que el producto de las parciales. Entonces, el gradiente explotará hacendose extremadamente grande y provocando el llamado problema de Explosión del Gradiente [2], que es el opuesto al de Gradiente Evanescente y generalmente se resuelve mediante:

  1. Regularizando (penalizando) los pesos WW con norma L1L_1 o L2L_2.

  2. Recortando gradientes excesivamente grandes.

  3. Truncando backpropagation cundo hay gradientes grandes.

[2] Y. Bengio et al., The Problem of Learning Long-Term Dependencies in Recurrent Networks, Neural Networks for Computing Conference, 1993.