122 lines
3.8 KiB
Python
122 lines
3.8 KiB
Python
from typing import List
|
|
|
|
from torch import nn
|
|
|
|
import jsparse
|
|
from jsparse import SparseTensor
|
|
from jsparse import nn as spnn
|
|
|
|
from jsparse.nn import SparseConvBlock, SparseConvTransposeBlock, SparseResBlock
|
|
|
|
__all__ = ['SparseResUNet42']
|
|
|
|
|
|
class SparseResUNet(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
stem_channels: int,
|
|
encoder_channels: List[int],
|
|
decoder_channels: List[int],
|
|
*,
|
|
in_channels: int = 4,
|
|
width_multiplier: float = 1.0,
|
|
) -> None:
|
|
super().__init__()
|
|
self.stem_channels = stem_channels
|
|
self.encoder_channels = encoder_channels
|
|
self.decoder_channels = decoder_channels
|
|
self.in_channels = in_channels
|
|
self.width_multiplier = width_multiplier
|
|
|
|
num_channels = [stem_channels] + encoder_channels + decoder_channels
|
|
num_channels = [int(width_multiplier * nc) for nc in num_channels]
|
|
|
|
self.stem = nn.Sequential(
|
|
spnn.Conv3d(in_channels, num_channels[0], 3),
|
|
spnn.BatchNorm(num_channels[0]),
|
|
spnn.ReLU(),
|
|
spnn.Conv3d(num_channels[0], num_channels[0], 3),
|
|
spnn.BatchNorm(num_channels[0]),
|
|
spnn.ReLU(),
|
|
)
|
|
|
|
# TODO(Zhijian): the current implementation of encoder and decoder
|
|
# is hard-coded for 4 encoder stages and 4 decoder stages. We should
|
|
# work on a more generic implementation in the future.
|
|
|
|
self.encoders = nn.ModuleList()
|
|
for k in range(4):
|
|
self.encoders.append(
|
|
nn.Sequential(
|
|
SparseConvBlock(
|
|
num_channels[k],
|
|
num_channels[k],
|
|
2,
|
|
stride=2,
|
|
),
|
|
SparseResBlock(num_channels[k], num_channels[k + 1], 3),
|
|
SparseResBlock(num_channels[k + 1], num_channels[k + 1], 3),
|
|
))
|
|
|
|
self.decoders = nn.ModuleList()
|
|
for k in range(4):
|
|
self.decoders.append(
|
|
nn.ModuleDict({
|
|
'upsample':
|
|
SparseConvTransposeBlock(
|
|
num_channels[k + 4],
|
|
num_channels[k + 5],
|
|
2,
|
|
stride=2,
|
|
),
|
|
'fuse':
|
|
nn.Sequential(
|
|
SparseResBlock(
|
|
num_channels[k + 5] + num_channels[3 - k],
|
|
num_channels[k + 5],
|
|
3,
|
|
),
|
|
SparseResBlock(
|
|
num_channels[k + 5],
|
|
num_channels[k + 5],
|
|
3,
|
|
),
|
|
)
|
|
}))
|
|
|
|
def _unet_forward(
|
|
self,
|
|
x: SparseTensor,
|
|
encoders: nn.ModuleList,
|
|
decoders: nn.ModuleList,
|
|
) -> List[SparseTensor]:
|
|
if not encoders and not decoders:
|
|
return [x]
|
|
|
|
# downsample
|
|
xd = encoders[0](x)
|
|
|
|
# inner recursion
|
|
outputs = self._unet_forward(xd, encoders[1:], decoders[:-1])
|
|
yd = outputs[-1]
|
|
|
|
# upsample and fuse
|
|
u = decoders[-1]['upsample'](yd)
|
|
y = decoders[-1]['fuse'](jsparse.cat([u, x]))
|
|
|
|
return [x] + outputs + [y]
|
|
|
|
def forward(self, x: SparseTensor) -> List[SparseTensor]:
|
|
return self._unet_forward(self.stem(x), self.encoders, self.decoders)
|
|
|
|
|
|
class SparseResUNet42(SparseResUNet):
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
super().__init__(
|
|
stem_channels=32,
|
|
encoder_channels=[32, 64, 128, 256],
|
|
decoder_channels=[256, 128, 96, 96],
|
|
**kwargs,
|
|
) |