Skip to content

Commit

Permalink
allow getting disk usage from database (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 authored Sep 4, 2024
1 parent 3a3f023 commit 7fa5a4b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 87 deletions.
2 changes: 1 addition & 1 deletion cacholote/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return _decode_and_update(session, cache_entry, settings)
except decode.DecodeError as ex:
warnings.warn(str(ex), UserWarning)
clean._delete_cache_entry(session, cache_entry)
clean._delete_cache_entries(session, cache_entry)

result = func(*args, **kwargs)
cache_entry = database.CacheEntry(
Expand Down
175 changes: 91 additions & 84 deletions cacholote/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import posixpath
from typing import Any, Callable, Literal, Optional

import fsspec
import pydantic
import sqlalchemy as sa
import sqlalchemy.orm
Expand All @@ -35,7 +36,9 @@
)


def _get_files_from_cache_entry(cache_entry: database.CacheEntry) -> dict[str, str]:
def _get_files_from_cache_entry(
cache_entry: database.CacheEntry, key: str | None
) -> dict[str, Any]:
result = cache_entry.result
if not isinstance(result, (list, tuple, set)):
result = [result]
Expand All @@ -48,27 +51,57 @@ def _get_files_from_cache_entry(cache_entry: database.CacheEntry) -> dict[str, s
and obj["callable"] in FILE_RESULT_CALLABLES
):
fs, urlpath = extra_encoders._get_fs_and_urlpath(*obj["args"][:2])
files[fs.unstrip_protocol(urlpath)] = obj["args"][0]["type"]
value = obj["args"][0]
if key is not None:
value = value[key]
files[fs.unstrip_protocol(urlpath)] = value
return files


def _delete_cache_entry(
session: sa.orm.Session, cache_entry: database.CacheEntry
def _remove_files(
fs: fsspec.AbstractFileSystem,
files: list[str],
max_tries: int = 10,
**kwargs: Any,
) -> None:
fs, _ = utils.get_cache_files_fs_dirname()
files_to_delete = _get_files_from_cache_entry(cache_entry)
logger = config.get().logger
assert max_tries >= 1
if not files:
return

config.get().logger.info("deleting files", n_files_to_delete=len(files), **kwargs)

n_tries = 0
while files:
n_tries += 1
try:
fs.rm(files, **kwargs)
return
except FileNotFoundError:
# Another concurrent process might have deleted files
if n_tries >= max_tries:
raise
files = [file for file in files if fs.exists(file)]

# First, delete database entry
logger.info("deleting cache entry", cache_entry=cache_entry)
session.delete(cache_entry)

def _delete_cache_entries(
session: sa.orm.Session, *cache_entries: database.CacheEntry
) -> None:
fs, _ = utils.get_cache_files_fs_dirname()
files_to_delete = []
dirs_to_delete = []
for cache_entry in cache_entries:
session.delete(cache_entry)

files = _get_files_from_cache_entry(cache_entry, key="type")
for file, file_type in files.items():
if file_type == "application/vnd+zarr":
dirs_to_delete.append(file)
else:
files_to_delete.append(file)
database._commit_or_rollback(session)

# Then, delete files
for urlpath, file_type in files_to_delete.items():
if fs.exists(urlpath):
logger.info("deleting cache file", urlpath=urlpath)
fs.rm(urlpath, recursive=file_type == "application/vnd+zarr")
_remove_files(fs, files_to_delete, recursive=False)
_remove_files(fs, dirs_to_delete, recursive=True)


def delete(func_to_del: str | Callable[..., Any], *args: Any, **kwargs: Any) -> None:
Expand All @@ -88,19 +121,20 @@ def delete(func_to_del: str | Callable[..., Any], *args: Any, **kwargs: Any) ->
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(database.CacheEntry.key == hexdigest)
):
_delete_cache_entry(session, cache_entry)
_delete_cache_entries(session, cache_entry)


class _Cleaner:
def __init__(self, depth: int) -> None:
def __init__(self, depth: int, use_database: bool) -> None:
self.logger = config.get().logger
self.fs, self.dirname = utils.get_cache_files_fs_dirname()

self.urldir = self.fs.unstrip_protocol(self.dirname)

self.logger.info("getting disk usage")
self.file_sizes: dict[str, int] = collections.defaultdict(int)
for path, size in self.fs.du(self.dirname, total=False).items():
du = self.known_files if use_database else self.fs.du(self.dirname, total=False)
for path, size in du.items():
# Group dirs
urlpath = self.fs.unstrip_protocol(path)
parts = urlpath.replace(self.urldir, "", 1).strip("/").split("/")
Expand All @@ -120,6 +154,16 @@ def log_disk_usage(self) -> None:
def stop_cleaning(self, maxsize: int) -> bool:
return self.disk_usage <= maxsize

@property
def known_files(self) -> dict[str, int]:
known_files: dict[str, int] = {}
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
known_files.update(
_get_files_from_cache_entry(cache_entry, key="file:size")
)
return known_files

def get_unknown_files(self, lock_validity_period: float | None) -> set[str]:
self.logger.info("getting unknown files")

Expand All @@ -137,25 +181,15 @@ def get_unknown_files(self, lock_validity_period: float | None) -> set[str]:
locked_files.add(urlpath)
locked_files.add(urlpath.rsplit(".lock", 1)[0])

if unknown_files := (set(self.file_sizes) - locked_files):
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
for known_file in _get_files_from_cache_entry(cache_entry):
unknown_files.discard(known_file)
if not unknown_files:
break
return unknown_files
return set(self.file_sizes) - locked_files - set(self.known_files)

def delete_unknown_files(
self, lock_validity_period: float | None, recursive: bool
) -> None:
unknown_files = self.get_unknown_files(lock_validity_period)
for urlpath in unknown_files:
self.pop_file_size(urlpath)
self.remove_files(
list(unknown_files),
recursive=recursive,
)
_remove_files(self.fs, list(unknown_files), recursive=recursive)
self.log_disk_usage()

@staticmethod
Expand Down Expand Up @@ -207,30 +241,6 @@ def _get_method_sorters(
sorters.append(database.CacheEntry.expiration)
return sorters

def remove_files(
self,
files: list[str],
max_tries: int = 10,
**kwargs: Any,
) -> None:
assert max_tries >= 1
if not files:
return

self.logger.info("deleting files", n_files_to_delete=len(files), **kwargs)

n_tries = 0
while files:
n_tries += 1
try:
self.fs.rm(files, **kwargs)
return
except FileNotFoundError:
# Another concurrent process might have deleted files
if n_tries >= max_tries:
raise
files = [file for file in files if self.fs.exists(file)]

def delete_cache_files(
self,
maxsize: int,
Expand All @@ -244,37 +254,27 @@ def delete_cache_files(
if self.stop_cleaning(maxsize):
return

files_to_delete = []
dirs_to_delete = []
entries_to_delete = []
self.logger.info("getting cache entries to delete")
n_entries_to_delete = 0
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters).order_by(*sorters)
):
files = _get_files_from_cache_entry(cache_entry)
files = _get_files_from_cache_entry(cache_entry, key="file:size")
if any(file.startswith(self.urldir) for file in files):
n_entries_to_delete += 1
session.delete(cache_entry)

for file, file_type in files.items():
entries_to_delete.append(cache_entry)
for file in files:
self.pop_file_size(file)
if file_type == "application/vnd+zarr":
dirs_to_delete.append(file)
else:
files_to_delete.append(file)

if self.stop_cleaning(maxsize):
break

if n_entries_to_delete:
if entries_to_delete:
self.logger.info(
"deleting cache entries", n_entries_to_delete=n_entries_to_delete
"deleting cache entries", n_entries_to_delete=len(entries_to_delete)
)
database._commit_or_rollback(session)
_delete_cache_entries(session, *entries_to_delete)

self.remove_files(files_to_delete, recursive=False)
self.remove_files(dirs_to_delete, recursive=True)
self.log_disk_usage()

if not self.stop_cleaning(maxsize):
Expand All @@ -296,6 +296,7 @@ def clean_cache_files(
tags_to_clean: list[str | None] | None = None,
tags_to_keep: list[str | None] | None = None,
depth: int = 1,
use_database: bool = False,
) -> None:
"""Clean cache files.
Expand All @@ -318,8 +319,15 @@ def clean_cache_files(
tags_to_clean and tags_to_keep are mutually exclusive.
depth: int, default: 1
depth for grouping cache files
use_database: bool, default: False
Whether to infer disk usage from the cacholote database
"""
cleaner = _Cleaner(depth=depth)
if use_database and delete_unknown_files:
raise ValueError(
"'use_database' and 'delete_unknown_files' are mutually exclusive"
)

cleaner = _Cleaner(depth=depth, use_database=use_database)

if delete_unknown_files:
cleaner.delete_unknown_files(lock_validity_period, recursive)
Expand Down Expand Up @@ -352,15 +360,15 @@ def clean_invalid_cache_entries(
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters)
):
_delete_cache_entry(session, cache_entry)
_delete_cache_entries(session, cache_entry)

if try_decode:
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
try:
decode.loads(cache_entry._result_as_string)
except decode.DecodeError:
_delete_cache_entry(session, cache_entry)
_delete_cache_entries(session, cache_entry)


def expire_cache_entries(
Expand All @@ -379,15 +387,14 @@ def expire_cache_entries(
if after is not None:
filters.append(database.CacheEntry.created_at > after)

count = 0
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters)
):
count += 1
if delete:
session.delete(cache_entry)
else:
cache_entries = list(
session.scalars(sa.select(database.CacheEntry).filter(*filters))
)
if delete:
_delete_cache_entries(session, *cache_entries)
else:
for cache_entry in cache_entries:
cache_entry.expiration = now
database._commit_or_rollback(session)
return count
database._commit_or_rollback(session)
return len(cache_entries)
6 changes: 4 additions & 2 deletions tests/test_60_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def cached_now() -> datetime.datetime:
@pytest.mark.parametrize("method", ["LRU", "LFU"])
@pytest.mark.parametrize("set_cache", ["file", "cads"], indirect=True)
@pytest.mark.parametrize("folder,depth", [("", 1), ("", 2), ("foo", 2)])
@pytest.mark.parametrize("use_database", [True, False])
def test_clean_cache_files(
tmp_path: pathlib.Path,
set_cache: str,
method: Literal["LRU", "LFU"],
folder: str,
depth: int,
use_database: bool,
) -> None:
con = config.get().engine.raw_connection()
cur = con.cursor()
Expand All @@ -66,12 +68,12 @@ def test_clean_cache_files(
assert set(fs.ls(dirname)) == {lru_path, lfu_path}

# Do not clean
clean.clean_cache_files(2, method=method, depth=depth)
clean.clean_cache_files(2, method=method, depth=depth, use_database=use_database)
cur.execute("SELECT COUNT(*) FROM cache_entries", ())
assert cur.fetchone() == (fs.du(dirname),) == (2,)

# Delete one file
clean.clean_cache_files(1, method=method, depth=depth)
clean.clean_cache_files(1, method=method, depth=depth, use_database=use_database)
cur.execute("SELECT COUNT(*) FROM cache_entries", ())
assert cur.fetchone() == (fs.du(dirname),) == (1,)
assert not fs.exists(lru_path if method == "LRU" else lfu_path)
Expand Down

0 comments on commit 7fa5a4b

Please sign in to comment.