Skip to content

Commit

Permalink
Force push the diff so far
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Jun 19, 2023
1 parent cd7797f commit 21bd57b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 15 deletions.
77 changes: 69 additions & 8 deletions include/mscclpp/channel.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef MSCCLPP_CHANNEL_HPP_
#define MSCCLPP_CHANNEL_HPP_

#include <mscclpp/concurrency.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/epoch.hpp>
#include <mscclpp/fifo.hpp>
Expand Down Expand Up @@ -175,18 +176,57 @@ struct DeviceChannel {
}

__forceinline__ __device__ void wait() { epoch_.wait(); }

#endif // __CUDACC__
protected:
DeviceChannel(ChannelId channelId, DeviceProxyFifo fifo) : channelId_(channelId), fifo_(fifo) {}

ChannelId channelId_;

DeviceEpoch::DeviceHandle epoch_;

// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
DeviceProxyFifo fifo_;
};

struct DirectDeviceChannel : public DeviceChannel {
DirectDeviceChannel() = default;

DirectDeviceChannel(ChannelId channelId, DirectEpoch::DeviceHandle epoch, DeviceProxyFifo fifo)
: DeviceChannel(channelId, fifo), epoch_(epoch) {}

DirectDeviceChannel(const DirectDeviceChannel& other) = default;

DirectDeviceChannel& operator=(DirectDeviceChannel& other) = default;

#ifdef __CUDACC__
__forceinline__ __device__ void putPacket(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size, void* srcPtr, void* putPacketBuffer, uint32_t threadId,
uint32_t numThreads, uint32_t numBlocks, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
uint32_t* srcBase = (uint32_t*)((char*)srcPtr + srcOffset);
ChannelPacket* putPacketBufferBase = (ChannelPacket*)((char*)putPacketBuffer + dstOffset);
size_t nElem = size / sizeof(uint64_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
ChannelPacket* pkt = &putPacketBufferBase[i];
pkt->write(srcBase[2 * i], srcBase[2 * i + 1], flag);
}
devSyncer_.sync(numBlocks);
if (threadId == 0) {
// Send data from the local putPacketBuffer to the remote getPacketBuffer
put(dst, dstOffset, src, srcOffset, size * 2);
}
}

__forceinline__ __device__ void epochIncrement() { epoch_.epochIncrement(); }

__forceinline__ __device__ uint64_t epochGetLocal() const { return epoch_.epochGetLocal(); }

#endif // __CUDACC__

DirectEpoch::DeviceHandle epoch_;
DeviceSyncer devSyncer_; // do we need this?
};

