forked from yangbisheng2009/nsfw-resnet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
1_predict_image.py
77 lines (63 loc) · 2.62 KB
/
1_predict_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import print_function
import datetime
import os
import time
import sys
import traceback
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
def main(args):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformation = transforms.Compose([
transforms.Resize((224, 224)),
#transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,])
'''
transformation = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
normalize,])
'''
classes = torch.load(args.checkpoint)['classes']
model = torchvision.models.__dict__[args.model](pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))
model = nn.DataParallel(model, device_ids=args.device)
model.cuda()
model.load_state_dict(torch.load(args.checkpoint)['model'])
model.eval()
for image in os.listdir(args.test_path):
try:
image_ = args.test_path + '/' + image
#image_tensor = transformation(Image.open(image_)).float()
image_tensor = transformation(Image.open(image_).convert('RGB')).float()
image_tensor = image_tensor.unsqueeze_(0)
input = image_tensor.cuda()
output = model(input)
index = output.data.cpu().numpy().argmax()
#print(output.data.cpu().numpy())
label = classes[index]
print('{}\t{}\t{}'.format(image, classes[index], index))
sys.stdout.flush()
except:
print(image)
traceback.print_exc()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--test-path', default='./data/beauty', help='dataset')
parser.add_argument('--model', default='resnet101', help='model')
parser.add_argument('--device', default=[0], help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--checkpoint', default='./checkpoints/model_2_600.pth', help='checkpoint')
args = parser.parse_args()
print(args)
main(args)