-
Notifications
You must be signed in to change notification settings - Fork 0
/
omniglot_dataset.py
59 lines (43 loc) · 2.15 KB
/
omniglot_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from omniglot_loader import OmniglotLoader
import numpy as np
class OmniglotDataset:
def __init__(self, path="./data/"):
self.__path = path
self.__train_folder = "train_alphabets"
self.__test_folder = "test_alphabets"
self.__data_loader = OmniglotLoader(self.__path, self.__train_folder, self.__test_folder)
self.train_set = ()
self.val_set = ()
self.test_set = ()
self.data_shape = ()
def load(self):
train_val_set, self.test_set = self.__data_loader.load_data()
_, _, height, width, channel = train_val_set[0].shape
start_index_val = train_val_set[1].tolist().index('Latin/character01')
self.train_set = (train_val_set[0][:start_index_val], train_val_set[1][:start_index_val])
self.val_set = (train_val_set[0][start_index_val:], train_val_set[1][start_index_val:])
self.data_shape = (height, width, channel)
def get_data_classes(self, num_classes, data_type='train'):
data = self.__get_data_type(data_type)
classes = np.random.choice(data.shape[0], size=(num_classes,), replace=False)
return classes
def get_image_pair(self, class_value, data_type='train', same_class=False):
data = self.__get_data_type(data_type)
num_classes, num_images, _, _, _ = data.shape
image_indices = np.random.choice(num_images, size=2, replace=(not same_class))
first_image = data[class_value, image_indices[0]]
second_class = class_value if same_class else self.__get_different_value(class_value, num_classes)
second_image = data[second_class, image_indices[1]]
return first_image, second_image
def __get_different_value(self, value, max_values):
values = list(range(max_values))
values.remove(value)
different_value = np.random.choice(values)
return different_value
def __get_data_type(self, data_type):
if (data_type == 'train'):
return self.train_set[0]
elif (data_type == 'val'):
return self.val_set[0]
else:
return self.test_set[0]