Commit e83657bb authored by 李聪聪's avatar 李聪聪
Browse files

v5

parent 4573a558
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -11,10 +11,12 @@ ADV:
  LAYERS: [True, False, True]
  DIS_MODEL:
    - in_channels: 256
      func_name: 'cross_entropy'
      embedding_norm: False
      embedding_dropout: False
      func_name: 'l2'
      pool_type: 'avg'
      loss_weight: 1.0
      window_sizes: [3, 9, 15]
      window_sizes: [1, 3, 9, 15]
    - in_channels: 1024
      func_name: 'focal_loss'
      focal_loss_gamma: 5
+1 −1
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ def smooth_l1_loss(input, target, beta=1. / 9, size_average=True):
    return loss.sum()


def sigmoid_focal_loss(inputs, targets, alpha, gamma, reduction="mean"):
def sigmoid_focal_loss(inputs, targets, alpha=-1, gamma=2, reduction="mean"):
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(
        inputs, targets, reduction="none"
+34 −17
Original line number Diff line number Diff line
@@ -94,12 +94,17 @@ class Dis(nn.Module):
        )

        self.shared_semantic = nn.Sequential(
            nn.Conv2d(in_channels + num_anchors, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.Conv2d(in_channels + num_anchors, in_channels, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias),
            NormModule(in_channels),
            nn.ReLU(True),

            nn.Conv2d(in_channels, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias),
            NormModule(256),
            nn.ReLU(True),
            DropoutModule(),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias),
            NormModule(256),
            nn.ReLU(True),
        )

@@ -109,17 +114,16 @@ class Dis(nn.Module):
        for i in range(self.num_windows):
            self.semantic_list += [
                nn.Sequential(
                    nn.Conv2d(256, 128, 1),
                    nn.BatchNorm2d(128),
                    nn.Conv2d(256, 128, 1, bias=bias),
                    NormModule(128),
                    nn.ReLU(True),

                    nn.Conv2d(128, 1, 1),
                )
            ]

        self.fc = nn.Sequential(
            nn.Conv2d(256 * channel_multiply, 128, 1, bias=False),
            nn.BatchNorm2d(128),
            NormModule(128),
            nn.ReLU(inplace=True),
        )

@@ -140,22 +144,29 @@ class Dis(nn.Module):
        N, C, H, W = feature.shape

        pyramid_features = []
        domain_logits_list = []
        for i, k in enumerate(self.window_sizes):
            if k == -1:
                x = self.pool_func(feature, kernel_size=(H, W))
            elif k == 1:
                x = feature
            else:
                stride = self.window_strides[i]
                if stride is None:
                    stride = i + 1  # default
                    stride = 1  # default
                x = self.pool_func(feature, kernel_size=k, stride=stride)
            _, _, h, w = x.shape
            semantic_map_per_level = F.interpolate(semantic_map, size=(h, w), mode='bilinear', align_corners=True)
            semantic_map_per_level = self.semantic_list[i](semantic_map_per_level)
            semantic_map_per_level = semantic_map_per_level.view(N, -1)
            semantic_map_per_level = F.softmax(semantic_map_per_level, dim=1)
            semantic_map_per_level = semantic_map_per_level.view(N, 1, h, w)
            domain_logits = self.semantic_list[i](semantic_map_per_level)
            domain_logits_list.append(domain_logits)

            x = torch.sum(x * semantic_map_per_level, dim=(2, 3), keepdim=True)
            domain_probs = domain_logits.sigmoid()

            domain_uncertainty = - domain_probs * torch.log(domain_probs)

            w_spatial = 1 - domain_uncertainty
            x = x + x * w_spatial
            x = F.adaptive_avg_pool2d(x, output_size=1)
            pyramid_features.append(x)

        fuse = sum(pyramid_features)  # [N, 256, 1, 1]
@@ -172,7 +183,7 @@ class Dis(nn.Module):
        final_features = final_features.view(N, -1)

        logits = self.predictor(final_features)
        return logits
        return logits, domain_logits_list

    def __repr__(self):
        attrs = {
@@ -285,18 +296,24 @@ class FasterRCNN(nn.Module):

            device = features.device
            for i, (s_feat, t_feat, netD) in enumerate(zip(s_adaptation_feats, t_adaptation_feats, self.dis_list)):
                s_domain_logits = netD(grad_reverse(s_feat, 1.0), grad_reverse(s_rpn_logits, 1.0))
                t_domain_logits = netD(grad_reverse(t_feat, 1.0), grad_reverse(t_rpn_logits, 1.0))
                s_domain_logits, s_domain_logits_list = netD(grad_reverse(s_feat, 1.0), grad_reverse(s_rpn_logits, 1.0))
                t_domain_logits, t_domain_logits_list = netD(grad_reverse(t_feat, 1.0), grad_reverse(t_rpn_logits, 1.0))
                loss_func = netD.loss_func
                loss_weight = netD.loss_weight
                num_windows = netD.num_windows
                gamma = netD.focal_loss_gamma

                w = 0.5
                s_domain_loss = loss_func(s_domain_logits, torch.zeros(s_domain_logits.size(0), dtype=torch.long, device=device)) * w
                t_domain_loss = loss_func(t_domain_logits, torch.ones(t_domain_logits.size(0), dtype=torch.long, device=device)) * w

                list_weights = (1.0 / num_windows) * 0.5

                loss_dict.update({
                    's_domain_loss%d' % i: s_domain_loss * loss_weight,
                    't_domain_loss%d' % i: t_domain_loss * loss_weight,
                    's_domain_list_loss%d' % i: list_weights * sum(sigmoid_focal_loss(la, torch.zeros_like(la), gamma=gamma) for la in s_domain_logits_list) * loss_weight,
                    't_domain_list_loss%d' % i: list_weights * sum(sigmoid_focal_loss(la, torch.ones_like(la), gamma=gamma) for la in t_domain_logits_list) * loss_weight,
                })

            # outputs['s_features'] = s_adaptation_feats