Skip to content

Commit

Permalink
Fixing RegisterMemory Allocation for ProxyChannels (#353)
Browse files Browse the repository at this point in the history
Co-authored-by: Binyang Li <[email protected]>
Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent 8a330f9 commit 08a0cec
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 47 deletions.
34 changes: 25 additions & 9 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,47 @@ def dtype_to_mscclpp_dtype(dtype):


def main(
execution_paln_name: str,
execution_plan_name: str,
execution_plan_path: str,
size: int,
in_place: bool = True,
dtype: cp.dtype = cp.float16,
packet_type: PacketType = PacketType.LL16,
seed: int = 42 + MPI.COMM_WORLD.rank,
seed: int = 42,
):
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
executor = Executor(mscclpp_group.communicator)
npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR")
if npkit_dump_dir is not None:
npkit.init(mscclpp_group.my_rank)
execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path)
execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path)

cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
sendbuf = cp.random.random(nelems).astype(dtype)
expected = cp.asnumpy(sendbuf)
expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM)
buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = cp.zeros(nelems, dtype=dtype)
for i in range(nelems):
sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i]

if "allgather" in execution_plan_name:
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
expected = buffer
else:
cp.random.seed(seed)
recvbuf = cp.zeros(nelems, dtype=dtype)
expected = cp.zeros_like(sendbuf, dtype=dtype)
for i in range(mscclpp_group.nranks):
expected += sub_arrays[i]
mscclpp_group.barrier()

executor_func = lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
sendbuf.data.ptr,
sendbuf.nbytes,
sendbuf.data.ptr if in_place else recvbuf.data.ptr,
sendbuf.nbytes,
sendbuf.nbytes if in_place else recvbuf.nbytes,
dtype_to_mscclpp_dtype(dtype),
execution_plan,
stream.ptr,
Expand All @@ -115,7 +128,8 @@ def main(
stream = cp.cuda.Stream(non_blocking=True)
executor_func(stream)
stream.synchronize()
assert cp.allclose(sendbuf, expected, atol=1e-2 * mscclpp_group.nranks)

assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks)

