Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@

WARM_UP_FACTOR = 20
BENCH_ITERS = 1_000_000
STEPS = [2 ** i - 1 for i in range(2, 20)]

DEFAULT_PIVOT_PATH = Path(__file__).parent.parent / "build" / "pivot"
pivot_path = os.getenv("PIVOT_PATH", DEFAULT_PIVOT_PATH)


def warmup(dim: int, seed: int | None = None):
def _steps(max_power):
return [2 ** i - 1 for i in range(2, max_power + 1)]


def warmup(dim: int, max_power: int, seed: int | None = None):
steps_list = _steps(max_power)
print(f"Running warmup for dimension {dim}")
for steps in STEPS:
for steps in steps_list:
warm_up_iters = WARM_UP_FACTOR * steps
print(f"Warming up with {steps} steps for {warm_up_iters} iterations")

Expand All @@ -30,12 +34,13 @@ def warmup(dim: int, seed: int | None = None):
print(f"Checkpoint saved to {out_dir}/walk.csv")


def benchmark(dim: int, slow: bool, naive: bool = False, seed: int | None = None):
def benchmark(dim: int, slow: bool, max_power: int, naive: bool = False, seed: int | None = None):
if naive:
slow = True
print(f"Running benchmark for dimension {dim}")
times = {}
for steps in STEPS:
steps_list = _steps(max_power)
for steps in steps_list:
print(f"Running benchmark with {steps} steps for {BENCH_ITERS} iterations")

in_dir = Path(__file__).parent / "benchmark" / f"dim_{dim}" / f"warmup_{steps}/walk.csv"
Expand Down Expand Up @@ -72,7 +77,8 @@ def analyze(dim):
times = json.load(f)
times = {int(k): v for k, v in times.items()}

plt.plot(STEPS, [times[steps] for steps in STEPS], marker="o")
steps_list = sorted(times.keys())
plt.plot(steps_list, [times[steps] for steps in steps_list], marker="o")
plt.title(f"Dimension {dim}")
plt.xlabel("Number of steps")
plt.ylabel("Microseconds per pivot attempt")
Expand All @@ -84,6 +90,7 @@ def analyze(dim):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dim", type=int, default=2)
parser.add_argument("--max-power", type=int, default=19)
subparsers = parser.add_subparsers(dest="command")

warmup_parser = subparsers.add_parser("warmup")
Expand All @@ -98,8 +105,8 @@ def analyze(dim):

args = parser.parse_args()
if args.command == "warmup":
warmup(args.dim, args.seed)
warmup(args.dim, args.max_power, args.seed)
elif args.command == "benchmark":
benchmark(args.dim, args.slow, naive=args.naive, seed=args.seed)
benchmark(args.dim, args.slow, args.max_power, naive=args.naive, seed=args.seed)
elif args.command == "analyze":
analyze(args.dim)