Skip to content

Commit

Permalink
Merge pull request #33 from tackhwa/main
Browse files Browse the repository at this point in the history
Added test case for online and offline embedding model, correspond to 天机-任务看板 No.7
  • Loading branch information
sanbuphy authored Mar 5, 2024
2 parents 827e4fb + 706f513 commit 03fcb00
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,17 @@ pip install .

为确保项目正常运行,**请在项目内新建`.env`文件,并在其中设置你的API密钥**,你可以根据下列例子写入对应的 key,即可成功运行调用,目前默认使用 zhipuai,你可以仅写入`ZHIPUAI_API_KEY`即可使用。

如果在从Hugging Face下载模型时遇到速度极慢或无法下载的问题,请在.env文件中设置`HF_ENDPOINT`的值为'https://hf-mirror.com'。请注意,某些Hugging Face仓库可能需要访问权限(例如Jina Ai)。为此,请注册一个Hugging Face账号,并在.env文件中添加`HF_TOKEN`。你可以在[这里](https://huggingface.co/settings/tokens)找到并获取你的token。

```
OPENAI_API_KEY=
OPENAI_API_BASE=
ZHIPUAI_API_KEY=
BAIDU_API_KEY=
OPENAI_API_MODEL=
HF_HOME='./cache/'
HF_ENDPOINT = 'https://hf-mirror.com'
HF_TOKEN=
```

## 文件目录说明
Expand Down
13 changes: 13 additions & 0 deletions test/embedding/local/BaseLocal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List


class BaseLocalEmbeddings:
"""
Base class for local embeddings
"""

def __init__(self, path: str) -> None:
self.path = path

def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError
44 changes: 44 additions & 0 deletions test/embedding/local/Bge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from BaseLocal import BaseLocalEmbeddings
from typing import List
import torch
from transformers import AutoModel, AutoTokenizer
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()


class BgeEmbedding(BaseLocalEmbeddings):
"""
class for Bge embeddings
"""

# path:str = TIANJI_PATH / "embedding/BAAI/bge-small-zh"
def __init__(self, path: str = "BAAI/bge-small-zh") -> None:
super().__init__(path)
self._model, self._tokenizer, self._device = self.load_model()

def get_embedding(self, text: str) -> List[float]:
encoded_input = self._tokenizer(
text, padding=True, truncation=False, return_tensors="pt"
).to(self._device)
return self._model(**encoded_input)[0][:, 0][0]

def load_model(self):
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-zh-v1.5")
return model, tokenizer, device


if __name__ == "__main__":
Bge = BgeEmbedding()
embedding_result = Bge.get_embedding("你好")
print(
f"Result of Bge Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
43 changes: 43 additions & 0 deletions test/embedding/local/Jina.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from BaseLocal import BaseLocalEmbeddings
from typing import List
import os
import torch
from transformers import AutoModel
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["HF_TOKEN"]=""


class JinaEmbedding(BaseLocalEmbeddings):
"""
class for Jina embeddings
"""

# path:str = TIANJI_PATH / "embedding/jinaai/jina-embeddings-v2-base-zh"
def __init__(self, path: str = "jinaai/jina-embeddings-v2-base-zh") -> None:
super().__init__(path)
self._model = self.load_model()

def get_embedding(self, text: str) -> List[float]:
return self._model.encode([text])[0]

def load_model(self):
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
return model


if __name__ == "__main__":
Jina = JinaEmbedding()
embedding_result = Jina.get_embedding("你好")
print(
f"Result of Jina Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
13 changes: 13 additions & 0 deletions test/embedding/online/BaseOnline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List


class BaseOnlineEmbeddings:
"""
Base class for online embeddings
"""

def __init__(self) -> None:
pass

def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError
38 changes: 38 additions & 0 deletions test/embedding/online/Ernie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from BaseOnline import BaseOnlineEmbeddings
from typing import List
import os
import erniebot
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["BAIDU_API_KEY"]=""


class ErnieEmbedding(BaseOnlineEmbeddings):
"""
class for Ernie embeddings
"""

def __init__(self) -> None:
super().__init__()
erniebot.api_type = "aistudio"
erniebot.access_token = os.getenv("BAIDU_API_KEY")
self.client = erniebot.Embedding()

def get_embedding(
self, text: str, model: str = "ernie-text-embedding"
) -> List[float]:
response = self.client.create(model=model, input=[text])
return response.get_result()[0]


if __name__ == "__main__":
Ernie = ErnieEmbedding()
embedding_result = Ernie.get_embedding("你好")
print(
f"Result of Ernie Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
41 changes: 41 additions & 0 deletions test/embedding/online/OpenAI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from BaseOnline import BaseOnlineEmbeddings
from typing import List
import os
from openai import OpenAI
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["OPENAI_API_BASE"]=""
# os.environ["OPENAI_API_KEY"]=""


class OpenAIEmbedding(BaseOnlineEmbeddings):
"""
class for OpenAI embeddings
"""

def __init__(self) -> None:
super().__init__()
self.client = OpenAI()
self.client.base_url = os.getenv("OPENAI_API_BASE")
self.client.api_key = os.getenv("OPENAI_API_KEY")

def get_embedding(
self, text: str, model: str = "text-embedding-3-small"
) -> List[float]:
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=model).data[0].embedding
)


if __name__ == "__main__":
OpenAI = OpenAIEmbedding()
embedding_result = OpenAI.get_embedding("你好")
print(
f"Result of OpenAI Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
41 changes: 41 additions & 0 deletions test/embedding/online/Zhipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from BaseOnline import BaseOnlineEmbeddings
from typing import List
import os
from zhipuai import ZhipuAI
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["ZHIPUAI_API_KEY"]=""


class ZhipuEmbedding(BaseOnlineEmbeddings):
"""
class for Zhipu embeddings
"""

def __init__(self) -> None:
super().__init__()
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))

def get_embedding(
self,
text: str,
model: str = "embedding-2",
) -> List[float]:
response = self.client.embeddings.create(
model=model,
input=text,
)
return response.data[0].embedding


if __name__ == "__main__":
Zhipu = ZhipuEmbedding()
embedding_result = Zhipu.get_embedding("你好")
print(
f"Result of Zhipu Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)

0 comments on commit 03fcb00

Please sign in to comment.