Modelos de MRFs con Pytorch

Mariano Rivera

ver 1.1 Septiembre 2022


Paquetes

import numpy as np
from scipy.sparse import rand as sprand
import matplotlib.pyplot as plt
import PIL.Image as Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor, ToPILImage

print("Disponibilidad de Cuda        :",torch.cuda.is_available())

if torch.cuda.is_available():
    device = 'cuda' 
    print("Numero de dispositivos Cuda   :", torch.cuda.device_count())
    print("Dispositivo Cuda actual       :", torch.cuda.current_device())
    print("Nombre del dispositivo actual :", torch.cuda.get_device_name(torch.cuda.current_device()))
    
else:
    device = 'cpu'
Disponibilidad de Cuda        : True
Numero de dispositivos Cuda   : 2
Dispositivo Cuda actual       : 0
Nombre del dispositivo actual : NVIDIA GeForce RTX 3090

Filtro de Membrana: Regularización quadrática de primer orden

Dado el modelo de observación

(1)
g(x)=f(x)+η(x) g(x) = f(x) + \eta(x)

para xLZ2x \in \mathcal{L} \subset \mathbb{Z}^2. Donde gg es la imagen observada; ff la imagen verdadera (desconocida); ηN(0,σ2)\eta \sim \mathcal{N}(0, \sigma^2) (ruido Gaussiano con media cero y varianza σ2\sigma^2) e i.i.d; y L\mathcal{L} es la retícula de pixeles de la imágen.

Luego, de acuerdo al marco de regularización Bayesiana, podemos estimar ff mediante la minimización de

(2)
L(f)=gf2+λf2 L(f) = \| g - f \|^2 + \lambda \| \nabla f \|^2

Definimos los parámetros de nuestra simulación

LAMBDA = 10

Cargamos la imagen a procesar

img = Image.open('guanajuato.jpg')
img

png

Parámetros sobre los que se realizará la optimización

g = ToTensor()(img).to(device)
f = g.clone().to(device).requires_grad_(True)
f = nn.Parameter(f)
f.is_leaf
True

Función de costo

Esta es la parte mas importante pues, define la función de costo (loss) que queremos minimizar. Para ello, primero definimos como primeras diferencias adelantadas:

(3)
fx(x,y)=f(x+1,y)f(x,y)fy(x,y)=f(x,y+1)f(x,y). f_x(x,y) = f(x+1,y) - f(x,y) \\ f_y(x,y) = f(x, y+1) - f(x,y).

Esto es:

def gradient(f):
    '''
    Entrada
    f:      (c,h,w), float32 or float64
    Resultados
    fx, fy: (c,h,w)
    '''
    # corrimientos
    f_10 = torch.nn.functional.pad(f, (0, 1, 0, 0))[:, :, 1:] # pad last dim by (0, 1)
    f_01 = torch.nn.functional.pad(f, (0, 0, 0, 1))[:, 1:, :] # pad 2nd to last dim by (0, 1) 
    # primeras diferencias adelantadas
    fx = f_10 - f 
    fy = f_01 - f 
    # derivadas en la frontera
    fx[:, :, -1] = 0     # fx will have zeros in the last column
    fy[:, -1, :] = 0     # fy will have zeros in the last row

    return fx, fy    

Luego la magnitude cuadrada de gradiente

(4)
f2=fx2+fy2 \| \nabla f \|^2 = f_x^2 + f_y^2

la implementamos mediante

def L2_gradient(f):
    '''
    Calcula el promedio de la magnitud del gradiente espacial de la imagen multicanal f  (c,h,w)
    '''
    # gradiente
    fx, fy = gradient(f)
    # promedio de la magnitud del gradiente
    return torch.mean(fx**2 + fy**2)
    
L2_data = torch.nn.MSELoss()

Optimizador

Usamos ADAM como optimizador de la funcón de costo. El optimizador recibe in iterador (lista en este caso) sobre los parámetros que debe optimizar.

optimizer = torch.optim.Adam([f])

Entrenamiento

Definimos el método fit que llamaremos para calcular la factorización (entrenar el modelo)

Este método itera los siguientes pasos

  1. Inicializa el gradiente (lo pone en cero).

  2. Obtienen la predicción (imagen ff).

  3. Calcula la funcíon de costo.

  4. Calcula el gradiente del costo con respecto a la imagen ff mediante retropropagación.

  5. Realiza el paso de actualización de los parámetros.

