Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rdma: support NCCL multi-recv interface #348

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ extern "C" {
#define MIN_TAG_BITS_FOR_RING_ID (32 + 1)

/* Maximum number of grouped receives */
#define NCCL_OFI_MAX_RECVS 1
#define NCCL_OFI_MAX_RECVS 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're going to break out the two protocols (which makes sense), then we should move these into their corresponding header files. And we really shouldn't leave something named as NCCL_OFI_MAX_RECVS when it's only sometimes right.

#define NCCL_OFI_MAX_RECVS_SENDRECV 1

/*
* This defines a higher value than maximum inflight requests supported by NCCL
Expand Down
25 changes: 21 additions & 4 deletions include/nccl_ofi_msgbuff.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ typedef struct {
// Type of element
nccl_ofi_msgbuff_elemtype_t type;
void *elem;
// Multi-recv information
uint16_t multi_recv_size;
uint16_t multi_recv_start;
int multi_recv_tag;
} nccl_ofi_msgbuff_elem_t;

typedef struct {
Expand Down Expand Up @@ -110,9 +114,14 @@ bool nccl_ofi_msgbuff_destroy(nccl_ofi_msgbuff_t *msgbuff);
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type,
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag,
void *elem, nccl_ofi_msgbuff_elemtype_t type,
nccl_ofi_msgbuff_status_t *msg_idx_status);

nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_ctrl_multirecv(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_base_index, uint16_t multi_recv_size, int *tags, void *elem,
nccl_ofi_msgbuff_elemtype_t type, nccl_ofi_msgbuff_status_t *msg_idx_status);

/**
* Replace an existing message element
*
Expand All @@ -126,8 +135,9 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff,
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type,
nccl_ofi_msgbuff_status_t *msg_idx_status);
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size,
int multi_recv_tag, void *elem, nccl_ofi_msgbuff_elemtype_t type,
nccl_ofi_msgbuff_status_t *msg_idx_status, bool *multi_send_ready);

/**
* Retrieve message with given index
Expand All @@ -142,6 +152,12 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff,
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size,
int multi_recv_tag, void **elem, nccl_ofi_msgbuff_elemtype_t *type,
nccl_ofi_msgbuff_status_t *msg_idx_status);

/* As above, but with no tag */
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve_notag(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, void **elem, nccl_ofi_msgbuff_elemtype_t *type,
nccl_ofi_msgbuff_status_t *msg_idx_status);

Expand All @@ -156,7 +172,8 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff,
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, nccl_ofi_msgbuff_status_t *msg_idx_status);
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size,
int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status);

#ifdef _cplusplus
} // End extern "C"
Expand Down
39 changes: 36 additions & 3 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,32 @@ typedef struct nccl_net_ofi_rdma_mr_handle {
struct fid_mr *mr[];
} nccl_net_ofi_rdma_mr_handle_t;

/* Contents of ctrl message sent from receiver to sender to advertise
destination buffer */
typedef struct nccl_net_ofi_rdma_ctrl_msg {
typedef struct nccl_net_ofi_rdma_ctrl_msg_entry {
int multi_recv_tag;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multirecv tag should go at the end for padding reasons.

uint64_t buff_addr;
uint64_t buff_len;
uint64_t buff_mr_key[MAX_NUM_RAILS];
} nccl_net_ofi_rdma_ctrl_msg_entry_t;

/* Contents of ctrl message sent from receiver to sender to advertise
destination buffer */
typedef struct nccl_net_ofi_rdma_ctrl_msg {
uint16_t msg_seq_num;
uint16_t multi_recv_size;
nccl_net_ofi_rdma_ctrl_msg_entry_t entries[];
} nccl_net_ofi_rdma_ctrl_msg_t;

#define RDMA_CTRL_MSG_ENTRIES_MAX_SIZE (NCCL_OFI_MAX_RECVS * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t))
#define RDMA_CTRL_MSG_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE)

/* Structure used to store control messages in a free list */
typedef struct nccl_net_ofi_rdma_ctrl_fl_item {
nccl_ofi_freelist_reginfo_t fl_reginfo;
nccl_net_ofi_rdma_ctrl_msg_t ctrl_msg;
} nccl_net_ofi_rdma_ctrl_fl_item_t;

#define RDMA_CTRL_FL_ITEM_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE)

/* For LL/LL128 protocols, bounce buffers (source of RDMA read operations) need to be 128B aligned */
#define BOUNCE_BUFFER_ALIGNMENT 128

Expand Down Expand Up @@ -152,6 +164,13 @@ typedef struct {
/* Total number of completions. Expect one completion for receiving the
* control message and one completion for each send segment. */
int total_num_compls;

/* Multi-recv information */
uint16_t multi_recv_size;
uint16_t multi_recv_start;
int multi_recv_tag;
/* This may not match sender-side seq num with multi-recv */
uint16_t recv_side_msg_seq_num;
} rdma_req_send_data_t;

/*
Expand All @@ -166,6 +185,8 @@ typedef struct {
nccl_net_ofi_schedule_t *ctrl_schedule;
/* Pointer to recv parent request */
nccl_net_ofi_rdma_req_t *recv_req;
/* Size of ctrl message */
size_t ctrl_msg_size;
} rdma_req_send_ctrl_data_t;

typedef struct {
Expand Down Expand Up @@ -206,6 +227,12 @@ typedef struct {
* For eager messages, the second completion will be received
* when the local read into the destination buffer is complete */
int total_num_compls;
/* Multi-recv information */
uint16_t multi_recv_size;
uint16_t multi_recv_start;
int multi_recv_tag;
/* Next req in sequence */
nccl_net_ofi_rdma_req_t *multi_recv_next;
} rdma_req_recv_data_t;

/*
Expand Down Expand Up @@ -345,6 +372,12 @@ typedef struct nccl_net_ofi_rdma_send_comm {
/* Comm ID provided by remote endpoint */
uint64_t remote_comm_id;

/* Request to send connect message */
nccl_net_ofi_rdma_req_t *send_conn_req;

/* Indicates if connect message was delivered (and req freed) */
bool connect_msg_delivered;

/* Request to receive connect response message to finalize
* connection establishment */
nccl_net_ofi_rdma_req_t *conn_resp_req;
Expand Down
Loading