from flask_restx import Resource, Namespace from flask import request from werkzeug.utils import secure_filename import os from database.database import DB import torch from torchvision.transforms import ToTensor from datetime import datetime from model.AttentiveRNN import AttentiveRNN from model.Classifier import Resnet as Classifier from subfuction.image_crop import crop_image device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') paths = os.getcwd() Action = Namespace( name="Action", description="노드 분석을 위해 사용하는 api.", ) @Action.route('/image_summit') class fileUpload(Resource): @Action.doc(responses={200: 'Success'}) @Action.doc(responses={500: 'Register Failed'}) def post(self): if request.method == 'POST': f = request.files['file'] f.save(secure_filename(f.filename)) return { 'save': 'done' # str으로 반환하여 return }, 200 @Action.route('/image_anal') class fileUpload(Resource): @Action.doc(responses={200: 'Success'}) @Action.doc(responses={500: 'Register Failed'}) def post(self): if request.method == 'POST': db = DB() arnn = AttentiveRNN(6, 3, 2) arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt")) arnn.to(device=device) crop_size = (512, 512) start_point = (750, 450) tf_toTensor = ToTensor() classifier = Classifier() classifier.load_state_dict(torch.load("weights/Classifier_512.pt")) classifier.to(device=device) dir = os.getcwd() lat = float(request.json['gps_x']) lon = float(request.json['gps_y']) filename = request.json['filename'] file_type = request.json['file_type'] total_path = dir + os.path.sep + filename + file_type image = crop_image(total_path, crop_size, start_point) if not image: return { 'node': (lat, lon), 'rain': None, }, 500 image_tensor = tf_toTensor(image) image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.to(device) image_arnn = arnn(image_tensor) result = classifier(image_arnn['x']) result = result.to("cpu") _, predicted = torch.max(result.data, 1) if predicted == 0: rain = False else: # elif result == 1 rain = True user_id = 'test' action_success = True action_id = 'test' db.db_add_action(action_id, lat, lon, user_id, action_success) return { 'node': (lat, lon), 'rain': rain, }, 200 @Action.route('/action_display') class fileUpload(Resource): @Action.doc(responses={200: 'Success'}) @Action.doc(responses={500: 'Register Failed'}) def post(self): if request.method == 'GET': db = DB() now = datetime.now() d = now.strftime('%Y-%m-%d %X') value = db.db_display_action(d) return { 'report': list(value) }, 200