Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Dadeos-Menlo committed Jul 8, 2024
1 parent e169be7 commit a51a80b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
3 changes: 3 additions & 0 deletions tornado/netutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def bind_unix_socket(
# Hurd doesn't support SO_REUSEADDR
raise
sock.setblocking(False)
# File names comprising of an initial null-byte denote an abstract
# namespace, on Linux, and therefore are not subject to file system
# orientated processing.
if not file.startswith("\0"):
try:
st = os.stat(file)
Expand Down
59 changes: 32 additions & 27 deletions tornado/test/httpserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import textwrap
import unittest
import urllib.parse
import uuid
from io import BytesIO

import typing
Expand Down Expand Up @@ -815,11 +816,7 @@ def test_manual_protocol(self):
self.assertEqual(self.fetch_json("/")["protocol"], "https")


@unittest.skipIf(
not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
"unix sockets not supported on this platform",
)
class UnixSocketTest(AsyncTestCase):
class UnixSocketTest:
"""HTTPServers can listen on Unix sockets too.
Why would you want to do this? Nginx can proxy to backends listening
Expand All @@ -832,38 +829,19 @@ class UnixSocketTest(AsyncTestCase):

def setUp(self):
super().setUp()
self.tmpdir = tempfile.mkdtemp()
self.sockfile = os.path.join(self.tmpdir, "test.sock")
app = Application([("/hello", HelloWorldRequestHandler)])
self.server = HTTPServer(app)
if sys.platform.startswith("linux"):
self.sockabstract = "\0" + os.path.basename(self.tmpdir)
self.server.add_socket(netutil.bind_unix_socket(self.sockabstract))
self.server.add_socket(netutil.bind_unix_socket(self.sockfile))
self.server.add_socket(netutil.bind_unix_socket(self.address))

def tearDown(self):
self.io_loop.run_sync(self.server.close_all_connections)
self.server.stop()
shutil.rmtree(self.tmpdir)
super().tearDown()

@gen_test
def test_unix_socket(self):
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.sockfile)
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
response = yield stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
header_data = yield stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(header_data.decode("latin1"))
body = yield stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")

@unittest.skipUnless(sys.platform.startswith("linux"), "requires Linux")
@gen_test
def test_unix_socket_abstract(self):
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.sockabstract)
stream.connect(self.address)
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
response = yield stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
Expand All @@ -878,12 +856,39 @@ def test_unix_socket_bad_request(self):
# empty string.
with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO):
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.sockfile)
stream.connect(self.address)
stream.write(b"garbage\r\n\r\n")
response = yield stream.read_until_close()
self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")


@unittest.skipIf(
not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
"unix sockets not supported on this platform"
)
class UnixSocketTestAbstract(UnixSocketTest, AsyncTestCase):

def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.address = os.path.join(self.tmpdir, "test.sock")
super().setUp()

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmpdir)


@unittest.skipIf(
not (hasattr(socket, "AF_UNIX") and sys.platform.startswith("linux")),
"abstract namespace unix sockets not supported on this platform"
)
class UnixSocketTestFile(UnixSocketTest, AsyncTestCase):

def setUp(self):
self.address = "\0" + uuid.uuid4().hex
super().setUp()


class KeepAliveTest(AsyncHTTPTestCase):
"""Tests various scenarios for HTTP 1.1 keep-alive support.
Expand Down

0 comments on commit a51a80b

Please sign in to comment.