Skip to content

Commit

Permalink
bootstrap now takes interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Saeed Maleki committed Jun 28, 2023
1 parent 21eed72 commit 417f7dc
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 153 deletions.
3 changes: 2 additions & 1 deletion include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class Bootstrap : public BaseBootstrap {
UniqueId getUniqueId() const;

void initialize(UniqueId uniqueId);
void initialize(std::string ipPortPair);
// the acceptable formats are "ip:port" or "interface:ip:port"
void initialize(std::string ifIpPortTrio);
int getRank() override;
int getNranks() override;
void send(void* data, int size, int peer, int tag) override;
Expand Down
133 changes: 78 additions & 55 deletions src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "socket.h"
#include "utils_internal.hpp"

using namespace mscclpp;
namespace mscclpp {

static void setFilesLimit() {
rlimit filesLimit;
Expand All @@ -32,8 +32,8 @@ enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
struct ExtInfo {
int rank;
int nRanks;
mscclppSocketAddress extAddressListenRoot;
mscclppSocketAddress extAddressListen;
SocketAddress extAddressListenRoot;
SocketAddress extAddressListen;
};

MSCCLPP_API_CPP void BaseBootstrap::send(const std::vector<char>& data, int peer, int tag) {
Expand All @@ -51,7 +51,7 @@ MSCCLPP_API_CPP void BaseBootstrap::recv(std::vector<char>& data, int peer, int

struct UniqueIdInternal {
uint64_t magic;
union mscclppSocketAddress addr;
union SocketAddress addr;
};
static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId");

Expand All @@ -60,7 +60,7 @@ class Bootstrap::Impl {
Impl(int rank, int nRanks);
~Impl();
void initialize(const UniqueId uniqueId);
void initialize(std::string ipPortPair);
void initialize(std::string ifIpPortTrio);
void establishConnections();
UniqueId createUniqueId();
UniqueId getUniqueId() const;
Expand All @@ -81,13 +81,13 @@ class Bootstrap::Impl {
std::unique_ptr<Socket> listenSock_;
std::unique_ptr<Socket> ringRecvSocket_;
std::unique_ptr<Socket> ringSendSocket_;
std::vector<mscclppSocketAddress> peerCommAddresses_;
std::vector<SocketAddress> peerCommAddresses_;
std::vector<int> barrierArr_;
std::unique_ptr<uint32_t> abortFlagStorage_;
volatile uint32_t* abortFlag_;
std::thread rootThread_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
mscclppSocketAddress netIfAddr_;
SocketAddress netIfAddr_;
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerSendSockets_;
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerRecvSockets_;

Expand All @@ -99,18 +99,18 @@ class Bootstrap::Impl {

void bootstrapCreateRoot();
void bootstrapRoot();
void getRemoteAddresses(Socket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank);
void sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot);
void netInit(std::string ipPortPair);
void getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
std::vector<SocketAddress>& rankAddressesRoot, int& rank);
void sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
const std::vector<SocketAddress>& rankAddressesRoot);
void netInit(std::string ipPortPair, std::string interface);
};

Bootstrap::Impl::Impl(int rank, int nRanks)
: rank_(rank),
nRanks_(nRanks),
netInitialized(false),
peerCommAddresses_(nRanks, mscclppSocketAddress()),
peerCommAddresses_(nRanks, SocketAddress()),
barrierArr_(nRanks, 0),
abortFlagStorage_(new uint32_t(0)),
abortFlag_(abortFlagStorage_.get()) {}
Expand All @@ -122,9 +122,9 @@ UniqueId Bootstrap::Impl::getUniqueId() const {
}

UniqueId Bootstrap::Impl::createUniqueId() {
netInit("");
netInit("", "");
getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress));
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
bootstrapCreateRoot();
return getUniqueId();
}
Expand All @@ -134,19 +134,34 @@ int Bootstrap::Impl::getRank() { return rank_; }
int Bootstrap::Impl::getNranks() { return nRanks_; }

void Bootstrap::Impl::initialize(const UniqueId uniqueId) {
netInit("");
netInit("", "");

std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));

establishConnections();
}