def fit(f,g, LAMBDA, epochs=100):

    for t in range(epochs):
        optimizer.zero_grad()            # Inicializa gradiente
        #f = f                           # Prediccion, no requiere procesamiento
        loss_d = L2_data(f, g)         # Término de datos
        loss_r = L2_gradient(f)        # Término de regularización
        loss = loss_d + LAMBDA * loss_r  # Costo total
        loss.backward()                  # Retropropagación (gradiente)
        optimizer.step()                 # Actualiza los parametros del modelo 
                
        if t%100==0:
            print(t, end=' ')

Ahora si, realizamos la optimizaciòn por epochs número de iteraciones.

fit(f, g, LAMBDA, epochs=300)
0 100 200 

Despliege de resultados

im_f = ToPILImage()(f)
im_g = ToPILImage()(g)

plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(im_g)
plt.subplot(122)
plt.imshow(im_f)
<matplotlib.image.AxesImage at 0x7fd4b455f490>

png

im_f.save('filtrada_membrana.png', mode='png')

¡Eso es todo! Bueno, no. Veamos que mas podemos hacer.


Filtro de Membrana Robusta: Regularización medinate la penalización de la Variación Total

En este caso la imagen ff se obtiene mediante la minimización de

(5)
L(f)=gf2+λf1 L(f) = \| g - f \|^2 + \lambda \| \nabla f \|_1

Note que ahora estamos regularizando sobre la variación total de la imagen: norma L1L_1 del gradiente.

Inicializamos nuestras variables

g = ToTensor()(img).to(device)
f = g.clone().to(device).requires_grad_(True)
f = nn.Parameter(f)
f.is_leaf
True

La variación total se calcula como

(6)
f1=fx1+fy1 \| \nabla f \|_1 = \|f_x\|_1 + \|f_y\|_1

y la implementamos mediante

def L1_gradient(f):
    '''
    Calcula el promedio de la magnitud del gradiente espacial de la imagen multicanal f  (c,h,w)
    '''
    # gradiente
    fx, fy = gradient(f)
    # promedio de la magnitud del gradiente
    return torch.mean(torch.abs(fx) + torch.abs(fy))
    

Optimizador para TV

TV_optimizer = torch.optim.Adam([f])
L2_data = torch.nn.MSELoss(reduction='mean')
def fit(f,g, LAMBDA=1, epochs=100):

    for t in range(epochs):
        TV_optimizer.zero_grad()         # Inicializa gradiente
        #f = f                           # Prediccion, no requiere procesamiento
        loss_d = L2_data(f, g)           # Término de datos
        loss_r = L1_gradient(f)          # Término de regularización
        loss = loss_d + LAMBDA * loss_r  # Costo total
        loss.backward()                  # Retropropagación (gradiente)
        TV_optimizer.step()              # Actualiza los parametros del modelo 
                
        if t%100==0:
            print(t, end=' ')
fit(f, g, LAMBDA=.5, epochs=300)
0 100 200 
im_f = ToPILImage()(f)
im_g = ToPILImage()(g)

plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(im_g)
plt.subplot(122)
plt.imshow(im_f)
<matplotlib.image.AxesImage at 0x7fd4b46d4e20>

png

plt.figure(figsize=(20,2))
plt.subplot(211)
plt.plot(g.to('cpu').numpy()[0,100,:], 'r')
plt.subplot(212)
plt.plot(f.to('cpu').detach().numpy()[0,100,:], 'g')
[<matplotlib.lines.Line2D at 0x7fd4b45e3a00>]

png

Note que la regularización usando la norma L1L_1 promueve respetar las discontinuidades grandes, mientras que las pequeñas son sobresuavizadas, creando un efecto de dibujo animado o cartoon. Este tipo de regularización que respeta bordes se conoce como edge-preserving regularization Para saber mas sobre estos modelos ver:

(Black and Rabgarajan, 1996) Black, M. J., & Rangarajan, A. (1996). On the unification of line processes, outlier rejection, and robust statistics with applications in early vision. IJCV, 19(1), 57-91.

im_f.save('filtrada_tv.png', mode='png')

Ahora si. ¡Esto es todo!