diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 775c385e..9b65bf7e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,8 @@ dev - Support for Python 3.9 has been removed. - Support for PyPy 3.9 has been removed. - `Stream.end_stream()` now raises `NoSuchStreamError` or `StreamClosedError` exceptions, instead of a generic `KeyError`. +- Duplicate ``content-length`` headers with different values now raise ``ProtocolError``. + Previously, the first ``content-length`` header was accepted and later conflicting values were ignored. - **backfill from v4.3.0** Convert emitted events into Python `dataclass`, which introduces new constructors with required arguments. Instantiating these events without arguments, as previously commonly used API pattern, will no longer work. diff --git a/src/h2/stream.py b/src/h2/stream.py index 9b5bce78..f8dc2bf2 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -1364,15 +1364,23 @@ def _initialize_content_length(self, headers: Iterable[Header]) -> None: self._expected_content_length = 0 return + content_length = None + for n, v in headers: if n == b"content-length": try: - self._expected_content_length = int(v, 10) + parsed_content_length = int(v, 10) except ValueError as err: msg = f"Invalid content-length header: {v!r}" raise ProtocolError(msg) from err - return + if content_length is None: + content_length = parsed_content_length + elif parsed_content_length != content_length: + msg = f"Conflicting content-length headers: {content_length} and {parsed_content_length}" + raise ProtocolError(msg) + + self._expected_content_length = content_length def _track_content_length(self, length: int, end_stream: bool) -> None: """ diff --git a/tests/test_invalid_content_lengths.py b/tests/test_invalid_content_lengths.py index 39401ea2..4242c94e 100644 --- a/tests/test_invalid_content_lengths.py +++ b/tests/test_invalid_content_lengths.py @@ -19,18 +19,24 @@ class TestInvalidContentLengths: peer is not valid. """ - example_request_headers = [ + example_request_headers_without_content_length = [ (":authority", "example.com"), (":path", "/"), (":scheme", "https"), (":method", "POST"), + ] + example_request_headers = [ + *example_request_headers_without_content_length, ("content-length", "15"), ] - example_request_headers_bytes = [ + example_request_headers_bytes_without_content_length = [ (b":authority", b"example.com"), (b":path", b"/"), (b":scheme", b"https"), (b":method", b"POST"), + ] + example_request_headers_bytes = [ + *example_request_headers_bytes_without_content_length, (b"content-length", b"15"), ] example_response_headers = [ @@ -39,6 +45,78 @@ class TestInvalidContentLengths: ] server_config = h2.config.H2Configuration(client_side=False) + @pytest.mark.parametrize( + "request_headers", + [ + example_request_headers_without_content_length, + example_request_headers_bytes_without_content_length, + ], + ) + def test_duplicate_matching_content_lengths(self, frame_factory, request_headers) -> None: + """ + Remote peers sending duplicate matching content-length fields are + accepted. + """ + c = h2.connection.H2Connection(config=self.server_config) + c.initiate_connection() + c.receive_data(frame_factory.preamble()) + c.clear_outbound_data_buffer() + + headers = frame_factory.build_headers_frame( + headers=[ + *request_headers, + ("content-length", "15"), + ("content-length", "15"), + ], + ) + data = frame_factory.build_data_frame( + data=b"\x01"*15, + flags=["END_STREAM"], + ) + + events = c.receive_data(headers.serialize() + data.serialize()) + + assert isinstance(events[0], h2.events.RequestReceived) + assert isinstance(events[1], h2.events.DataReceived) + assert isinstance(events[2], h2.events.StreamEnded) + assert c.data_to_send() == b"" + + @pytest.mark.parametrize( + "request_headers", + [ + example_request_headers_without_content_length, + example_request_headers_bytes_without_content_length, + ], + ) + def test_duplicate_conflicting_content_lengths(self, frame_factory, request_headers) -> None: + """ + Remote peers sending duplicate conflicting content-length fields cause + Protocol Errors. + """ + c = h2.connection.H2Connection(config=self.server_config) + c.initiate_connection() + c.receive_data(frame_factory.preamble()) + c.clear_outbound_data_buffer() + + headers = frame_factory.build_headers_frame( + headers=[ + *request_headers, + ("content-length", "15"), + ("content-length", "16"), + ], + ) + with pytest.raises( + h2.exceptions.ProtocolError, + match="Conflicting content-length headers: 15 and 16", + ): + c.receive_data(headers.serialize()) + + expected_frame = frame_factory.build_goaway_frame( + last_stream_id=1, + error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, + ) + assert c.data_to_send() == expected_frame.serialize() + @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) def test_too_much_data(self, frame_factory, request_headers) -> None: """