File name
Commit message
Commit date
from flask_restx import Resource, Namespace
from flask import request
from werkzeug.utils import secure_filename
import os
from database.database import DB
import torch
from torchvision.transforms import ToTensor
from datetime import datetime
from model.AttentiveRNN import AttentiveRNN
from model.Classifier import Resnet as Classifier
from subfuction.image_crop import crop_image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
paths = os.getcwd()
Action = Namespace(
name="Action",
description="노드 분석을 위해 사용하는 api.",
)
@Action.route('/image_summit')
class fileUpload(Resource):
@Action.doc(responses={200: 'Success'})
@Action.doc(responses={500: 'Register Failed'})
def post(self):
if request.method == 'POST':
f = request.files['file']
f.save(secure_filename(f.filename))
return {
'save': 'done' # str으로 반환하여 return
}, 200
@Action.route('/image_anal')
class fileUpload(Resource):
@Action.doc(responses={200: 'Success'})
@Action.doc(responses={500: 'Register Failed'})
def post(self):
if request.method == 'POST':
db = DB()
arnn = AttentiveRNN(6, 3, 2)
arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
arnn.to(device=device)
crop_size = (512, 512)
start_point = (750, 450)
tf_toTensor = ToTensor()
clasifier = Classifier()
clasifier.load_state_dict(torch.load("weights/Classifier_512.pt"))
clasifier.to(device=device)
dir = os.getcwd()
lat = float(request.json['gps_x'])
lon = float(request.json['gps_y'])
filename = request.json['filename']
file_type = request.json['file_type']
total_path = dir + os.path.sep + filename + file_type
image = crop_image(total_path, crop_size, start_point)
if not image:
return {
'node': (lat, lon),
'rain': 'rain',
}, 500
image_tensor = tf_toTensor(image)
image_tensor.to(device)
image_arnn = AttentiveRNN(image_tensor)
result = Classifier(image_arnn)
result = result.to("cpu")
if result == 0:
rain = False
else: # elif result == 1
rain = True
user_id = 'test'
action_success = True
action_id = 'test'
db.db_add_action(action_id, lat, lon, user_id, action_success)
return {
'node': (lat, lon),
'rain': rain,
}, 200
@Action.route('/action_display')
class fileUpload(Resource):
@Action.doc(responses={200: 'Success'})
@Action.doc(responses={500: 'Register Failed'})
def post(self):
if request.method == 'GET':
db = DB()
now = datetime.now()
d = now.strftime('%Y-%m-%d %X')
value = db.db_display_action(d)
return {
'report': list(value)
}, 200