Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions src/ops/dispatch_combine/internode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t scalesOffset = indicesOffset + sizeof(index_t) * numExpertPerToken;
size_t stagingOffset = scalesOffset + config.scaleTypeSize * config.scaleDim;

const uint32_t crossDeviceBarrierFlag = args.crossDeviceBarrierFlag[0];

extern __shared__ char sharedMem[];

int subWarpNumPerWarp = warpSize / numExpertPerToken;
Expand Down Expand Up @@ -140,10 +142,9 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
const int startIdx = localBlockId * baseChunk + min(localBlockId, remainder);
const int endIdx = startIdx + myChunkSize;
if (localBlockId == 0 && warpId == warpNum - 1) {
shmem::ShmemPutInt32ImmNbiWarp(
args.recvTokenNumMemObj,
(myPe + (args.crossDeviceBarrierFlag[0] & 1) * npes) * sizeof(index_t), totalTokens, destPe,
localBlockId);
shmem::ShmemPutInt32ImmNbiWarp(args.recvTokenNumMemObj,
(myPe + (crossDeviceBarrierFlag & 1) * npes) * sizeof(index_t),
totalTokens, destPe, localBlockId);
}

if (destNode == myNode) {
Expand Down Expand Up @@ -258,18 +259,17 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
__syncthreads();
if (warpId == warpNum - 1) {
shmem::ShmemAtomicTypeNonFetchWarp<int64_t>(
args.sendAtomicSignalMemObj,
(myPe + (args.crossDeviceBarrierFlag[0] & 1) * npes) * sizeof(int64_t), 1, core::AMO_ADD,
destPe, localBlockId);
args.sendAtomicSignalMemObj, (myPe + (crossDeviceBarrierFlag & 1) * npes) * sizeof(int64_t),
1, core::AMO_ADD, destPe, localBlockId);
}
if (thdId == 0) {
int64_t* signal = args.sendAtomicSignalMemObj->template GetAs<int64_t*>() + destPe +
(args.crossDeviceBarrierFlag[0] & 1) * npes;
(crossDeviceBarrierFlag & 1) * npes;
shmem::ShmemInt64WaitUntilGreaterThan(signal, numsBlockPerDestPe - 1);
recvTokenNum = atomicAdd(
&args.recvTokenNumMemObj
->template GetAs<index_t*>()[destPe + (args.crossDeviceBarrierFlag[0] & 1) * npes],
0);
recvTokenNum =
atomicAdd(&args.recvTokenNumMemObj
->template GetAs<index_t*>()[destPe + (crossDeviceBarrierFlag & 1) * npes],
0);
if (localBlockId == 0) {
atomicAdd(args.totalRecvTokenNum, recvTokenNum);
args.destPeTokenCounter[destPe] = 0;
Expand Down Expand Up @@ -331,8 +331,8 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
/* BarrierKernel */
/* ---------------------------------------------------------------------------------------------- */
template <typename T>
inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs<T> args,
int numQps) {
inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs<T> args, int numQps,
const uint32_t crossDeviceBarrierFlag) {
int thdId = threadIdx.x;
int laneId = threadIdx.x & (warpSize - 1);
int globalThdId = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -348,15 +348,16 @@ inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs<T
shmem::ShmemUint32WaitUntilEquals(args.combineGridBarrier, globalWarpNum);
}

volatile uint64_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs<volatile uint64_t*>();
volatile uint64_t* localBarrierPtr =
args.crossDeviceBarrierMemObj->template GetAs<volatile uint64_t*>();
if (thdId < args.config.worldSize) {
uint64_t currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId);
#if DEBUG == 1
printf("Thread %d: localBarrierPtr[%d] = %lu, expected = %lu\n", thdId, thdId, currentVal,
(uint64_t)(args.crossDeviceBarrierFlag[0] * numQps));
(uint64_t)(crossDeviceBarrierFlag * numQps));
#endif

while (currentVal != args.crossDeviceBarrierFlag[0] * numQps) {
while (currentVal != crossDeviceBarrierFlag * numQps) {
currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId);
}
}
Expand Down Expand Up @@ -387,6 +388,8 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t MaxNumTokensToSendPerRank = config.MaxNumTokensToSendPerRank();
size_t MaxNumTokensToRecvPerRank = config.MaxNumTokensToRecvPerRank();

const uint32_t crossDeviceBarrierFlag = args.crossDeviceBarrierFlag[0];

