from flask_restx import Resource, Namespace from flask import request, jsonify import os import json from database.database import DB import torch from torchvision.transforms import ToTensor from model.AttentiveRNN import AttentiveRNN from model.Classifier import Resnet as Classifier from subfuction.image_crop import crop_image import numpy as np import cv2 db = DB() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # pre-loading models arnn = AttentiveRNN(6, 3, 2) arnn.eval() arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt")) arnn.to(device=device) classifier = Classifier(in_ch=1) classifier.eval() classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt")) classifier.to(device=device) tf_toTensor = ToTensor() crop_size = (512, 512) start_point = (750, 450) root_dir = os.getcwd() Action = Namespace( name="Action", description="노드 분석을 위해 사용하는 api.", ) @Action.route('/image_anal') class fileUpload(Resource): @Action.doc(responses={200: 'Success'}) @Action.doc(responses={500: 'Register Failed'}) def post(self): uploaded_file = request.files.get('file') if not uploaded_file: return {"message": "No file uploaded"}, 400 json_data = request.form.get('data') if not json_data: return {"message": "Missing JSON data"}, 400 data = json.loads(json_data) lat = float(data['gps_x']) lon = float(data['gps_y']) # filename = data['filename'] # file_type = data['file_type'] uploaded_file = request.files.get('file') file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) image = crop_image(image, crop_size, start_point) image_tensor = tf_toTensor(image) image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.to(device) with torch.no_grad(): image_arnn = arnn(image_tensor) image_tensor.cpu() del image_tensor result = classifier(image_arnn['attention_map_list'][-1]) image_arnn['x'].cpu() del image_arnn result = result.to("cpu") _, predicted = torch.max(result.data, 1) del result if predicted == 0: rain = False else: # elif result == 1 rain = True user_id = 'test' action_success = True action_id = 'test' db.db_add_action(action_id, lat, lon, user_id, action_success) return { 'node': (lat, lon), 'rain': rain, }, 200 @Action.route('/action_display') class fileUpload(Resource): @Action.doc(responses={200: 'Success'}) @Action.doc(responses={500: 'Register Failed'}) def get(self): if request.method == 'GET': db = DB() value = db.db_display_action() return { 'report': list(value) }, 200