Created
July 23, 2020 11:35
-
-
Save mathildecaron31/bcd03b8864f7ca1aeb89dfe76a118b14 to your computer and use it in GitHub Desktop.
Running DETR with SwAV RN50 backbone: modification to https://github.com/facebookresearch/detr/blob/master/models/backbone.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
""" | |
Backbone modules. | |
""" | |
from collections import OrderedDict | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
from torch import nn | |
from torchvision.models._utils import IntermediateLayerGetter | |
from typing import Dict, List | |
from util.misc import NestedTensor, is_main_process | |
from .position_encoding import build_position_encoding | |
class FrozenBatchNorm2d(torch.nn.Module): | |
""" | |
BatchNorm2d where the batch statistics and the affine parameters are fixed. | |
Copy-paste from torchvision.misc.ops with added eps before rqsrt, | |
without which any other models than torchvision.models.resnet[18,34,50,101] | |
produce nans. | |
""" | |
def __init__(self, n): | |
super(FrozenBatchNorm2d, self).__init__() | |
self.register_buffer("weight", torch.ones(n)) | |
self.register_buffer("bias", torch.zeros(n)) | |
self.register_buffer("running_mean", torch.zeros(n)) | |
self.register_buffer("running_var", torch.ones(n)) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
num_batches_tracked_key = prefix + 'num_batches_tracked' | |
if num_batches_tracked_key in state_dict: | |
del state_dict[num_batches_tracked_key] | |
super(FrozenBatchNorm2d, self)._load_from_state_dict( | |
state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs) | |
def forward(self, x): | |
# move reshapes to the beginning | |
# to make it fuser-friendly | |
w = self.weight.reshape(1, -1, 1, 1) | |
b = self.bias.reshape(1, -1, 1, 1) | |
rv = self.running_var.reshape(1, -1, 1, 1) | |
rm = self.running_mean.reshape(1, -1, 1, 1) | |
eps = 1e-5 | |
scale = w * (rv + eps).rsqrt() | |
bias = b - rm * scale | |
return x * scale + bias | |
class BackboneBase(nn.Module): | |
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): | |
super().__init__() | |
for name, parameter in backbone.named_parameters(): | |
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: | |
parameter.requires_grad_(False) | |
if return_interm_layers: | |
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} | |
else: | |
return_layers = {'layer4': "0"} | |
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) | |
self.num_channels = num_channels | |
def forward(self, tensor_list: NestedTensor): | |
xs = self.body(tensor_list.tensors) | |
out: Dict[str, NestedTensor] = {} | |
for name, x in xs.items(): | |
m = tensor_list.mask | |
assert m is not None | |
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] | |
out[name] = NestedTensor(x, mask) | |
return out | |
class Backbone(BackboneBase): | |
"""ResNet backbone with frozen BatchNorm.""" | |
def __init__(self, name: str, | |
train_backbone: bool, | |
return_interm_layers: bool, | |
dilation: bool): | |
backbone = getattr(torchvision.models, name)( | |
replace_stride_with_dilation=[False, False, dilation], | |
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) | |
if is_main_process(): | |
assert name == 'resnet50' | |
state_dict = torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar", | |
map_location="cpu", | |
) | |
# optionnaly cleans "module." | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# load weights: you can check that the model loads correctly with `print(msg)` | |
msg = backbone.load_state_dict(state_dict, strict=False) | |
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 | |
super().__init__(backbone, train_backbone, num_channels, return_interm_layers) | |
class Joiner(nn.Sequential): | |
def __init__(self, backbone, position_embedding): | |
super().__init__(backbone, position_embedding) | |
def forward(self, tensor_list: NestedTensor): | |
xs = self[0](tensor_list) | |
out: List[NestedTensor] = [] | |
pos = [] | |
for name, x in xs.items(): | |
out.append(x) | |
# position encoding | |
pos.append(self[1](x).to(x.tensors.dtype)) | |
return out, pos | |
def build_backbone(args): | |
position_embedding = build_position_encoding(args) | |
train_backbone = args.lr_backbone > 0 | |
return_interm_layers = args.masks | |
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) | |
model = Joiner(backbone, position_embedding) | |
model.num_channels = backbone.num_channels | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment