From 0ab161c11967c9a85faa2a5a5898263276221877 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 26 Sep 2024 16:00:28 -0700 Subject: [PATCH] remove format_state and override behavior for bare (#3979) * remove format_state and override behavior for bare * pass the test cases * only do one level of dicting dataclasses * remove dict and replace list with set * delete unnecessary serialize calls * remove serialize for mutable proxy * dang it darglint --- reflex/compiler/utils.py | 2 +- reflex/components/base/bare.py | 4 +- reflex/middleware/hydrate_middleware.py | 3 +- reflex/state.py | 21 +++------- reflex/utils/format.py | 44 +------------------- reflex/utils/serializers.py | 53 ++++-------------------- reflex/vars/base.py | 2 +- tests/integration/test_var_operations.py | 4 +- tests/units/test_app.py | 3 +- tests/units/utils/test_format.py | 3 +- tests/units/utils/test_serializers.py | 8 ++-- 11 files changed, 30 insertions(+), 117 deletions(-) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 1808f787a7..443e1984fc 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -155,7 +155,7 @@ def compile_state(state: Type[BaseState]) -> dict: initial_state = state(_reflex_internal_init=True).dict( initial=True, include_computed=False ) - return format.format_state(initial_state) + return initial_state def _compile_client_storage_field( diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 970cdfc840..ada511ef2b 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -7,7 +7,7 @@ from reflex.components.component import Component from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless -from reflex.vars.base import Var +from reflex.vars import ArrayVar, BooleanVar, ObjectVar, Var class Bare(Component): @@ -33,6 +33,8 @@ def create(cls, contents: Any) -> Component: def _render(self) -> Tag: if isinstance(self.contents, Var): + if isinstance(self.contents, (BooleanVar, ObjectVar, ArrayVar)): + return Tagless(contents=f"{{{str(self.contents.to_string())}}}") return Tagless(contents=f"{{{str(self.contents)}}}") return Tagless(contents=str(self.contents)) diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 46b524cd76..2198b82c2c 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -9,7 +9,6 @@ from reflex.event import Event, get_hydrate_event from reflex.middleware.middleware import Middleware from reflex.state import BaseState, StateUpdate -from reflex.utils import format if TYPE_CHECKING: from reflex.app import App @@ -43,7 +42,7 @@ async def preprocess( setattr(state, constants.CompileVars.IS_HYDRATED, False) # Get the initial state. - delta = format.format_state(state.dict()) + delta = state.dict() # since a full dict was captured, clean any dirtiness state._clean() diff --git a/reflex/state.py b/reflex/state.py index 8ab6a90a24..c16b37b69c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -73,7 +73,7 @@ LockExpiredError, ) from reflex.utils.exec import is_testing_env -from reflex.utils.serializers import SerializedType, serialize, serializer +from reflex.utils.serializers import serializer from reflex.utils.types import override from reflex.vars import VarData @@ -1790,9 +1790,6 @@ def get_delta(self) -> Delta: for substate in self.dirty_substates.union(self._always_dirty_substates): delta.update(substates[substate].get_delta()) - # Format the delta. - delta = format.format_state(delta) - # Return the delta. return delta @@ -2433,7 +2430,7 @@ def json(self) -> str: Returns: The state update as a JSON string. """ - return format.json_dumps(dataclasses.asdict(self)) + return format.json_dumps(self) class StateManager(Base, ABC): @@ -3660,22 +3657,16 @@ def __reduce_ex__(self, protocol_version): @serializer -def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType: - """Serialize the wrapped value of a MutableProxy. +def serialize_mutable_proxy(mp: MutableProxy): + """Return the wrapped value of a MutableProxy. Args: mp: The MutableProxy to serialize. Returns: - The serialized wrapped object. - - Raises: - ValueError: when the wrapped object is not serializable. + The wrapped object. """ - value = serialize(mp.__wrapped__) - if value is None: - raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}") - return value + return mp.__wrapped__ class ImmutableMutableProxy(MutableProxy): diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 86b4d96c94..ae985a8719 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union from reflex import constants -from reflex.utils import exceptions, types +from reflex.utils import exceptions from reflex.utils.console import deprecate if TYPE_CHECKING: @@ -624,48 +624,6 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]: return {k.replace("-", "_"): v for k, v in params.items()} -def format_state(value: Any, key: Optional[str] = None) -> Any: - """Recursively format values in the given state. - - Args: - value: The state to format. - key: The key associated with the value (optional). - - Returns: - The formatted state. - - Raises: - TypeError: If the given value is not a valid state. - """ - from reflex.utils import serializers - - # Handle dicts. - if isinstance(value, dict): - return {k: format_state(v, k) for k, v in value.items()} - - # Handle lists, sets, typles. - if isinstance(value, types.StateIterBases): - return [format_state(v) for v in value] - - # Return state vars as is. - if isinstance(value, types.StateBases): - return value - - # Serialize the value. - serialized = serializers.serialize(value) - if serialized is not None: - return serialized - - if key is None: - raise TypeError( - f"No JSON serializer found for var {value} of type {type(value)}." - ) - else: - raise TypeError( - f"No JSON serializer found for State Var '{key}' of value {value} of type {type(value)}." - ) - - def format_state_name(state_name: str) -> str: """Format a state name, replacing dots with double underscore. diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 42fb82916e..614257181a 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -12,7 +12,6 @@ from typing import ( Any, Callable, - Dict, List, Literal, Optional, @@ -126,7 +125,8 @@ def serialize( # If there is no serializer, return None. if serializer is None: if dataclasses.is_dataclass(value) and not isinstance(value, type): - return serialize(dataclasses.asdict(value)) + return {k.name: getattr(value, k.name) for k in dataclasses.fields(value)} + if get_type: return None, None return None @@ -214,32 +214,6 @@ def serialize_type(value: type) -> str: return value.__name__ -@serializer -def serialize_str(value: str) -> str: - """Serialize a string. - - Args: - value: The string to serialize. - - Returns: - The serialized string. - """ - return value - - -@serializer -def serialize_primitive(value: Union[bool, int, float, None]): - """Serialize a primitive type. - - Args: - value: The number/bool/None to serialize. - - Returns: - The serialized number/bool/None. - """ - return value - - @serializer def serialize_base(value: Base) -> dict: """Serialize a Base instance. @@ -250,33 +224,20 @@ def serialize_base(value: Base) -> dict: Returns: The serialized Base. """ - return {k: serialize(v) for k, v in value.dict().items() if not callable(v)} + return {k: v for k, v in value.dict().items() if not callable(v)} @serializer -def serialize_list(value: Union[List, Tuple, Set]) -> list: - """Serialize a list to a JSON string. +def serialize_set(value: Set) -> list: + """Serialize a set to a JSON serializable list. Args: - value: The list to serialize. + value: The set to serialize. Returns: The serialized list. """ - return [serialize(item) for item in value] - - -@serializer -def serialize_dict(prop: Dict[str, Any]) -> dict: - """Serialize a dictionary to a JSON string. - - Args: - prop: The dictionary to serialize. - - Returns: - The serialized dictionary. - """ - return {k: serialize(v) for k, v in prop.items()} + return list(value) @serializer(to=str) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 4faa38be74..afbc56a55e 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1141,7 +1141,7 @@ def serialize_literal(value: LiteralVar): Returns: The serialized Literal. """ - return serializers.serialize(value._var_value) + return value._var_value P = ParamSpec("P") diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index cae56e1a8d..919a39f3b3 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -793,8 +793,8 @@ def test_var_operations(driver, var_operations: AppHarness): ("foreach_list_ix", "1\n2"), ("foreach_list_nested", "1\n1\n2"), # rx.memo component with state - ("memo_comp", "1210"), - ("memo_comp_nested", "345"), + ("memo_comp", "[1,2]10"), + ("memo_comp_nested", "[3,4]5"), # foreach in a match ("foreach_in_match", "first\nsecond\nthird"), ] diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 88655d7de3..0c22c38e35 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import functools import io import json @@ -1053,7 +1052,7 @@ def _dynamic_state_event(name, val, **kwargs): f"comp_{arg_name}": exp_val, constants.CompileVars.IS_HYDRATED: False, # "side_effect_counter": exp_index, - "router": dataclasses.asdict(exp_router), + "router": exp_router, } }, events=[ diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 4ec5099f51..042c3f3231 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +import json from typing import Any, List import plotly.graph_objects as go @@ -621,7 +622,7 @@ def test_format_state(input, output): input: The state to format. output: The expected formatted state. """ - assert format.format_state(input) == output + assert json.loads(format.json_dumps(input)) == json.loads(format.json_dumps(output)) @pytest.mark.parametrize( diff --git a/tests/units/utils/test_serializers.py b/tests/units/utils/test_serializers.py index 97da987928..6301873097 100644 --- a/tests/units/utils/test_serializers.py +++ b/tests/units/utils/test_serializers.py @@ -1,19 +1,21 @@ import datetime +import json from enum import Enum from pathlib import Path -from typing import Any, Dict, Type +from typing import Any, Type import pytest from reflex.base import Base from reflex.components.core.colors import Color from reflex.utils import serializers +from reflex.utils.format import json_dumps from reflex.vars.base import LiteralVar @pytest.mark.parametrize( "type_,expected", - [(str, True), (dict, True), (Dict[int, int], True), (Enum, True)], + [(Enum, True)], ) def test_has_serializer(type_: Type, expected: bool): """Test that has_serializer returns the correct value. @@ -198,7 +200,7 @@ def test_serialize(value: Any, expected: str): value: The value to serialize. expected: The expected result. """ - assert serializers.serialize(value) == expected + assert json.loads(json_dumps(value)) == json.loads(json_dumps(expected)) @pytest.mark.parametrize(