윤영준 윤영준 2023-08-23
fixed inference code
@2ec958df0351f8ea302c1646c925bb0d5d733039
action.py
--- action.py
+++ action.py
@@ -33,8 +33,6 @@
             }, 200
 
 
-
-
 @Action.route('/image_anal')
 class fileUpload(Resource):
     @Action.doc(responses={200: 'Success'})
@@ -48,9 +46,9 @@
             crop_size = (512, 512)
             start_point = (750, 450)
             tf_toTensor = ToTensor()
-            clasifier = Classifier()
-            clasifier.load_state_dict(torch.load("weights/Classifier_512.pt"))
-            clasifier.to(device=device)
+            classifier = Classifier()
+            classifier.load_state_dict(torch.load("weights/Classifier_512.pt"))
+            classifier.to(device=device)
             dir = os.getcwd()
             lat = float(request.json['gps_x'])
             lon = float(request.json['gps_y'])
@@ -61,16 +59,18 @@
             if not image:
                 return {
                     'node': (lat, lon),
-                    'rain': 'rain',
+                    'rain': None,
                 }, 500
             image_tensor = tf_toTensor(image)
-            image_tensor.to(device)
-            image_arnn = AttentiveRNN(image_tensor)
-            result = Classifier(image_arnn)
+            image_tensor = image_tensor.unsqueeze(0)
+            image_tensor = image_tensor.to(device)
+            image_arnn = arnn(image_tensor)
+            result = classifier(image_arnn['x'])
             result = result.to("cpu")
-            if result == 0:
+            _, predicted = torch.max(result.data, 1)
+            if predicted == 0:
                 rain = False
-            else: # elif result == 1
+            else:  # elif result == 1
                 rain = True
             user_id = 'test'
             action_success = True
auth.py
--- auth.py
+++ auth.py
@@ -6,9 +6,6 @@
 import jwt
 
 
-
-
-
 users = {}
 
 Auth = Namespace(
@@ -31,7 +28,6 @@
     'password': fields.String(description='Password', required=True),'email': fields.String(description='email', required=True),'user_sex': fields.String(description='sex', required=True),'phone': fields.String(description='phone', required=True)
 
 })
-
 
 
 @Auth.route('/id')
model/Classifier.py
--- model/Classifier.py
+++ model/Classifier.py
@@ -68,7 +68,7 @@
         )
 
         self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(256, classes)  # assuming 10 classes for the classification
+        self.fc = nn.Linear(256, classes)
 
     def forward(self, x):
         x = self.firstconv(x)
 
requirements.txt (added)
+++ requirements.txt
@@ -0,0 +1,11 @@
+torch~=2.0.1+cu118
+flask~=2.3.3
+jwt~=1.3.1
+werkzeug~=2.3.7
+torchvision~=0.15.2+cu118
+networkx~=3.0
+geojson~=3.0.1
+haversine~=2.8.0
+opencv-python~=4.8.0.76
+joblib~=1.3.2
+pandas~=2.0.3(파일 끝에 줄바꿈 문자 없음)
subfuction/image_crop.py
--- subfuction/image_crop.py
+++ subfuction/image_crop.py
@@ -22,16 +22,16 @@
     # run the cropping function in parallel
     Parallel(n_jobs=-1)(delayed(crop_image)(image_path, output_directory, crop_size, start_point) for image_path in image_paths)
 
+if __name__ == "__main__":
+    output_directory = "/home/takensoft/Pictures/test512_512/rainy/"
 
-output_directory = "/home/takensoft/Pictures/test512_512/rainy/"
+    # get all image paths in the directory
+    # image_paths = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/화창한날 프레임 추출/하드디스크 화창한날(17개)/**/*.png")
+    # image_paths += glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/화창한날 프레임 추출/7월19일 화창한날(8개)/**/*.png")
+    image_paths = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/비오는날 프레임 추출/7월11일 폭우(3개)/**/*.png")
+    image_paths += glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png")
 
-# get all image paths in the directory
-# image_paths = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/화창한날 프레임 추출/하드디스크 화창한날(17개)/**/*.png")
-# image_paths += glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/화창한날 프레임 추출/7월19일 화창한날(8개)/**/*.png")
-image_paths = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/비오는날 프레임 추출/7월11일 폭우(3개)/**/*.png")
-image_paths += glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png")
+    crop_size = (512, 512) # width and height you want for your cropped images
+    start_point = (750, 450) # upper left point where the crop should start
 
-crop_size = (512, 512) # width and height you want for your cropped images
-start_point = (750, 450) # upper left point where the crop should start
-
-crop_images_parallel(image_paths, output_directory, crop_size, start_point)
+    crop_images_parallel(image_paths, output_directory, crop_size, start_point)
Add a comment
List