From ef99d40b5d35f91721e23781769c388c6afdbc6e Mon Sep 17 00:00:00 2001 From: Marcel Hellkamp Date: Wed, 19 Feb 2025 18:32:11 +0100 Subject: [PATCH] feat: Allow subclassing MultipartSegment This patch includes refactoring and changes of internal APIs to allow stabilizing those APIs in the future. --- CHANGELOG.rst | 3 ++ multipart.py | 103 ++++++++++++++++++++++++++++---------------------- 2 files changed, 60 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 56319f2..0c1b948 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,9 @@ Release 1.3 **Not released yet** * feat: Nicer error messages when reading from a closed ``MultipartPart``. +* feat: Support custom `MultipartSegment` subclasses to be used and emitted by + `PushMultipartParser`. However, the API between parser and segment is not + stable yet. Overriding any of the ``_on_*`` methods may break during releases. Release 1.2 =========== diff --git a/multipart.py b/multipart.py index e7b677f..22e0fd6 100644 --- a/multipart.py +++ b/multipart.py @@ -20,7 +20,7 @@ import re from io import BytesIO -from typing import Iterator, Union, Optional, Tuple, List +from typing import Generic, Iterator, Type, TypeVar, Union, Optional, Tuple, List from urllib.parse import parse_qs from wsgiref.headers import Headers from collections.abc import MutableMapping as DictMixin @@ -280,8 +280,10 @@ def parse_options_header(header, options=None, unquote=header_unquote): _BODY = "BODY" _COMPLETE = "END" +t_segment = TypeVar('SegmentType', bound="MultipartSegment") + +class PushMultipartParser(Generic[t_segment]): -class PushMultipartParser: def __init__( self, boundary: Union[str, bytes], @@ -292,6 +294,7 @@ def __init__( max_segment_count=inf, # unlimited header_charset="utf8", strict=False, + segment_class: Optional[Type[t_segment]] = None, ): """A push-based (incremental, non-blocking) parser for multipart/form-data. @@ -311,6 +314,8 @@ def __init__( :param max_segment_count: Maximum number of segments. :param header_charset: Charset for header names and values. :param strict: Enables additional format and sanity checks. + + :param segment_class: Class for emitted segments, defaults to `MultipartSegment`. """ self.boundary = to_bytes(boundary) self.content_length = content_length @@ -321,13 +326,17 @@ def __init__( self.max_segment_count = max_segment_count self.strict = strict - self._delimiter = b"\r\n--" + self.boundary + if segment_class and issubclass(self.segment_class, MultipartSegment): + self.segment_class = segment_class + else: + self.segment_class = MultipartSegment # Internal parser state + self._delimiter = b"\r\n--" + self.boundary self._parsed = 0 - self._fieldcount = 0 self._buffer = bytearray() - self._current = None + self._segment_count = 0 + self._segment = None self._state = _PREAMBLE #: True if the parser reached the end of the multipart stream, stopped @@ -344,7 +353,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def parse( self, chunk: Union[bytes, bytearray] - ) -> Iterator[Union["MultipartSegment", bytearray, None]]: + ) -> Iterator[Union[t_segment, bytearray, None]]: """Parse a chunk of data and yield as many result objects as possible with the data given. @@ -406,7 +415,7 @@ def parse( tail = buffer[next_start-2 : next_start] if tail == b"\r\n": # Normal delimiter found - self._current = MultipartSegment(self) + self._segment = self._new_segment() self._state = _HEADER offset = next_start continue @@ -433,12 +442,12 @@ def parse( nl = buffer.find(b"\r\n", offset) if nl > offset: # Non-empty header line - self._current._add_headerline(buffer[offset:nl]) + self._segment._on_headerline(buffer[offset:nl]) offset = nl + 2 continue elif nl == offset: # Empty header line -> End of header section - self._current._close_headers() - yield self._current + self._segment._on_header_complete() + yield self._segment self._state = _BODY offset += 2 continue @@ -463,18 +472,17 @@ def parse( if tail == b"\r\n" or tail == b"--": if index > offset: - self._current._update_size(index - offset) - yield buffer[offset:index] + yield self._segment._on_data(buffer[offset:index]) offset = next_start - self._current._mark_complete() + self._segment._on_data_complete() yield None # End of segment if tail == b"--": # Last delimiter self._state = _COMPLETE break else: # Normal delimiter - self._current = MultipartSegment(self) + self._segment = self._new_segment() self._state = _HEADER continue @@ -482,8 +490,7 @@ def parse( # the end, but emiot the rest. chunk_end = bufferlen - (d_len + 1) assert chunk_end > offset # Always true - self._current._update_size(chunk_end - offset) - yield buffer[offset:chunk_end] + yield self._segment._on_data(buffer[offset:chunk_end]) offset = chunk_end break # wait for more data @@ -501,6 +508,12 @@ def parse( self.close(check_complete=False) raise + def _new_segment(self) -> t_segment: + self._segment_count += 1 + if self._segment_count > self.max_segment_count: + raise ParserLimitReached("Maximum segment count exceeded") + return self.segment_class(self) + def close(self, check_complete=True): """ Close this parser if not already closed. @@ -510,7 +523,7 @@ def close(self, check_complete=True): """ self.closed = True - self._current = None + self._segment = None del self._buffer[:] if check_complete and self._state is not _COMPLETE: @@ -551,39 +564,34 @@ class MultipartSegment: def __init__(self, parser: PushMultipartParser): """ Private constructor, used by :class:`PushMultipartParser` """ self._parser = parser - - if parser._fieldcount+1 > parser.max_segment_count: - raise ParserLimitReached("Maximum segment count exceeded") - parser._fieldcount += 1 - self.headerlist = [] self.size = 0 - self.complete = 0 + self.complete = False - self.name = None + self.name = "" self.filename = None self.content_type = None self.charset = None + self._maxlen = parser.max_segment_size self._clen = -1 - self._size_limit = parser.max_segment_size - def _add_headerline(self, line: bytearray): - assert line and self.name is None - parser = self._parser + def _on_headerline(self, line: bytearray): + """ Called for each raw header line in a segment. """ - if line[0] in b" \t": # Multi-line header value - if not self.headerlist or parser.strict: + if line[0] in b" \t": # Continuation of last header line + if not self.headerlist or self._parser.strict: raise StrictParserError("Unexpected segment header continuation") prev = ": ".join(self.headerlist.pop()) - line = prev.encode(parser.header_charset) + b" " + line.strip() + line = prev.encode(self._parser.header_charset) + b" " + line.strip() - if len(line) > parser.max_header_size: + if len(line) > self._parser.max_header_size: raise ParserLimitReached("Maximum segment header length exceeded") - if len(self.headerlist) >= parser.max_header_count: + + if len(self.headerlist) >= self._parser.max_header_count: raise ParserLimitReached("Maximum segment header count exceeded") try: - name, col, value = line.decode(parser.header_charset).partition(":") + name, col, value = line.decode(self._parser.header_charset).partition(":") name = name.strip() if not col or not name: raise ParserError("Malformed segment header") @@ -594,9 +602,10 @@ def _add_headerline(self, line: bytearray): self.headerlist.append((name.title(), value.strip())) - def _close_headers(self): - assert self.name is None + def _on_header_complete(self): + """ Called after the last segment header. """ + dtype = False for h,v in self.headerlist: if h == "Content-Disposition": dtype, args = parse_options_header(v, unquote=content_disposition_unquote) @@ -611,21 +620,23 @@ def _close_headers(self): self.charset = args.get("charset") elif h == "Content-Length" and v.isdecimal(): self._clen = int(v) + self._maxlen = min(self._clen, self._maxlen) - if self.name is None: + if not dtype: raise ParserError("Missing Content-Disposition segment header") - def _update_size(self, bytecount: int): - assert self.name is not None and not self.complete - self.size += bytecount - if self._clen >= 0 and self.size > self._clen: - raise ParserError("Segment Content-Length exceeded") - if self.size > self._size_limit: + def _on_data(self, chunk: bytearray) -> bytearray: + """ Called for each chunk of segment data. Must return the chunk. """ + self.size += len(chunk) + if self.size > self._maxlen: + if self.size > self._clen > -1: + raise ParserError("Segment Content-Length exceeded") raise ParserLimitReached("Maximum segment size exceeded") + return chunk - def _mark_complete(self): - assert self.name is not None and not self.complete - if self._clen >= 0 and self.size != self._clen: + def _on_data_complete(self): + """ Called after the last chunk of segment data. """ + if self._clen > -1 and self.size != self._clen: raise ParserError("Segment size does not match Content-Length header") self.complete = True