from __future__ import annotations
from .base_vsf import BaseVSF
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union
from dataclasses import dataclass, asdict, field
@dataclass
class NeuralVSFConfig:
"""Configuration for a neural VSF model."""
aabb: Tuple[List[float], List[float]] = field(default_factory=lambda: ([-1., -1., -1.], [1., 1., 1.]))
"""The domain of the vsf. Output values are 0 outside of this domain"""
output_names: List[str] = field(default_factory=lambda: ['stiffness'])
"""The names of the output fields. Each field can have multiple dimensions as given by the output_dims attribute."""
output_dims: List[int] = field(default_factory=lambda:[1])
"""The dimensions of each output field. The length of this list should match the length of the output_names attribute."""
num_layers: int = 8
"""The number of layers in the neural network."""
hidden_dim: int = 64
"""The hidden dimension of the neural network."""
skip_connection: List[int] = field(default_factory=lambda: [4])
"""The layers to add skip connections to."""
output_scale: float = 1e4
"""Multiply the network output by a constant factor."""
# borrowed (and modified) from https://github.com/ashawkey/torch-ngp
class FreqEncoder(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True,
periodic_fns=(torch.sin, torch.cos)):
super().__init__()
self.input_dim = input_dim
self.include_input = include_input
self.periodic_fns = periodic_fns
self.output_dim = 0
if self.include_input:
self.output_dim += self.input_dim
self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
if log_sampling:
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
else:
self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
self.freq_bands = self.freq_bands.numpy().tolist()
def forward(self, input, **kwargs):
out = []
if self.include_input:
out.append(input)
for i in range(len(self.freq_bands)):
freq = self.freq_bands[i]
for p_fn in self.periodic_fns:
out.append(p_fn(input * freq))
out = torch.cat(out, dim=-1)
return out
def get_encoder(encoding, input_dim=3,
multires=6,
degree=4,
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
**kwargs):
if encoding == 'None':
return lambda x, **kwargs: x, input_dim
elif encoding == 'frequency':
encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
# hashgrid encoding requires the gridencoder module
# elif encoding == 'hashgrid':
# from gridencoder import GridEncoder
# encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
else:
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
return encoder, encoder.output_dim
class VSFNetwork(nn.Module):
def __init__(self,
num_layers=8,
hidden_dim=64,
skip_connection=[4],
output_names=['stiffness'],
output_dims=[1],
aabb=None,
sdf=None,
output_scale=1e4,
**kwargs
):
super(VSFNetwork, self).__init__()
if aabb is None:
print("Warning: AABB not provided, using default AABB")
aabb = [[-1., -1., -1.], [1., 1., 1.]]
aabb = torch.tensor(aabb, dtype=torch.float32)
center = .5 * (aabb[0] + aabb[1])
scale = (aabb[1] - aabb[0]) / 2.
self.aabb = aabb
self.center = center
self.scale = scale
self.sdf = sdf
# sigma network
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.skip_connection = skip_connection
self.output_names = output_names
self.output_dims = output_dims
self.output_dim = sum(output_dims)
self.encoder, self.in_dim = get_encoder('frequency', input_dim=3, multires=4)
# self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, num_levels=8, level_dim=2, base_resolution=16, log2_hashmap_size=12, desired_resolution=512)
self.output_scale = output_scale
sigma_net = []
for l in range(num_layers):
if l == 0:
in_dim = self.in_dim
else:
in_dim = hidden_dim
if l == num_layers - 1:
out_dim = self.output_dim
else:
out_dim = hidden_dim
if l in skip_connection:
sigma_net.append(nn.Linear(in_dim + self.in_dim, out_dim, bias=False))
else:
sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))
self.sigma_net = nn.ModuleList(sigma_net)
def forward(self, x):
x = (x - self.center) / self.scale
bound_mask = (x[..., 0] > -1) & (x[..., 1] > -1) & (x[..., 2] > -1) & \
(x[..., 0] < 1) & (x[..., 1] < 1) & (x[..., 2] < 1)
if self.sdf is not None:
x_shape = x.shape
x_sdf = torch.nn.functional.grid_sample(self.sdf[None, None, ...],
x.flip(dims=(-1,)).reshape(1,-1,1,1,3), mode='bilinear', align_corners=True).reshape(x_shape[:-1])
sdf_mask = x_sdf < 0
bound_mask = bound_mask & sdf_mask
sigma_r = torch.zeros(*x.shape[:-1], 1, device=x.device)
x = x[bound_mask]
# sigma
x = self.encoder(x)
h = x
for l in range(self.num_layers):
if l in self.skip_connection:
h = torch.cat([h, x], dim=-1)
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)
sigma = torch.exp(h)
sigma_r[bound_mask] = sigma
return sigma_r * self.output_scale
def get_sdf(self, x):
assert self.sdf is not None, "SDF not provided"
x = (x - self.center) / self.scale
x_shape = x.shape
x_sdf = torch.nn.functional.grid_sample(self.sdf[None, None, ...],
x.flip(dims=(-1,)).reshape(1,-1,1,1,3), mode='bilinear', align_corners=True).reshape(x_shape[:-1])
return x_sdf
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr},
{'params': self.sigma_net.parameters(), 'lr': lr},
]
return params
def save(self, path):
save_dict = {
"save_dict": self.state_dict(),
"aabb": self.aabb,
"center": self.center,
"scale": self.scale,
"sdf": self.sdf,
"output_scale": self.output_scale
}
torch.save(save_dict, path)
def load(self, path):
load_dict = torch.load(path, map_location=self.device)
self.load_state_dict(load_dict["save_dict"])
self.aabb = load_dict["aabb"]
self.center = load_dict["center"]
self.scale = load_dict["scale"]
self.sdf = load_dict["sdf"]
self.output_scale = load_dict["output_scale"]
self.eval()
def to(self, device):
self.aabb = self.aabb.to(device)
self.center = self.center.to(device)
self.scale = self.scale.to(device)
if self.sdf is not None:
self.sdf = self.sdf.to(device)
return super().to(device)
@property
def device(self):
return next(self.parameters()).device
[docs]
class NeuralVSF(BaseVSF):
"""
A Neural VSF model that conforms to the BaseSDF base class.
Optionally takes an SDF tensor to use as a ground truth SDF for training.
:param vsfConfig: Configuration for the Neural VSF model.
:type vsfConfig: dict
:param sdf: Optional SDF tensor to use as a geometry mask.
The SDF values are sampled at the center of the grid cells.
Cell (0,0,0) is at `BBox[0]` and cell `(N-1, N-1, N-1)` is at `BBox[1]`.
:type sdf: torch.Tensor, optional
"""
def __init__(self, vsfConfig : NeuralVSFConfig, sdf : torch.Tensor = None):
super().__init__()
self.config = vsfConfig
self.vsfNetwork = VSFNetwork(sdf=sdf, **asdict(vsfConfig))
[docs]
def getBBox(self) -> torch.Tensor:
return torch.tensor(self.config.aabb)
[docs]
def getStiffness(self, position: torch.Tensor) -> torch.Tensor:
assert self.config.output_names[0] == 'stiffness',"We only support stiffness as the first output"
assert self.config.output_dims[0] == 1
return self.vsfNetwork(position)[:,0]
[docs]
def save(self, path):
self.vsfNetwork.save(path)
[docs]
def load(self, path):
self.vsfNetwork.load(path)
[docs]
def to(self, device) -> NeuralVSF:
"""Converts the VSF to a given device or dtype"""
self.vsfNetwork.to(device)
return self
@property
def device(self):
return next(self.vsfNetwork.parameters()).device
@property
def dtype(self):
return next(self.vsfNetwork.parameters()).dtype