diff --git a/ds4_server.c b/ds4_server.c index 435491fe..aa5972a3 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -7617,6 +7617,8 @@ struct server { ds4_engine *engine; ds4_session *session; int default_tokens; + bool batching; + int batch_size; kv_disk_cache kv; tool_memory tool_mem; live_tool_state responses_live; @@ -7650,6 +7652,22 @@ struct job { job *next; }; +typedef struct { + job *j; + ds4_session *session; + char id[96]; + char err[160]; + buf text; + size_t plain_stream_pos; + size_t stop_scan_from; + const char *finish; + int prompt_tokens; + int completion; + int max_tokens; + uint64_t rng; + double t0; +} batch_decode_job; + /* ========================================================================= * Tool Call Text Memory. * ========================================================================= @@ -10582,6 +10600,13 @@ static bool enqueue(server *s, job *j) { return true; } +static void job_signal_done(job *j) { + pthread_mutex_lock(&j->mu); + j->done = true; + pthread_cond_signal(&j->cv); + pthread_mutex_unlock(&j->mu); +} + static job *dequeue(server *s) { pthread_mutex_lock(&s->mu); while (!s->head && !s->stopping) pthread_cond_wait(&s->cv, &s->mu); @@ -10597,16 +10622,244 @@ static job *dequeue(server *s) { return j; } +static job *dequeue_ready(server *s) { + pthread_mutex_lock(&s->mu); + job *j = s->head; + if (j) { + s->head = j->next; + if (!s->head) s->tail = NULL; + } + pthread_mutex_unlock(&s->mu); + if (j) j->next = NULL; + return j; +} + +static bool batch_job_supported(const job *j) { + const request *r = j ? &j->req : NULL; + return r && r->stream && r->api == API_OPENAI && //!r->has_tools && + //!ds4_think_mode_enabled(r->think_mode) && + (r->kind == REQ_CHAT || r->kind == REQ_COMPLETION); +} + +static bool batch_decode_start(server *s, batch_decode_job *b, job *j) { + memset(b, 0, sizeof(*b)); + b->j = j; + b->finish = "length"; + b->t0 = now_sec(); + b->prompt_tokens = j->req.prompt.len; + b->err[0] = '\0'; + if (ds4_session_create(&b->session, s->engine, ds4_session_ctx(s->session)) != 0) { + http_error(j->fd, s->enable_cors, 500, "failed to create batched session"); + return false; + } + j->req.cache_read_tokens = 0; + j->req.cache_write_tokens = b->prompt_tokens; + if (ds4_session_sync(b->session, &j->req.prompt, b->err, sizeof(b->err)) != 0) { + http_error(j->fd, s->enable_cors, 500, b->err[0] ? b->err : "prefill failed"); + return false; + } + snprintf(b->id, sizeof(b->id), "%s-%llu", + j->req.kind == REQ_CHAT ? "chatcmpl" : "cmpl", + (unsigned long long)++s->seq); + if (!sse_headers(j->fd, s->enable_cors)) { + snprintf(b->err, sizeof(b->err), "client stream write failed"); + return false; + } + if (j->req.kind == REQ_CHAT && !sse_chunk(j->fd, &j->req, b->id, NULL, NULL)) { + snprintf(b->err, sizeof(b->err), "client stream write failed"); + return false; + } + int room = ds4_session_ctx(b->session) - ds4_session_pos(b->session); + b->max_tokens = j->req.max_tokens; + if (b->max_tokens < 0) b->max_tokens = 0; + if (b->max_tokens > room) b->max_tokens = room; + b->rng = j->req.seed ? j->req.seed : + (((uint64_t)time(NULL) << 32) ^ ((uint64_t)s->seq << 1) ^ (uint64_t)(uintptr_t)j); + server_log(DS4_LOG_GENERATION, + "ds4-server: batching start %s prompt=%d max=%d active_limit=%d", + j->req.kind == REQ_CHAT ? "chat" : "completion", + b->prompt_tokens, + b->max_tokens, + s->batch_size); + return true; +} + +static bool batch_decode_step(server *s, batch_decode_job *b) { + job *j = b->j; + if (b->completion >= b->max_tokens || + ds4_session_pos(b->session) >= ds4_session_ctx(b->session)) { + b->finish = "length"; + return true; + } + + int token = ds4_session_sample(b->session, j->req.temperature, j->req.top_k, + j->req.top_p, j->req.min_p, &b->rng); + if (token == ds4_token_eos(s->engine)) { + b->finish = "stop"; + return true; + } + if (ds4_session_eval(b->session, token, b->err, sizeof(b->err)) != 0) { + b->finish = "error"; + return true; + } + + size_t piece_len = 0; + char *piece = ds4_token_text(s->engine, token, &piece_len); + b->completion++; + buf_append(&b->text, piece, piece_len); + free(piece); + + size_t stop_pos = 0, stop_len = 0; + bool hit_stop = stop_list_find_from(&j->req.stops, b->text.ptr, + b->stop_scan_from, + &stop_pos, &stop_len); + size_t stream_len = hit_stop ? + stop_pos : stop_list_stream_safe_len(&j->req.stops, b->text.len); + if (stream_len > b->text.len) stream_len = b->text.len; + stream_len = utf8_stream_safe_len(b->text.ptr, b->plain_stream_pos, + stream_len, hit_stop); + if (!hit_stop && j->req.stops.max_len > 1) { + const size_t hold = j->req.stops.max_len - 1; + b->stop_scan_from = b->text.len > hold ? b->text.len - hold : 0; + } + if (stream_len > b->plain_stream_pos) { + char *delta = xstrndup(b->text.ptr + b->plain_stream_pos, + stream_len - b->plain_stream_pos); + bool ok = sse_chunk(j->fd, &j->req, b->id, delta, NULL); + free(delta); + if (!ok) { + b->finish = "error"; + snprintf(b->err, sizeof(b->err), "client stream write failed"); + return true; + } + b->plain_stream_pos = stream_len; + } + if (hit_stop) { + (void)stop_len; + b->finish = "stop"; + b->text.len = stop_pos; + if (b->text.ptr) b->text.ptr[b->text.len] = '\0'; + ds4_session_invalidate(b->session); + return true; + } + return b->completion >= b->max_tokens; +} + +static void batch_decode_cleanup(batch_decode_job *b) { + ds4_session_free(b->session); + buf_free(&b->text); + memset(b, 0, sizeof(*b)); +} + +static void batch_decode_finish(server *s, batch_decode_job *b) { + job *j = b->j; + if (j->req.stream && b->text.len > b->plain_stream_pos) { + char *tail = xstrndup(b->text.ptr + b->plain_stream_pos, + b->text.len - b->plain_stream_pos); + if (!sse_chunk(j->fd, &j->req, b->id, tail, NULL)) b->finish = "error"; + free(tail); + } + if (j->req.stream) { + if (!sse_chunk(j->fd, &j->req, b->id, NULL, b->finish) || + !sse_done(j->fd, &j->req, b->id, b->prompt_tokens, b->completion)) { + server_log(DS4_LOG_DEFAULT, + "ds4-server: batching final stream failed"); + } + } else { + final_response(j->fd, s->enable_cors, &j->req, b->id, + b->text.ptr ? b->text.ptr : "", NULL, NULL, + b->finish, b->prompt_tokens, b->completion); + } + if (!strcmp(b->finish, "error") && b->err[0]) { + server_log(DS4_LOG_GENERATION, + "ds4-server: batching %s gen=%d finish=%s error=\"%s\" %.3fs", + j->req.kind == REQ_CHAT ? "chat" : "completion", + b->completion, + b->finish, + b->err, + now_sec() - b->t0); + } else { + server_log(DS4_LOG_GENERATION, + "ds4-server: batching %s gen=%d finish=%s %.3fs", + j->req.kind == REQ_CHAT ? "chat" : "completion", + b->completion, + b->finish, + now_sec() - b->t0); + } + batch_decode_cleanup(b); +} + +static void batch_decode_remove(batch_decode_job *active, int *nactive, int idx) { + for (int i = idx + 1; i < *nactive; i++) active[i - 1] = active[i]; + (*nactive)--; +} + +static bool worker_batch_admit(server *s, batch_decode_job *active, int *nactive, + int cap, bool block_if_empty) { + bool admitted = false; + while (*nactive < cap) { + job *j = (*nactive == 0 && block_if_empty) ? dequeue(s) : dequeue_ready(s); + if (!j) break; + if (!batch_job_supported(j)) { + generate_job(s, j); + job_signal_done(j); + admitted = true; + continue; + } + if (!batch_decode_start(s, &active[*nactive], j)) { + batch_decode_cleanup(&active[*nactive]); + job_signal_done(j); + admitted = true; + continue; + } + (*nactive)++; + admitted = true; + } + return admitted; +} + +static void *worker_main_batched(void *arg) { + server *s = arg; + int cap = s->batch_size > 0 ? s->batch_size : 2; + batch_decode_job *active = xmalloc((size_t)cap * sizeof(active[0])); + memset(active, 0, (size_t)cap * sizeof(active[0])); + int nactive = 0; + server_log(DS4_LOG_DEFAULT, + "ds4-server: continuous batching enabled batch_size=%d", cap); + for (;;) { + worker_batch_admit(s, active, &nactive, cap, nactive == 0); + if (nactive == 0) { + pthread_mutex_lock(&s->mu); + bool stopping = s->stopping && !s->head; + pthread_mutex_unlock(&s->mu); + if (stopping) break; + continue; + } + worker_batch_admit(s, active, &nactive, cap, false); + for (int i = 0; i < nactive;) { + bool done = batch_decode_step(s, &active[i]); + if (!done) { + i++; + continue; + } + job *j = active[i].j; + batch_decode_finish(s, &active[i]); + job_signal_done(j); + batch_decode_remove(active, &nactive, i); + } + } + free(active); + return NULL; +} + static void *worker_main(void *arg) { server *s = arg; + if (s->batching) return worker_main_batched(arg); for (;;) { job *j = dequeue(s); if (!j) break; generate_job(s, j); - pthread_mutex_lock(&j->mu); - j->done = true; - pthread_cond_signal(&j->cv); - pthread_mutex_unlock(&j->mu); + job_signal_done(j); } return NULL; } @@ -10914,6 +11167,8 @@ typedef struct { bool disable_exact_dsml_tool_replay; int tool_memory_max_ids; bool enable_cors; + bool batching; + int batch_size; } server_config; static int parse_int_arg(const char *s, const char *opt) { @@ -11029,6 +11284,10 @@ static void usage(FILE *fp) { " Add Access-Control-Allow-* headers for browser JS clients. Does not change --host.\n" " --trace FILE\n" " Write a human-readable session trace: prompts, cache decisions, output, tool calls.\n" + " --batching\n" + " Enable continuous batching for simple OpenAI streaming requests.\n" + " --batch-size N\n" + " Maximum concurrently active batched requests. Default: 2\n" "\n" "Thinking and sampling:\n" " DeepSeek-compatible chat requests default to thinking mode with high effort.\n" @@ -11110,6 +11369,7 @@ static server_config parse_options(int argc, char **argv) { .ctx_size = 32768, .default_tokens = 393216, .tool_memory_max_ids = DS4_TOOL_MEMORY_DEFAULT_MAX_IDS, + .batch_size = 2, }; c.kv_cache = kv_cache_default_options(); @@ -11143,6 +11403,10 @@ static server_config parse_options(int argc, char **argv) { c.enable_cors = true; } else if (!strcmp(arg, "--trace")) { c.trace_path = need_arg(&i, argc, argv, arg); + } else if (!strcmp(arg, "--batching")) { + c.batching = true; + } else if (!strcmp(arg, "--batch-size")) { + c.batch_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--kv-disk-dir")) { c.kv_disk_dir = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--kv-disk-space-mb")) { @@ -11237,6 +11501,8 @@ int main(int argc, char **argv) { s.engine = engine; s.session = session; s.default_tokens = cfg.default_tokens; + s.batching = cfg.batching; + s.batch_size = cfg.batch_size; s.disable_exact_dsml_tool_replay = cfg.disable_exact_dsml_tool_replay; s.tool_mem.max_entries = cfg.tool_memory_max_ids; s.enable_cors = cfg.enable_cors; diff --git a/tests/ds4_test.c b/tests/ds4_test.c index 16d14593..f3f312cb 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -178,6 +178,450 @@ static char *test_read_file(const char *path) { return s; } +static void test_send_all_or_fail(int fd, const char *data, size_t len) { + size_t off = 0; + while (off < len) { + ssize_t n = send(fd, data + off, len - off, 0); + if (n < 0 && errno == EINTR) continue; + TEST_ASSERT(n > 0); + if (n <= 0) return; + off += (size_t)n; + } +} + +static void test_set_nonblocking_or_fail(int fd) { + int flags = fcntl(fd, F_GETFL, 0); + TEST_ASSERT(flags >= 0); + if (flags < 0) return; + TEST_ASSERT(fcntl(fd, F_SETFL, flags | O_NONBLOCK) == 0); +} + +static void test_server_cleanup_keep_engine(server *s) { + if (!s) return; + if (s->trace) { + fclose(s->trace); + s->trace = NULL; + } + kv_cache_close(&s->kv); + tool_memory_free(&s->tool_mem); + live_tool_state_free(&s->responses_live); + live_tool_state_free(&s->anthropic_live); + visible_live_free(&s->thinking_live); + pthread_mutex_destroy(&s->tool_mu); + pthread_mutex_destroy(&s->trace_mu); + pthread_cond_destroy(&s->clients_cv); + pthread_cond_destroy(&s->cv); + pthread_mutex_destroy(&s->mu); + ds4_session_free(s->session); + memset(s, 0, sizeof(*s)); +} + +static void test_server_init_live(server *s, ds4_engine *engine, int ctx_size, + const char *trace_path) { + memset(s, 0, sizeof(*s)); + s->engine = engine; + s->default_tokens = 256; + s->tool_mem.max_entries = DS4_TOOL_MEMORY_DEFAULT_MAX_IDS; + TEST_ASSERT(ds4_session_create(&s->session, engine, ctx_size) == 0); + if (!s->session) return; + pthread_mutex_init(&s->mu, NULL); + pthread_cond_init(&s->cv, NULL); + pthread_cond_init(&s->clients_cv, NULL); + pthread_mutex_init(&s->tool_mu, NULL); + pthread_mutex_init(&s->trace_mu, NULL); + if (trace_path) { + s->trace = fopen(trace_path, "w"); + TEST_ASSERT(s->trace != NULL); + if (!s->trace) return; + setvbuf(s->trace, NULL, _IONBF, 0); + server_log(DS4_LOG_DEFAULT, "ds4-server: tracing session to %s", trace_path); + } +} + +typedef struct { + int fd; + bool saw_bytes; + bool saw_done; + bool eof; + double send_done_at; + double first_byte_at; + double last_byte_at; + double done_at; + double eof_at; + buf raw; +} test_stream_capture; + +static double test_wall_sec(void) { + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + return (double)ts.tv_sec + (double)ts.tv_nsec * 1e-9; +} + +static void test_sleep_ms(long ms) { + struct timespec req = { + .tv_sec = ms / 1000, + .tv_nsec = (ms % 1000) * 1000000L, + }; + while (nanosleep(&req, &req) != 0 && errno == EINTR) {} +} + +static void test_stream_capture_append(test_stream_capture *cap, + const char *data, size_t len) { + if (!cap || !data || !len) return; + if (!cap->saw_bytes) { + cap->saw_bytes = true; + cap->first_byte_at = test_wall_sec(); + } + cap->last_byte_at = test_wall_sec(); + buf_append(&cap->raw, data, len); + if (!cap->saw_done && cap->raw.ptr && strstr(cap->raw.ptr, "data: [DONE]")) { + cap->saw_done = true; + cap->done_at = test_wall_sec(); + } +} + +static void test_stream_capture_close(test_stream_capture *cap) { + if (!cap || cap->fd < 0) return; + cap->eof_at = test_wall_sec(); + close(cap->fd); + cap->fd = -1; + cap->eof = true; +} + +static size_t test_count_nonempty_lines(const char *text) { + size_t lines = 0; + const char *p = text ? text : ""; + while (*p) { + const char *line_end = strchr(p, '\n'); + size_t len = line_end ? (size_t)(line_end - p) : strlen(p); + while (len > 0 && (p[len - 1] == '\r' || p[len - 1] == '\n')) len--; + if (len > 0) lines++; + if (!line_end) break; + p = line_end + 1; + } + return lines; +} + +static void test_server_log_multiline(const char *prefix, const char *text) { + const char *p = text ? text : ""; + while (*p) { + const char *line_end = strchr(p, '\n'); + size_t len = line_end ? (size_t)(line_end - p) : strlen(p); + if (len > 0) { + char *line = xstrndup(p, len); + server_log(DS4_LOG_DEFAULT, "%s%s", prefix ? prefix : "", line); + free(line); + } + if (!line_end) break; + p = line_end + 1; + } +} + +static void test_log_stream_capture_server(const char *label, + const test_stream_capture *cap, + const char *text) { + server_log(DS4_LOG_DEFAULT, + "ds4-test: %s send_done_ts=%.6f first_byte_ts=%.6f done_marker_ts=%.6f last_byte_ts=%.6f eof_ts=%.6f", + label, + cap->send_done_at, + cap->first_byte_at, + cap->done_at > 0.0 ? cap->done_at : -1.0, + cap->last_byte_at, + cap->eof_at > 0.0 ? cap->eof_at : -1.0); + server_log(DS4_LOG_DEFAULT, + "ds4-test: %s raw_response_after_eof", + label); + test_server_log_multiline("ds4-test: ", text ? text : ""); +} + +static void test_log_stream_timing_summary_server(const char *label, + const test_stream_capture *cap) { + server_log(DS4_LOG_DEFAULT, + "ds4-test: final_timing_summary %s send_done_ts=%.6f first_byte_ts=%.6f done_marker_ts=%.6f last_byte_ts=%.6f eof_ts=%.6f", + label, + cap->send_done_at, + cap->first_byte_at, + cap->done_at > 0.0 ? cap->done_at : -1.0, + cap->last_byte_at, + cap->eof_at > 0.0 ? cap->eof_at : -1.0); +} + +static char *test_build_chat_http_request(const char *prompt, bool stream) { + buf user_prompt = {0}; + buf body = {0}; + buf req = {0}; + buf_puts(&user_prompt, prompt ? prompt : ""); + buf_puts(&user_prompt, + "\n\nFormato obbligatorio: restituisci tutte e sole le 50 parole " + "della lista che iniziano con c. L'ordine non importa. Puoi " + "separarle con spazi o nuove righe. Niente spiegazioni o altre parole."); + buf_puts(&body, + "{\"model\":\"deepseek-chat\"," + "\"messages\":[{\"role\":\"user\",\"content\":"); + json_escape(&body, user_prompt.ptr ? user_prompt.ptr : ""); + buf_printf(&body, + "}],\"max_tokens\":512,\"temperature\":0," + "\"stream\":%s,\"thinking\":false}", + stream ? "true" : "false"); + buf_printf(&req, + "POST /v1/chat/completions HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Type: application/json\r\n" + "Content-Length: %zu\r\n" + "\r\n", + body.len); + buf_append(&req, body.ptr, body.len); + buf_free(&user_prompt); + buf_free(&body); + return req.ptr; +} + +static void test_drive_single_stream(test_stream_capture *cap) { + struct pollfd pfd; + double start = now_sec(); + + memset(&pfd, 0, sizeof(pfd)); + pfd.fd = cap->fd; + pfd.events = POLLIN | POLLHUP; + while (!cap->eof) { + int rc = poll(&pfd, 1, 1000); + if (rc < 0 && errno == EINTR) continue; + TEST_ASSERT(rc >= 0); + TEST_ASSERT(now_sec() - start < 180.0); + if (rc <= 0) continue; + if (!(pfd.revents & (POLLIN | POLLHUP))) continue; + for (;;) { + char tmp[4096]; + ssize_t n = recv(cap->fd, tmp, sizeof(tmp), 0); + if (n < 0 && errno == EINTR) continue; + if (n < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) break; + TEST_ASSERT(n >= 0); + if (n < 0) { + test_stream_capture_close(cap); + break; + } + if (n == 0) { + test_stream_capture_close(cap); + break; + } + test_stream_capture_append(cap, tmp, (size_t)n); + } + } +} + +static void test_drive_two_streams(test_stream_capture *a, + test_stream_capture *b) { + struct pollfd pfds[2]; + double start = now_sec(); + + memset(pfds, 0, sizeof(pfds)); + pfds[0].fd = a->fd; + pfds[0].events = POLLIN | POLLHUP; + pfds[1].fd = b->fd; + pfds[1].events = POLLIN | POLLHUP; + + while (!a->eof || !b->eof) { + int rc = poll(pfds, 2, 1000); + if (rc < 0 && errno == EINTR) continue; + TEST_ASSERT(rc >= 0); + TEST_ASSERT(now_sec() - start < 180.0); + if (rc <= 0) continue; + + for (int i = 0; i < 2; i++) { + test_stream_capture *cap = i == 0 ? a : b; + if (cap->eof) continue; + if (!(pfds[i].revents & (POLLIN | POLLHUP))) continue; + for (;;) { + char tmp[4096]; + ssize_t n = recv(cap->fd, tmp, sizeof(tmp), 0); + if (n < 0 && errno == EINTR) continue; + if (n < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) break; + TEST_ASSERT(n >= 0); + if (n < 0) { + test_stream_capture_close(cap); + break; + } + if (n == 0) { + test_stream_capture_close(cap); + break; + } + test_stream_capture_append(cap, tmp, (size_t)n); + } + } + } +} + +static const char test_server_word_filter_prompt[] = + "Ti passo una lista di parole. Di queste elencami le parole, solo quelle, senza altri frasi di spiegazione, che iniziano per il carattere c. una parola per ogni linea, mi aspetto 50 linee perche' abbimoa 50 parole che soddisfano il requisito.\n" + "cable cactus camera candle cannon canvas captain carbon castle catalog celery center ceremony champion channel chapter charity cheetah cherry chimney chorus circle citizen clarity classic climate closet cluster coastal coconut coffee college comfort comic compass concert condor control cookie corner cotton country courage cradle crystal culture curtain custom cyclone cylinder able about above absurd adapt admit adult afraid agent agree airport album alert alien alley almost alpha always amber amount anchor angel animal answer anyone apart april arena argue arise around artist aspect attack august author autumn avenue await banana barrel basket battle beauty behalf behind belief belong benefit beyond binary bishop blanket border borrow bottle bottom branch breeze bridge bright broken budget buffer bullet bundle button buyer damage danger daring debate decade defeat defend define degree demand depart depend desert design detail device dialog differ dinner direct disease display distant divide dollar domain dragon drawer dream driven during eager early earth easily editor effect effort eighth either elder elegant element elite embark emotion empire enable ending energy engine enjoy enough ensure entire envelope episode equal escape estate ethics evening fabric factor failure fairly family famous father feature fellow female fiction filter final finger finish fiscal flavor flight flower follow forest formal forward fragile freedom friday future galaxy gallery garden gather general gentle genuine gesture ginger global golden govern grammar harbor harmony hazard height hidden holiday honest hunger hybrid ideal ignore illegal imagine impact import improve include infant inform inherit initial inquiry inside inspire instead intense island jacket jungle kernel ladder language lawyer leader legend liberty light linear little magnet manager manual market master matter memory mental middle minute modern monkey mother mountain musical mystery narrow nation native nature nearby normal notice number object office online open opera option oral order organ origin output owner panel paper parent part party phase phone photo piano piece pilot place plain plane plant plate player point power press price prime print prior prize proof proud prove public punch pupil radio range rapid ratio ready realm reason reply report result retail review river round route royal rural scale share shift shirt shock short signal silver simple single sister skill sleep slide small smart smile solid solve sorry sound south space speak speed spend split sport staff stage stand start state steam steel stock stone store story style sugar suite super sweet table taste teach thank theme thick thing think third those throw tiger title today topic total touch tough tower trade train treat trend trial trust truth twice union unity value video virus visit vital voice waste watch water wheel where which while white whole woman world worry worth write wrong yield young youth"; + +static const char test_server_trace_path[] = "/tmp/ds4-trace.txt"; + +static char *test_server_word_filter_prompt_dup(void) { + return xstrdup(test_server_word_filter_prompt); +} + +static void test_server_single_request_word_filter(void) { + char *prompt = test_server_word_filter_prompt_dup(); + TEST_ASSERT(prompt != NULL); + if (!prompt) return; + + char *http_req = test_build_chat_http_request(prompt, true); + TEST_ASSERT(http_req != NULL); + if (!http_req) { + free(prompt); + return; + } + + ds4_engine *engine = test_get_engine(false); + server s; + pthread_t worker; + pthread_t client_thread; + int sv[2] = {-1, -1}; + test_stream_capture cap = {.fd = -1}; + + test_server_init_live(&s, engine, 4096, NULL); + if (!s.session) { + free(http_req); + free(prompt); + return; + } + TEST_ASSERT(pthread_create(&worker, NULL, worker_main, &s) == 0); + + TEST_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv) == 0); + if (sv[0] >= 0 && sv[1] >= 0) { + configure_client_socket(sv[0]); + test_send_all_or_fail(sv[1], http_req, strlen(http_req)); + shutdown(sv[1], SHUT_WR); + cap.send_done_at = test_wall_sec(); + test_set_nonblocking_or_fail(sv[1]); + cap.fd = sv[1]; + + client_arg *ca = xmalloc(sizeof(*ca)); + ca->srv = &s; + ca->fd = sv[0]; + TEST_ASSERT(pthread_create(&client_thread, NULL, client_main, ca) == 0); + + test_drive_single_stream(&cap); + pthread_join(client_thread, NULL); + } + + pthread_mutex_lock(&s.mu); + s.stopping = true; + pthread_cond_broadcast(&s.cv); + pthread_mutex_unlock(&s.mu); + pthread_join(worker, NULL); + + TEST_ASSERT(cap.raw.ptr != NULL); + TEST_ASSERT(cap.saw_bytes); + TEST_ASSERT(cap.saw_done); + TEST_ASSERT(cap.raw.ptr && strstr(cap.raw.ptr, "HTTP/1.1 200 OK") != NULL); + TEST_ASSERT(cap.last_byte_at > 0.0); + size_t line_count = test_count_nonempty_lines(cap.raw.ptr ? cap.raw.ptr : ""); + server_log(DS4_LOG_DEFAULT, + "ds4-test: smoke raw_response_nonempty_lines=%zu expected_lt=150", + line_count); + TEST_ASSERT(line_count < 150); + + test_log_stream_capture_server("smoke", &cap, cap.raw.ptr ? cap.raw.ptr : ""); + + buf_free(&cap.raw); + test_server_cleanup_keep_engine(&s); + free(http_req); + free(prompt); +} + +static void test_server_concurrent_requests_stream_sequentially(void) { + char *prompt = test_server_word_filter_prompt_dup(); + TEST_ASSERT(prompt != NULL); + if (!prompt) return; + + char *http_req = test_build_chat_http_request(prompt, true); + TEST_ASSERT(http_req != NULL); + if (!http_req) { + free(prompt); + return; + } + + ds4_engine *engine = test_get_engine(false); + server s; + pthread_t worker; + pthread_t client_threads[2]; + int sv[2][2] = {{-1, -1}, {-1, -1}}; + test_stream_capture caps[2] = {{.fd = -1}, {.fd = -1}}; + + test_server_init_live(&s, engine, 4096, test_server_trace_path); + if (!s.session) { + free(http_req); + free(prompt); + return; + } + s.batching = true; + s.batch_size = 2; + TEST_ASSERT(pthread_create(&worker, NULL, worker_main, &s) == 0); + + for (int i = 0; i < 2; i++) { + TEST_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv[i]) == 0); + if (sv[i][0] < 0 || sv[i][1] < 0) continue; + configure_client_socket(sv[i][0]); + test_send_all_or_fail(sv[i][1], http_req, strlen(http_req)); + shutdown(sv[i][1], SHUT_WR); + caps[i].send_done_at = test_wall_sec(); + test_set_nonblocking_or_fail(sv[i][1]); + caps[i].fd = sv[i][1]; + + client_arg *ca = xmalloc(sizeof(*ca)); + ca->srv = &s; + ca->fd = sv[i][0]; + TEST_ASSERT(pthread_create(&client_threads[i], NULL, client_main, ca) == 0); + + if (i == 0) test_sleep_ms(500); + } + + test_drive_two_streams(&caps[0], &caps[1]); + + for (int i = 0; i < 2; i++) { + pthread_join(client_threads[i], NULL); + } + pthread_mutex_lock(&s.mu); + s.stopping = true; + pthread_cond_broadcast(&s.cv); + pthread_mutex_unlock(&s.mu); + pthread_join(worker, NULL); + + for (int i = 0; i < 2; i++) { + TEST_ASSERT(caps[i].raw.ptr != NULL); + TEST_ASSERT(caps[i].saw_bytes); + TEST_ASSERT(caps[i].saw_done); + TEST_ASSERT(caps[i].raw.ptr && strstr(caps[i].raw.ptr, "HTTP/1.1 200 OK") != NULL); + TEST_ASSERT(caps[i].last_byte_at > 0.0); + size_t line_count = test_count_nonempty_lines(caps[i].raw.ptr ? caps[i].raw.ptr : ""); + server_log(DS4_LOG_DEFAULT, + "ds4-test: req%d raw_response_nonempty_lines=%zu expected_lt=150", + i + 1, + line_count); + TEST_ASSERT(line_count < 150); + } + + test_log_stream_capture_server("req1", &caps[0], caps[0].raw.ptr ? caps[0].raw.ptr : ""); + test_log_stream_capture_server("req2", &caps[1], caps[1].raw.ptr ? caps[1].raw.ptr : ""); + test_log_stream_timing_summary_server("req1", &caps[0]); + test_log_stream_timing_summary_server("req2", &caps[1]); + server_log(DS4_LOG_DEFAULT, + "ds4-test: concurrent compare req_gap_ms=500 req1_last_byte_ts=%.6f req2_first_byte_ts=%.6f sequential=%d", + caps[0].last_byte_at, + caps[1].first_byte_at, + caps[1].first_byte_at > caps[0].last_byte_at ? 1 : 0); + TEST_ASSERT(caps[1].first_byte_at > 0.0); + TEST_ASSERT(caps[0].last_byte_at > caps[1].first_byte_at); + + buf_free(&caps[0].raw); + buf_free(&caps[1].raw); + test_server_cleanup_keep_engine(&s); + free(http_req); + free(prompt); +} + typedef struct { const char *name; int number; @@ -650,6 +1094,10 @@ static void test_tool_call_quality(void) { static void test_server_unit_group(void) { ds4_server_unit_tests_run(); +#ifndef DS4_NO_GPU + test_server_single_request_word_filter(); + test_server_concurrent_requests_stream_sequentially(); +#endif } typedef void (*test_fn)(void); @@ -668,7 +1116,7 @@ static const ds4_test_entry test_entries[] = { {"--logprob-vectors", "logprob-vectors", "official API top-logprob vector comparison", test_official_logprob_vectors}, {"--metal-kernels", "metal-kernels", "isolated Metal kernel numeric regressions", test_metal_f16_matvec_fast_nr0_4}, #endif - {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, + {"--server", "server", "server parser/rendering/cache unit tests plus concurrent inference smoke", test_server_unit_group}, }; static void test_print_help(const char *prog) {