Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@ SOURCES=$(wildcard src/*.c)

OBJS=$(SOURCES:%=$(BUILD)/%.o)



ifeq ($(UNAME),Darwin)
FRAMEWORKS_DIR=/Library/Developer/CommandLineTools/SDKs/MacOSX15.2.sdk/System/Library/Frameworks/
ACCELERATE_HEADERS=$(FRAMEWORKS_DIR)/Accelerate.framework/Versions/A/Headers
CFLAGS+=-I$(ACCELERATE_HEADERS) -DACCELERATE_NEW_LAPACK
LDFLAGS+=-dynamiclib -framework Accelerate
else
CFLAGS+=-fopenmp
CFLAGS+=-fopenmp -mavx2 -mfma
LDFLAGS+=-shared -fopenmp
endif


ifeq ($(UNAME),Darwin)
LIBRARY_NAME=libeel.dylib
else
Expand Down
41 changes: 13 additions & 28 deletions eel/eel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ctypes
import json
import platform
from pathlib import Path
from typing import Any

import numpy as np
Expand All @@ -9,8 +11,12 @@
from ._cdefs import Config, InferState, LayerWeights, Model, Weights


# TODO EDF don't hardcode
LibPath = "/Users/eric/Development/eel/build/libeel.dylib"
RootDir = Path(__file__).parent.parent

if platform.system() == "Darwin":
LibPath = RootDir / "build" / "libeel.dylib"
else:
LibPath = RootDir / "build" / "libeel.so"

lib = ctypes.CDLL(LibPath)

Expand Down Expand Up @@ -274,34 +280,13 @@ def load_model(config: Config, weights: Weights) -> Model:


def init_state(config: Config) -> InferState:
state = InferState()

state.x1 = (ctypes.c_float * config.size)()
state.x2 = (ctypes.c_float * config.size)()
state.x3 = (ctypes.c_float * config.size)()

state.h1 = (ctypes.c_float * config.ffn_hidden_size)()
state.h2 = (ctypes.c_float * config.ffn_hidden_size)()
state.h3 = (ctypes.c_float * config.ffn_hidden_size)()

kv_cache_size = (
config.num_layers * config.max_seq_len * config.num_kv_heads * config.head_size
lib.make_state.argtypes = (
ctypes.POINTER(Config),
)
state.k_cache = (ctypes.c_float * kv_cache_size)()
state.v_cache = (ctypes.c_float * kv_cache_size)()

q_size = config.num_q_heads * config.head_size
kv_size = config.num_kv_heads * config.head_size
state.q = (ctypes.c_float * q_size)()
state.k = (ctypes.c_float * kv_size)()
state.v = (ctypes.c_float * kv_size)()

state.score = (ctypes.c_float * config.max_seq_len)()
state.mha_out = (ctypes.c_float * q_size)()

state.logits = (ctypes.c_float * config.vocab_size)()
lib.make_state.restype = ctypes.POINTER(InferState)

return state
state_ptr = lib.make_state(ctypes.pointer(config))
return _deref(state_ptr)


def forward(model: Model, state: InferState, token: int, pos: int) -> np.array:
Expand Down
31 changes: 31 additions & 0 deletions src/eel.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@

#define EEL_DEBUG_STATE 0 // very slow!

struct InferState *make_state(struct Config *config) {
struct InferState *state = malloc(sizeof(struct InferState));

state->x1 = aligned_alloc(32, config->size * sizeof(float));
state->x2 = aligned_alloc(32, config->size * sizeof(float));
state->x3 = aligned_alloc(32, config->size * sizeof(float));

state->h1 = aligned_alloc(32, config->ffn_hidden_size * sizeof(float));
state->h2 = aligned_alloc(32, config->ffn_hidden_size * sizeof(float));
state->h3 = aligned_alloc(32, config->ffn_hidden_size * sizeof(float));

size_t kv_cache_size = (
config->num_layers * config->max_seq_len * config->num_kv_heads * config->head_size
);
state->k_cache = aligned_alloc(32, kv_cache_size * sizeof(float));
state->v_cache = aligned_alloc(32, kv_cache_size * sizeof(float));

size_t q_size = config->num_q_heads * config->head_size;
size_t kv_size = config->num_kv_heads * config->head_size;
state->q = aligned_alloc(32, q_size * sizeof(float));
state->k = aligned_alloc(32, kv_size * sizeof(float));
state->v = aligned_alloc(32, kv_size * sizeof(float));

state->score = aligned_alloc(32, config->max_seq_len * sizeof(float));
state->mha_out = aligned_alloc(32, q_size * sizeof(float));

state->logits = aligned_alloc(32, config->vocab_size * sizeof(float));

return state;
}

/**
* Forward one layer.
*
Expand Down
2 changes: 2 additions & 0 deletions src/eel.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ struct Model {
struct Weights *weights;
};

struct InferState *init_state(struct Config *config);

/**
* Forward pass of the model by one token.
*
Expand Down
66 changes: 66 additions & 0 deletions src/matmul.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "matmul.h"

#include <assert.h>
#include <stdint.h>

#ifdef __APPLE__
#include <Accelerate/Accelerate.h>
#endif
Expand All @@ -8,6 +11,10 @@
#include <omp.h>
#endif

#ifdef __AVX2__
#include <immintrin.h>
#endif

#if defined(__APPLE__)
void mva(const float *restrict A, const float *restrict x, const float *restrict b, float *restrict y, int M, int N)
{
Expand All @@ -20,6 +27,65 @@ void mv(const float *restrict A, const float *restrict x, float *restrict y, int
cblas_sgemv(CblasRowMajor, CblasNoTrans, M, N, 1.0f, A, N, x, 1, 0.0f, y, 1);
}

#elif defined(__AVX2__)

// see https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction
float hsum_ps_sse3(__m128 v) {
__m128 shuf = _mm_movehdup_ps(v); // broadcast elements 3,1 to 2,0
__m128 sums = _mm_add_ps(v, shuf);
shuf = _mm_movehl_ps(shuf, sums); // high half -> low half
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
}

float hsum256_ps_avx(__m256 v) {
__m128 vlow = _mm256_castps256_ps128(v);
__m128 vhigh = _mm256_extractf128_ps(v, 1); // high 128
vlow = _mm_add_ps(vlow, vhigh); // add the low 128
return hsum_ps_sse3(vlow); // and inline the sse3 version, which is optimal for AVX
// (no wasted instructions, and all of them are the 4B minimum)
}

void mva(const float *restrict A, const float *restrict x, const float *restrict b, float *restrict y, int M, int N)
{
assert((uintptr_t)A % 32 == 0);
assert((uintptr_t)x % 32 == 0);

#pragma omp parallel for
for (int i = 0; i < M; ++i)
{
const float *rowA = &A[i * N];
__m256 sum = _mm256_setzero_ps();

for (int j = 0; j < N; j += 8)
{
__m256 vA = _mm256_loadu_ps(&rowA[j]);
__m256 vx = _mm256_loadu_ps(&x[j]);

// do the fma operation. there isn't any speed difference between
// using FMA vs plain AVX multiply and add. not surprising since we're
// surely just memory bound loading from A & x.
#if 1
sum = _mm256_fmadd_ps(vA, vx, sum); // sum += A[j] * x[j]
#else
__m256 vAx = _mm256_mul_ps(vA, vx);
sum = _mm256_add_ps(sum, vAx);
#endif
}

if (b) {
y[i] = hsum256_ps_avx(sum) + b[i];
} else {
y[i] = hsum256_ps_avx(sum);
}
}
}

void mv(const float *restrict A, const float *restrict x, float *restrict y, int M, int N)
{
mva(A, x, NULL, y, M, N);
}

#else

void mva(const float *restrict A, const float *restrict x, const float *restrict b, float *restrict y, int M, int N)
Expand Down