Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bootstrap now takes interface #113

Merged
merged 2 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class Bootstrap : public BaseBootstrap {
UniqueId getUniqueId() const;

void initialize(UniqueId uniqueId);
void initialize(std::string ipPortPair);
// the acceptable formats are "ip:port" or "interface:ip:port"
void initialize(std::string ifIpPortTrio);
int getRank() override;
int getNranks() override;
void send(void* data, int size, int peer, int tag) override;
Expand Down
133 changes: 78 additions & 55 deletions src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "socket.h"
#include "utils_internal.hpp"

using namespace mscclpp;
namespace mscclpp {

static void setFilesLimit() {
rlimit filesLimit;
Expand All @@ -32,8 +32,8 @@ enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
struct ExtInfo {
int rank;
int nRanks;
mscclppSocketAddress extAddressListenRoot;
mscclppSocketAddress extAddressListen;
SocketAddress extAddressListenRoot;
SocketAddress extAddressListen;
};

MSCCLPP_API_CPP void BaseBootstrap::send(const std::vector<char>& data, int peer, int tag) {
Expand All @@ -51,7 +51,7 @@ MSCCLPP_API_CPP void BaseBootstrap::recv(std::vector<char>& data, int peer, int

struct UniqueIdInternal {
uint64_t magic;
union mscclppSocketAddress addr;
union SocketAddress addr;
};
static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId");

Expand All @@ -60,7 +60,7 @@ class Bootstrap::Impl {
Impl(int rank, int nRanks);
~Impl();
void initialize(const UniqueId uniqueId);
void initialize(std::string ipPortPair);
void initialize(std::string ifIpPortTrio);
void establishConnections();
UniqueId createUniqueId();
UniqueId getUniqueId() const;
Expand All @@ -81,13 +81,13 @@ class Bootstrap::Impl {
std::unique_ptr<Socket> listenSock_;
std::unique_ptr<Socket> ringRecvSocket_;
std::unique_ptr<Socket> ringSendSocket_;
std::vector<mscclppSocketAddress> peerCommAddresses_;
std::vector<SocketAddress> peerCommAddresses_;
std::vector<int> barrierArr_;
std::unique_ptr<uint32_t> abortFlagStorage_;
volatile uint32_t* abortFlag_;
std::thread rootThread_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
mscclppSocketAddress netIfAddr_;
SocketAddress netIfAddr_;
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerSendSockets_;
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerRecvSockets_;

Expand All @@ -99,18 +99,18 @@ class Bootstrap::Impl {

void bootstrapCreateRoot();
void bootstrapRoot();
void getRemoteAddresses(Socket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank);
void sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot);
void netInit(std::string ipPortPair);
void getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
std::vector<SocketAddress>& rankAddressesRoot, int& rank);
void sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
const std::vector<SocketAddress>& rankAddressesRoot);
void netInit(std::string ipPortPair, std::string interface);
};

Bootstrap::Impl::Impl(int rank, int nRanks)
: rank_(rank),
nRanks_(nRanks),
netInitialized(false),
peerCommAddresses_(nRanks, mscclppSocketAddress()),
peerCommAddresses_(nRanks, SocketAddress()),
barrierArr_(nRanks, 0),
abortFlagStorage_(new uint32_t(0)),
abortFlag_(abortFlagStorage_.get()) {}
Expand All @@ -122,9 +122,9 @@ UniqueId Bootstrap::Impl::getUniqueId() const {
}

UniqueId Bootstrap::Impl::createUniqueId() {
netInit("");
netInit("", "");
getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress));
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
bootstrapCreateRoot();
return getUniqueId();
}
Expand All @@ -134,19 +134,34 @@ int Bootstrap::Impl::getRank() { return rank_; }
int Bootstrap::Impl::getNranks() { return nRanks_; }

void Bootstrap::Impl::initialize(const UniqueId uniqueId) {
netInit("");
netInit("", "");

std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));

establishConnections();
}

void Bootstrap::Impl::initialize(std::string ipPortPair) {
netInit(ipPortPair);
void Bootstrap::Impl::initialize(std::string ifIpPortTrio) {
// first check if it is a trio
int nColons = 0;
for (auto c : ifIpPortTrio) {
if (c == ':') {
nColons++;
}
}
std::string ipPortPair = ifIpPortTrio;
std::string interface = "";
if (nColons == 2) {
// we know the <interface>
interface = ifIpPortTrio.substr(0, ipPortPair.find_first_of(':'));
ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1);
}

netInit(ipPortPair, interface);

uniqueId_.magic = 0xdeadbeef;
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress));
mscclppSocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str());
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
SocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str());

if (rank_ == 0) {
bootstrapCreateRoot();
Expand All @@ -164,14 +179,14 @@ Bootstrap::Impl::~Impl() {
}
}

void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank) {
void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
std::vector<SocketAddress>& rankAddressesRoot, int& rank) {
ExtInfo info;
mscclppSocketAddress zero;
std::memset(&zero, 0, sizeof(mscclppSocketAddress));
SocketAddress zero;
std::memset(&zero, 0, sizeof(SocketAddress));

{
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
sock.accept(listenSock);
netRecv(&sock, &info, sizeof(info));
}
Expand All @@ -182,7 +197,7 @@ void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<mscclpp
ErrorCode::InternalError);
}

if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(SocketAddress)) != 0) {
throw Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) +
" has already checked in",
ErrorCode::InternalError);
Expand All @@ -194,17 +209,16 @@ void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<mscclpp
rank = info.rank;
}

void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot) {
void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
const std::vector<SocketAddress>& rankAddressesRoot) {
int next = (peer + 1) % nRanks_;
Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, mscclppSocketTypeBootstrap, abortFlag_);
Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_);
sock.connect();
netSend(&sock, &rankAddresses[next], sizeof(mscclppSocketAddress));
netSend(&sock, &rankAddresses[next], sizeof(SocketAddress));
}

void Bootstrap::Impl::bootstrapCreateRoot() {
listenSockRoot_ =
std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, abortFlag_, 0);
listenSockRoot_ = std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0);
listenSockRoot_->listen();
uniqueId_.addr = listenSockRoot_->getAddr();

Expand All @@ -220,12 +234,12 @@ void Bootstrap::Impl::bootstrapCreateRoot() {

void Bootstrap::Impl::bootstrapRoot() {
int numCollected = 0;
std::vector<mscclppSocketAddress> rankAddresses(nRanks_, mscclppSocketAddress());
std::vector<SocketAddress> rankAddresses(nRanks_, SocketAddress());
// for initial rank <-> root information exchange
std::vector<mscclppSocketAddress> rankAddressesRoot(nRanks_, mscclppSocketAddress());
std::vector<SocketAddress> rankAddressesRoot(nRanks_, SocketAddress());

std::memset(rankAddresses.data(), 0, sizeof(mscclppSocketAddress) * nRanks_);
std::memset(rankAddressesRoot.data(), 0, sizeof(mscclppSocketAddress) * nRanks_);
std::memset(rankAddresses.data(), 0, sizeof(SocketAddress) * nRanks_);
std::memset(rankAddressesRoot.data(), 0, sizeof(SocketAddress) * nRanks_);
setFilesLimit();

TRACE(MSCCLPP_INIT, "BEGIN");
Expand All @@ -252,24 +266,32 @@ void Bootstrap::Impl::bootstrapRoot() {
TRACE(MSCCLPP_INIT, "DONE");
}

void Bootstrap::Impl::netInit(std::string ipPortPair) {
void Bootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
if (netInitialized) return;
if (!ipPortPair.empty()) {
mscclppSocketAddress remoteAddr;
mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str());
if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError);
if (interface != "") {
// we know the <interface>
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1, interface.c_str());
if (ret <= 0) throw Error("NET/Socket : No interface named " + interface + " found.", ErrorCode::InternalError);
} else {
// we do not know the <interface> try to match it next
SocketAddress remoteAddr;
SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str());
if (FindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError);
}
}

} else {
int ret = mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
if (ret <= 0) {
throw Error("Bootstrap : no socket interface found", ErrorCode::InternalError);
}
}

