-
Notifications
You must be signed in to change notification settings - Fork 0
/
bilstmTrain.py
209 lines (174 loc) · 7.02 KB
/
bilstmTrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import numpy as np
from torch import optim
from part3_parser import Parser
from utils import *
import matplotlib.pyplot as plt
import collections
def save_data_to_file(data_name, epochs, loss, acu, choice, with_pretrain=False):
file_dir = os.path.join(ARTIFACTS_PATH, "{0}_model_result.txt".format(data_name))
with open(file_dir, "a") as output:
output.write(
"Parameters - Choice \'{0}\' Batch size: {1}, epochs: {2}, lr: {3}, embedding length: {4}, lstm hidden dim: {5}\n".format(
choice, batch_size, epochs, lr, embedding_len, lstm_h_dim))
output.write(
"With pre train: {0}, Epochs: {1}\nAccuracy: {2}\nLoss: {3}\n".format(str(with_pretrain), epochs, str(acu),
str(loss)))
output.close()
def plot_graphs(dev_acc_list, dev_loss_list, epochs, name):
ticks = int(epochs / 10)
if not ticks:
ticks = 1
plt.plot(range(epochs), dev_acc_list)
plt.xticks(np.arange(0, epochs, step=1))
plt.yticks(np.arange(0, 110, step=10))
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.title('{} accuracy'.format(name))
for i in range(0, len(dev_acc_list), ticks):
plt.annotate("", (i, dev_acc_list[i]))
plt.show()
plt.plot(range(epochs), dev_loss_list)
plt.xticks(np.arange(0, epochs, step=1))
plt.yticks(np.arange(0, 4, step=0.5))
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('{} loss'.format(name))
for i in range(0, len(dev_loss_list), ticks):
plt.annotate("", (i, dev_loss_list[i]))
plt.show()
def iterate_model(model, train_data_loader, optimizer, criterion, epoch):
percentages_show = 5
limit_to_print = round(len(train_data_loader) * (percentages_show / 100))
limit_to_print = max(1, limit_to_print)
for index, batch in enumerate(train_data_loader):
sentences, tags = batch
sentences = sentences
tags = tags
optimizer.zero_grad()
output = model(sentences)
loss = criterion(output, tags)
loss.backward()
optimizer.step()
# Information printing:
if index % limit_to_print == 0 and index != 0:
percentages = (index / len(train_data_loader)) * 100
print("Train | Epoch: {0} | {1:.2f}% sentences finished".format(epoch + 1, percentages))
print('\n------ Train | Finished epoch {0} ------\n'.format(epoch + 1))
def save_500_acc_dev_to_file(data_name, choice):
file_dir = os.path.join(ARTIFACTS_PATH, "{0}_dev_500_acc_result.txt".format(data_name))
with open(file_dir, "a") as output:
output.write("Data name: {0}, Choice: '{1}'\nDev list: {2}\n\n".format(data_name, choice, str(dev_500_acc)))
output.close()
def train(model, train_data_loader, dev_data_loader, criterion, optimizer, epochs, data_name):
dev_acc_list = []
dev_loss_list = []
for epoch in range(epochs):
# train loop
iterate_model(model, train_data_loader, optimizer, criterion, epoch)
# calculate performance on dev_data_set
if dev_data_loader:
dev_acc, dev_loss = evaluate_accuracy(model, dev_data_loader, criterion, data_name, epoch)
dev_acc_list.append(dev_acc)
dev_loss_list.append(dev_loss)
if dev_data_loader:
print("\n\nTotal Accuracy: " + str(dev_acc_list))
print("\nTotal Loss: " + str(dev_loss_list))
save_data_to_file(data_name, epochs, dev_loss_list, dev_acc_list, model.choice, with_pretrain=False)
save_500_acc_dev_to_file(data_name, model.choice)
def calculate_accuracy(y_hats, tags, data_name):
good = 0
bad = 0
y_hats = y_hats.view(-1)
tags = tags.view(-1)
for i in range(len(tags)):
if tags[i] == L2I[PAD]:
continue
if data_name == "ner" and tags[i] == y_hats[i] == L2I['O']:
continue
if tags[i] == y_hats[i]:
good += 1
else:
bad += 1
return good / (good + bad)
def evaluate_accuracy(model, dev_dataset_loader, criterion, data_name, epoch):
percentages_show = 5
limit_to_print = round(len(dev_dataset_loader) * (percentages_show / 100))
limit_to_print = max(1, limit_to_print)
counter = 0
avg_acc = 0
avg_loss = 0
for index, batch in enumerate(dev_dataset_loader):
sentences, tags = batch
sentences = sentences
tags = tags
counter += 1
y_scores = model(sentences)
y_hats = torch.argmax(y_scores, dim=1)
loss = criterion(y_scores, tags)
current_accuracy = calculate_accuracy(y_hats, tags, data_name)
avg_acc += current_accuracy
dev_500_acc.append(current_accuracy)
avg_loss += float(loss)
# Information printing:
if index % limit_to_print == 0 and index != 0:
percentages = (index / len(dev_dataset_loader)) * 100
print("Dev | Epoch: {0} | {1:.2f}% sentences finished".format(epoch + 1, percentages))
print('\n------ Dev | Finished epoch {0} ------\n'.format(epoch + 1))
# Calculating acc and loss on all data set.
acc = (avg_acc / counter) * 100
loss = avg_loss / counter
print('***********************************************************************************************')
print('\nEmbed choice: \'{0}\' Data name:{1} Epoch:{2}, Acc:{3}, Loss:{4}\n'.format(model.choice, data_name,
epoch + 1,
acc, loss))
print('***********************************************************************************************')
return acc, loss
# Hyper parameters:
# batch_size = 200
# epochs = 50
# lr = 0.001
# embedding_length = 150
# lstm_h_dim = 200
dev_500_acc = []
batch_size = 500
epochs = 1
lr = 0.005
embedding_len = 300
char_embedding_len = 150
lstm_h_dim = 200
choice = 'a'
save_model = True
load_model = False
to_replace_rare_words = False
check_on_dev = False
def replace_rare_words(data):
counter = collections.Counter([tup[0] for sequence in data for tup in sequence])
tup = counter.most_common()[-10:]
if __name__ == "__main__":
# bilstmTrain.py <a/b/c/d> <train_file_path> <saved_model_path> <pos/ner>
if len(sys.argv) != 6:
raise ValueError("must get 4 parameters, Please run command: "
"'bilstmTrain.py <a/b/c/d> <train_file_path> <saved_model_path> <train_dataset_save_dir> <pos/ner>'")
_, choice, train_file_path, model_file_path, train_dataset_save_dir, data_name = sys.argv
# data
dataTrain = load_dataset(TRAIN_DATASET_DIR) if load_model else Parser("train", data_name,
dataset_path=train_file_path)
dataDev = Parser("dev", data_name) if check_on_dev else ""
dicts = Dictionaries(dataTrain)
F2I, L2I = dicts.F2I, dicts.L2I
#
# if to_replace_rare_words:
# replace_rare_words(dataTrain.data)
train_loader = make_loader(dataTrain.data, F2I, L2I, batch_size)
dev_loader = make_loader(dataDev.data, F2I, L2I, batch_size) if check_on_dev else ""
vocab_size = len(F2I)
output_dim = len(L2I)
# model
model = BILSTMNet(vocab_size, embedding_len, lstm_h_dim, output_dim, dicts, char_embedding_len, batch_size, choice)
if load_model:
model.load(MODEL_DIR)
criterion = nn.CrossEntropyLoss(ignore_index=F2I[PAD])
optimizer = optim.Adam(model.parameters(), lr)
# train
train(model, train_loader, dev_loader, criterion, optimizer, epochs, dataTrain.data_name)
save_model_and_data_sets(model, dataTrain, model_file_path, train_dataset_save_dir)