Skip to content

Commit

Permalink
cache kwargs for function (#128)
Browse files Browse the repository at this point in the history
* cache kwargs for function

* qa, docs and tests

---------

Co-authored-by: Mattia Almansi <[email protected]>
  • Loading branch information
EddyCMWF and malmans2 authored Sep 2, 2024
1 parent 683d4b0 commit 07fc1b7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
21 changes: 18 additions & 3 deletions cacholote/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,21 @@ def _decode_and_update(
return result


def cacheable(func: F) -> F:
"""Make a function cacheable."""
def cacheable(func: F, **cache_kwargs: Any) -> F:
"""Make a function cacheable.
Parameters
----------
func: callable
Function to cache
**cache_kwargs: Any
Additional kwargs to use for hashing
Returns
-------
callable:
Cache functions
"""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
Expand All @@ -56,7 +69,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

try:
hexdigest = encode._hexdigestify_python_call(func, *args, **kwargs)
hexdigest = encode._hexdigestify_python_call(
func, *args, **kwargs, **cache_kwargs
)
except encode.EncodeError as ex:
if settings.return_cache_entry:
raise ex
Expand Down
18 changes: 12 additions & 6 deletions tests/test_30_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import datetime
import pathlib
import time
from typing import Any

Expand Down Expand Up @@ -31,11 +30,18 @@ def cached_error() -> None:
raise ValueError("test error")


def test_cacheable(tmp_path: pathlib.Path) -> None:
@pytest.mark.parametrize(
"cache_kwargs,expected_hash",
[
({}, "a8260ac3cdc1404aa64a6fb71e853049"),
({"foo": "bar"}, "ad4c1867757974cfccabc18c3a5078b9"),
],
)
def test_cacheable(cache_kwargs: dict[str, Any], expected_hash: str) -> None:
con = config.get().engine.raw_connection()
cur = con.cursor()

cfunc = cache.cacheable(func)
cfunc = cache.cacheable(func, **cache_kwargs)

for counter in range(1, 3):
before = datetime.datetime.now(tz=datetime.timezone.utc)
Expand All @@ -49,7 +55,7 @@ def test_cacheable(tmp_path: pathlib.Path) -> None:
assert cur.fetchall() == [
(
1,
"a8260ac3cdc1404aa64a6fb71e853049",
expected_hash,
"9999-12-31 00:00:00.000000",
'{"a": "test", "b": null, "args": [], "kwargs": {}}',
counter,
Expand All @@ -70,7 +76,7 @@ def test_cacheable(tmp_path: pathlib.Path) -> None:


@pytest.mark.parametrize("raise_all_encoding_errors", [True, False])
def test_encode_errors(tmp_path: pathlib.Path, raise_all_encoding_errors: bool) -> None:
def test_encode_errors(raise_all_encoding_errors: bool) -> None:
config.set(raise_all_encoding_errors=raise_all_encoding_errors)

cfunc = cache.cacheable(func)
Expand Down Expand Up @@ -167,7 +173,7 @@ def test_expiration_and_return_cache_entry() -> None:
assert third.expiration == datetime.datetime(9999, 12, 31)


def test_tag(tmp_path: pathlib.Path) -> None:
def test_tag() -> None:
con = config.get().engine.raw_connection()
cur = con.cursor()

Expand Down

0 comments on commit 07fc1b7

Please sign in to comment.