Skip to content

RuntimeError: expected scalar type Half but found Float during FP16 inference #29

@ray24777

Description

@ray24777

Description

When running inference with DAT models using fp16 (half-precision), the model crashes during the spatial attention calculation. This happens because the dynamically generated attention mask does not automatically match the precision of the input tensors.

Steps to Reproduce

Load a DAT model at half precision with following command, and the error raises:

python ./test_code/inference.py  --model DAT --scale 4 --downsample_threshold 720 --weight_path pretrained/4x_APISR_DAT_GAN_generator.pth --float16_inference true --input_dir ...

Error Traceback

Traceback (most recent call last):
  File "/workspace/./test_code/inference.py", line 271, in <module>
    inner_loop(os.path.join(input_dir, filename))
  File "/workspace/./test_code/inference.py", line 254, in inner_loop
    super_resolve_img(generator, process_dir, output_path, weight_dtype, downsample_threshold, crop_for_4x=True)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/./test_code/inference.py", line 68, in super_resolve_img
    super_resolved_img = generator(img_lr)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/architecture/dat.py", line 861, in forward
    x = self.conv_after_body(self.forward_features(x)) + x
  File "/workspace/architecture/dat.py", line 839, in forward_features
    x = layer(x, x_size)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/architecture/dat.py", line 652, in forward
    x = blk(x, x_size)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/architecture/dat.py", line 568, in forward
    x = x + self.drop_path(self.attn(self.norm1(x), H, W))
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/architecture/dat.py", line 404, in forward
    x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device))
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/architecture/dat.py", line 244, in forward
    x = (attn @ v)
RuntimeError: expected scalar type Half but found Float

Cause

In architecture/dat.py, when the input resolution doesn't match the patches_resolution, a new mask is calculated via self.calculate_mask(_H, _W).
The mask is generated as a Float32 tensor by default. While it is correctly moved to the same device as the input, its dtype is not cast to Float16. When this mask is added to the attention scores and subsequently multiplied by the value tensor v (which is Half), PyTorch throws a type mismatch error.

Proposed Fix

The mask should be explicitly cast to x.dtype when moved to the device to ensure it supports both fp32 and fp16 workflows.

I have submitted a fix for this in PR #28.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions