import os import clip import torch from torchvision.datasets import CIFAR100 from PIL import Image # Inputs data_dir = "/vulcanscratch/lzhao/data/IncidentsDataset/images/val/" output_file = "/vulcanscratch/lzhao/data/IncidentsDataset/predictions/clip_incident_preds.txt" #features = ['railroad track', 'beach', 'train accident', 'truck accident', 'flooded', 'on fire', 'derecho', 'construction site', 'house', 'power line', 'sinkhole', 'burned', 'rockslide rockfall', 'building outdoor', 'village', 'industrial area', 'desert road', 'blocked', 'coast', 'nuclear power plant', 'snowslide avalanche', 'field', 'junkyard', 'ice storm', 'ocean', 'traffic jam', 'under construction', 'pier', 'residential neighborhood', 'valley', 'mountain', 'damaged', 'forest road', 'dust devil', 'volcanic eruption', 'glacier', 'thunderstorm', 'street', 'cabin outdoor', 'excavation', 'bus accident', 'sports field', 'airplane accident', 'oil spill', 'dam', 'dirty contamined', 'building facade', 'mudslide mudflow', 'highway', 'religious building', 'earthquake', 'landfill', 'fire whirl', 'van accident', 'collapsed', 'snow covered', 'sky', 'badlands', 'forest', 'drought', 'heavy rainfall', 'tornado', 'ship boat accident', 'oil rig', 'bicycle accident', 'river', 'wildfire', 'parking lot', 'volcano', 'nuclear explosion', 'desert', 'hailstorm', 'storm surge', 'fog', 'car accident', 'fire station', 'with smoke', 'lighthouse', 'downtown', 'dust sand storm', 'motorcycle accident', 'lake natural', 'tropical cyclone', 'park', 'slum', 'snowfield', 'gas station', 'landslide', 'farm', 'skyscraper', 'bridge', 'port'] features = ['burned', 'blocked', 'drought', 'airplane accident', 'rockslide rockfall', 'flooded', 'collapsed', 'nuclear explosion', 'heavy rainfall', 'car accident', 'van accident', 'ship boat accident', 'truck accident', 'snow covered', 'fog', 'hailstorm', 'snowslide avalanche', 'motorcycle accident', 'oil spill', 'train accident', 'storm surge', 'damaged', 'mudslide mudflow', 'wildfire', 'with smoke', 'dust sand storm', 'bicycle accident', 'tornado', 'ice storm', 'on fire', 'sinkhole', 'traffic jam', 'tropical cyclone', 'bus accident', 'thunderstorm', 'landslide', 'under construction', 'fire whirl', 'earthquake', 'volcanic eruption', 'dust devil', 'dirty contamined', 'derecho'] #features = ['sky', 'coast', 'snowfield', 'ocean', 'religious building', 'bridge', 'building facade', 'desert road', 'lighthouse', 'house', 'fire station', 'port', 'forest', 'mountain', 'desert', 'landfill', 'beach', 'power line', 'village', 'street', 'farm', 'badlands', 'valley', 'residential neighborhood', 'downtown', 'oil rig', 'nuclear power plant', 'field', 'river', 'parking lot', 'volcano', 'building outdoor', 'gas station', 'highway', 'glacier', 'lake natural', 'skyscraper', 'excavation', 'park', 'cabin outdoor', 'forest road', 'construction site', 'industrial area', 'sports field', 'railroad track', 'junkyard', 'slum', 'dam', 'pier'] # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load('ViT-B/32', device) count = 0 with open(output_file, "w") as f_out: for image in os.listdir(data_dir): image_name = os.path.splitext(image)[0].lower() # Prepare the inputs try: image_input = preprocess(Image.open(data_dir+image)).unsqueeze(0).to(device) except: print("Can't open file: ", data_dir+image) continue text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in features]).to(device) # Calculate features with torch.no_grad(): image_features = model.encode_image(image_input) text_features = model.encode_text(text_inputs) # Pick the top k most similar labels for the image image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(32) feature_list = [features[index] for index in indices] feature_scores = [str(round(100 * value.item(), 2)) for value in values] f_out.write("{} {}\n".format(image_name, ", ".join(feature_list))) f_out.write("{} {}\n".format("score_" + image_name, ", ".join(feature_scores))) count += 1 print("Number of images predicted: ", count)