Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

add support for half precision gemm#32

Open
bjarthur wants to merge 1 commit into
FluxML:masterfrom
bjarthur:bja/float16
Open

add support for half precision gemm#32
bjarthur wants to merge 1 commit into
FluxML:masterfrom
bjarthur:bja/float16

Conversation

@bjarthur
Copy link
Copy Markdown

@bjarthur bjarthur commented Nov 16, 2021

in conjunction with FluxML/NNlib.jl#363, add support for half-precision gemm, for which a special kernel is provided by Nvidia. see JuliaGPU/CUDA.jl#1080

@mcabbott
Copy link
Copy Markdown
Member

Why do you say this is needed in addition? It looks like an alternative path. But the existing method NNlib._batched_gemm!(::Type{<:CuArray}, ought to match Float16 (if NNlib.jl would let it be called).

What would be good to add here is tests using this precision. Which I think should test the user-facing batched_mul not the internal functions.

@DhairyaLGandhi
Copy link
Copy Markdown
Member

Why would nnlib prevent it from getting called?

@bjarthur
Copy link
Copy Markdown
Author

the current code actually works with Float16, but falls back to batched_mul_generic! where a loop is performed over the last dimension. so painfully slow. i thought about tests, but couldn't come up with a way to test that the batched nvidia kernel is called instead.

@ToucheSir
Copy link
Copy Markdown
Member

Yup, the overriden method in NNlib uses BlasFloat, which does not include Float16. Now, one hang-up I see with this PR is that _batched_try_gemm! also only accepts BlasFloat. @bjarthur can you confirm this works locally without any errors?

@bjarthur
Copy link
Copy Markdown
Author

indeed, it does work locally without any errors, otherwise i would not have submitted it! ;)

@ToucheSir
Copy link
Copy Markdown
Member

Great, I think per @mcabbott's comment a test for this would be good :)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants