diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 050f41435..cbafc838d 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -62,6 +62,13 @@ SupportsWrite, ) +if sys.version_info >= (3, 10): + from types import UnionType as _types_UnionType +else: + + class _types_UnionType: + ... + # Proto 3 data types TYPE_ENUM = "enum" @@ -148,6 +155,7 @@ def datetime_default_gen() -> datetime: DATETIME_ZERO = datetime_default_gen() + # Special protobuf json doubles INFINITY = "Infinity" NEG_INFINITY = "-Infinity" @@ -1166,30 +1174,29 @@ def _get_field_default(self, field_name: str) -> Any: def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: t = cls._type_hint(field.name) - if hasattr(t, "__origin__"): - if t.__origin__ is dict: - # This is some kind of map (dict in Python). - return dict - elif t.__origin__ is list: - # This is some kind of list (repeated) field. - return list - elif t.__origin__ is Union and t.__args__[1] is type(None): + is_310_union = isinstance(t, _types_UnionType) + if hasattr(t, "__origin__") or is_310_union: + if is_310_union or t.__origin__ is Union: # This is an optional field (either wrapped, or using proto3 # field presence). For setting the default we really don't care # what kind of field it is. return type(None) - else: - return t - elif issubclass(t, Enum): + if t.__origin__ is list: + # This is some kind of list (repeated) field. + return list + if t.__origin__ is dict: + # This is some kind of map (dict in Python). + return dict + return t + if issubclass(t, Enum): # Enums always default to zero. return t.try_value - elif t is datetime: + if t is datetime: # Offsets are relative to 1970-01-01T00:00:00Z return datetime_default_gen - else: - # This is either a primitive scalar or another message type. Calling - # it should result in its zero value. - return t + # This is either a primitive scalar or another message type. Calling + # it should result in its zero value. + return t def _postprocess_single( self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index 4221122b9..b216dfc59 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import os import re from typing import ( + TYPE_CHECKING, Dict, List, Set, @@ -13,6 +16,9 @@ from .naming import pythonize_class_name +if TYPE_CHECKING: + from ..plugin.typing_compiler import TypingCompiler + WRAPPER_TYPES: Dict[str, Type] = { ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.FloatValue": google_protobuf.FloatValue, @@ -47,7 +53,7 @@ def get_type_reference( package: str, imports: set, source_type: str, - typing_compiler: "TypingCompiler", + typing_compiler: TypingCompiler, unwrap: bool = True, pydantic: bool = False, ) -> str: diff --git a/src/betterproto/plugin/typing_compiler.py b/src/betterproto/plugin/typing_compiler.py index 937c7bfc1..eca3691f9 100644 --- a/src/betterproto/plugin/typing_compiler.py +++ b/src/betterproto/plugin/typing_compiler.py @@ -139,29 +139,35 @@ def imports(self) -> Dict[str, Optional[Set[str]]]: class NoTyping310TypingCompiler(TypingCompiler): _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) + @staticmethod + def _fmt(type: str) -> str: # for now this is necessary till 3.14 + if type.startswith('"'): + return type[1:-1] + return type + def optional(self, type: str) -> str: - return f"{type} | None" + return f'"{self._fmt(type)} | None"' def list(self, type: str) -> str: - return f"list[{type}]" + return f'"list[{self._fmt(type)}]"' def dict(self, key: str, value: str) -> str: - return f"dict[{key}, {value}]" + return f'"dict[{key}, {self._fmt(value)}]"' def union(self, *types: str) -> str: - return " | ".join(types) + return f'"{" | ".join(map(self._fmt, types))}"' def iterable(self, type: str) -> str: - self._imports["typing"].add("Iterable") - return f"Iterable[{type}]" + self._imports["collections.abc"].add("Iterable") + return f'"Iterable[{type}]"' def async_iterable(self, type: str) -> str: - self._imports["typing"].add("AsyncIterable") - return f"AsyncIterable[{type}]" + self._imports["collections.abc"].add("AsyncIterable") + return f'"AsyncIterable[{type}]"' def async_iterator(self, type: str) -> str: - self._imports["typing"].add("AsyncIterator") - return f"AsyncIterator[{type}]" + self._imports["collections.abc"].add("AsyncIterator") + return f'"AsyncIterator[{type}]"' def imports(self) -> Dict[str, Optional[Set[str]]]: return {k: v if v else None for k, v in self._imports.items()} diff --git a/tests/test_typing_compiler.py b/tests/test_typing_compiler.py index 3d1083c72..ee17449b5 100644 --- a/tests/test_typing_compiler.py +++ b/tests/test_typing_compiler.py @@ -62,19 +62,17 @@ def test_typing_import_typing_compiler(): def test_no_typing_311_typing_compiler(): compiler = NoTyping310TypingCompiler() assert compiler.imports() == {} - assert compiler.optional("str") == "str | None" + assert compiler.optional("str") == '"str | None"' assert compiler.imports() == {} - assert compiler.list("str") == "list[str]" + assert compiler.list("str") == '"list[str]"' assert compiler.imports() == {} - assert compiler.dict("str", "int") == "dict[str, int]" + assert compiler.dict("str", "int") == '"dict[str, int]"' assert compiler.imports() == {} - assert compiler.union("str", "int") == "str | int" + assert compiler.union("str", "int") == '"str | int"' assert compiler.imports() == {} - assert compiler.iterable("str") == "Iterable[str]" - assert compiler.imports() == {"typing": {"Iterable"}} - assert compiler.async_iterable("str") == "AsyncIterable[str]" - assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}} - assert compiler.async_iterator("str") == "AsyncIterator[str]" + assert compiler.iterable("str") == '"Iterable[str]"' + assert compiler.async_iterable("str") == '"AsyncIterable[str]"' + assert compiler.async_iterator("str") == '"AsyncIterator[str]"' assert compiler.imports() == { - "typing": {"Iterable", "AsyncIterable", "AsyncIterator"} + "collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"} }