论文复现:傅里叶特征INR恢复图像的高频特征

摘要: 主要学习INR和傅里叶特征增强,并给出了复现的代码(pytorch版本)

这篇文章上次修改于 1 年前,可能部分内容已经不适用,如有疑问可询问作者。

论文复现:傅里叶特征INR恢复图像的高频特征

Fourier Features Let Networks LearnHigh Frequency Functions in Low Dimensional Domains

基本思路

Backgroud

核回归

非线性回归算法之一。对任意的x构造基础函数的估计值f(x)=Σ(Ky)k(xi,x)
K是一个n*n的核且为gram矩阵(两个向量两两内积组成的矩阵,是一个对称矩阵,Kij=k(xi,xj))
类似自注意力,xi与x之间的相似性与yi相乘得到的权重相加
后面的推导属于NTK的部分https://www.cnblogs.com/manuscript-of-nomad/p/17243296.html
急功近利,贴个NTK的讲解blog,看不懂了再回来。NTK作为一种神经网络的输出分析工具还是相当好写在论文里面的()

傅里叶特征增强

论文提出了一种简单但有效的方法,通过将输入数据映射到高维傅里叶特征空间,使神经网络能够更好地捕捉高频细节。这种映射采用了随机傅里叶特征方法,将输入数据通过正弦和余弦函数进行转换。
效果——提升高频信号的学习能力:传统的神经网络在处理高频信号时往往表现不佳,因为它们更偏向于学习低频成分。通过引入傅里叶特征,网络能够在保留原始信号结构的同时,更好地捕捉和表示高频细节,从而提高整体性能。
简而言之,通过网络前面的随机傅里叶使得图像的高频特征进行了增强,让MLP(对于低维回归来说性能不佳)能更加快速的学习。

代码复现

基础代码

对数据的读取部分暂时跳过先(后续补上对MRI数据的处理)
先对基本模型的大致结构进行解析:(代码大部分来自原作者的jupyter,原代码格式为比较老的jax,我修改为pytorch的版本,CNN的架构(原文是MLP的架构),最后实际的结果相似) ——写jupyter的都是好人啊。

下面这段代码是对γ(v)的描述。其中B是从m*d空间的高斯分布(高斯噪声)采样得到的。
x,y则目前为相应空间上的梯度编码,ps.如果不经过B的处理的话,xy最后变换得到的图像是一个对角线方向的梯度图案。 γ(v)=[cos(2πBv),sin(2πBv)]

PYTHON
# Fourier feature mapping
class GaussianFourierFeatures(nn.Module):
    def __init__(self, num_input_channels, mapping_size=256, scale=10):
        super().__init__()
        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = torch.randn((num_input_channels, mapping_size)) * scale

    def forward(self, x):
        assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())

        batches, channels, width, height = x.shape

        assert channels == self._num_input_channels, \
            "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)

        # Make shape compatible for matmul with _B.
        # From [B, C, W, H] to [(B*W*H), C].
        x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
        # print(x.shape, self._B.shape)
        x = x @ self._B.to(x.device)

        # From [(B*W*H), C] to [B, W, H, C]
        x = x.view(batches, width, height, self._mapping_size)
        # From [B, W, H, C] to [B, C, W, H]
        x = x.permute(0, 3, 1, 2)

        x = 2 * np.pi * x
        return torch.cat([torch.sin(x), torch.cos(x)], dim=1)
Copy

接下来是设置训练Model
首先得到INF后面的线性网络给model。
这边代码大幅度调整,设立了不加前面INF操作的一个对比,采用两种不同的网络运行,并对比不同B尺度下网络的收敛情况,数据分别记录在psnr里面,而生成的图片放在pred_imgs,顺便存一个坐标轴xs。

PYTHON
# train in torch Model
def train_model(network_size, learning_rate, iters, B, train_data, test_data):
        if B is None or B is np.eye(2):
        model = basic_network(*network_size).to(device)
        x = normal_CNN(2, 128, B)(train_data[0])
    else:
        model = make_network(*network_size).to(device)
        x = GaussianFourierFeatures(2, 128, 10)(train_data[0])
    optimizer = optim.Adam(list(model.parameters()), lr=learning_rate)
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_psnrs = []
    test_psnrs = []
    pred_imgs = []
    xs = []
    start_time = time.time()

    for epoch in tqdm(range(iters), desc='train iter', leave=False):
        optimizer.zero_grad()
        generated = model(x)
        loss = nn.functional.mse_loss(train_data[1], generated)
        print('\nloss:', loss.item())
        loss.backward()
        optimizer.step()

        if epoch % 30 == 0:
            test_loss = nn.functional.l1_loss(test_data[1], generated)
            train_psnrs.append(loss.item())
            test_psnrs.append(test_loss.item())
            pred_imgs.append(generated[0].detach())
            print('\nEpoch %d, loss = %.03f' % (epoch, float(loss)))
            xs.append(epoch)

    return {
        'state': model.state_dict(),
        'train_psnrs': train_psnrs,
        'test_psnrs': test_psnrs,
        'pred_imgs': torch.stack(pred_imgs),
        'xs': xs
        'time': time.time() - start_time
    }
