Versión simplificada de la oficial en SIREN
Mariano Rivera
version 1.0, diciembre 2022.
Aprendizaje Automático, Mariano Rivera, CIMAT © 2022
Estamos acostumbrado a representar señales e imágenes a través de vectores y matrices; respectivamente. Estas representaciones discretas son muestras de señales reales de mundo que de forma natural son continuas. representar en forma discreta estos datos tiene ventajas inheretes para su almacenamiento, procesamiento, y transmición a través de computadoras de naturaleza digital. Al mismo tiempo, al ser muestras de señales reales si se viola el criterio e Nyquist se tiene pérdida de información y falseamiento de la misma (aliasing).
En esta notas se presenta como usar la librería PyTorch para optimizar funciones de costo cuyos parámetros son directamente las variables que buscamos (imágenes); ver MRF. Ahora consideremos otro enfoque, asumamos que en vez de representar explícitamente la solución a nuestro problema, ésta se representa mediante una red neuronal. Es decir, asumimos que existe una función que realiza el mapeo:
(1)
donde es la variable independiente y la dependiente. Por ejemplo, considere una imagen en tonos de gris , cuyas posiciones de pixeles están dados por con , y son las dimensiones de la imagen.
Luego, a cada posición le corresponde un valor de gris .
Ahora, podemos ver cada par como una dato para entrenar nuestro modelo (parametrizado por ) que aproxima a la función verdadera . Luego tenemos varias opciones para elegir el modelo las primeras que se nos pueden ocurrir en froma natural son las Funciones Base Radial (RBFs) y, por supuesto el Perceptrón Multicapa (MLP); ambos son aproximadores universales de funciones.
Aquí analizaremos el caso del MLP, en el cual usamos los pares para entrenar el MLP. Recordemos que un MLP de una capa oculta corresponde a:
(2)
donde es la predicción de , son los parámetros del modelo y son las funciones de activación; generalmente o ReLU. Entonces para aprender los parámetros definimos una función de pérdida, digamos
(3)
y la optimizamos usando descenso estocástico seleccionando en cada iteración un subconjunto de los pixeles (batch) usando como gradiente
(4)
Para entrenar la red de representación implícita se se procede a sobreentrenar (overfit) el model.
Antes de pasar revisar otras posibles aplicaciones de la representación implícita, introduciremos una capa que ha mostrado tener propiedades interesantes en cuanto a estabilidad para su entrenamiento. Esta capa es denominada SIREN y toma su nombre de Periodic Activation Implicit Neural Representations (Sitzmann et al., 2020). Esta capa corresponde a una capa densa pero con activación senoidal:
(5)
donde es un factor de escala no entrenable. Los autores proponen usar para datos de , y los parámetros se inicialializan acorde a
(6)
para la primera capa y
(7)
para las subsecuentes capas; donde es la dimensión del vector de entrada a la capa. En PyTorch los pesos son inicializados por omisión a . En esta implementación usamos la inicialización por omisión de PyTorch para la primera etapa, dado que de acuerdo a nuestros experimentos produjo mejores resultados.
La intención de usar activación periódica senoidal es que para 's pequeñas. De ahi que al tener el factor de escala se procura que si en el proceso de entrenamiento resulta ser muy grande, la función periódica produce una respuesta equivalente a . Es decir a la producida en la rama principal. Esto es, la capa SIREN es ciega a factores que incremeten la respuesta de la parte lineal en la forma para entera.
(Sitzmann et al., 2020) V. Sitzmann et al. Implicit Neural Representations with Periodic Activation Functions, Proc. NIPS, 2020.
import numpy as np
from scipy.sparse import rand as sprand
import matplotlib.pyplot as plt
import PIL.Image as Image
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
print("Disponibilidad de Cuda :",torch.cuda.is_available())
if torch.cuda.is_available():
device = 'cuda'
print("Numero de dispositivos Cuda :", torch.cuda.device_count())
print("Dispositivo Cuda actual :", torch.cuda.current_device())
print("Nombre del dispositivo actual :", torch.cuda.get_device_name(torch.cuda.current_device()))
else:
device = 'cpu'
device = 'cpu'
Disponibilidad de Cuda : True
Numero de dispositivos Cuda : 1
Dispositivo Cuda actual : 0
Nombre del dispositivo actual : NVIDIA GeForce RTX 3090
img = Image.open('cameraman.pgm')
img
f = ToTensor()(img)[0,:,:] # en el rango [0,1]
nrows, ncols = f.shape
nrows, ncols
(256, 256)
Creamos la base de datos, pares (posición de pixel, valor de pixel) = (X,Y)
ii, jj = np.meshgrid(np.arange(nrows), np.arange(ncols))
ii = ii.flatten().astype('float32')
ii = (ii/nrows)-0.5
jj = jj.flatten().astype('float32')
jj = (jj/ncols)-0.5
X = np.stack([ii,jj]).T
X = ToTensor()(X)[0]
X.min(),X.max(), X.shape
(tensor(-0.5000), tensor(0.4961), torch.Size([65536, 2]))
Y = f.flatten()
Y = (Y-0.5) # en el rango [-0.5,0.5]
Y.min(),Y.max()
(tensor(-0.4137), tensor(0.5000))
from torch.utils.data import TensorDataset, DataLoader
batch_size = nrows*(ncols//10)
dataloader = torch.utils.data.DataLoader(dataset = [*zip(X,Y)],
batch_size = batch_size,
shuffle = True)
Versión comentada de la capa SIREN
class SineLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True,
is_first=False, is_last=False, omega_0=30):
'''
Implementa
\sin ( omega ( W x + b) )
in_features : (int) dimensión de entrada
out_features : (int) número de neuronas (dimensión de salida)
bias : -
is_fisrt : (boolean) se escala distinto la inicialización de la primera capa oculta y las restantes
is_last : (boolean) sn funcion de activacio en la capa de salida
'''
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.is_last = is_last
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.init_weights()
def forward(self, input):
'''
y = Phi(omega0 W (x+b) )
'''
return torch.sin(self.omega_0 * self.linear(input))
def forward_with_intermediate(self, input):
'''
z = omega0 (W x + b)
y = \sin(z)
(y,z)
'''
# For visualization of activation distributions
z = self.omega_0 * self.linear(input)
return torch.sin(z), z if not self.last else z,z
def init_weights(self):
'''
Inicialización escalada para considerar a la activación periodica
Pretende mantener la respuesta de cada neurona dentro de una misma rama y
evitar saltos entre ramas al entrenar los pesos
'''
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-np.sqrt(1 / self.in_features),
np.sqrt(1 / self.in_features ))
else:
self.linear.weight.uniform_(-np.sqrt(6 /self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)
Ahora el modelo de la red es un perceptrón multicapa con capas SIREN (MLP-SIREN).
class SIRENnet(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim=1):
#super(MLPnet, self).__init__()
super().__init__()
'''
El método init define las capas de las cuales constará el modelo,
aunque no la forma en que se interconectan
'''
# Modelo inicialmente vacio
self.net = []
# Se agragan capas Seno ocultas
is_first=True
for i in range(len(hidden_dims)):
is_last = False if i < len(hidden_dims)-1 else True
self.net.append(SineLayer(in_features = input_dim,
out_features = hidden_dims[i],
is_first = is_first,
is_last = is_last))
input_dim = hidden_dims[i]
is_first=False
self.net = nn.Sequential(*self.net)
def forward(self, x):
out=self.net(x)
return out
def name(self):
return "MLP"
model = SIRENnet(input_dim=(2), hidden_dims=[128,128, 128, 128, 1], output_dim=1)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params, 256*256
(50049, 65536)
model = model.to(device)
rec_loss = torch.nn.MSELoss() # L1Loss() #
optimizer = torch.optim.Adam(model.parameters())
num_epochs = 100
loss_list = []
for epoch in range(num_epochs):
loss_epoch =0
# - - - - - - - - - - - - - - -
# Entrena la Red en lotes cada época
# - - - - - - - - - - - - - - -
for i, (x, y) in enumerate(dataloader):
x = x.to(device)
y = y.to(device)
x = Variable(x)
y = Variable(y)
optimizer.zero_grad() # Borra gradiente
y_pred = model(x) # Propagación
loss = rec_loss(y_pred[:,0],y) # Calcula error
loss.backward() # Retropropaga error
optimizer.step() # Actualiza parámetros
loss_epoch += loss.data
# - - - - - - - - - - - - - - -
# Despliega evaluación
# - - - - - - - - - - - - - - -
loss_list.append(loss_epoch/i)
print('Epoch: {:03}/{} Loss: {:.6f}'.format(epoch, num_epochs,loss_list[-1]))
Epoch: 000/100 Loss: 0.562069
Epoch: 001/100 Loss: 0.504904
Epoch: 002/100 Loss: 0.505939
...
Epoch: 098/100 Loss: 0.001744
Epoch: 099/100 Loss: 0.001795
torch.save(model, 'model_implicit.pt')
Y_pred=model(X.to(device))
y_pred = Y_pred.detach().cpu().numpy().reshape(nrows,ncols)
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(y_pred, 'gray')
plt.axis('off')
plt.subplot(122)
plt.imshow(f, 'gray')
plt.axis('off')
(-0.5, 255.5, 255.5, -0.5)
history_loss = [v.item() for v in loss_list]
fig, ax = plt.subplots()
ax.plot(np.log(history_loss))
ax.set_title('Error de entrenamiento')
ax.set_xlabel('Iteración')
plt.show()