Skip to content

Commit

Permalink
Implement lazy loading mechanism for QualifiedNameProvider (#720)
Browse files Browse the repository at this point in the history
* Implement lazy loading mechanism for expensive metadata providers
* Add support for lazy values in metadata matchers
* Fix type issues and implement lazy value support in base metadata provider too
* Add unit tests for BaseMetadataProvider

Co-authored-by: Zsolt Dollenstein <[email protected]>
  • Loading branch information
Chenguang-Zhu and zsol authored Jul 9, 2022
1 parent b3eda50 commit 7cb229d
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 21 deletions.
33 changes: 30 additions & 3 deletions libcst/_metadata_dependent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from abc import ABC
from contextlib import contextmanager
from typing import (
Callable,
cast,
ClassVar,
Collection,
Generic,
Iterator,
Mapping,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)

if TYPE_CHECKING:
Expand All @@ -29,7 +32,28 @@

_T = TypeVar("_T")

_UNDEFINED_DEFAULT = object()

class _UNDEFINED_DEFAULT:
pass


class LazyValue(Generic[_T]):
"""
The class for implementing a lazy metadata loading mechanism that improves the
performance when retriving expensive metadata (e.g., qualified names). Providers
including :class:`~libcst.metadata.QualifiedNameProvider` use this class to load
the metadata of a certain node lazily when calling
:func:`~libcst.MetadataDependent.get_metadata`.
"""

def __init__(self, callable: Callable[[], _T]) -> None:
self.callable = callable
self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT

def __call__(self) -> _T:
if self.return_value is _UNDEFINED_DEFAULT:
self.return_value = self.callable()
return cast(_T, self.return_value)


class MetadataDependent(ABC):
Expand Down Expand Up @@ -107,6 +131,9 @@ def get_metadata(
)

if default is not _UNDEFINED_DEFAULT:
return cast(_T, self.metadata[key].get(node, default))
value = self.metadata[key].get(node, default)
else:
return cast(_T, self.metadata[key][node])
value = self.metadata[key][node]
if isinstance(value, LazyValue):
value = value()
return cast(_T, value)
7 changes: 6 additions & 1 deletion libcst/matchers/_matcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import libcst
import libcst.metadata as meta
from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel
from libcst._metadata_dependent import LazyValue


class DoNotCareSentinel(Enum):
Expand Down Expand Up @@ -1544,7 +1545,11 @@ def _fetch(provider: meta.ProviderT, node: libcst.CSTNode) -> object:
if provider not in metadata:
metadata[provider] = wrapper.resolve(provider)

return metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL)
node_metadata = metadata.get(provider, {}).get(node, _METADATA_MISSING_SENTINEL)
if isinstance(node_metadata, LazyValue):
node_metadata = node_metadata()

return node_metadata

return _fetch

Expand Down
29 changes: 16 additions & 13 deletions libcst/metadata/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from types import MappingProxyType
from typing import (
Callable,
cast,
Generic,
List,
Mapping,
Expand All @@ -16,12 +15,14 @@
Type,
TYPE_CHECKING,
TypeVar,
Union,
)

from libcst._batched_visitor import BatchableCSTVisitor
from libcst._metadata_dependent import (
_T as _MetadataT,
_UNDEFINED_DEFAULT,
LazyValue,
MetadataDependent,
)
from libcst._visitors import CSTVisitor
Expand All @@ -36,6 +37,7 @@
# BaseMetadataProvider[int] would be a subtype of BaseMetadataProvider[object], so the
# typevar is covariant.
_ProvidedMetadataT = TypeVar("_ProvidedMetadataT", covariant=True)
MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT]


# We can't use an ABCMeta here, because of metaclass conflicts
Expand All @@ -52,16 +54,16 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
#
# N.B. This has some typing variance problems. See `set_metadata` for an
# explanation.
_computed: MutableMapping["CSTNode", _ProvidedMetadataT]
_computed: MutableMapping["CSTNode", MaybeLazyMetadataT]

#: Implement gen_cache to indicate the matadata provider depends on cache from external
#: Implement gen_cache to indicate the metadata provider depends on cache from external
#: system. This function will be called by :class:`~libcst.metadata.FullRepoManager`
#: to compute required cache object per file path.
gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None

def __init__(self, cache: object = None) -> None:
super().__init__()
self._computed = {}
self._computed: MutableMapping["CSTNode", MaybeLazyMetadataT] = {}
if self.gen_cache and cache is None:
# The metadata provider implementation is responsible to store and use cache.
raise Exception(
Expand All @@ -71,7 +73,7 @@ def __init__(self, cache: object = None) -> None:

def _gen(
self, wrapper: "MetadataWrapper"
) -> Mapping["CSTNode", _ProvidedMetadataT]:
) -> Mapping["CSTNode", MaybeLazyMetadataT]:
"""
Resolves and returns metadata mapping for the module in ``wrapper``.
Expand All @@ -93,11 +95,7 @@ def _gen_impl(self, module: "Module") -> None:
"""
...

