Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions cmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# code adapted from https://github.com/Shengcao-Cao/CMT
import warnings
import torch
import torch.nn as nn
from torchvision.ops import roi_align


def data2boxes(data):
boxes = []
for i in range(len(data)):
boxes_i = data[i]['instances'].gt_boxes.tensor
if boxes_i.shape[0]:
indices = i * torch.ones((boxes_i.shape[0], 1), dtype=boxes_i.dtype, device=boxes_i.device)
boxes_i = torch.cat([indices, boxes_i], dim=1)
boxes.append(boxes_i)
if len(boxes):
boxes = torch.cat(boxes, dim=0)
return boxes
else:
return None


def data2labels(data):
labels = []
for i in range(len(data)):
labels_i = data[i]['instances'].gt_classes
if labels_i.shape[0]:
labels.append(labels_i)
labels = torch.cat(labels, dim=0)
return labels


def locate_feature_roialign(feature_map, boxes, image_width, image_height):
selected_features = []
sx = feature_map.shape[3] / image_width
sy = feature_map.shape[2] / image_height
if len(boxes):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
boxes_level = torch.tensor(boxes, device=feature_map.device)
boxes_level[:, 1] *= sx
boxes_level[:, 2] *= sy
boxes_level[:, 3] *= sx
boxes_level[:, 4] *= sy
selected_features_level = roi_align(feature_map, boxes_level, output_size=1, aligned=True)
selected_features_level = torch.flatten(selected_features_level, start_dim=1)
selected_features = selected_features_level
else:
selected_features = None
return selected_features


class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature

def forward(self, features, labels=None, mask=None, weights=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf

Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))

if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)

batch_size = features.shape[0]

if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)

contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask

# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos

if weights is not None:
loss = (loss.view(anchor_count, batch_size) * weights).sum() / weights.sum()
else:
loss = loss.view(anchor_count, batch_size).mean()

return loss


def get_cmt_losses(unlabeled_strong, unlabeled_weak, pseudolabeled_data, features_student, features_teacher, cmt_loss_weight):
unlabel_data_q, unlabel_data_k, pslabel_data_q = unlabeled_strong, unlabeled_weak, pseudolabeled_data,
all_unlabel_data = pslabel_data_q

feature_levels = ['p4', 'p5', 'p6'] # TODO: make configurable

supconloss = SupConLoss(contrast_mode='one')

# 7. CMT: object-level contrastive learning
record_dict = dict()
boxes = data2boxes(all_unlabel_data)
image_width = all_unlabel_data[0]['image'].shape[2]
image_height = all_unlabel_data[0]['image'].shape[1]

# filter objects that are too different in two views
if boxes is not None:
flags = []
for i in range(boxes.shape[0]):
box_i = boxes[i].to(torch.int)
image_index = box_i[0]
x1 = box_i[1]
y1 = box_i[2]
x2 = box_i[3]
y2 = box_i[4]
image_q_patch = unlabel_data_q[image_index]['image'][:, y1:y2, x1:x2].to(torch.float)
image_k_patch = unlabel_data_k[image_index]['image'][:, y1:y2, x1:x2].to(torch.float)
diff = (image_q_patch - image_k_patch).absolute().flatten()
ratio = (diff > 40).sum() / diff.numel()
if ratio > 0.5:
flags.append(0)
else:
flags.append(1)
else:
flags = [0]

if sum(flags):
# build contrastive loss
for feature_level in feature_levels:
# get student and teacher features for objects
object_features_student = locate_feature_roialign(features_student[feature_level], boxes, image_width, image_height)
object_features_teacher = locate_feature_roialign(features_teacher[feature_level], boxes, image_width, image_height)
object_features_student = torch.nn.functional.normalize(object_features_student, dim=1)
object_features_teacher = torch.nn.functional.normalize(object_features_teacher, dim=1)

# compute contrastive loss
object_features_all = torch.stack([object_features_student, object_features_teacher], dim=1)
object_labels = data2labels(all_unlabel_data)