void Bootstrap::Impl::initialize(std::string ipPortPair) {
netInit(ipPortPair);
void Bootstrap::Impl::initialize(std::string ifIpPortTrio) {
// first check if it is a trio
int nColons = 0;
for (auto c : ifIpPortTrio) {
if (c == ':') {
nColons++;
}
}
std::string ipPortPair = ifIpPortTrio;
std::string interface = "";
if (nColons == 2) {
// we know the <interface>
interface = ifIpPortTrio.substr(0, ipPortPair.find_first_of(':'));
ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1);
}

netInit(ipPortPair, interface);

uniqueId_.magic = 0xdeadbeef;
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress));
mscclppSocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str());
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
SocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str());

if (rank_ == 0) {
bootstrapCreateRoot();
Expand All @@ -164,14 +179,14 @@ Bootstrap::Impl::~Impl() {
}
}

void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank) {
void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
std::vector<SocketAddress>& rankAddressesRoot, int& rank) {
ExtInfo info;
mscclppSocketAddress zero;
std::memset(&zero, 0, sizeof(mscclppSocketAddress));
SocketAddress zero;
std::memset(&zero, 0, sizeof(SocketAddress));

{
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
sock.accept(listenSock);
netRecv(&sock, &info, sizeof(info));
}
Expand All @@ -182,7 +197,7 @@ void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<mscclpp
ErrorCode::InternalError);
}

if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(SocketAddress)) != 0) {
throw Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) +
" has already checked in",
ErrorCode::InternalError);
Expand All @@ -194,17 +209,16 @@ void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<mscclpp
rank = info.rank;
}

void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot) {
void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
const std::vector<SocketAddress>& rankAddressesRoot) {
int next = (peer + 1) % nRanks_;
Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, mscclppSocketTypeBootstrap, abortFlag_);
Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_);
sock.connect();
netSend(&sock, &rankAddresses[next], sizeof(mscclppSocketAddress));
netSend(&sock, &rankAddresses[next], sizeof(SocketAddress));
}

void Bootstrap::Impl::bootstrapCreateRoot() {
listenSockRoot_ =
std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, abortFlag_, 0);
listenSockRoot_ = std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0);
listenSockRoot_->listen();
uniqueId_.addr = listenSockRoot_->getAddr();

Expand All @@ -220,12 +234,12 @@ void Bootstrap::Impl::bootstrapCreateRoot() {

void Bootstrap::Impl::bootstrapRoot() {
int numCollected = 0;
std::vector<mscclppSocketAddress> rankAddresses(nRanks_, mscclppSocketAddress());
std::vector<SocketAddress> rankAddresses(nRanks_, SocketAddress());
// for initial rank <-> root information exchange
std::vector<mscclppSocketAddress> rankAddressesRoot(nRanks_, mscclppSocketAddress());
std::vector<SocketAddress> rankAddressesRoot(nRanks_, SocketAddress());

std::memset(rankAddresses.data(), 0, sizeof(mscclppSocketAddress) * nRanks_);
std::memset(rankAddressesRoot.data(), 0, sizeof(mscclppSocketAddress) * nRanks_);
std::memset(rankAddresses.data(), 0, sizeof(SocketAddress) * nRanks_);
std::memset(rankAddressesRoot.data(), 0, sizeof(SocketAddress) * nRanks_);
setFilesLimit();

TRACE(MSCCLPP_INIT, "BEGIN");
Expand All @@ -252,24 +266,32 @@ void Bootstrap::Impl::bootstrapRoot() {
TRACE(MSCCLPP_INIT, "DONE");
}

void Bootstrap::Impl::netInit(std::string ipPortPair) {
void Bootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
if (netInitialized) return;
if (!ipPortPair.empty()) {
mscclppSocketAddress remoteAddr;
mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str());
if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError);
if (interface != "") {
// we know the <interface>
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1, interface.c_str());
if (ret <= 0) throw Error("NET/Socket : No interface named " + interface + " found.", ErrorCode::InternalError);
} else {
// we do not know the <interface> try to match it next
SocketAddress remoteAddr;
SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str());
if (FindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError);
}
}

} else {
int ret = mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
if (ret <= 0) {
throw Error("Bootstrap : no socket interface found", ErrorCode::InternalError);
}
}

