Description
When running inference with DAT models using fp16 (half-precision), the model crashes during the spatial attention calculation. This happens because the dynamically generated attention mask does not automatically match the precision of the input tensors.
Steps to Reproduce
Load a DAT model at half precision with following command, and the error raises:
python ./test_code/inference.py --model DAT --scale 4 --downsample_threshold 720 --weight_path pretrained/4x_APISR_DAT_GAN_generator.pth --float16_inference true --input_dir ...
Error Traceback
Traceback (most recent call last):
File "/workspace/./test_code/inference.py", line 271, in <module>
inner_loop(os.path.join(input_dir, filename))
File "/workspace/./test_code/inference.py", line 254, in inner_loop
super_resolve_img(generator, process_dir, output_path, weight_dtype, downsample_threshold, crop_for_4x=True)
File "/workspace/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/./test_code/inference.py", line 68, in super_resolve_img
super_resolved_img = generator(img_lr)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/architecture/dat.py", line 861, in forward
x = self.conv_after_body(self.forward_features(x)) + x
File "/workspace/architecture/dat.py", line 839, in forward_features
x = layer(x, x_size)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/architecture/dat.py", line 652, in forward
x = blk(x, x_size)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/architecture/dat.py", line 568, in forward
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/architecture/dat.py", line 404, in forward
x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device))
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/architecture/dat.py", line 244, in forward
x = (attn @ v)
RuntimeError: expected scalar type Half but found Float
Cause
In architecture/dat.py, when the input resolution doesn't match the patches_resolution, a new mask is calculated via self.calculate_mask(_H, _W).
The mask is generated as a Float32 tensor by default. While it is correctly moved to the same device as the input, its dtype is not cast to Float16. When this mask is added to the attention scores and subsequently multiplied by the value tensor v (which is Half), PyTorch throws a type mismatch error.
Proposed Fix
The mask should be explicitly cast to x.dtype when moved to the device to ensure it supports both fp32 and fp16 workflows.
I have submitted a fix for this in PR #28.
Description
When running inference with DAT models using
fp16(half-precision), the model crashes during the spatial attention calculation. This happens because the dynamically generated attention mask does not automatically match the precision of the input tensors.Steps to Reproduce
Load a DAT model at half precision with following command, and the error raises:
python ./test_code/inference.py --model DAT --scale 4 --downsample_threshold 720 --weight_path pretrained/4x_APISR_DAT_GAN_generator.pth --float16_inference true --input_dir ...Error Traceback
Cause
In
architecture/dat.py, when the input resolution doesn't match thepatches_resolution, a new mask is calculated viaself.calculate_mask(_H, _W).The mask is generated as a
Float32tensor by default. While it is correctly moved to the same device as the input, its dtype is not cast toFloat16. When this mask is added to the attention scores and subsequently multiplied by the value tensorv(which isHalf), PyTorch throws a type mismatch error.Proposed Fix
The mask should be explicitly cast to
x.dtypewhen moved to the device to ensure it supports bothfp32andfp16workflows.I have submitted a fix for this in PR #28.