Skip to content

Commit 9be35cf

Browse files
authored
Merge branch 'google:main' into master
2 parents 0dabfc8 + 0da57b8 commit 9be35cf

8 files changed

Lines changed: 281 additions & 18 deletions

File tree

docs/user_guide.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,46 @@ BENCHMARK(BM_test)->Range(8, 8<<10)->UseRealTime();
863863

864864
Without `UseRealTime`, CPU time is used by default.
865865

866+
### Manual Multithreaded Benchmarks
867+
868+
Google/benchmark uses `std::thread` as multithreading environment per default.
869+
If you want to use another multithreading environment (e.g. OpenMP), you can provide
870+
a factory function to your benchmark using the `ThreadRunner` function.
871+
The factory function takes the number of threads as argument and creates a custom class
872+
derived from `benchmark::ThreadRunnerBase`.
873+
This custom class must override the function
874+
`void RunThreads(const std::function<void(int)>& fn)`.
875+
`RunThreads` is called by the main thread and spawns the requested number of threads.
876+
Each spawned thread must call `fn(thread_index)`, where `thread_index` is its own
877+
thread index. Before `RunThreads` returns, all spawned threads must be joined.
878+
```c++
879+
class OpenMPThreadRunner : public benchmark::ThreadRunnerBase
880+
{
881+
OpenMPThreadRunner(int num_threads)
882+
: num_threads_(num_threads)
883+
{}
884+
885+
void RunThreads(const std::function<void(int)>& fn) final
886+
{
887+
#pragma omp parallel num_threads(num_threads_)
888+
fn(omp_get_thread_num());
889+
}
890+
891+
private:
892+
int num_threads_;
893+
};
894+
895+
BENCHMARK(BM_MultiThreaded)
896+
->ThreadRunner([](int num_threads) {
897+
return std::make_unique<OpenMPThreadRunner>(num_threads);
898+
})
899+
->Threads(1)->Threads(2)->Threads(4);
900+
```
901+
The above example creates a parallel OpenMP region before it enters `BM_MultiThreaded`.
902+
The actual benchmark code can remain the same and is therefore not tied to a specific
903+
thread runner. The measurement does not include the time for creating and joining the
904+
threads.
905+
866906
<a name="cpu-timers" />
867907
868908
## CPU Timers

include/benchmark/benchmark.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,8 +1093,18 @@ inline BENCHMARK_ALWAYS_INLINE State::StateIterator State::end() {
10931093
return StateIterator();
10941094
}
10951095