# exclude unused objects
flags = [bool(x) for x in flags]
object_features_all = object_features_all[flags]
object_labels = object_labels[flags]
loss_contrastive_object = supconloss(object_features_all, object_labels)

# record contrastive loss
record_dict['loss_contrastive_object' + '_' + feature_level] = loss_contrastive_object * cmt_loss_weight

else:
for feature_level in feature_levels:
record_dict['loss_contrastive_object' + '_' + feature_level] = torch.tensor(0.0)

return record_dict
5 changes: 5 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def add_da_config(cfg):
_C.GRCNN.EFL_LAMBDA = [0.5, 0.5]
_C.GRCNN.MODEL_TYPE = "GAUSSIAN"

# Contrastive Mean Teacher (CMT) settings
_C.MODEL.CMT = CN()
_C.MODEL.CMT.ENABLED = False
_C.MODEL.CMT.CONTRASTIVE_LOSS_WEIGHT = 0.05

# We interpret SOLVER.IMS_PER_BATCH as the total batch size on all GPUs, for
# experimental consistency. Gradient accumulation is used according to
# num_gradient_accum_steps = IMS_PER_BATCH / (NUM_GPUS * IMS_PER_GPU)
Expand Down
5 changes: 5 additions & 0 deletions configs/cfc/cfc_priorart/AT-CFC.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_BASE_: "../MeanTeacher-CFC.yaml"
MODEL:
DA:
ENABLED: True
OUTPUT_DIR: "output/cfc_adaptiveteacher_nobackwardatend/"
12 changes: 12 additions & 0 deletions configs/cfc/cfc_priorart/CMT-AT-CFC.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_BASE_: "./AT-CFC.yaml"
MODEL:
CMT:
ENABLED: True
CONTRASTIVE_LOSS_WEIGHT: 0.01
SOLVER:
IMS_PER_BATCH: 48
IMS_PER_GPU: 2
BACKWARD_AT_END: True
INPUT:
MIN_SIZE_TRAIN: (1024,) # TODO: currently need identical image sizes inside a single batch to enable CMT
OUTPUT_DIR: "output/cfc_cmtat/"
12 changes: 12 additions & 0 deletions configs/cityscapes/cityscapes_priorart/CMT-AT-Cityscapes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_BASE_: "./AT-Cityscapes.yaml"
MODEL:
CMT:
ENABLED: True
CONTRASTIVE_LOSS_WEIGHT: 0.01
SOLVER:
IMS_PER_BATCH: 48
IMS_PER_GPU: 2
BACKWARD_AT_END: True
INPUT:
MIN_SIZE_TRAIN: (1024,) # TODO: currently need identical image sizes inside a single batch to enable CMT
OUTPUT_DIR: "output/cityscapes_cmtat/"
5 changes: 5 additions & 0 deletions configs/sim10k/sim10k_priorart/AT-Sim10k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_BASE_: "../MeanTeacher-Sim10k.yaml"
MODEL:
DA:
ENABLED: True
OUTPUT_DIR: "output/sim10k_adaptiveteacher_nobackwardatend/"
12 changes: 12 additions & 0 deletions configs/sim10k/sim10k_priorart/CMT-AT-Sim10k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_BASE_: "./AT-Sim10k.yaml"
MODEL:
CMT:
ENABLED: True
CONTRASTIVE_LOSS_WEIGHT: 0.01
SOLVER:
IMS_PER_BATCH: 48
IMS_PER_GPU: 2
BACKWARD_AT_END: True
INPUT:
MIN_SIZE_TRAIN: (1024,) # TODO: currently need identical image sizes inside a single batch to enable CMT
OUTPUT_DIR: "output/sim10k_cmtat/"
6 changes: 6 additions & 0 deletions rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
sada_heads: SADA = None,
dis_type: str = None,
dis_loss_weight: float = 0.0,
cmt_loss_weight: float = 0.0,
**kwargs
):
super(DARCNN, self).__init__(**kwargs)
Expand All @@ -51,6 +52,8 @@ def __init__(
if self.dis_type:
assert sada_heads is None, "Can't have both SADA heads and DA heads"
self.sada_heads = FCDiscriminator_img(self.backbone._out_feature_channels[self.dis_type])

self.cmt_loss_weight = cmt_loss_weight

# register hooks so we can grab output of sub-modules
self.backbone_io, self.rpn_io, self.roih_io, self.boxhead_io, self.boxpred_io = SaveIO(), SaveIO(), SaveIO(), SaveIO(), SaveIO()
Expand All @@ -77,6 +80,9 @@ def from_config(cls, cfg):
ret.update({"dis_type": cfg.MODEL.DA.DIS_TYPE,
"dis_loss_weight": cfg.MODEL.DA.DIS_LOSS_WEIGHT,
})

if cfg.MODEL.CMT.ENABLED:
ret.update({"cmt_loss_weight": cfg.MODEL.CMT.CONTRASTIVE_LOSS_WEIGHT})

return ret

Expand Down
23 changes: 21 additions & 2 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataloader import SaveWeakDatasetMapper, UnlabeledDatasetMapper, WeakStrongDataloader
from ema import EMA
from pseudolabeler import PseudoLabeler
from cmt import get_cmt_losses


DEBUG = False
Expand Down Expand Up @@ -46,8 +47,19 @@ def run_model_labeled_unlabeled(trainer, labeled_weak, labeled_strong, unlabeled
do_weak = labeled_weak is not None
do_strong = labeled_strong is not None
do_unlabeled = unlabeled_weak is not None and pseudo_labeler is not None
do_cmt = _model.cmt_loss_weight != 0
total_batch_size = sum([len(s or []) for s in [labeled_weak, labeled_strong, unlabeled_weak]])
num_grad_accum_steps = total_batch_size // model_batch_size
features_student = {}
features_teacher = {}

def merge_backbone_features(acc, val):
assert not acc or acc.keys() == val.keys()
for k in val.keys():
acc[k] = torch.cat((acc[k], val[k])) if k in acc else val[k]
return acc

assert not (do_cmt and not backward_at_end)

if DEBUG:
debug_dict['last_labeled_weak'] = copy.deepcopy(labeled_weak)
Expand Down Expand Up @@ -76,13 +88,15 @@ def maybe_do_backward(losses, key_conditional=lambda k: True):
losses = { k: v * 0 if not key_conditional(k) else v for k, v in losses.items() }
trainer.do_backward(sum(losses.values()) / num_grad_accum_steps, override=True)

def do_training_step(data, name="", key_conditional=lambda k: True, **kwargs):
def do_training_step(data, name="", key_conditional=lambda k: True, do_cmt= False, **kwargs):
"""Helper method to do a forward pass:
- Handle gradient accumulation and possible backward passes
- Handle Detectron2's loss dictionary
"""
for batch_i in range(0, len(data), model_batch_size):
loss = model(data[batch_i:batch_i+model_batch_size], **kwargs)
if do_cmt and name == "target_pseudolabeled":
merge_backbone_features(features_student, _model.backbone_io.output)
maybe_do_backward(loss, key_conditional)
add_to_loss_dict(loss, name, key_conditional)

Expand All @@ -104,10 +118,15 @@ def do_training_step(data, name="", key_conditional=lambda k: True, **kwargs):
for batch_i in range(0, len(unlabeled_weak), model_batch_size):
pseudolabeled_data.extend(pseudo_labeler(unlabeled_weak[batch_i:batch_i+model_batch_size],
unlabeled_strong[batch_i:batch_i+model_batch_size]))
do_training_step(pseudolabeled_data, "target_pseudolabeled", labeled=False, do_sada=False)
if do_cmt:
merge_backbone_features(features_teacher, pseudo_labeler.model.model.backbone_io.output)
do_training_step(pseudolabeled_data, "target_pseudolabeled", labeled=False, do_sada=False, do_cmt=do_cmt)
if DEBUG:
debug_dict['last_pseudolabeled'] = copy.deepcopy(pseudolabeled_data)

if do_cmt:
loss_dict.update(get_cmt_losses(unlabeled_strong, unlabeled_weak, pseudolabeled_data, features_student, features_teacher, _model.cmt_loss_weight))

return loss_dict


Expand Down