Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

支持embeddings使用类似openai api的server #3892

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ MODEL_ROOT_PATH = ""
# 选用的 Embedding 名称
EMBEDDING_MODEL = "bge-large-zh-v1.5"

# 是否embedding使用openai like的api接口
EMBEDDING_MODEL_USE_OPENAI = True
# 这里tiktoken设置为false,并设置tiktoken model name为本地的hungginface的路径,可以使用指定的模型的tokenizer来切分。
# 如果想使用openai tiktoken,则设置tiktoken enable为True。

# chunk_size 设置单词请求的最大batch
# embedding_ctx_length 设置单个batch的token的最大长度,需要小于等于模型限制的最大值
EMBEDDING_MODEL_OPENAI = {
"model_name": "bge-large-zh-v1.5",
"api_base_url": "http://host:port/v1",
"api_key": "123",
"openai_proxy": "",
"tiktoken_enabled": False,
"tiktoken_model_name": "bge-large-zh-v1.5",
"chunk_size": 20,
"embedding_ctx_length": 256,
}

# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
EMBEDDING_DEVICE = "auto"

Expand Down
21 changes: 19 additions & 2 deletions server/knowledge_base/kb_cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from langchain.vectorstores.faiss import FAISS
import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose)
from server.utils import embedding_device, get_model_path, list_online_embed_models
logger, log_verbose, EMBEDDING_MODEL_USE_OPENAI, EMBEDDING_MODEL_OPENAI)
from server.utils import embedding_device, get_model_path, list_online_embed_models, get_absulute_model_path
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
Expand Down Expand Up @@ -132,6 +132,23 @@ def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
embeddings = OpenAIEmbeddings(model=model,
openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE)
elif EMBEDDING_MODEL_USE_OPENAI:
from langchain.embeddings.openai import OpenAIEmbeddings
config = EMBEDDING_MODEL_OPENAI
model_name = config.get("model_name", model)
tiktoken_enabled = config.get("tiktoken_enabled", False)
tiktoken_model_name=config.get("tiktoken_model_name", "")
if not tiktoken_enabled:
tiktoken_model_name = get_absulute_model_path(model_name, tiktoken_model_name)
embeddings = OpenAIEmbeddings(model=model_name,
openai_api_base=config.get("api_base_url", ""),
openai_api_key=config.get("api_key", "123"),
openai_proxy=config.get("openai_proxy", ""),
tiktoken_enabled=tiktoken_enabled,
tiktoken_model_name=tiktoken_model_name,
chunk_size=config.get("chunk_size", 1),
embedding_ctx_length=config.get("embedding_ctx_length", 10),
)
elif 'bge-' in model:
from langchain.embeddings import HuggingFaceBgeEmbeddings
if 'zh' in model:
Expand Down
33 changes: 17 additions & 16 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,22 @@ def list_config_llm_models() -> Dict[str, Dict]:
"worker": workers,
}

def get_absulute_model_path(model_name: str, path_str: str):
path = Path(path_str)
if path.is_dir(): # 任意绝对路径
return str(path)
root_path = Path(MODEL_ROOT_PATH)
if root_path.is_dir():
path = root_path / model_name
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
return str(path)
path = root_path / path_str
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
return str(path)
path = root_path / path_str.split("/")[-1]
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
return str(path)
return path_str # THUDM/chatglm06b

def get_model_path(model_name: str, type: str = None) -> Optional[str]:
if type in MODEL_PATH:
Expand All @@ -331,22 +347,7 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
paths.update(v)

if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
path = Path(path_str)
if path.is_dir(): # 任意绝对路径
return str(path)

root_path = Path(MODEL_ROOT_PATH)
if root_path.is_dir():
path = root_path / model_name
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
return str(path)
path = root_path / path_str
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
return str(path)
path = root_path / path_str.split("/")[-1]
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
return str(path)
return path_str # THUDM/chatglm06b
return get_absulute_model_path(model_name, path_str)


# 从server_config中获取服务信息
Expand Down