import numpy as np from collections import defaultdict from sklearn.metrics import f1_score, precision_score, recall_score, precision_recall_fscore_support from PIL import Image, ImageDraw, ImageFont base_dir = "/vulcanscratch/lzhao/data/IncidentsDataset/" prediction_file = base_dir + "predictions/clip_incident_preds.txt" reference_file = "/cfarhomes/lzhao/data/IncidentsDataset/val_incident_reference.txt" image_dir = base_dir + "images/val/" output_image_dir = base_dir + "clip.thres8.incident.visual/" threshold = 8 #font_size = 22 text_fraction = 0.015 #features = ['railroad track', 'beach', 'train accident', 'truck accident', 'flooded', 'on fire', 'derecho', 'construction site', 'house', 'power line', 'sinkhole', 'burned', 'rockslide rockfall', 'building outdoor', 'village', 'industrial area', 'desert road', 'blocked', 'coast', 'nuclear power plant', 'snowslide avalanche', 'field', 'junkyard', 'ice storm', 'ocean', 'traffic jam', 'under construction', 'pier', 'residential neighborhood', 'valley', 'mountain', 'damaged', 'forest road', 'dust devil', 'volcanic eruption', 'glacier', 'thunderstorm', 'street', 'cabin outdoor', 'excavation', 'bus accident', 'sports field', 'airplane accident', 'oil spill', 'dam', 'dirty contamined', 'building facade', 'mudslide mudflow', 'highway', 'religious building', 'earthquake', 'landfill', 'fire whirl', 'van accident', 'collapsed', 'snow covered', 'sky', 'badlands', 'forest', 'drought', 'heavy rainfall', 'tornado', 'ship boat accident', 'oil rig', 'bicycle accident', 'river', 'wildfire', 'parking lot', 'volcano', 'nuclear explosion', 'desert', 'hailstorm', 'storm surge', 'fog', 'car accident', 'fire station', 'with smoke', 'lighthouse', 'downtown', 'dust sand storm', 'motorcycle accident', 'lake natural', 'tropical cyclone', 'park', 'slum', 'snowfield', 'gas station', 'landslide', 'farm', 'skyscraper', 'bridge', 'port'] #features = ['sky', 'coast', 'snowfield', 'ocean', 'religious building', 'bridge', 'building facade', 'desert road', 'lighthouse', 'house', 'fire station', 'port', 'forest', 'mountain', 'desert', 'landfill', 'beach', 'power line', 'village', 'street', 'farm', 'badlands', 'valley', 'residential neighborhood', 'downtown', 'oil rig', 'nuclear power plant', 'field', 'river', 'parking lot', 'volcano', 'building outdoor', 'gas station', 'highway', 'glacier', 'lake natural', 'skyscraper', 'excavation', 'park', 'cabin outdoor', 'forest road', 'construction site', 'industrial area', 'sports field', 'railroad track', 'junkyard', 'slum', 'dam', 'pier'] features = ['burned', 'blocked', 'drought', 'airplane accident', 'rockslide rockfall', 'flooded', 'collapsed', 'nuclear explosion', 'heavy rainfall', 'car accident', 'van accident', 'ship boat accident', 'truck accident', 'snow covered', 'fog', 'hailstorm', 'snowslide avalanche', 'motorcycle accident', 'oil spill', 'train accident', 'storm surge', 'damaged', 'mudslide mudflow', 'wildfire', 'with smoke', 'dust sand storm', 'bicycle accident', 'tornado', 'ice storm', 'on fire', 'sinkhole', 'traffic jam', 'tropical cyclone', 'bus accident', 'thunderstorm', 'landslide', 'under construction', 'fire whirl', 'earthquake', 'volcanic eruption', 'dust devil', 'dirty contamined', 'derecho'] print("Number of features: ", len(features)) shot2ref = {} with open(reference_file) as f_ref: for line in f_ref: shot_name, ref_text = line.split(" ", 1) shot_id = shot_name.split(".")[0] ref_list = ref_text.strip().split(", ") shot2ref[shot_id] = ref_list pred_dict, ref_dict = defaultdict(list), defaultdict(list) count = 0 count_ref = 0 curr_shot_id = None with open(prediction_file) as f_pred: for line in f_pred: shot_id, text = line.strip().split(" ", 1) if not shot_id.startswith("score"): count += 1 pred_list = text.split(", ") if shot_id not in shot2ref: print("Image doesn't have reference", shot_id) continue count_ref += 1 curr_shot_id = shot_id continue if curr_shot_id: img_path = image_dir + curr_shot_id + ".jpg" print(img_path) try: img = Image.open(img_path).convert('RGB') draw = ImageDraw.Draw(img) except: print("Can't open or draw on img: ", img_path) continue score_list = text.split(", ") top_pred_score_list = [] for i in range(len(pred_list)): if float(score_list[i]) >= threshold: top_pred_score_list.append(pred_list[i]) top_pred_score_list.append(score_list[i]) width, height = img.size font_size = int(width * text_fraction) font = ImageFont.truetype("/usr/share/fonts/gnu-free/FreeSerif.ttf", font_size) draw.text((3, 0), " ".join(top_pred_score_list),(255,0,0), font=font) draw.text((3, font_size + 5), " ".join(shot2ref[curr_shot_id]),(255,0,0), font=font) draw.text((3, height-font_size-8), ", ".join([curr_shot_id, "thres="+str(threshold)]), (255,0,0), font=font) img.save(output_image_dir + curr_shot_id + ".jpg") curr_shot_id = None print("Number of drawed outputs with reference: ", count_ref) print("Number of predicted images: ", count)