import numpy as np from collections import defaultdict from sklearn.metrics import f1_score, precision_score, recall_score, precision_recall_fscore_support from PIL import Image, ImageDraw, ImageFont base_dir = "/vulcanscratch/lzhao/data/DSDI20/" prediction_file = base_dir + "predictions/frames_top_preds.txt" reference_file = base_dir + "references.txt" image_dir = base_dir + "frames_v2/" output_image_dir = base_dir + "frames_v2.visual.thres8/" features = ["damage", "flooding or water damage", "landslide", "road washout", "rubble or debris", "smoke or fire", "dirt", "grass", "lava", "rocks", "sand", "shrubs", "snow or ice", "trees", "bridge", "building", "dam or levee", "pipes", "utility or power lines or electric towers", "railway", "wireless or radio communication towers", "water tower", "aircraft", "boat", "car", "truck", "flooding", "lake or pond", "ocean", "puddle", "river or stream", "road"] threshold = 8 font_size = 22 shot2ref = {} with open(reference_file) as f_ref: for line in f_ref: shot_id, ref_text = line.split(" ", 1) ref_list = ref_text.split(", ") shot2ref[shot_id] = ref_list pred_dict, ref_dict = defaultdict(list), defaultdict(list) count = 0 font = ImageFont.truetype("/usr/share/fonts/gnu-free/FreeSerif.ttf", font_size) with open(prediction_file) as f_pred: for line in f_pred: shot_id, text = line.split(" ", 1) if shot_id.startswith("shot"): pred_list = text.split(", ") curr_shot_id = shot_id continue img_path = image_dir + curr_shot_id.capitalize() + ".jpg" print(img_path) try: img = Image.open(img_path) draw = ImageDraw.Draw(img) except: continue score_list = text.split(", ") top_pred_score_list = [] for i in range(len(pred_list)): if float(score_list[i]) >= threshold: top_pred_score_list.append(pred_list[i]) top_pred_score_list.append(score_list[i]) width, height = img.size draw.text((3, 0), " ".join(top_pred_score_list),(255,0,0), font=font) draw.text((3, font_size + 5), " ".join(shot2ref[curr_shot_id]),(0,255,0), font=font) draw.text((3, height-font_size-8), ", ".join([curr_shot_id, "thres="+str(threshold)]), (255, 0, 0), font=font) img.save(output_image_dir + curr_shot_id + ".jpg") count += 1 print("Number of drawed outputs: ", count)