import json import random import pickle import argparse import torch import torchtext import torch.nn as nn import numpy as np from nltk.corpus import stopwords from nltk.stem import PorterStemmer from nltk.tokenize import word_tokenize from transformers import GPT2LMHeadModel, GPT2TokenizerFast from nlp import load_dataset from collections import defaultdict from statistics import mean def read_categories(json_file): with open(json_file) as f: categories = json.load(f) return categories def sample_neighbor(word, word_list, neighbor_range): center = torch.stack([embedding_model[word]]) neighbor_array = torch.stack([embedding_model[neighbor] for neighbor in word_list]) similarity = cos(center, neighbor_array) similarity_array = similarity.numpy() #nearest = torch.argmax(similarity).tolist() sorted_indices = np.argsort(-similarity_array) #ranged_neighbors = sorted_indices[neighbor_range[0]: min(neighbor_range[1], len(word_list))] #sampled_neighbor = random.choice(ranged_neighbors) sampled_neighbor = sorted_indices[min(neighbor_range-1, len(word_list) - 1)] return word_list[sampled_neighbor] def produce_negative_sample(tokens, stopwords_list, wordnet_categories, selected_wordnet_categories, neighbor_range): # return the negative sent with lowest ppl score #count = 0 sent_negative_samples = [] for i in range(len(tokens)): if tokens[i] in stopwords_list: continue for lexname in selected_wordnet_categories: word_list = wordnet_categories[lexname] word_stem = ps.stem(tokens[i]) if word_stem in word_list: sample_token = sample_neighbor(word_stem, list(set(word_list) - set([word_stem])), neighbor_range) sent_negative_sample = tokens[:i] + [sample_token] + tokens[(i+1):] #return sent_negative_sample sent_negative_samples.append(sent_negative_sample) #count += 1 continue #if count == 5: # Maximum negative samples per sentence # return sent_negative_samples #return sent_negative_samples if sent_negative_samples: ppl_scores = [sent_perplexity(" ".join(sent)) for sent in sent_negative_samples] index = np.argmin(np.array(ppl_scores)) return sent_negative_samples[index] else: return None def sent_perplexity(sent): encodings = ppl_tokenizer(sent, return_tensors='pt') max_length = ppl_model.config.n_positions stride = 512 lls = [] for i in range(0, encodings.input_ids.size(1), stride): begin_loc = max(i + stride - max_length, 0) end_loc = min(i + stride, encodings.input_ids.size(1)) trg_len = end_loc - i # may be different from stride on last loop input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) target_ids = input_ids.clone() target_ids[:,:-trg_len] = -100 with torch.no_grad(): outputs = ppl_model(input_ids, labels=target_ids) log_likelihood = outputs[0] * trg_len lls.append(log_likelihood) ppl = torch.exp(torch.stack(lls).sum() / end_loc) return ppl.tolist() def negative_sampling(input_json_file, output_json_file, output_pos_text_file, output_neg_text_file, wordnet_categories, selected_wordnet_categories, neighbor_range): #video_categories = [1, 2, 3, 4] #num_samples = 1000 negative_samples, original_samples = defaultdict(list), defaultdict(list) img2names = defaultdict(str) stopwords_list = stopwords.words('english') img_list = [] with open(input_json_file) as f: data = json.load(f) samples = data['samples'] for sample in samples[:1000]: # TODO: delete img_id, file_name, captions = sample['imgId'], sample['imgName'], sample['captions'] modified = False original_sents, negative_sents = [], [] for sent in captions[:1]: tokens = word_tokenize(sent.strip()) # TODO: text processing sent_negative_sample = produce_negative_sample(tokens, stopwords_list, wordnet_categories, selected_wordnet_categories, neighbor_range) original_sents.append(tokens) if sent_negative_sample: modified = True negative_sents.append(sent_negative_sample) else: negative_sents.append(tokens) if modified: img2names[img_id] = file_name negative_samples[img_id] = [" ".join(sent) for sent in negative_sents] original_samples[img_id] = [" ".join(sent) for sent in original_sents] #print("Original: ", original_samples[img_id]) #print("Modified: ", negative_samples[img_id]) img_ids = list(img2names.keys()) ppl_scores = [] for img_id in img_ids: captions = negative_samples[img_id] captions_ppl_scores = [sent_perplexity(sent) for sent in captions if sent] #ppl_score = sent_perplexity(sent) ppl_scores.append(mean(captions_ppl_scores)) best_indices = np.argsort(np.array(ppl_scores)) best_ids = [img_ids[index] for index in best_indices] ranked_samples = {} ranked_samples['samples'] = [{'imgId': img_id, 'imgName': img2names[img_id], 'pos_captions': original_samples[img_id], 'neg_captions': negative_samples[img_id]} for img_id in best_ids] with open(output_json_file, 'w') as fp: json.dump(ranked_samples, fp) with open(output_neg_text_file, "w") as f_out: samples = ranked_samples['samples'] for sample in samples: img_id, file_name, captions = sample['imgId'], sample['imgName'], sample['neg_captions'] #captions_text = [" ".join(caption) for caption in captions] f_out.write("\n\n" + str(img_id) + ": \n") f_out.write("\n".join(captions)) with open(output_pos_text_file, "w") as f_out: samples = ranked_samples['samples'] for sample in samples: img_id, file_name, captions = sample['imgId'], sample['imgName'], sample['pos_captions'] #captions_text = [" ".join(caption) for caption in captions] f_out.write("\n\n" + str(img_id) + ": \n") f_out.write("\n".join(captions)) print("Number of images with captions modified: ", len(img_ids)) if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--neighbor_k", help="semantic nearest k neighbor", type=int, default=1, required=True) argparser.add_argument("--selected_wordnet_categories", help="Wordnet supersense category", nargs="+", required=True) #argparser.add_argument("--input_json_file", help="input_json_file", type=str, default="/vulcanscratch/lzhao/data/MSRVTT/test_videodatainfo.json", required=False) argparser.add_argument("--input_json_file", help="input_json_file", type=str, default="val2017_positive.json", required=False) #argparser.add_argument("--wordnet_categories_file", help="wordnet_categories_file", type=str, default="/vulcanscratch/lzhao/exp/collaborative_experts_text_gen/test_wordnet_categories.json", required=False) argparser.add_argument("--wordnet_categories_file", help="wordnet_categories_file", type=str, default="test_wordnet_categories.json", required=False) argparser.add_argument("--output_json_file", help="output_json_file", type=str, default="outputs/img_captions.json", required=False) argparser.add_argument("--output_pos_text_file", help="output_pos_text_file", type=str, default="outputs/captions_pos.txt", required=False) argparser.add_argument("--output_neg_text_file", help="output_neg_text_file", type=str, default="outputs/captions_neg.txt", required=False) args = argparser.parse_args() embedding_model = torchtext.vocab.FastText('simple') cos = nn.CosineSimilarity(dim=1, eps=1e-6) ps = PorterStemmer() device = "cuda" # "cuda" ppl_model = GPT2LMHeadModel.from_pretrained('gpt2').to(device) ppl_tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') #selected_wordnet_categories = ['noun.person'] #neighbor_range = [0, 1] # sample from range [left, right) #neighbor_k = 1 wordnet_categories = read_categories(args.wordnet_categories_file) #original_word = "woman" #word_list = wordnet_categories['noun.person'] #substitute_word = sample_neighbor(original_word, list(set(word_list) - set([original_word])), args.neighbor_k) negative_sampling(args.input_json_file, args.output_json_file, args.output_pos_text_file, args.output_neg_text_file, wordnet_categories, args.selected_wordnet_categories, args.neighbor_k)