diff --git a/src/ops/dispatch_combine/internode.hpp b/src/ops/dispatch_combine/internode.hpp index 7dc26495..11a206ea 100644 --- a/src/ops/dispatch_combine/internode.hpp +++ b/src/ops/dispatch_combine/internode.hpp @@ -78,6 +78,8 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t scalesOffset = indicesOffset + sizeof(index_t) * numExpertPerToken; size_t stagingOffset = scalesOffset + config.scaleTypeSize * config.scaleDim; + const uint32_t crossDeviceBarrierFlag = args.crossDeviceBarrierFlag[0]; + extern __shared__ char sharedMem[]; int subWarpNumPerWarp = warpSize / numExpertPerToken; @@ -140,10 +142,9 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { const int startIdx = localBlockId * baseChunk + min(localBlockId, remainder); const int endIdx = startIdx + myChunkSize; if (localBlockId == 0 && warpId == warpNum - 1) { - shmem::ShmemPutInt32ImmNbiWarp( - args.recvTokenNumMemObj, - (myPe + (args.crossDeviceBarrierFlag[0] & 1) * npes) * sizeof(index_t), totalTokens, destPe, - localBlockId); + shmem::ShmemPutInt32ImmNbiWarp(args.recvTokenNumMemObj, + (myPe + (crossDeviceBarrierFlag & 1) * npes) * sizeof(index_t), + totalTokens, destPe, localBlockId); } if (destNode == myNode) { @@ -258,18 +259,17 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { __syncthreads(); if (warpId == warpNum - 1) { shmem::ShmemAtomicTypeNonFetchWarp( - args.sendAtomicSignalMemObj, - (myPe + (args.crossDeviceBarrierFlag[0] & 1) * npes) * sizeof(int64_t), 1, core::AMO_ADD, - destPe, localBlockId); + args.sendAtomicSignalMemObj, (myPe + (crossDeviceBarrierFlag & 1) * npes) * sizeof(int64_t), + 1, core::AMO_ADD, destPe, localBlockId); } if (thdId == 0) { int64_t* signal = args.sendAtomicSignalMemObj->template GetAs() + destPe + - (args.crossDeviceBarrierFlag[0] & 1) * npes; + (crossDeviceBarrierFlag & 1) * npes; shmem::ShmemInt64WaitUntilGreaterThan(signal, numsBlockPerDestPe - 1); - recvTokenNum = atomicAdd( - &args.recvTokenNumMemObj - ->template GetAs()[destPe + (args.crossDeviceBarrierFlag[0] & 1) * npes], - 0); + recvTokenNum = + atomicAdd(&args.recvTokenNumMemObj + ->template GetAs()[destPe + (crossDeviceBarrierFlag & 1) * npes], + 0); if (localBlockId == 0) { atomicAdd(args.totalRecvTokenNum, recvTokenNum); args.destPeTokenCounter[destPe] = 0; @@ -331,8 +331,8 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { /* BarrierKernel */ /* ---------------------------------------------------------------------------------------------- */ template -inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs args, - int numQps) { +inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs args, int numQps, + const uint32_t crossDeviceBarrierFlag) { int thdId = threadIdx.x; int laneId = threadIdx.x & (warpSize - 1); int globalThdId = blockIdx.x * blockDim.x + threadIdx.x; @@ -348,15 +348,16 @@ inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgstemplate GetAs(); + volatile uint64_t* localBarrierPtr = + args.crossDeviceBarrierMemObj->template GetAs(); if (thdId < args.config.worldSize) { uint64_t currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId); #if DEBUG == 1 printf("Thread %d: localBarrierPtr[%d] = %lu, expected = %lu\n", thdId, thdId, currentVal, - (uint64_t)(args.crossDeviceBarrierFlag[0] * numQps)); + (uint64_t)(crossDeviceBarrierFlag * numQps)); #endif - while (currentVal != args.crossDeviceBarrierFlag[0] * numQps) { + while (currentVal != crossDeviceBarrierFlag * numQps) { currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId); } } @@ -387,6 +388,8 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { size_t MaxNumTokensToSendPerRank = config.MaxNumTokensToSendPerRank(); size_t MaxNumTokensToRecvPerRank = config.MaxNumTokensToRecvPerRank(); + const uint32_t crossDeviceBarrierFlag = args.crossDeviceBarrierFlag[0]; + // Phase 1: send token // This phase is symmetric with dispatch recv phase, where tokens are first sent back to its // source pe in pe sorted order @@ -395,7 +398,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { const int srcNode = srcPe / MAX_GPUS_PER_NODE; const int localBlockId = blockIdx.x - srcPe * numsBlockPerSrcPe; const int srcPeTokenNum = *(args.recvTokenNumMemObj->template GetAs() + srcPe + - (args.crossDeviceBarrierFlag[0] & 1) * npes); + (crossDeviceBarrierFlag & 1) * npes); const int baseChunk = srcPeTokenNum / numsBlockPerSrcPe; const int remainder = srcPeTokenNum % numsBlockPerSrcPe; @@ -501,13 +504,13 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { SyncIfDebugEnabled("Combine kernel: send token end"); // Make sure copy on all GPUs are finished - CrossDeviceBarrierInterNodeKernel(args, numsBlockPerSrcPe); + CrossDeviceBarrierInterNodeKernel(args, numsBlockPerSrcPe, crossDeviceBarrierFlag); shmem::ShmemQuietThread(); if (globalThdId < npes) { args.recvTokenNumMemObj - ->template GetAs()[globalThdId + (args.crossDeviceBarrierFlag[0] & 1) * npes] = 0; + ->template GetAs()[globalThdId + (crossDeviceBarrierFlag & 1) * npes] = 0; args.sendAtomicSignalMemObj - ->template GetAs()[globalThdId + (args.crossDeviceBarrierFlag[0] & 1) * npes] = 0; + ->template GetAs()[globalThdId + (crossDeviceBarrierFlag & 1) * npes] = 0; } if (globalThdId == 0) { diff --git a/src/ops/dispatch_combine/intranode.hpp b/src/ops/dispatch_combine/intranode.hpp index 70189ef1..e3e746ff 100644 --- a/src/ops/dispatch_combine/intranode.hpp +++ b/src/ops/dispatch_combine/intranode.hpp @@ -34,7 +34,8 @@ namespace moe { /* BarrierKernel */ /* ---------------------------------------------------------------------------------------------- */ template -inline __device__ void CrossDeviceBarrierIntraNodeKernel(EpDispatchCombineArgs args) { +inline __device__ void CrossDeviceBarrierIntraNodeKernel(EpDispatchCombineArgs args, + const uint32_t crossDeviceBarrierFlag) { int thdId = threadIdx.x; int laneId = threadIdx.x & (warpSize - 1); int globalThdId = blockIdx.x * blockDim.x + threadIdx.x; @@ -50,12 +51,12 @@ inline __device__ void CrossDeviceBarrierIntraNodeKernel(EpDispatchCombineArgstemplate GetAs(globalThdId) + args.config.rank, - *args.crossDeviceBarrierFlag); + crossDeviceBarrierFlag); } uint32_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); if (thdId < args.config.worldSize) { - while (core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId) != *args.crossDeviceBarrierFlag) { + while (core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId) != crossDeviceBarrierFlag) { } } __syncthreads(); @@ -206,7 +207,8 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { int myPe = config.rank; int npes = config.worldSize; - size_t maxNumTokensToSend = config.MaxNumTokensToSend(); + const uint32_t crossDeviceBarrierFlag = args.crossDeviceBarrierFlag[0]; + size_t maxNumOutTokenPerRank = config.MaxNumTokensToSend(); // Copy input to shmem registered buffer so that other GPUs can access directly index_t totalRecvTokenNum = args.totalRecvTokenNum[0]; if (args.config.useExternalInpBuffer) { @@ -225,7 +227,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { } // Make sure copy on all GPUs are finished - CrossDeviceBarrierIntraNodeKernel(args); + CrossDeviceBarrierIntraNodeKernel(args, crossDeviceBarrierFlag); *args.totalRecvTokenNum = 0; if (args.curRankNumToken == 0) return; @@ -271,6 +273,11 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { srcWeightsPtr, nullptr, config.numExpertPerToken, config.numExpertPerToken); } } + + if (globalThdId == 0) { + __hip_atomic_fetch_add(args.crossDeviceBarrierFlag, 1, __ATOMIC_RELEASE, + __HIP_MEMORY_SCOPE_SYSTEM); + } } } // namespace moe