Copy

然后在启动训练和运行,以及重要的显示之前,定义一个转换tensor到cpu上的函数(matlab库这边主要是在cpu上绘制图片)。

PYTHON
# the function convince u to manifest your result
def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
    print('tensor:', tensor.shape)
    tensor = tensor * 256
    tensor[tensor > 255] = 255
    tensor[tensor < 0] = 0
    tensor = tensor.type(torch.uint8).permute(1, 2, 0).cpu().numpy()

    return tensor

# This should take about 10 minutes
# start your model
outputs = {}
for k in tqdm(B_dict):
  outputs[k] = train_model(network_size, learning_rate, iters, B_dict[k], train_data, test_data)

# Show final network outputs

plt.figure(figsize=(24, 4))
N = len(outputs)
for i, k in enumerate(outputs):
    img = tensor_to_numpy(outputs[k]['pred_imgs'][-1])
    print(img.shape)
    plt.subplot(1, N+1, i+1)
    plt.imshow(img)
    plt.title(k)
plt.subplot(1, N+1, N+1)
plt.imshow(img)
plt.title('GT')
plt.show()

# Plot train/test error curves

plt.figure(figsize=(16, 6))

plt.subplot(121)
for i, k in enumerate(outputs):
    plt.plot(outputs[k]['xs'], outputs[k]['train_psnrs'], label=k)
plt.title('Train error')
plt.ylabel('PSNR')
plt.xlabel('Training iter')
plt.legend()

plt.subplot(122)
for i, k in enumerate(outputs):
    plt.plot(outputs[k]['xs'], outputs[k]['test_psnrs'], label=k)
plt.title('Test error')
plt.ylabel('PSNR')
plt.xlabel('Training iter')
plt.legend()

plt.show()
Copy

其他设定如下,放在最开头,作为初始化/图片的读取:

PYTHON
# Get an image that will be the target for our model.
img = torch.tensor(get_image()).unsqueeze(0).permute(0, 3, 1, 2).float().contiguous().to(device)
# Download image, take a square crop from the center
# Create input pixel coordinates in the unit square
# [::2] is a down_sample operation
coords = np.linspace(0, 1, img.shape[2], endpoint=False)
xy_grid = np.stack(np.meshgrid(coords, coords), -1)
x_test_tensor = torch.tensor(xy_grid).unsqueeze(0).permute(0, 3, 1, 2).float().contiguous().to(device)
test_data = [x_test_tensor, img]
train_data = [x_test_tensor, img]

learning_rate = 1e-4
iters = 600
network_size = (4, 256)

mapping_size = 128
# set B matrix
B_dict = {}
# Standard network - no mapping
B_dict['none'] = None
# Basic mapping
B_dict['basic'] = np.eye(2)
B_gauss = torch.randn(mapping_size, 2)
for scale in [1., 10., 100.]:
# for scale in [100.]:
    B_dict[f'gauss_{scale}'] = B_gauss * scale
  
for i, k in enumerate(outputs):
    plt.bar(k, outputs[k]['time'], label=k)
plt.title('Time taken for each operation')
plt.ylabel('time')
plt.xlabel('model')
plt.legend()

plt.show()
Copy

线性网络源代码

以上的代码是在卷积神经网络下面拟合的,虽然卷积核设定为1使得变换更像线性函数,但是运作方式还是有一定的出入。下面给出Pytorch版本的MLP做INR的代码。

PYTHON
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import imageio
import time


# Setup the input and target for our model.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
    print('tensor:', tensor.shape)
    tensor = tensor * 256
    tensor[tensor > 255] = 255
    tensor[tensor < 0] = 0
    tensor = tensor.type(torch.uint8).cpu().numpy()

    return tensor

