Loading configs/adv_resnet101_voc_2_watercolor.yaml +4 −2 Original line number Diff line number Diff line Loading @@ -11,10 +11,12 @@ ADV: LAYERS: [True, False, True] DIS_MODEL: - in_channels: 256 func_name: 'cross_entropy' embedding_norm: False embedding_dropout: False func_name: 'l2' pool_type: 'avg' loss_weight: 1.0 window_sizes: [3, 9, 15] window_sizes: [1, 3, 9, 15] - in_channels: 1024 func_name: 'focal_loss' focal_loss_gamma: 5 Loading detection/layers/losses.py +1 −1 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ def smooth_l1_loss(input, target, beta=1. / 9, size_average=True): return loss.sum() def sigmoid_focal_loss(inputs, targets, alpha, gamma, reduction="mean"): def sigmoid_focal_loss(inputs, targets, alpha=-1, gamma=2, reduction="mean"): p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits( inputs, targets, reduction="none" Loading detection/modeling/faster_rcnn.py +34 −17 Original line number Diff line number Diff line Loading @@ -94,12 +94,17 @@ class Dis(nn.Module): ) self.shared_semantic = nn.Sequential( nn.Conv2d(in_channels + num_anchors, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.Conv2d(in_channels + num_anchors, in_channels, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias), NormModule(in_channels), nn.ReLU(True), nn.Conv2d(in_channels, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias), NormModule(256), nn.ReLU(True), DropoutModule(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.Conv2d(256, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias), NormModule(256), nn.ReLU(True), ) Loading @@ -109,17 +114,16 @@ class Dis(nn.Module): for i in range(self.num_windows): self.semantic_list += [ nn.Sequential( nn.Conv2d(256, 128, 1), nn.BatchNorm2d(128), nn.Conv2d(256, 128, 1, bias=bias), NormModule(128), nn.ReLU(True), nn.Conv2d(128, 1, 1), ) ] self.fc = nn.Sequential( nn.Conv2d(256 * channel_multiply, 128, 1, bias=False), nn.BatchNorm2d(128), NormModule(128), nn.ReLU(inplace=True), ) Loading @@ -140,22 +144,29 @@ class Dis(nn.Module): N, C, H, W = feature.shape pyramid_features = [] domain_logits_list = [] for i, k in enumerate(self.window_sizes): if k == -1: x = self.pool_func(feature, kernel_size=(H, W)) elif k == 1: x = feature else: stride = self.window_strides[i] if stride is None: stride = i + 1 # default stride = 1 # default x = self.pool_func(feature, kernel_size=k, stride=stride) _, _, h, w = x.shape semantic_map_per_level = F.interpolate(semantic_map, size=(h, w), mode='bilinear', align_corners=True) semantic_map_per_level = self.semantic_list[i](semantic_map_per_level) semantic_map_per_level = semantic_map_per_level.view(N, -1) semantic_map_per_level = F.softmax(semantic_map_per_level, dim=1) semantic_map_per_level = semantic_map_per_level.view(N, 1, h, w) domain_logits = self.semantic_list[i](semantic_map_per_level) domain_logits_list.append(domain_logits) x = torch.sum(x * semantic_map_per_level, dim=(2, 3), keepdim=True) domain_probs = domain_logits.sigmoid() domain_uncertainty = - domain_probs * torch.log(domain_probs) w_spatial = 1 - domain_uncertainty x = x + x * w_spatial x = F.adaptive_avg_pool2d(x, output_size=1) pyramid_features.append(x) fuse = sum(pyramid_features) # [N, 256, 1, 1] Loading @@ -172,7 +183,7 @@ class Dis(nn.Module): final_features = final_features.view(N, -1) logits = self.predictor(final_features) return logits return logits, domain_logits_list def __repr__(self): attrs = { Loading Loading @@ -285,18 +296,24 @@ class FasterRCNN(nn.Module): device = features.device for i, (s_feat, t_feat, netD) in enumerate(zip(s_adaptation_feats, t_adaptation_feats, self.dis_list)): s_domain_logits = netD(grad_reverse(s_feat, 1.0), grad_reverse(s_rpn_logits, 1.0)) t_domain_logits = netD(grad_reverse(t_feat, 1.0), grad_reverse(t_rpn_logits, 1.0)) s_domain_logits, s_domain_logits_list = netD(grad_reverse(s_feat, 1.0), grad_reverse(s_rpn_logits, 1.0)) t_domain_logits, t_domain_logits_list = netD(grad_reverse(t_feat, 1.0), grad_reverse(t_rpn_logits, 1.0)) loss_func = netD.loss_func loss_weight = netD.loss_weight num_windows = netD.num_windows gamma = netD.focal_loss_gamma w = 0.5 s_domain_loss = loss_func(s_domain_logits, torch.zeros(s_domain_logits.size(0), dtype=torch.long, device=device)) * w t_domain_loss = loss_func(t_domain_logits, torch.ones(t_domain_logits.size(0), dtype=torch.long, device=device)) * w list_weights = (1.0 / num_windows) * 0.5 loss_dict.update({ 's_domain_loss%d' % i: s_domain_loss * loss_weight, 't_domain_loss%d' % i: t_domain_loss * loss_weight, 's_domain_list_loss%d' % i: list_weights * sum(sigmoid_focal_loss(la, torch.zeros_like(la), gamma=gamma) for la in s_domain_logits_list) * loss_weight, 't_domain_list_loss%d' % i: list_weights * sum(sigmoid_focal_loss(la, torch.ones_like(la), gamma=gamma) for la in t_domain_logits_list) * loss_weight, }) # outputs['s_features'] = s_adaptation_feats Loading Loading
configs/adv_resnet101_voc_2_watercolor.yaml +4 −2 Original line number Diff line number Diff line Loading @@ -11,10 +11,12 @@ ADV: LAYERS: [True, False, True] DIS_MODEL: - in_channels: 256 func_name: 'cross_entropy' embedding_norm: False embedding_dropout: False func_name: 'l2' pool_type: 'avg' loss_weight: 1.0 window_sizes: [3, 9, 15] window_sizes: [1, 3, 9, 15] - in_channels: 1024 func_name: 'focal_loss' focal_loss_gamma: 5 Loading
detection/layers/losses.py +1 −1 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ def smooth_l1_loss(input, target, beta=1. / 9, size_average=True): return loss.sum() def sigmoid_focal_loss(inputs, targets, alpha, gamma, reduction="mean"): def sigmoid_focal_loss(inputs, targets, alpha=-1, gamma=2, reduction="mean"): p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits( inputs, targets, reduction="none" Loading
detection/modeling/faster_rcnn.py +34 −17 Original line number Diff line number Diff line Loading @@ -94,12 +94,17 @@ class Dis(nn.Module): ) self.shared_semantic = nn.Sequential( nn.Conv2d(in_channels + num_anchors, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.Conv2d(in_channels + num_anchors, in_channels, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias), NormModule(in_channels), nn.ReLU(True), nn.Conv2d(in_channels, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias), NormModule(256), nn.ReLU(True), DropoutModule(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.Conv2d(256, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias), NormModule(256), nn.ReLU(True), ) Loading @@ -109,17 +114,16 @@ class Dis(nn.Module): for i in range(self.num_windows): self.semantic_list += [ nn.Sequential( nn.Conv2d(256, 128, 1), nn.BatchNorm2d(128), nn.Conv2d(256, 128, 1, bias=bias), NormModule(128), nn.ReLU(True), nn.Conv2d(128, 1, 1), ) ] self.fc = nn.Sequential( nn.Conv2d(256 * channel_multiply, 128, 1, bias=False), nn.BatchNorm2d(128), NormModule(128), nn.ReLU(inplace=True), ) Loading @@ -140,22 +144,29 @@ class Dis(nn.Module): N, C, H, W = feature.shape pyramid_features = [] domain_logits_list = [] for i, k in enumerate(self.window_sizes): if k == -1: x = self.pool_func(feature, kernel_size=(H, W)) elif k == 1: x = feature else: stride = self.window_strides[i] if stride is None: stride = i + 1 # default stride = 1 # default x = self.pool_func(feature, kernel_size=k, stride=stride) _, _, h, w = x.shape semantic_map_per_level = F.interpolate(semantic_map, size=(h, w), mode='bilinear', align_corners=True) semantic_map_per_level = self.semantic_list[i](semantic_map_per_level) semantic_map_per_level = semantic_map_per_level.view(N, -1) semantic_map_per_level = F.softmax(semantic_map_per_level, dim=1) semantic_map_per_level = semantic_map_per_level.view(N, 1, h, w) domain_logits = self.semantic_list[i](semantic_map_per_level) domain_logits_list.append(domain_logits) x = torch.sum(x * semantic_map_per_level, dim=(2, 3), keepdim=True) domain_probs = domain_logits.sigmoid() domain_uncertainty = - domain_probs * torch.log(domain_probs) w_spatial = 1 - domain_uncertainty x = x + x * w_spatial x = F.adaptive_avg_pool2d(x, output_size=1) pyramid_features.append(x) fuse = sum(pyramid_features) # [N, 256, 1, 1] Loading @@ -172,7 +183,7 @@ class Dis(nn.Module): final_features = final_features.view(N, -1) logits = self.predictor(final_features) return logits return logits, domain_logits_list def __repr__(self): attrs = { Loading Loading @@ -285,18 +296,24 @@ class FasterRCNN(nn.Module): device = features.device for i, (s_feat, t_feat, netD) in enumerate(zip(s_adaptation_feats, t_adaptation_feats, self.dis_list)): s_domain_logits = netD(grad_reverse(s_feat, 1.0), grad_reverse(s_rpn_logits, 1.0)) t_domain_logits = netD(grad_reverse(t_feat, 1.0), grad_reverse(t_rpn_logits, 1.0)) s_domain_logits, s_domain_logits_list = netD(grad_reverse(s_feat, 1.0), grad_reverse(s_rpn_logits, 1.0)) t_domain_logits, t_domain_logits_list = netD(grad_reverse(t_feat, 1.0), grad_reverse(t_rpn_logits, 1.0)) loss_func = netD.loss_func loss_weight = netD.loss_weight num_windows = netD.num_windows gamma = netD.focal_loss_gamma w = 0.5 s_domain_loss = loss_func(s_domain_logits, torch.zeros(s_domain_logits.size(0), dtype=torch.long, device=device)) * w t_domain_loss = loss_func(t_domain_logits, torch.ones(t_domain_logits.size(0), dtype=torch.long, device=device)) * w list_weights = (1.0 / num_windows) * 0.5 loss_dict.update({ 's_domain_loss%d' % i: s_domain_loss * loss_weight, 't_domain_loss%d' % i: t_domain_loss * loss_weight, 's_domain_list_loss%d' % i: list_weights * sum(sigmoid_focal_loss(la, torch.zeros_like(la), gamma=gamma) for la in s_domain_logits_list) * loss_weight, 't_domain_list_loss%d' % i: list_weights * sum(sigmoid_focal_loss(la, torch.ones_like(la), gamma=gamma) for la in t_domain_logits_list) * loss_weight, }) # outputs['s_features'] = s_adaptation_feats Loading