codefuse-chatbot/examples/api.py

214 lines
6.8 KiB
Python
Raw Permalink Normal View History

2023-09-28 10:58:58 +08:00
import nltk
import argparse
import uvicorn, os, sys
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from typing import List
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
2023-09-28 10:58:58 +08:00
sys.path.append(src_dir)
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import VERSION
# from configs.model_config import NLTK_DATA_PATH
# from configs.server_config import OPEN_CROSS_DOMAIN
from muagent.chat import LLMChat, SearchChat, KnowledgeChat
from muagent.service.kb_api import *
from muagent.service.cb_api import *
from muagent.utils.server_utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
# nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
2023-09-28 10:58:58 +08:00
from muagent.chat import LLMChat, SearchChat, KnowledgeChat, CodeChat
from configs.model_config import KB_ROOT_PATH
2023-09-28 10:58:58 +08:00
llmChat = LLMChat()
searchChat = SearchChat()
knowledgeChat = KnowledgeChat(kb_root_path=KB_ROOT_PATH)
codeChat = CodeChat()
2023-09-28 10:58:58 +08:00
async def document():
return RedirectResponse(url="/docs")
def create_app():
app = FastAPI(
title="DevOps-ChatBot API Server",
version=VERSION
)
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if False:
# if OPEN_CROSS_DOMAIN:
2023-09-28 10:58:58 +08:00
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.get("/",
response_model=BaseResponse,
summary="swagger 文档")(document)
# Tag: Chat
# app.post("/chat/fastchat",
# tags=["Chat"],
# summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
app.post("/chat/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)")(llmChat.chat)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledgeChat.chat)
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与搜索引擎对话")(searchChat.chat)
app.post("/chat/code_chat",
tags=["Chat"],
summary="与代码库对话")(codeChat.chat)
2023-09-28 10:58:58 +08:00
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_docs)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化"
)(upload_doc)
app.post("/knowledge_base/delete_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_doc)
app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_doc)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],
summary="下载对应的知识文件")(download_doc)
app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"],
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
app.post("/code_base/create_code_base",
tags=["Code Base Management"],
summary="新建 code_base"
)(create_cb)
app.post("/code_base/delete_code_base",
tags=["Code Base Management"],
summary="删除 code_base"
)(delete_cb)
app.post("/code_base/code_base_chat",
tags=["Code Base Management"],
summary="code_base 对话"
)(search_code)
app.get("/code_base/list_code_bases",
tags=["Code Base Management"],
summary="列举 code_base",
response_model=ListResponse
)(list_cbs)
2023-09-28 10:58:58 +08:00
# # LLM模型相关接口
# app.post("/llm_model/list_models",
# tags=["LLM Model Management"],
# summary="列出当前已加载的模型",
# )(list_llm_models)
# app.post("/llm_model/stop",
# tags=["LLM Model Management"],
# summary="停止指定的LLM模型Model Worker)",
# )(stop_llm_model)
# app.post("/llm_model/change",
# tags=["LLM Model Management"],
# summary="切换指定的LLM模型Model Worker)",
# )(change_llm_model)
return app
app = create_app()
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='DevOps-ChatBot',
description='About DevOps-ChatBot, local knowledge based LLM with langchain'
' 基于本地知识库的 LLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)