Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialization registry #398

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
If you want to enforce a matching order, use an ``OrderedDict`` as the
argument; channels will then be matched in the order the dict provides them.

.. _encryption
``symmetric_encryption_keys``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -237,6 +238,51 @@ And then in your channels consumer, you can implement the handler:
async def redis_disconnect(self, *args):
# Handle disconnect



``serializer_format``
~~~~~~~~~~~~~~~~~~~~~~
By default every message sent to redis is encoded using `msgpack <https://msgpack.org/>`_ (_currently ``msgpack`` is a mandatory dependency of this package, it may become optional in a future release_).
It is also possible to switch to `JSON <http://www.json.org/>`_:

.. code-block:: python

CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": ["redis://:[email protected]:6379/0"],
"serializer_format": "json",
},
},
}


Custom serializer can be defined by:

- extending ``channels_redis.serializers.BaseMessageSerializer``, implementing ``as_bytes `` and ``from_bytes`` methods
- using any class which accepts generic keyword arguments and provides ``serialize``/``deserialize`` methods

Then it may be registered (or can be overriden) by using ``channels_redis.serializers.registry``:

.. code-block:: python

from channels_redis.serializers import registry

class MyFormatSerializer:
def serialize(self, message):
bigfootjon marked this conversation as resolved.
Show resolved Hide resolved
...
def deserialize(self, message):
...

registry.register_serializer('myformat', MyFormatSerializer)

**NOTE**: the registry allows to override the serializer class used for a specific format without any particular check nor constraint, thus it is recommended to pay attention with order-of-imports when using third-party serializers which may override a built-in format.


Serializers are also responsible for encryption *symmetric_encryption_keys*. When extending ``channels_redis.serializers.BaseMessageSerializer`` encryption is already configured in the base class, unless you override ``serialize``/``deserialize`` methods: in this case you should call ``self.crypter.encrypt`` in serialization and ``self.crypter.decrypt`` in deserialization process. When using full custom serializer expect an optional sequence of keys to be passed via ``symmetric_encryption_keys``.


Dependencies
------------

Expand Down
61 changes: 13 additions & 48 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import asyncio
import base64
import collections
import functools
import hashlib
import itertools
import logging
import random
import time
import uuid

import msgpack
from redis import asyncio as aioredis

from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
Expand Down Expand Up @@ -115,6 +112,8 @@ def __init__(
capacity=100,
channel_capacity=None,
symmetric_encryption_keys=None,
random_prefix_length=12,
serializer_format="msgpack",
):
# Store basic information
self.expiry = expiry
Expand All @@ -126,15 +125,21 @@ def __init__(
# Configure the host objects
self.hosts = decode_hosts(hosts)
self.ring_size = len(self.hosts)
# serialization
self._serializer = registry.get_serializer(
serializer_format,
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
random_prefix_length=random_prefix_length,
expiry=self.expiry,
symmetric_encryption_keys=symmetric_encryption_keys,
)
# Cached redis connection pools and the event loop they are from
self._layers = {}
# Normal channels choose a host index by cycling through the available hosts
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
# Decide on a unique client prefix to use in ! sections
self.client_prefix = uuid.uuid4().hex
# Set up any encryption objects
self._setup_encryption(symmetric_encryption_keys)
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
Expand All @@ -154,24 +159,6 @@ def __init__(
def create_pool(self, index):
return create_pool(self.hosts[index])

def _setup_encryption(self, symmetric_encryption_keys):
# See if we can do encryption if they asked
if symmetric_encryption_keys:
if isinstance(symmetric_encryption_keys, (str, bytes)):
raise ValueError(
"symmetric_encryption_keys must be a list of possible keys"
)
try:
from cryptography.fernet import MultiFernet
except ImportError:
raise ValueError(
"Cannot run with encryption without 'cryptography' installed."
)
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
self.crypter = MultiFernet(sub_fernets)
else:
self.crypter = None

### Channel layer API ###

extensions = ["groups", "flush"]
Expand Down Expand Up @@ -656,41 +643,19 @@ def serialize(self, message):
"""
Serializes message to a byte string.
"""
value = msgpack.packb(message, use_bin_type=True)
if self.crypter:
value = self.crypter.encrypt(value)

# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
return random_prefix + value
return self._serializer.serialize(message)

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
# Removes the random prefix
message = message[12:]

if self.crypter:
message = self.crypter.decrypt(message, self.expiry + 10)
return msgpack.unpackb(message, raw=False)
return self._serializer.deserialize(message)

### Internal functions ###

def consistent_hash(self, value):
return _consistent_hash(value, self.ring_size)

def make_fernet(self, key):
"""
Given a single encryption key, returns a Fernet instance using it.
"""
from cryptography.fernet import Fernet

if isinstance(key, str):
key = key.encode("utf8")
formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest())
return Fernet(formatted_key)

def __str__(self):
return f"{self.__class__.__name__}(hosts={self.hosts})"

Expand Down
19 changes: 15 additions & 4 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import logging
import uuid

import msgpack
from redis import asyncio as aioredis

from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
Expand All @@ -25,10 +25,21 @@ async def _async_proxy(obj, name, *args, **kwargs):


class RedisPubSubChannelLayer:
def __init__(self, *args, **kwargs) -> None:
def __init__(
self,
*args,
symmetric_encryption_keys=None,
serializer_format="msgpack",
**kwargs,
) -> None:
self._args = args
self._kwargs = kwargs
self._layers = {}
# serialization
self._serializer = registry.get_serializer(
serializer_format,
symmetric_encryption_keys=symmetric_encryption_keys,
)

def __getattr__(self, name):
if name in (
Expand All @@ -48,13 +59,13 @@ def serialize(self, message):
"""
Serializes message to a byte string.
"""
return msgpack.packb(message)
return self._serializer.serialize(message)

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
return msgpack.unpackb(message)
return self._serializer.deserialize(message)

def _get_layer(self):
loop = asyncio.get_running_loop()
Expand Down
Loading
Loading