import torch import clip import argparse import json import numpy as np from PIL import Image from torch.utils.data import Dataset, DataLoader import torch.nn as nn from torch.nn import functional as F from sklearn.model_selection import train_test_split class COCODataset(Dataset): def __init__(self, data): self.n_samples = len(data) columns = list(zip(*data)) self.label = columns[0] self.imdId = columns[1] self.imgName = columns[2] self.caption = columns[3] # support indexing such that dataset[i] can be used to get i-th sample def __getitem__(self, index): return self.imgName[index], self.caption[index], self.label[index] # we can call len(dataset) to return the size def __len__(self): return self.n_samples def read_dataset(positive_json, negative_json): dataset = [] for label, input_file in [(1, positive_json)]: # TODO: comment #for label, input_file in [(1.0, positive_json), (0.0, negative_json)]: with open(positive_json) as f_in: data = json.load(f_in) for p in data['samples']: #sample = [label, p['imgId'], p['imgName'], " ".join(p['captions'])] sample = [label, p['imgId'], p['imgName'], p['captions'][0]] # TODO: deal with multiple captions dataset.append(sample) return dataset def evaluate(model, data): # TODO with torch.no_grad(): scores = torch.zeros_like(data.label).to(device) y_predicted = self(data.feature) y_predicted_cls = y_predicted.round() acc = y_predicted_cls.eq(data.label).sum() / float(data.label.shape[0]) return acc def random_negative_sample_NCE_loss(x0, x1, K=32): N, C = x0.shape # 0/1 mask candidate_mask = 1 - torch.eye(N).to(x0.device) # convert to ids from 0/1 mask candidate_ids = torch.where(candidate_mask)[1].reshape(N, N-1) # randomly select K out of each column K = 3 subsampled_ids = torch.tensor(np.random.randint(N-1, size=(N, K))).to(candidate_ids.device) batch_ids = torch.arange(N)[None,...].expand(K, -1).reshape(-1) neg_ids = candidate_ids[batch_ids, subsampled_ids.reshape(-1)].reshape(N, K) # x1 : N x C -> neg_x1 : N x K x C neg_x1 = x1[neg_ids].reshape(N, K, -1) # all_x1 : N x (K+1) x C all_x1 = torch.cat([x1.unsqueeze(1), neg_x1], 1) # inner product between text and image : N x (K+1) product = (all_x1 @ x0[...,None])[..., 0] correct_pred = ((product[:, :1] > product[:, 1:]).sum(1) == K).sum() acc = correct_pred / N # softmax softmax = nn.Softmax(dim=1) product = softmax(product) labels = torch.zeros(N).to(product.device).long() nce_loss = nn.CrossEntropyLoss() return nce_loss(product, labels), acc def step(epoch, ex, model, optimizer, criterion, images, captions, labels): """ :param epoch: The current epoch :param ex: Which example / minibatch :param model: The model optimizing :param inputs: The current set of inputs :param labels: The labels for those inputs """ images= torch.cat([preprocess(Image.open(img)).unsqueeze(0) for img in images]).to(device) texts = torch.cat([clip.tokenize(c) for c in captions]).to(device) # Use the softmax cross entropy with the correct pairing (i.e. the identity matrix) as the target. logits_per_image, logits_per_text = model(images, texts) # N, C logits_per_image, logits_per_text = F.normalize(logits_per_image), F.normalize(logits_per_text) loss_it, acc_it = random_negative_sample_NCE_loss(logits_per_image, logits_per_text) loss_ti, acc_ti = random_negative_sample_NCE_loss(logits_per_text, logits_per_image) total_loss = (loss_it + loss_ti) / 2 mean_acc = (acc_it + acc_ti) / 2 # Backward pass and update, zero grad total_loss.backward() optimizer.step() optimizer.zero_grad() if ex % 10 == 0: #acc_train = evaluate(model, train) #acc_test = evaluate(model, test) #print(f'Epoch: {epoch+1}/{num_epochs}, Example {ex}, loss = {total_loss.item():.4f}, train_acc = {acc_train.item():.4f} test_acc = {acc_test.item():.4f}') print(f'Epoch: {epoch+1}/{n_epochs}, Example {ex}, loss = {total_loss.item():.4f}, acc = {mean_acc.item():.4f}') if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--n_epochs", help="Number of passes through train", type=int, default=1) argparser.add_argument("--learning_rate", help="Learning rate for SGD", type=float, default=0.01) argparser.add_argument("--batch", help="Number of items in each batch", type=int, default=64) argparser.add_argument("--positive_json_file", help="positive samples file", type=str, default="val2017_positive.json") argparser.add_argument("--negative_json_file", help="negative samples file", type=str, default="val2017_positive.json") # TODO: prepare negative file args = argparser.parse_args() # Load dataset dataset = read_dataset(args.positive_json_file, args.negative_json_file) train_init, test_init = train_test_split(dataset, test_size=0.15, random_state=1234) train, test = COCODataset(train_init), COCODataset(test_init) # Load model device = "cuda" if torch.cuda.is_available() else "cpu" device = "cpu" model, preprocess = clip.load("ViT-B/32", device=device) # Training config batch = args.batch criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) # Fine-tuning train_loader = DataLoader(dataset=train, batch_size=batch, shuffle=True, num_workers=0) dataiter = iter(train_loader) n_epochs = args.n_epochs for epoch in range(n_epochs): for ex, (images, captions, labels) in enumerate(train_loader): # Run training process step(epoch, ex, model, optimizer, criterion, images, captions, labels)