codefuse-chatbot/coagent/service/sdfile_api.py.bak

132 lines
4.6 KiB
Python
Raw Normal View History

import sys, os, json, traceback, uvicorn, argparse
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi import File, UploadFile
from coagent.utils.server_utils import BaseResponse, ListResponse, DataResponse
# from configs.server_config import OPEN_CROSS_DOMAIN, SDFILE_API_SERVER
from configs.model_config import (
JUPYTER_WORK_PATH
)
VERSION = "v0.1.0"
async def sd_upload_file(file: UploadFile = File(...), work_dir: str = JUPYTER_WORK_PATH):
# 保存上传的文件到服务器
try:
content = await file.read()
with open(os.path.join(work_dir, file.filename), "wb") as f:
f.write(content)
return {"data": True}
except:
return {"data": False}
async def sd_download_file(filename: str, save_filename: str = "filename_to_download.ext", work_dir: str = JUPYTER_WORK_PATH):
# 从服务器下载文件
logger.debug(f"{os.path.join(work_dir, filename)}")
return {"data": os.path.join(work_dir, filename), "filename": save_filename}
# return {"data": FileResponse(os.path.join(work_dir, filename), filename=save_filename)}
async def sd_list_files(work_dir: str = JUPYTER_WORK_PATH):
# 去除目录
return {"data": os.listdir(work_dir)}
async def sd_delete_file(filename: str, work_dir: str = JUPYTER_WORK_PATH):
# 去除目录
try:
os.remove(os.path.join(work_dir, filename))
return {"data": True}
except:
return {"data": False}
def create_app(open_cross_domain, version=VERSION):
app = FastAPI(
title="DevOps-ChatBot API Server",
version=version
)
# MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if open_cross_domain:
# if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.post("/sdfiles/upload",
tags=["files upload and download"],
response_model=BaseResponse,
summary="上传文件到沙盒"
)(sd_upload_file)
app.get("/sdfiles/download",
tags=["files upload and download"],
response_model=DataResponse,
summary="从沙盒下载文件"
)(sd_download_file)
app.get("/sdfiles/list",
tags=["files upload and download"],
response_model=ListResponse,
summary="从沙盒工作目录展示文件"
)(sd_list_files)
app.get("/sdfiles/delete",
tags=["files upload and download"],
response_model=BaseResponse,
summary="从沙盒工作目录中删除文件"
)(sd_delete_file)
return app
def run_api(host, port, open_cross_domain, **kwargs):
app = create_app(open_cross_domain)
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='DevOps-ChatBot',
description='About DevOps-ChatBot, local knowledge based LLM with langchain'
' 基于本地知识库的 LLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default="7862")
# parser.add_argument("--port", type=int, default=SDFILE_API_SERVER["port"])
parser.add_argument("--open_cross_domain", type=bool, default=False)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
run_api(host=args.host,
port=args.port,
open_cross_domain=args.open_cross_domain,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)