Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix phototour dataset #8733

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 14 additions & 32 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,38 +1305,29 @@ def test_not_found_or_corrupted(self):
class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.PhotoTour

# The PhotoTour dataset returns examples with different features with respect to the 'train' parameter. Thus,
# we overwrite 'FEATURE_TYPES' with a dummy value to satisfy the initial checks of the base class. Furthermore, we
# overwrite the 'test_feature_types()' method to select the correct feature types before the test is run.
FEATURE_TYPES = ()
_TRAIN_FEATURE_TYPES = (torch.Tensor,)
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)

combinations_grid(train=(True, False))
# The PhotoTour dataset returns only a single feature type.
FEATURE_TYPES = (torch.Tensor,)

_NAME = "liberty"
_NAME = "notredame"

def dataset_args(self, tmpdir, config):
return tmpdir, self._NAME

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

# In contrast to the original data, the fake images injected here comprise only a single patch. Thus,
# num_images == num_patches.
# Simulate fake data
num_patches = 5

image_files = self._create_images(tmpdir, self._NAME, num_patches)
point_ids, info_file = self._create_info_file(tmpdir / self._NAME, num_patches)
num_matches, matches_file = self._create_matches_file(tmpdir / self._NAME, num_patches, point_ids)

self._create_archive(tmpdir, self._NAME, *image_files, info_file, matches_file)
self._create_archive(tmpdir, self._NAME, *image_files, info_file)

return num_patches if config["train"] else num_matches
return num_patches

def _create_images(self, root, name, num_images):
# The images in the PhotoTour dataset comprises of multiple grayscale patches of 64 x 64 pixels. Thus, the
# smallest fake image is 64 x 64 pixels and comprises a single patch.
# Generate images
return datasets_utils.create_image_folder(
root, name, lambda idx: f"patches{idx:04d}.bmp", num_images, size=(1, 64, 64)
)
Expand All @@ -1350,18 +1341,6 @@ def _create_info_file(self, root, num_images):

return point_ids, file

def _create_matches_file(self, root, num_patches, point_ids):
lines = [
f"{patch_id1} {point_ids[patch_id1]} 0 {patch_id2} {point_ids[patch_id2]} 0\n"
for patch_id1, patch_id2 in itertools.combinations(range(num_patches), 2)
]

file = root / "m50_100000_100000_0.txt"
with open(file, "w") as fh:
fh.writelines(lines)

return len(lines), file

def _create_archive(self, root, name, *files):
archive = root / f"{name}.zip"
with zipfile.ZipFile(archive, "w") as zip:
Expand All @@ -1372,12 +1351,10 @@ def _create_archive(self, root, name, *files):

@datasets_utils.test_all_configs
def test_feature_types(self, config):
feature_types = self.FEATURE_TYPES
self.FEATURE_TYPES = self._TRAIN_FEATURE_TYPES if config["train"] else self._TEST_FEATURE_TYPES
try:
super().test_feature_types.__wrapped__(self, config)
finally:
self.FEATURE_TYPES = feature_types
except KeyError as e:
pytest.fail(f"KeyError during test_feature_types: {e}")


class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1869,6 +1846,11 @@ def test_class_to_idx(self):
with self.create_dataset() as (dataset, _):
assert dataset.class_to_idx == class_to_idx

def test_images_download_preexisting(self):
with pytest.raises(RuntimeError):
with self.create_dataset({"download": True}):
pass


class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.INaturalist
Expand Down
4 changes: 1 addition & 3 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,7 @@ def phototour():
return itertools.chain.from_iterable(
[
collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
# requests timeout from within CI. They are disabled until this is resolved.
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
for name in ("notredame", "trevi", "halfdome")
]
)

Expand Down
127 changes: 39 additions & 88 deletions torchvision/datasets/phototour.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Union

import numpy as np
import torch
Expand All @@ -11,18 +11,7 @@