// Phase 1: send token
// This phase is symmetric with dispatch recv phase, where tokens are first sent back to its
// source pe in pe sorted order
Expand All @@ -395,7 +398,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
const int srcNode = srcPe / MAX_GPUS_PER_NODE;
const int localBlockId = blockIdx.x - srcPe * numsBlockPerSrcPe;
const int srcPeTokenNum = *(args.recvTokenNumMemObj->template GetAs<index_t*>() + srcPe +
(args.crossDeviceBarrierFlag[0] & 1) * npes);
(crossDeviceBarrierFlag & 1) * npes);
const int baseChunk = srcPeTokenNum / numsBlockPerSrcPe;
const int remainder = srcPeTokenNum % numsBlockPerSrcPe;

Expand Down Expand Up @@ -501,13 +504,13 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
SyncIfDebugEnabled("Combine kernel: send token end");

// Make sure copy on all GPUs are finished
CrossDeviceBarrierInterNodeKernel(args, numsBlockPerSrcPe);
CrossDeviceBarrierInterNodeKernel(args, numsBlockPerSrcPe, crossDeviceBarrierFlag);
shmem::ShmemQuietThread();
if (globalThdId < npes) {
args.recvTokenNumMemObj
->template GetAs<index_t*>()[globalThdId + (args.crossDeviceBarrierFlag[0] & 1) * npes] = 0;
->template GetAs<index_t*>()[globalThdId + (crossDeviceBarrierFlag & 1) * npes] = 0;
args.sendAtomicSignalMemObj
->template GetAs<int64_t*>()[globalThdId + (args.crossDeviceBarrierFlag[0] & 1) * npes] = 0;
->template GetAs<int64_t*>()[globalThdId + (crossDeviceBarrierFlag & 1) * npes] = 0;
}

if (globalThdId == 0) {
Expand Down
17 changes: 12 additions & 5 deletions src/ops/dispatch_combine/intranode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace moe {
/* BarrierKernel */
/* ---------------------------------------------------------------------------------------------- */
template <typename T>
inline __device__ void CrossDeviceBarrierIntraNodeKernel(EpDispatchCombineArgs<T> args) {
inline __device__ void CrossDeviceBarrierIntraNodeKernel(EpDispatchCombineArgs<T> args,
const uint32_t crossDeviceBarrierFlag) {
int thdId = threadIdx.x;
int laneId = threadIdx.x & (warpSize - 1);
int globalThdId = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -50,12 +51,12 @@ inline __device__ void CrossDeviceBarrierIntraNodeKernel(EpDispatchCombineArgs<T
args.combineGridBarrier[0] = 0;
core::AtomicStoreRelaxedSystem(
args.crossDeviceBarrierMemObj->template GetAs<uint32_t*>(globalThdId) + args.config.rank,
*args.crossDeviceBarrierFlag);
crossDeviceBarrierFlag);
}

uint32_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs<uint32_t*>();
if (thdId < args.config.worldSize) {
while (core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId) != *args.crossDeviceBarrierFlag) {
while (core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId) != crossDeviceBarrierFlag) {
}
}
__syncthreads();
Expand Down Expand Up @@ -206,7 +207,8 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {
int myPe = config.rank;
int npes = config.worldSize;

size_t maxNumTokensToSend = config.MaxNumTokensToSend();
const uint32_t crossDeviceBarrierFlag = args.crossDeviceBarrierFlag[0];
size_t maxNumOutTokenPerRank = config.MaxNumTokensToSend();
// Copy input to shmem registered buffer so that other GPUs can access directly
index_t totalRecvTokenNum = args.totalRecvTokenNum[0];
if (args.config.useExternalInpBuffer) {
Expand All @@ -225,7 +227,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {
}

// Make sure copy on all GPUs are finished
CrossDeviceBarrierIntraNodeKernel(args);
CrossDeviceBarrierIntraNodeKernel(args, crossDeviceBarrierFlag);
*args.totalRecvTokenNum = 0;
if (args.curRankNumToken == 0) return;

Expand Down Expand Up @@ -271,6 +273,11 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {
srcWeightsPtr, nullptr, config.numExpertPerToken, config.numExpertPerToken);
}
}

if (globalThdId == 0) {
__hip_atomic_fetch_add(args.crossDeviceBarrierFlag, 1, __ATOMIC_RELEASE,
__HIP_MEMORY_SCOPE_SYSTEM);
}
}

} // namespace moe
Expand Down