Commit 4573a558 authored by 李聪聪's avatar 李聪聪
Browse files

v4

parent ca97c1cd
Loading
Loading
Loading
Loading
+47 −0
Original line number Diff line number Diff line
MODEL:
  BACKBONE:
    NAME: 'vgg16'
  ROI_BOX_HEAD:
    NUM_CLASSES: 2
    BOX_PREDICTOR: 'vgg16_predictor'
    POOL_TYPE: 'pooling'
ADV:
  LAYERS: [False, False, True]
  DIS_MODEL:
    - in_channels: 512
      func_name: 'focal_loss'
      focal_loss_gamma: 3
      pool_type: 'avg'
      loss_weight: 1.0
      window_sizes: [3, 9, 15, 21, -1]
DATASETS:
  TRAINS: ['cityscapes_car_train']
  TARGETS: ['kitti_train']
  TESTS: ['kitti_train']
INPUT:
  TRANSFORMS_TRAIN:
    - name: 'random_flip'
    - name: 'resize'
      min_size: 600
    - name: 'normalize'
      mean: [0.5, 0.5, 0.5]
      std: [0.5, 0.5, 0.5]
      to_01: True
    - name: 'collect'
  TRANSFORMS_TEST:
    - name: 'resize'
      min_size: 600
    - name: 'normalize'
      mean: [0.5, 0.5, 0.5]
      std: [0.5, 0.5, 0.5]
      to_01: True
    - name: 'collect'
SOLVER:
  EPOCHS: 20
  STEPS: (16, 18)
  LR: 1e-5
  BATCH_SIZE: 1
TEST:
  EVAL_TYPES: ['voc']

WORK_DIR: './debug/adv_cityscapes_car_2_kitti'
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
from .misc import FrozenBatchNorm2d, cat
from .losses import smooth_l1_loss, sigmoid_focal_loss, softmax_focal_loss
from .losses import smooth_l1_loss, sigmoid_focal_loss, softmax_focal_loss, l2_loss
from .grad_reverse import grad_reverse
from .style_pool2d import style_pool2d, StylePool2d
+5 −0
Original line number Diff line number Diff line
@@ -51,3 +51,8 @@ def softmax_focal_loss(inputs, targets, gamma=2, reduction='mean'):
    else:
        raise ValueError
    return loss


def l2_loss(inputs, targets, reduction='mean'):
    loss = F.mse_loss(inputs.sigmoid(), targets.to(inputs.dtype).expand_as(inputs), reduction=reduction)
    return loss
+54 −25
Original line number Diff line number Diff line
@@ -3,8 +3,9 @@ from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from terminaltables import AsciiTable

