import torch import clip import argparse import json import numpy as np from PIL import Image from torch.utils.data import Dataset, DataLoader import torch.nn as nn from torch.nn import functional as F from torch.autograd import Function from sklearn.model_selection import train_test_split class COCODataset(Dataset): def __init__(self, data): self.n_samples = len(data) columns = list(zip(*data)) self.label = columns[0] self.imdId = columns[1] self.imgName = columns[2] self.caption = columns[3] # support indexing such that dataset[i] can be used to get i-th sample def __getitem__(self, index): return self.imgName[index], self.caption[index], self.label[index] # we can call len(dataset) to return the size def __len__(self): return self.n_samples class GradientReversal(torch.nn.Module): def __init__(self, lambda_=1): super(GradientReversal, self).__init__() self.lambda_ = lambda_ def forward(self, x): return GradientReversalFunction.apply(x, self.lambda_) class GradientReversalFunction(Function): """ Gradient Reversal Layer from: Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) Forward pass is the identity function. In the backward pass, the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) """ @staticmethod def forward(ctx, x, lambda_): ctx.lambda_ = lambda_ return x.clone() @staticmethod def backward(ctx, grads): lambda_ = ctx.lambda_ lambda_ = grads.new_tensor(lambda_) dx = -lambda_ * grads return dx, None def read_dataset(positive_json, negative_json): dataset = [] for label, input_file in [(1, positive_json)]: # TODO: comment #for label, input_file in [(1.0, positive_json), (0.0, negative_json)]: with open(positive_json) as f_in: data = json.load(f_in) for p in data['samples']: #sample = [label, p['imgId'], p['imgName'], " ".join(p['captions'])] sample = [label, p['imgId'], p['imgName'], p['captions'][0]] # TODO: deal with multiple captions dataset.append(sample) return dataset def evaluate(model, data): # TODO with torch.no_grad(): scores = torch.zeros_like(data.label).to(device) y_predicted = self(data.feature) y_predicted_cls = y_predicted.round() acc = y_predicted_cls.eq(data.label).sum() / float(data.label.shape[0]) return acc def random_negative_sample_NCE_loss(x0, x1, K=32): N, C = x0.shape # 0/1 mask candidate_mask = 1 - torch.eye(N).to(device) # convert to ids from 0/1 mask candidate_ids = torch.where(candidate_mask)[1].reshape(N, N-1) # randomly select K out of each column K = 3 subsampled_ids = torch.tensor(np.random.randint(N-1, size=(N, K))).to(device) batch_ids = torch.arange(N)[None,...].expand(K, -1).reshape(-1) neg_ids = candidate_ids[batch_ids, subsampled_ids.reshape(-1)].reshape(N, K) # x1 : N x C -> neg_x1 : N x K x C neg_x1 = x1[neg_ids].reshape(N, K, -1) # all_x1 : N x (K+1) x C all_x1 = torch.cat([x1.unsqueeze(1), neg_x1], 1) # inner product between text and image : N x (K+1) product = (all_x1 @ x0[...,None])[..., 0] correct_pred = ((product[:, :1] > product[:, 1:]).sum(1) == K).sum() acc = correct_pred / N # softmax softmax = nn.Softmax(dim=1) product = softmax(product) labels = torch.zeros(N).to(device).long() nce_loss = nn.CrossEntropyLoss() return nce_loss(product, labels), acc def step(epoch, ex, model, discriminator, optimizer, criterion, adv_criterion, images, captions, labels): """ :param epoch: The current epoch :param ex: Which example / minibatch :param model: The model optimizing :param inputs: The current set of inputs :param labels: The labels for those inputs """ batch_size = len(images) images= torch.cat([preprocess(Image.open(img)).unsqueeze(0) for img in images]).to(device) texts = torch.cat([clip.tokenize(c) for c in captions]).to(device) # Use the softmax cross entropy with the correct pairing (i.e. the identity matrix) as the target. logits_per_image, logits_per_text = model(images, texts) # N, C logits_per_image, logits_per_text = F.normalize(logits_per_image).float(), F.normalize(logits_per_text).float() loss_it, acc_it = random_negative_sample_NCE_loss(logits_per_image, logits_per_text) loss_ti, acc_ti = random_negative_sample_NCE_loss(logits_per_text, logits_per_image) label_loss = ((loss_it + loss_ti) / 2.0).float() mean_acc = ((acc_it + acc_ti) / 2.0).float() # Adversarial training # Prepare domain labels and features image_labels = torch.ones(batch_size).to(device) text_labels = torch.zeros(batch_size).to(device) image_features = model.encode_image(images).float().to(device) text_features = model.encode_text(texts).float().to(device) # Compute preds and loss and accuracy image_preds = discriminator(image_features).squeeze() text_preds = discriminator(text_features).squeeze() adv_image_loss = F.binary_cross_entropy_with_logits(image_preds, image_labels) adv_text_loss = F.binary_cross_entropy_with_logits(text_preds, text_labels) adv_loss = adv_image_loss + adv_text_loss image_predicted_cls = image_preds.round() img_acc = image_predicted_cls.eq(image_labels).sum() / float(batch_size) text_predicted_cls = text_preds.round() txt_acc = text_predicted_cls.eq(text_labels).sum() / float(batch_size) adv_acc = (img_acc + txt_acc) / 2.0 # Backward pass and update, zero grad total_loss = (label_loss + adv_loss).float() #adv_loss.backward() total_loss.backward() optimizer.step() optimizer.zero_grad() if ex % 10 == 0: #acc_train = evaluate(model, train) #acc_test = evaluate(model, test) #print(f'Epoch: {epoch+1}/{num_epochs}, Example {ex}, loss = {total_loss.item():.4f}, train_acc = {acc_train.item():.4f} test_acc = {acc_test.item():.4f}') print(f'Epoch: {epoch+1}/{n_epochs}, Example {ex}, label_loss = {total_loss.item():.4f}, acc = {mean_acc.item():.4f}, adv_loss = {adv_loss.item():.4f}, img_acc = {img_acc.item():.4f}, txt_acc = {txt_acc.item():.4f}') if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--n_epochs", help="Number of passes through train", type=int, default=1) argparser.add_argument("--learning_rate", help="Learning rate for SGD", type=float, default=0.001) argparser.add_argument("--batch", help="Number of items in each batch", type=int, default=64) argparser.add_argument("--positive_json_file", help="positive samples file", type=str, default="val2017_positive.json") argparser.add_argument("--negative_json_file", help="negative samples file", type=str, default="val2017_positive.json") # TODO: prepare negative file args = argparser.parse_args() # Load dataset dataset = read_dataset(args.positive_json_file, args.negative_json_file) train_init, test_init = train_test_split(dataset, test_size=0.15, random_state=1234) train, test = COCODataset(train_init), COCODataset(test_init) # Load model device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) discriminator = nn.Sequential( GradientReversal(), nn.Linear(512, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1) ).to(device) # Training config batch = args.batch criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) adv_criterion = nn.BCELoss() # Fine-tuning train_loader = DataLoader(dataset=train, batch_size=batch, shuffle=True, num_workers=0) dataiter = iter(train_loader) n_epochs = args.n_epochs for epoch in range(n_epochs): for ex, (images, captions, labels) in enumerate(train_loader): # Run training process step(epoch, ex, model, discriminator, optimizer, criterion, adv_criterion, images, captions, labels)