Skip to content

Commit

Permalink
Merge pull request #1174 from dedupeio/closure_branch_and_bound
Browse files Browse the repository at this point in the history
WIP: closure branch and bound
  • Loading branch information
fgregg authored Dec 19, 2023
2 parents 7c05b1c + 3b159a2 commit 3179777
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 146 deletions.
7 changes: 6 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
# Required
version: 2

# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.12"

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
Expand All @@ -16,7 +22,6 @@ formats: all

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
install:
- requirements: docs/requirements.txt
- method: pip
Expand Down
22 changes: 15 additions & 7 deletions dedupe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from pkgutil import extend_path

__path__ = extend_path(__path__, __name__)

from dedupe._init import * # noqa
from dedupe.api import ( # noqa: F401
Dedupe,
Gazetteer,
RecordLink,
StaticDedupe,
StaticGazetteer,
StaticRecordLink,
)
from dedupe.convenience import ( # noqa: F401
canonicalize,
console_label,
training_data_dedupe,
training_data_link,
)
from dedupe.serializer import read_training, write_training # noqa: F401
15 changes: 0 additions & 15 deletions dedupe/_init.py

This file was deleted.

113 changes: 113 additions & 0 deletions dedupe/branch_and_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import functools
import warnings
from typing import Any, Iterable, Mapping, Sequence, Tuple

from ._typing import Cover
from .predicates import Predicate

Partial = Tuple[Predicate, ...]


def _reachable(dupe_cover: Mapping[Any, frozenset[int]]) -> int:
return len(frozenset.union(*dupe_cover.values())) if dupe_cover else 0


def _remove_dominated(coverage: Cover, dominator: Predicate) -> Cover:
dominant_cover = coverage[dominator]

return {
pred: cover
for pred, cover in coverage.items()
if not (dominator.cover_count <= pred.cover_count and dominant_cover >= cover)
}


def _uncovered_by(
coverage: Mapping[Any, frozenset[int]], covered: frozenset[int]
) -> dict[Any, frozenset[int]]:
remaining = {}
for predicate, uncovered in coverage.items():
still_uncovered = uncovered - covered
if still_uncovered:
remaining[predicate] = still_uncovered

return remaining


def _order_by(
candidates: Mapping[Predicate, Sequence[Any]], p: Predicate
) -> tuple[int, float]:
return (len(candidates[p]), -p.cover_count)


def _score(partial: Iterable[Predicate]) -> float:
return sum(p.cover_count for p in partial)