# pyre-ignore[46]: The covariant `value` isn't type-safe because we write it to
# pyre: `self._computed`, however we assume that only one subclass in the MRO chain
# pyre: will ever call `set_metadata`, so it's okay for our purposes. There's no
# pyre: sane way to redesign this API so that it doesn't have this problem.
def set_metadata(self, node: "CSTNode", value: _ProvidedMetadataT) -> None:
def set_metadata(self, node: "CSTNode", value: MaybeLazyMetadataT) -> None:
"""
Record a metadata value ``value`` for ``node``.
"""
Expand All @@ -107,7 +105,9 @@ def get_metadata(
self,
key: Type["BaseMetadataProvider[_MetadataT]"],
node: "CSTNode",
default: _MetadataT = _UNDEFINED_DEFAULT,
default: Union[
MaybeLazyMetadataT, Type[_UNDEFINED_DEFAULT]
] = _UNDEFINED_DEFAULT,
) -> _MetadataT:
"""
The same method as :func:`~libcst.MetadataDependent.get_metadata` except
Expand All @@ -116,9 +116,12 @@ def get_metadata(
"""
if key is type(self):
if default is not _UNDEFINED_DEFAULT:
return cast(_MetadataT, self._computed.get(node, default))
ret = self._computed.get(node, default)
else:
return cast(_MetadataT, self._computed[node])
ret = self._computed[node]
if isinstance(ret, LazyValue):
return ret()
return ret

return super().get_metadata(key, node, default)

Expand Down
6 changes: 4 additions & 2 deletions libcst/metadata/name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Collection, List, Mapping, Optional, Union

import libcst as cst
from libcst._metadata_dependent import MetadataDependent
from libcst._metadata_dependent import LazyValue, MetadataDependent
from libcst.helpers.module import calculate_module_and_package, ModuleNameAndPackage
from libcst.metadata.base_provider import BatchableMetadataProvider
from libcst.metadata.scope_provider import (
Expand Down Expand Up @@ -78,7 +78,9 @@ def __init__(self, provider: "QualifiedNameProvider") -> None:
def on_visit(self, node: cst.CSTNode) -> bool:
scope = self.provider.get_metadata(ScopeProvider, node, None)
if scope:
self.provider.set_metadata(node, scope.get_qualified_names_for(node))
self.provider.set_metadata(
node, LazyValue(lambda: scope.get_qualified_names_for(node))
)
else:
self.provider.set_metadata(node, set())
super().on_visit(node)
Expand Down
61 changes: 61 additions & 0 deletions libcst/metadata/tests/test_base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import libcst as cst
from libcst import parse_module
from libcst._metadata_dependent import LazyValue
from libcst.metadata import (
BatchableMetadataProvider,
MetadataWrapper,
Expand Down Expand Up @@ -75,3 +76,63 @@ def visit_Return(self, node: cst.Return) -> None:
self.assertEqual(metadata[SimpleProvider][pass_], 1)
self.assertEqual(metadata[SimpleProvider][return_], 2)
self.assertEqual(metadata[SimpleProvider][pass_2], 1)

def test_lazy_visitor_provider(self) -> None:
class SimpleLazyProvider(VisitorMetadataProvider[int]):
"""
Sets metadata on every node to a callable that returns 1.
"""

def on_visit(self, node: cst.CSTNode) -> bool:
self.set_metadata(node, LazyValue(lambda: 1))
return True

wrapper = MetadataWrapper(parse_module("pass; return"))
module = wrapper.module
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]

provider = SimpleLazyProvider()
metadata = provider._gen(wrapper)

# Check access on provider
self.assertEqual(provider.get_metadata(SimpleLazyProvider, module), 1)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 1)

# Check returned mapping
self.assertTrue(isinstance(metadata[module], LazyValue))
self.assertTrue(isinstance(metadata[pass_], LazyValue))
self.assertTrue(isinstance(metadata[return_], LazyValue))

def testlazy_batchable_provider(self) -> None:
class SimpleLazyProvider(BatchableMetadataProvider[int]):
"""
Sets metadata on every pass node to a callable that returns 1,
and every return node to a callable that returns 2.
"""

def visit_Pass(self, node: cst.Pass) -> None:
self.set_metadata(node, LazyValue(lambda: 1))

def visit_Return(self, node: cst.Return) -> None:
self.set_metadata(node, LazyValue(lambda: 2))

wrapper = MetadataWrapper(parse_module("pass; return; pass"))
module = wrapper.module
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2]

provider = SimpleLazyProvider()
metadata = _gen_batchable(wrapper, [provider])

# Check access on provider
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_), 1)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, return_), 2)
self.assertEqual(provider.get_metadata(SimpleLazyProvider, pass_2), 1)

# Check returned mapping
self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_], LazyValue))
self.assertTrue(isinstance(metadata[SimpleLazyProvider][return_], LazyValue))
self.assertTrue(isinstance(metadata[SimpleLazyProvider][pass_2], LazyValue))
20 changes: 18 additions & 2 deletions libcst/metadata/tests/test_name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import libcst as cst
from libcst import ensure_type
from libcst._nodes.base import CSTNode
from libcst.metadata import (
FullyQualifiedNameProvider,
MetadataWrapper,
Expand All @@ -22,11 +23,26 @@
from libcst.testing.utils import data_provider, UnitTest


class QNameVisitor(cst.CSTVisitor):

METADATA_DEPENDENCIES = (QualifiedNameProvider,)

def __init__(self) -> None:
self.qnames: Dict["CSTNode", Collection[QualifiedName]] = {}

def on_visit(self, node: cst.CSTNode) -> bool:
qname = self.get_metadata(QualifiedNameProvider, node)
self.qnames[node] = qname
return True


def get_qualified_name_metadata_provider(
module_str: str,
) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]:
wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
return wrapper.module, wrapper.resolve(QualifiedNameProvider)
visitor = QNameVisitor()
wrapper.visit(visitor)
return wrapper.module, visitor.qnames


def get_qualified_names(module_str: str) -> Set[QualifiedName]:
Expand Down Expand Up @@ -358,7 +374,7 @@ def f(): pass
else:
import f
import a.b as f
f()
"""
)
Expand Down

0 comments on commit 7cb229d

Please sign in to comment.