def get_image():
    image_url = 'https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg'
    img = imageio.v3.imread(image_url)[..., :3] / 255.
    print('img_init size:', img.shape)
    c = [img.shape[0] // 2, img.shape[1] // 2]
    r = 128
    img = img[c[0] - r:c[0] + r, c[1] - r:c[1] + r]

    return img


# Get an image that will be the target for our model.
img = torch.tensor(get_image()).float().to(device)
# Download image, take a square crop from the center
# Create input pixel coordinates in the unit square
# [::2] is a down_sample operation
coords = torch.linspace(0, 1, img.shape[0],dtype=torch.float)
xy_grid = torch.stack(torch.meshgrid(coords, coords), -1).to(device)
# x_test_tensor = torch.tensor(xy_grid).unsqueeze(0).permute(0, 3, 1, 2).float().contiguous().to(device)
test_data = [xy_grid, img]
train_data = [xy_grid, img]
print('train_data:', train_data[0].shape)


# Fourier feature mapping
class GaussianFourierFeatures(nn.Module):
    def __init__(self, num_input_channels, mapping_size=256, scale=10):
        super().__init__()
        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = torch.randn((num_input_channels, mapping_size)) * scale

    def forward(self, x):
        assert x.dim() == 3, 'Expected 3D input (got {}D input)'.format(x.dim())
        width, height, channels = x.shape

        assert channels == self._num_input_channels, \
            "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)

        # Make shape compatible for matmul with _B.
        # print(x.shape, self._B.shape)
        x = x @ self._B.to(x.device)
        x = 2 * torch.pi * x
        return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)


class normal_CNN(nn.Module):
    def __init__(self, num_input_channels, mapping_size, B):
        super().__init__()
        if B is not None:
            self.B = torch.eye(2)   # B matrix
        else:
            self.B = None
        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size

    def forward(self, x):
        if self.B is not None:
            assert x.dim() == 3, 'Expected 3D input (got {}D input)'.format(x.dim())
            width, height, channels = x.shape
            assert channels == self._num_input_channels, \
                "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)

            x = x @ self._B.to(x.device)
            x = 2 * torch.pi * x
            return torch.cat([torch.sin(x), torch.cos(x)], dim=1)
        else:
            return x


learning_rate = 1e-4
iters = 500
network_size = (4, 256)

mapping_size = 128
# set B matrix
B_dict = {}
# Standard network - no mapping
B_dict['none'] = None
# Basic mapping
B_dict['basic'] = np.eye(2)
B_gauss = torch.randn(mapping_size, 2)
for scale in [1., 10., 100.]:
# for scale in [100.]:
    B_dict[f'gauss_{scale}'] = B_gauss * scale

def make_network(num_layers, num_channels):
    layers = []
    for i in range(num_layers - 1):
        # layers.append(nn.Linear(num_channels, num_channels))
        layers.append(nn.Linear(num_channels, num_channels))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(num_channels,3))
    layers.append(nn.Sigmoid())
    return nn.Sequential(*layers).to(device)


def basic_network(num_layers, num_channels):
    layers = []
    for i in range(num_layers - 1):
        layers.append(nn.Linear(num_channels if i > 0 else 2, num_channels))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(num_channels, 3))
    layers.append(nn.Sigmoid())
    return nn.Sequential(*layers).to(device)

def train_model(network_size, learning_rate, iters, B, train_data, test_data):
    # Assuming `make_network` returns a PyTorch model

    print('trains:', train_data[0].shape)
    if B is None or B is np.eye(2):
        model = basic_network(*network_size).to(device)
        x = normal_CNN(2, 128, B)(train_data[0])
        print('x:', x.shape)
    else:
        model = make_network(*network_size).to(device)
        x = GaussianFourierFeatures(2, 128, 10)(train_data[0])
        print('xg:', x.shape)

    optimizer = optim.Adam(list(model.parameters()), lr=learning_rate)
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_psnrs = []
    test_psnrs = []
    pred_imgs = []
    xs = []
    start_time = time.time()

    for epoch in tqdm(range(iters), desc='train iter', leave=False):
        optimizer.zero_grad()
        generated = model(x)
        loss = nn.functional.mse_loss(train_data[1], generated)
        print('\nloss:', loss.item())
        loss.backward()
        optimizer.step()

        if epoch % 25 == 0:
            test_loss = nn.functional.l1_loss(test_data[1], generated)
            train_psnrs.append(loss.item())
            test_psnrs.append(test_loss.item())
            pred_imgs.append(generated.detach())
            print('\nEpoch %d, loss = %.03f' % (epoch, float(loss)))
            xs.append(epoch)

    return {
        'state': model.state_dict(),
        'train_psnrs': train_psnrs,
        'test_psnrs': test_psnrs,
        'pred_imgs': torch.stack(pred_imgs),
        'xs': xs,
        'time': time.time() - start_time
    }

