130 lines
4.5 KiB
Python
130 lines
4.5 KiB
Python
import sys, os, json, traceback, uvicorn, argparse
|
||
|
||
src_dir = os.path.join(
|
||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
)
|
||
sys.path.append(src_dir)
|
||
|
||
from loguru import logger
|
||
|
||
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi import File, UploadFile
|
||
|
||
from coagent.utils.server_utils import BaseResponse, ListResponse, DataResponse
|
||
from configs.server_config import OPEN_CROSS_DOMAIN, SDFILE_API_SERVER
|
||
from configs.model_config import JUPYTER_WORK_PATH
|
||
|
||
|
||
VERSION = "v0.1.0"
|
||
|
||
async def sd_upload_file(file: UploadFile = File(...), work_dir: str = JUPYTER_WORK_PATH):
|
||
# 保存上传的文件到服务器
|
||
try:
|
||
content = await file.read()
|
||
with open(os.path.join(work_dir, file.filename), "wb") as f:
|
||
f.write(content)
|
||
return {"data": True}
|
||
except:
|
||
return {"data": False}
|
||
|
||
|
||
async def sd_download_file(filename: str, save_filename: str = "filename_to_download.ext", work_dir: str = JUPYTER_WORK_PATH):
|
||
# 从服务器下载文件
|
||
logger.debug(f"{os.path.join(work_dir, filename)}")
|
||
return {"data": os.path.join(work_dir, filename), "filename": save_filename}
|
||
# return {"data": FileResponse(os.path.join(work_dir, filename), filename=save_filename)}
|
||
|
||
|
||
async def sd_list_files(work_dir: str = JUPYTER_WORK_PATH):
|
||
# 去除目录
|
||
return {"data": os.listdir(work_dir)}
|
||
|
||
|
||
async def sd_delete_file(filename: str, work_dir: str = JUPYTER_WORK_PATH):
|
||
# 去除目录
|
||
try:
|
||
os.remove(os.path.join(work_dir, filename))
|
||
return {"data": True}
|
||
except:
|
||
return {"data": False}
|
||
|
||
|
||
def create_app(open_cross_domain, version=VERSION):
|
||
app = FastAPI(
|
||
title="DevOps-ChatBot API Server",
|
||
version=version
|
||
)
|
||
# MakeFastAPIOffline(app)
|
||
# Add CORS middleware to allow all origins
|
||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||
if open_cross_domain:
|
||
# if OPEN_CROSS_DOMAIN:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
app.post("/sdfiles/upload",
|
||
tags=["files upload and download"],
|
||
response_model=BaseResponse,
|
||
summary="上传文件到沙盒"
|
||
)(sd_upload_file)
|
||
|
||
app.get("/sdfiles/download",
|
||
tags=["files upload and download"],
|
||
response_model=DataResponse,
|
||
summary="从沙盒下载文件"
|
||
)(sd_download_file)
|
||
|
||
app.get("/sdfiles/list",
|
||
tags=["files upload and download"],
|
||
response_model=ListResponse,
|
||
summary="从沙盒工作目录展示文件"
|
||
)(sd_list_files)
|
||
|
||
app.get("/sdfiles/delete",
|
||
tags=["files upload and download"],
|
||
response_model=BaseResponse,
|
||
summary="从沙盒工作目录中删除文件"
|
||
)(sd_delete_file)
|
||
return app
|
||
|
||
|
||
|
||
def run_api(host, port, open_cross_domain, **kwargs):
|
||
app = create_app(open_cross_domain)
|
||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||
uvicorn.run(app,
|
||
host=host,
|
||
port=port,
|
||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||
)
|
||
else:
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(prog='DevOps-ChatBot',
|
||
description='About DevOps-ChatBot, local knowledge based LLM with langchain'
|
||
' | 基于本地知识库的 LLM 问答')
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default="7862")
|
||
# parser.add_argument("--port", type=int, default=SDFILE_API_SERVER["port"])
|
||
parser.add_argument("--open_cross_domain", type=bool, default=False)
|
||
parser.add_argument("--ssl_keyfile", type=str)
|
||
parser.add_argument("--ssl_certfile", type=str)
|
||
# 初始化消息
|
||
args = parser.parse_args()
|
||
args_dict = vars(args)
|
||
run_api(host=args.host,
|
||
port=args.port,
|
||
open_cross_domain=args.open_cross_domain,
|
||
ssl_keyfile=args.ssl_keyfile,
|
||
ssl_certfile=args.ssl_certfile,
|
||
) |