from detection.layers import grad_reverse, softmax_focal_loss, sigmoid_focal_loss, style_pool2d
from detection.layers import grad_reverse, softmax_focal_loss, sigmoid_focal_loss, style_pool2d, l2_loss
from .backbone import build_backbone
from .roi_heads import BoxHead
from .rpn import RPN
@@ -14,6 +15,9 @@ class Dis(nn.Module):
    def __init__(self,
                 cfg,
                 in_channels,
                 embedding_kernel_size=3,
                 embedding_norm=True,
                 embedding_dropout=True,
                 func_name='focal_loss',
                 focal_loss_gamma=5,
                 pool_type='avg',
@@ -27,6 +31,9 @@ class Dis(nn.Module):
        num_anchors         = len(anchor_scales) * len(anchor_ratios)
        # fmt:on
        self.in_channels = in_channels
        self.embedding_kernel_size = embedding_kernel_size
        self.embedding_norm = embedding_norm
        self.embedding_dropout = embedding_dropout
        self.num_windows = len(window_sizes)
        self.num_anchors = num_anchors
        self.window_sizes = window_sizes
@@ -48,31 +55,42 @@ class Dis(nn.Module):
        self.pool_func = pool_func

        if func_name == 'focal_loss':
            num_domain_classes = 2
            loss_func = partial(softmax_focal_loss, gamma=focal_loss_gamma)
        elif func_name == 'cross_entropy':
            num_domain_classes = 2
            loss_func = F.cross_entropy
        elif func_name == 'l2':
            num_domain_classes = 1
            loss_func = l2_loss
        else:
            raise ValueError
        self.focal_loss_gamma = focal_loss_gamma
        self.func_name = func_name
        self.loss_func = loss_func
        self.loss_weight = loss_weight
        self.num_domain_classes = num_domain_classes

        NormModule = nn.BatchNorm2d if embedding_norm else nn.Identity
        DropoutModule = nn.Dropout if embedding_dropout else nn.Identity

        padding = (embedding_kernel_size - 1) // 2
        bias = not embedding_norm
        self.embedding = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, in_channels, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias),
            NormModule(in_channels),
            nn.ReLU(True),
            nn.Dropout(),
            DropoutModule(),

            nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.Conv2d(in_channels, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias),
            NormModule(256),
            nn.ReLU(True),
            nn.Dropout(),
            DropoutModule(),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=embedding_kernel_size, stride=1, padding=padding, bias=bias),
            NormModule(256),
            nn.ReLU(True),
            nn.Dropout(),
            DropoutModule(),
        )

        self.shared_semantic = nn.Sequential(
@@ -109,7 +127,7 @@ class Dis(nn.Module):
            nn.Conv2d(128, self.num_windows * 256 * channel_multiply, 1, bias=False),
        )

        self.predictor = nn.Linear(256 * channel_multiply, 2)
        self.predictor = nn.Linear(256 * channel_multiply, num_domain_classes)

    def forward(self, feature, rpn_logits):
        if feature.shape != rpn_logits.shape:
@@ -159,13 +177,20 @@ class Dis(nn.Module):
    def __repr__(self):
        attrs = {
            'in_channels': self.in_channels,
            'embedding_kernel_size': self.embedding_kernel_size,
            'embedding_norm': self.embedding_norm,
            'embedding_dropout': self.embedding_dropout,
            'num_domain_classes': self.num_domain_classes,
            'func_name': self.func_name,
            'focal_loss_gamma': self.focal_loss_gamma,
            'pool_type': self.pool_type,
            'loss_weight': self.loss_weight,
            'window_strides': self.window_strides,
            'window_sizes': self.window_sizes,
        }
        return self.__class__.__name__ + str(attrs)
        table = AsciiTable(list(zip(attrs.keys(), attrs.values())))
        table.inner_heading_row_border = False
        return self.__class__.__name__ + '\n' + table.table


class FasterRCNN(nn.Module):
@@ -179,6 +204,9 @@ class FasterRCNN(nn.Module):
        self.rpn = RPN(cfg, in_channels)
        self.box_head = BoxHead(cfg, in_channels)

        self.enable_adaptation = len(cfg.DATASETS.TARGETS) > 0
        self.ada_layers = [False] * 3
        if self.enable_adaptation:
            self.ada_layers = cfg.ADV.LAYERS
            dis_model = cfg.ADV.DIS_MODEL

@@ -186,6 +214,7 @@ class FasterRCNN(nn.Module):

            # self.netD = netD()
            # self.netD = D(cfg, in_channels)

            self.dis_list = nn.ModuleList()
            for model_config in dis_model:
                dis = Dis(cfg, **model_config)
@@ -248,7 +277,7 @@ class FasterRCNN(nn.Module):
        proposals, rpn_losses, s_rpn_logits = self.rpn(images, features, img_metas, targets)
        dets, box_losses, s_proposals, box_features, roi_features = self.box_head(features, proposals, img_metas, targets)

        if self.training and t_images is not None:
        if self.enable_adaptation and self.training and t_images is not None:
            t_features, t_adaptation_feats = forward_func(t_images)

            t_proposals, _, t_rpn_logits = self.rpn(t_images, t_features, t_img_metas, targets=None)