Skip to content

Conversation

@WHoutstanding
Copy link
Contributor

PR Category

other

Description

Add FP32_ONLY_FUNCS = {
torch.nn.functional.softmax,
torch.nn.functional.layer_norm,
torch.nn.functional.group_norm,
torch.nn.functional.batch_norm,
torch.nn.functional.embedding,
torch.exp,
torch.log,
torch.pow,
torch.sigmoid,
torch.tanh,
torch.conv_transpose2d,
} to fix dtype generalization pass

@paddle-bot
Copy link

paddle-bot bot commented Jan 27, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Jan 27, 2026
@WHoutstanding
Copy link
Contributor Author

之前的代码逻辑中并没有修改黑名单算子如torch.layer_norm 的dtype,因此torch.layer_norm 报错 expected scalar type Float but found Half 的原因不在于 layer_norm 这个节点本身有没有被改写,而在于传给它的输入参数可能被上游节点转成了f16或Bf16。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant