From 0eb387fdd073f1e63a12e891771fc7117829603f Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Fri, 23 Feb 2024 00:09:08 +0000 Subject: [PATCH 1/6] rdma: defer connect completion after sending connect message In the current implementation of connect/accept, it is possible for `accept` to complete (i.e., return a non-NULL communicator) after the corresponding `connect` returned a NULL communicator (while waiting for a completion for the connection message). This is a strange semantic, and evidently causes NCCL to be unhappy, particularly in the multi-recv case (which is being added in a future commit). So, after sending the connect message, defer waiting for completion; block when closing the send comm if necessary. Signed-off-by: Eric Raut --- include/nccl_ofi_rdma.h | 6 ++++ src/nccl_ofi_rdma.c | 80 ++++++++++++++++++++--------------------- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 9ec602f73..335ea8004 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -345,6 +345,12 @@ typedef struct nccl_net_ofi_rdma_send_comm { /* Comm ID provided by remote endpoint */ uint64_t remote_comm_id; + /* Request to send connect message */ + nccl_net_ofi_rdma_req_t *send_conn_req; + + /* Indicates if connect message was delivered (and req freed) */ + bool connect_msg_delivered; + /* Request to receive connect response message to finalize * connection establishment */ nccl_net_ofi_rdma_req_t *conn_resp_req; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 28cfd90e6..9afa5dd39 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -1407,7 +1407,18 @@ static inline int process_completions(struct fi_cq_tagged_entry *cq_entry, return ncclInternalError; } - if (IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag) && (comp_flags & FI_RECV)) { + if (req->type == NCCL_OFI_RDMA_SEND_CONN) { + assert(req->comm->type == NCCL_NET_OFI_SEND_COMM); + nccl_net_ofi_rdma_send_comm_t *s_comm = + (nccl_net_ofi_rdma_send_comm_t *)req->comm; + assert(req == s_comm->send_conn_req); + /* Release connect message request */ + req->free(req, false); + req = NULL; + s_comm->send_conn_req = NULL; + __sync_synchronize(); + s_comm->connect_msg_delivered = true; + } else if (IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag) && (comp_flags & FI_RECV)) { assert(req->comm->type == NCCL_NET_OFI_SEND_COMM); /* Complete send communicator */ nccl_net_ofi_rdma_send_comm_t *s_comm = @@ -4804,8 +4815,9 @@ static int blocked_send_close(nccl_net_ofi_send_comm_t *send_comm) return ncclInternalError; } - // TODO: We might want to use READ_ONCE to read variable `connected' - while (!s_comm->connected) { + // TODO: We might want to use READ_ONCE to read variables + // `connect_msg_delivered` and `connected' + while (!s_comm->connect_msg_delivered || !s_comm->connected) { __compiler_barrier(); int ret = 0; /* Progress our engine to get completions. If the @@ -5212,14 +5224,12 @@ static int connect(nccl_net_ofi_ep_t *base_ep, nccl_net_ofi_send_comm_t **send_comm) { int ret = 0; - nccl_net_ofi_rdma_req_state_t conn_msg_state; *send_comm = NULL; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; /* Extract connection state of the communicator */ save_comm_state_t *comm_state = &(handle->state); - nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)comm_state->req; nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)comm_state->comm; @@ -5259,23 +5269,22 @@ static int connect(nccl_net_ofi_ep_t *base_ep, comm_state->comm = &s_comm->base.base; /* Prepare connect request to be sent to peer */ - req = prepare_send_conn_req(s_comm); - if (OFI_UNLIKELY(req == NULL)) { + s_comm->send_conn_req = prepare_send_conn_req(s_comm); + if (OFI_UNLIKELY(s_comm->send_conn_req == NULL)) { send_close(s_comm); return ncclSystemError; } - comm_state->req = &req->base; comm_state->stage = COMM_SEND_CONN; case COMM_SEND_CONN: /* COMM_SEND_CONN: Post a connect message to send peer connections */ - ret = post_send_conn(s_comm, device, ep, req); + ret = post_send_conn(s_comm, device, ep, s_comm->send_conn_req); if (ret == -FI_EAGAIN) { return 0; } else if (ret != 0) { - req->free(req, false); + s_comm->send_conn_req->free(s_comm->send_conn_req, false); send_close(s_comm); return ret; } @@ -5296,29 +5305,6 @@ static int connect(nccl_net_ofi_ep_t *base_ep, return ret; } - /* Check if the connect message is sent */ - ret = pthread_mutex_lock(&req->req_lock); - if (OFI_UNLIKELY(ret)) { - NCCL_OFI_WARN("Unable to acquire req_lock mutex"); - return ncclInternalError; - } - conn_msg_state = req->state; - ret = pthread_mutex_unlock(&req->req_lock); - if (OFI_UNLIKELY(ret)) { - NCCL_OFI_WARN("Failed to unlock req_lock mutex"); - return ncclInternalError; - } - - /* Wait until connect message is sent */ - if (conn_msg_state != NCCL_OFI_RDMA_REQ_COMPLETED) { - return 0; - } - - /* Release connect message request */ - req->free(req, false); - comm_state->req = NULL; - req = NULL; - /* Prepare request to receive connect response message */ s_comm->conn_resp_req = prepare_recv_conn_resp_req(s_comm); if (OFI_UNLIKELY(s_comm->conn_resp_req == NULL)) { @@ -5328,15 +5314,27 @@ static int connect(nccl_net_ofi_ep_t *base_ep, comm_state->stage = COMM_RECV_CONN; - case COMM_RECV_CONN: + case COMM_RECV_CONN: { /* COMM_RECV_CONN: Receive connect response message from remote */ - ret = post_recv_conn_resp(s_comm, device, ep); - if (ret == -FI_EAGAIN) { - return 0; - } else if (ret != 0) { - send_close(s_comm); - return ret; + bool recv_conn_resp_posted = false; + while (!recv_conn_resp_posted) { + ret = post_recv_conn_resp(s_comm, device, ep); + if (ret == -FI_EAGAIN) { + /* Block until we post the connection response request. + EAGAIN only involves waiting for local resources to free up, so it + should be safe to block. */ + ret = ofi_process_cq(ep); + if (OFI_UNLIKELY(ret != 0)) { + send_close(s_comm); + return ret; + } + } else if (ret != 0) { + send_close(s_comm); + return ret; + } else { + recv_conn_resp_posted = true; + } } /* Progress our engine to get completions. If the @@ -5350,7 +5348,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, comm_state->stage = COMM_CONN_RESP_REQ_PENDING; break; - + } case COMM_CONN_RESP_REQ_PENDING: case COMM_CONNECTED: default: From d20ba9e8c58c8e18e6df37722243bf348d535e0f Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Thu, 22 Feb 2024 02:16:35 +0000 Subject: [PATCH 2/6] tests: set `nrecv=1` in functional tests These tests do not support multi-recv. Signed-off-by: Eric Raut --- tests/functional/nccl_message_transfer.c | 2 +- tests/functional/ring.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/nccl_message_transfer.c b/tests/functional/nccl_message_transfer.c index 9daa3323f..10537ed75 100644 --- a/tests/functional/nccl_message_transfer.c +++ b/tests/functional/nccl_message_transfer.c @@ -38,7 +38,7 @@ int main(int argc, char* argv[]) /* For grouped recvs */ int tag = 1; - int nrecv = NCCL_OFI_MAX_RECVS; + int nrecv = 1; int *sizes = (int *)malloc(sizeof(int)*nrecv); int *tags = (int *)malloc(sizeof(int)*nrecv); int recv_n; diff --git a/tests/functional/ring.c b/tests/functional/ring.c index 8fc6a98c2..ad141cbcb 100644 --- a/tests/functional/ring.c +++ b/tests/functional/ring.c @@ -39,7 +39,7 @@ int main(int argc, char *argv[]) /* For grouped receives */ int tag = 1; - int nrecv = NCCL_OFI_MAX_RECVS; + int nrecv = 1; int *sizes = (int *)malloc(sizeof(int)*nrecv); int *tags = (int *)malloc(sizeof(int)*nrecv); int recv_n; From 9207e74b36ad6a492c0f29692ea381387f91a7f2 Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Sun, 25 Feb 2024 20:02:11 +0000 Subject: [PATCH 3/6] tests/ring: Post receives before sends The previous implementation of the ring unit test posts sends before receives, which is incompatible with the current multi-recv behavior of rejecting sends until ctrl information is avaiable. Signed-off-by: Eric Raut --- tests/functional/ring.c | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/functional/ring.c b/tests/functional/ring.c index ad141cbcb..42f46a6fb 100644 --- a/tests/functional/ring.c +++ b/tests/functional/ring.c @@ -158,6 +158,20 @@ int main(int argc, char *argv[]) OFINCCLCHECK(extNet->accept((void *)lComm, (void **)&rComm, &r_ignore)); NCCL_OFI_INFO(NCCL_NET, "Successfully accepted connection from rank %d", prev); + /* Receive NUM_REQUESTS from prev rank */ + NCCL_OFI_INFO(NCCL_NET, "Rank %d posting %d receive buffers", rank, NUM_REQUESTS); + for (idx = 0; idx < NUM_REQUESTS; idx++) { + OFINCCLCHECK(allocate_buff((void **)&recv_buf[idx], RECV_SIZE, buffer_type)); + OFINCCLCHECK(extNet->regMr((void *)rComm, (void *)recv_buf[idx], RECV_SIZE, + buffer_type, &recv_mhandle[idx])); + NCCL_OFI_TRACE(NCCL_NET, "Successfully registered receive memory for request %d of rank %d", idx, rank); + + while (recv_req[idx] == NULL) { + OFINCCLCHECK(extNet->irecv((void *)rComm, nrecv, (void **)&recv_buf[idx], + sizes, tags, &recv_mhandle[idx], (void **)&recv_req[idx])); + } + } + /* Send NUM_REQUESTS to next rank */ NCCL_OFI_INFO(NCCL_NET, "Sending %d requests to rank %d", NUM_REQUESTS, next); for (idx = 0; idx < NUM_REQUESTS; idx++) { @@ -174,20 +188,6 @@ int main(int argc, char *argv[]) } } - /* Receive NUM_REQUESTS from prev rank */ - NCCL_OFI_INFO(NCCL_NET, "Rank %d posting %d receive buffers", rank, NUM_REQUESTS); - for (idx = 0; idx < NUM_REQUESTS; idx++) { - OFINCCLCHECK(allocate_buff((void **)&recv_buf[idx], RECV_SIZE, buffer_type)); - OFINCCLCHECK(extNet->regMr((void *)rComm, (void *)recv_buf[idx], RECV_SIZE, - buffer_type, &recv_mhandle[idx])); - NCCL_OFI_TRACE(NCCL_NET, "Successfully registered receive memory for request %d of rank %d", idx, rank); - - while (recv_req[idx] == NULL) { - OFINCCLCHECK(extNet->irecv((void *)rComm, nrecv, (void **)&recv_buf[idx], - sizes, tags, &recv_mhandle[idx], (void **)&recv_req[idx])); - } - } - /* Allocate and populate expected buffer */ char *expected_buf = NULL; OFINCCLCHECK(allocate_buff((void **)&expected_buf, SEND_SIZE, NCCL_PTR_HOST)); From 8c770ce45958c530b168712b9138be8514602ea9 Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Fri, 23 Feb 2024 22:08:53 +0000 Subject: [PATCH 4/6] msgbuff: support tags in preparation for multi-recv msgbuff functions accept tag and multi-recv information. RDMA protocol code is updated to pass dummy values for these fields. Multi-recv support for RDMA protocol will be an upcoming commit. Signed-off-by: Eric Raut --- include/nccl_ofi_msgbuff.h | 25 +++- src/nccl_ofi_msgbuff.c | 235 +++++++++++++++++++++++++++++++++---- src/nccl_ofi_rdma.c | 39 +++--- tests/unit/msgbuff.c | 19 +-- 4 files changed, 262 insertions(+), 56 deletions(-) diff --git a/include/nccl_ofi_msgbuff.h b/include/nccl_ofi_msgbuff.h index d53073058..51d2cabe7 100644 --- a/include/nccl_ofi_msgbuff.h +++ b/include/nccl_ofi_msgbuff.h @@ -68,6 +68,10 @@ typedef struct { // Type of element nccl_ofi_msgbuff_elemtype_t type; void *elem; + // Multi-recv information + uint16_t multi_recv_size; + uint16_t multi_recv_start; + int multi_recv_tag; } nccl_ofi_msgbuff_elem_t; typedef struct { @@ -110,9 +114,14 @@ bool nccl_ofi_msgbuff_destroy(nccl_ofi_msgbuff_t *msgbuff); * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag, + void *elem, nccl_ofi_msgbuff_elemtype_t type, nccl_ofi_msgbuff_status_t *msg_idx_status); +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_ctrl_multirecv(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_base_index, uint16_t multi_recv_size, int *tags, void *elem, + nccl_ofi_msgbuff_elemtype_t type, nccl_ofi_msgbuff_status_t *msg_idx_status); + /** * Replace an existing message element * @@ -126,8 +135,9 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, - nccl_ofi_msgbuff_status_t *msg_idx_status); + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status, bool *multi_send_ready); /** * Retrieve message with given index @@ -142,6 +152,12 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void **elem, nccl_ofi_msgbuff_elemtype_t *type, + nccl_ofi_msgbuff_status_t *msg_idx_status); + +/* As above, but with no tag */ +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve_notag(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void **elem, nccl_ofi_msgbuff_elemtype_t *type, nccl_ofi_msgbuff_status_t *msg_idx_status); @@ -156,7 +172,8 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, nccl_ofi_msgbuff_status_t *msg_idx_status); + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status); #ifdef _cplusplus } // End extern "C" diff --git a/src/nccl_ofi_msgbuff.c b/src/nccl_ofi_msgbuff.c index 2e6d6eea5..afecb7f4e 100644 --- a/src/nccl_ofi_msgbuff.c +++ b/src/nccl_ofi_msgbuff.c @@ -25,7 +25,7 @@ nccl_ofi_msgbuff_t *nccl_ofi_msgbuff_init(uint16_t buffer_size) goto error; } msgbuff->buff_size = buffer_size; - if (!(msgbuff->buff = malloc(sizeof(nccl_ofi_msgbuff_elem_t)*buffer_size))) { + if (!(msgbuff->buff = calloc((4*buffer_size), sizeof(nccl_ofi_msgbuff_elem_t)))) { NCCL_OFI_WARN("Memory allocation (msgbuff->buff) failed"); goto error; } @@ -77,7 +77,7 @@ static uint16_t nccl_ofi_msgbuff_num_inflight(const nccl_ofi_msgbuff_t *msgbuff) static inline nccl_ofi_msgbuff_elem_t *buff_idx(const nccl_ofi_msgbuff_t *msgbuff, uint16_t idx) { - return &msgbuff->buff[idx % msgbuff->buff_size]; + return &msgbuff->buff[idx % (4*msgbuff->buff_size)]; } /** @@ -115,19 +115,11 @@ static nccl_ofi_msgbuff_status_t nccl_ofi_msgbuff_get_idx_status return NCCL_OFI_MSGBUFF_UNAVAILABLE; } -nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, +static inline nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_at_idx(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, + uint16_t multi_recv_size, uint16_t multi_recv_start, int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status) { - if (!msgbuff) { - NCCL_OFI_WARN("msgbuff is NULL"); - return NCCL_OFI_MSGBUFF_ERROR; - } - if (pthread_mutex_lock(&msgbuff->lock)) { - NCCL_OFI_WARN("Error locking mutex"); - return NCCL_OFI_MSGBUFF_ERROR; - } - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; @@ -135,6 +127,10 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, buff_idx(msgbuff, msg_index)->stat = NCCL_OFI_MSGBUFF_INPROGRESS; buff_idx(msgbuff, msg_index)->elem = elem; buff_idx(msgbuff, msg_index)->type = type; + buff_idx(msgbuff, msg_index)->multi_recv_size = multi_recv_size; + if (multi_recv_size > 1) + buff_idx(msgbuff, msg_index)->multi_recv_start = multi_recv_start; + buff_idx(msgbuff, msg_index)->multi_recv_tag = multi_recv_tag; /* Update msg_next ptr */ while ((uint16_t)(msg_index - msgbuff->msg_next) <= msgbuff->buff_size) { if (msgbuff->msg_next != msg_index) { @@ -148,16 +144,99 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, ret = NCCL_OFI_MSGBUFF_INVALID_IDX; } + return ret; +} + +static inline bool nccl_ofi_msgbuff_multirecv_search(nccl_ofi_msgbuff_t *msgbuff, + uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag, + uint16_t *match_index) +{ + for (uint16_t idx = multi_recv_start; idx != (uint16_t)(multi_recv_start+multi_recv_size); ++idx) { + nccl_ofi_msgbuff_status_t msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, idx); + if (msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { + int present_tag = buff_idx(msgbuff, idx)->multi_recv_tag; + if (present_tag == multi_recv_tag) { + *match_index = idx; + return true; + } + } + } + return false; +} + +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag, + void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status) +{ + nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + + if (pthread_mutex_lock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } + + ret = nccl_ofi_msgbuff_insert_at_idx(msgbuff, msg_index, elem, type, + multi_recv_size, multi_recv_start, multi_recv_tag, msg_idx_status); + + if (pthread_mutex_unlock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error unlocking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } + return ret; +} + +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_ctrl_multirecv(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_base_index, uint16_t multi_recv_size, int *tags, void *elem, + nccl_ofi_msgbuff_elemtype_t type, nccl_ofi_msgbuff_status_t *msg_idx_status) +{ + assert(type == NCCL_OFI_MSGBUFF_BUFF); + + nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + + if (pthread_mutex_lock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } + + for (uint16_t i = 0; i < multi_recv_size; ++i) { + uint16_t msg_index = msg_base_index + i; + ret = nccl_ofi_msgbuff_insert_at_idx(msgbuff, msg_index, elem, type, + multi_recv_size, msg_base_index, tags[i], + msg_idx_status); + if (ret != NCCL_OFI_MSGBUFF_SUCCESS) { + goto unlock; + } + } + +unlock: if (pthread_mutex_unlock(&msgbuff->lock)) { NCCL_OFI_WARN("Error unlocking mutex"); - ret = NCCL_OFI_MSGBUFF_ERROR; + return NCCL_OFI_MSGBUFF_ERROR; } return ret; } +static bool test_ms_ready(nccl_ofi_msgbuff_t *msgbuff, uint16_t multi_recv_start, + uint16_t multi_recv_size) +{ + for (uint16_t i = multi_recv_start; i != (uint16_t)(multi_recv_start + multi_recv_size); + ++i) { + nccl_ofi_msgbuff_status_t msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, i); + if (msg_idx_status != NCCL_OFI_MSGBUFF_INPROGRESS) { + return false; + } + if (buff_idx(msgbuff, i)->type != NCCL_OFI_MSGBUFF_REQ) { + return false; + } + } + return true; +} + nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, - nccl_ofi_msgbuff_status_t *msg_idx_status) + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status, bool *multi_send_ready) { if (!msgbuff) { NCCL_OFI_WARN("msgbuff is NULL"); @@ -167,18 +246,32 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, NCCL_OFI_WARN("Error locking mutex"); return NCCL_OFI_MSGBUFF_ERROR; } + if (multi_send_ready) *multi_send_ready = false; - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + bool match_found = nccl_ofi_msgbuff_multirecv_search(msgbuff, multi_recv_start, + multi_recv_size, multi_recv_tag, &msg_index); + if (!match_found) { + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + goto unlock; + } + + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { buff_idx(msgbuff, msg_index)->elem = elem; buff_idx(msgbuff, msg_index)->type = type; + if (multi_send_ready) + *multi_send_ready = test_ms_ready(msgbuff, multi_recv_start, + multi_recv_size); ret = NCCL_OFI_MSGBUFF_SUCCESS; } else { ret = NCCL_OFI_MSGBUFF_INVALID_IDX; } +unlock: if (pthread_mutex_unlock(&msgbuff->lock)) { NCCL_OFI_WARN("Error unlocking mutex"); ret = NCCL_OFI_MSGBUFF_ERROR; @@ -186,7 +279,7 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, return ret; } -nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve_notag(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void **elem, nccl_ofi_msgbuff_elemtype_t *type, nccl_ofi_msgbuff_status_t *msg_idx_status) { @@ -199,16 +292,17 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, return NCCL_OFI_MSGBUFF_ERROR; } if (pthread_mutex_lock(&msgbuff->lock)) { - NCCL_OFI_WARN("Error locking mutex"); - return NCCL_OFI_MSGBUFF_ERROR; - } + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { *elem = buff_idx(msgbuff, msg_index)->elem; *type = buff_idx(msgbuff, msg_index)->type; + assert(*type == NCCL_OFI_MSGBUFF_REQ); ret = NCCL_OFI_MSGBUFF_SUCCESS; } else { if (*msg_idx_status == NCCL_OFI_MSGBUFF_UNAVAILABLE) { @@ -225,21 +319,102 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, return ret; } +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void **elem, nccl_ofi_msgbuff_elemtype_t *type, + nccl_ofi_msgbuff_status_t *msg_idx_status) +{ + if (!msgbuff) { + NCCL_OFI_WARN("msgbuff is NULL"); + return NCCL_OFI_MSGBUFF_ERROR; + } + if (!elem) { + NCCL_OFI_WARN("elem is NULL"); + return NCCL_OFI_MSGBUFF_ERROR; + } + if (pthread_mutex_lock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } + + nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + + if (multi_recv_size <= 1) { + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status != NCCL_OFI_MSGBUFF_UNAVAILABLE) { + /* Check if this actually should be a multi-recv */ + if (buff_idx(msgbuff, msg_index)->multi_recv_size > 1) { + assert(multi_recv_size == 0); + multi_recv_start = buff_idx(msgbuff, msg_index)->multi_recv_start; + multi_recv_size = buff_idx(msgbuff, msg_index)->multi_recv_size; + } + } + } + + if (multi_recv_size <= 1) { + /* Ok so this actually isn't a multirecv (that we know of) */ + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { + *elem = buff_idx(msgbuff, msg_index)->elem; + *type = buff_idx(msgbuff, msg_index)->type; + ret = NCCL_OFI_MSGBUFF_SUCCESS; + } else { + if (*msg_idx_status == NCCL_OFI_MSGBUFF_UNAVAILABLE) { + // UNAVAILABLE really only applies to insert, so return NOTSTARTED here + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + } + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + } + } else { + /* Multi-recv -- search the index space */ + bool match_found = nccl_ofi_msgbuff_multirecv_search(msgbuff, multi_recv_start, + multi_recv_size, multi_recv_tag, &msg_index); + if (!match_found) { + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + } else { + *msg_idx_status = NCCL_OFI_MSGBUFF_INPROGRESS; + *elem = buff_idx(msgbuff, msg_index)->elem; + *type = buff_idx(msgbuff, msg_index)->type; + + ret = NCCL_OFI_MSGBUFF_SUCCESS; + } + } + + if (pthread_mutex_unlock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error unlocking mutex"); + ret = NCCL_OFI_MSGBUFF_ERROR; + } + return ret; +} + nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, nccl_ofi_msgbuff_status_t *msg_idx_status) + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status) { if (!msgbuff) { NCCL_OFI_WARN("msgbuff is null"); return NCCL_OFI_MSGBUFF_ERROR; } if (pthread_mutex_lock(&msgbuff->lock)) { - NCCL_OFI_WARN("Error locking mutex"); - return NCCL_OFI_MSGBUFF_ERROR; - } + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + if (multi_recv_size > 1) { + bool match_found = nccl_ofi_msgbuff_multirecv_search(msgbuff, multi_recv_start, + multi_recv_size, multi_recv_tag, &msg_index); + if (!match_found) { + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + goto unlock; + } + } + + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { buff_idx(msgbuff, msg_index)->stat = NCCL_OFI_MSGBUFF_COMPLETED; buff_idx(msgbuff, msg_index)->elem = NULL; @@ -247,6 +422,12 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, while (msgbuff->msg_last_incomplete != msgbuff->msg_next && buff_idx(msgbuff, msgbuff->msg_last_incomplete)->stat == NCCL_OFI_MSGBUFF_COMPLETED) { + /* Clear out relevant info of the now-unavailable message */ + uint16_t unavail_index = msgbuff->msg_last_incomplete - msgbuff->buff_size; + buff_idx(msgbuff, unavail_index)->elem = NULL; + buff_idx(msgbuff, unavail_index)->multi_recv_size = 0; + buff_idx(msgbuff, unavail_index)->multi_recv_start = 0; + buff_idx(msgbuff, unavail_index)->multi_recv_tag = 0; ++(msgbuff->msg_last_incomplete); } ret = NCCL_OFI_MSGBUFF_SUCCESS; @@ -257,6 +438,8 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, } ret = NCCL_OFI_MSGBUFF_INVALID_IDX; } + +unlock: if (pthread_mutex_unlock(&msgbuff->lock)) { NCCL_OFI_WARN("Error unlocking mutex"); ret = NCCL_OFI_MSGBUFF_ERROR; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 9afa5dd39..051a3844a 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -1013,7 +1013,7 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_ofi_msgbuff_status_t stat; nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, msg_seq_num, - bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); + msg_seq_num, 1, 0, bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { /* Inserted! In this case sender has not yet called send() for this message, so @@ -1029,7 +1029,8 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, // Already a req entry here void *elem; nccl_ofi_msgbuff_elemtype_t type; - mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, &elem, &type, &stat); + mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, msg_seq_num, 1, 0, + &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS || type != NCCL_OFI_MSGBUFF_REQ) { NCCL_OFI_WARN("Invalid message retrieval result for msg %hu", msg_seq_num); return -EINVAL; @@ -1137,7 +1138,7 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_ofi_msgbuff_status_t stat; nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, msg_seq_num, - bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); + msg_seq_num, 1, 0, bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { /* Inserted! In this case receiver has not yet called recv() for this message, so @@ -1157,7 +1158,8 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, // In this case, there is already a req entry here. Initiate eager copy. void *elem; nccl_ofi_msgbuff_elemtype_t type; - mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, &elem, &type, &stat); + mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, msg_seq_num, + 1, 0, &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS || type != NCCL_OFI_MSGBUFF_REQ) { NCCL_OFI_WARN("Invalid message retrieval result for msg %hu", msg_seq_num); return -EINVAL; @@ -1258,7 +1260,7 @@ static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data nccl_ofi_msgbuff_status_t stat; nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, - msg_seq_num, &elem, &type, &stat); + msg_seq_num, msg_seq_num, 1, 0, &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { /* Unexpected: we don't have a msgbuff entry corresponding to this message*/ NCCL_OFI_WARN("Unexpected status (%d) for message %hu", (int)stat, msg_seq_num); @@ -2307,7 +2309,8 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) } nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_complete(msgbuff, req->msg_seq_num, &stat); + nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_complete(msgbuff, req->msg_seq_num, + req->msg_seq_num, 1, 0, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("Invalid result of msgbuff_complete for msg %hu", req->msg_seq_num); ret = ncclSystemError; @@ -2980,9 +2983,9 @@ static inline int insert_rdma_recv_req_into_msgbuff(nccl_net_ofi_rdma_recv_comm_ * replace it with a request. */ mb_res = nccl_ofi_msgbuff_replace(r_comm->msgbuff, - req->msg_seq_num, req, + req->msg_seq_num, req->msg_seq_num, 1, 0, req, NCCL_OFI_MSGBUFF_REQ, - &msg_stat); + &msg_stat, NULL); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("Unexpected result of nccl_ofi_msgbuff_replace for msg %hu", req->msg_seq_num); @@ -2990,8 +2993,8 @@ static inline int insert_rdma_recv_req_into_msgbuff(nccl_net_ofi_rdma_recv_comm_ } } else { /* Try inserting the new request */ - mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, req->msg_seq_num, req, - NCCL_OFI_MSGBUFF_REQ, &msg_stat); + mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, req->msg_seq_num, req->msg_seq_num, + 1, 0, req, NCCL_OFI_MSGBUFF_REQ, &msg_stat); if (OFI_UNLIKELY((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && (msg_stat == NCCL_OFI_MSGBUFF_INPROGRESS))) { @@ -3076,8 +3079,8 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; - mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, &elem, - &type, &msg_stat); + mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, msg_seq_num, + 1, 0, &elem, &type, &msg_stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { if (type == NCCL_OFI_MSGBUFF_REQ) { @@ -4194,9 +4197,10 @@ static int insert_rdma_send_req_into_msgbuff(nccl_net_ofi_rdma_send_comm_t *s_co * so replace it with a request. */ mb_res = nccl_ofi_msgbuff_replace(s_comm->msgbuff, - req->msg_seq_num, req, + req->msg_seq_num, req->msg_seq_num, + 1, 0, req, NCCL_OFI_MSGBUFF_REQ, - &msg_stat); + &msg_stat, NULL); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("Unexpected result of nccl_ofi_msgbuff_replace for msg %hu", req->msg_seq_num); @@ -4205,7 +4209,8 @@ static int insert_rdma_send_req_into_msgbuff(nccl_net_ofi_rdma_send_comm_t *s_co } else { /* Try inserting the new request */ mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, - req->msg_seq_num, req, + req->msg_seq_num, req->msg_seq_num, + 1, 0, req, NCCL_OFI_MSGBUFF_REQ, &msg_stat); if (OFI_UNLIKELY((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && @@ -4636,8 +4641,8 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t nccl_ofi_msgbuff_result_t mb_res; /* Retrive entry from message buffer for msg_seq_num index */ - mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, &elem, - &type, &msg_stat); + mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, msg_seq_num, + 1, 0, &elem, &type, &msg_stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { if (type == NCCL_OFI_MSGBUFF_BUFF) { /* diff --git a/tests/unit/msgbuff.c b/tests/unit/msgbuff.c index 074dcb217..63dc1fdd0 100644 --- a/tests/unit/msgbuff.c +++ b/tests/unit/msgbuff.c @@ -26,17 +26,17 @@ int main(int argc, char *argv[]) /** Test insert new **/ for (uint16_t i = 0; i < buff_sz; ++i) { - if (nccl_ofi_msgbuff_insert(msgbuff, i, &buff_store[i], type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { + if (nccl_ofi_msgbuff_insert(msgbuff, i, i, 1, 0, &buff_store[i], type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_insert failed when non-full"); return 1; } } - if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz, buff_sz, 1, 0, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_UNAVAILABLE) { NCCL_OFI_WARN("nccl_ofi_msgbuff_insert did not return unavailable when full"); return 1; } - if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz-1, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz-1, buff_sz-1, 1, 0, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_INPROGRESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_insert did not return inprogress on duplicate insert"); return 1; @@ -45,7 +45,7 @@ int main(int argc, char *argv[]) /** Test retrieve **/ uint16_t *result; for (uint16_t i = 0; i < buff_sz; ++i) { - if (nccl_ofi_msgbuff_retrieve(msgbuff, i, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { + if (nccl_ofi_msgbuff_retrieve(msgbuff, i, i, 1, 0, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_retrieve failed on valid index"); return 1; } @@ -54,12 +54,13 @@ int main(int argc, char *argv[]) return 1; } } - if (nccl_ofi_msgbuff_retrieve(msgbuff, buff_sz, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_retrieve(msgbuff, buff_sz, buff_sz, 1, 0, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_NOTSTARTED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_retrieve did not return notstarted"); return 1; } - if (nccl_ofi_msgbuff_retrieve(msgbuff, UINT16_C(0) - UINT16_C(1), (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_retrieve(msgbuff, UINT16_C(0) - UINT16_C(1), UINT16_C(0) - UINT16_C(1), 1, 0, + (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_COMPLETED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_retrieve did not return completed"); return 1; @@ -67,17 +68,17 @@ int main(int argc, char *argv[]) /** Test complete **/ for (uint16_t i = 0; i < buff_sz; ++i) { - if (nccl_ofi_msgbuff_complete(msgbuff, i, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { + if (nccl_ofi_msgbuff_complete(msgbuff, i, i, 1, 0, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_complete failed"); return 1; } } - if (nccl_ofi_msgbuff_complete(msgbuff, buff_sz, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_complete(msgbuff, buff_sz, buff_sz, 1, 0, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_NOTSTARTED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_complete did not return notstarted"); return 1; } - if (nccl_ofi_msgbuff_complete(msgbuff, 0, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_complete(msgbuff, 0, 0, 1, 0, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_COMPLETED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_complete did not return completed"); return 1; From 1278cf6531eec4e7c6dcf78350ba2dfd96e5b001 Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Sat, 24 Feb 2024 01:07:29 +0000 Subject: [PATCH 5/6] rdma: support NCCL multi-recv interface The multi-recv interface allows aggregating up to 8 receive requests in a single request. This commit does not yet advertise support for multi-recv to NCCL. * Temporarily disables eager; it will be re-enabled in a future commit. Signed-off-by: Eric Raut --- include/nccl_ofi_rdma.h | 33 ++- src/nccl_ofi_rdma.c | 642 ++++++++++++++++++++++++++-------------- 2 files changed, 456 insertions(+), 219 deletions(-) diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 335ea8004..9c3e7c2ea 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -74,20 +74,32 @@ typedef struct nccl_net_ofi_rdma_mr_handle { struct fid_mr *mr[]; } nccl_net_ofi_rdma_mr_handle_t; -/* Contents of ctrl message sent from receiver to sender to advertise - destination buffer */ -typedef struct nccl_net_ofi_rdma_ctrl_msg { +typedef struct nccl_net_ofi_rdma_ctrl_msg_entry { + int multi_recv_tag; uint64_t buff_addr; uint64_t buff_len; uint64_t buff_mr_key[MAX_NUM_RAILS]; +} nccl_net_ofi_rdma_ctrl_msg_entry_t; + +/* Contents of ctrl message sent from receiver to sender to advertise + destination buffer */ +typedef struct nccl_net_ofi_rdma_ctrl_msg { + uint16_t msg_seq_num; + uint16_t multi_recv_size; + nccl_net_ofi_rdma_ctrl_msg_entry_t entries[]; } nccl_net_ofi_rdma_ctrl_msg_t; +#define RDMA_CTRL_MSG_ENTRIES_MAX_SIZE (NCCL_OFI_MAX_RECVS * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t)) +#define RDMA_CTRL_MSG_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE) + /* Structure used to store control messages in a free list */ typedef struct nccl_net_ofi_rdma_ctrl_fl_item { nccl_ofi_freelist_reginfo_t fl_reginfo; nccl_net_ofi_rdma_ctrl_msg_t ctrl_msg; } nccl_net_ofi_rdma_ctrl_fl_item_t; +#define RDMA_CTRL_FL_ITEM_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE) + /* For LL/LL128 protocols, bounce buffers (source of RDMA read operations) need to be 128B aligned */ #define BOUNCE_BUFFER_ALIGNMENT 128 @@ -152,6 +164,13 @@ typedef struct { /* Total number of completions. Expect one completion for receiving the * control message and one completion for each send segment. */ int total_num_compls; + + /* Multi-recv information */ + uint16_t multi_recv_size; + uint16_t multi_recv_start; + int multi_recv_tag; + /* This may not match sender-side seq num with multi-recv */ + uint16_t recv_side_msg_seq_num; } rdma_req_send_data_t; /* @@ -166,6 +185,8 @@ typedef struct { nccl_net_ofi_schedule_t *ctrl_schedule; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req_t *recv_req; + /* Size of ctrl message */ + size_t ctrl_msg_size; } rdma_req_send_ctrl_data_t; typedef struct { @@ -206,6 +227,12 @@ typedef struct { * For eager messages, the second completion will be received * when the local read into the destination buffer is complete */ int total_num_compls; + /* Multi-recv information */ + uint16_t multi_recv_size; + uint16_t multi_recv_start; + int multi_recv_tag; + /* Next req in sequence */ + nccl_net_ofi_rdma_req_t *multi_recv_next; } rdma_req_recv_data_t; /* diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 051a3844a..aae931d62 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -914,18 +914,42 @@ static inline int inc_recv_seg_completion(nccl_net_ofi_rdma_req_t *req, return ret; } -static void copy_ctrl_data(nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdma_req_t *req) +static void copy_ctrl_data(nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdma_req_t *req, int tag) { rdma_req_send_data_t *send_data = get_send_data(req); rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = get_bounce_ctrl_msg(bounce_data->bounce_fl_item); + /** Ctrl message size consistency check **/ + assert(bounce_data->recv_len == sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + + ctrl_msg->multi_recv_size * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t)); + + + uint16_t multi_recv_size = ctrl_msg->multi_recv_size; + + /* TODO remove an extra search */ + int ctrl_idx; + for (ctrl_idx = 0; ctrl_idx < multi_recv_size; ++ctrl_idx) { + nccl_net_ofi_rdma_ctrl_msg_entry_t *entry = &ctrl_msg->entries[ctrl_idx]; + if (entry->multi_recv_tag == tag) { + break; + } + } + if (ctrl_idx >= multi_recv_size) { + assert(false); abort(); + } + for (int rail_id = 0; rail_id != MAX_NUM_RAILS; ++rail_id) { - send_data->remote_mr_key[rail_id] = ctrl_msg->buff_mr_key[rail_id]; + send_data->remote_mr_key[rail_id] = ctrl_msg->entries[ctrl_idx].buff_mr_key[rail_id]; } - send_data->remote_buff = ctrl_msg->buff_addr; - send_data->remote_len = ctrl_msg->buff_len; + send_data->remote_buff = ctrl_msg->entries[ctrl_idx].buff_addr; + send_data->remote_len = ctrl_msg->entries[ctrl_idx].buff_len; + + send_data->multi_recv_size = ctrl_msg->multi_recv_size; + send_data->multi_recv_start = ctrl_msg->msg_seq_num; + assert(send_data->multi_recv_tag == ctrl_msg->entries[ctrl_idx].multi_recv_tag); + send_data->recv_side_msg_seq_num = ctrl_msg->msg_seq_num + (uint16_t)ctrl_idx; } /* @@ -1009,18 +1033,29 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdma_ep_t *ep) { - int ret; + nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = get_bounce_ctrl_msg(get_bounce_data(bounce_req)->bounce_fl_item); + + /* Assert that imm data matches ctrl data for seq num */ + assert(msg_seq_num == ctrl_msg->msg_seq_num); + + uint16_t multi_recv_size = ctrl_msg->multi_recv_size; + int tags[multi_recv_size]; + for (uint16_t i = 0; i < multi_recv_size; ++i) { + tags[i] = ctrl_msg->entries[i].multi_recv_tag; + } nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, msg_seq_num, - msg_seq_num, 1, 0, bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); + nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert_ctrl_multirecv(s_comm->msgbuff, msg_seq_num, + ctrl_msg->multi_recv_size, tags, bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { /* Inserted! In this case sender has not yet called send() for this message, so return success and initiate RDMA write when sender calls send(). */ return decrease_bounce_buff_cnt(ep, get_bounce_data(bounce_req)->rail); } + assert(false); abort(); /* TODO handle this case */ +#if 0 if (mb_res != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_INPROGRESS) { NCCL_OFI_WARN("Unexpected message insert result (%d) (ctrl recv)", (int)mb_res); return -EINVAL; @@ -1029,8 +1064,9 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, // Already a req entry here void *elem; nccl_ofi_msgbuff_elemtype_t type; - mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, msg_seq_num, 1, 0, - &elem, &type, &stat); + mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, + ctrl_msg->multi_recv_start, ctrl_msg->multi_recv_size, + ctrl_msg->multi_recv_tag, &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS || type != NCCL_OFI_MSGBUFF_REQ) { NCCL_OFI_WARN("Invalid message retrieval result for msg %hu", msg_seq_num); return -EINVAL; @@ -1039,7 +1075,8 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, rdma_req_send_data_t *send_data = get_send_data(req); if (!send_data->eager) { - copy_ctrl_data(bounce_req, req); + abort(); + copy_ctrl_data(bounce_req, req, -1); /* We need to initiate RDMA write here. */ if (send_data->buff_len > send_data->remote_len) { @@ -1080,7 +1117,7 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, NCCL_OFI_WARN("Failed to repost bounce buff"); return ret; } - +#endif return 0; } @@ -1227,10 +1264,12 @@ static inline int handle_bounce_recv(struct fi_cq_tagged_entry *cq_entry, int ra NCCL_OFI_TRACE_SEND_CTRL_RECV(comm->dev_id, rail_id, comm, msg_seq_num); nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)comm; assert(s_comm->local_comm_id == local_comm_id); - assert(bounce_data->recv_len == sizeof(nccl_net_ofi_rdma_ctrl_msg_t)); + assert(bounce_data->recv_len <= RDMA_CTRL_MSG_MAX_SIZE); return handle_ctrl_recv(s_comm, msg_seq_num, bounce_req, ep); } else if (comm->type == NCCL_NET_OFI_RECV_COMM) { + NCCL_OFI_WARN("Eager receive is not yet supported!"); + assert(false); abort(); /* Eager message */ NCCL_OFI_TRACE_EAGER_RECV(comm->dev_id, rail_id, comm, msg_seq_num); nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)comm; @@ -1259,8 +1298,10 @@ static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data nccl_ofi_msgbuff_elemtype_t type; nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, - msg_seq_num, msg_seq_num, 1, 0, &elem, &type, &stat); + /* We don't have a multi-recv tag here, so we rely on msg_seq_num matching + our seq num */ + nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_retrieve_notag(r_comm->msgbuff, + msg_seq_num, &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { /* Unexpected: we don't have a msgbuff entry corresponding to this message*/ NCCL_OFI_WARN("Unexpected status (%d) for message %hu", (int)stat, msg_seq_num); @@ -1286,6 +1327,7 @@ static inline int handle_write_comp(struct fi_cq_tagged_entry *cq_entry, return ncclSystemError; } assert(req->type == NCCL_OFI_RDMA_RECV); + assert(req->msg_seq_num == GET_SEQ_NUM_FROM_IMM(cq_entry->data)); rdma_req_recv_data_t *recv_data = get_recv_data(req); nccl_net_ofi_rdma_req_t *recv_segms_req = recv_data->recv_segms_req; @@ -2247,34 +2289,13 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm) #define __compiler_barrier() do { asm volatile ("" : : : "memory"); } while(0) -static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) +static int test_req(nccl_net_ofi_rdma_req_t *req, int *done, int *size) { - int ret = 0; - nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)base_req; *done = 0; - assert(req->type == NCCL_OFI_RDMA_SEND || - req->type == NCCL_OFI_RDMA_RECV || - req->type == NCCL_OFI_RDMA_FLUSH); - - /* Retrieve and validate comm */ - nccl_net_ofi_comm_t *base_comm = req->comm; - assert(base_comm != NULL); - - /* Retrieve and validate endpoint */ - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_comm->ep; - assert(ep != NULL); - - /* Process more completions unless the current request is - * completed */ - if (req->state != NCCL_OFI_RDMA_REQ_COMPLETED - && OFI_LIKELY(req->state != NCCL_OFI_RDMA_REQ_ERROR)) { - ret = ofi_process_cq(ep); - if (OFI_UNLIKELY(ret != 0)) - goto exit; - } + int ret = 0; /* Determine whether the request has finished without error and free if done */ - if (OFI_LIKELY(req->state == NCCL_OFI_RDMA_REQ_COMPLETED)) { + if (req->state == NCCL_OFI_RDMA_REQ_COMPLETED) { size_t req_size; if (pthread_mutex_lock(&req->req_lock)) { NCCL_OFI_WARN("Unable to acquire req_lock mutex"); @@ -2294,36 +2315,146 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) *size = req_size; /* Mark as done */ *done = 1; + } else if (OFI_UNLIKELY(req->state == NCCL_OFI_RDMA_REQ_ERROR)) { + NCCL_OFI_WARN("Request completed with error"); + ret = ncclSystemError; + goto exit; + } +exit: + return ret; +} + +static int test_free_req(nccl_net_ofi_rdma_req_t *req) +{ + int ret = 0; + if (req->type != NCCL_OFI_RDMA_FLUSH) { + uint16_t multi_recv_start; + uint16_t multi_recv_size; + int multi_recv_tag; + + /* Retrieve and validate comm */ + nccl_net_ofi_comm_t *base_comm = req->comm; + assert(base_comm != NULL); + + /* Mark as complete in message buffer */ + nccl_ofi_msgbuff_t *msgbuff; + + if (req->type == NCCL_OFI_RDMA_SEND) { + msgbuff = ((nccl_net_ofi_rdma_send_comm_t *)base_comm)->msgbuff; + rdma_req_send_data_t *send_data = get_send_data(req); + multi_recv_start = send_data->multi_recv_start; + multi_recv_size = send_data->multi_recv_size; + multi_recv_tag = send_data->multi_recv_tag; + } else if (req->type == NCCL_OFI_RDMA_RECV) { + msgbuff = ((nccl_net_ofi_rdma_recv_comm_t *)base_comm)->msgbuff; + rdma_req_recv_data_t *recv_data = get_recv_data(req); + multi_recv_start = recv_data->multi_recv_start; + multi_recv_size = recv_data->multi_recv_size; + multi_recv_tag = recv_data->multi_recv_tag; + } else { + NCCL_OFI_WARN("Unexpected request type: %d", req->type); + ret = ncclSystemError; + goto exit; + } + + nccl_ofi_msgbuff_status_t stat; + nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_complete(msgbuff, req->msg_seq_num, multi_recv_start, + multi_recv_size, multi_recv_tag, &stat); + if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { + NCCL_OFI_WARN("Invalid result (%d,%d) of msgbuff_complete for msg %hu type %d", mb_res, stat, req->msg_seq_num, req->type); + ret = ncclSystemError; + goto exit; + } + } + assert(req->free); + req->free(req, true); + +exit: + return ret; +} - if (req->type != NCCL_OFI_RDMA_FLUSH) { - /* Mark as complete in message buffer */ - nccl_ofi_msgbuff_t *msgbuff; - if (req->type == NCCL_OFI_RDMA_SEND) { - msgbuff = ((nccl_net_ofi_rdma_send_comm_t *)base_comm)->msgbuff; - } else if (req->type == NCCL_OFI_RDMA_RECV) { - msgbuff = ((nccl_net_ofi_rdma_recv_comm_t *)base_comm)->msgbuff; +static int free_multirecv_req(nccl_net_ofi_rdma_req_t *req) +{ + int ret = 0; + while (req) { + nccl_net_ofi_rdma_req_t *next_req = get_recv_data(req)->multi_recv_next; + ret = test_free_req(req); + if (OFI_UNLIKELY(ret != 0)) { + return ret; + } + req = next_req; + } + return ret; +} + +static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) +{ + int ret = 0; + nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)base_req; + *done = 0; + assert(req->type == NCCL_OFI_RDMA_SEND || + req->type == NCCL_OFI_RDMA_RECV || + req->type == NCCL_OFI_RDMA_FLUSH); + + /* Retrieve and validate comm */ + nccl_net_ofi_comm_t *base_comm = req->comm; + assert(base_comm != NULL); + + /* Retrieve and validate endpoint */ + nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_comm->ep; + assert(ep != NULL); + + if (req->type == NCCL_OFI_RDMA_RECV && + get_recv_data(req)->multi_recv_size > 1) { +#ifndef NDEBUG + uint16_t multi_recv_size = get_recv_data(req)->multi_recv_size; +#endif + /* Multi-recv: test each request individually */ + bool processed_cq = false; + int i = 0; + while (req) { + ret = test_req(req, done, &size[i]); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + if (*done) { + req = get_recv_data(req)->multi_recv_next; + ++i; } else { - NCCL_OFI_WARN("Unexpected request type: %d", req->type); - ret = ncclSystemError; + if (!processed_cq) { + ret = ofi_process_cq(ep); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + processed_cq = true; + } else { + break; + } + } + } + if (*done) { + assert(i == multi_recv_size); + req = (nccl_net_ofi_rdma_req_t *)base_req; + ret = free_multirecv_req(req); + } + } else { + ret = test_req(req, done, size); + if (OFI_UNLIKELY(ret)) { + goto exit; + } + if (!(*done)) { + ret = ofi_process_cq(ep); + if (OFI_UNLIKELY(ret != 0)) { goto exit; } - - nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_complete(msgbuff, req->msg_seq_num, - req->msg_seq_num, 1, 0, &stat); - if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { - NCCL_OFI_WARN("Invalid result of msgbuff_complete for msg %hu", req->msg_seq_num); - ret = ncclSystemError; + ret = test_req(req, done, size); + if (OFI_UNLIKELY(ret)) { goto exit; } } - - assert(req->free); - req->free(req, true); - } else if (OFI_UNLIKELY(req->state == NCCL_OFI_RDMA_REQ_ERROR)) { - NCCL_OFI_WARN("Request completed with error"); - ret = ncclSystemError; - goto exit; + if (*done) { + ret = test_free_req(req); + } } exit: @@ -2819,6 +2950,10 @@ static inline int insert_send_ctrl_req( return ncclSystemError; } + rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); + + recv_data->total_num_compls = 2; + send_ctrl_req->comm = &r_comm->base.base; send_ctrl_req->dev_id = dev_id; send_ctrl_req->type = NCCL_OFI_RDMA_SEND_CTRL; @@ -2826,11 +2961,14 @@ static inline int insert_send_ctrl_req( send_ctrl_req->msg_seq_num = msg_seq_num; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req); + send_ctrl_data->ctrl_msg_size = sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + + recv_data->multi_recv_size * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t); send_ctrl_data->recv_req = recv_req; send_ctrl_data->ctrl_fl_item = NULL; send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, - sizeof(nccl_net_ofi_rdma_ctrl_msg_t), - device->num_rails); + sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + + (recv_data->multi_recv_size * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t)), + device->num_rails); if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) { return ncclInternalError; @@ -2861,13 +2999,16 @@ static inline int insert_send_ctrl_req( return ncclInternalError; } - ctrl_fl_item->ctrl_msg.buff_addr = (uint64_t)buff; - ctrl_fl_item->ctrl_msg.buff_len = size; + ctrl_fl_item->ctrl_msg.msg_seq_num = msg_seq_num; + ctrl_fl_item->ctrl_msg.multi_recv_size = recv_data->multi_recv_size; + ctrl_fl_item->ctrl_msg.entries[0].multi_recv_tag = recv_data->multi_recv_tag; + ctrl_fl_item->ctrl_msg.entries[0].buff_addr = (uint64_t)buff; + ctrl_fl_item->ctrl_msg.entries[0].buff_len = size; int rail_id = 0; for (; rail_id < r_comm->num_rails; rail_id++) { - ctrl_fl_item->ctrl_msg.buff_mr_key[rail_id] = fi_mr_key(buff_mr_handle->mr[rail_id]); + ctrl_fl_item->ctrl_msg.entries[0].buff_mr_key[rail_id] = fi_mr_key(buff_mr_handle->mr[rail_id]); - if (ctrl_fl_item->ctrl_msg.buff_mr_key[rail_id] == FI_KEY_NOTAVAIL) { + if (ctrl_fl_item->ctrl_msg.entries[0].buff_mr_key[rail_id] == FI_KEY_NOTAVAIL) { NCCL_OFI_WARN("RDMA write buffers should be pre-registered"); return ncclInternalError; } @@ -2875,7 +3016,6 @@ static inline int insert_send_ctrl_req( send_ctrl_data->ctrl_fl_item = ctrl_fl_item; - rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); recv_data->send_ctrl_req = send_ctrl_req; return 0; @@ -2922,8 +3062,9 @@ static inline int insert_recv_segms_req( static inline int allocate_rdma_recv_req( nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_device_t *device, - int dev_id, uint16_t msg_seq_num, void *buff, - size_t size, + int dev_id, uint16_t msg_seq_num, + uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void *buff, size_t size, nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, nccl_net_ofi_rdma_req_t **ret_req) { @@ -2946,18 +3087,19 @@ static inline int allocate_rdma_recv_req( req->msg_seq_num = msg_seq_num; recv_data = get_recv_data(req); - recv_data->total_num_compls = 2; + recv_data->total_num_compls = 1; recv_data->eager_copy_req = NULL; recv_data->dst_buff = buff; recv_data->dst_len = size; recv_data->dest_mr_handle = buff_mr_handle; - /* TODO consolidate arguments to insert_send_ctrl_req and insert_recv_segms_req */ - ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req); - if (ret) { - NCCL_OFI_WARN("Failed to insert send ctrl request into recv request"); - return ret; - } + /* Populate multi-recv data */ + recv_data->multi_recv_size = multi_recv_size; + recv_data->multi_recv_start = multi_recv_start; + recv_data->multi_recv_tag = multi_recv_tag; + recv_data->multi_recv_next = NULL; + + recv_data->send_ctrl_req = NULL; ret = insert_recv_segms_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req); if (ret) { @@ -2977,13 +3119,19 @@ static inline int insert_rdma_recv_req_into_msgbuff(nccl_net_ofi_rdma_recv_comm_ nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; + rdma_req_recv_data_t *recv_data = get_recv_data(req); + if (eager) { + assert(false); + assert(recv_data->multi_recv_size == 1); /* * There is already a buffer entry in the message buffer, so * replace it with a request. */ mb_res = nccl_ofi_msgbuff_replace(r_comm->msgbuff, - req->msg_seq_num, req->msg_seq_num, 1, 0, req, + req->msg_seq_num, recv_data->multi_recv_start, + recv_data->multi_recv_size, + recv_data->multi_recv_tag, req, NCCL_OFI_MSGBUFF_REQ, &msg_stat, NULL); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { @@ -2993,13 +3141,17 @@ static inline int insert_rdma_recv_req_into_msgbuff(nccl_net_ofi_rdma_recv_comm_ } } else { /* Try inserting the new request */ - mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, req->msg_seq_num, req->msg_seq_num, - 1, 0, req, NCCL_OFI_MSGBUFF_REQ, &msg_stat); + mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, req->msg_seq_num, + recv_data->multi_recv_start, + recv_data->multi_recv_size, + recv_data->multi_recv_tag, req, + NCCL_OFI_MSGBUFF_REQ, &msg_stat); if (OFI_UNLIKELY((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && (msg_stat == NCCL_OFI_MSGBUFF_INPROGRESS))) { /* Unlikely: an eager message was received on another thread. Return NULL and let NCCL call recv again. */ + assert(false); /* Reduce testing surface for now. TODO remove. */ req->free(req, false); *ret_req = NULL; } else if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_SUCCESS)) { @@ -3046,7 +3198,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, assert(r_comm != NULL); - if (OFI_UNLIKELY(r_comm->num_inflight_reqs == NCCL_OFI_MAX_REQUESTS)) { + if (OFI_UNLIKELY(r_comm->num_inflight_reqs + n > NCCL_OFI_MAX_REQUESTS)) { ret = -ENOSPC; NCCL_OFI_WARN("Can not support more than %d inflight requests", NCCL_OFI_MAX_REQUESTS); @@ -3071,109 +3223,92 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - uint16_t msg_seq_num = r_comm->next_msg_seq_num; + uint16_t base_msg_seq_num = r_comm->next_msg_seq_num; - bool eager = false; - void *elem; - nccl_ofi_msgbuff_elemtype_t type; - nccl_ofi_msgbuff_status_t msg_stat; - nccl_ofi_msgbuff_result_t mb_res; + nccl_net_ofi_rdma_req_t *multirecv_base_req = NULL; + nccl_net_ofi_rdma_req_t *multirecv_prev_req = NULL; + rdma_req_recv_data_t *base_recv_data = NULL; - mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, msg_seq_num, - 1, 0, &elem, &type, &msg_stat); - if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { + assert(n <= NCCL_OFI_MAX_RECVS); - if (type == NCCL_OFI_MSGBUFF_REQ) { - /* Shouldn't happen: duplicate request */ - NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num); - ret = -EINVAL; - goto error; - } else if (type == NCCL_OFI_MSGBUFF_BUFF) { - /* This is an eager message */ - eager = true; - } else { - NCCL_OFI_WARN("Invalid type in msg buff"); - ret = -EINVAL; - goto error; - } - } else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && - (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) { - /* Allocate a new req */ - } else { - NCCL_OFI_WARN("Message %hu has invalid status.", msg_seq_num); - ret = -EINVAL; - goto error; - } + for (uint16_t i = 0; i < n; ++i) { + uint16_t msg_idx = base_msg_seq_num + i; + bool eager = false; - ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_seq_num, - buffers[0], sizes[0], - mr_handles[0], &req); - if (ret != 0) { - goto error; - } + /* Eager TODO: check for existing request */ - rdma_req_recv_data_t *recv_data = get_recv_data(req); + ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_idx, + base_msg_seq_num, n, tags[i], + buffers[i], sizes[i], + mr_handles[i], &req); + if (ret != 0) { + goto error; + } - if (eager) { - nccl_net_ofi_rdma_req_t *bounce_req = elem; - rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); - if (bounce_data->recv_len == 0) { - /* Special case for zero-sized messages */ - ret = check_post_bounce_req(bounce_req); - if (ret != 0) { - NCCL_OFI_WARN("Failed call to check_post_bounce_req"); + if (i == 0) { + ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_idx, buffers[i], + sizes[i], mr_handles[i], req); + if (ret) { + NCCL_OFI_WARN("Failed to insert send ctrl request into recv request"); return ret; } - recv_data->eager_copy_req = NULL; } else { - ret = alloc_eager_copy_req(req, r_comm, bounce_req); - if (ret != 0) { - goto error; + /* Fill in info for this req */ + assert(multirecv_base_req); + assert(n == base_recv_data->multi_recv_size); + nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = + get_send_ctrl_data(base_recv_data->send_ctrl_req)->ctrl_fl_item; + ctrl_fl_item->ctrl_msg.entries[i].multi_recv_tag = tags[i]; + ctrl_fl_item->ctrl_msg.entries[i].buff_addr = (uint64_t)buffers[i]; + ctrl_fl_item->ctrl_msg.entries[i].buff_len = sizes[i]; + for (int rail_id = 0; rail_id < r_comm->num_rails; ++rail_id) { + ctrl_fl_item->ctrl_msg.entries[i].buff_mr_key[rail_id] = + fi_mr_key(mr_handles[i]->mr[rail_id]); + + if (ctrl_fl_item->ctrl_msg.entries[i].buff_mr_key[rail_id] == FI_KEY_NOTAVAIL) { + NCCL_OFI_WARN("RDMA write buffers should be pre-registered"); + return ncclInternalError; + } } } - } - ret = insert_rdma_recv_req_into_msgbuff(r_comm, eager, &req); - if (ret != 0) { - goto free_req; - } else if (req == NULL) { - ret = -ENOMEM; - goto free_req; - } + if (multirecv_prev_req != NULL) { + get_recv_data(multirecv_prev_req)->multi_recv_next = req; + } + multirecv_prev_req = req; + if (i == 0) { + multirecv_base_req = req; + base_recv_data = get_recv_data(multirecv_base_req); + } - /* At this point, we've successfully inserted a new request, so update the num inflight. */ - (r_comm->num_inflight_reqs)++; + /* Eager TODO: allocate eager copy req */ + + ret = insert_rdma_recv_req_into_msgbuff(r_comm, eager, &req); + if (ret != 0) { + goto free_req; + } else if (req == NULL) { + ret = -ENOMEM; + goto free_req; + } + + /* At this point, we've successfully inserted a new request, so update the num inflight. */ + (r_comm->num_inflight_reqs)++; - NCCL_OFI_TRACE_RECV(dev_id, r_comm->local_comm_id, sizes[0], req, base_req); + NCCL_OFI_TRACE_RECV(dev_id, r_comm->local_tag, sizes[0], req, base_req); - ret = receive_progress(recv_data->send_ctrl_req, true); + /* Eager TODO: post eager copy req */ + } + + ret = receive_progress(base_recv_data->send_ctrl_req, true); if (OFI_UNLIKELY(ret != 0)) { /* TODO: Remove req from message buffer */ goto error; } - if (eager) { - if (recv_data->eager_copy_req == NULL) { - /* If we don't need to do eager copy, this recv is already complete */ - ret = inc_req_completion(req, 0, recv_data->total_num_compls); - if (ret != 0) { - goto error; - } - } else { - /* Post eager copy */ - ret = receive_progress(recv_data->eager_copy_req, true); - if (ret != 0) { - NCCL_OFI_WARN("Failed to issue eager read"); - /* TODO: Remove req from message buffer */ - goto error; - } - } - } - /* Return request to NCCL */ - *base_req = (nccl_net_ofi_req_t *)req; + *base_req = (nccl_net_ofi_req_t *)multirecv_base_req; /* Increment next_msg_seq_num for next call */ - ++(r_comm->next_msg_seq_num); + r_comm->next_msg_seq_num += n; goto exit; @@ -3355,6 +3490,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, int *sizes, nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **base_req) { + assert(n == 1); int ret = 0; nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; @@ -3647,7 +3783,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen return NULL; } - ret = nccl_ofi_freelist_init_mr(sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t), 8, 8, + ret = nccl_ofi_freelist_init_mr(RDMA_CTRL_FL_ITEM_MAX_SIZE, 8, 8, NCCL_OFI_MAX_REQUESTS, freelist_regmr_host_fn, freelist_deregmr_host_fn, ep, 0, 1, &r_comm->ctrl_buff_fl); @@ -3908,7 +4044,7 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, ret = ncclInternalError; goto exit; } - + /* Set r_comm's (local) comm ID to be sent back to remote */ conn_msg->local_comm_id = r_comm->local_comm_id; @@ -4007,7 +4143,7 @@ static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm) if (l_comm->req.state == NCCL_OFI_RDMA_REQ_PENDING) { NCCL_OFI_WARN("Unable to free request of listen communicator. Request is still pending. Leaking memory."); - return ncclInternalError; + return 0; } if (l_comm->r_comm) { @@ -4135,7 +4271,7 @@ static int dereg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, } static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm_t *s_comm, - uint16_t msg_seq_num, + uint16_t msg_seq_num, int multi_recv_tag, void *buff, size_t size, nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, bool eager, bool have_ctrl, @@ -4178,6 +4314,11 @@ static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm_t *s_comm, req->msg_seq_num, send_data->schedule->num_xfer_infos); + /* Initialize for now. It will be populated later with correct info from receiver*/ + send_data->multi_recv_size = 0; + send_data->multi_recv_start = 0; + send_data->multi_recv_tag = multi_recv_tag; + *ret_req = req; return 0; @@ -4185,38 +4326,46 @@ static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm_t *s_comm, static int insert_rdma_send_req_into_msgbuff(nccl_net_ofi_rdma_send_comm_t *s_comm, int dev_id, bool have_ctrl, - nccl_net_ofi_rdma_req_t **ret_req) + nccl_net_ofi_rdma_req_t **ret_req, + bool *multi_send_ready) { nccl_net_ofi_rdma_req_t *req = *ret_req; nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; + rdma_req_send_data_t *send_data = get_send_data(req); + if (have_ctrl) { /* * There is already a buffer entry in the message buffer, * so replace it with a request. */ mb_res = nccl_ofi_msgbuff_replace(s_comm->msgbuff, - req->msg_seq_num, req->msg_seq_num, - 1, 0, req, + req->msg_seq_num, send_data->multi_recv_start, + send_data->multi_recv_size, + send_data->multi_recv_tag, + req, NCCL_OFI_MSGBUFF_REQ, - &msg_stat, NULL); + &msg_stat, multi_send_ready); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("Unexpected result of nccl_ofi_msgbuff_replace for msg %hu", req->msg_seq_num); return ncclSystemError; } } else { + assert(false); abort(); /* Try inserting the new request */ mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, - req->msg_seq_num, req->msg_seq_num, - 1, 0, req, + req->msg_seq_num, send_data->multi_recv_start, + send_data->multi_recv_size, + send_data->multi_recv_tag, req, NCCL_OFI_MSGBUFF_REQ, &msg_stat); if (OFI_UNLIKELY((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && (msg_stat == NCCL_OFI_MSGBUFF_INPROGRESS))) { /* Unlikely: a ctrl message was received on another thread. Return NULL and let NCCL call send again. */ + assert(false); req->free(req, false); *ret_req = NULL; } else if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_SUCCESS)) { @@ -4238,6 +4387,13 @@ static int post_rdma_write(nccl_net_ofi_rdma_req_t *req, struct fid_mr *rail_mr_handle = send_data->buff_mr_handle->mr[rail_id]; void *desc = fi_mr_desc(rail_mr_handle); + /* For multi-recv, in wdata, we need to make sure we use the same msg_seq_num as + receiver has, so recompute the wdata */ + send_data->wdata = + GET_RDMA_WRITE_IMM_DATA(((nccl_net_ofi_rdma_send_comm_t*)(req->comm))->remote_comm_id, + send_data->recv_side_msg_seq_num, + send_data->schedule->num_xfer_infos); + ssize_t rc; /* Post RDMA write */ rc = fi_writedata(comm_rail->local_ep, send_data->buff + xfer_info->offset, @@ -4409,7 +4565,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) uint64_t data = GET_RDMA_WRITE_IMM_DATA(r_comm->remote_comm_id, req->msg_seq_num, 0); ssize_t rc = fi_tsenddata(comm_rail->local_ep, &ctrl_fl_item->ctrl_msg, - sizeof(nccl_net_ofi_rdma_ctrl_msg_t), desc, + send_ctrl_data->ctrl_msg_size, desc, data, comm_rail->remote_addr, RDMA_DATA_TAG, req); if ((rc != 0) && (rc != -FI_EAGAIN)) { @@ -4568,6 +4724,57 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req) return ret; } +static int rdma_post_multi_send(nccl_net_ofi_rdma_send_comm_t *s_comm, uint16_t multi_recv_start, + uint16_t multi_recv_size) +{ + int ret = 0; + + nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; + + for (uint16_t idx = multi_recv_start; idx != (uint16_t)(multi_recv_start+multi_recv_size); ++idx) { + void *elem; + nccl_ofi_msgbuff_elemtype_t type; + nccl_ofi_msgbuff_status_t stat; + nccl_ofi_msgbuff_result_t res = nccl_ofi_msgbuff_retrieve_notag(s_comm->msgbuff, + idx, &elem, &type, &stat); + if (res != NCCL_OFI_MSGBUFF_SUCCESS) { + assert(false); + abort(); + } + assert(elem); + assert(type == NCCL_OFI_MSGBUFF_REQ); + assert(stat == NCCL_OFI_MSGBUFF_INPROGRESS); + nccl_net_ofi_rdma_req_t *req = elem; + + ret = send_progress(req); + if (ret == -FI_EAGAIN) { + /* Add to pending reqs queue */ + assert(ep != NULL); + ret = nccl_ofi_deque_insert_back(ep->pending_reqs_queue, &req->pending_reqs_elem); + if (ret != 0) { + assert(false); + NCCL_OFI_WARN("Failed to nccl_ofi_deque_insert_back: %d", ret); + goto exit; + } + } else if (OFI_UNLIKELY(ret != 0)) { + /* TODO: Remove req from message buffer */ + ret = -ENOTSUP; + assert(false); + goto exit; + } + } + assert(!ret); + if (ret) goto exit; + ret = process_cq_if_pending(ep); + if (ret == -FI_EAGAIN) { + ret = 0; + } else if (ret != 0) { + assert(false); + } +exit: + return ret; +} + /** * @brief Send a message. This "interface function" is called, indirectly, from * the application @@ -4627,11 +4834,6 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } - /* - * TODO: Use NCCL provided tags when using grouped receives aka - * props->maxRecvs > 1. - */ - bool have_ctrl = false; uint16_t msg_seq_num = s_comm->next_msg_seq_num; @@ -4640,9 +4842,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; - /* Retrive entry from message buffer for msg_seq_num index */ - mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, msg_seq_num, - 1, 0, &elem, &type, &msg_stat); + /* Retrive entry from message buffer for msg_seq_num index. + At this point we don't have multi-recv info */ + mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, 0, 0, tag, + &elem, &type, &msg_stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { if (type == NCCL_OFI_MSGBUFF_BUFF) { /* @@ -4652,8 +4855,13 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t have_ctrl = true; } else if (type == NCCL_OFI_MSGBUFF_REQ) { /* Shouldn't happen: we already have a req in the message buffer */ - NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num); - ret = ncclSystemError; + //NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num); + //ret = ncclSystemError; + ret = ofi_process_cq(ep); + if (ret != 0) { + goto error; + } + ret = ncclSuccess; goto error; } else { NCCL_OFI_WARN("Unexpected type of buffer retrieved from message buffer: %d", @@ -4662,13 +4870,21 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } } else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && - (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) { + (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED || msg_stat == NCCL_OFI_MSGBUFF_UNAVAILABLE)) { /* * We haven't encountered this message sequence number. * Allocate a request so that we are able to send RDMA write * as soon as we receive the RDMA control message. */ have_ctrl = false; + /** Just return a NULL req here **/ + /** Eager TODO: this will be an eager message if small enough */ + ret = ofi_process_cq(ep); + if (ret != 0) { + goto error; + } + ret = ncclSuccess; + goto free_req; } else { NCCL_OFI_WARN("Message %hu has invalid status. res = %d and stat = %d", msg_seq_num, mb_res, msg_stat); @@ -4678,33 +4894,27 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t /* Determine if this should be sent eagerly. */ bool eager = false; - if ((!have_ctrl && size <= eager_max_size) || - (size == 0)) { - eager = true; - } + /* Eager TODO */ - ret = alloc_rdma_send_req(s_comm, msg_seq_num, data, + ret = alloc_rdma_send_req(s_comm, msg_seq_num, tag, data, size, mr_handle, eager, have_ctrl, &req); if (OFI_UNLIKELY(ret != 0)) { goto error; } + assert(have_ctrl); if (have_ctrl) { /* * For already received RDMA control message, populate * the RDMA write metadata from the bounce buffer */ nccl_net_ofi_rdma_req_t *bounce_req = elem; - copy_ctrl_data(bounce_req, req); - - /* Post if needed */ - ret = check_post_bounce_req(bounce_req); - if (OFI_UNLIKELY(ret != 0)) { - goto error; - } + copy_ctrl_data(bounce_req, req, tag); } - ret = insert_rdma_send_req_into_msgbuff(s_comm, dev_id, have_ctrl, &req); + bool multi_send_ready = false; + ret = insert_rdma_send_req_into_msgbuff(s_comm, dev_id, have_ctrl, &req, + &multi_send_ready); if (ret != 0 || req == NULL) { goto free_req; } @@ -4717,25 +4927,24 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t NCCL_OFI_TRACE_SEND(req->dev_id, size, s_comm, msg_seq_num, req, base_req); - /* Try posting RDMA write for received RDMA control messages */ - if (have_ctrl || eager) { + assert(!eager); - ret = send_progress(req); - if (ret == -FI_EAGAIN) { - /* Add to pending reqs queue */ - ret = nccl_ofi_deque_insert_back(ep->pending_reqs_queue, &req->pending_reqs_elem); - if (ret != 0) { - NCCL_OFI_WARN("Failed to nccl_ofi_deque_insert_back: %d", ret); - goto error; - } - NCCL_OFI_TRACE_PENDING_INSERT(req); - } else if (OFI_UNLIKELY(ret != 0)) { - /* TODO: Remove req from message buffer */ - ret = -ENOTSUP; + if (multi_send_ready) { + rdma_req_send_data_t *send_data = get_send_data(req); + + ret = rdma_post_multi_send(s_comm, send_data->multi_recv_start, + send_data->multi_recv_size); + + /* Re-post bounce buffer if needed */ + nccl_net_ofi_rdma_req_t *bounce_req = elem; + ret = check_post_bounce_req(bounce_req); + if (OFI_UNLIKELY(ret != 0)) { goto error; } } + /* Eager TODO: post eager message */ + /* Return request to NCCL */ *base_req = &req->base; /* Increment next_msg_seq_num for next call */ @@ -5305,8 +5514,8 @@ static int connect(nccl_net_ofi_ep_t *base_ep, ret = ofi_process_cq(ep); if (OFI_UNLIKELY(ret != 0)) { /* Send communicator cannot be closed since - * send request of send connect message is - * still pending */ + * send request of send connect message is + * still pending */ return ret; } @@ -5607,8 +5816,7 @@ static int get_ep(nccl_net_ofi_device_t *base_dev, /* Initialize reference count */ ep->ref_cnt = 0; - ep->bounce_buff_size = NCCL_OFI_MAX(sizeof(nccl_net_ofi_rdma_ctrl_msg_t), - eager_max_size); + ep->bounce_buff_size = NCCL_OFI_MAX(RDMA_CTRL_MSG_MAX_SIZE, eager_max_size); /* Store endpoint in thread-local variable */ pthread_setspecific(device->ep_key, (void *)ep); @@ -5983,7 +6191,9 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, ret = ncclInvalidArgument; goto error; } - eager_max_size = (size_t) ofi_nccl_eager_max_size(); + /* Eager TODO: support eager_max_size */ + // eager_max_size = (size_t) ofi_nccl_eager_max_size(); + eager_max_size = 0; plugin = malloc(sizeof(nccl_net_ofi_plugin_t)); if (!plugin) { From 5fddb2e86a026f0a9ff085ce6830be0eca26360c Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Sat, 24 Feb 2024 01:15:06 +0000 Subject: [PATCH 6/6] Advertise multi-recv support to NCCL for RDMA protocol RDMA protocol will now support up to 8 multi-recv buffers at a time. Signed-off-by: Eric Raut --- include/nccl_ofi.h | 3 ++- src/nccl_ofi_net.c | 2 +- src/nccl_ofi_rdma.c | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 497deea02..7212f86cc 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -59,7 +59,8 @@ extern "C" { #define MIN_TAG_BITS_FOR_RING_ID (32 + 1) /* Maximum number of grouped receives */ -#define NCCL_OFI_MAX_RECVS 1 +#define NCCL_OFI_MAX_RECVS 8 +#define NCCL_OFI_MAX_RECVS_SENDRECV 1 /* * This defines a higher value than maximum inflight requests supported by NCCL diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index 9c6f57049..d4d7b6056 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -331,7 +331,7 @@ static int set_nic_props_default(int dev_id, struct fi_info *nic_prov, * impacted with this feature as NCCL doesn't aggregate receives from * same source. */ - props->max_group_receives = NCCL_OFI_MAX_RECVS; + props->max_group_receives = NCCL_OFI_MAX_RECVS_SENDRECV; if (support_gdr == GDR_SUPPORTED) { props->hmem_support = true; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index aae931d62..337764b1e 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -597,6 +597,9 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev, struct fi_info *info = device->device_rails[0].info; int ret = nccl_net_ofi_info_properties(info, dev_id, base_dev->plugin->num_devs, props); + /* Multi-recv adjustment */ + props->max_group_receives = NCCL_OFI_MAX_RECVS; + /* Scale speed by the total number of rails. Assume that all * reails have the same speed. */ if (ret == 0) {