diff --git a/sapp/bulk_saver.py b/sapp/bulk_saver.py index 98a4a44a..6472ade0 100644 --- a/sapp/bulk_saver.py +++ b/sapp/bulk_saver.py @@ -7,7 +7,7 @@ """ import logging -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, Type from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert @@ -38,7 +38,7 @@ class BulkSaver: """Stores new objects created within a run and bulk save them""" # order is significant, objects will be saved in this order. - SAVING_CLASSES_ORDER = [ + DEFAULT_SAVING_CLASSES_ORDER = [ SharedText, Issue, IssueInstanceFixInfo, @@ -62,18 +62,23 @@ class BulkSaver: BATCH_SIZE = 30000 def __init__( - self, primary_key_generator: Optional[PrimaryKeyGenerator] = None + self, + primary_key_generator: Optional[PrimaryKeyGenerator] = None, + extra_saving_classes: Optional[List[Type[object]]] = None, ) -> None: self.primary_key_generator: PrimaryKeyGenerator = ( primary_key_generator or PrimaryKeyGenerator() ) + self.saving_classes_order: List[Type[object]] = ( + extra_saving_classes or [] + ) + self.DEFAULT_SAVING_CLASSES_ORDER self.saving: Dict[str, Any] = {} - for cls in self.SAVING_CLASSES_ORDER: + for cls in self.saving_classes_order: self.saving[cls.__name__] = [] # pyre-fixme[2]: Parameter must be annotated. def add(self, item) -> None: - assert item.model in self.SAVING_CLASSES_ORDER, ( + assert item.model in self.saving_classes_order, ( "%s should be added with session.add()" % item.model.__name__ ) self.saving[item.model.__name__].append(item) @@ -81,7 +86,7 @@ def add(self, item) -> None: # pyre-fixme[2]: Parameter must be annotated. def add_all(self, items) -> None: if items: - assert items[0].model in self.SAVING_CLASSES_ORDER, ( + assert items[0].model in self.saving_classes_order, ( "%s should be added with session.add_all()" % items[0].model.__name__ ) self.saving[items[0].model.__name__].extend(items) @@ -96,7 +101,7 @@ def save_all( ) -> None: saving_classes = [ cls - for cls in self.SAVING_CLASSES_ORDER + for cls in self.saving_classes_order if len(self.saving[cls.__name__]) != 0 ] @@ -276,6 +281,6 @@ def add_trace_frame_annotation_trace_frame_assoc( def dump_stats(self) -> str: stat_str = "" - for cls in self.SAVING_CLASSES_ORDER: + for cls in self.saving_classes_order: stat_str += "%s: %d\n" % (cls.__name__, len(self.saving[cls.__name__])) return stat_str diff --git a/sapp/db_support.py b/sapp/db_support.py index d4f42b25..0499d2b8 100644 --- a/sapp/db_support.py +++ b/sapp/db_support.py @@ -361,7 +361,14 @@ def reserve( count = item_counts[cls.__name__] else: count = 1 - self._reserve_id_range(session, cls, count) + + if count > 0: + self._reserve_id_range(session, cls, count) + elif count == 0: + # Don't bother locking rows if there's nothing to reserve + pass + else: + raise ValueError(f"{cls.__name__} count must be >= 0") return self diff --git a/sapp/pipeline/database_saver.py b/sapp/pipeline/database_saver.py index c70190b2..0f8a3e0f 100644 --- a/sapp/pipeline/database_saver.py +++ b/sapp/pipeline/database_saver.py @@ -7,7 +7,7 @@ import collections import logging -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Type from ..bulk_saver import BulkSaver from ..db import DB @@ -40,13 +40,16 @@ def __init__( database: DB, primary_key_generator: Optional[PrimaryKeyGenerator] = None, dry_run: bool = False, + extra_saving_classes: Optional[List[Type[object]]] = None, ) -> None: self.dbname: str = database.dbname self.database = database self.primary_key_generator: PrimaryKeyGenerator = ( primary_key_generator or PrimaryKeyGenerator() ) - self.bulk_saver = BulkSaver(self.primary_key_generator) + self.bulk_saver = BulkSaver( + self.primary_key_generator, extra_saving_classes=extra_saving_classes + ) self.dry_run = dry_run self.graph: TraceGraph self.summary: Summary @@ -116,7 +119,6 @@ def _save(self) -> RunSummary: run_label=self.summary.get("meta_run_child_label", None), ) ) - session.add_all(self.summary.get("run_attributes", [])) session.commit() run_id = self.summary["run"].id.resolved()