File name
Commit message
Commit date
2023-10-24
2023-10-24
2023-10-24
from flask_restx import Resource, Namespace
from flask import request, jsonify
import os
import json
from database.database import DB
import torch
from torchvision.transforms import ToTensor
from model.AttentiveRNN import AttentiveRNN
from model.Classifier import Resnet as Classifier
from subfuction.image_crop import crop_image
import numpy as np
import cv2
db = DB()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# pre-loading models
arnn = AttentiveRNN(6, 3, 2)
arnn.eval()
arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
arnn.to(device=device)
classifier = Classifier(in_ch=1)
classifier.eval()
classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt"))
classifier.to(device=device)
tf_toTensor = ToTensor()
crop_size = (512, 512)
start_point = (750, 450)
root_dir = os.getcwd()
Action = Namespace(
name="Action",
description="노드 분석을 위해 사용하는 api.",
)
@Action.route('/image_anal')
class fileUpload(Resource):
@Action.doc(responses={200: 'Success'})
@Action.doc(responses={500: 'Register Failed'})
def post(self):
uploaded_file = request.files.get('file')
if not uploaded_file:
return {"message": "No file uploaded"}, 400
json_data = request.form.get('data')
if not json_data:
return {"message": "Missing JSON data"}, 400
data = json.loads(json_data)
lat = float(data['gps_x'])
lon = float(data['gps_y'])
# filename = data['filename']
# file_type = data['file_type']
uploaded_file = request.files.get('file')
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
image = crop_image(image, crop_size, start_point)
image_tensor = tf_toTensor(image)
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to(device)
with torch.no_grad():
image_arnn = arnn(image_tensor)
image_tensor.cpu()
del image_tensor
result = classifier(image_arnn['attention_map_list'][-1])
image_arnn['x'].cpu()
del image_arnn
result = result.to("cpu")
_, predicted = torch.max(result.data, 1)
del result
if predicted == 0:
rain = False
else: # elif result == 1
rain = True
user_id = 'test'
action_success = True
action_id = 'test'
db.db_add_action(action_id, lat, lon, user_id, action_success)
return {
'node': (lat, lon),
'rain': rain,
}, 200
@Action.route('/action_display')
class fileUpload(Resource):
@Action.doc(responses={200: 'Success'})
@Action.doc(responses={500: 'Register Failed'})
def get(self):
if request.method == 'GET':
db = DB()
value = db.db_display_action()
return {
'report': list(value)
}, 200