mscclpp_group.barrier()
execution_time = bench_time(100, 10, executor_func)
Expand All @@ -136,6 +150,7 @@ def main(
parser.add_argument("-n", "--execution_plan_name", type=str, required=True)
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
parser.add_argument("--size", type=str, required=True)
parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation")
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
parser.add_argument("--seed", type=int, default=42)
Expand All @@ -151,6 +166,7 @@ def main(
args.execution_plan_name,
args.execution_plan_path,
buffer_size,
args.in_place,
dtype,
packet_type,
args.seed,
Expand Down
17 changes: 10 additions & 7 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

namespace mscclpp {

void validateTransport(RegisteredMemory mem, Transport transport) {
void validateTransport(RegisteredMemory mem, Transport transport, uint64_t offset = 0, uint64_t size = 0) {
if (!mem.transports().has(transport)) {
throw Error("RegisteredMemory does not support this transport", ErrorCode::InvalidUsage);
}
if (offset + size > mem.size()) {
throw Error("RegisteredMemory out of bounds", ErrorCode::InvalidUsage);
}
}

// Connection
Expand Down Expand Up @@ -59,8 +62,8 @@ Transport CudaIpcConnection::remoteTransport() { return Transport::CudaIpc; }

void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
validateTransport(dst, remoteTransport(), dstOffset, size);
validateTransport(src, transport(), srcOffset, size);

char* dstPtr = (char*)dst.data();
char* srcPtr = (char*)src.data();
Expand Down Expand Up @@ -115,8 +118,8 @@ Transport IBConnection::remoteTransport() { return remoteTransport_; }

void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
validateTransport(dst, remoteTransport(), dstOffset, size);
validateTransport(src, transport(), srcOffset, size);

auto dstTransportInfo = getImpl(dst)->getTransportInfo(remoteTransport());
if (dstTransportInfo.ibLocal) {
Expand Down Expand Up @@ -231,8 +234,8 @@ Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; }
void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
// Validating Transport Protocol
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
validateTransport(dst, remoteTransport(), dstOffset, size);
validateTransport(src, transport(), srcOffset, size);

// Initializing Variables
char* srcPtr = reinterpret_cast<char*>(src.data()) + srcOffset / sizeof(char);
Expand Down
67 changes: 36 additions & 31 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct Executor::Impl {
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
this->setupConnections(context, rank, plan);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
this->setupDeviceExecutionPlan(context, rank, plan);
context.deviceExecutionPlansBuffer =
allocExtSharedCuda<char>(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan));
Expand All @@ -119,6 +119,23 @@ struct Executor::Impl {
return context;
}

TransportFlags getTransportFlags(std::vector<ChannelInfo>& infos, int rank) {
TransportFlags flags;
for (ChannelInfo& info : infos) {
if (info.channelType == ChannelType::SM) {
flags |= Transport::CudaIpc;
} else if (info.channelType == ChannelType::PROXY) {
for (int peer : info.connectedPeers) {
if (!inSameNode(rank, peer, this->nranksPerNode)) {
flags |= IBs[rank % this->nranksPerNode];
} else
flags |= Transport::CudaIpc;
}
}
}
return flags;
};

void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) {
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers(rank);
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
Expand All @@ -135,22 +152,6 @@ struct Executor::Impl {

void setupRegisteredMemories(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
auto getTransportFlags = [&](std::vector<ChannelInfo>& infos, int rank) {
TransportFlags flags;
for (ChannelInfo& info : infos) {
if (info.channelType == ChannelType::SM) {
flags |= Transport::CudaIpc;
} else if (info.channelType == ChannelType::PROXY) {
for (int peer : info.connectedPeers) {
if (!inSameNode(rank, peer, this->nranksPerNode)) {
flags |= IBs[rank % this->nranksPerNode];
} else
flags |= Transport::CudaIpc;
}
}
}
return flags;
};
auto getBufferInfo = [&](BufferType type) {
switch (type) {
case BufferType::INPUT:
Expand Down Expand Up @@ -192,22 +193,12 @@ struct Executor::Impl {
comm->setup();
for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) {
context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get());
CUdeviceptr myRegBaseAdr, peerRegBaseAdr;
size_t temp;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&myRegBaseAdr, &temp, (CUdeviceptr)(char*)memory.data()));
MSCCLPP_CUTHROW(cuMemGetAddressRange(
&peerRegBaseAdr, &temp,
(CUdeviceptr)(char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data()));
size_t myRegOffset = (char*)memory.data() - (char*)myRegBaseAdr;
size_t peerRegOffset =
(char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data() - (char*)peerRegBaseAdr;
if (myRegOffset != peerRegOffset) throw Error("Divergent data offset between peers", ErrorCode::ExecutorError);
}
}
}

void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, int rank,
const ExecutionPlan& plan) {
void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
const auto channelTypes = {ChannelType::SM, ChannelType::PROXY};
std::vector<std::shared_ptr<SmDevice2DeviceSemaphore>> smSemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
Expand Down Expand Up @@ -251,13 +242,27 @@ struct Executor::Impl {
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
};
auto getBufferSize = [&](BufferType type) {
switch (type) {
case BufferType::INPUT:
return sendBufferSize;
case BufferType::OUTPUT:
return recvBufferSize;
case BufferType::SCRATCH:
return context.scratchBufferSize;
default:
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
};

for (ChannelType channelType : channelTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(rank, channelType);
int index = 0;
for (ChannelInfo& info : channelInfos) {
void* src = getBuffer(info.srcBufferType);
TransportFlags transport = context.registeredMemories.begin()->second.transports();
RegisteredMemory localMemory = this->comm->registerMemory(src, sendBufferSize, transport);
size_t bufferSize = getBufferSize(info.srcBufferType);
TransportFlags transport = getTransportFlags(channelInfos, rank);
RegisteredMemory localMemory = this->comm->registerMemory(src, bufferSize, transport);
for (int peer : info.connectedPeers) {
if (channelType == ChannelType::SM) {
context.smChannels.emplace_back(context.smSemaphores[index++],
Expand Down

0 comments on commit 08a0cec

Please sign in to comment.