Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/transport/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
if (GHEX_USE_UCP)
# Variable used for benchmarks that DO NOT require multithreading support
#set(_benchmarks ghex_p2p_bi_cb_avail ghex_p2p_bi_cb_wait ghex_p2p_bi_ft_avail ghex_p2p_bi_ft_wait mpi_p2p_bi_avail mpi_p2p_bi_wait)
set(_benchmarks ghex_p2p_bi_cb_avail)
set(_benchmarks ghex_p2p_bi_cb_avail ghex_p2p_bi_ft_avail)

# Variable used for benchmarks that require multithreading support
#set(_benchmarks_mt ghex_p2p_bi_cb_avail ghex_p2p_bi_cb_wait ghex_p2p_bi_ft_avail ghex_p2p_bi_ft_wait mpi_p2p_bi_avail mpi_p2p_bi_wait)
set(_benchmarks_mt ghex_p2p_bi_cb_avail)
set(_benchmarks_mt ghex_p2p_bi_cb_avail ghex_p2p_bi_ft_avail)

foreach (_t ${_benchmarks})
add_executable(${_t} ${_t}_mt.cpp )
Expand Down
130 changes: 62 additions & 68 deletions benchmarks/transport/ghex_p2p_bi_cb_avail_mt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,45 @@
#include <vector>
#include <atomic>

//#include <ghex/transport_layer/ucx/threads.hpp>
#include <ghex/common/timer.hpp>
#include "utils.hpp"

namespace ghex = gridtools::ghex;

#ifdef USE_MPI
/* MPI backend */
#ifdef USE_OPENMP
#include <ghex/threads/atomic/primitives.hpp>
using threading = ghex::threads::atomic::primitives;
#else
#include <ghex/threads/none/primitives.hpp>
using threading = ghex::threads::none::primitives;
#endif
#include <ghex/transport_layer/mpi/context.hpp>
using transport = ghex::tl::mpi_tag;
/* MPI backend */
#ifdef USE_OPENMP
#include <ghex/threads/atomic/primitives.hpp>
using threading = ghex::threads::atomic::primitives;
#else
#include <ghex/threads/none/primitives.hpp>
using threading = ghex::threads::none::primitives;
#endif
#include <ghex/transport_layer/mpi/context.hpp>
using transport = ghex::tl::mpi_tag;
#else
/* UCX backend */
#ifdef USE_OPENMP
#include <ghex/threads/omp/primitives.hpp>
using threading = ghex::threads::omp::primitives;
#else
/* UCX backend */
#ifdef USE_OPENMP
#include <ghex/threads/omp/primitives.hpp>
using threading = ghex::threads::omp::primitives;
//#include <ghex/threads/atomic/primitives.hpp>
//using threading = ghex::threads::atomic::primitives;
#else
#include <ghex/threads/none/primitives.hpp>
using threading = ghex::threads::none::primitives;
#endif
#include <ghex/transport_layer/ucx/address_db_mpi.hpp>
#include <ghex/transport_layer/ucx/context.hpp>
using db_type = ghex::tl::ucx::address_db_mpi;
using transport = ghex::tl::ucx_tag;
#include <ghex/threads/none/primitives.hpp>
using threading = ghex::threads::none::primitives;
#endif
#include <ghex/transport_layer/ucx/address_db_mpi.hpp>
#include <ghex/transport_layer/ucx/context.hpp>
using db_type = ghex::tl::ucx::address_db_mpi;
using transport = ghex::tl::ucx_tag;
#endif /* USE_MPI */

#include <ghex/transport_layer/message_buffer.hpp>
#include <ghex/transport_layer/shared_message_buffer.hpp>
using context_type = ghex::tl::context<transport, threading>;
using communicator_type = typename context_type::communicator_type;
using future_type = typename communicator_type::request_cb_type;

using MsgType = gridtools::ghex::tl::message_buffer<>;
//using MsgType = gridtools::ghex::tl::shared_message_buffer<>;
//using MsgType = gridtools::ghex::tl::message_buffer<>;
using MsgType = gridtools::ghex::tl::shared_message_buffer<>;


std::atomic<int> sent(0);
Expand All @@ -52,46 +49,45 @@ std::atomic<int> tail_send(0);
std::atomic<int> tail_recv(0);
int last_received = 0;
int last_sent = 0;
int inflight;


