Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ if (CUDAToolkit_FOUND)
# Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
# But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code.
# So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released.
#
# Set GGML_CUDA_BLACKWELL_CONSUMER=ON to skip this replacement for consumer Blackwell GPUs
# (e.g. RTX 5090, SM 12.0) that don't have FP4 tensor cores and will fault on 12Xa instructions.
option(GGML_CUDA_BLACKWELL_CONSUMER "Skip sm_12X→sm_12Xa replacement for consumer Blackwell" OFF)
if(NOT GGML_CUDA_BLACKWELL_CONSUMER)
foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE)
set(FIXED_ARCHS "")
foreach(ARCH IN LISTS ${ARCHS})
Expand All @@ -89,6 +94,7 @@ if (CUDAToolkit_FOUND)
endforeach()
set(${ARCHS} ${FIXED_ARCHS})
endforeach()
endif() # NOT GGML_CUDA_BLACKWELL_CONSUMER

# If we try to compile a "native" build it will use the 12X architectures and fail.
# So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa.
Expand All @@ -111,6 +117,18 @@ if (CUDAToolkit_FOUND)
file(GLOB SRCS "template-instances/mmf*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})

if(GGML_CUDA_BLACKWELL_CONSUMER)
# FP4 MMA kernels (mxfp4/nvfp4) require sm_120a instructions not present
# on consumer Blackwell (RTX 5090, SM 12.0).
list(REMOVE_ITEM GGML_SOURCES_CUDA
"${CMAKE_CURRENT_SOURCE_DIR}/template-instances/mmq-instance-mxfp4.cu"
"${CMAKE_CURRENT_SOURCE_DIR}/template-instances/mmq-instance-nvfp4.cu"
)
# Let dispatch code in mmq.cu know to skip FP4 cases.
add_compile_definitions(GGML_CUDA_BLACKWELL_CONSUMER)
message(STATUS "ggml-cuda: Excluding FP4 MMA kernels (GGML_CUDA_BLACKWELL_CONSUMER)")
endif()

if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
Expand Down
16 changes: 13 additions & 3 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
break;
case GGML_TYPE_NVFP4:
mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
#ifndef GGML_CUDA_BLACKWELL_CONSUMER
if (args.type_x == GGML_TYPE_MXFP4) {
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
} else {
mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
}
#else
GGML_ABORT("FP4 quantization requires sm_120a, not supported on consumer Blackwell (SM 12.0)");
#endif
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
Expand Down Expand Up @@ -277,6 +283,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
#ifdef GGML_CUDA_BLACKWELL_CONSUMER
mmq_supported = false;
break;
#endif
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
Expand Down