From a51a80bb8d09d1a5283a4f29c26f2d1b9b19daa2 Mon Sep 17 00:00:00 2001 From: Peter Stokes Date: Fri, 5 Jul 2024 10:25:14 +0100 Subject: [PATCH] Address review comments --- tornado/netutil.py | 3 ++ tornado/test/httpserver_test.py | 59 ++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/tornado/netutil.py b/tornado/netutil.py index b0cf21c5a3..b0cb7a444c 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -209,6 +209,9 @@ def bind_unix_socket( # Hurd doesn't support SO_REUSEADDR raise sock.setblocking(False) + # File names comprising of an initial null-byte denote an abstract + # namespace, on Linux, and therefore are not subject to file system + # orientated processing. if not file.startswith("\0"): try: st = os.stat(file) diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 1e32729bab..36b8d12208 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -44,6 +44,7 @@ import textwrap import unittest import urllib.parse +import uuid from io import BytesIO import typing @@ -815,11 +816,7 @@ def test_manual_protocol(self): self.assertEqual(self.fetch_json("/")["protocol"], "https") -@unittest.skipIf( - not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin", - "unix sockets not supported on this platform", -) -class UnixSocketTest(AsyncTestCase): +class UnixSocketTest: """HTTPServers can listen on Unix sockets too. Why would you want to do this? Nginx can proxy to backends listening @@ -832,38 +829,19 @@ class UnixSocketTest(AsyncTestCase): def setUp(self): super().setUp() - self.tmpdir = tempfile.mkdtemp() - self.sockfile = os.path.join(self.tmpdir, "test.sock") app = Application([("/hello", HelloWorldRequestHandler)]) self.server = HTTPServer(app) - if sys.platform.startswith("linux"): - self.sockabstract = "\0" + os.path.basename(self.tmpdir) - self.server.add_socket(netutil.bind_unix_socket(self.sockabstract)) - self.server.add_socket(netutil.bind_unix_socket(self.sockfile)) + self.server.add_socket(netutil.bind_unix_socket(self.address)) def tearDown(self): self.io_loop.run_sync(self.server.close_all_connections) self.server.stop() - shutil.rmtree(self.tmpdir) super().tearDown() @gen_test def test_unix_socket(self): with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: - stream.connect(self.sockfile) - stream.write(b"GET /hello HTTP/1.0\r\n\r\n") - response = yield stream.read_until(b"\r\n") - self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") - header_data = yield stream.read_until(b"\r\n\r\n") - headers = HTTPHeaders.parse(header_data.decode("latin1")) - body = yield stream.read_bytes(int(headers["Content-Length"])) - self.assertEqual(body, b"Hello world") - - @unittest.skipUnless(sys.platform.startswith("linux"), "requires Linux") - @gen_test - def test_unix_socket_abstract(self): - with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: - stream.connect(self.sockabstract) + stream.connect(self.address) stream.write(b"GET /hello HTTP/1.0\r\n\r\n") response = yield stream.read_until(b"\r\n") self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") @@ -878,12 +856,39 @@ def test_unix_socket_bad_request(self): # empty string. with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO): with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: - stream.connect(self.sockfile) + stream.connect(self.address) stream.write(b"garbage\r\n\r\n") response = yield stream.read_until_close() self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n") +@unittest.skipIf( + not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin", + "unix sockets not supported on this platform" +) +class UnixSocketTestAbstract(UnixSocketTest, AsyncTestCase): + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.address = os.path.join(self.tmpdir, "test.sock") + super().setUp() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmpdir) + + +@unittest.skipIf( + not (hasattr(socket, "AF_UNIX") and sys.platform.startswith("linux")), + "abstract namespace unix sockets not supported on this platform" +) +class UnixSocketTestFile(UnixSocketTest, AsyncTestCase): + + def setUp(self): + self.address = "\0" + uuid.uuid4().hex + super().setUp() + + class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support.