1096+
// Base class for user-defined multi-threading
1097+
struct ThreadRunnerBase {
1098+
virtual ~ThreadRunnerBase() {}
1099+
virtual void RunThreads(const std::function<void(int)>& fn) = 0;
1100+
};
1101+
10961102
namespace internal {
10971103

1104+
// Define alias of ThreadRunner factory function type
1105+
using threadrunner_factory =
1106+
std::function<std::unique_ptr<ThreadRunnerBase>(int)>;
1107+
10981108
typedef void(Function)(State&);
10991109

11001110
// ------------------------------------------------------
@@ -1299,6 +1309,9 @@ class BENCHMARK_EXPORT Benchmark {
12991309
// Equivalent to ThreadRange(NumCPUs(), NumCPUs())
13001310
Benchmark* ThreadPerCpu();
13011311

1312+
// Sets a user-defined threadrunner (see ThreadRunnerBase)
1313+
Benchmark* ThreadRunner(threadrunner_factory&& factory);
1314+
13021315
virtual void Run(State& state) = 0;
13031316

13041317
TimeUnit GetTimeUnit() const;
@@ -1340,6 +1353,8 @@ class BENCHMARK_EXPORT Benchmark {
13401353
callback_function setup_;
13411354
callback_function teardown_;
13421355

1356+
threadrunner_factory threadrunner_;
1357+
13431358
BENCHMARK_DISALLOW_COPY_AND_ASSIGN(Benchmark);
13441359
};
13451360

src/benchmark_api_internal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ class BenchmarkInstance {
4141
int threads() const { return threads_; }
4242
void Setup() const;
4343
void Teardown() const;
44+
const auto& GetUserThreadRunnerFactory() const {
45+
return benchmark_.threadrunner_;
46+
}
4447

4548
State Run(IterationCount iters, int thread_id, internal::ThreadTimer* timer,
4649
internal::ThreadManager* manager,

src/benchmark_register.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,11 @@ Benchmark* Benchmark::ThreadPerCpu() {
484484
return this;
485485
}
486486

487+
Benchmark* Benchmark::ThreadRunner(threadrunner_factory&& factory) {
488+
threadrunner_ = std::move(factory);
489+
return this;
490+
}
491+
487492
void Benchmark::SetName(const std::string& name) { name_ = name; }
488493

489494
const char* Benchmark::GetName() const { return name_.c_str(); }

src/benchmark_runner.cc

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <cstdio>
3535
#include <cstdlib>
3636
#include <fstream>
37+
#include <functional>
3738
#include <iostream>
3839
#include <limits>
3940
#include <memory>
@@ -182,6 +183,38 @@ IterationCount ComputeIters(const benchmark::internal::BenchmarkInstance& b,
182183
return iters_or_time.iters;
183184
}
184185

186+
class ThreadRunnerDefault : public ThreadRunnerBase {
187+
public:
188+
explicit ThreadRunnerDefault(int num_threads)
189+
: pool(static_cast<size_t>(num_threads - 1)) {}
190+
191+
void RunThreads(const std::function<void(int)>& fn) final {
192+
// Run all but one thread in separate threads
193+
for (std::size_t ti = 0; ti < pool.size(); ++ti) {
194+
pool[ti] = std::thread(fn, static_cast<int>(ti + 1));
195+
}
196+
// And run one thread here directly.
197+
// (If we were asked to run just one thread, we don't create new threads.)
198+
// Yes, we need to do this here *after* we start the separate threads.
199+
fn(0);
200+
201+
// The main thread has finished. Now let's wait for the other threads.
202+
for (std::thread& thread : pool) {
203+
thread.join();
204+
}
205+
}
206+
207+
private:
208+
std::vector<std::thread> pool;
209+
};
210+
211+
std::unique_ptr<ThreadRunnerBase> GetThreadRunner(
212+
const threadrunner_factory& userThreadRunnerFactory, int num_threads) {
213+
return userThreadRunnerFactory
214+
? userThreadRunnerFactory(num_threads)
215+
: std::make_unique<ThreadRunnerDefault>(num_threads);
216+
}
217+
185218
} // end namespace
186219

187220
BenchTimeType ParseBenchMinTime(const std::string& value) {
@@ -258,7 +291,8 @@ BenchmarkRunner::BenchmarkRunner(
258291
has_explicit_iteration_count(b.iterations() != 0 ||
259292
parsed_benchtime_flag.tag ==
260293
BenchTimeType::ITERS),
261-
pool(static_cast<size_t>(b.threads() - 1)),
294+
thread_runner(
295+
GetThreadRunner(b.GetUserThreadRunnerFactory(), b.threads())),
262296
iters(FLAGS_benchmark_dry_run
263297
? 1
264298
: (has_explicit_iteration_count
@@ -289,22 +323,10 @@ BenchmarkRunner::IterationResults BenchmarkRunner::DoNIterations() {
289323
std::unique_ptr<internal::ThreadManager> manager;
290324
manager.reset(new internal::ThreadManager(b.threads()));
291325

292-
// Run all but one thread in separate threads
293-
for (std::size_t ti = 0; ti < pool.size(); ++ti) {
294-
pool[ti] = std::thread(&RunInThread, &b, iters, static_cast<int>(ti + 1),
295-
manager.get(), perf_counters_measurement_ptr,
296-
/*profiler_manager=*/nullptr);
297-
}
298-
// And run one thread here directly.
299-
// (If we were asked to run just one thread, we don't create new threads.)
300-
// Yes, we need to do this here *after* we start the separate threads.
301-
RunInThread(&b, iters, 0, manager.get(), perf_counters_measurement_ptr,
302-
/*profiler_manager=*/nullptr);
303-
304-
// The main thread has finished. Now let's wait for the other threads.
305-
for (std::thread& thread : pool) {
306-
thread.join();
307-
}
326+
thread_runner->RunThreads([&](int thread_idx) {
327+
RunInThread(&b, iters, thread_idx, manager.get(),
328+
perf_counters_measurement_ptr, /*profiler_manager=*/nullptr);
329+
});
308330

309331
IterationResults i;
310332
// Acquire the measurements/counters from the manager, UNDER THE LOCK!

src/benchmark_runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef BENCHMARK_RUNNER_H_
1616
#define BENCHMARK_RUNNER_H_
1717

18+
#include <memory>
1819
#include <thread>
1920
#include <vector>
2021

@@ -89,7 +90,7 @@ class BenchmarkRunner {
8990

9091
int num_repetitions_done = 0;
9192

92-
std::vector<std::thread> pool;
93+
std::unique_ptr<ThreadRunnerBase> thread_runner;
9394

9495
IterationCount iters; // preserved between repetitions!
9596
// So only the first repetition has to find/calculate it,

test/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ benchmark_add_test(NAME perf_counters_test COMMAND perf_counters_test --benchmar
189189
compile_output_test(internal_threading_test)
190190
benchmark_add_test(NAME internal_threading_test COMMAND internal_threading_test --benchmark_min_time=0.01s)
191191

192+
compile_output_test(manual_threading_test)
193+
benchmark_add_test(NAME manual_threading_test COMMAND manual_threading_test --benchmark_min_time=0.01s)
194+
192195
compile_output_test(report_aggregates_only_test)
193196
benchmark_add_test(NAME report_aggregates_only_test COMMAND report_aggregates_only_test --benchmark_min_time=0.01s)
194197

test/manual_threading_test.cc

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
2+
#include <memory>
3+
#undef NDEBUG
4+
5+
#include <chrono>
6+
#include <thread>
7+
8+
#include "../src/timers.h"
9+
#include "benchmark/benchmark.h"
10+
11+
namespace {
12+
13+
const std::chrono::duration<double, std::milli> time_frame(50);
14+
const double time_frame_in_sec(
15+
std::chrono::duration_cast<std::chrono::duration<double, std::ratio<1, 1>>>(
16+
time_frame)
17+
.count());
18+
19+
void MyBusySpinwait() {
20+
const auto start = benchmark::ChronoClockNow();
21+
22+
while (true) {
23+
const auto now = benchmark::ChronoClockNow();
24+
const auto elapsed = now - start;
25+
26+
if (std::chrono::duration<double, std::chrono::seconds::period>(elapsed) >=
27+
time_frame) {
28+
return;
29+
}
30+
}
31+
}
32+
33+
int numRunThreadsCalled_ = 0;
34+
35+
class ManualThreadRunner : public benchmark::ThreadRunnerBase {
36+
public:
37+
explicit ManualThreadRunner(int num_threads)
38+
: pool(static_cast<size_t>(num_threads - 1)) {}
39+
40+
void RunThreads(const std::function<void(int)>& fn) final {
41+
for (std::size_t ti = 0; ti < pool.size(); ++ti) {
42+
pool[ti] = std::thread(fn, static_cast<int>(ti + 1));
43+
}
44+
45+
fn(0);
46+
47+
for (std::thread& thread : pool) {
48+
thread.join();
49+
}
50+
51+
++numRunThreadsCalled_;
52+
}
53+
54+
private:
55+
std::vector<std::thread> pool;
56+
};
57+
58+
// ========================================================================= //
59+
// --------------------------- TEST CASES BEGIN ---------------------------- //
60+
// ========================================================================= //
61+
62+
// ========================================================================= //
63+
// BM_ManualThreading
64+
// Creation of threads is done before the start of the measurement,
65+
// joining after the finish of the measurement.
66+
void BM_ManualThreading(benchmark::State& state) {
67+
for (auto _ : state) {
68+
MyBusySpinwait();
69+
state.SetIterationTime(time_frame_in_sec);
70+
}
71+
state.counters["invtime"] =
72+
benchmark::Counter{1, benchmark::Counter::kIsRate};
73+
}
74+
75+
} // end namespace
76+
77+
BENCHMARK(BM_ManualThreading)
78+
->Iterations(1)
79+
->ThreadRunner([](int num_threads) {
80+
return std::make_unique<ManualThreadRunner>(num_threads);
81+
})
82+
->Threads(1);
83+
BENCHMARK(BM_ManualThreading)
84+
->Iterations(1)
85+
->ThreadRunner([](int num_threads) {
86+
return std::make_unique<ManualThreadRunner>(num_threads);
87+
})
88+
->Threads(1)
89+
->UseRealTime();
90+
BENCHMARK(BM_ManualThreading)
91+
->Iterations(1)
92+
->ThreadRunner([](int num_threads) {
93+
return std::make_unique<ManualThreadRunner>(num_threads);
94+
})
95+
->Threads(1)
96+
->UseManualTime();
97+
BENCHMARK(BM_ManualThreading)
98+
->Iterations(1)
99+
->ThreadRunner([](int num_threads) {
100+
return std::make_unique<ManualThreadRunner>(num_threads);
101+
})
102+
->Threads(1)
103+
->MeasureProcessCPUTime();
104+
BENCHMARK(BM_ManualThreading)
105+
->Iterations(1)
106+
->ThreadRunner([](int num_threads) {
107+
return std::make_unique<ManualThreadRunner>(num_threads);
108+
})
109+
->Threads(1)
110+
->MeasureProcessCPUTime()
111+
->UseRealTime();
112+
BENCHMARK(BM_ManualThreading)
113+
->Iterations(1)
114+
->ThreadRunner([](int num_threads) {
115+
return std::make_unique<ManualThreadRunner>(num_threads);
116+
})
117+
->Threads(1)
118+
->MeasureProcessCPUTime()
119+
->UseManualTime();
120+
121+
BENCHMARK(BM_ManualThreading)
122+
->Iterations(1)
123+
->ThreadRunner([](int num_threads) {
124+
return std::make_unique<ManualThreadRunner>(num_threads);
125+
})
126+
->Threads(2);
127+
BENCHMARK(BM_ManualThreading)
128+
->Iterations(1)
129+
->ThreadRunner([](int num_threads) {
130+
return std::make_unique<ManualThreadRunner>(num_threads);
131+
})
132+
->Threads(2)
133+
->UseRealTime();
134+
BENCHMARK(BM_ManualThreading)
135+
->Iterations(1)
136+
->ThreadRunner([](int num_threads) {
137+
return std::make_unique<ManualThreadRunner>(num_threads);
138+
})
139+
->Threads(2)
140+
->UseManualTime();
141+
BENCHMARK(BM_ManualThreading)
142+
->Iterations(1)
143+
->ThreadRunner([](int num_threads) {
144+
return std::make_unique<ManualThreadRunner>(num_threads);
145+
})
146+
->Threads(2)
147+
->MeasureProcessCPUTime();
148+
BENCHMARK(BM_ManualThreading)
149+
->Iterations(1)
150+
->ThreadRunner([](int num_threads) {
151+
return std::make_unique<ManualThreadRunner>(num_threads);
152+
})
153+
->Threads(2)
154+
->MeasureProcessCPUTime()
155+
->UseRealTime();
156+
BENCHMARK(BM_ManualThreading)
157+
->Iterations(1)
158+
->ThreadRunner([](int num_threads) {
159+
return std::make_unique<ManualThreadRunner>(num_threads);
160+
})
161+
->Threads(2)
162+
->MeasureProcessCPUTime()
163+
->UseManualTime();
164+
165+
// ========================================================================= //
166+
// ---------------------------- TEST CASES END ----------------------------- //
167+
// ========================================================================= //
168+
169+
int main(int argc, char* argv[]) {
170+
benchmark::Initialize(&argc, argv);
171+
benchmark::RunSpecifiedBenchmarks();
172+
benchmark::Shutdown();
173+
assert(numRunThreadsCalled_ > 0);
174+
}

0 commit comments

Comments
 (0)