import torch import argparse import json from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, pipeline, BitsAndBytesConfig ) from peft import LoraConfig, PeftModel from trl import SFTTrainer def finetune_model(train_data_file, val_data_file, model_folder, log_folder, model_name, num_epochs, lr=2e-4): # Load datasets train_data = load_dataset("json", data_files=train_data_file, split="train") val_data = load_dataset("json", data_files=val_data_file, split="train") print("Training data size: ", len(train_data)) print("Validation data size: ", len(val_data)) print("Example of training data:") print(train_data[0]) # Load tokenizer and model llama_base_model_name = "/fs/clip-scratch/lzhao/repos/llama-hf/llama-2-13b-chat/" # Path to save the new model / adapter weights # optimized_llama_model = "{}/{}".format(model_folder, model_name) llama_tokenizer = AutoTokenizer.from_pretrained(llama_base_model_name) llama_tokenizer.pad_token = llama_tokenizer.eos_token llama_tokenizer.padding_side = "right" quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=False ) llama_base_model = AutoModelForCausalLM.from_pretrained( llama_base_model_name, quantization_config=quant_config, device_map={"": 0} ) llama_base_model.config.use_cache = False llama_base_model.config.pretraining_tp = 1 # LoRA Config peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, r=8, bias="none", task_type="CAUSAL_LM" ) # Training Params training_params = TrainingArguments( output_dir="{}/{}".format(log_folder, model_name), num_train_epochs=1, per_device_train_batch_size=4, gradient_accumulation_steps=1, optim="paged_adamw_32bit", save_steps=200, logging_steps=50, learning_rate=lr, weight_decay=0.001, fp16=False, bf16=False, max_grad_norm=0.3, max_steps=-1, warmup_ratio=0.03, group_by_length=True, lr_scheduler_type="constant", report_to="tensorboard" ) # Trainer llama_fine_tuning = SFTTrainer( model=llama_base_model, train_dataset=train_data, eval_dataset=val_data, peft_config=peft_config, dataset_text_field="text", tokenizer=llama_tokenizer, args=training_params ) # Training best_val_accuracy = 0.0 # best_val_accuracy = validate_model(val_data_file, llama_base_model, llama_tokenizer) for epoch in range(num_epochs): llama_fine_tuning.train() # Compute validation accuracy results = llama_fine_tuning.evaluate() print("Val results: ", results) # print(f"Epoch {epoch + 1}/{num_epochs} - Validation Accuracy: {val_accuracy}") # Compute validation accuracy val_accuracy = validate_model(val_data_file, llama_base_model, llama_tokenizer) print(f"Epoch {epoch + 1}/{num_epochs} - Validation Accuracy: {val_accuracy}") # Save the best model based on validation accuracy if epoch == 0 or val_accuracy > best_val_accuracy: best_val_accuracy = val_accuracy llama_fine_tuning.model.save_pretrained("{}/{}/best".format(model_folder, model_name)) # Save the final model llama_fine_tuning.model.save_pretrained("{}/{}/final".format(model_folder, model_name)) def validate_model(val_data_file, llama_base_model, llama_tokenizer): # load jsonl with open(val_data_file, 'r') as f: val_data = [] for line in f: val_data.append(json.loads(line)) count = 0 count_correct = 0 for item in val_data: prompt = formatting_prompts_func_val(item)[0] print("Prompt: ", prompt) text_gen = pipeline(task="text-generation", model=llama_base_model, tokenizer=llama_tokenizer, max_length=2000) output = text_gen(f"[INST] {prompt} [/INST]") print(output[0]['generated_text']) label = item['truthful_label'] print("Label: ", label) predicted_label = output[0]['generated_text'].split("### Answer:")[1].strip() print("Predicted label: ", predicted_label) if label == 1 and predicted_label == "Truthful": count_correct += 1 elif label == 0 and predicted_label == "Deceptive": count_correct += 1 count += 1 if count == 10: break accuracy = count_correct * 1.0 / count return accuracy def formatting_prompts_func_val(example): output_texts = [] text = f"### Question: Is this hotel review truthful or deceptive?\n{example['review']}\n\n### Answer: " output_texts.append(text) return output_texts def formatting_prompts_func_train(example): output_texts = [] if example["truthful_label"] == 1: answer = "Truthful" else: answer = "Deceptive" text = f"### Question: Is this hotel review truthful or deceptive?\n{example['review']}\n\n### Answer: {answer}\n" output_texts.append(text) return output_texts if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_name", ) parser.add_argument( "--train_data_file", ) parser.add_argument( "--val_data_file", ) args = parser.parse_args() num_epochs = 2 model_folder = "trained_models/" log_folder = "logs/" print("Training data file: ", args.train_data_file) print("Validation data file: ", args.val_data_file) finetune_model(args.train_data_file, args.val_data_file, model_folder, log_folder, args.model_name, num_epochs)