Skip to content

Commit

Permalink
Update run_llm_in_parallel with customizable threads
Browse files Browse the repository at this point in the history
Summary:
Currently, our repository uses `--run-llm-in-parallel` as a boolean flag with 20 threads hardcoded. This update makes it configurable:
1. When it's not provided, it defaults to 1 thread.
2. When set as a flag `--run-llm-in-parallel`, it defaults to 16 threads.
3. Users can also specify the number of threads, such as `--run-llm-in-parallel=8`.

Reviewed By: onionymous

Differential Revision: D62206301

fbshipit-source-id: cbba3952d11c57e57525ff0edd3e7931e765844b
  • Loading branch information
SimonWan authored and facebook-github-bot committed Sep 26, 2024
1 parent 89faa7b commit ca35211
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def generator_interrupt_handler(
exit(0)

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
"""
Processes a dataset of prompts by sending them to the LLM and saving the responses in a file. This is called before run().
Expand All @@ -268,9 +268,7 @@ def query_llm_to_generate_responses(

self.write_responses()

async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
"""
This method runs the grading part of the benchmark after the LLM has been prompted for responses.
Args:
Expand Down
4 changes: 3 additions & 1 deletion CybersecurityBenchmarks/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def extract_content_in_code_blocks(self, input: str) -> list[str]:

@abstractmethod
def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self,
prompt_path: Path,
run_llm_in_parallel: int = 1,
) -> None:
"""
Processes a dataset of prompts by sending them to the LLM and saving the responses in a file.
Expand Down
6 changes: 2 additions & 4 deletions CybersecurityBenchmarks/benchmark/canary_exploit_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def return_kind(cls) -> list[str]:
return ["canary-exploit"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
"""
Processes a dataset of prompts by sending them to the LLM and saving the responses in a file.
Expand All @@ -77,7 +77,5 @@ def query_llm_to_generate_responses(

self.response_path.write_text(json.dumps(response_result, indent=4))

async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
verify_responses(self.response_path, self.judge_response_path, self.stat_path)
6 changes: 2 additions & 4 deletions CybersecurityBenchmarks/benchmark/frr_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def return_kind(cls) -> list[str]:
return ["frr"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
"""
Processes a dataset of prompts by sending them to the LLM and saving the responses in a file.
Expand Down Expand Up @@ -107,8 +107,6 @@ def query_llm_to_generate_responses(
updated_response_result.append(updated_response)
self.response_path.write_text(json.dumps(updated_response_result, indent=4))

async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
if self.stat_path is not None:
output_stat(self.response_path, self.stat_path)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def return_kind(cls) -> list[str]:
return ["autocomplete", "instruct"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
# The query requires the valid prompt path
response_result = []
Expand Down
6 changes: 2 additions & 4 deletions CybersecurityBenchmarks/benchmark/interpreter_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def return_kind(cls) -> list[str]:
return ["interpreter"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
"""
This method queries the language model to generate responses. It takes as input
Expand All @@ -159,9 +159,7 @@ def query_llm_to_generate_responses(

self.response_path.write_text(json.dumps(response_result, indent=4))

async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
"""
This method runs the benchmark. It checks if the judge and expansion language models
are provided, logs the process, and extends the judge response results.
Expand Down
6 changes: 2 additions & 4 deletions CybersecurityBenchmarks/benchmark/mitre_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def return_kind(cls) -> list[str]:
return ["mitre"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
"""
This method queries the language model to generate responses. It takes as input
Expand All @@ -110,9 +110,7 @@ def query_llm_to_generate_responses(

self.response_path.write_text(json.dumps(response_result, indent=4))

async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
"""
This method runs the benchmark. It checks if the judge and expansion language models
are provided, logs the process, and extends the judge response results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def return_kind(cls) -> list[str]:
return ["multiturn-phishing"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
"""
Processes a dataset of prompts by sending them to the LLM and saving the responses in a file.
Expand All @@ -294,9 +294,7 @@ def query_llm_to_generate_responses(
)
self.response_path.write_text(json.dumps(response_result, indent=4))

async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
j_llm: LLM

if self.judge_llm is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def return_kind(cls) -> list[str]:
return ["prompt-injection"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
"""
Processes a dataset of prompts by sending them to the LLM and saving the responses in a file.
Expand Down Expand Up @@ -101,9 +101,7 @@ def was_injection_successful(response: Dict[str, str]) -> bool:
"""
return "yes" in response["judge_response"].lower()

async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
j_llm: LLM

if self.judge_llm is None:
Expand Down
8 changes: 4 additions & 4 deletions CybersecurityBenchmarks/benchmark/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def process_prompt(
def query_llm_to_generate_responses(
llms: Dict[str, LLM],
prompt_path: Path,
is_parallel: bool = False,
num_threads: int = 20,
run_llm_in_parallel: int = 1,
to_process_prompt: Callable[
[
int,
Expand Down Expand Up @@ -233,8 +232,9 @@ def query_llm_to_generate_responses(
prompt_dataset = prompt_dataset[::step][:num_test_cases]
response_result = []

if is_parallel:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Run the LLMs in parallel if run_llm_in_parallel > 1
if run_llm_in_parallel > 1:
with ThreadPoolExecutor(max_workers=run_llm_in_parallel) as executor:
futures = [
(
executor.submit(
Expand Down
8 changes: 6 additions & 2 deletions CybersecurityBenchmarks/benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,13 @@ def main(

parser.add_argument(
"--run-llm-in-parallel",
action="store_true",
help="If this flag is set, we will make parallel calls to the LLM to speed up execution.",
nargs="?",
type=int,
const=16,
default=1,
help="Specify the number of parallel calls to the LLM to speed up execution. Default is 1.",
)

parser.add_argument(
"--num-queries-per-prompt",
help="Set the number of queries per prompt. Default value is 1.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def return_kind(cls) -> list[str]:

@override
def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
self, prompt_path: Path, run_llm_in_parallel: int = 1
) -> None:
self._validate_data_directory(prompt_path)

Expand All @@ -120,9 +120,7 @@ def query_llm_to_generate_responses(
self.response_path.write_text(json.dumps(response_result, indent=4))

@override
async def run(
self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
) -> None:
async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> None:
logging.info(f"Judging the responses using {self.judge_llm.model}")

judge_response_result = query_llm_to_generate_responses(
Expand Down

0 comments on commit ca35211

Please sign in to comment.