Mariano Rivera
ver 1.1 Septiembre 2022
import numpy as np
from scipy.sparse import rand as sprand
import matplotlib.pyplot as plt
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor, ToPILImage
print("Disponibilidad de Cuda :",torch.cuda.is_available())
if torch.cuda.is_available():
device = 'cuda'
print("Numero de dispositivos Cuda :", torch.cuda.device_count())
print("Dispositivo Cuda actual :", torch.cuda.current_device())
print("Nombre del dispositivo actual :", torch.cuda.get_device_name(torch.cuda.current_device()))
else:
device = 'cpu'
Disponibilidad de Cuda : True
Numero de dispositivos Cuda : 2
Dispositivo Cuda actual : 0
Nombre del dispositivo actual : NVIDIA GeForce RTX 3090
Dado el modelo de observación
(1)
para . Donde es la imagen observada; la imagen verdadera (desconocida); (ruido Gaussiano con media cero y varianza ) e i.i.d; y es la retícula de pixeles de la imágen.
Luego, de acuerdo al marco de regularización Bayesiana, podemos estimar mediante la minimización de
(2)
Definimos los parámetros de nuestra simulación
LAMBDA = 10
Cargamos la imagen a procesar
img = Image.open('guanajuato.jpg')
img
g = ToTensor()(img).to(device)
f = g.clone().to(device).requires_grad_(True)
f = nn.Parameter(f)
f.is_leaf
True
Esta es la parte mas importante pues, define la función de costo (loss) que queremos minimizar. Para ello, primero definimos como primeras diferencias adelantadas:
(3)
Esto es:
def gradient(f):
'''
Entrada
f: (c,h,w), float32 or float64
Resultados
fx, fy: (c,h,w)
'''
# corrimientos
f_10 = torch.nn.functional.pad(f, (0, 1, 0, 0))[:, :, 1:] # pad last dim by (0, 1)
f_01 = torch.nn.functional.pad(f, (0, 0, 0, 1))[:, 1:, :] # pad 2nd to last dim by (0, 1)
# primeras diferencias adelantadas
fx = f_10 - f
fy = f_01 - f
# derivadas en la frontera
fx[:, :, -1] = 0 # fx will have zeros in the last column
fy[:, -1, :] = 0 # fy will have zeros in the last row
return fx, fy
Luego la magnitude cuadrada de gradiente
(4)
la implementamos mediante
def L2_gradient(f):
'''
Calcula el promedio de la magnitud del gradiente espacial de la imagen multicanal f (c,h,w)
'''
# gradiente
fx, fy = gradient(f)
# promedio de la magnitud del gradiente
return torch.mean(fx**2 + fy**2)
L2_data = torch.nn.MSELoss()
Usamos ADAM como optimizador de la funcón de costo. El optimizador recibe in iterador
(lista en este caso) sobre los parámetros que debe optimizar.
optimizer = torch.optim.Adam([f])
Definimos el método fit
que llamaremos para calcular la factorización (entrenar el modelo)
Este método itera los siguientes pasos
Inicializa el gradiente (lo pone en cero).
Obtienen la predicción (imagen ).
Calcula la funcíon de costo.
Calcula el gradiente del costo con respecto a la imagen mediante retropropagación.
Realiza el paso de actualización de los parámetros.
def fit(f,g, LAMBDA, epochs=100):
for t in range(epochs):
optimizer.zero_grad() # Inicializa gradiente
#f = f # Prediccion, no requiere procesamiento
loss_d = L2_data(f, g) # Término de datos
loss_r = L2_gradient(f) # Término de regularización
loss = loss_d + LAMBDA * loss_r # Costo total
loss.backward() # Retropropagación (gradiente)
optimizer.step() # Actualiza los parametros del modelo
if t%100==0:
print(t, end=' ')
Ahora si, realizamos la optimizaciòn por epochs
número de iteraciones.
fit(f, g, LAMBDA, epochs=300)
0 100 200
im_f = ToPILImage()(f)
im_g = ToPILImage()(g)
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(im_g)
plt.subplot(122)
plt.imshow(im_f)
<matplotlib.image.AxesImage at 0x7fd4b455f490>
im_f.save('filtrada_membrana.png', mode='png')
¡Eso es todo! Bueno, no. Veamos que mas podemos hacer.
En este caso la imagen se obtiene mediante la minimización de
(5)
Note que ahora estamos regularizando sobre la variación total de la imagen: norma del gradiente.
Inicializamos nuestras variables
g = ToTensor()(img).to(device)
f = g.clone().to(device).requires_grad_(True)
f = nn.Parameter(f)
f.is_leaf
True
La variación total se calcula como
(6)
y la implementamos mediante
def L1_gradient(f):
'''
Calcula el promedio de la magnitud del gradiente espacial de la imagen multicanal f (c,h,w)
'''
# gradiente
fx, fy = gradient(f)
# promedio de la magnitud del gradiente
return torch.mean(torch.abs(fx) + torch.abs(fy))
Optimizador para TV
TV_optimizer = torch.optim.Adam([f])
L2_data = torch.nn.MSELoss(reduction='mean')
def fit(f,g, LAMBDA=1, epochs=100):
for t in range(epochs):
TV_optimizer.zero_grad() # Inicializa gradiente
#f = f # Prediccion, no requiere procesamiento
loss_d = L2_data(f, g) # Término de datos
loss_r = L1_gradient(f) # Término de regularización
loss = loss_d + LAMBDA * loss_r # Costo total
loss.backward() # Retropropagación (gradiente)
TV_optimizer.step() # Actualiza los parametros del modelo
if t%100==0:
print(t, end=' ')
fit(f, g, LAMBDA=.5, epochs=300)
0 100 200
im_f = ToPILImage()(f)
im_g = ToPILImage()(g)
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(im_g)
plt.subplot(122)
plt.imshow(im_f)
<matplotlib.image.AxesImage at 0x7fd4b46d4e20>
plt.figure(figsize=(20,2))
plt.subplot(211)
plt.plot(g.to('cpu').numpy()[0,100,:], 'r')
plt.subplot(212)
plt.plot(f.to('cpu').detach().numpy()[0,100,:], 'g')
[<matplotlib.lines.Line2D at 0x7fd4b45e3a00>]
Note que la regularización usando la norma promueve respetar las discontinuidades grandes, mientras que las pequeñas son sobresuavizadas, creando un efecto de dibujo animado o cartoon. Este tipo de regularización que respeta bordes se conoce como edge-preserving regularization Para saber mas sobre estos modelos ver:
(Black and Rabgarajan, 1996) Black, M. J., & Rangarajan, A. (1996). On the unification of line processes, outlier rejection, and robust statistics with applications in early vision. IJCV, 19(1), 57-91.
im_f.save('filtrada_tv.png', mode='png')
Ahora si. ¡Esto es todo!