# This should take about 2-3 minutes
outputs = {}
for k in tqdm(B_dict):
  outputs[k] = train_model(network_size, learning_rate, iters, B_dict[k], train_data, test_data)

# Show final network outputs

plt.figure(figsize=(24, 4))
N = len(outputs)
for i, k in enumerate(outputs):
    img = tensor_to_numpy(outputs[k]['pred_imgs'][-1])
    print(img.shape)
    plt.subplot(1, N+1, i+1)
    plt.imshow(img)
    plt.title(k)
plt.subplot(1, N+1, N+1)
plt.imshow(img)
plt.title('GT')
plt.show()

# Plot train/test error curves

plt.figure(figsize=(16, 6))

plt.subplot(121)
for i, k in enumerate(outputs):
    plt.plot(outputs[k]['xs'], outputs[k]['train_psnrs'], label=k)
plt.title('Train error')
plt.ylabel('PSNR')
plt.xlabel('Training iter')
plt.legend()

plt.subplot(122)
for i, k in enumerate(outputs):
    plt.plot(outputs[k]['xs'], outputs[k]['test_psnrs'], label=k)
plt.title('Test error')
plt.ylabel('PSNR')
plt.xlabel('Training iter')
plt.legend()

plt.show()


for i, k in enumerate(outputs):
    plt.bar(k, outputs[k]['time'], label=k)
plt.title('Time taken for each operation')
plt.ylabel('time')
plt.xlabel('model')
plt.legend()

plt.show()
Copy

这个跑起来相对CNN就快很多,仅仅采用了多层感知机。结果和他们贴出来的相似。确实可以论证这个傅里叶变换的核适合INR的高频特征还原,其中取高斯方差为10左右最为明显。


其中对比作者贴出的代码基本类似。曲线上略有出入。
以下是原作者的jupyter: https://colab.research.google.com/github/tancik/fourier-feature-networks/blob/master/Demo.ipynb#scrollTo=BwbEtsn0gB2h

3D-MRI代码复现

其中的两个包做医学图像用,安装一下(DL的包一般都默认装过)

pip install livelossplot
pip install phantominator
# 顺便装一下ipython, ipywidgets
pip install ipython
pip install ipywidgets
Copy

先翻翻前面的函数工作:
第一个函数看注释是计算cell的边界(利用反射的边界条件,但是明明是生成一个数组)
第二个函数则是似乎建立在第一个函数之上。附上copilot的函数解释:
first_breaks: 第一组间隔的边界点,形状为 (N+1,)。这些边界点必须是非递减序列,代表了第一组间隔的起始和结束位置。
second_breaks: 第二组间隔的边界点,形状为 (M+1,)。同样,这些边界点必须是非递减序列,代表了第二组间隔的起始和结束位置。
返回一个形状为 (N, M) 的数组,表示每一对间隔之间的重叠区域大小。
第三个函数则是一个平均数处理
old_size:原始尺寸。
new_size:新的尺寸。
reflect:是否有反射边界条件。如果为 True,则在边界处镜像反射;如果为 False,则不使用反射

PYTHON
def _reflect_breaks(size: int) -> np.ndarray:
  """Calculate cell boundaries with reflecting boundary conditions."""
  result = np.concatenate([[0], 0.5 + np.arange(size - 1), [size - 1]])
  assert len(result) == size + 1
  return result

# 计算两组间隔(或区间)之间的重叠距离
def _interval_overlap(first_breaks: np.ndarray,
                      second_breaks: np.ndarray) -> np.ndarray:
  """Return the overlap distance between all pairs of intervals.

  Args:
    first_breaks: breaks between entries in the first set of intervals, with
      shape (N+1,). Must be a non-decreasing sequence.
    second_breaks: breaks between entries in the second set of intervals, with
      shape (M+1,). Must be a non-decreasing sequence.

  Returns:
    Array with shape (N, M) giving the size of the overlapping region between
    each pair of intervals.
  """
  first_upper = first_breaks[1:]
  second_upper = second_breaks[1:]
  upper = np.minimum(first_upper[:, np.newaxis], second_upper[np.newaxis, :])

  first_lower = first_breaks[:-1]
  second_lower = second_breaks[:-1]
  lower = np.maximum(first_lower[:, np.newaxis], second_lower[np.newaxis, :])

  return np.maximum(upper - lower, 0)

  