int main(int argc, char *argv[])
{
int niter, buff_size;
int inflight;
int mode;
gridtools::ghex::timer timer, ttimer;

if(argc != 4){
std::cerr << "Usage: bench [niter] [msg_size] [inflight]" << "\n";
std::terminate();
if(argc != 4)
{
std::cerr << "Usage: bench [niter] [msg_size] [inflight]" << "\n";
std::terminate();
}
niter = atoi(argv[1]);
buff_size = atoi(argv[2]);
inflight = atoi(argv[3]);
inflight = atoi(argv[3]);

gridtools::ghex::timer timer, ttimer;

int num_threads = 1;
#ifdef USE_OPENMP
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &mode);
if(mode != MPI_THREAD_MULTIPLE){
std::cerr << "MPI_THREAD_MULTIPLE not supported by MPI, aborting\n";
std::terminate();
std::cerr << "MPI_THREAD_MULTIPLE not supported by MPI, aborting\n";
std::terminate();
}
#pragma omp parallel
#pragma omp parallel
{
#pragma omp master
#pragma omp master
num_threads = omp_get_num_threads();
}
#else
MPI_Init_thread(NULL, NULL, MPI_THREAD_SINGLE, &mode);
#endif

{
auto context_ptr = ghex::tl::context_factory<transport,threading>::create(num_threads, MPI_COMM_WORLD);
auto& context = *context_ptr;

#ifdef USE_OPENMP
#pragma omp parallel
#pragma omp parallel
#endif
{
auto token = context.get_token();
Expand All @@ -101,7 +97,7 @@ int main(int argc, char *argv[])
const auto thread_id = token.id();
const auto num_threads = context.thread_primitives().size();
const auto peer_rank = (rank+1)%2;

int comm_cnt = 0, nlsend_cnt = 0, nlrecv_cnt = 0, submit_cnt = 0, submit_recv_cnt = 0;
int dbg = 0, sdbg = 0, rdbg = 0;

Expand All @@ -120,12 +116,12 @@ int main(int argc, char *argv[])
comm_cnt++;
received++;
};

if (thread_id==0 && rank==0)
{
std::cout << "\n\nrunning test " << __FILE__ << " with communicator " << typeid(comm).name() << "\n\n";
};

std::vector<MsgType> smsgs(inflight);
std::vector<MsgType> rmsgs(inflight);
std::vector<future_type> sreqs(inflight);
Expand All @@ -137,24 +133,24 @@ int main(int argc, char *argv[])
make_zero(smsgs[j]);
make_zero(rmsgs[j]);
}

context.barrier(token);

if (thread_id == 0)
{
{
timer.tic();
ttimer.tic();
if(rank == 1)
if(rank == 1)
std::cout << "number of threads: " << num_threads << ", multi-threaded: true\n";
};
}

// send/recv niter messages - as soon as a slot becomes free
while(sent < niter || received < niter)
{
if(thread_id == 0 && dbg >= (niter/10))
{
dbg = 0;
std::cout << rank << " total bwdt MB/s: "
std::cout << rank << " total bwdt MB/s: "
<< ((double)(received-last_received + sent-last_sent)*size*buff_size/2)/timer.stoc()
<< "\n";
timer.tic();
Expand Down Expand Up @@ -184,7 +180,7 @@ int main(int argc, char *argv[])
dbg += num_threads;
rreqs[j] = comm.recv(rmsgs[j], peer_rank, thread_id*inflight+j, recv_callback);
}
else
else
comm.progress();

//if(sent < niter && smsgs[j].use_count() == 1)
Expand All @@ -199,29 +195,28 @@ int main(int argc, char *argv[])
comm.progress();
}
}


context.barrier(token);

if(thread_id==0 && rank == 0)
{
const auto t = ttimer.stoc();
std::cout << "time: " << t/1000000 << "s\n";
std::cout << "final MB/s: " << ((double)niter*size*buff_size)/t << "\n";
}

context.barrier(token);
}

context.thread_primitives().critical(
[&]()
{
std::cout
<< "rank " << rank << " thread " << thread_id << " sends submitted " << submit_cnt/num_threads
<< " serviced " << comm_cnt << ", non-local sends " << nlsend_cnt << " non-local recvs " << nlrecv_cnt << "\n";
});
[&]()
{
std::cout
<< "rank " << rank << " thread " << thread_id << " sends submitted " << submit_cnt/num_threads
<< " serviced " << comm_cnt << ", non-local sends " << nlsend_cnt << " non-local recvs " << nlrecv_cnt << "\n";
});

// tail loops - submit RECV requests until
// all SEND requests have been finalized.
// This is because UCX cannot cancel SEND requests.
// https://github.com/openucx/ucx/issues/1162
//
{
int incomplete_sends = 0;
int send_complete = 0;
Expand All @@ -248,12 +243,12 @@ int main(int argc, char *argv[])
}
}
} while(tail_send!=num_threads);

// We have all completed the sends, but the peer might not have yet.
// Notify the peer and keep submitting recvs until we get his notification.
future_type sf, rf;
MsgType smsg(1), rmsg(1);
context.thread_primitives().master(token,
context.thread_primitives().master(token,
[&]() mutable
{
sf = comm.send(smsg, peer_rank, 0x800000, [](communicator_type::message_type, int, int){});
Expand All @@ -269,7 +264,7 @@ int main(int argc, char *argv[])
rreqs[j] = comm.recv(rmsgs[j], peer_rank, thread_id*inflight + j, recv_callback);
}
}
context.thread_primitives().master(token,
context.thread_primitives().master(token,
[&]()
{
if(rf.test()) tail_recv = 1;
Expand All @@ -286,4 +281,3 @@ int main(int argc, char *argv[])
MPI_Barrier(MPI_COMM_WORLD);
MPI_Finalize();
}

Loading