-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
class ABF(nn.Module):
def __init__(self, in_channel, out_channel, mid_channel, is_fuse=True):
super(ABF, self).__init__()
self.conv_first = nn.Sequential(
nn.Conv2d(in_channel, mid_channel, kernel_size=(1, 1), bias=False),
nn.BatchNorm2d(mid_channel)
)
self.conv_last = nn.Sequential(
nn.Conv2d(mid_channel, out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(out_channel)
)
self.att_conv = None if not is_fuse else nn.Sequential(
nn.Conv2d(mid_channel * 2, 2, kernel_size=(1, 1)),
nn.Sigmoid()
)
self.__init_weights()
def __init_weights(self):
nn.init.kaiming_uniform_(self.conv_first[0].weight, a=1)
nn.init.kaiming_uniform_(self.conv_last[0].weight, a=1)
def forward(self, x, y=None, shape=None):
assert len(x.shape) == 4
N, _, H, W = x.shape[:4]
x = self.conv_first(x)
if self.att_conv is not None:
# up sample residual features
y = F.interpolate(y, shape, mode="nearest")
# fusion
z = torch.cat([x, y], dim=1)
z = self.att_conv(z)
x = (x * z[:, 0].view(N, 1, H, W) + y * z[:, 1].view(N, 1, H, W))
y = self.conv_last(x)
return y, x
In the 'forward' function, only the channel of y seems must be equal to mid_channel if self.att_conv could work.But the input y is res_features, the channel's number of res_features seem can't be guaranteed to be equal to mid_channel.
Metadata
Metadata
Assignees
Labels
No labels