Commit ca97c1cd authored by 李聪聪's avatar 李聪聪
Browse files

add window strides

parent 6f85991f
Loading
Loading
Loading
Loading
+10 −1
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ class Dis(nn.Module):
                 focal_loss_gamma=5,
                 pool_type='avg',
                 loss_weight=1.0,
                 window_strides=None,
                 window_sizes=(3, 9, 15, 21, -1)):
        super().__init__()
        # fmt:off
@@ -29,6 +30,11 @@ class Dis(nn.Module):
        self.num_windows = len(window_sizes)
        self.num_anchors = num_anchors
        self.window_sizes = window_sizes
        if window_strides is None:
            self.window_strides = [None] * len(window_sizes)
        else:
            assert len(window_strides) == len(window_sizes), 'window_strides and window_sizes should has same len'
            self.window_strides = window_strides

        if pool_type == 'avg':
            channel_multiply = 1
@@ -120,7 +126,10 @@ class Dis(nn.Module):
            if k == -1:
                x = self.pool_func(feature, kernel_size=(H, W))
            else:
                x = self.pool_func(feature, kernel_size=k, stride=i + 1)
                stride = self.window_strides[i]
                if stride is None:
                    stride = i + 1  # default
                x = self.pool_func(feature, kernel_size=k, stride=stride)
            _, _, h, w = x.shape
            semantic_map_per_level = F.interpolate(semantic_map, size=(h, w), mode='bilinear', align_corners=True)
            semantic_map_per_level = self.semantic_list[i](semantic_map_per_level)