Skip to content

Commit

Permalink
Fix a pytest bug (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
saeedmaleki authored Oct 13, 2023
1 parent 8c0f9e8 commit 148681b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)

@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
Expand All @@ -139,6 +139,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
memory_expected = memory.copy()
else:
memory = xp.zeros(nelem, dtype=xp.float32)
if device == "cuda":
cp.cuda.runtime.deviceSynchronize()

signal_memory = xp.zeros(1, dtype=xp.int64)
all_reg_memories = group.register_tensor_with_connections(memory, connections)
Expand All @@ -156,6 +158,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
connections[next_rank].flush()
if group.my_rank == 0:
memory[:] = 0
if device == "cuda":
cp.cuda.runtime.deviceSynchronize()
connections[next_rank].update_and_sync(
all_signal_memories[next_rank], 0, dummy_memory_on_cpu.ctypes.data, signal_val
)
Expand Down
4 changes: 2 additions & 2 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ void IBConnection::flush(int64_t timeoutUsec) {
}

auto elapsed = timer.elapsed();
if ((timeoutUsec >= 0) && (elapsed * 1e3 > timeoutUsec)) {
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e3) + " seconds. Expected " +
if ((timeoutUsec >= 0) && (elapsed > timeoutUsec)) {
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " +
std::to_string(numSignaledSends) + " signals",
ErrorCode::InternalError);
}
Expand Down

0 comments on commit 148681b

Please sign in to comment.