Commit 7fb5ee48 authored by 李聪聪's avatar 李聪聪
Browse files

add adv

parent 22881a82
import argparse
import datetime
import math
import os
import time
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
from detection.utils import dist_utils
from detection.config import cfg
from detection.data.build import build_data_loaders
from detection.engine.eval import evaluation
from detection.modeling.build import build_detectors
from detection import utils
global_step = 0
total_steps = 0
def cosine_scheduler(eta_max, eta_min, current_step):
y = eta_min + 0.5 * (eta_max - eta_min) * (1 + math.cos(current_step / total_steps * math.pi))
return y
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
def f(x):
if x >= warmup_iters:
return 1
alpha = float(x) / warmup_iters
return warmup_factor * (1 - alpha) + alpha
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
def detach_features(features):
if isinstance(features, torch.Tensor):
return features.detach()
return tuple([f.detach() for f in features])
def train_one_epoch(model, optimizer, train_loader, target_loader, device, epoch, dis_model, dis_optimizer, print_freq=10, writer=None):
global global_step
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.0e}'))
metric_logger.add_meter('lr_dis', utils.SmoothedValue(window_size=1, fmt='{value:.0e}'))
header = 'Epoch: [{}]'.format(epoch)
lr_schedulers = []
if epoch == 0:
warmup_factor = 1. / 500
warmup_iters = min(500, len(train_loader) - 1)
lr_schedulers = [
warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor),
warmup_lr_scheduler(dis_optimizer, warmup_iters, warmup_factor),
]
source_label = 1
target_label = 0
target_loader = iter(target_loader)
for images, img_metas, targets in metric_logger.log_every(train_loader, print_freq, header):
global_step += 1
images = images.to(device)
targets = [t.to(device) for t in targets]
t_images, t_img_metas, _ = next(target_loader)
t_images = t_images.to(device)
loss_dict, outputs = model(images, img_metas, targets, t_images, t_img_metas)
loss_dict_for_log = dict(loss_dict)
s_windows = outputs['s_windows']
t_windows = outputs['t_windows']
s_rpn_logits = outputs['s_rpn_logits']
t_rpn_logits = outputs['t_rpn_logits']
s_box_features = outputs['s_box_features']
t_box_features = outputs['t_box_features']
# -------------------------------------------------------------------
# -----------------------------1.Train D-----------------------------
# -------------------------------------------------------------------
s_dis_loss = dis_model(detach_features(s_windows), source_label, s_rpn_logits.detach(), s_box_features.detach())
t_dis_loss = dis_model(detach_features(t_windows), target_label, t_rpn_logits.detach(), t_box_features.detach())
dis_loss = s_dis_loss + t_dis_loss
loss_dict_for_log['s_dis_loss'] = s_dis_loss
loss_dict_for_log['t_dis_loss'] = t_dis_loss
dis_optimizer.zero_grad()
dis_loss.backward()
dis_optimizer.step()
# -------------------------------------------------------------------
# -----------------------------2.Train G-----------------------------
# -------------------------------------------------------------------
adv_loss = dis_model(t_windows, source_label, t_rpn_logits, t_box_features)
loss_dict_for_log['adv_loss'] = adv_loss
gamma = 1e-2
det_loss = sum(list(loss_dict.values()))
losses = det_loss + adv_loss * gamma
optimizer.zero_grad()
losses.backward()
optimizer.step()
loss_dict_reduced = dist_utils.reduce_dict(loss_dict_for_log)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
metric_logger.update(lr_dis=dis_optimizer.param_groups[0]["lr"])
metric_logger.update(gamma=gamma)
if global_step % print_freq == 0:
if writer:
for k, v in loss_dict_reduced.items():
writer.add_scalar('losses/{}'.format(k), v, global_step=global_step)
writer.add_scalar('losses/total_loss', losses_reduced, global_step=global_step)
writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step=global_step)
writer.add_scalar('lr_dis', dis_optimizer.param_groups[0]['lr'], global_step=global_step)
class DisModelPerLevel(nn.Module):
def __init__(self, cfg, in_channels=512, window_sizes=(3, 7, 13, 21, 32)):
super().__init__()
# fmt:off
anchor_scales = cfg.MODEL.RPN.ANCHOR_SIZES
anchor_ratios = cfg.MODEL.RPN.ASPECT_RATIOS
num_anchors = len(anchor_scales) * len(anchor_ratios)
# fmt:on
self.window_sizes = window_sizes
self.model_list = nn.ModuleList()
# self.weight_list = nn.ModuleList()
for _ in range(len(self.window_sizes)):
self.model_list += [
nn.Sequential(
nn.Conv2d(in_channels * 2, in_channels, kernel_size=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels, 256, kernel_size=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 1, kernel_size=1),
)
]
# self.weight_list += [
# nn.Sequential(
# nn.Conv2d(in_channels * 2 + 1, 128, 1),
# nn.ReLU(inplace=True),
# nn.Conv2d(128, 1, 1),
# nn.Sigmoid(),
# )
# ]
def forward(self, window_features, label, rpn_logits, box_features):
# rpn_semantic_map = torch.mean(rpn_logits, dim=1, keepdim=True)
logits = []
for i, x in enumerate(window_features):
_, _, window_h, window_w = x.shape
avg_x = torch.mean(x, dim=(2, 3), keepdim=True) # (N, C * 2, 1, 1)
avg_x_expanded = avg_x.expand(-1, -1, window_h, window_w) # (N, C * 2, h, w)
# rpn_semantic_map_per_level = F.interpolate(rpn_semantic_map, size=(window_h, window_w), mode='bilinear', align_corners=True)
# rpn_semantic_map_per_level = torch.cat((x, rpn_semantic_map_per_level), dim=1)
# weight = self.weight_list[i](rpn_semantic_map_per_level) # (N, 1, h, w)
# residual = weight * (x - avg_x_expanded)
residual = (x - avg_x_expanded)
x = self.model_list[i](residual)
logits.append(x)
losses = sum(F.binary_cross_entropy_with_logits(l, torch.full_like(l, label)) for l in logits)
return losses
class Discriminator(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.ModuleList([
DisModelPerLevel(cfg, in_channels=512, window_sizes=(3, 7, 13, 21, 32)), # (1, 512, 32, 64)
])
def forward(self, window_features, label, rpn_logits, box_features):
if isinstance(window_features[0], torch.Tensor):
window_features = (window_features,)
losses = []
for i, (layer, feature) in enumerate(zip(self.layers, window_features)):
loss = layer(feature, label, rpn_logits, box_features)
losses.append(loss)
losses = sum(losses)
return losses
def main(cfg, args):
train_loader = build_data_loaders(cfg.DATASETS.TRAINS, transforms=cfg.INPUT.TRANSFORMS_TRAIN, is_train=True, distributed=args.distributed,
batch_size=cfg.SOLVER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS)
target_loader = build_data_loaders(cfg.DATASETS.TARGETS, transforms=cfg.INPUT.TRANSFORMS_TRAIN, is_train=True, distributed=args.distributed,
batch_size=cfg.SOLVER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS)
test_loaders = build_data_loaders(cfg.DATASETS.TESTS, transforms=cfg.INPUT.TRANSFORMS_TEST, is_train=False,
distributed=args.distributed, num_workers=cfg.DATALOADER.NUM_WORKERS)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_detectors(cfg)
model.to(device)
dis_model = Discriminator(cfg)
dis_model.to(device)
model_without_ddp = model
dis_model_without_ddp = dis_model
if args.distributed:
model = DistributedDataParallel(model, device_ids=[args.gpu])
dis_model = DistributedDataParallel(dis_model, device_ids=[args.gpu])
model_without_ddp = model.module
dis_model_without_ddp = dis_model.module
# optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], cfg.SOLVER.LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], cfg.SOLVER.LR, betas=(0.9, 0.999), weight_decay=cfg.SOLVER.WEIGHT_DECAY)
dis_optimizer = torch.optim.Adam([p for p in dis_model.parameters() if p.requires_grad], cfg.SOLVER.LR, betas=(0.9, 0.999), weight_decay=cfg.SOLVER.WEIGHT_DECAY)
schedulers = [
torch.optim.lr_scheduler.MultiStepLR(optimizer, cfg.SOLVER.STEPS, gamma=cfg.SOLVER.GAMMA),
torch.optim.lr_scheduler.MultiStepLR(dis_optimizer, cfg.SOLVER.STEPS, gamma=cfg.SOLVER.GAMMA),
]
current_epoch = -1
if args.resume:
print('Loading from {} ...'.format(args.resume))
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if 'current_epoch' in checkpoint:
current_epoch = int(checkpoint['current_epoch'])
if 'discriminator' in checkpoint:
dis_model_without_ddp.load_state_dict(checkpoint['discriminator'])
work_dir = cfg.WORK_DIR
if args.test_only:
evaluation(model, test_loaders, device, types=cfg.TEST.EVAL_TYPES, output_dir=work_dir)
return
losses_writer = None
if dist_utils.is_main_process():
losses_writer = SummaryWriter(os.path.join(work_dir, 'losses'))
losses_writer.add_text('config', '{}'.format(str(cfg).replace('\n', ' \n')))
losses_writer.add_text('args', str(args))
metrics_writers = {}
if dist_utils.is_main_process():
test_dataset_names = [loader.dataset.dataset_name for loader in test_loaders]
for dataset_name in test_dataset_names:
metrics_writers[dataset_name] = SummaryWriter(os.path.join(work_dir, 'metrics', dataset_name))
start_time = time.time()
epochs = cfg.SOLVER.EPOCHS
global total_steps
start_epoch = current_epoch + 1
total_steps = (epochs - start_epoch) * len(train_loader)
print("Start training, total epochs: {} ({} - {}), total steps: {}".format(epochs - start_epoch, start_epoch, epochs - 1, total_steps))
for epoch in range(start_epoch, epochs):
if args.distributed:
train_loader.batch_sampler.sampler.set_epoch(epoch)
target_loader.batch_sampler.sampler.set_epoch(epoch)
epoch_start = time.time()
train_one_epoch(model, optimizer, train_loader, target_loader, device, epoch,
dis_model=dis_model, dis_optimizer=dis_optimizer,
writer=losses_writer)
for scheduler in schedulers:
scheduler.step()
state_dict = {
'model': model_without_ddp.state_dict(),
'discriminator': dis_model_without_ddp.state_dict(),
'current_epoch': epoch,
}
save_path = os.path.join(work_dir, 'model_epoch_{:02d}.pth'.format(epoch))
dist_utils.save_on_master(state_dict, save_path)
print('Saved to {}.'.format(save_path))
metrics = evaluation(model, test_loaders, device, cfg.TEST.EVAL_TYPES, output_dir=work_dir, iteration=epoch)
if dist_utils.is_main_process() and losses_writer:
for dataset_name, metric in metrics.items():
for k, v in metric.items():
metrics_writers[dataset_name].add_scalar('metrics/' + k, v, global_step=global_step)
epoch_cost = time.time() - epoch_start
left = epochs - epoch - 1
print('Epoch {} ended, cost {}. Left {} epochs, may cost {}'.format(epoch,
str(datetime.timedelta(seconds=int(epoch_cost))),
left,
str(datetime.timedelta(seconds=int(left * epoch_cost)))))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Total training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument("--config-file", help="path to config file", type=str)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument("--test-only", help="Only test the model", action="store_true")
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument("opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
dist_utils.init_distributed_mode(args)
print(args)
world_size = dist_utils.get_world_size()
if world_size != 4:
lr = cfg.SOLVER.LR * (float(world_size) / 4)
print('Change lr from {} to {}'.format(cfg.SOLVER.LR, lr))
cfg.merge_from_list(['SOLVER.LR', lr])
print(cfg)
os.makedirs(cfg.WORK_DIR, exist_ok=True)
main(cfg, args)
MODEL:
BACKBONE:
NAME: 'vgg16'
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ROI_BOX_HEAD:
NUM_CLASSES: 9
BOX_PREDICTOR: 'vgg16_predictor'
POOL_TYPE: 'align'
DATASETS:
TRAINS: ['cityscapes_train']
TARGETS: ['foggy_cityscapes_train_0.02']
TESTS: ['foggy_cityscapes_val_0.02']
INPUT:
TRANSFORMS_TRAIN:
- name: 'random_flip'
- name: 'resize'
min_size: 600
- name: 'normalize'
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
to_01: True
- name: 'collect'
TRANSFORMS_TEST:
- name: 'resize'
min_size: 600
- name: 'normalize'
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
to_01: True
- name: 'collect'
SOLVER:
EPOCHS: 25
STEPS: (16, 22)
LR: 1e-5
BATCH_SIZE: 1
TEST:
EVAL_TYPES: ['voc']
WORK_DIR: './work_dir/adv_cityscapes_2_foggy'
\ No newline at end of file
MODEL:
BACKBONE:
NAME: 'vgg16'
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ROI_BOX_HEAD:
NUM_CLASSES: 9
BOX_PREDICTOR: 'vgg16_predictor'
POOL_TYPE: 'pooling'
POOL_TYPE: 'align'
DATASETS:
TRAINS: ['cityscapes_train']
TESTS: ['foggy_cityscapes_val_0.02']
......@@ -34,4 +36,4 @@ SOLVER:
TEST:
EVAL_TYPES: ['voc']
WORK_DIR: './work_dir/baseline_cityscapes_2_foggy'
\ No newline at end of file
WORK_DIR: './work_dir/baseline_cityscapes_2_foggy_align'
\ No newline at end of file
......@@ -46,7 +46,7 @@ _C.MODEL.RPN.PRE_NMS_TOP_N_TEST = 6000
_C.MODEL.RPN.POST_NMS_TOP_N_TRAIN = 2000
_C.MODEL.RPN.POST_NMS_TOP_N_TEST = 300
_C.MODEL.RPN.NMS_THRESH = 0.7
_C.MODEL.RPN.MIN_SIZE = 0
_C.MODEL.RPN.MIN_SIZE = 1
# ---------------------------------------------------------------------------- #
# ROI HEADS options
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import build_backbone
from .roi_heads import BoxHead
from .rpn import RPN
class LocalWindowExtractor:
def __init__(self, window_sizes=(3, 7, 13, 21, 32)):
assert 1 not in window_sizes, 'Not support window size 1'
self.window_sizes = window_sizes
self.strides = (1, 3, 6, 10, 15)
def __call__(self, feature):
N, C, H, W = feature.shape
windows = []
for i, K in enumerate(self.window_sizes):
# stride = max(1, (K - 1) // 2)
stride = self.strides[i]
NEW_H, NEW_W = int((H - K) / stride + 1), int((W - K) / stride + 1)
img_windows = F.unfold(feature, kernel_size=K, stride=stride)
img_windows = img_windows.view(N, C, K, K, -1)
var, mean = torch.var_mean(img_windows, dim=(2, 3), unbiased=False) # (N, C, NEW_H * NEW_W)
std = torch.sqrt(var + 1e-12)
x = torch.cat((mean, std), dim=1) # (N, C * 2, NEW_H * NEW_W)
x = x.view(N, C * 2, NEW_H, NEW_W)
windows.append(x)
return windows
class FasterRCNN(nn.Module):
def __init__(self, cfg):
super(FasterRCNN, self).__init__()
......@@ -12,16 +39,37 @@ class FasterRCNN(nn.Module):
self.backbone = backbone
self.rpn = RPN(cfg, in_channels)
self.box_head = BoxHead(cfg, in_channels)
window_sizes = (3, 7, 13, 21, 32)
self.local_window_extractor = LocalWindowExtractor(window_sizes)
def forward(self, images, img_metas, targets=None):
def forward(self, images, img_metas, targets=None, t_images=None, t_img_metas=None):
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
outputs = dict()
loss_dict = dict()
features = self.backbone(images)
proposals, rpn_losses = self.rpn(images, features, img_metas, targets)
dets, box_losses = self.box_head(features, proposals, img_metas, targets)
loss_dict = {}
proposals, rpn_losses, s_rpn_logits = self.rpn(images, features, img_metas, targets)
dets, box_losses, box_features = self.box_head(features, proposals, img_metas, targets)
if self.training and t_images is not None:
s_windows = self.local_window_extractor(features)
t_features = self.backbone(t_images)
t_windows = self.local_window_extractor(t_features)
t_proposals, _, t_rpn_logits = self.rpn(t_images, t_features, t_img_metas, targets=None)
_, _, t_box_features = self.box_head(t_features, t_proposals, t_img_metas, targets=None)
outputs['s_windows'] = s_windows
outputs['t_windows'] = t_windows
outputs['s_rpn_logits'] = s_rpn_logits
outputs['t_rpn_logits'] = t_rpn_logits
outputs['s_box_features'] = box_features
outputs['t_box_features'] = t_box_features
if self.training:
loss_dict.update(rpn_losses)
loss_dict.update(box_losses)
return loss_dict
return loss_dict, outputs
return dets
......@@ -35,9 +35,10 @@ class VGG16BoxPredictor(nn.Module):
def forward(self, box_features):
box_features = box_features.view(box_features.size(0), -1)
box_features = self.classifier(box_features)
class_logits = self.cls_score(box_features)
box_regression = self.bbox_pred(box_features)
return class_logits, box_regression
return class_logits, box_regression, box_features
class ResNetBoxPredictor(nn.Module):
......@@ -60,9 +61,10 @@ class ResNetBoxPredictor(nn.Module):
def forward(self, box_features):
box_features = self.extractor(box_features)
box_features = torch.mean(box_features, dim=(2, 3))
class_logits = self.cls_score(box_features)
box_regression = self.bbox_pred(box_features)
return class_logits, box_regression
return class_logits, box_regression, box_features
BOX_PREDICTORS = {
......@@ -129,15 +131,20 @@ class BoxHead(nn.Module):
self.fg_bg_sampler = BalancedPositiveNegativeSampler(batch_size, 0.25)
def forward(self, features, proposals, img_metas, targets=None):
if self.training:
if self.training and targets is not None:
with torch.no_grad():
proposals, labels, regression_targets = self.select_training_samples(proposals, targets)
is_target_domain = self.training and targets is None
box_features = self.pooler(features, proposals)
class_logits, box_regression = self.box_predictor(box_features)
class_logits, box_regression, box_features = self.box_predictor(box_features)
if is_target_domain:
return [], {}, box_features
if self.training:
if self.training and targets is not None:
classification_loss, box_loss = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
loss = {
'rcnn_cls_loss': classification_loss,
......@@ -147,7 +154,7 @@ class BoxHead(nn.Module):
else:
loss = {}
dets = self.post_processor(class_logits, box_regression, proposals, img_metas)
return dets, loss
return dets, loss, box_features
def post_processor(self, class_logits, box_regression, proposals, img_metas):
num_classes = class_logits.shape[1]
......
......@@ -14,6 +14,7 @@ from .anchor_generator import AnchorGenerator
class RPN(nn.Module):
def __init__(self, cfg, in_channels):
super().__init__()
self.cfg = cfg
# fmt:off
batch_size = cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE
anchor_stride = cfg.MODEL.RPN.ANCHOR_STRIDE
......@@ -56,10 +57,12 @@ class RPN(nn.Module):
t = F.relu(self.conv(features))
logits = self.cls_logits(t)
bbox_reg = self.bbox_pred(t)
is_target_domain = self.training and targets is None
with torch.no_grad():
proposals = self.generate_proposals(anchors, logits, bbox_reg, img_metas)
proposals = self.generate_proposals(anchors, logits, bbox_reg, img_metas, is_target_domain)
if self.training:
if self.training and targets is not None:
objectness_loss, box_loss = self.losses(anchors, logits, bbox_reg, img_metas, targets)
loss = {
'rpn_cls_loss': objectness_loss,
......@@ -68,19 +71,22 @@ class RPN(nn.Module):
else:
loss = {}
return proposals, loss
return proposals, loss, logits
def generate_proposals(self, anchors, objectness, box_regression, img_metas):
def generate_proposals(self, anchors, objectness, box_regression, img_metas, is_target_domain=False):
"""
Args:
anchors:
objectness: (N, A, H, W)
box_regression: (N, A * 4, H, W)
img_metas:
is_target_domain:
Returns:
"""
pre_nms_top_n = self.pre_nms_top_n[self.training]
post_nms_top_n = self.post_nms_top_n[self.training]