adaptive_weights = torch.fft.irfft2(torch.view_as_complex(DFT_map), dim=(1, 2)).reshape(batch_size, 1, RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR
adaptive_weights = torch.fft.irfft2(torch.view_as_complex(DFT_map), dim=(1, 2)).reshape(batch_size, 1,
RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR