Skip to content

Commit

Permalink
Ensure text message exists before handling on WebsocketConsumer (#2097
Browse files Browse the repository at this point in the history
)

* fix(channels/generic): ensure text message exists before deciding to handle

* tests(channels/generic): regression test for double check of text message None

* refactor(channels/generic): short condition

* lint: fix flake8 errors
  • Loading branch information
cacosandon authored Sep 4, 2024
1 parent 3ea0817 commit 643d083
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
4 changes: 2 additions & 2 deletions channels/generic/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def websocket_receive(self, message):
Called when a WebSocket frame is received. Decodes it and passes it
to receive().
"""
if "text" in message:
if message.get("text") is not None:
self.receive(text_data=message["text"])
else:
self.receive(bytes_data=message["bytes"])
Expand Down Expand Up @@ -200,7 +200,7 @@ async def websocket_receive(self, message):
Called when a WebSocket frame is received. Decodes it and passes it
to receive().
"""
if "text" in message:
if message.get("text") is not None:
await self.receive(text_data=message["text"])
else:
await self.receive(bytes_data=message["bytes"])
Expand Down
51 changes: 51 additions & 0 deletions tests/test_generic_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,54 @@ async def connect(self):
assert msg["type"] == "websocket.close"
assert msg["code"] == 4007
assert msg["reason"] == "test reason"


@pytest.mark.django_db
@pytest.mark.asyncio
async def test_websocket_receive_with_none_text():
"""
Tests that the receive method handles messages with None text data correctly.
"""

class TestConsumer(WebsocketConsumer):
def receive(self, text_data=None, bytes_data=None):
if text_data:
self.send(text_data="Received text: " + text_data)
elif bytes_data:
self.send(text_data=f"Received bytes of length: {len(bytes_data)}")

app = TestConsumer()

# Open a connection
communicator = WebsocketCommunicator(app, "/testws/")
connected, _ = await communicator.connect()
assert connected

# Simulate Hypercorn behavior
# (both 'text' and 'bytes' keys present, but 'text' is None)
await communicator.send_input(
{
"type": "websocket.receive",
"text": None,
"bytes": b"test data",
}
)
response = await communicator.receive_output()
assert response["type"] == "websocket.send"
assert response["text"] == "Received bytes of length: 9"

# Test with only 'bytes' key (simulating uvicorn/daphne behavior)
await communicator.send_input({"type": "websocket.receive", "bytes": b"more data"})
response = await communicator.receive_output()
assert response["type"] == "websocket.send"
assert response["text"] == "Received bytes of length: 9"

# Test with valid text data
await communicator.send_input(
{"type": "websocket.receive", "text": "Hello, world!"}
)
response = await communicator.receive_output()
assert response["type"] == "websocket.send"
assert response["text"] == "Received text: Hello, world!"

await communicator.disconnect()

0 comments on commit 643d083

Please sign in to comment.