diff --git a/README.md b/README.md index 2194b04..d933261 100644 --- a/README.md +++ b/README.md @@ -4,24 +4,35 @@ This repo contains the official code for running `LLM persona` experiments and subsequent analyses in the PersonaLLM paper. +## Setup + +```bash +pip install -r requirements.txt +``` + +Set your API keys as environment variables before running any scripts: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +export REPLICATE_API_TOKEN="your-replicate-token" # only needed for llama-2 +``` ## Simulate `LLM personas` We first create 10 personas for each of 32 personality types. ```bash -conda activate audiencenlp -python3.9 run_bfi.py --model "GPT-3.5-turbo-0613" -python3.9 run_bfi.py --model "GPT-4-0613" -python3.9 run_bfi.py --model "llama-2" +python run_bfi.py --model "gpt-3.5-turbo-0613" +python run_bfi.py --model "gpt-4-0613" +python run_bfi.py --model "llama-2" ``` ## Generate stories with `LLM personas` ```bash -python3.9 run_creative_writing.py --model "GPT-3.5-turbo-0613" -python3.9 run_creative_writing.py --model "GPT-4-0613" -python3.9 run_creative_writing.py --model "llama-2" +python run_creative_writing.py --model "gpt-3.5-turbo-0613" +python run_creative_writing.py --model "gpt-4-0613" +python run_creative_writing.py --model "llama-2" ``` ## References diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..adae79b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +openai>=1.0.0 +replicate +tqdm +pandas +numpy +scipy +tenacity diff --git a/run_bfi.py b/run_bfi.py index 7236bd8..2a355f5 100644 --- a/run_bfi.py +++ b/run_bfi.py @@ -1,7 +1,6 @@ import argparse import random import json -import openai import os import pandas as pd import sys @@ -11,13 +10,13 @@ import json import numpy as np import itertools +from openai import OpenAI from gpt import is_answer_in_valid_form from tenacity import retry, stop_after_attempt, wait_random_exponential import multiprocessing import replicate -openai.organization = "" -openai.api_key = "" +client = OpenAI() # reads OPENAI_API_KEY from environment def construct_big_five_words(persona_type): @@ -31,7 +30,7 @@ def construct_big_five_words(persona_type): return ", ".join(options) def run_gpt_query(model_name, temperature, system_prompt, user_prompt): - response = openai.ChatCompletion.create( + response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": system_prompt}, @@ -80,7 +79,7 @@ def generate_bfi_response(model_name, temperature, persona_type, prompt_file): while True: if model_name.lower().startswith("gpt"): response = run_gpt_query(model_name, temperature, system_prompt, user_prompt) - response = response["choices"][0]['message']['content'].strip("\n") + response = response.choices[0].message.content.strip("\n") elif model_name.lower().startswith("llama"): # print("llama-2 generating") # print(system_prompt, user_prompt) diff --git a/run_creative_writing.py b/run_creative_writing.py index dd9a3f3..05ba5a4 100644 --- a/run_creative_writing.py +++ b/run_creative_writing.py @@ -1,7 +1,6 @@ import argparse import random import json -import openai import os import pandas as pd import sys @@ -11,13 +10,13 @@ import json import numpy as np import itertools +from openai import OpenAI # from gpt import run_completion_query from tenacity import retry, stop_after_attempt, wait_random_exponential import multiprocessing import replicate -openai.organization = "" -openai.api_key = "" +client = OpenAI() # reads OPENAI_API_KEY from environment def construct_big_five_words(persona_type: list): """Construct the list of personality traits @@ -30,7 +29,7 @@ def construct_big_five_words(persona_type: list): return ", ".join(options) def run_gpt_query(model_name, temperature, system_prompt, prev_user_prompt, prev_assistant_prompt, user_prompt): - response = openai.ChatCompletion.create( + response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": system_prompt}, @@ -85,7 +84,7 @@ def generate_bfi_story(model_name, temperature, persona_type, prompt_file, json_ if model_name.lower().startswith("gpt"): response = run_gpt_query(model_name, temperature, system_prompt, prev_user_prompt, prev_assistant_prompt, user_prompt) - response = response["choices"][0]['message']['content'].strip("\n") + response = response.choices[0].message.content.strip("\n") elif model_name.lower().startswith("llama"): response = run_llama2_query(temperature, system_prompt, prev_user_prompt, prev_assistant_prompt, user_prompt) response = "".join([each for each in response]).strip("\n").strip()