def _resize_weights(
    old_size: int, new_size: int, reflect: bool = False) -> np.ndarray:
  """Create a weight matrix for resizing with the local mean along an axis.

  Args:
    old_size: old size.
    new_size: new size.
    reflect: whether or not there are reflecting boundary conditions.

  Returns:
    NumPy array with shape (new_size, old_size). Rows sum to 1.
  """
  if not reflect:
    old_breaks = np.linspace(0, old_size, num=old_size + 1)
    new_breaks = np.linspace(0, old_size, num=new_size + 1)
  else:
    old_breaks = _reflect_breaks(old_size)
    new_breaks = (old_size - 1) / (new_size - 1) * _reflect_breaks(new_size)

  weights = _interval_overlap(new_breaks, old_breaks)
  weights /= np.sum(weights, axis=1, keepdims=True)
  assert weights.shape == (new_size, old_size)
  return weights

def resize(array: np.ndarray,
           shape: [int, ...],
           reflect_axes: [int] = ()) -> np.ndarray:
  """Resize an array with the local mean / bilinear scaling.

  Works for both upsampling and downsampling in a fashion equivalent to
  block_mean and zoom, but allows for resizing by non-integer multiples. Prefer
  block_mean and zoom when possible, as this implementation is probably slower.

  Args:
    array: array to resize.
    shape: shape of the resized array.
    reflect_axes: iterable of axis numbers with reflecting boundary conditions,
      mirrored over the center of the first and last cell.

  Returns:
    Array resized to shape.

  Raises:
    ValueError: if any values in reflect_axes fall outside the interval
      [-array.ndim, array.ndim).
  """
  reflect_axes_set = set()
  for axis in reflect_axes:
    if not -array.ndim <= axis < array.ndim:
      raise ValueError('invalid axis: {}'.format(axis))
    reflect_axes_set.add(axis % array.ndim)

  output = array
  for axis, (old_size, new_size) in enumerate(zip(array.shape, shape)):
    reflect = axis in reflect_axes_set
    weights = _resize_weights(old_size, new_size, reflect=reflect)
    product = np.tensordot(output, weights, [[axis], [-1]])
    output = np.moveaxis(product, -1, axis)
  return output
Copy

前面这一段代码基本不用修改,前面的代码基本上都属于图像预处理的部分。
接下来是模型的部分:很简单的MLP——其中的depth和width后面会座位网络的层数和通道数的参数传入进去。
网络处理的核心其实和前面一致,先对输入做一个FFT的变换然后加入相应的高斯噪声核。可以提高网络的收敛速度。

PYTHON
# Model and training code
network_depth = 4 #@param
network_width = 96 #@param

input_encoder = lambda x, a, b: np.concatenate([a * np.sin((2.*np.pi*x) @ b.T), 
                                                a * np.cos((2.*np.pi*x) @ b.T)], axis=-1) #/ np.linalg.norm(a) * np.sqrt(a.shape[0])


def make_network(num_layers, num_channels):
  layers = []
  for i in range(num_layers - 1):
    # layers.append(nn.Linear(num_channels, num_channels))
    layers.append(nn.Linear(num_channels, num_channels))
    layers.append(nn.ReLU())
  layers.append(nn.Linear(num_channels, 1))
  layers.append(nn.Sigmoid())
  return nn.Sequential(*layers).to(device)

# 这个函数用的是作者原jax库的函数的表达方式,修改成torch版本可能不好理解
# 同样的还有fft的变换。也需要参考原作者的train_model函数。
def run_model(params, x, avals, bvals):
    if avals is not None:
        x = input_encoder(x, avals, bvals)
    return np.reshape(apply_fn(params, np.reshape(x, (-1, x.shape[-1]))), (x.shape[0], x.shape[1], x.shape[2]))
Copy

论文给出的公式大致为: 输入: γ(v)=(a0sin(πb⊤0v),a0cos(πb⊤0v),a1sin(πb⊤1v),a1cos(πb⊤1v),...)
创造的核函数为:
kγ(v1,v2)=∑mi=1 a2icos(πb⊤i(v1−v2))
这个核函数得代入NTK的理论里面学习。核心思想差不多就在这边结束了。

Reference

https://github.com/tancik/fourier-feature-networks
https://www.cnblogs.com/manuscript-of-nomad/p/17243296.html