import numpy as np from collections import defaultdict from sklearn.metrics import f1_score, precision_score, recall_score, precision_recall_fscore_support base_dir = "/vulcanscratch/lzhao/data/IncidentsDataset/" prediction_file = base_dir + "predictions/clip_place_preds.txt" reference_file = "/cfarhomes/lzhao/data/IncidentsDataset/val_place_reference.txt" #features = ['railroad track', 'beach', 'train accident', 'truck accident', 'flooded', 'on fire', 'derecho', 'construction site', 'house', 'power line', 'sinkhole', 'burned', 'rockslide rockfall', 'building outdoor', 'village', 'industrial area', 'desert road', 'blocked', 'coast', 'nuclear power plant', 'snowslide avalanche', 'field', 'junkyard', 'ice storm', 'ocean', 'traffic jam', 'under construction', 'pier', 'residential neighborhood', 'valley', 'mountain', 'damaged', 'forest road', 'dust devil', 'volcanic eruption', 'glacier', 'thunderstorm', 'street', 'cabin outdoor', 'excavation', 'bus accident', 'sports field', 'airplane accident', 'oil spill', 'dam', 'dirty contamined', 'building facade', 'mudslide mudflow', 'highway', 'religious building', 'earthquake', 'landfill', 'fire whirl', 'van accident', 'collapsed', 'snow covered', 'sky', 'badlands', 'forest', 'drought', 'heavy rainfall', 'tornado', 'ship boat accident', 'oil rig', 'bicycle accident', 'river', 'wildfire', 'parking lot', 'volcano', 'nuclear explosion', 'desert', 'hailstorm', 'storm surge', 'fog', 'car accident', 'fire station', 'with smoke', 'lighthouse', 'downtown', 'dust sand storm', 'motorcycle accident', 'lake natural', 'tropical cyclone', 'park', 'slum', 'snowfield', 'gas station', 'landslide', 'farm', 'skyscraper', 'bridge', 'port'] #features = ['burned', 'blocked', 'drought', 'airplane accident', 'rockslide rockfall', 'flooded', 'collapsed', 'nuclear explosion', 'heavy rainfall', 'car accident', 'van accident', 'ship boat accident', 'truck accident', 'snow covered', 'fog', 'hailstorm', 'snowslide avalanche', 'motorcycle accident', 'oil spill', 'train accident', 'storm surge', 'damaged', 'mudslide mudflow', 'wildfire', 'with smoke', 'dust sand storm', 'bicycle accident', 'tornado', 'ice storm', 'on fire', 'sinkhole', 'traffic jam', 'tropical cyclone', 'bus accident', 'thunderstorm', 'landslide', 'under construction', 'fire whirl', 'earthquake', 'volcanic eruption', 'dust devil', 'dirty contamined', 'derecho'] features = ['sky', 'coast', 'snowfield', 'ocean', 'religious building', 'bridge', 'building facade', 'desert road', 'lighthouse', 'house', 'fire station', 'port', 'forest', 'mountain', 'desert', 'landfill', 'beach', 'power line', 'village', 'street', 'farm', 'badlands', 'valley', 'residential neighborhood', 'downtown', 'oil rig', 'nuclear power plant', 'field', 'river', 'parking lot', 'volcano', 'building outdoor', 'gas station', 'highway', 'glacier', 'lake natural', 'skyscraper', 'excavation', 'park', 'cabin outdoor', 'forest road', 'construction site', 'industrial area', 'sports field', 'railroad track', 'junkyard', 'slum', 'dam', 'pier'] top_k = 10 shot2ref = {} with open(reference_file) as f_ref: for line in f_ref: shot_name, ref_text = line.split(" ", 1) shot_id = shot_name.split(".")[0] ref_list = ref_text.strip().split(", ") shot2ref[shot_id] = ref_list feature2score = defaultdict(list) count = 0 curr_shot_id = None with open(prediction_file) as f_pred: for line in f_pred: shot_id, text = line.split(" ", 1) if not shot_id.startswith("score"): pred_list = text.split(", ") if shot_id not in shot2ref: continue count += 1 curr_shot_id = shot_id continue if curr_shot_id: score_list = text.split(", ") #top_pred_list = [] for i in range(len(pred_list)): feature2score[pred_list[i]].append((float(score_list[i]),curr_shot_id)) curr_shot_id = None print("Number of outputs: ", count) for feature in feature2score.keys(): sorted_list = sorted(feature2score[feature], key=lambda x:x[0]) low_k = sorted_list[:top_k] high_k = sorted_list[-top_k:] print("\nFeature {} highest predictions:".format(feature)) for score, name in reversed(high_k): print("{}: {}".format(name, str(score))) print("\nFeature {} lowest predictions:".format(feature)) for score, name in low_k: print("{}: {}".format(name, str(score)))