176 lines
8.2 KiB
Python
176 lines
8.2 KiB
Python
from fastapi import Body, Request
|
||
from fastapi.responses import StreamingResponse
|
||
from typing import List
|
||
from loguru import logger
|
||
import importlib
|
||
import copy
|
||
import json
|
||
|
||
from configs.model_config import (
|
||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||
|
||
from dev_opsgpt.tools import (
|
||
toLangchainTools,
|
||
TOOL_DICT, TOOL_SETS
|
||
)
|
||
|
||
from dev_opsgpt.connector.phase import BasePhase
|
||
from dev_opsgpt.connector.agents import BaseAgent, ReactAgent
|
||
from dev_opsgpt.connector.chains import BaseChain
|
||
from dev_opsgpt.connector.connector_schema import (
|
||
Message,
|
||
load_phase_configs, load_chain_configs, load_role_configs
|
||
)
|
||
from dev_opsgpt.connector.shcema import Memory
|
||
|
||
from dev_opsgpt.chat.utils import History, wrap_done
|
||
from dev_opsgpt.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS
|
||
|
||
PHASE_MODULE = importlib.import_module("dev_opsgpt.connector.phase")
|
||
|
||
|
||
|
||
class AgentChat:
|
||
|
||
def __init__(
|
||
self,
|
||
engine_name: str = "",
|
||
top_k: int = 1,
|
||
stream: bool = False,
|
||
) -> None:
|
||
self.top_k = top_k
|
||
self.stream = stream
|
||
|
||
def chat(
|
||
self,
|
||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||
phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]),
|
||
chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]),
|
||
history: List[History] = Body(
|
||
[], description="历史对话",
|
||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||
),
|
||
doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||
code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]),
|
||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||
stream: bool = Body(False, description="流式输出"),
|
||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||
choose_tools: List[str] = Body([], description="选择tool的集合"),
|
||
do_search: bool = Body(False, description="是否进行搜索"),
|
||
do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"),
|
||
do_code_retrieval: bool = Body(False, description="是否执行代码检索"),
|
||
do_tool_retrieval: bool = Body(False, description="是否执行工具检索"),
|
||
custom_phase_configs: dict = Body({}, description="自定义phase配置"),
|
||
custom_chain_configs: dict = Body({}, description="自定义chain配置"),
|
||
custom_role_configs: dict = Body({}, description="自定义role配置"),
|
||
history_node_list: List = Body([], description="代码历史相关节点"),
|
||
isDetaild: bool = Body([], description="是否输出完整的agent相关内容"),
|
||
**kargs
|
||
) -> Message:
|
||
|
||
# update configs
|
||
phase_configs, chain_configs, agent_configs = self.update_configs(
|
||
custom_phase_configs, custom_chain_configs, custom_role_configs)
|
||
|
||
logger.info('phase_configs={}'.format(phase_configs))
|
||
logger.info('chain_configs={}'.format(chain_configs))
|
||
logger.info('agent_configs={}'.format(agent_configs))
|
||
logger.info('phase_name')
|
||
logger.info('chain_name')
|
||
|
||
# choose tools
|
||
tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT])
|
||
input_message = Message(
|
||
role_content=query,
|
||
role_type="human",
|
||
role_name="user",
|
||
input_query=query,
|
||
phase_name=phase_name,
|
||
chain_name=chain_name,
|
||
do_search=do_search,
|
||
do_doc_retrieval=do_doc_retrieval,
|
||
do_code_retrieval=do_code_retrieval,
|
||
do_tool_retrieval=do_tool_retrieval,
|
||
doc_engine_name=doc_engine_name, search_engine_name=search_engine_name,
|
||
code_engine_name=code_engine_name,
|
||
score_threshold=score_threshold, top_k=top_k,
|
||
history_node_list=history_node_list,
|
||
tools=tools
|
||
)
|
||
# history memory mangemant
|
||
history = Memory([
|
||
Message(role_name=i["role"], role_type=i["role"], role_content=i["content"])
|
||
for i in history
|
||
])
|
||
# start to execute
|
||
phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"])
|
||
phase = phase_class(input_message.phase_name,
|
||
task = input_message.task,
|
||
phase_config = phase_configs,
|
||
chain_config = chain_configs,
|
||
role_config = agent_configs,
|
||
do_summary=phase_configs[input_message.phase_name]["do_summary"],
|
||
do_code_retrieval=input_message.do_code_retrieval,
|
||
do_doc_retrieval=input_message.do_doc_retrieval,
|
||
do_search=input_message.do_search,
|
||
)
|
||
output_message, local_memory = phase.step(input_message, history)
|
||
# logger.debug(f"local_memory: {local_memory.to_str_messages(content_key='step_content')}")
|
||
|
||
# return {
|
||
# "answer": output_message.role_content,
|
||
# "db_docs": output_message.db_docs,
|
||
# "search_docs": output_message.search_docs,
|
||
# "code_docs": output_message.code_docs,
|
||
# "figures": output_message.figures
|
||
# }
|
||
|
||
def chat_iterator(message: Message, local_memory: Memory, isDetaild=False):
|
||
result = {
|
||
"answer": "",
|
||
"db_docs": [str(doc) for doc in message.db_docs],
|
||
"search_docs": [str(doc) for doc in message.search_docs],
|
||
"code_docs": [str(doc) for doc in message.code_docs],
|
||
"related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||
"figures": message.figures
|
||
}
|
||
|
||
|
||
related_nodes, has_nodes = [], [ ]
|
||
for nodes in result["related_nodes"]:
|
||
for node in nodes:
|
||
if node not in has_nodes:
|
||
related_nodes.append(node)
|
||
result["related_nodes"] = related_nodes
|
||
|
||
# logger.debug(f"{result['figures'].keys()}")
|
||
message_str = local_memory.to_str_messages(content_key='step_content') if isDetaild else message.role_content
|
||
if self.stream:
|
||
for token in message_str:
|
||
result["answer"] = token
|
||
yield json.dumps(result, ensure_ascii=False)
|
||
else:
|
||
for token in message_str:
|
||
result["answer"] += token
|
||
yield json.dumps(result, ensure_ascii=False)
|
||
|
||
return StreamingResponse(chat_iterator(output_message, local_memory, isDetaild), media_type="text/event-stream")
|
||
|
||
def _chat(self, ):
|
||
pass
|
||
|
||
def update_configs(self, custom_phase_configs, custom_chain_configs, custom_role_configs):
|
||
'''update phase/chain/agent configs'''
|
||
phase_configs = copy.deepcopy(PHASE_CONFIGS)
|
||
phase_configs.update(custom_phase_configs)
|
||
chain_configs = copy.deepcopy(CHAIN_CONFIGS)
|
||
chain_configs.update(custom_chain_configs)
|
||
agent_configs = copy.deepcopy(AGETN_CONFIGS)
|
||
agent_configs.update(custom_role_configs)
|
||
# phase_configs = load_phase_configs(new_phase_configs)
|
||
# chian_configs = load_chain_configs(new_chain_configs)
|
||
# agent_configs = load_role_configs(new_agent_configs)
|
||
return phase_configs, chain_configs, agent_configs |