2023-09-28 10:58:58 +08:00
|
|
|
|
# 该文件包含webui通用工具,可以被不同的webui使用
|
|
|
|
|
from typing import *
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
import httpx
|
|
|
|
|
import asyncio
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
import contextlib
|
|
|
|
|
import json
|
|
|
|
|
import nltk
|
|
|
|
|
import traceback
|
|
|
|
|
from loguru import logger
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2023-09-28 10:58:58 +08:00
|
|
|
|
|
|
|
|
|
from configs.model_config import (
|
|
|
|
|
EMBEDDING_MODEL,
|
|
|
|
|
DEFAULT_VS_TYPE,
|
|
|
|
|
KB_ROOT_PATH,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
CB_ROOT_PATH,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
LLM_MODEL,
|
|
|
|
|
SCORE_THRESHOLD,
|
|
|
|
|
VECTOR_SEARCH_TOP_K,
|
|
|
|
|
SEARCH_ENGINE_TOP_K,
|
|
|
|
|
NLTK_DATA_PATH,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
JUPYTER_WORK_PATH,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
)
|
|
|
|
|
from configs.server_config import SANDBOX_SERVER
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# from configs.server_config import SANDBOX_SERVER
|
|
|
|
|
|
|
|
|
|
from coagent.utils.server_utils import run_async, iter_over_async
|
|
|
|
|
from coagent.service.kb_api import *
|
|
|
|
|
from coagent.service.cb_api import *
|
|
|
|
|
from coagent.chat import LLMChat, SearchChat, KnowledgeChat, CodeChat, AgentChat
|
|
|
|
|
from coagent.sandbox import PyCodeBox, CodeBoxResponse
|
|
|
|
|
from coagent.utils.common_utils import file_normalize, get_uploadfile
|
2023-09-28 10:58:58 +08:00
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
from coagent.codechat.code_crawler.zip_crawler import ZipCrawler
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2023-09-28 10:58:58 +08:00
|
|
|
|
from web_crawler.utils.WebCrawler import WebCrawler
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
2023-09-28 10:58:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_httpx_timeout(timeout=60.0):
|
|
|
|
|
'''
|
|
|
|
|
设置httpx默认timeout到60秒。
|
|
|
|
|
httpx默认timeout是5秒,在请求LLM回答时不够用。
|
|
|
|
|
'''
|
|
|
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
|
|
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
|
|
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
|
|
|
|
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
2023-09-28 10:58:58 +08:00
|
|
|
|
set_httpx_timeout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ApiRequest:
|
|
|
|
|
'''
|
|
|
|
|
api.py调用的封装,主要实现:
|
|
|
|
|
1. 简化api调用方式
|
|
|
|
|
2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py
|
|
|
|
|
'''
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
base_url: str = "http://127.0.0.1:7861",
|
2023-11-07 19:44:47 +08:00
|
|
|
|
sandbox_file_url: str = "http://127.0.0.1:7862",
|
2023-09-28 10:58:58 +08:00
|
|
|
|
timeout: float = 60.0,
|
|
|
|
|
no_remote_api: bool = False, # call api view function directly
|
2024-01-26 14:03:25 +08:00
|
|
|
|
cb_root_path: str = "",
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
self.base_url = base_url
|
2023-11-07 19:44:47 +08:00
|
|
|
|
self.sandbox_file_url = sandbox_file_url
|
2023-09-28 10:58:58 +08:00
|
|
|
|
self.timeout = timeout
|
|
|
|
|
self.no_remote_api = no_remote_api
|
2024-01-26 14:03:25 +08:00
|
|
|
|
self.cb_root_path = cb_root_path
|
|
|
|
|
|
2023-09-28 10:58:58 +08:00
|
|
|
|
self.llmChat = LLMChat()
|
|
|
|
|
self.searchChat = SearchChat()
|
2024-01-26 14:03:25 +08:00
|
|
|
|
self.knowledgeChat = KnowledgeChat(kb_root_path=KB_ROOT_PATH)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
self.codeChat = CodeChat()
|
|
|
|
|
self.agentChat = AgentChat()
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# self.codebox = PyCodeBox(
|
|
|
|
|
# remote_url=self.sandbox_server["url"],
|
|
|
|
|
# remote_ip=self.sandbox_server["host"], # "http://localhost",
|
|
|
|
|
# remote_port=self.sandbox_server["port"],
|
|
|
|
|
# token="mytoken",
|
|
|
|
|
# do_code_exe=True,
|
|
|
|
|
# do_remote=self.sandbox_server["do_remote"]
|
|
|
|
|
# )
|
2023-09-28 10:58:58 +08:00
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# def codebox_chat(self, text: str, file_path: str = None, do_code_exe: bool = None) -> CodeBoxResponse:
|
|
|
|
|
# return self.codebox.chat(text, file_path, do_code_exe=do_code_exe)
|
2023-09-28 10:58:58 +08:00
|
|
|
|
|
|
|
|
|
def _parse_url(self, url: str) -> str:
|
|
|
|
|
if (not url.startswith("http")
|
|
|
|
|
and self.base_url
|
|
|
|
|
):
|
2023-11-07 19:44:47 +08:00
|
|
|
|
part1 = self.sandbox_file_url.strip(" /") \
|
|
|
|
|
if "sdfiles" in url else self.base_url.strip(" /")
|
2023-09-28 10:58:58 +08:00
|
|
|
|
part2 = url.strip(" /")
|
|
|
|
|
return f"{part1}/{part2}"
|
|
|
|
|
else:
|
|
|
|
|
return url
|
|
|
|
|
|
|
|
|
|
def get(
|
|
|
|
|
self,
|
|
|
|
|
url: str,
|
|
|
|
|
params: Union[Dict, List[Tuple], bytes] = None,
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
while retry > 0:
|
|
|
|
|
try:
|
|
|
|
|
if stream:
|
|
|
|
|
return httpx.stream("GET", url, params=params, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return httpx.get(url, params=params, **kwargs)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
async def aget(
|
|
|
|
|
self,
|
|
|
|
|
url: str,
|
|
|
|
|
params: Union[Dict, List[Tuple], bytes] = None,
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
|
|
while retry > 0:
|
|
|
|
|
try:
|
|
|
|
|
if stream:
|
|
|
|
|
return await client.stream("GET", url, params=params, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return await client.get(url, params=params, **kwargs)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
def post(
|
|
|
|
|
self,
|
|
|
|
|
url: str,
|
|
|
|
|
data: Dict = None,
|
|
|
|
|
json: Dict = None,
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
while retry > 0:
|
|
|
|
|
try:
|
|
|
|
|
# return requests.post(url, data=data, json=json, stream=stream, **kwargs)
|
|
|
|
|
if stream:
|
|
|
|
|
return httpx.stream("POST", url, data=data, json=json, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return httpx.post(url, data=data, json=json, **kwargs)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
async def apost(
|
|
|
|
|
self,
|
|
|
|
|
url: str,
|
|
|
|
|
data: Dict = None,
|
|
|
|
|
json: Dict = None,
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
|
|
while retry > 0:
|
|
|
|
|
try:
|
|
|
|
|
if stream:
|
|
|
|
|
return await client.stream("POST", url, data=data, json=json, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return await client.post(url, data=data, json=json, **kwargs)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
def delete(
|
|
|
|
|
self,
|
|
|
|
|
url: str,
|
|
|
|
|
data: Dict = None,
|
|
|
|
|
json: Dict = None,
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
while retry > 0:
|
|
|
|
|
try:
|
|
|
|
|
if stream:
|
|
|
|
|
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return httpx.delete(url, data=data, json=json, **kwargs)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
async def adelete(
|
|
|
|
|
self,
|
|
|
|
|
url: str,
|
|
|
|
|
data: Dict = None,
|
|
|
|
|
json: Dict = None,
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
|
|
while retry > 0:
|
|
|
|
|
try:
|
|
|
|
|
if stream:
|
|
|
|
|
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return await client.delete(url, data=data, json=json, **kwargs)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
|
|
|
|
'''
|
|
|
|
|
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
|
|
|
|
'''
|
|
|
|
|
try:
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
except:
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for chunk in iter_over_async(response.body_iterator, loop):
|
|
|
|
|
if as_json and chunk:
|
|
|
|
|
yield json.loads(chunk)
|
|
|
|
|
elif chunk.strip():
|
|
|
|
|
yield chunk
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
def _stream2generator(self, response: str, as_json: bool =False):
|
|
|
|
|
'''
|
|
|
|
|
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
|
|
|
|
'''
|
|
|
|
|
try:
|
|
|
|
|
if as_json and response:
|
|
|
|
|
return json.loads(response)
|
|
|
|
|
elif response.strip():
|
|
|
|
|
return response
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
|
2023-09-28 10:58:58 +08:00
|
|
|
|
def _httpx_stream2generator(
|
|
|
|
|
self,
|
|
|
|
|
response: contextlib._GeneratorContextManager,
|
|
|
|
|
as_json: bool = False,
|
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
|
|
|
|
'''
|
|
|
|
|
try:
|
|
|
|
|
with response as r:
|
|
|
|
|
for chunk in r.iter_text(None):
|
|
|
|
|
if as_json and chunk:
|
|
|
|
|
yield json.loads(chunk)
|
|
|
|
|
elif chunk.strip():
|
|
|
|
|
yield chunk
|
|
|
|
|
except httpx.ConnectError as e:
|
|
|
|
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
logger.error(e)
|
|
|
|
|
yield {"code": 500, "msg": msg}
|
|
|
|
|
except httpx.ReadTimeout as e:
|
|
|
|
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
logger.error(e)
|
|
|
|
|
yield {"code": 500, "msg": msg}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
yield {"code": 500, "msg": str(e)}
|
|
|
|
|
|
|
|
|
|
def chat_chat(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
history: List[Dict] = [],
|
|
|
|
|
stream: bool = True,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
|
|
|
|
llm_model: str ="", temperature: float= 0.2
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/chat/chat接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"query": query,
|
|
|
|
|
"history": history,
|
|
|
|
|
"stream": stream,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_name": llm_model,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"model_device": model_device,
|
|
|
|
|
"temperature": temperature,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = self.llmChat.chat(**data)
|
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
|
|
|
|
else:
|
|
|
|
|
response = self.post("/chat/chat", json=data, stream=True)
|
|
|
|
|
return self._httpx_stream2generator(response)
|
|
|
|
|
|
|
|
|
|
def knowledge_base_chat(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
knowledge_base_name: str,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
top_k: int = 5,
|
|
|
|
|
score_threshold: float = 1.0,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
history: List[Dict] = [],
|
|
|
|
|
stream: bool = True,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
|
|
|
|
llm_model: str ="", temperature: float= 0.2
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/chat/knowledge_base_chat接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"query": query,
|
|
|
|
|
"engine_name": knowledge_base_name,
|
|
|
|
|
"top_k": top_k,
|
|
|
|
|
"score_threshold": score_threshold,
|
|
|
|
|
"history": history,
|
|
|
|
|
"stream": stream,
|
|
|
|
|
"local_doc_url": no_remote_api,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_name": llm_model,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"model_device": model_device,
|
|
|
|
|
"temperature": temperature,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = self.knowledgeChat.chat(**data)
|
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/chat/knowledge_base_chat",
|
|
|
|
|
json=data,
|
|
|
|
|
stream=True,
|
|
|
|
|
)
|
|
|
|
|
return self._httpx_stream2generator(response, as_json=True)
|
|
|
|
|
|
|
|
|
|
def search_engine_chat(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
search_engine_name: str,
|
2023-12-26 11:41:53 +08:00
|
|
|
|
top_k: int,
|
|
|
|
|
history: List[Dict] = [],
|
2023-09-28 10:58:58 +08:00
|
|
|
|
stream: bool = True,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
|
|
|
|
llm_model: str ="", temperature: float= 0.2
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/chat/search_engine_chat接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"query": query,
|
|
|
|
|
"engine_name": search_engine_name,
|
2023-12-26 11:41:53 +08:00
|
|
|
|
"top_k": top_k,
|
|
|
|
|
"history": history,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
"stream": stream,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_name": llm_model,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"model_device": model_device,
|
|
|
|
|
"temperature": temperature,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = self.searchChat.chat(**data)
|
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/chat/search_engine_chat",
|
|
|
|
|
json=data,
|
|
|
|
|
stream=True,
|
|
|
|
|
)
|
|
|
|
|
return self._httpx_stream2generator(response, as_json=True)
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
def code_base_chat(
|
2023-11-07 19:44:47 +08:00
|
|
|
|
self,
|
|
|
|
|
query: str,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
code_base_name: str,
|
|
|
|
|
code_limit: int = 1,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
history: List[Dict] = [],
|
2023-12-07 20:17:21 +08:00
|
|
|
|
cb_search_type: str = 'tag',
|
2023-11-07 19:44:47 +08:00
|
|
|
|
stream: bool = True,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
|
|
|
|
llm_model: str ="", temperature: float= 0.2
|
2023-11-07 19:44:47 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
2023-12-07 20:17:21 +08:00
|
|
|
|
对应api.py/chat/knowledge_base_chat接口
|
2023-11-07 19:44:47 +08:00
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
cb_search_type = {
|
|
|
|
|
'基于 cypher': 'cypher',
|
|
|
|
|
'基于标签': 'tag',
|
|
|
|
|
'基于描述': 'description',
|
|
|
|
|
'tag': 'tag',
|
|
|
|
|
'description': 'description',
|
|
|
|
|
'cypher': 'cypher'
|
|
|
|
|
}.get(cb_search_type, 'tag')
|
|
|
|
|
|
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
data = {
|
|
|
|
|
"query": query,
|
|
|
|
|
"history": history,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"engine_name": code_base_name,
|
|
|
|
|
"code_limit": code_limit,
|
|
|
|
|
"cb_search_type": cb_search_type,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
"stream": stream,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"local_doc_url": no_remote_api,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_name": llm_model,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"model_device": model_device,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
}
|
2023-12-07 20:17:21 +08:00
|
|
|
|
logger.info('data={}'.format(data))
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# logger.info('history_node_list before={}'.format(self.codeChat.history_node_list))
|
|
|
|
|
response = self.codeChat.chat(**data)
|
|
|
|
|
# logger.info('history_node_list after={}'.format(self.codeChat.history_node_list))
|
2023-11-07 19:44:47 +08:00
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
|
|
|
|
else:
|
2023-12-07 20:17:21 +08:00
|
|
|
|
response = self.post(
|
|
|
|
|
"/chat/code_chat",
|
|
|
|
|
json=data,
|
|
|
|
|
stream=True,
|
|
|
|
|
)
|
|
|
|
|
return self._httpx_stream2generator(response, as_json=True)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
def agent_chat(
|
2023-11-07 19:44:47 +08:00
|
|
|
|
self,
|
|
|
|
|
query: str,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
phase_name: str,
|
|
|
|
|
doc_engine_name: str,
|
|
|
|
|
code_engine_name: str,
|
|
|
|
|
search_engine_name: str,
|
|
|
|
|
top_k: int = 3,
|
|
|
|
|
score_threshold: float = 1.0,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
history: List[Dict] = [],
|
|
|
|
|
stream: bool = True,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
local_doc_url: bool = False,
|
|
|
|
|
do_search: bool = False,
|
|
|
|
|
do_doc_retrieval: bool = False,
|
|
|
|
|
do_code_retrieval: bool = False,
|
|
|
|
|
do_tool_retrieval: bool = False,
|
|
|
|
|
choose_tools: List[str] = [],
|
|
|
|
|
custom_phase_configs = {},
|
|
|
|
|
custom_chain_configs = {},
|
|
|
|
|
custom_role_configs = {},
|
2023-11-07 19:44:47 +08:00
|
|
|
|
no_remote_api: bool = None,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
history_node_list: List[str] = [],
|
|
|
|
|
isDetailed: bool = False,
|
|
|
|
|
upload_file: Union[str, Path, bytes] = "",
|
2024-01-26 14:03:25 +08:00
|
|
|
|
kb_root_path: str =KB_ROOT_PATH,
|
|
|
|
|
embed_model: str="", embed_model_path: str="",
|
|
|
|
|
model_device: str="", embed_engine: str="",
|
|
|
|
|
temperature: float=0.2, model_name:str ="",
|
2023-11-07 19:44:47 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/chat/chat接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"query": query,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"phase_name": phase_name,
|
|
|
|
|
"chain_name": "",
|
2023-11-07 19:44:47 +08:00
|
|
|
|
"history": history,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"doc_engine_name": doc_engine_name,
|
|
|
|
|
"code_engine_name": code_engine_name,
|
|
|
|
|
"search_engine_name": search_engine_name,
|
|
|
|
|
"top_k": top_k,
|
|
|
|
|
"score_threshold": score_threshold,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
"stream": stream,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"local_doc_url": local_doc_url,
|
|
|
|
|
"do_search": do_search,
|
|
|
|
|
"do_doc_retrieval": do_doc_retrieval,
|
|
|
|
|
"do_code_retrieval": do_code_retrieval,
|
|
|
|
|
"do_tool_retrieval": do_tool_retrieval,
|
|
|
|
|
"custom_phase_configs": custom_phase_configs,
|
|
|
|
|
"custom_chain_configs": custom_phase_configs,
|
|
|
|
|
"custom_role_configs": custom_role_configs,
|
|
|
|
|
"choose_tools": choose_tools,
|
|
|
|
|
"history_node_list": history_node_list,
|
|
|
|
|
"isDetailed": isDetailed,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"upload_file": upload_file,
|
|
|
|
|
"kb_root_path": kb_root_path,
|
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_device": model_device,
|
|
|
|
|
"model_name": model_name,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"jupyter_work_path": JUPYTER_WORK_PATH,
|
|
|
|
|
"sandbox_server": SANDBOX_SERVER,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
}
|
|
|
|
|
if no_remote_api:
|
2023-12-07 20:17:21 +08:00
|
|
|
|
response = self.agentChat.chat(**data)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
|
|
|
|
else:
|
|
|
|
|
response = self.post("/chat/data_chat", json=data, stream=True)
|
|
|
|
|
return self._httpx_stream2generator(response)
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
def agent_achat(
|
2023-11-07 19:44:47 +08:00
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
phase_name: str,
|
|
|
|
|
doc_engine_name: str,
|
|
|
|
|
code_engine_name: str,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
cb_search_type: str,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
search_engine_name: str,
|
|
|
|
|
top_k: int = 3,
|
|
|
|
|
score_threshold: float = 1.0,
|
|
|
|
|
history: List[Dict] = [],
|
|
|
|
|
stream: bool = True,
|
|
|
|
|
local_doc_url: bool = False,
|
|
|
|
|
do_search: bool = False,
|
|
|
|
|
do_doc_retrieval: bool = False,
|
|
|
|
|
do_code_retrieval: bool = False,
|
|
|
|
|
do_tool_retrieval: bool = False,
|
|
|
|
|
choose_tools: List[str] = [],
|
|
|
|
|
custom_phase_configs = {},
|
|
|
|
|
custom_chain_configs = {},
|
|
|
|
|
custom_role_configs = {},
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
history_node_list: List[str] = [],
|
2023-12-07 20:17:21 +08:00
|
|
|
|
isDetailed: bool = False,
|
|
|
|
|
upload_file: Union[str, Path, bytes] = "",
|
2024-01-26 14:03:25 +08:00
|
|
|
|
kb_root_path: str =KB_ROOT_PATH,
|
|
|
|
|
embed_model: str="", embed_model_path: str="",
|
|
|
|
|
model_device: str="", embed_engine: str="",
|
|
|
|
|
temperature: float=0.2, model_name: str="",
|
2023-11-07 19:44:47 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/chat/chat接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"query": query,
|
|
|
|
|
"phase_name": phase_name,
|
|
|
|
|
"chain_name": "",
|
|
|
|
|
"history": history,
|
|
|
|
|
"doc_engine_name": doc_engine_name,
|
|
|
|
|
"code_engine_name": code_engine_name,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"cb_search_type": cb_search_type,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
"search_engine_name": search_engine_name,
|
|
|
|
|
"top_k": top_k,
|
|
|
|
|
"score_threshold": score_threshold,
|
|
|
|
|
"stream": stream,
|
|
|
|
|
"local_doc_url": local_doc_url,
|
|
|
|
|
"do_search": do_search,
|
|
|
|
|
"do_doc_retrieval": do_doc_retrieval,
|
|
|
|
|
"do_code_retrieval": do_code_retrieval,
|
|
|
|
|
"do_tool_retrieval": do_tool_retrieval,
|
|
|
|
|
"custom_phase_configs": custom_phase_configs,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"custom_chain_configs": custom_chain_configs,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
"custom_role_configs": custom_role_configs,
|
|
|
|
|
"choose_tools": choose_tools,
|
|
|
|
|
"history_node_list": history_node_list,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"isDetailed": isDetailed,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"upload_file": upload_file,
|
|
|
|
|
"kb_root_path": kb_root_path,
|
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_device": model_device,
|
|
|
|
|
"model_name": model_name,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"jupyter_work_path": JUPYTER_WORK_PATH,
|
|
|
|
|
"sandbox_server": SANDBOX_SERVER,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
}
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
if no_remote_api:
|
2023-12-07 20:17:21 +08:00
|
|
|
|
for response in self.agentChat.achat(**data):
|
|
|
|
|
yield self._stream2generator(response, as_json=True)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
else:
|
|
|
|
|
response = self.post("/chat/data_chat", json=data, stream=True)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
yield self._httpx_stream2generator(response)
|
|
|
|
|
|
2023-09-28 10:58:58 +08:00
|
|
|
|
def _check_httpx_json_response(
|
|
|
|
|
self,
|
|
|
|
|
response: httpx.Response,
|
|
|
|
|
errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py",
|
|
|
|
|
) -> Dict:
|
|
|
|
|
'''
|
|
|
|
|
check whether httpx returns correct data with normal Response.
|
|
|
|
|
error in api with streaming support was checked in _httpx_stream2enerator
|
|
|
|
|
'''
|
|
|
|
|
try:
|
|
|
|
|
return response.json()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
return {"code": 500, "msg": errorMsg or str(e)}
|
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
def _check_httpx_file_response(
|
|
|
|
|
self,
|
|
|
|
|
response: httpx.Response,
|
|
|
|
|
errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py",
|
|
|
|
|
) -> Dict:
|
|
|
|
|
'''
|
|
|
|
|
check whether httpx returns correct data with normal Response.
|
|
|
|
|
error in api with streaming support was checked in _httpx_stream2enerator
|
|
|
|
|
'''
|
|
|
|
|
try:
|
|
|
|
|
return response.content
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
return {"code": 500, "msg": errorMsg or str(e)}
|
|
|
|
|
|
2023-09-28 10:58:58 +08:00
|
|
|
|
def list_knowledge_bases(
|
|
|
|
|
self,
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/list_knowledge_bases接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(list_kbs())
|
|
|
|
|
return response.data
|
|
|
|
|
else:
|
|
|
|
|
response = self.get("/knowledge_base/list_knowledge_bases")
|
|
|
|
|
data = self._check_httpx_json_response(response)
|
|
|
|
|
return data.get("data", [])
|
|
|
|
|
|
|
|
|
|
def create_knowledge_base(
|
|
|
|
|
self,
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
vector_store_type: str = "faiss",
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
kb_root_path: str =KB_ROOT_PATH,
|
|
|
|
|
embed_model: str="", embed_model_path: str="",
|
|
|
|
|
embedding_device: str="", embed_engine: str=""
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/create_knowledge_base接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"knowledge_base_name": knowledge_base_name,
|
|
|
|
|
"vector_store_type": vector_store_type,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"kb_root_path": kb_root_path,
|
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
2023-09-28 10:58:58 +08:00
|
|
|
|
"embed_model": embed_model,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"model_device": embedding_device,
|
|
|
|
|
"embed_engine": embed_engine
|
2023-09-28 10:58:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(create_kb(**data))
|
|
|
|
|
return response.dict()
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/knowledge_base/create_knowledge_base",
|
|
|
|
|
json=data,
|
|
|
|
|
)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def delete_knowledge_base(
|
|
|
|
|
self,
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
kb_root_path: str =KB_ROOT_PATH,
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/delete_knowledge_base接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
2024-01-26 14:03:25 +08:00
|
|
|
|
data = {
|
|
|
|
|
"knowledge_base_name": knowledge_base_name,
|
|
|
|
|
"kb_root_path": kb_root_path,
|
|
|
|
|
}
|
2023-09-28 10:58:58 +08:00
|
|
|
|
if no_remote_api:
|
2024-01-26 14:03:25 +08:00
|
|
|
|
response = run_async(delete_kb(**data))
|
2023-09-28 10:58:58 +08:00
|
|
|
|
return response.dict()
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/knowledge_base/delete_knowledge_base",
|
|
|
|
|
json=f"{knowledge_base_name}",
|
|
|
|
|
)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def list_kb_docs(
|
|
|
|
|
self,
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/list_docs接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
2024-01-26 14:03:25 +08:00
|
|
|
|
response = run_async(list_docs(knowledge_base_name, kb_root_path=KB_ROOT_PATH))
|
2023-09-28 10:58:58 +08:00
|
|
|
|
return response.data
|
|
|
|
|
else:
|
|
|
|
|
response = self.get(
|
|
|
|
|
"/knowledge_base/list_docs",
|
|
|
|
|
params={"knowledge_base_name": knowledge_base_name}
|
|
|
|
|
)
|
|
|
|
|
data = self._check_httpx_json_response(response)
|
|
|
|
|
return data.get("data", [])
|
|
|
|
|
|
|
|
|
|
def upload_kb_doc(
|
|
|
|
|
self,
|
|
|
|
|
file: Union[str, Path, bytes],
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
filename: str = None,
|
|
|
|
|
override: bool = False,
|
|
|
|
|
not_refresh_vs_cache: bool = False,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
kb_root_path: str = KB_ROOT_PATH,
|
|
|
|
|
embed_model: str="", embed_model_path: str="",
|
|
|
|
|
model_device: str="", embed_engine: str=""
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/upload_docs接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
if isinstance(file, bytes): # raw bytes
|
|
|
|
|
file = BytesIO(file)
|
|
|
|
|
elif hasattr(file, "read"): # a file io like object
|
|
|
|
|
filename = filename or file.name
|
|
|
|
|
else: # a local path
|
|
|
|
|
file = Path(file).absolute().open("rb")
|
|
|
|
|
filename = filename or file.name
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
from fastapi import UploadFile
|
|
|
|
|
from tempfile import SpooledTemporaryFile
|
|
|
|
|
|
|
|
|
|
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
|
|
|
|
temp_file.write(file.read())
|
|
|
|
|
temp_file.seek(0)
|
|
|
|
|
response = run_async(upload_doc(
|
|
|
|
|
UploadFile(file=temp_file, filename=filename),
|
|
|
|
|
knowledge_base_name,
|
|
|
|
|
override,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
not_refresh_vs_cache,
|
|
|
|
|
kb_root_path=kb_root_path,
|
|
|
|
|
api_key=os.environ["OPENAI_API_KEY"],
|
|
|
|
|
api_base_url=os.environ["API_BASE_URL"],
|
|
|
|
|
embed_model=embed_model,
|
|
|
|
|
embed_model_path=embed_model_path,
|
|
|
|
|
model_device=model_device,
|
|
|
|
|
embed_engine=embed_engine
|
2023-09-28 10:58:58 +08:00
|
|
|
|
))
|
|
|
|
|
return response.dict()
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/knowledge_base/upload_doc",
|
|
|
|
|
data={
|
|
|
|
|
"knowledge_base_name": knowledge_base_name,
|
|
|
|
|
"override": override,
|
|
|
|
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
|
|
|
|
},
|
|
|
|
|
files={"file": (filename, file)},
|
|
|
|
|
)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def delete_kb_doc(
|
|
|
|
|
self,
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
doc_name: str,
|
|
|
|
|
delete_content: bool = False,
|
|
|
|
|
not_refresh_vs_cache: bool = False,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
kb_root_path: str = KB_ROOT_PATH,
|
|
|
|
|
embed_model: str="", embed_model_path: str="",
|
|
|
|
|
model_device: str="", embed_engine: str=""
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/delete_doc接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"knowledge_base_name": knowledge_base_name,
|
|
|
|
|
"doc_name": doc_name,
|
|
|
|
|
"delete_content": delete_content,
|
|
|
|
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"kb_root_path": kb_root_path,
|
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"model_device": model_device,
|
|
|
|
|
"embed_engine": embed_engine
|
2023-09-28 10:58:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(delete_doc(**data))
|
|
|
|
|
return response.dict()
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/knowledge_base/delete_doc",
|
|
|
|
|
json=data,
|
|
|
|
|
)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def update_kb_doc(
|
|
|
|
|
self,
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
file_name: str,
|
|
|
|
|
not_refresh_vs_cache: bool = False,
|
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
embed_model: str="", embed_model_path: str="",
|
|
|
|
|
model_device: str="", embed_engine: str=""
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/update_doc接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
2024-01-26 14:03:25 +08:00
|
|
|
|
response = run_async(update_doc(
|
|
|
|
|
knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH,
|
|
|
|
|
api_key=os.environ["OPENAI_API_KEY"],
|
|
|
|
|
api_base_url=os.environ["API_BASE_URL"],
|
|
|
|
|
embed_model=embed_model,
|
|
|
|
|
embed_model_path=embed_model_path,
|
|
|
|
|
model_device=model_device,
|
|
|
|
|
embed_engine=embed_engine))
|
2023-09-28 10:58:58 +08:00
|
|
|
|
return response.dict()
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/knowledge_base/update_doc",
|
|
|
|
|
json={
|
|
|
|
|
"knowledge_base_name": knowledge_base_name,
|
|
|
|
|
"file_name": file_name,
|
|
|
|
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def recreate_vector_store(
|
|
|
|
|
self,
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
allow_empty_kb: bool = True,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
vs_type: str = "faiss",
|
2023-09-28 10:58:58 +08:00
|
|
|
|
no_remote_api: bool = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
kb_root_path: str =KB_ROOT_PATH,
|
|
|
|
|
embed_model: str="", embed_model_path: str="",
|
|
|
|
|
embedding_device: str="", embed_engine: str=""
|
2023-09-28 10:58:58 +08:00
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
对应api.py/knowledge_base/recreate_vector_store接口
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"knowledge_base_name": knowledge_base_name,
|
|
|
|
|
"allow_empty_kb": allow_empty_kb,
|
|
|
|
|
"vs_type": vs_type,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"kb_root_path": kb_root_path,
|
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
2023-09-28 10:58:58 +08:00
|
|
|
|
"embed_model": embed_model,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"model_device": embedding_device,
|
|
|
|
|
"embed_engine": embed_engine
|
2023-09-28 10:58:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(recreate_vector_store(**data))
|
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/knowledge_base/recreate_vector_store",
|
|
|
|
|
json=data,
|
|
|
|
|
stream=True,
|
|
|
|
|
timeout=None,
|
|
|
|
|
)
|
|
|
|
|
return self._httpx_stream2generator(response, as_json=True)
|
|
|
|
|
|
|
|
|
|
def web_crawl(
|
|
|
|
|
self,
|
|
|
|
|
base_url: str,
|
|
|
|
|
html_dir: str,
|
|
|
|
|
text_dir: str,
|
|
|
|
|
do_dfs: bool = False,
|
|
|
|
|
reptile_lib: str = "requests",
|
|
|
|
|
method: str = "get",
|
|
|
|
|
time_sleep: float = 2,
|
|
|
|
|
no_remote_api: bool = None
|
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
根据url来检索
|
|
|
|
|
'''
|
|
|
|
|
async def _web_crawl(html_dir, text_dir, base_url, reptile_lib, method, time_sleep, do_dfs):
|
|
|
|
|
wc = WebCrawler()
|
|
|
|
|
try:
|
|
|
|
|
if not do_dfs:
|
|
|
|
|
wc.webcrawler_single(html_dir=html_dir,
|
|
|
|
|
text_dir=text_dir,
|
|
|
|
|
base_url=base_url,
|
|
|
|
|
reptile_lib=reptile_lib,
|
|
|
|
|
method=method,
|
|
|
|
|
time_sleep=time_sleep
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
wc.webcrawler_1_degree(html_dir=html_dir,
|
|
|
|
|
text_dir=text_dir,
|
|
|
|
|
base_url=base_url,
|
|
|
|
|
reptile_lib=reptile_lib,
|
|
|
|
|
method=method,
|
|
|
|
|
time_sleep=time_sleep
|
|
|
|
|
)
|
|
|
|
|
return {"status": 200, "response": "success"}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return {"status": 500, "response": str(e)}
|
|
|
|
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"base_url": base_url,
|
|
|
|
|
"html_dir": html_dir,
|
|
|
|
|
"text_dir": text_dir,
|
|
|
|
|
"do_dfs": do_dfs,
|
|
|
|
|
"reptile_lib": reptile_lib,
|
|
|
|
|
"method": method,
|
|
|
|
|
"time_sleep": time_sleep,
|
|
|
|
|
}
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(_web_crawl(**data))
|
|
|
|
|
return response
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("not impletenion")
|
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
def web_sd_upload(self, file: str = None, filename: str = None):
|
|
|
|
|
'''对应file_service/sd_upload_file'''
|
|
|
|
|
file, filename = file_normalize(file, filename)
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/sdfiles/upload",
|
|
|
|
|
files={"file": (filename, file)},
|
|
|
|
|
)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def web_sd_download(self, filename: str, save_filename: str = None):
|
|
|
|
|
'''对应file_service/sd_download_file'''
|
|
|
|
|
save_filename = save_filename or filename
|
2023-12-07 20:17:21 +08:00
|
|
|
|
response = self.get(
|
|
|
|
|
f"/sdfiles/download",
|
|
|
|
|
params={"filename": filename, "save_filename": save_filename}
|
|
|
|
|
)
|
|
|
|
|
# logger.debug(f"response: {response.json()}")
|
|
|
|
|
if filename:
|
|
|
|
|
file_content, _ = file_normalize(response.json()["data"])
|
|
|
|
|
return file_content, save_filename
|
|
|
|
|
return "", save_filename
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
|
|
def web_sd_delete(self, filename: str):
|
|
|
|
|
'''对应file_service/sd_delete_file'''
|
|
|
|
|
response = self.get(
|
|
|
|
|
f"/sdfiles/delete",
|
|
|
|
|
params={"filename": filename}
|
|
|
|
|
)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def web_sd_list_files(self, ):
|
|
|
|
|
'''对应对应file_service/sd_list_files接口'''
|
|
|
|
|
response = self.get("/sdfiles/list",)
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
# code base 相关操作
|
2024-01-26 14:03:25 +08:00
|
|
|
|
def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None,
|
|
|
|
|
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
|
|
|
|
llm_model: str ="", temperature: float= 0.2
|
|
|
|
|
):
|
2023-11-07 19:44:47 +08:00
|
|
|
|
'''
|
|
|
|
|
创建 code_base
|
|
|
|
|
@param cb_name:
|
|
|
|
|
@param zip_path:
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
# mkdir
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# cb_root_path = CB_ROOT_PATH
|
2023-11-07 19:44:47 +08:00
|
|
|
|
mkdir_dir = [
|
2024-01-26 14:03:25 +08:00
|
|
|
|
self.cb_root_path,
|
|
|
|
|
self.cb_root_path + os.sep + cb_name,
|
|
|
|
|
raw_code_path := self.cb_root_path + os.sep + cb_name + os.sep + 'raw_code'
|
2023-11-07 19:44:47 +08:00
|
|
|
|
]
|
|
|
|
|
for dir in mkdir_dir:
|
2024-01-26 14:03:25 +08:00
|
|
|
|
os.makedirs(dir, exist_ok=True)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
|
|
data = {
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"zip_file": zip_file,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
"cb_name": cb_name,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
"code_path": raw_code_path,
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"do_interpret": do_interpret,
|
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_name": llm_model,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"model_device": embedding_device,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
}
|
|
|
|
|
logger.info('create cb data={}'.format(data))
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(create_cb(**data))
|
|
|
|
|
return response.dict()
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/code_base/create_code_base",
|
|
|
|
|
json=data,
|
|
|
|
|
)
|
|
|
|
|
logger.info('response={}'.format(response.json()))
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
def delete_code_base(self, cb_name: str, no_remote_api: bool = None,
|
|
|
|
|
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
|
|
|
|
llm_model: str ="", temperature: float= 0.2
|
|
|
|
|
):
|
2023-11-07 19:44:47 +08:00
|
|
|
|
'''
|
|
|
|
|
删除 code_base
|
|
|
|
|
@param cb_name:
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
data = {
|
2024-01-26 14:03:25 +08:00
|
|
|
|
"cb_name": cb_name,
|
|
|
|
|
"api_key": os.environ["OPENAI_API_KEY"],
|
|
|
|
|
"api_base_url": os.environ["API_BASE_URL"],
|
|
|
|
|
"embed_model": embed_model,
|
|
|
|
|
"embed_model_path": embed_model_path,
|
|
|
|
|
"embed_engine": embed_engine,
|
|
|
|
|
"model_name": llm_model,
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
"model_device": embedding_device
|
2023-11-07 19:44:47 +08:00
|
|
|
|
}
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(delete_cb(**data))
|
|
|
|
|
return response.dict()
|
|
|
|
|
else:
|
|
|
|
|
response = self.post(
|
|
|
|
|
"/code_base/delete_code_base",
|
|
|
|
|
json=cb_name
|
|
|
|
|
)
|
|
|
|
|
logger.info(response.json())
|
|
|
|
|
return self._check_httpx_json_response(response)
|
|
|
|
|
|
|
|
|
|
def list_cb(self, no_remote_api: bool = None):
|
|
|
|
|
'''
|
|
|
|
|
列举 code_base
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
response = run_async(list_cbs())
|
|
|
|
|
return response.data
|
|
|
|
|
else:
|
|
|
|
|
response = self.get("/code_base/list_code_bases")
|
|
|
|
|
data = self._check_httpx_json_response(response)
|
|
|
|
|
return data.get("data", [])
|
|
|
|
|
|
|
|
|
|
|
2023-09-28 10:58:58 +08:00
|
|
|
|
|
|
|
|
|
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
|
|
|
|
'''
|
|
|
|
|
return error message if error occured when requests API
|
|
|
|
|
'''
|
|
|
|
|
if isinstance(data, dict):
|
|
|
|
|
if key in data:
|
|
|
|
|
return data[key]
|
|
|
|
|
if "code" in data and data["code"] != 200:
|
|
|
|
|
return data["msg"]
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
|
|
|
|
'''
|
|
|
|
|
return error message if error occured when requests API
|
|
|
|
|
'''
|
|
|
|
|
if (isinstance(data, dict)
|
|
|
|
|
and key in data
|
|
|
|
|
and "code" in data
|
|
|
|
|
and data["code"] == 200):
|
|
|
|
|
return data[key]
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
api = ApiRequest(no_remote_api=True)
|
|
|
|
|
|
|
|
|
|
# print(api.chat_fastchat(
|
|
|
|
|
# messages=[{"role": "user", "content": "hello"}]
|
|
|
|
|
# ))
|
|
|
|
|
|
|
|
|
|
# with api.chat_chat("你好") as r:
|
|
|
|
|
# for t in r.iter_text(None):
|
|
|
|
|
# print(t)
|
|
|
|
|
|
|
|
|
|
# r = api.chat_chat("你好", no_remote_api=True)
|
|
|
|
|
# for t in r:
|
|
|
|
|
# print(t)
|
|
|
|
|
|
|
|
|
|
# r = api.duckduckgo_search_chat("室温超导最新研究进展", no_remote_api=True)
|
|
|
|
|
# for t in r:
|
|
|
|
|
# print(t)
|
|
|
|
|
|
|
|
|
|
# print(api.list_knowledge_bases())
|