-
Notifications
You must be signed in to change notification settings - Fork 0
/
params.py
129 lines (97 loc) · 2.8 KB
/
params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import typing
import yaml
# -- DVC params --------------------------------------------------------------
MULTIPROCESSING = True
BATCH_SIZE = 32
class LocalizeParams:
# Note: only square images are supported
IMG_SHAPE = (1, 84, 84)
class PrepareLocalizeParams:
SEED = 25032023
MAX_IMAGES = 100000
TRAIN_SPLIT = 0.8
class LocalizeModelParams:
MODEL_NAME = "plate_localizer_v3"
CONV_LAYERS = (64, "M", 128, "M", 256, "M", 512, "M")
AVGPOOL_SIZE = (5, 5)
HIDDEN_LAYERS = (512, 512)
DROPOUT = 0.4
class TrainLocalizeParams:
SEED = 25032023
LR = 0.0005
EPOCHS = 3
class EvaluateLocalizeParams:
SEED = 25032023
class OCRParams:
# Note: only square images are supported
IMG_SHAPE = (1, 84, 84)
MAX_LABEL_LENGTH = 7
# cantons (26) + number of digits (10) + 1 for the blank symbol
GRU_NUM_CLASSES = 37
GRU_BLANK_CLASS = 36
# number of swiss cantons
# RNN_NUM_CLASSES = len(Canton) # not for dvc.yaml
class PrepareOCRParams:
SEED = 25032023
MAX_IMAGES = 250000
TRAIN_SPLIT = 0.8
# the size of the image before cropping
IMG_SIZE = (256, 256)
class OCRModelParams:
MODEL_NAME = "plate_ocr_v2"
CONV_LAYERS = (64, "M", 128, "M", 256, "M", 256, "M")
GRU_AVGPOOL_SIZE = (18, 18)
GRU_HIDDEN_SIZE = 384
GRU_NUM_LAYERS = 2
GRU_DROPOUT = 0.5
class TrainOCRParams:
SEED = 25032023
LR = 0.00025
EPOCHS = 2
class EvaluateOCRParams:
SEED = 25032023
class PrepareStackParams:
SEED = 25032023
MAX_IMAGES = 5000
class EvaluateStackParams:
SEED = 25032023
# -- Python specific --------------------------------------------------------
# The following are typing extensions to make catching errors easier
GlobParamsKeysType = typing.Literal[
"src_path",
"out_path",
"dataset_path",
"prepared_data_path",
"out_prepared_folder",
"out_log_folder",
"out_checkpoint_folder",
"out_save_folder",
"out_evaluation_folder",
"localize_model_folder",
"ocr_model_folder",
"stack_model_folder",
"prepared_data_localize_path",
"prepared_data_ocr_path",
"prepared_data_stack_path",
"out_prepared_localize_path",
"out_prepared_ocr_path",
"out_log_localize_path",
"out_log_ocr_path",
"out_save_localize_path",
"out_save_ocr_path",
"out_checkpoints_localize_path",
"out_checkpoints_ocr_path",
"out_evaluation_localize_path",
"out_evaluation_ocr_path",
"out_evaluation_stack_path",
]
GlobParamsType = typing.Dict[
GlobParamsKeysType,
str,
]
glob_params: GlobParamsType = yaml.safe_load(open("params.yaml"))
# Verify that the params.yaml and params.py are in sync
err_msg = "params.yaml keys does not match with params.py keys"
assert (
tuple(glob_params.keys()) == GlobParamsType.__args__[0].__args__
), err_msg