File name
Commit message
Commit date
2023-10-24
2023-10-24
2023-10-24
2023-10-24
2023-10-24
2023-10-24
import random
import time
import visdom
import glob
import torch
import cv2
from torchvision.transforms import ToTensor, Compose, Normalize
from flask import request
from model.AttentiveRNN import AttentiveRNN
from model.Classifier import Resnet as Classifier
from subfuction.image_crop import crop_image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# execute visdom instance first
# to do that, install visdom via pip and execute in terminal
def process_image():
vis = visdom.Visdom()
arnn = AttentiveRNN(6, 3, 2)
arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
arnn.to(device=device)
arnn.eval()
crop_size = (512, 512)
start_point = (750, 450)
tf_toTensor = ToTensor()
classifier = Classifier(in_ch=1)
classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt"))
classifier.to(device=device)
classifier.eval()
rainy_data_path = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/SUNNY/**/**/*.png")
# rainy_data_path = glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png")
img_path = rainy_data_path
# clean_data_path = glob.glob("/home/takensoft/Documents/AttentiveRNNClassifier/output/original/*.png")
# img_path = rainy_data_path + clean_data_path
# normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
random.shuffle(img_path)
for i in iter(range(len(img_path))):
image = crop_image(img_path[i], crop_size, start_point)
if not image.any():
continue
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_tensor = tf_toTensor(image)
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to(device)
image_arnn = arnn(image_tensor)
input_win = 'input_window'
attention_map_wins = [f'attention_map_{i}' for i in range(6)]
prediction_win = 'prediction_window'
# Visualize attention maps using visdom
vis.images(
image_tensor,
opts=dict(title=f"input"),
win=input_win
)
for idx, attention_map in enumerate(image_arnn['attention_map_list']):
if idx == 0 or idx == 5:
vis.images(
attention_map.cpu(), # Expected shape: (batch_size, C, H, W)
opts=dict(title=f'Attention Map {idx + 1}'),
win=attention_map_wins[idx]
)
# arnn_result = normalize(image_arnn['x'])
result = classifier(image_arnn['attention_map_list'][-1])
result = result.to("cpu")
_, predicted = torch.max(result.data, 1)
print(result.data)
print(_)
print(predicted)
# Load and display the corresponding icon
if predicted == 0:
icon_path = 'asset/sun-svgrepo-com.png'
else: # elif result == 1
icon_path = 'asset/rain-svgrepo-com.png'
# Load icon and convert to tensor
icon_image = cv2.imread(icon_path, cv2.IMREAD_UNCHANGED)
transform = Compose([
ToTensor()
])
icon_tensor = transform(icon_image).unsqueeze(0) # Add batch dimension
# Visualize icon using visdom
vis.images(
icon_tensor,
opts=dict(title='Weather Prediction'),
win=prediction_win
)
time.sleep(1)
# result = classifier(image_arnn['x'])
# result = result.to("cpu")
# _, predicted = torch.max(result.data, 1)
# if predicted == 0:
# rain = False
# else: # elif result == 1
# rain = True
# return {
# 'rain': rain,
# }, 200
if __name__ == "__main__":
process_image()