class BaseChannelService {
public:
BaseChannelService() = default;
Expand Down Expand Up @@ -224,7 +264,14 @@ class DeviceChannelService : public BaseChannelService {
struct SimpleDeviceChannel {
SimpleDeviceChannel() = default;

SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src, void* srcPtr = nullptr,
void* putPacketBuffer = nullptr, void* getPacketBuffer = nullptr)
: devChan_(devChan),
dst_(dst),
src_(src),
srcPtr_(srcPtr),
putPacketBuffer_(putPacketBuffer),
getPacketBuffer_(getPacketBuffer) {}

SimpleDeviceChannel(DeviceChannel devChan) : devChan_(devChan) {}

Expand Down Expand Up @@ -259,19 +306,33 @@ struct SimpleDeviceChannel {

__forceinline__ __device__ void wait() { devChan_.wait(); }

__forceinline__ __device__ void putPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
uint32_t numThreads, uint32_t numBlocks, uint32_t flag) {
devChan_.putPacket(dst_, dstOffset, src_, srcOffset, size, srcPtr_, putPacketBuffer_, threadId, numThreads,
numBlocks, flag);
}

__forceinline__ __device__ void epochIncrement() { devChan_.epochIncrement(); }

__forceinline__ __device__ uint64_t epochGetLocal() const { return devChan_.epochGetLocal(); }
#endif // __CUDACC__

DeviceChannel devChan_;
MemoryId dst_;
MemoryId src_;

void* srcPtr_;
void* putPacketBuffer_;
void* getPacketBuffer_;
};

// A direct version of DeviceChannel only for CudaIpc
struct DirectChannel {
public:
DirectChannel() = default;
DirectChannel(DirectEpoch::DeviceHandle epoch, RegisteredMemory dst, void* src, void* tmp = nullptr)
: epoch_(epoch), src_(src), tmp_(tmp) {

DirectChannel(DirectEpoch::DeviceHandle epoch, RegisteredMemory dst, void* src, void* getPacketBuffer = nullptr)
: epoch_(epoch), src_(src), getPacketBuffer_(getPacketBuffer) {
if (!dst.transports().has(Transport::CudaIpc)) {
throw Error("DirectChannel: dst must be registered with CudaIpc", ErrorCode::InvalidUsage);
}
Expand Down Expand Up @@ -320,11 +381,11 @@ struct DirectChannel {
__forceinline__ __device__ void getPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
ChannelPacket* tmpBase = (ChannelPacket*)((char*)tmp_ + srcOffset);
ChannelPacket* getPacketBufferBase = (ChannelPacket*)((char*)getPacketBuffer_ + srcOffset);
uint2* srcBase = (uint2*)((char*)src_ + dstOffset);
size_t nElem = size / sizeof(uint2);
for (size_t i = threadId; i < nElem; i += numThreads) {
ChannelPacket* pkt = &tmpBase[i];
ChannelPacket* pkt = &getPacketBufferBase[i];
srcBase[i] = pkt->read(flag);
}
}
Expand All @@ -343,7 +404,7 @@ struct DirectChannel {
DirectEpoch::DeviceHandle epoch_;
void* src_;
void* dst_;
void* tmp_;
void* getPacketBuffer_;
};

} // namespace channel
Expand Down
20 changes: 13 additions & 7 deletions src/epoch.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <mscclpp/epoch.hpp>

#include "api.h"
#include "debug.h"

namespace mscclpp {

Expand Down Expand Up @@ -34,14 +35,19 @@ MSCCLPP_API_CPP DirectEpoch::DirectEpoch(Communicator& communicator, std::shared
: localInboundEpochId_(allocUniqueCuda<uint64_t>()),
expectedInboundEpochId_(allocUniqueCuda<uint64_t>()),
outboundEpochId_(allocUniqueCuda<uint64_t>()) {
if (connection->transport() != Transport::CudaIpc) {
throw Error("DirectEpoch can only be used with CudaIpc transport", ErrorCode::InvalidUsage);
if (connection->transport() == Transport::CudaIpc) {
auto localInboundEpochIdsRegMem =
communicator.registerMemory(localInboundEpochId_.get(), sizeof(uint64_t), connection->transport());

communicator.sendMemoryOnSetup(localInboundEpochIdsRegMem, connection->remoteRank(), connection->tag());
remoteInboundEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
INFO(MSCCLPP_INIT, "Creating a direct epoch for CudaIPC transport from %d to %d",
communicator.bootstrapper()->getRank(), connection->remoteRank());
} else if (AllIBTransports.has(connection->transport())) {
// We don't need to really with any of the IB transports, since the values will be local
INFO(MSCCLPP_INIT, "Creating a direct epoch for IB transport from %d to %d", communicator.bootstrapper()->getRank(),
connection->remoteRank());
}
auto localInboundEpochIdsRegMem =
communicator.registerMemory(localInboundEpochId_.get(), sizeof(uint64_t), connection->transport());

communicator.sendMemoryOnSetup(localInboundEpochIdsRegMem, connection->remoteRank(), connection->tag());
remoteInboundEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
}

MSCCLPP_API_CPP DirectEpoch::DeviceHandle DirectEpoch::deviceHandle() {
Expand Down

0 comments on commit 21bd57b

Please sign in to comment.