Skip to content

Commit

Permalink
feat(tracing): add nvtx provider
Browse files Browse the repository at this point in the history
Hook nvtx on existing lttng macros.

We figured out how to structure this in a way that
aligns the required usages of nvtx with cases
like NCCL_OFI_TRACE_SEND_WRITE_SEG COMPLETE/START. We use the NVTX
start/end API for ranges, and mark API for events.

Only supports RDMA protocol for now, SENDRECV protocol NVTX support will
be added in the future.

Signed-off-by: Eric Raut <[email protected]>
  • Loading branch information
rauteric committed Apr 11, 2024
1 parent 774a14d commit 04d5d66
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 18 deletions.
1 change: 1 addition & 0 deletions include/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ noinst_HEADERS = \
nccl_ofi_ofiutils.h \
nccl_ofi_tracepoint.h \
tracing_impl/lttng.h \
tracing_impl/nvtx.h \
nccl-headers/net.h \
nccl-headers/error.h \
nccl-headers/nvidia/err.h \
Expand Down
24 changes: 24 additions & 0 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ extern "C" {
#include "nccl_ofi_deque.h"
#include "nccl_ofi_freelist.h"
#include "nccl_ofi_idpool.h"
#include "nccl_ofi_tracepoint.h"

/* Maximum number of rails supported. This defines the size of
* messages exchanged during connection establishment (linear
Expand Down Expand Up @@ -170,6 +171,10 @@ typedef struct {
/* Total number of completions. Expect one completion for receiving the
* control message and one completion for each send segment. */
int total_num_compls;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
nvtxRangeId_t seg_trace_id[MAX_NUM_RAILS];
#endif
} rdma_req_send_data_t;

/*
Expand All @@ -184,6 +189,9 @@ typedef struct {
nccl_net_ofi_schedule_t *ctrl_schedule;
/* Pointer to recv parent request */
nccl_net_ofi_rdma_req_t *recv_req;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
#endif
} rdma_req_send_ctrl_data_t;

typedef struct {
Expand Down Expand Up @@ -224,6 +232,9 @@ typedef struct {
* For eager messages, the second completion will be received
* when the local read into the destination buffer is complete */
int total_num_compls;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
#endif
} rdma_req_recv_data_t;

/*
Expand Down Expand Up @@ -403,8 +414,13 @@ typedef struct nccl_net_ofi_rdma_send_comm {
* and `num_init_rails' is adjusted. */
int num_init_rails;

#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM];
#endif

/* Array of `num_rails` communicator rails */
nccl_net_ofi_rdma_send_comm_rail_t rails[];

} nccl_net_ofi_rdma_send_comm_t;

/*
Expand Down Expand Up @@ -465,6 +481,10 @@ typedef struct nccl_net_ofi_rdma_recv_comm {
/* Free list to track control buffers, for sending RDMA control messages */
nccl_ofi_freelist_t *ctrl_buff_fl;

#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM];
#endif

/* Number of rails */
int num_rails;

Expand Down Expand Up @@ -659,6 +679,10 @@ typedef struct nccl_net_ofi_rdma_device {

/* Memory registration key pool */
nccl_ofi_idpool_t key_pool;

#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[MAX_NUM_RAILS];
#endif
} nccl_net_ofi_rdma_device_t;

/*
Expand Down
62 changes: 45 additions & 17 deletions include/nccl_ofi_tracepoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define NCCL_OFI_TRACEPOINT_H_

#include "config.h"
#include "tracing_impl/nvtx.h"
#include "tracing_impl/lttng.h"

/***** SENDRECV PROTOCOL *****/
Expand All @@ -27,52 +28,79 @@
} while(0)

/***** RDMA PROTOCL *****/

#define NCCL_OFI_TRACE_SEND(dev, size, comm, msg_seq_num, request, nccl_req) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send, dev, size, comm, msg_seq_num, request, nccl_req); \
} while(0)
NCCL_OFI_TRACE_SEND_NVTX(dev, size, comm, msg_seq_num, request, nccl_req); \
} while(0)

#define NCCL_OFI_TRACE_SEND_END(request) do { \
NCCL_OFI_TRACE_SEND_END_NVTX(request); \
} while(0)

#define NCCL_OFI_TRACE_SEND_CTRL_RECV(dev, rail_id, comm, msg_seq_num) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send_ctrl_recv, dev, rail_id, comm, msg_seq_num); \
} while (0)
lttng_ust_tracepoint(nccl_ofi_plugin, Send_ctrl_recv, dev, rail_id, comm, msg_seq_num); \
NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(dev, rail_id, comm, msg_seq_num); \
} while (0)

#define NCCL_OFI_TRACE_SEND_CTRL_START(dev, rail_id, comm, req, msg_seq_num) do { \
NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(dev, rail_id, comm, req, msg_seq_num); \
} while (0);

#define NCCL_OFI_TRACE_SEND_CTRL_END(dev, rail_id, comm, req, msg_seq_num) do { \
NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(dev, rail_id, comm, req, msg_seq_num); \
} while (0);

#define NCCL_OFI_TRACE_SEND_WRITE_SEG_START(dev, rail_id, size, comm, msg_seq_num, request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_start, dev, rail_id, size, comm, msg_seq_num, request); \
} while(0)
lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_start, dev, rail_id, size, comm, msg_seq_num, request); \
NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request); \
} while(0)

#define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(dev, rail_id, comm, msg_seq_num, request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_complete, dev, rail_id, comm, msg_seq_num, request); \
} while(0)
NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request); \
} while(0)

#define NCCL_OFI_TRACE_RECV(dev, tag, size, request, nccl_req) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv, dev, tag, size, request, nccl_req); \
} while(0)
NCCL_OFI_TRACE_RECV_NVTX(dev, tag, size, request, nccl_req); \
} while(0)

#define NCCL_OFI_TRACE_RECV_END(request) do { \
NCCL_OFI_TRACE_RECV_END_NVTX(request); \
} while(0)

#define NCCL_OFI_TRACE_RECV_CTRL_SEND_COMPLETE(request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv_ctrl_send_complete, request); \
} while(0)
lttng_ust_tracepoint(nccl_ofi_plugin, Recv_ctrl_send_complete, request); \
} while(0)

#define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE(dev, rail_id, size, request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv_segment_complete, dev, rail_id, size, request); \
} while(0)
NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(dev, rail_id, size, request); \
} while(0)

#define NCCL_OFI_TRACE_EAGER_RECV(dev, rail_id, comm, msg_seq_num) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Eager_recv, dev, rail_id, comm, msg_seq_num); \
} while(0)
lttng_ust_tracepoint(nccl_ofi_plugin, Eager_recv, dev, rail_id, comm, msg_seq_num); \
NCCL_OFI_TRACE_EAGER_RECV_NVTX(dev, rail_id, comm, msg_seq_num); \
} while(0)

#define NCCL_OFI_TRACE_COMPLETIONS(request,ctx) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, ProcessCompletions, request,ctx); \
} while(0)
} while(0)

#define NCCL_OFI_TRACE_FLUSH(request, nccl_req) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Flush, request, nccl_req); \
} while(0)
NCCL_OFI_TRACE_FLUSH_NVTX(request, nccl_req); \
} while(0)

#define NCCL_OFI_TRACE_PENDING_INSERT(request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Pending_queue_insert, request); \
} while(0)
NCCL_OFI_TRACE_PENDING_INSERT_NVTX(request); \
} while(0)

#define NCCL_OFI_TRACE_PENDING_REMOVE(request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Pending_queue_remove, request); \
} while(0)
NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(request); \
} while(0)

#endif /* NCCL_OFI_TRACEPOINT_H_ */
#endif /* NCCL_OFI_TRACEPOINT_H_ */
190 changes: 190 additions & 0 deletions include/tracing_impl/nvtx.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Copyright (c) 2022-2024 Amazon.com, Inc. or its affiliates. All rights reserved.
*/

#ifndef NVTX_H
#define NVTX_H

#if HAVE_NVTX_TRACING
#include "nvToolsExt.h"

static inline void nvtx_mark_domain(nvtxDomainHandle_t domain, const char* name, uint32_t color)
{
const nvtxEventAttributes_t eventAttrib = {
.version = NVTX_VERSION,
.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE,
.colorType = NVTX_COLOR_ARGB,
.color = color,
.messageType = NVTX_MESSAGE_TYPE_ASCII,
.message = { .ascii = name },
};
nvtxDomainMarkEx(domain, &eventAttrib);
}

static inline nvtxRangeId_t nvtx_start_domain(bool have_domain, nvtxDomainHandle_t domain, const char* name, uint32_t color) {
const nvtxEventAttributes_t eventAttrib = {
.version = NVTX_VERSION,
.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE,
.colorType = NVTX_COLOR_ARGB,
.color = color,
.messageType = NVTX_MESSAGE_TYPE_ASCII,
.message = { .ascii = name },
};
if (have_domain)
return nvtxDomainRangeStartEx(domain, &eventAttrib);
else
return nvtxRangeStartEx(&eventAttrib);
}

static inline nvtxRangeId_t nvtx_start(const char* name, uint32_t color) {
return nvtx_start_domain(false, 0, name, color);
}

static inline void nvtx_end_domain(nvtxDomainHandle_t domain, nvtxRangeId_t id) {
nvtxDomainRangeEnd(domain, id);
}

static inline void nvtx_end(nvtxRangeId_t id) {
nvtxRangeEnd(id);
}

#define NCCL_OFI_TRACE_SEND_NVTX(dev, size, comm, msg_seq_num, request, nccl_req) do { \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)comm) \
->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
get_send_data(request)->trace_id = nvtx_start_domain(true, handle, "Send", 0xeb9234); \
} \
} while (0)

#define NCCL_OFI_TRACE_SEND_END_NVTX(request) do { \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)(request->comm)) \
->nvtx_domain[request->msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
nvtx_end_domain(handle, get_send_data(request)->trace_id); \
} \
} while(0)

#define NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(dev, rail_id, comm, msg_seq_num) do { \
nvtxDomainHandle_t handle; \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \
} \
if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \
handle = ((nccl_net_ofi_rdma_device_t*)(comm->base.base.ep->device))->nvtx_domain[rail_id]; \
nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \
} \
} while (0)

#define NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(dev, rail_id, comm, req, msg_seq_num) do { \
nvtxDomainHandle_t handle; \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
get_send_ctrl_data(req)->trace_id = nvtx_start_domain(true, handle, "Send_ctrl_start", 0x00ffff); \
} \
if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \
handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \
get_send_ctrl_data(req)->trace_id = nvtx_start_domain(true, handle, "Send_ctrl_start", 0x00ffff); \
} \
} while (0)

#define NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(dev, rail_id, comm, req, msg_seq_num) do { \
nvtxDomainHandle_t handle; \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
nvtx_end_domain(handle, get_send_ctrl_data(req)->trace_id); \
} \
if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \
handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \
nvtx_end_domain(handle, get_send_ctrl_data(req)->trace_id);\
} \
} while (0)

#define NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request) do { \
nvtxDomainHandle_t handle; \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \
} \
if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \
handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \
get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \
} \
} while(0)

#define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request) do { \
nvtxDomainHandle_t handle; \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \
} \
if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \
handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \
nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \
} \
} while(0)

#define NCCL_OFI_TRACE_RECV_NVTX(dev, tag, size, request, nccl_req) do { \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \
->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
get_recv_data(request)->trace_id = nvtx_start_domain(true, handle, "Recv", 0x34EB37); \
} \
} while(0)

#define NCCL_OFI_TRACE_RECV_END_NVTX(request) do { \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \
->nvtx_domain[request->msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
nvtx_end_domain(handle, get_recv_data(request)->trace_id); \
} \
} while(0)

#define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(dev, rail_id, size, request) do { \
nvtxDomainHandle_t handle; \
if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \
handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm)->nvtx_domain[request->msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \
} \
if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \
handle = ((nccl_net_ofi_rdma_device_t*)(request->comm->ep->device))->nvtx_domain[rail_id]; \
nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \
} \
} while(0)

#define NCCL_OFI_TRACE_EAGER_RECV_NVTX(dev, rail_id, comm, msg_seq_num) do { \
nvtx_mark_domain(NULL, "Eager_recv", 0x0000FF); \
} while(0)

#define NCCL_OFI_TRACE_FLUSH_NVTX(request, nccl_req) do { \
nvtx_mark_domain(NULL, "Flush", 0xA52A2A); \
} while(0)

#define NCCL_OFI_TRACE_PENDING_INSERT_NVTX(request) do { \
nvtx_mark_domain(NULL, "Pending_insert", 0xFF8C00); \
} while(0)

#define NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(request) do { \
nvtx_mark_domain(NULL, "Pending_remove", 0xFF8C00); \
} while(0)

#else

#define NCCL_OFI_TRACE_SEND_NVTX(...)
#define NCCL_OFI_TRACE_SEND_END_NVTX(...)
#define NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(...)
#define NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(...)
#define NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(...)
#define NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(...)
#define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(...)
#define NCCL_OFI_TRACE_RECV_NVTX(...)
#define NCCL_OFI_TRACE_RECV_END_NVTX(...)
#define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(...)
#define NCCL_OFI_TRACE_EAGER_RECV_NVTX(...)
#define NCCL_OFI_TRACE_FLUSH_NVTX(...)
#define NCCL_OFI_TRACE_PENDING_INSERT_NVTX(...)
#define NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(...)

#endif

#endif /* NVTX_H */
Loading

0 comments on commit 04d5d66

Please sign in to comment.