Skip to content

Commit

Permalink
remove format_state and override behavior for bare (#3979)
Browse files Browse the repository at this point in the history
* remove format_state and override behavior for bare

* pass the test cases

* only do one level of dicting dataclasses

* remove dict and replace list with set

* delete unnecessary serialize calls

* remove serialize for mutable proxy

* dang it darglint
  • Loading branch information
adhami3310 authored Sep 26, 2024
1 parent 70bd88c commit 0ab161c
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 117 deletions.
2 changes: 1 addition & 1 deletion reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def compile_state(state: Type[BaseState]) -> dict:
initial_state = state(_reflex_internal_init=True).dict(
initial=True, include_computed=False
)
return format.format_state(initial_state)
return initial_state


def _compile_client_storage_field(
Expand Down
4 changes: 3 additions & 1 deletion reflex/components/base/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless
from reflex.vars.base import Var
from reflex.vars import ArrayVar, BooleanVar, ObjectVar, Var


class Bare(Component):
Expand All @@ -33,6 +33,8 @@ def create(cls, contents: Any) -> Component:

def _render(self) -> Tag:
if isinstance(self.contents, Var):
if isinstance(self.contents, (BooleanVar, ObjectVar, ArrayVar)):
return Tagless(contents=f"{{{str(self.contents.to_string())}}}")
return Tagless(contents=f"{{{str(self.contents)}}}")
return Tagless(contents=str(self.contents))

Expand Down
3 changes: 1 addition & 2 deletions reflex/middleware/hydrate_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from reflex.event import Event, get_hydrate_event
from reflex.middleware.middleware import Middleware
from reflex.state import BaseState, StateUpdate
from reflex.utils import format

if TYPE_CHECKING:
from reflex.app import App
Expand Down Expand Up @@ -43,7 +42,7 @@ async def preprocess(
setattr(state, constants.CompileVars.IS_HYDRATED, False)

# Get the initial state.
delta = format.format_state(state.dict())
delta = state.dict()
# since a full dict was captured, clean any dirtiness
state._clean()

Expand Down
21 changes: 6 additions & 15 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
LockExpiredError,
)
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.utils.serializers import serializer
from reflex.utils.types import override
from reflex.vars import VarData

Expand Down Expand Up @@ -1790,9 +1790,6 @@ def get_delta(self) -> Delta:
for substate in self.dirty_substates.union(self._always_dirty_substates):
delta.update(substates[substate].get_delta())

# Format the delta.
delta = format.format_state(delta)

# Return the delta.
return delta

Expand Down Expand Up @@ -2433,7 +2430,7 @@ def json(self) -> str:
Returns:
The state update as a JSON string.
"""
return format.json_dumps(dataclasses.asdict(self))
return format.json_dumps(self)


class StateManager(Base, ABC):
Expand Down Expand Up @@ -3660,22 +3657,16 @@ def __reduce_ex__(self, protocol_version):


@serializer
def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType:
"""Serialize the wrapped value of a MutableProxy.
def serialize_mutable_proxy(mp: MutableProxy):
"""Return the wrapped value of a MutableProxy.
Args:
mp: The MutableProxy to serialize.
Returns:
The serialized wrapped object.
Raises:
ValueError: when the wrapped object is not serializable.
The wrapped object.
"""
value = serialize(mp.__wrapped__)
if value is None:
raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}")
return value
return mp.__wrapped__


class ImmutableMutableProxy(MutableProxy):
Expand Down
44 changes: 1 addition & 43 deletions reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union

from reflex import constants
from reflex.utils import exceptions, types
from reflex.utils import exceptions
from reflex.utils.console import deprecate

if TYPE_CHECKING:
Expand Down Expand Up @@ -624,48 +624,6 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
return {k.replace("-", "_"): v for k, v in params.items()}


def format_state(value: Any, key: Optional[str] = None) -> Any:
"""Recursively format values in the given state.
Args:
value: The state to format.
key: The key associated with the value (optional).
Returns:
The formatted state.
Raises:
TypeError: If the given value is not a valid state.
"""
from reflex.utils import serializers

# Handle dicts.
if isinstance(value, dict):
return {k: format_state(v, k) for k, v in value.items()}

# Handle lists, sets, typles.
if isinstance(value, types.StateIterBases):
return [format_state(v) for v in value]

# Return state vars as is.
if isinstance(value, types.StateBases):
return value

# Serialize the value.
serialized = serializers.serialize(value)
if serialized is not None:
return serialized

if key is None:
raise TypeError(
f"No JSON serializer found for var {value} of type {type(value)}."
)
else:
raise TypeError(
f"No JSON serializer found for State Var '{key}' of value {value} of type {type(value)}."
)


def format_state_name(state_name: str) -> str:
"""Format a state name, replacing dots with double underscore.
Expand Down
53 changes: 7 additions & 46 deletions reflex/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Expand Down Expand Up @@ -126,7 +125,8 @@ def serialize(
# If there is no serializer, return None.
if serializer is None:
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return serialize(dataclasses.asdict(value))
return {k.name: getattr(value, k.name) for k in dataclasses.fields(value)}

if get_type:
return None, None
return None
Expand Down Expand Up @@ -214,32 +214,6 @@ def serialize_type(value: type) -> str:
return value.__name__


@serializer
def serialize_str(value: str) -> str:
"""Serialize a string.
Args:
value: The string to serialize.
Returns:
The serialized string.
"""
return value


@serializer
def serialize_primitive(value: Union[bool, int, float, None]):
"""Serialize a primitive type.
Args:
value: The number/bool/None to serialize.
Returns:
The serialized number/bool/None.
"""
return value


@serializer
def serialize_base(value: Base) -> dict:
"""Serialize a Base instance.
Expand All @@ -250,33 +224,20 @@ def serialize_base(value: Base) -> dict:
Returns:
The serialized Base.
"""
return {k: serialize(v) for k, v in value.dict().items() if not callable(v)}
return {k: v for k, v in value.dict().items() if not callable(v)}


@serializer
def serialize_list(value: Union[List, Tuple, Set]) -> list:
"""Serialize a list to a JSON string.
def serialize_set(value: Set) -> list:
"""Serialize a set to a JSON serializable list.
Args:
value: The list to serialize.
value: The set to serialize.
Returns:
The serialized list.
"""
return [serialize(item) for item in value]


@serializer
def serialize_dict(prop: Dict[str, Any]) -> dict:
"""Serialize a dictionary to a JSON string.
Args:
prop: The dictionary to serialize.
Returns:
The serialized dictionary.
"""
return {k: serialize(v) for k, v in prop.items()}
return list(value)


@serializer(to=str)
Expand Down
2 changes: 1 addition & 1 deletion reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def serialize_literal(value: LiteralVar):
Returns:
The serialized Literal.
"""
return serializers.serialize(value._var_value)
return value._var_value


P = ParamSpec("P")
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_var_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,8 @@ def test_var_operations(driver, var_operations: AppHarness):
("foreach_list_ix", "1\n2"),
("foreach_list_nested", "1\n1\n2"),
# rx.memo component with state
("memo_comp", "1210"),
("memo_comp_nested", "345"),
("memo_comp", "[1,2]10"),
("memo_comp_nested", "[3,4]5"),
# foreach in a match
("foreach_in_match", "first\nsecond\nthird"),
]
Expand Down
3 changes: 1 addition & 2 deletions tests/units/test_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import dataclasses
import functools
import io
import json
Expand Down Expand Up @@ -1053,7 +1052,7 @@ def _dynamic_state_event(name, val, **kwargs):
f"comp_{arg_name}": exp_val,
constants.CompileVars.IS_HYDRATED: False,
# "side_effect_counter": exp_index,
"router": dataclasses.asdict(exp_router),
"router": exp_router,
}
},
events=[
Expand Down
3 changes: 2 additions & 1 deletion tests/units/utils/test_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime
import json
from typing import Any, List

import plotly.graph_objects as go
Expand Down Expand Up @@ -621,7 +622,7 @@ def test_format_state(input, output):
input: The state to format.
output: The expected formatted state.
"""
assert format.format_state(input) == output
assert json.loads(format.json_dumps(input)) == json.loads(format.json_dumps(output))


@pytest.mark.parametrize(
Expand Down
8 changes: 5 additions & 3 deletions tests/units/utils/test_serializers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import datetime
import json
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Type
from typing import Any, Type

import pytest

from reflex.base import Base
from reflex.components.core.colors import Color
from reflex.utils import serializers
from reflex.utils.format import json_dumps
from reflex.vars.base import LiteralVar


@pytest.mark.parametrize(
"type_,expected",
[(str, True), (dict, True), (Dict[int, int], True), (Enum, True)],
[(Enum, True)],
)
def test_has_serializer(type_: Type, expected: bool):
"""Test that has_serializer returns the correct value.
Expand Down Expand Up @@ -198,7 +200,7 @@ def test_serialize(value: Any, expected: str):
value: The value to serialize.
expected: The expected result.
"""
assert serializers.serialize(value) == expected
assert json.loads(json_dumps(value)) == json.loads(json_dumps(expected))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 0ab161c

Please sign in to comment.