|
| 1 | +"""Side-by-side benchmark. Run on main, then on the perf branch.""" |
| 2 | +import os, sys, time |
| 3 | +os.environ["BRAINTRUST_DISABLE_ATEXIT_FLUSH"] = "true" |
| 4 | +sys.path.insert(0, "src") |
| 5 | + |
| 6 | +from braintrust.logger import ( |
| 7 | + BraintrustState, SpanImpl, _MemoryBackgroundLogger, SpanObjectTypeV3, |
| 8 | + stringify_with_overflow_meta, |
| 9 | +) |
| 10 | +from braintrust.merge_row_batch import merge_row_batch |
| 11 | +from braintrust.bt_json import bt_safe_deep_copy |
| 12 | +from braintrust.util import LazyValue |
| 13 | + |
| 14 | +PID = LazyValue(lambda: "test", use_mutex=False) |
| 15 | +PID.get() |
| 16 | + |
| 17 | +MED = { |
| 18 | + "input": {"messages": [{"role": "user", "content": "What is 2+2?"}, |
| 19 | + {"role": "assistant", "content": "4"}]}, |
| 20 | + "output": {"result": "The answer is 4", "confidence": 0.95}, |
| 21 | + "scores": {"accuracy": 0.9, "relevance": 0.85}, |
| 22 | + "metadata": {"model": "gpt-4", "temperature": 0.7, "max_tokens": 100}, |
| 23 | +} |
| 24 | +LARGE = { |
| 25 | + "input": {"messages": [{"role": "user", "content": f"msg {i}" * 20} for i in range(10)]}, |
| 26 | + "output": {"result": "x" * 5000, |
| 27 | + "details": {f"k{i}": f"v{i}" * 10 for i in range(50)}}, |
| 28 | + "scores": {f"s{i}": i / 100.0 for i in range(20)}, |
| 29 | + "metadata": {f"m{i}": f"v{i}" * 5 for i in range(30)}, |
| 30 | +} |
| 31 | + |
| 32 | +def fresh(): |
| 33 | + s = BraintrustState() |
| 34 | + ml = _MemoryBackgroundLogger() |
| 35 | + s._override_bg_logger.logger = ml |
| 36 | + return s, ml |
| 37 | + |
| 38 | +def bench(label, fn, N): |
| 39 | + # warmup |
| 40 | + fn() |
| 41 | + t0 = time.perf_counter() |
| 42 | + for _ in range(N): |
| 43 | + fn() |
| 44 | + us = (time.perf_counter() - t0) / N * 1e6 |
| 45 | + print(f" {label:40s} {us:8.1f} us/op") |
| 46 | + return us |
| 47 | + |
| 48 | +N = 5000 |
| 49 | +print(f"N={N}\n") |
| 50 | + |
| 51 | +# -- user thread -- |
| 52 | +print("User thread:") |
| 53 | + |
| 54 | +s, ml = fresh() |
| 55 | +bench("start_span (medium)", lambda: SpanImpl( |
| 56 | + parent_object_type=SpanObjectTypeV3.PROJECT_LOGS, parent_object_id=PID, |
| 57 | + parent_compute_object_metadata_args=None, parent_span_ids=None, |
| 58 | + name="b", state=s, event=dict(MED), lookup_span_parent=False), N) |
| 59 | + |
| 60 | +s, ml = fresh() |
| 61 | +bench("start_span (large)", lambda: SpanImpl( |
| 62 | + parent_object_type=SpanObjectTypeV3.PROJECT_LOGS, parent_object_id=PID, |
| 63 | + parent_compute_object_metadata_args=None, parent_span_ids=None, |
| 64 | + name="b", state=s, event=dict(LARGE), lookup_span_parent=False), N) |
| 65 | + |
| 66 | +s, ml = fresh() |
| 67 | +def _tree(): |
| 68 | + root = SpanImpl(parent_object_type=SpanObjectTypeV3.PROJECT_LOGS, parent_object_id=PID, |
| 69 | + parent_compute_object_metadata_args=None, parent_span_ids=None, |
| 70 | + name="root", state=s, event=dict(MED), lookup_span_parent=False) |
| 71 | + c = root.start_span(name="child", input="x", output="y") |
| 72 | + c.end(); root.end() |
| 73 | +bench("root + child + end (medium)", _tree, N) |
| 74 | + |
| 75 | +s, ml = fresh() |
| 76 | +span = SpanImpl(parent_object_type=SpanObjectTypeV3.PROJECT_LOGS, parent_object_id=PID, |
| 77 | + parent_compute_object_metadata_args=None, parent_span_ids=None, |
| 78 | + name="b", state=s, lookup_span_parent=False) |
| 79 | +bench("span.log (medium)", lambda: span.log(**MED), N) |
| 80 | + |
| 81 | +# -- deep copy -- |
| 82 | +print("\nDeep copy:") |
| 83 | +bench("bt_safe_deep_copy (medium)", lambda: bt_safe_deep_copy(MED), N) |
| 84 | +bench("bt_safe_deep_copy (large)", lambda: bt_safe_deep_copy(LARGE), N) |
| 85 | + |
| 86 | +# -- flush thread -- |
| 87 | +print("\nFlush thread:") |
| 88 | +for count in (1000, 5000): |
| 89 | + s2, ml2 = fresh() |
| 90 | + for _ in range(count): |
| 91 | + SpanImpl(parent_object_type=SpanObjectTypeV3.PROJECT_LOGS, parent_object_id=PID, |
| 92 | + parent_compute_object_metadata_args=None, parent_span_ids=None, |
| 93 | + name="b", state=s2, event=dict(MED), lookup_span_parent=False) |
| 94 | + items = ml2.logs[:] |
| 95 | + def _flush(items=items): |
| 96 | + unwrapped = [it.get() for it in items] |
| 97 | + merged = merge_row_batch(unwrapped) |
| 98 | + _ = [stringify_with_overflow_meta(m) for m in merged] |
| 99 | + t0 = time.perf_counter() |
| 100 | + _flush() |
| 101 | + elapsed = time.perf_counter() - t0 |
| 102 | + print(f" flush {count} items (medium) {elapsed/count*1e6:8.1f} us/item") |
0 commit comments