class PhotoTour(VisionDataset):
"""`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.

.. note::

We only provide the newer version of the dataset, since the authors state that it

is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
patches are centred on real interest point detections, rather than being projections of 3D points as is the
case in the old dataset.

The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.

"""`Multi-view Stereo Correspondence <https://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.

Args:
root (str or ``pathlib.Path``): Root directory where images are.
Expand All @@ -32,60 +21,42 @@ class PhotoTour(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""

urls = {
"notredame_harris": [
"http://matthewalunbrown.com/patchdata/notredame_harris.zip",
"notredame_harris.zip",
"69f8c90f78e171349abdf0307afefe4d",
],
"yosemite_harris": [
"http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
"yosemite_harris.zip",
"a73253d1c6fbd3ba2613c45065c00d46",
],
"liberty_harris": [
"http://matthewalunbrown.com/patchdata/liberty_harris.zip",
"liberty_harris.zip",
"c731fcfb3abb4091110d0ae8c7ba182c",
"trevi": [
"https://phototour.cs.washington.edu/patches/trevi.zip",
"trevi.zip",
"d49ab428f154554856f83dba8aa76539",
],
"notredame": [
"http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
"https://phototour.cs.washington.edu/patches/notredame.zip",
"notredame.zip",
"509eda8535847b8c0a90bbb210c83484",
"0f801127085e405a61465605ea80c595",
],
"halfdome": [
"https://phototour.cs.washington.edu/patches/halfdome.zip",
"halfdome.zip",
"db871c5a86f4878c6754d0d12146440b",
],
"yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
"liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
}
means = {
"notredame": 0.4854,
"yosemite": 0.4844,
"liberty": 0.4437,
"notredame_harris": 0.4854,
"yosemite_harris": 0.4844,
"liberty_harris": 0.4437,
"trevi": 0.4832,
"notredame": 0.4757,
"halfdome": 0.4718,
}
stds = {
"notredame": 0.1864,
"yosemite": 0.1818,
"liberty": 0.2019,
"notredame_harris": 0.1864,
"yosemite_harris": 0.1818,
"liberty_harris": 0.2019,
"trevi": 0.1913,
"notredame": 0.1931,
"halfdome": 0.1791,
}
lens = {
"notredame": 468159,
"yosemite": 633587,
"liberty": 450092,
"liberty_harris": 379587,
"yosemite_harris": 450912,
"notredame_harris": 325295,
"trevi": 101120,
"notredame": 104196,
"halfdome": 107776,
}
image_ext = "bmp"
info_file = "info.txt"
matches_files = "m50_100000_100000_0.txt"

def __init__(
self,
Expand All @@ -112,30 +83,23 @@ def __init__(
self.cache()

# load the serialized data
self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)
self.data, self.labels = torch.load(self.data_file, weights_only=True)

def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
def __getitem__(self, index: int) -> torch.Tensor:
"""
Args:
index (int): Index

Returns:
tuple: (data1, data2, matches)
torch.Tensor: The image patch.
"""
if self.train:
data = self.data[index]
if self.transform is not None:
data = self.transform(data)
return data
m = self.matches[index]
data1, data2 = self.data[m[0]], self.data[m[1]]
data = self.data[index]
if self.transform is not None:
data1 = self.transform(data1)
data2 = self.transform(data2)
return data1, data2, m[2]
data = self.transform(data)
return data

def __len__(self) -> int:
return len(self.data if self.train else self.matches)
return len(self.data)

def _check_datafile_exists(self) -> bool:
return os.path.exists(self.data_file)
Expand Down Expand Up @@ -165,36 +129,35 @@ def download(self) -> None:

def cache(self) -> None:
# process and save as torch files

dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
read_info_file(self.data_dir, self.info_file),
read_matches_files(self.data_dir, self.matches_files),
)

with open(self.data_file, "wb") as f:
torch.save(dataset, f)

def extra_repr(self) -> str:
split = "Train" if self.train is True else "Test"
return f"Split: {split}"
return f"Dataset: {self.name}"


def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
"""Return a Tensor containing the patches"""

def PIL2array(_img: Image.Image) -> np.ndarray:
"""Convert PIL image type to numpy 2D array"""
# Ensure the patch size is exactly 64x64
if _img.size != (64, 64):
raise ValueError(f"Invalid patch size: {_img.size}")
return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

def find_files(_data_dir: str, _image_ext: str) -> List[str]:
"""Return a list with the file names of the images containing the patches"""
files = []
# find those files with the specified extension
for file_dir in os.listdir(_data_dir):
if file_dir.endswith(_image_ext):
files.append(os.path.join(_data_dir, file_dir))
return sorted(files) # sort files in ascend order to keep relations
return sorted(files)

patches = []
list_files = find_files(data_dir, image_ext)
Expand All @@ -204,27 +167,15 @@ def find_files(_data_dir: str, _image_ext: str) -> List[str]:
for y in range(0, img.height, 64):
for x in range(0, img.width, 64):
patch = img.crop((x, y, x + 64, y + 64))
patches.append(PIL2array(patch))
try:
patches.append(PIL2array(patch))
except ValueError as e:
print(f"Skipping invalid patch at ({x}, {y}) in {fpath}: {e}")
return torch.ByteTensor(np.array(patches[:n]))


def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point.
"""
"""Return a Tensor containing the list of labels."""
with open(os.path.join(data_dir, info_file)) as f:
labels = [int(line.split()[0]) for line in f]
return torch.LongTensor(labels)


def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
"""Return a Tensor containing the ground truth matches
Read the file and keep only 3D point ID.
Matches are represented with a 1, non matches with a 0.
"""
matches = []
with open(os.path.join(data_dir, matches_file)) as f:
for line in f:
line_split = line.split()
matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
return torch.LongTensor(matches)