import torch import clip import argparse import json import os import numpy as np from PIL import Image from torch.utils.data import Dataset, DataLoader import torch.nn as nn from torch.nn import functional as F from torch.autograd import Function from sklearn.model_selection import train_test_split class COCODataset(Dataset): def __init__(self, data): self.n_samples = len(data) columns = list(zip(*data)) self.label = columns[0] self.imdId = columns[1] self.imgName = columns[2] self.caption = columns[3] # support indexing such that dataset[i] can be used to get i-th sample def __getitem__(self, index): return self.imgName[index], self.caption[index], self.label[index] # we can call len(dataset) to return the size def __len__(self): return self.n_samples class GradientReversal(torch.nn.Module): def __init__(self, lambda_=1): super(GradientReversal, self).__init__() self.lambda_ = lambda_ def forward(self, x): return GradientReversalFunction.apply(x, self.lambda_) class GradientReversalFunction(Function): """ Gradient Reversal Layer from: Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) Forward pass is the identity function. In the backward pass, the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) """ @staticmethod def forward(ctx, x, lambda_): ctx.lambda_ = lambda_ return x.clone() @staticmethod def backward(ctx, grads): lambda_ = ctx.lambda_ lambda_ = grads.new_tensor(lambda_) dx = -lambda_ * grads return dx, None def read_dataset(positive_json, negative_json): dataset = [] for label, input_file in [(1, positive_json)]: # TODO: comment #for label, input_file in [(1.0, positive_json), (0.0, negative_json)]: with open(positive_json) as f_in: data = json.load(f_in) for p in data['samples']: #sample = [label, p['imgId'], p['imgName'], " ".join(p['captions'])] sample = [label, p['imgId'], p['imgName'], p['captions'][0]] # TODO: deal with multiple captions dataset.append(sample) return dataset def compute_recall(sim, topks): results = [] for topk in topks: topk_v = torch.topk(sim, topk, dim=1, largest=True)[0][:,-1:] mask = (sim >= topk_v) * torch.eye(sim.shape[0]) results.append(1.0 * mask.sum() / sim.shape[0]) return results def evaluate(model, data): # TODO data_loader = DataLoader(data, batch_size=32) image_output_list = [] text_output_list = [] total_num = len(data.label) with torch.no_grad(): count = 0 for images, captions, labels in data_loader: images= torch.cat([preprocess(Image.open(img)).unsqueeze(0) for img in images]).to(device) texts = torch.cat([clip.tokenize(c) for c in captions]).to(device) image_feat, text_feat = model.encode_image(images), model.encode_text(texts) image_output_list.append(image_feat) text_output_list.append(text_feat) count += len(images) #print("{}/{} tokens have been processed.".format(count, total_num)) pairwise_similarity = torch.zeros(total_num, total_num) icount = 0 for img_output in image_output_list: jcount = 0 for txt_output in text_output_list: pairwise_similarity[icount:icount+len(img_output), jcount:jcount+len(txt_output)] = img_output @ txt_output.transpose(1, 0) jcount += len(txt_output) icount += len(img_output) #print("{}/{} pairs of similarity has been calculated.".format(icount * total_num, total_num**2)) # compute image: R@1, R@5, R@10, text: R@1, R@5, R@10 topks = [1, 5, 10] recall_image2text = compute_recall(pairwise_similarity, topks) recall_text2image = compute_recall(pairwise_similarity.transpose(1, 0), topks) print("Text Retrieval:") for topk, rec in zip(topks, recall_image2text): print("R@{} = {:0.4f}".format(topk, rec)) print("Image Retrieval :") for topk, rec in zip(topks, recall_text2image): print("R@{} = {:0.4f}".format(topk, rec)) def random_negative_sample_NCE_loss(x0, x1, K=32): N, C = x0.shape # 0/1 mask candidate_mask = 1 - torch.eye(N).to(x0.device) # convert to ids from 0/1 mask candidate_ids = torch.where(candidate_mask)[1].reshape(N, N-1) # randomly select K out of each column K = 3 subsampled_ids = torch.tensor(np.random.randint(N-1, size=(N, K))).to(candidate_ids.device) batch_ids = torch.arange(N)[None,...].expand(K, -1).reshape(-1) neg_ids = candidate_ids[batch_ids, subsampled_ids.reshape(-1)].reshape(N, K) # x1 : N x C -> neg_x1 : N x K x C neg_x1 = x1[neg_ids].reshape(N, K, -1) # all_x1 : N x (K+1) x C all_x1 = torch.cat([x1.unsqueeze(1), neg_x1], 1) # inner product between text and image : N x (K+1) product = (all_x1 @ x0[...,None])[..., 0] correct_pred = ((product[:, :1] > product[:, 1:]).sum(1) == K).sum() acc = correct_pred / N # softmax softmax = nn.Softmax(dim=1) product = softmax(product) labels = torch.zeros(N).to(product.device).long() nce_loss = nn.CrossEntropyLoss() return nce_loss(product, labels), acc def step(epoch, ex, model, discriminator, optimizer, criterion, adv_criterion, images, captions, labels): """ :param epoch: The current epoch :param ex: Which example / minibatch :param model: The model optimizing :param inputs: The current set of inputs :param labels: The labels for those inputs """ batch_size = len(images) images= torch.cat([preprocess(Image.open(img)).unsqueeze(0) for img in images]).to(device) texts = torch.cat([clip.tokenize(c) for c in captions]).to(device) # Use the softmax cross entropy with the correct pairing (i.e. the identity matrix) as the target. image_feat, text_feat = model.encode_image(images), model.encode_text(texts) # N, C image_feat, text_feat = F.normalize(image_feat), F.normalize(text_feat) loss_it, acc_it = random_negative_sample_NCE_loss(image_feat, text_feat) loss_ti, acc_ti = random_negative_sample_NCE_loss(text_feat, image_feat) label_loss = (loss_it + loss_ti) / 2 mean_acc = (acc_it + acc_ti) / 2 # Adversarial training # Prepare domain labels and features image_labels = torch.ones(batch_size).to(device) text_labels = torch.zeros(batch_size).to(device) image_features = model.encode_image(images).to(device).float() text_features = model.encode_text(texts).to(device).float() # Compute preds and loss and accuracy image_preds = discriminator(image_features).squeeze() text_preds = discriminator(text_features).squeeze() adv_image_loss = F.binary_cross_entropy_with_logits(image_preds, image_labels) adv_text_loss = F.binary_cross_entropy_with_logits(text_preds, text_labels) adv_loss = (adv_image_loss + adv_text_loss) / 2 image_predicted_cls = image_preds.round() img_acc = image_predicted_cls.eq(image_labels).sum() / float(batch_size) text_predicted_cls = text_preds.round() txt_acc = text_predicted_cls.eq(text_labels).sum() / float(batch_size) adv_acc = (img_acc + txt_acc) / 2.0 # Backward pass and update, zero grad total_loss = label_loss + adv_loss total_loss.backward() optimizer.step() optimizer.zero_grad() if ex % 10 == 0: #acc_train = evaluate(model, train) #acc_test = evaluate(model, test) #print(f'Epoch: {epoch+1}/{num_epochs}, Example {ex}, loss = {total_loss.item():.4f}, train_acc = {acc_train.item():.4f} test_acc = {acc_test.item():.4f}') #print(f'Epoch: {epoch+1}/{n_epochs}, Example {ex}, label_loss = {total_loss.item():.4f}, label_acc = {mean_acc.item():.4f}, adv_loss = {adv_loss.item():.4f}, img_acc = {img_acc.item():.4f}, txt_acc = {txt_acc.item():.4f}') print(f'Epoch: {epoch+1}/{n_epochs}, Example {ex}, label_loss = {total_loss.item():.4f}, label_acc = {mean_acc.item():.4f}, adv_loss = {adv_loss.item():.4f}') if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--n_epochs", help="Number of passes through train", type=int, default=20) argparser.add_argument("--learning_rate", help="Learning rate for SGD", type=float, default=0.001) argparser.add_argument("--lambda_adv", help="Lambda for gradient reversal layer", type=float, default=0.1) argparser.add_argument("--batch", help="Number of items in each batch", type=int, default=64) argparser.add_argument("--save_path", help="path to save the weights", type=str, default="checkpoints") argparser.add_argument("--positive_json_file", help="positive samples file", type=str, default="val2017_positive.json") argparser.add_argument("--negative_json_file", help="negative samples file", type=str, default="val2017_positive.json") # TODO: prepare negative file argparser.add_argument("--test_only", action='store_true', help='set to test only.') argparser.add_argument("--test_model", help='test model checkpoint suffix.', default="lr0.001-lbd0.1") argparser.add_argument("--use_original_weight", action='store_true', help='set to test only.') args = argparser.parse_args() # Load dataset dataset = read_dataset(args.positive_json_file, args.negative_json_file) train_init, test_init = train_test_split(dataset, test_size=0.15, random_state=1234) train, test = COCODataset(train_init), COCODataset(test_init) # Load model device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) discriminator = nn.Sequential( GradientReversal(lambda_=args.lambda_adv), nn.Linear(512, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1) ).to(device) if not args.test_only: # Training config batch = args.batch criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) adv_criterion = nn.BCELoss() # Fine-tuning train_loader = DataLoader(dataset=train, batch_size=batch, shuffle=True, num_workers=0) dataiter = iter(train_loader) n_epochs = args.n_epochs if not os.path.exists(args.save_path): os.makedirs(args.save_path) for epoch in range(n_epochs): for ex, (images, captions, labels) in enumerate(train_loader): # Run training process step(epoch, ex, model, discriminator, optimizer, criterion, adv_criterion, images, captions, labels) if epoch % 5 == 4: torch.save(model.state_dict(), os.path.join(args.save_path, "checkpoint{:04d}-lr{}-lbd{}.pth".format(epoch, args.learning_rate, args.lambda_adv))) torch.save(model.state_dict(), os.path.join(args.save_path, "checkpoint-lr{}-lbd{}.pth".format(epoch, args.learning_rate, args.lambda_adv))) else: if not args.use_original_weight: model.load_state_dict(torch.load(os.path.join(args.save_path, "checkpoint-{}.pth".format(args.test_model)))) if not args.use_original_weight: print("CLIP with finetune:") else: print("CLIP without finetune:") evaluate(model, test)