윤영준 윤영준 2023-11-08
fixed a critical mistake, misaligned neural network
@0106b7baecee1d51788a0c71bc312453e2ce828d
action.py
--- action.py
+++ action.py
@@ -20,9 +20,9 @@
 arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
 arnn.to(device=device)
 
-classifier = Classifier()
+classifier = Classifier(in_ch=1)
 classifier.eval()
-classifier.load_state_dict(torch.load("weights/Classifier_512.pt"))
+classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt"))
 classifier.to(device=device)
 
 tf_toTensor = ToTensor()
@@ -69,7 +69,7 @@
             image_arnn = arnn(image_tensor)
             image_tensor.cpu()
             del image_tensor
-            result = classifier(image_arnn['x'])
+            result = classifier(image_arnn['attention_map_list'][-1])
             image_arnn['x'].cpu()
             del image_arnn
 
app.py
--- app.py
+++ app.py
@@ -18,5 +18,5 @@
 api.add_namespace(Action, '/action')
 
 if __name__ == "__main__":
-    app.run(debug=False, host='0.0.0.0', port=8000)
+    app.run(debug=False, host='0.0.0.0', port=7700)
     print("Flask Start")
(파일 끝에 줄바꿈 문자 없음)
demonstration.py
--- demonstration.py
+++ demonstration.py
@@ -38,7 +38,8 @@
     random.shuffle(img_path)
 
     for i in iter(range(len(img_path))):
-        image = crop_image(img_path[i], crop_size, start_point)
+        ori_img = cv2.imread(img_path[i])
+        image = crop_image(ori_img, crop_size, start_point)
         if not image.any():
             continue
         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -217,8 +217,8 @@
             lstm_feats.append(lstm_feats_i)
         ret = {
             'x' : x,
-            # 'attention_map_list' : attention_map,
-            # 'lstm_feats' : lstm_feats
+            'attention_map_list' : attention_map,
+            'lstm_feats' : lstm_feats
         }
         return ret
 
Add a comment
List