Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Llama.cpp support with quantized models #16

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions rank_llm/rank_vicuna_q.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import re
from typing import Tuple, List, Union, Dict, Any

from fastchat.model import load_model, get_conversation_template, add_model_args
from ftfy import fix_text
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig

from rank_llm import RankLLM, PromptMode
from llama_cpp import Llama, LlamaCache


def replace_number(s):
return re.sub(r"\[(\d+)\]", r"(\1)", s)


class RankVicunaQ(RankLLM):
def __init__(
self,
model: str,
context_size: int,
top_k_candidates: int,
dataset: str,
prompt_mode: PromptMode,
device: str,
num_gpus: int = 1, # AFAIK, support for multiple GPUS is not very good at Llama.cpp.
) -> None:
super().__init__(model, context_size, top_k_candidates, dataset, prompt_mode)
self._device = device
if self._device == "cuda":
assert torch.cuda.is_available()
if prompt_mode != PromptMode.RANK_GPT:
raise ValueError(
f"Unsuported prompt mode: {prompt_mode}. The only prompt mode cuurently supported by vicuna is a slight variation of Rank_GPT prompt."
)
self._llm = Llama(
model_path=model,
n_ctx=context_size,
n_gpu_layers=-1,
verbose=False,
)
self._llm.set_cache(LlamaCache())

def run_llm(self, prompt: str) -> Tuple[str, int]:
output: Dict[str, Any] = self._llm(
prompt, max_tokens=self.max_tokens(), temperature=0.9, top_p=0.6
) # type: ignore
text = output["choices"][0]["text"]
n_tokens = output["usage"]["completion_tokens"]
return text, n_tokens

def num_output_tokens(self) -> int:
return 200

def _add_prefix_prompt(self, query: str, num: int) -> str:
return f"I will provide you with {num} passages, each indicated by a numerical identifier []. Rank the passages based on their relevance to the search query: {query}.\n"

def _add_post_prompt(self, query: str, num: int) -> str:
return f"Search Query: {query}.\nRank the {num} passages above based on their relevance to the search query. All the passages should be included and listed using identifiers, in descending order of relevance. The output format should be [] > [], e.g., [4] > [2], Only respond with the ranking results, do not say any word or explain."

def create_prompt(
self, retrieved_result: Dict[str, Any], rank_start: int, rank_end: int
) -> Tuple[str, int]:
query = retrieved_result["query"]
num = len(retrieved_result["hits"][rank_start:rank_end])
max_length = 300
while True:
conv = get_conversation_template(self._model)
# conv.set_system_message(
# "You are RankVicuna, an intelligent assistant that can rank passages based on their relevancy to the query."
# )
prefix = self._add_prefix_prompt(query, num)
rank = 0
input_context = f"{prefix}\n"
for hit in retrieved_result["hits"][rank_start:rank_end]:
rank += 1
content = hit["content"]
content = content.replace("Title: Content: ", "")
content = content.strip()
# For Japanese should cut by character: content = content[:int(max_length)]
content = " ".join(content.split()[: int(max_length)])
input_context += f"[{rank}] {replace_number(content)}\n"

input_context += self._add_post_prompt(query, num)
conv.append_message(conv.roles[0], input_context)
prompt = conv.get_prompt() + " ASSISTANT:"
prompt = fix_text(prompt)
num_tokens = self.get_num_tokens(prompt)
if num_tokens <= self.max_tokens() - self.num_output_tokens():
break
else:
max_length -= max(
1,
(num_tokens - self.max_tokens() + self.num_output_tokens())
// (rank_end - rank_start),
)
return prompt, self.get_num_tokens(prompt)

def get_num_tokens(self, prompt: str) -> int:
return len(self._llm.tokenize(prompt.encode()))

def cost_per_1k_token(self, input_token: bool) -> float:
return 0
11 changes: 11 additions & 0 deletions rank_llm/run_rank_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rank_gpt import SafeOpenai
from rank_llm import PromptMode
from rank_vicuna import RankVicuna
from rank_vicuna_q import RankVicunaQ
from topics_dict import TOPICS
from trec_eval import EvalFunction

Expand Down Expand Up @@ -42,6 +43,16 @@ def main(args):
prompt_mode=prompt_mode,
keys=openai_keys,
)
elif "q4" or "q8" in model_path:
agent = RankVicunaQ(
model=model_path,
context_size=context_size,
top_k_candidates=top_k_candidates,
dataset=dataset,
prompt_mode=prompt_mode,
device=device,
num_gpus=num_gpus,
)
else:
agent = RankVicuna(
model=model_path,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ python-dotenv==1.0.0
faiss-gpu==1.7.2
ftfy==6.1.1
fschat==0.2.28
accelerate==0.23.0
accelerate==0.23.0
llama-cpp-python==0.2.7