fixed a critical mistake, misaligned image input (GRB->RGB)
@f19e8cf81a43a0854454868e577f7e4ec9722d99
--- action.py
+++ action.py
... | ... | @@ -62,6 +62,8 @@ |
62 | 62 |
|
63 | 63 |
image = crop_image(image, crop_size, start_point) |
64 | 64 |
|
65 |
+ image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) |
|
66 |
+ |
|
65 | 67 |
image_tensor = tf_toTensor(image) |
66 | 68 |
image_tensor = image_tensor.unsqueeze(0) |
67 | 69 |
image_tensor = image_tensor.to(device) |
--- demonstration.py
+++ demonstration.py
... | ... | @@ -28,8 +28,8 @@ |
28 | 28 |
classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt")) |
29 | 29 |
classifier.to(device=device) |
30 | 30 |
classifier.eval() |
31 |
- rainy_data_path = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/SUNNY/**/**/*.png") |
|
32 |
- # rainy_data_path = glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png") |
|
31 |
+ # rainy_data_path = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/SUNNY/**/**/*.png") |
|
32 |
+ rainy_data_path = glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png") |
|
33 | 33 |
img_path = rainy_data_path |
34 | 34 |
# clean_data_path = glob.glob("/home/takensoft/Documents/AttentiveRNNClassifier/output/original/*.png") |
35 | 35 |
|
... | ... | @@ -46,7 +46,8 @@ |
46 | 46 |
image_tensor = tf_toTensor(image) |
47 | 47 |
image_tensor = image_tensor.unsqueeze(0) |
48 | 48 |
image_tensor = image_tensor.to(device) |
49 |
- image_arnn = arnn(image_tensor) |
|
49 |
+ with torch.no_grad(): |
|
50 |
+ image_arnn = arnn(image_tensor) |
|
50 | 51 |
|
51 | 52 |
input_win = 'input_window' |
52 | 53 |
attention_map_wins = [f'attention_map_{i}' for i in range(6)] |
... | ... | @@ -66,7 +67,8 @@ |
66 | 67 |
win=attention_map_wins[idx] |
67 | 68 |
) |
68 | 69 |
# arnn_result = normalize(image_arnn['x']) |
69 |
- result = classifier(image_arnn['attention_map_list'][-1]) |
|
70 |
+ with torch.no_grad(): |
|
71 |
+ result = classifier(image_arnn['attention_map_list'][-1]) |
|
70 | 72 |
result = result.to("cpu") |
71 | 73 |
_, predicted = torch.max(result.data, 1) |
72 | 74 |
print(result.data) |
+++ demonstration_nogui.py
... | ... | @@ -0,0 +1,114 @@ |
1 | +import random | |
2 | +import time | |
3 | +import visdom | |
4 | +import glob | |
5 | +import torch | |
6 | +import cv2 | |
7 | +from torchvision.transforms import ToTensor, Compose, Normalize | |
8 | +from flask import request | |
9 | +from model.AttentiveRNN import AttentiveRNN | |
10 | +from model.Classifier import Resnet as Classifier | |
11 | +from subfuction.image_crop import crop_image | |
12 | + | |
13 | +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
14 | + | |
15 | +# execute visdom instance first | |
16 | +# to do that, install visdom via pip and execute in terminal | |
17 | + | |
18 | +def process_image(): | |
19 | + # vis = visdom.Visdom() | |
20 | + arnn = AttentiveRNN(6, 3, 2) | |
21 | + arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt")) | |
22 | + arnn.to(device=device) | |
23 | + arnn.eval() | |
24 | + crop_size = (512, 512) | |
25 | + start_point = (750, 450) | |
26 | + tf_toTensor = ToTensor() | |
27 | + classifier = Classifier(in_ch=1) | |
28 | + classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt")) | |
29 | + classifier.to(device=device) | |
30 | + classifier.eval() | |
31 | + rainy_data_path = glob.glob("/home/takensoft/Desktop/KOLAS_TEST/정상/*.png") | |
32 | + # rainy_data_path = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/RAIN/**/**/*.png") | |
33 | + # rainy_data_path = glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png") | |
34 | + img_path = rainy_data_path | |
35 | + # clean_data_path = glob.glob("/home/takensoft/Documents/AttentiveRNNClassifier/output/original/*.png") | |
36 | + | |
37 | + # img_path = rainy_data_path + clean_data_path | |
38 | + # normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
39 | + random.shuffle(img_path) | |
40 | + | |
41 | + index = [] | |
42 | + | |
43 | + for i in iter(range(len(img_path))): | |
44 | + ori_img = cv2.imread(img_path[i]) | |
45 | + image = crop_image(ori_img, crop_size, start_point) | |
46 | + if not image.any(): | |
47 | + continue | |
48 | + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
49 | + image_tensor = tf_toTensor(image) | |
50 | + image_tensor = image_tensor.unsqueeze(0) | |
51 | + image_tensor = image_tensor.to(device) | |
52 | + with torch.no_grad(): | |
53 | + image_arnn = arnn(image_tensor) | |
54 | + | |
55 | + input_win = 'input_window' | |
56 | + attention_map_wins = [f'attention_map_{i}' for i in range(6)] | |
57 | + prediction_win = 'prediction_window' | |
58 | + | |
59 | + # Visualize attention maps using visdom | |
60 | + # vis.images( | |
61 | + # image_tensor, | |
62 | + # opts=dict(title=f"input"), | |
63 | + # win=input_win | |
64 | + # ) | |
65 | + # for idx, attention_map in enumerate(image_arnn['attention_map_list']): | |
66 | + # if idx == 0 or idx == 5: | |
67 | + # vis.images( | |
68 | + # attention_map.cpu(), # Expected shape: (batch_size, C, H, W) | |
69 | + # opts=dict(title=f'Attention Map {idx + 1}'), | |
70 | + # win=attention_map_wins[idx] | |
71 | + # ) | |
72 | + # arnn_result = normalize(image_arnn['x']) | |
73 | + with torch.no_grad(): | |
74 | + result = classifier(image_arnn['attention_map_list'][-1]) | |
75 | + result = result.to("cpu") | |
76 | + _, predicted = torch.max(result.data, 1) | |
77 | + print(result.data) | |
78 | + print(_) | |
79 | + print(predicted) | |
80 | + index += (predicted) | |
81 | + # # Load and display the corresponding icon | |
82 | + # if predicted == 0: | |
83 | + # icon_path = 'asset/sun-svgrepo-com.png' | |
84 | + # else: # elif result == 1 | |
85 | + # icon_path = 'asset/rain-svgrepo-com.png' | |
86 | + | |
87 | + # # Load icon and convert to tensor | |
88 | + # icon_image = cv2.imread(icon_path, cv2.IMREAD_UNCHANGED) | |
89 | + # transform = Compose([ | |
90 | + # ToTensor() | |
91 | + # ]) | |
92 | + # icon_tensor = transform(icon_image).unsqueeze(0) # Add batch dimension | |
93 | + | |
94 | + # # Visualize icon using visdom | |
95 | + # vis.images( | |
96 | + # icon_tensor, | |
97 | + # opts=dict(title='Weather Prediction'), | |
98 | + # win=prediction_win | |
99 | + # ) | |
100 | + # time.sleep(1) | |
101 | + print(index) | |
102 | + # result = classifier(image_arnn['x']) | |
103 | + # result = result.to("cpu") | |
104 | + # _, predicted = torch.max(result.data, 1) | |
105 | + # if predicted == 0: | |
106 | + # rain = False | |
107 | + # else: # elif result == 1 | |
108 | + # rain = True | |
109 | + # return { | |
110 | + # 'rain': rain, | |
111 | + # }, 200 | |
112 | + | |
113 | +if __name__ == "__main__": | |
114 | + process_image()(파일 끝에 줄바꿈 문자 없음) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?