codefuse-chatbot/examples/sdfile_api.py

130 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys, os, json, traceback, uvicorn, argparse
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(src_dir)
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi import File, UploadFile
from muagent.utils.server_utils import BaseResponse, ListResponse, DataResponse
from configs.server_config import OPEN_CROSS_DOMAIN, SDFILE_API_SERVER
from configs.model_config import JUPYTER_WORK_PATH
VERSION = "v0.1.0"
async def sd_upload_file(file: UploadFile = File(...), work_dir: str = JUPYTER_WORK_PATH):
# 保存上传的文件到服务器
try:
content = await file.read()
with open(os.path.join(work_dir, file.filename), "wb") as f:
f.write(content)
return {"data": True}
except:
return {"data": False}
async def sd_download_file(filename: str, save_filename: str = "filename_to_download.ext", work_dir: str = JUPYTER_WORK_PATH):
# 从服务器下载文件
logger.debug(f"{os.path.join(work_dir, filename)}")
return {"data": os.path.join(work_dir, filename), "filename": save_filename}
# return {"data": FileResponse(os.path.join(work_dir, filename), filename=save_filename)}
async def sd_list_files(work_dir: str = JUPYTER_WORK_PATH):
# 去除目录
return {"data": os.listdir(work_dir)}
async def sd_delete_file(filename: str, work_dir: str = JUPYTER_WORK_PATH):
# 去除目录
try:
os.remove(os.path.join(work_dir, filename))
return {"data": True}
except:
return {"data": False}
def create_app(open_cross_domain, version=VERSION):
app = FastAPI(
title="DevOps-ChatBot API Server",
version=version
)
# MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if open_cross_domain:
# if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.post("/sdfiles/upload",
tags=["files upload and download"],
response_model=BaseResponse,
summary="上传文件到沙盒"
)(sd_upload_file)
app.get("/sdfiles/download",
tags=["files upload and download"],
response_model=DataResponse,
summary="从沙盒下载文件"
)(sd_download_file)
app.get("/sdfiles/list",
tags=["files upload and download"],
response_model=ListResponse,
summary="从沙盒工作目录展示文件"
)(sd_list_files)
app.get("/sdfiles/delete",
tags=["files upload and download"],
response_model=BaseResponse,
summary="从沙盒工作目录中删除文件"
)(sd_delete_file)
return app
def run_api(host, port, open_cross_domain, **kwargs):
app = create_app(open_cross_domain)
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='DevOps-ChatBot',
description='About DevOps-ChatBot, local knowledge based LLM with langchain'
' 基于本地知识库的 LLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default="7862")
# parser.add_argument("--port", type=int, default=SDFILE_API_SERVER["port"])
parser.add_argument("--open_cross_domain", type=bool, default=False)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
run_api(host=args.host,
port=args.port,
open_cross_domain=args.open_cross_domain,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)