-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
Hi! Thanks for the great work and releasing the code.
My question is that is it possible to adopt your temporal loss on other video tasks such as video semantic segmentation and video depth estimation? In those areas, most temporal losses are based on the optical flow warping loss, which is quite time consuming while training. Your temporal loss are used on RGB outputs. Is it possible to be extended to semantic results or depth maps?
By the way, is the temporal_loss_mode == 2 worse than temporal_loss_mode == 1 in your case? What's the reason for that case?
## use multi-scale relation-based loss
elif args.temporal_loss_mode == 1:
# blur image/area statistics/intensity
# k_sizes = [1, 3, 5, 7]
k_sizes = args.k_sizes
gt_errors = []
out_errors = []
for i in range(len(k_sizes)):
k_size = k_sizes[i]
avg_blur = nn.AvgPool2d(k_size, stride=1, padding=int((k_size - 1) / 2))
gt_error = avg_blur(label) - avg_blur(label_1)
out_error = avg_blur(out_img) - avg_blur(out_img_1)
gt_errors.append(gt_error)
out_errors.append(out_error)
gt_error_rgb_pixel_min = gt_errors[0]
out_error_rgb_pixel_min = out_errors[0]
for j in range(1, len(k_sizes)):
gt_error_rgb_pixel_min = torch.where(torch.abs(out_error_rgb_pixel_min) < torch.abs(out_errors[j]),
gt_error_rgb_pixel_min, gt_errors[j])
out_error_rgb_pixel_min = torch.where(torch.abs(out_error_rgb_pixel_min) < torch.abs(out_errors[j]),
out_error_rgb_pixel_min, out_errors[j])
loss_temporal = F.l1_loss(gt_error_rgb_pixel_min, out_error_rgb_pixel_min)
## Alternatively, combine relation-based loss at different scales with different weights
elif args.temporal_loss_mode == 2:
# blur image/area statistics/intensity
# k_sizes = [1, 3, 5, 7]
k_sizes = args.k_sizes
# k_weights = [0.25, 0.25, 0.25, 0.25]
k_weights = args.k_weights
loss_temporal = 0*loss
for i in range(len(k_sizes)):
k_size = k_sizes[i]
k_weight = k_weights[i]
avg_blur = nn.AvgPool2d(k_size, stride=1, padding=int((k_size - 1) / 2))
gt_error = avg_blur(label) - avg_blur(label_1)
out_error = avg_blur(out_img) - avg_blur(out_img_1)
loss_temporal = loss_temporal + F.l1_loss(gt_error, out_error) * k_weight
Metadata
Metadata
Assignees
Labels
No labels