import json import random import pickle import argparse import torch import torchtext import torch.nn as nn import numpy as np from nltk.corpus import stopwords from nltk.stem import PorterStemmer from transformers import GPT2LMHeadModel, GPT2TokenizerFast from nlp import load_dataset from collections import defaultdict from statistics import mean def read_categories(json_file): with open(json_file) as f: categories = json.load(f) return categories def sample_neighbor(word, word_list, neighbor_range): center = torch.stack([embedding_model[word]]) neighbor_array = torch.stack([embedding_model[neighbor] for neighbor in word_list]) similarity = cos(center, neighbor_array) similarity_array = similarity.numpy() #nearest = torch.argmax(similarity).tolist() sorted_indices = np.argsort(-similarity_array) #ranged_neighbors = sorted_indices[neighbor_range[0]: min(neighbor_range[1], len(word_list))] #sampled_neighbor = random.choice(ranged_neighbors) sampled_neighbor = sorted_indices[min(neighbor_range-1, len(word_list) - 1)] return word_list[sampled_neighbor] def produce_negative_sample(tokens, stopwords_list, wordnet_categories, selected_wordnet_categories, neighbor_range): # return the negative sent with lowest ppl score #count = 0 sent_negative_samples = [] for i in range(len(tokens)): if tokens[i] in stopwords_list: continue for lexname in selected_wordnet_categories: word_list = wordnet_categories[lexname] word_stem = ps.stem(tokens[i]) if word_stem in word_list: sample_token = sample_neighbor(word_stem, list(set(word_list) - set([word_stem])), neighbor_range) sent_negative_sample = tokens[:i] + [sample_token] + tokens[(i+1):] #return sent_negative_sample sent_negative_samples.append(sent_negative_sample) #count += 1 continue #if count == 5: # Maximum negative samples per sentence # return sent_negative_samples #return sent_negative_samples if sent_negative_samples: ppl_scores = [sent_perplexity(" ".join(sent)) for sent in sent_negative_samples] index = np.argmin(np.array(ppl_scores)) return sent_negative_samples[index] else: return tokens def sent_perplexity(sent): encodings = ppl_tokenizer(sent, return_tensors='pt') max_length = ppl_model.config.n_positions stride = 512 lls = [] for i in range(0, encodings.input_ids.size(1), stride): begin_loc = max(i + stride - max_length, 0) end_loc = min(i + stride, encodings.input_ids.size(1)) trg_len = end_loc - i # may be different from stride on last loop #input_ids = encodings.input_ids[:,begin_loc:end_loc] input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) target_ids = input_ids.clone() target_ids[:,:-trg_len] = -100 with torch.no_grad(): outputs = ppl_model(input_ids, labels=target_ids) log_likelihood = outputs[0] * trg_len lls.append(log_likelihood) ppl = torch.exp(torch.stack(lls).sum() / end_loc) return ppl.tolist() def negative_sampling(input_json_file, output_pos_file, output_neg_file, output_pos_text_file, output_neg_text_file, output_video_list_file, wordnet_categories, selected_wordnet_categories, neighbor_range): video_categories = [1, 2, 3, 4] num_samples = 1000 negative_samples, original_samples = defaultdict(list), defaultdict(list) stopwords_list = stopwords.words('english') video_list = [] with open(input_json_file) as f: data = json.load(f) videos_info = data['videos'] for video in videos_info: if video['category'] in video_categories: video_list.append(video['video_id']) sents = data['sentences'] for sent in sents: vid = sent['video_id'] if vid in video_list: caption = sent['caption'] tokens = caption.strip().split() # TODO: text processing sent_negative_sample = produce_negative_sample(tokens, stopwords_list, wordnet_categories, selected_wordnet_categories, neighbor_range) negative_samples[vid].append(sent_negative_sample) original_samples[vid].append(tokens) ppl_scores = [] for video_id in video_list: captions = negative_samples[video_id] captions_ppl_scores = [sent_perplexity(" ".join(sent)) for sent in captions if sent] #ppl_score = sent_perplexity(sent) ppl_scores.append(mean(captions_ppl_scores)) best_indices = np.argsort(np.array(ppl_scores))[:num_samples] best_vids = [video_list[index] for index in best_indices] #for i in range(len(best_indices)): # print(negative_samples[random_videos[best_indices[i]]]) # print(original_samples[random_videos[best_indices[i]]]) # print("\n") selected_negative_samples = {vid:negative_samples[vid] for vid in best_vids} selected_positive_samples = {vid:original_samples[vid] for vid in best_vids} with open(output_neg_file, 'wb') as fp: pickle.dump(selected_negative_samples, fp) with open(output_pos_file, 'wb') as fp: pickle.dump(selected_positive_samples, fp) with open(output_neg_text_file, "w") as f_out: for video, captions in selected_negative_samples.items(): captions_text = [" ".join(caption) for caption in captions] f_out.write("\n\n" + video + ": \n") f_out.write("\n".join(captions_text)) with open(output_pos_text_file, "w") as f_out: for video, captions in selected_positive_samples.items(): captions_text = [" ".join(caption) for caption in captions] f_out.write("\n\n" + video + ": \n") f_out.write("\n".join(captions_text)) video_ids = list(selected_negative_samples.keys()) with open(output_video_list_file, "w") as out_f: out_f.write("\n".join(video_ids)) if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--neighbor_k", help="semantic nearest k neighbor", type=int, default=1, required=True) argparser.add_argument("--selected_wordnet_categories", help="Wordnet supersense category", nargs="+", required=True) argparser.add_argument("--input_json_file", help="input_json_file", type=str, default="/vulcanscratch/lzhao/data/MSRVTT/test_videodatainfo.json", required=False) argparser.add_argument("--wordnet_categories_file", help="wordnet_categories_file", type=str, default="/vulcanscratch/lzhao/exp/collaborative_experts_text_gen/test_wordnet_categories.json", required=False) argparser.add_argument("--output_pos_file", help="output_pos_file", type=str, default="raw_captions_pos.pkl", required=True) argparser.add_argument("--output_neg_file", help="output_neg_file", type=str, default="raw_captions_neg.pkl", required=True) argparser.add_argument("--output_pos_text_file", help="output_pos_text_file", type=str, default="raw_captions_pos.txt", required=True) argparser.add_argument("--output_neg_text_file", help="output_neg_text_file", type=str, default="raw_captions_neg.txt", required=True) argparser.add_argument("--output_video_list_file", help="output_video_list_file", type=str, default="video_list.txt", required=True) args = argparser.parse_args() embedding_model = torchtext.vocab.FastText('simple') cos = nn.CosineSimilarity(dim=1, eps=1e-6) ps = PorterStemmer() device = "cuda" ppl_model = GPT2LMHeadModel.from_pretrained('gpt2-large').to(device) ppl_tokenizer = GPT2TokenizerFast.from_pretrained('gpt2-large') wordnet_categories = read_categories(args.wordnet_categories_file) negative_sampling(args.input_json_file, args.output_pos_file, args.output_neg_file, args.output_pos_text_file, args.output_neg_text_file, args.output_video_list_file, wordnet_categories, args.selected_wordnet_categories, args.neighbor_k)