char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
std::sprintf(line, " %s:", netIfName_);
mscclppSocketToString(&netIfAddr_, line + strlen(line));
SocketToString(&netIfAddr_, line + strlen(line));
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
netInitialized = true;
}
Expand All @@ -289,7 +311,7 @@ void Bootstrap::Impl::netInit(std::string ipPortPair) {
void Bootstrap::Impl::establishConnections() {
const int64_t connectionTimeoutUs = (int64_t)Config::getInstance()->getBootstrapConnectionTimeoutConfig() * 1000000;
Timer timer;
mscclppSocketAddress nextAddr;
SocketAddress nextAddr;
ExtInfo info;

TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_);
Expand All @@ -305,13 +327,13 @@ void Bootstrap::Impl::establishConnections() {

uint64_t magic = uniqueId_.magic;
// Create socket for other ranks to contact me
listenSock_ = std::make_unique<Socket>(&netIfAddr_, magic, mscclppSocketTypeBootstrap, abortFlag_);
listenSock_ = std::make_unique<Socket>(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
listenSock_->listen();
info.extAddressListen = listenSock_->getAddr();

{
// Create socket for root to contact me
Socket lsock(&netIfAddr_, magic, mscclppSocketTypeBootstrap, abortFlag_);
Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
lsock.listen();
info.extAddressListenRoot = lsock.getAddr();

Expand All @@ -329,28 +351,28 @@ void Bootstrap::Impl::establishConnections() {

// send info on my listening socket to root
{
Socket sock(&uniqueId_.addr, magic, mscclppSocketTypeBootstrap, abortFlag_);
Socket sock(&uniqueId_.addr, magic, SocketTypeBootstrap, abortFlag_);
TIMEOUT(sock.connect(getLeftTime()));
netSend(&sock, &info, sizeof(info));
}

// get info on my "next" rank in the bootstrap ring from root
{
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
Socket sock(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
TIMEOUT(sock.accept(&lsock, getLeftTime()));
netRecv(&sock, &nextAddr, sizeof(mscclppSocketAddress));
netRecv(&sock, &nextAddr, sizeof(SocketAddress));
}
}

ringSendSocket_ = std::make_unique<Socket>(&nextAddr, magic, mscclppSocketTypeBootstrap, abortFlag_);
ringSendSocket_ = std::make_unique<Socket>(&nextAddr, magic, SocketTypeBootstrap, abortFlag_);
TIMEOUT(ringSendSocket_->connect(getLeftTime()));
// Accept the connect request from the previous rank in the AllGather ring
ringRecvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
ringRecvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
TIMEOUT(ringRecvSocket_->accept(listenSock_.get(), getLeftTime()));

// AllGather all listen handlers
peerCommAddresses_[rank_] = listenSock_->getAddr();
allGather(peerCommAddresses_.data(), sizeof(mscclppSocketAddress));
allGather(peerCommAddresses_.data(), sizeof(SocketAddress));

TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
}
Expand Down Expand Up @@ -384,8 +406,7 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerSendSocket(int peer, int tag) {
if (it != peerSendSockets_.end()) {
return it->second;
}
auto sock =
std::make_shared<Socket>(&peerCommAddresses_[peer], uniqueId_.magic, mscclppSocketTypeBootstrap, abortFlag_);
auto sock = std::make_shared<Socket>(&peerCommAddresses_[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_);
sock->connect();
netSend(sock.get(), &rank_, sizeof(int));
netSend(sock.get(), &tag, sizeof(int));
Expand All @@ -399,7 +420,7 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
return it->second;
}
for (;;) {
auto sock = std::make_shared<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, mscclppSocketTypeUnknown, abortFlag_);
auto sock = std::make_shared<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
sock->accept(listenSock_.get());
int recvPeer, recvTag;
netRecv(sock.get(), &recvPeer, sizeof(int));
Expand Down Expand Up @@ -471,3 +492,5 @@ MSCCLPP_API_CPP void Bootstrap::initialize(std::string ipPortPair) { pimpl_->ini
MSCCLPP_API_CPP void Bootstrap::barrier() { pimpl_->barrier(); }

MSCCLPP_API_CPP Bootstrap::~Bootstrap() { pimpl_->close(); }

} // namespace mscclpp
Loading