diff --git a/cmt.py b/cmt.py new file mode 100644 index 00000000..cf4ef611 --- /dev/null +++ b/cmt.py @@ -0,0 +1,209 @@ +# code adapted from https://github.com/Shengcao-Cao/CMT +import warnings +import torch +import torch.nn as nn +from torchvision.ops import roi_align + + +def data2boxes(data): + boxes = [] + for i in range(len(data)): + boxes_i = data[i]['instances'].gt_boxes.tensor + if boxes_i.shape[0]: + indices = i * torch.ones((boxes_i.shape[0], 1), dtype=boxes_i.dtype, device=boxes_i.device) + boxes_i = torch.cat([indices, boxes_i], dim=1) + boxes.append(boxes_i) + if len(boxes): + boxes = torch.cat(boxes, dim=0) + return boxes + else: + return None + + +def data2labels(data): + labels = [] + for i in range(len(data)): + labels_i = data[i]['instances'].gt_classes + if labels_i.shape[0]: + labels.append(labels_i) + labels = torch.cat(labels, dim=0) + return labels + + +def locate_feature_roialign(feature_map, boxes, image_width, image_height): + selected_features = [] + sx = feature_map.shape[3] / image_width + sy = feature_map.shape[2] / image_height + if len(boxes): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + boxes_level = torch.tensor(boxes, device=feature_map.device) + boxes_level[:, 1] *= sx + boxes_level[:, 2] *= sy + boxes_level[:, 3] *= sx + boxes_level[:, 4] *= sy + selected_features_level = roi_align(feature_map, boxes_level, output_size=1, aligned=True) + selected_features_level = torch.flatten(selected_features_level, start_dim=1) + selected_features = selected_features_level + else: + selected_features = None + return selected_features + + +class SupConLoss(nn.Module): + """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + It also supports the unsupervised contrastive loss in SimCLR""" + def __init__(self, temperature=0.07, contrast_mode='all', + base_temperature=0.07): + super(SupConLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None, mask=None, weights=None): + """Compute loss for model. If both `labels` and `mask` are None, + it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + + Args: + features: hidden vector of shape [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + Returns: + A loss scalar. + """ + device = (torch.device('cuda') + if features.is_cuda + else torch.device('cpu')) + + if len(features.shape) < 3: + raise ValueError('`features` needs to be [bsz, n_views, ...],' + 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).float().to(device) + else: + mask = mask.float().to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + if self.contrast_mode == 'one': + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), + self.temperature) + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(batch_size * anchor_count).view(-1, 1).to(device), + 0 + ) + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + + if weights is not None: + loss = (loss.view(anchor_count, batch_size) * weights).sum() / weights.sum() + else: + loss = loss.view(anchor_count, batch_size).mean() + + return loss + + +def get_cmt_losses(unlabeled_strong, unlabeled_weak, pseudolabeled_data, features_student, features_teacher, cmt_loss_weight): + unlabel_data_q, unlabel_data_k, pslabel_data_q = unlabeled_strong, unlabeled_weak, pseudolabeled_data, + all_unlabel_data = pslabel_data_q + + feature_levels = ['p4', 'p5', 'p6'] # TODO: make configurable + + supconloss = SupConLoss(contrast_mode='one') + + # 7. CMT: object-level contrastive learning + record_dict = dict() + boxes = data2boxes(all_unlabel_data) + image_width = all_unlabel_data[0]['image'].shape[2] + image_height = all_unlabel_data[0]['image'].shape[1] + + # filter objects that are too different in two views + if boxes is not None: + flags = [] + for i in range(boxes.shape[0]): + box_i = boxes[i].to(torch.int) + image_index = box_i[0] + x1 = box_i[1] + y1 = box_i[2] + x2 = box_i[3] + y2 = box_i[4] + image_q_patch = unlabel_data_q[image_index]['image'][:, y1:y2, x1:x2].to(torch.float) + image_k_patch = unlabel_data_k[image_index]['image'][:, y1:y2, x1:x2].to(torch.float) + diff = (image_q_patch - image_k_patch).absolute().flatten() + ratio = (diff > 40).sum() / diff.numel() + if ratio > 0.5: + flags.append(0) + else: + flags.append(1) + else: + flags = [0] + + if sum(flags): + # build contrastive loss + for feature_level in feature_levels: + # get student and teacher features for objects + object_features_student = locate_feature_roialign(features_student[feature_level], boxes, image_width, image_height) + object_features_teacher = locate_feature_roialign(features_teacher[feature_level], boxes, image_width, image_height) + object_features_student = torch.nn.functional.normalize(object_features_student, dim=1) + object_features_teacher = torch.nn.functional.normalize(object_features_teacher, dim=1) + + # compute contrastive loss + object_features_all = torch.stack([object_features_student, object_features_teacher], dim=1) + object_labels = data2labels(all_unlabel_data) + + # exclude unused objects + flags = [bool(x) for x in flags] + object_features_all = object_features_all[flags] + object_labels = object_labels[flags] + loss_contrastive_object = supconloss(object_features_all, object_labels) + + # record contrastive loss + record_dict['loss_contrastive_object' + '_' + feature_level] = loss_contrastive_object * cmt_loss_weight + + else: + for feature_level in feature_levels: + record_dict['loss_contrastive_object' + '_' + feature_level] = torch.tensor(0.0) + + return record_dict \ No newline at end of file diff --git a/config.py b/config.py index 3ef3c0a4..b9729e89 100644 --- a/config.py +++ b/config.py @@ -61,6 +61,11 @@ def add_da_config(cfg): _C.GRCNN.EFL_LAMBDA = [0.5, 0.5] _C.GRCNN.MODEL_TYPE = "GAUSSIAN" + # Contrastive Mean Teacher (CMT) settings + _C.MODEL.CMT = CN() + _C.MODEL.CMT.ENABLED = False + _C.MODEL.CMT.CONTRASTIVE_LOSS_WEIGHT = 0.05 + # We interpret SOLVER.IMS_PER_BATCH as the total batch size on all GPUs, for # experimental consistency. Gradient accumulation is used according to # num_gradient_accum_steps = IMS_PER_BATCH / (NUM_GPUS * IMS_PER_GPU) diff --git a/configs/cfc/cfc_priorart/AT-CFC.yaml b/configs/cfc/cfc_priorart/AT-CFC.yaml new file mode 100644 index 00000000..ecf81d38 --- /dev/null +++ b/configs/cfc/cfc_priorart/AT-CFC.yaml @@ -0,0 +1,5 @@ +_BASE_: "../MeanTeacher-CFC.yaml" +MODEL: + DA: + ENABLED: True +OUTPUT_DIR: "output/cfc_adaptiveteacher_nobackwardatend/" \ No newline at end of file diff --git a/configs/cfc/cfc_priorart/CMT-AT-CFC.yaml b/configs/cfc/cfc_priorart/CMT-AT-CFC.yaml new file mode 100644 index 00000000..dab91b4c --- /dev/null +++ b/configs/cfc/cfc_priorart/CMT-AT-CFC.yaml @@ -0,0 +1,12 @@ +_BASE_: "./AT-CFC.yaml" +MODEL: + CMT: + ENABLED: True + CONTRASTIVE_LOSS_WEIGHT: 0.01 +SOLVER: + IMS_PER_BATCH: 48 + IMS_PER_GPU: 2 + BACKWARD_AT_END: True +INPUT: + MIN_SIZE_TRAIN: (1024,) # TODO: currently need identical image sizes inside a single batch to enable CMT +OUTPUT_DIR: "output/cfc_cmtat/" \ No newline at end of file diff --git a/configs/cityscapes/cityscapes_priorart/CMT-AT-Cityscapes.yaml b/configs/cityscapes/cityscapes_priorart/CMT-AT-Cityscapes.yaml new file mode 100644 index 00000000..f49d7eab --- /dev/null +++ b/configs/cityscapes/cityscapes_priorart/CMT-AT-Cityscapes.yaml @@ -0,0 +1,12 @@ +_BASE_: "./AT-Cityscapes.yaml" +MODEL: + CMT: + ENABLED: True + CONTRASTIVE_LOSS_WEIGHT: 0.01 +SOLVER: + IMS_PER_BATCH: 48 + IMS_PER_GPU: 2 + BACKWARD_AT_END: True +INPUT: + MIN_SIZE_TRAIN: (1024,) # TODO: currently need identical image sizes inside a single batch to enable CMT +OUTPUT_DIR: "output/cityscapes_cmtat/" \ No newline at end of file diff --git a/configs/sim10k/sim10k_priorart/AT-Sim10k.yaml b/configs/sim10k/sim10k_priorart/AT-Sim10k.yaml new file mode 100644 index 00000000..6e3a3395 --- /dev/null +++ b/configs/sim10k/sim10k_priorart/AT-Sim10k.yaml @@ -0,0 +1,5 @@ +_BASE_: "../MeanTeacher-Sim10k.yaml" +MODEL: + DA: + ENABLED: True +OUTPUT_DIR: "output/sim10k_adaptiveteacher_nobackwardatend/" \ No newline at end of file diff --git a/configs/sim10k/sim10k_priorart/CMT-AT-Sim10k.yaml b/configs/sim10k/sim10k_priorart/CMT-AT-Sim10k.yaml new file mode 100644 index 00000000..5893fd55 --- /dev/null +++ b/configs/sim10k/sim10k_priorart/CMT-AT-Sim10k.yaml @@ -0,0 +1,12 @@ +_BASE_: "./AT-Sim10k.yaml" +MODEL: + CMT: + ENABLED: True + CONTRASTIVE_LOSS_WEIGHT: 0.01 +SOLVER: + IMS_PER_BATCH: 48 + IMS_PER_GPU: 2 + BACKWARD_AT_END: True +INPUT: + MIN_SIZE_TRAIN: (1024,) # TODO: currently need identical image sizes inside a single batch to enable CMT +OUTPUT_DIR: "output/sim10k_cmtat/" \ No newline at end of file diff --git a/rcnn.py b/rcnn.py index 565b542e..f76bc154 100644 --- a/rcnn.py +++ b/rcnn.py @@ -33,6 +33,7 @@ def __init__( sada_heads: SADA = None, dis_type: str = None, dis_loss_weight: float = 0.0, + cmt_loss_weight: float = 0.0, **kwargs ): super(DARCNN, self).__init__(**kwargs) @@ -51,6 +52,8 @@ def __init__( if self.dis_type: assert sada_heads is None, "Can't have both SADA heads and DA heads" self.sada_heads = FCDiscriminator_img(self.backbone._out_feature_channels[self.dis_type]) + + self.cmt_loss_weight = cmt_loss_weight # register hooks so we can grab output of sub-modules self.backbone_io, self.rpn_io, self.roih_io, self.boxhead_io, self.boxpred_io = SaveIO(), SaveIO(), SaveIO(), SaveIO(), SaveIO() @@ -77,6 +80,9 @@ def from_config(cls, cfg): ret.update({"dis_type": cfg.MODEL.DA.DIS_TYPE, "dis_loss_weight": cfg.MODEL.DA.DIS_LOSS_WEIGHT, }) + + if cfg.MODEL.CMT.ENABLED: + ret.update({"cmt_loss_weight": cfg.MODEL.CMT.CONTRASTIVE_LOSS_WEIGHT}) return ret diff --git a/trainer.py b/trainer.py index 15d38fe6..65faaa9b 100644 --- a/trainer.py +++ b/trainer.py @@ -18,6 +18,7 @@ from dataloader import SaveWeakDatasetMapper, UnlabeledDatasetMapper, WeakStrongDataloader from ema import EMA from pseudolabeler import PseudoLabeler +from cmt import get_cmt_losses DEBUG = False @@ -46,8 +47,19 @@ def run_model_labeled_unlabeled(trainer, labeled_weak, labeled_strong, unlabeled do_weak = labeled_weak is not None do_strong = labeled_strong is not None do_unlabeled = unlabeled_weak is not None and pseudo_labeler is not None + do_cmt = _model.cmt_loss_weight != 0 total_batch_size = sum([len(s or []) for s in [labeled_weak, labeled_strong, unlabeled_weak]]) num_grad_accum_steps = total_batch_size // model_batch_size + features_student = {} + features_teacher = {} + + def merge_backbone_features(acc, val): + assert not acc or acc.keys() == val.keys() + for k in val.keys(): + acc[k] = torch.cat((acc[k], val[k])) if k in acc else val[k] + return acc + + assert not (do_cmt and not backward_at_end) if DEBUG: debug_dict['last_labeled_weak'] = copy.deepcopy(labeled_weak) @@ -76,13 +88,15 @@ def maybe_do_backward(losses, key_conditional=lambda k: True): losses = { k: v * 0 if not key_conditional(k) else v for k, v in losses.items() } trainer.do_backward(sum(losses.values()) / num_grad_accum_steps, override=True) - def do_training_step(data, name="", key_conditional=lambda k: True, **kwargs): + def do_training_step(data, name="", key_conditional=lambda k: True, do_cmt= False, **kwargs): """Helper method to do a forward pass: - Handle gradient accumulation and possible backward passes - Handle Detectron2's loss dictionary """ for batch_i in range(0, len(data), model_batch_size): loss = model(data[batch_i:batch_i+model_batch_size], **kwargs) + if do_cmt and name == "target_pseudolabeled": + merge_backbone_features(features_student, _model.backbone_io.output) maybe_do_backward(loss, key_conditional) add_to_loss_dict(loss, name, key_conditional) @@ -104,10 +118,15 @@ def do_training_step(data, name="", key_conditional=lambda k: True, **kwargs): for batch_i in range(0, len(unlabeled_weak), model_batch_size): pseudolabeled_data.extend(pseudo_labeler(unlabeled_weak[batch_i:batch_i+model_batch_size], unlabeled_strong[batch_i:batch_i+model_batch_size])) - do_training_step(pseudolabeled_data, "target_pseudolabeled", labeled=False, do_sada=False) + if do_cmt: + merge_backbone_features(features_teacher, pseudo_labeler.model.model.backbone_io.output) + do_training_step(pseudolabeled_data, "target_pseudolabeled", labeled=False, do_sada=False, do_cmt=do_cmt) if DEBUG: debug_dict['last_pseudolabeled'] = copy.deepcopy(pseudolabeled_data) + if do_cmt: + loss_dict.update(get_cmt_losses(unlabeled_strong, unlabeled_weak, pseudolabeled_data, features_student, features_teacher, _model.cmt_loss_weight)) + return loss_dict