def _suppress_recursion_error(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RecursionError:
warnings.warn("Recursion limit eached while searching for predicates")

return wrapper


def search(candidates, target: int, max_calls: int) -> Partial:
calls = max_calls

cheapest_score = float("inf")
cheapest: Partial = ()

original_cover = candidates.copy()

def _covered(partial: Partial) -> int:
return (
len(frozenset.union(*(original_cover[p] for p in partial)))
if partial
else 0
)

@_suppress_recursion_error
def walk(candidates: Cover, partial: Partial = ()) -> None:
nonlocal calls
nonlocal cheapest
nonlocal cheapest_score

if calls <= 0:
return

calls -= 1

covered = _covered(partial)
score = _score(partial)

if covered < target:
window = cheapest_score - score
candidates = {
p: cover for p, cover in candidates.items() if p.cover_count < window
}

reachable = _reachable(candidates) + covered

if candidates and reachable >= target:
order_by = functools.partial(_order_by, candidates)
best = max(candidates, key=order_by)

remaining = _uncovered_by(candidates, candidates[best])
walk(remaining, partial + (best,))
del remaining

reduced = _remove_dominated(candidates, best)
walk(reduced, partial)
del reduced

elif score < cheapest_score:
cheapest = partial
cheapest_score = score

walk(candidates)
return cheapest
114 changes: 3 additions & 111 deletions dedupe/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from typing import TYPE_CHECKING, overload
from warnings import warn

from . import blocking
from . import blocking, branch_and_bound

if TYPE_CHECKING:
from typing import Any, Iterable, Mapping, Sequence
from typing import Iterable, Sequence

from ._typing import (
ComparisonCover,
Expand Down Expand Up @@ -75,8 +75,7 @@ def learn(
else:
raise ValueError("candidate_type is not valid")

searcher = BranchBound(target_cover, 2500)
final_predicates = searcher.search(candidate_cover)
final_predicates = branch_and_bound.search(candidate_cover, target_cover, 2500)

logger.info("Final predicate set:")
for predicate in final_predicates:
Expand Down Expand Up @@ -329,113 +328,6 @@ def coveredPairs(self, blocker, records_1, records_2):
return pair_cover


class BranchBound(object):
def __init__(self, target: int, max_calls: int) -> None:
self.target: int = target
self.calls: int = max_calls

self.cheapest_score: float = float("inf")
self.original_cover: Cover = {}
self.cheapest: tuple[Predicate, ...] = ()

def search(
self, candidates: Cover, partial: tuple[Predicate, ...] = ()
) -> tuple[Predicate, ...]:
if self.calls <= 0:
return self.cheapest

if not self.original_cover:
self.original_cover = candidates.copy()

self.calls -= 1

covered = self.covered(partial)
score = self.score(partial)

if covered >= self.target:
if score < self.cheapest_score:
self.cheapest = partial
self.cheapest_score = score

else:
window = self.cheapest_score - score

candidates = {
p: cover for p, cover in candidates.items() if p.cover_count < window
}

reachable = self.reachable(candidates) + covered

if candidates and reachable >= self.target:
order_by = functools.partial(self.order_by, candidates)

best = max(candidates, key=order_by)

remaining = self.uncovered_by(candidates, candidates[best])
try:
self.search(remaining, partial + (best,))
except RecursionError:
return self.cheapest

del remaining

reduced = self.remove_dominated(candidates, best)

try:
self.search(reduced, partial)
except RecursionError:
return self.cheapest

del reduced

return self.cheapest

@staticmethod
def order_by(
candidates: Mapping[Predicate, Sequence[Any]], p: Predicate
) -> tuple[int, float]:
return (len(candidates[p]), -p.cover_count)

@staticmethod
def score(partial: Iterable[Predicate]) -> float:
return sum(p.cover_count for p in partial)

def covered(self, partial: tuple[Predicate, ...]) -> int:
if partial:
return len(frozenset.union(*(self.original_cover[p] for p in partial)))
else:
return 0

@staticmethod
def reachable(dupe_cover: Mapping[Any, frozenset[int]]) -> int:
if dupe_cover:
return len(frozenset.union(*dupe_cover.values()))
else:
return 0

@staticmethod
def remove_dominated(coverage: Cover, dominator: Predicate) -> Cover:
dominant_cover = coverage[dominator]

for pred, cover in coverage.copy().items():
if dominator.cover_count <= pred.cover_count and dominant_cover >= cover:
del coverage[pred]

return coverage

@staticmethod
def uncovered_by(
coverage: Mapping[Any, frozenset[int]], covered: frozenset[int]
) -> dict[Any, frozenset[int]]:
remaining = {}
for predicate, uncovered in coverage.items():
still_uncovered = uncovered - covered
if still_uncovered:
remaining[predicate] = still_uncovered

return remaining


class InfiniteSet(object):
def __and__(self, item):
return item
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ dependencies = [
"scikit-learn",
"affinegap>=1.3",
"categorical-distance>=1.9",
"dedupe-variable-datetime",
"numpy>=1.20",
"doublemetaphone",
"highered>=0.2.0",
Expand Down
8 changes: 3 additions & 5 deletions tests/test_blocking.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import unittest
from collections import defaultdict

from future.utils import viewitems, viewvalues

import dedupe


Expand Down Expand Up @@ -54,7 +52,7 @@ def test_unconstrained_inverted_index(self):
[dedupe.predicates.TfidfTextSearchPredicate(0.0, "name")]
)

blocker.index(set(record["name"] for record in viewvalues(self.data_d)), "name")
blocker.index(set(record["name"] for record in self.data_d.values()), "name")

blocks = defaultdict(set)

Expand Down Expand Up @@ -87,13 +85,13 @@ def setUp(self):

self.records_1 = dict(
(record_id, record)
for record_id, record in viewitems(data_d)
for record_id, record in data_d.items()
if record["dataset"] == 0
)

self.fields_2 = dict(
(record_id, record["name"])
for record_id, record in viewitems(data_d)
for record_id, record in data_d.items()
if record["dataset"] == 1
)

Expand Down
2 changes: 0 additions & 2 deletions tests/test_predicate_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import unittest

from future.builtins import str

from dedupe import predicate_functions as fn
from dedupe.cpredicates import ngrams

Expand Down
2 changes: 0 additions & 2 deletions tests/test_predicates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import unittest

from future.builtins import str

from dedupe import predicates


Expand Down
5 changes: 3 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import dedupe
import dedupe.branch_and_bound as branch_and_bound
import dedupe.training as training


Expand Down Expand Up @@ -67,8 +68,8 @@ def test_uncovered_by(self):

before_copy = before.copy()

assert training.BranchBound.uncovered_by(before, frozenset()) == before
assert training.BranchBound.uncovered_by(before, frozenset({3})) == after
assert branch_and_bound._uncovered_by(before, frozenset()) == before
assert branch_and_bound._uncovered_by(before, frozenset({3})) == after
assert before == before_copy

def test_covered_pairs(self):
Expand Down

0 comments on commit 3179777

Please sign in to comment.