Commit 22881a82 authored by 李聪聪's avatar 李聪聪
Browse files

optimize

parent d9049d1e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
from .misc import FrozenBatchNorm2d
from .misc import FrozenBatchNorm2d, cat
from .losses import smooth_l1_loss
+10 −0
Original line number Diff line number Diff line
import torch


def cat(tensors, dim=0):
    """
    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
    """
    assert isinstance(tensors, (list, tuple))
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim)


class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters
+4 −3
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from torchvision import ops, models
from torchvision.ops import boxes as box_ops

from detection.layers import FrozenBatchNorm2d, smooth_l1_loss
from detection.layers import cat
from detection.modeling.utils import BalancedPositiveNegativeSampler, BoxCoder, Matcher


@@ -71,8 +72,8 @@ BOX_PREDICTORS = {


def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
    labels = torch.cat(labels, dim=0)
    regression_targets = torch.cat(regression_targets, dim=0)
    labels = cat(labels, dim=0)
    regression_targets = cat(regression_targets, dim=0)

    classification_loss = F.cross_entropy(class_logits, labels)

@@ -153,7 +154,7 @@ class BoxHead(nn.Module):
        device = class_logits.device

        boxes_per_image = [box.shape[0] for box in proposals]
        proposals = torch.cat([box for box in proposals])
        proposals = cat([box for box in proposals])
        pred_boxes = self.box_coder.decode(
            box_regression.view(sum(boxes_per_image), -1), proposals
        )
+6 −6
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from torchvision import ops

from torchvision.ops import boxes as box_ops

from detection.layers import smooth_l1_loss
from detection.layers import smooth_l1_loss, cat
from .utils import BalancedPositiveNegativeSampler, Matcher, BoxCoder
from .anchor_generator import AnchorGenerator

@@ -90,7 +90,7 @@ class RPN(nn.Module):
        objectness = objectness.sigmoid()

        box_regression = box_regression.permute(0, 2, 3, 1).reshape(N, H * W * A, 4)
        concat_anchors = torch.cat(anchors, dim=0)
        concat_anchors = cat(anchors, dim=0)
        concat_anchors = concat_anchors.reshape(N, A * H * W, 4)

        num_anchors = A * H * W
@@ -162,15 +162,15 @@ class RPN(nn.Module):

        sampled_pos_inds, sampled_neg_inds = self.sampler(labels)

        sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
        sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
        sampled_pos_inds = torch.nonzero(cat(sampled_pos_inds, dim=0)).squeeze(1)
        sampled_neg_inds = torch.nonzero(cat(sampled_neg_inds, dim=0)).squeeze(1)
        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.permute(0, 2, 3, 1).reshape(-1)
        box_regression = box_regression.permute(0, 2, 3, 1).reshape(-1, 4)

        labels = torch.cat(labels)
        regression_targets = torch.cat(regression_targets, dim=0)
        labels = cat(labels)
        regression_targets = cat(regression_targets, dim=0)

        box_loss = smooth_l1_loss(
            box_regression[sampled_pos_inds],