char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
std::sprintf(line, " %s:", netIfName_);
mscclppSocketToString(&netIfAddr_, line + strlen(line));
SocketToString(&netIfAddr_, line + strlen(line));
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
netInitialized = true;
}
Expand All @@ -289,7 +311,7 @@ void Bootstrap::Impl::netInit(std::string ipPortPair) {
void Bootstrap::Impl::establishConnections() {
const int64_t connectionTimeoutUs = (int64_t)Config::getInstance()->getBootstrapConnectionTimeoutConfig() * 1000000;
Timer timer;
mscclppSocketAddress nextAddr;
SocketAddress nextAddr;
ExtInfo info;

TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_);
Expand All @@ -305,13 +327,13 @@ void Bootstrap::Impl::establishConnections() {

uint64_t magic = uniqueId_.magic;
// Create socket for other ranks to contact me
listenSock_ = std::make_unique<Socket>(&netIfAddr_, magic, mscclppSocketTypeBootstrap, abortFlag_);
listenSock_ = std::make_unique<Socket>(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
listenSock_->listen();
info.extAddressListen = listenSock_->getAddr();

{
// Create socket for root to contact me
Socket lsock(&netIfAddr_, magic, mscclppSocketTypeBootstrap, abortFlag_);
Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
lsock.listen();
info.extAddressListenRoot = lsock.getAddr();

Expand All @@ -329,28 +351,28 @@ void Bootstrap::Impl::establishConnections() {

// send info on my listening socket to root
{
Socket sock(&uniqueId_.addr, magic, mscclppSocketTypeBootstrap, abortFlag_);
Socket sock(&uniqueId_.addr, magic, SocketTypeBootstrap, abortFlag_);
TIMEOUT(sock.connect(getLeftTime()));
netSend(&sock, &info, sizeof(info));
}

// get info on my "next" rank in the bootstrap ring from root
{
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
TIMEOUT(sock.accept(&lsock, getLeftTime()));
netRecv(&sock, &nextAddr, sizeof(mscclppSocketAddress));
netRecv(&sock, &nextAddr, sizeof(SocketAddress));
}
}

ringSendSocket_ = std::make_unique<Socket>(&nextAddr, magic, mscclppSocketTypeBootstrap, abortFlag_);
ringSendSocket_ = std::make_unique<Socket>(&nextAddr, magic, SocketTypeBootstrap, abortFlag_);
TIMEOUT(ringSendSocket_->connect(getLeftTime()));
// Accept the connect request from the previous rank in the AllGather ring
ringRecvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
ringRecvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
TIMEOUT(ringRecvSocket_->accept(listenSock_.get(), getLeftTime()));

// AllGather all listen handlers
peerCommAddresses_[rank_] = listenSock_->getAddr();
allGather(peerCommAddresses_.data(), sizeof(mscclppSocketAddress));
allGather(peerCommAddresses_.data(), sizeof(SocketAddress));

TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
}
Expand Down Expand Up @@ -384,8 +406,7 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerSendSocket(int peer, int tag) {
if (it != peerSendSockets_.end()) {
return it->second;
}
auto sock =
std::make_shared<Socket>(&peerCommAddresses_[peer], uniqueId_.magic, mscclppSocketTypeBootstrap, abortFlag_);
auto sock = std::make_shared<Socket>(&peerCommAddresses_[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_);
sock->connect();
netSend(sock.get(), &rank_, sizeof(int));
netSend(sock.get(), &tag, sizeof(int));
Expand All @@ -399,7 +420,7 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
return it->second;
}
for (;;) {
auto sock = std::make_shared<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
auto sock = std::make_shared<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
sock->accept(listenSock_.get());
int recvPeer, recvTag;
netRecv(sock.get(), &recvPeer, sizeof(int));
Expand Down Expand Up @@ -471,3 +492,5 @@ MSCCLPP_API_CPP void Bootstrap::initialize(std::string ipPortPair) { pimpl_->ini
MSCCLPP_API_CPP void Bootstrap::barrier() { pimpl_->barrier(); }

MSCCLPP_API_CPP Bootstrap::~Bootstrap() { pimpl_->close(); }

} // namespace mscclpp
Loading

0 comments on commit 417f7dc

Please sign in to comment.