From b02523d94d39f4439bc081d8097c054ca010d6ce Mon Sep 17 00:00:00 2001 From: Trenton Chang Date: Thu, 7 Jan 2021 22:57:01 -0800 Subject: [PATCH 1/3] Update _SimpleConsensus to use static autograd methods --- .../segmental_consensuses/simple_consensus.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py index 950fffb..2dddc48 100644 --- a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py +++ b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py @@ -1,33 +1,26 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ...registry import SEGMENTAL_CONSENSUSES class _SimpleConsensus(torch.autograd.Function): """Simplest segmental consensus module""" - def __init__(self, - consensus_type='avg', - dim=1): - super(_SimpleConsensus, self).__init__() - - assert consensus_type in ['avg'] - self.consensus_type = consensus_type - self.dim = dim - self.shape = None - - def forward(self, x): - self.shape = x.size() - if self.consensus_type == 'avg': - output = x.mean(dim=self.dim, keepdim=True) + @staticmethod + def forward(ctx, x, dim, consensus_type): + ctx.save_for_backward(x, dim, consensus_type) + if consensus_type == 'avg': + output = x.mean(dim=dim, keepdim=True) else: output = None return output - def backward(self, grad_output): - if self.consensus_type == 'avg': - grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim]) + @staticmethod + def backward(ctx, grad_output): + x, dim, consensus_type = ctx.saved_tensors + shape = x.size() + if consensus_type == 'avg': + grad_in = grad_output.expand(shape) / float(shape[dim]) else: grad_in = None return grad_in @@ -46,4 +39,6 @@ def init_weights(self): pass def forward(self, input): - return _SimpleConsensus(self.consensus_type, self.dim)(input) + return _SimpleConsensus.apply(input, + self.dim, + self.consensus_type) From c397a69cd71a8fff7d38a5b689f9795f35bb099a Mon Sep 17 00:00:00 2001 From: Shreyas Bhat Kera Date: Mon, 11 Jan 2021 15:29:28 +0530 Subject: [PATCH 2/3] Fix ctx-autograd problems for _SimpleConsensus Save dim and consensus_type as separate variables in ctx --- .../segmental_consensuses/simple_consensus.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py index 2dddc48..dfa5213 100644 --- a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py +++ b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...registry import SEGMENTAL_CONSENSUSES @@ -7,23 +8,28 @@ class _SimpleConsensus(torch.autograd.Function): """Simplest segmental consensus module""" @staticmethod - def forward(ctx, x, dim, consensus_type): - ctx.save_for_backward(x, dim, consensus_type) + def forward(ctx,x,dim,consensus_type): + ctx.dim = dim + ctx.consensus_type=consensus_type + ctx.save_for_backward(x) if consensus_type == 'avg': output = x.mean(dim=dim, keepdim=True) else: output = None return output + @staticmethod - def backward(ctx, grad_output): - x, dim, consensus_type = ctx.saved_tensors + def backward( ctx,grad_output): + x, = ctx.saved_tensors + dim = ctx.dim + consensus_type=ctx.consensus_type shape = x.size() if consensus_type == 'avg': grad_in = grad_output.expand(shape) / float(shape[dim]) else: grad_in = None - return grad_in + return grad_in, None , None @SEGMENTAL_CONSENSUSES.register_module From f761af1e61086733a882cc37e0556cb47116f574 Mon Sep 17 00:00:00 2001 From: Shreyas Bhat Kera Date: Mon, 11 Jan 2021 15:31:39 +0530 Subject: [PATCH 3/3] Fix ctx-autograd problems for _SimpleConsensus Removed unnecessary import --- .../models/tenons/segmental_consensuses/simple_consensus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py index dfa5213..ceaf1de 100644 --- a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py +++ b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F + from ...registry import SEGMENTAL_CONSENSUSES