forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfix.patch
More file actions
125 lines (121 loc) · 5.15 KB
/
Copy pathfix.patch
File metadata and controls
125 lines (121 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index f96739657..e6a919e1e 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -236,6 +236,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
}
} break;
+ case 640: {
+ // Padded turbo KV cache for GLM-4.7 Flash (K head_dim=576 zero-padded to 640).
+ // D=640 shared memory (Q storage = ncols*(DKQ/2+4)*4) exceeds hardware limit at ncols1>=4.
+ // Cap at ncols1=2 (ncols=32): Q=32*324*4=41KB + KV≈37KB = ~78KB total.
+ GGML_ASSERT(V->ne[0] == 512);
+ if (Q->ne[1] <= 1) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 1, 16>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 2, 16>(ctx, dst);
+ }
+ } break;
default:
GGML_ABORT("fatal error");
break;
@@ -325,6 +336,51 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#endif // GGML_CUDA_FA_ALL_QUANTS
+ // TurboQuant3 KV cache types (always enabled)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0)
+
+ // Mixed turbo3/q8_0 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0)
+
+ // Mixed f16/turbo3 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO3_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_F16)
+
+ // TurboQuant2 KV cache types (always enabled)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0)
+
+ // Mixed turbo2/q8_0 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0)
+
+ // Mixed f16/turbo2 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO2_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_F16)
+
+ // Mixed turbo3/turbo2 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0)
+
+ // TurboQuant4 KV cache types (always enabled)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0)
+
+ // Mixed turbo4/q8_0 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0)
+
+ // Mixed f16/turbo4 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_F16)
+
+ // Mixed turbo4/turbo3 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0)
+
+ // Mixed turbo4/turbo2 KV cache types
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0)
+
GGML_ABORT("fatal error");
}
@@ -410,6 +466,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
break;
case 576:
+ case 640:
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
@@ -423,7 +480,16 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#ifndef GGML_CUDA_FA_ALL_QUANTS
if (K->type != V->type) {
- return BEST_FATTN_KERNEL_NONE;
+ // Allow mixed KV types for combinations that have FA template instances compiled in:
+ // - turbo2/3/4 + q8_0 (turbo cache work)
+ // - f16/bf16 + q8_0 (common K=f16, V=q8_0 setup)
+ auto is_kv_compat = [](ggml_type t) {
+ return t == GGML_TYPE_TURBO2_0 || t == GGML_TYPE_TURBO3_0 || t == GGML_TYPE_TURBO4_0
+ || t == GGML_TYPE_Q8_0 || t == GGML_TYPE_F16 || t == GGML_TYPE_BF16;
+ };
+ if (!is_kv_compat(K->type) || !is_kv_compat(V->type)) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
}
#endif // GGML_CUDA_FA_ALL_QUANTS
@@ -441,6 +507,24 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
case GGML_TYPE_Q8_0:
case GGML_TYPE_BF16:
break;
+ case GGML_TYPE_TURBO3_0:
+ // turbo3 VEC kernel instantiated for D in {64, 128, 256}.
+ if (K->ne[0] % 64 != 0) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ case GGML_TYPE_TURBO2_0:
+ // turbo2 VEC kernel instantiated for D in {64, 128, 256}.
+ if (K->ne[0] % 64 != 0) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ case GGML_TYPE_TURBO4_0:
+ // turbo4 VEC kernel instantiated for D in {64, 128, 256}.
+ if (K->ne[0] % 64 != 0) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
default:
return BEST_FATTN_KERNEL_NONE;
}