from fastapi import Body, Request from fastapi.responses import StreamingResponse from typing import List, Union, Dict from loguru import logger import importlib import copy import json import os from pathlib import Path # from configs.model_config import ( # llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, # VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from coagent.tools import ( toLangchainTools, TOOL_DICT, TOOL_SETS ) from coagent.connector.phase import BasePhase from coagent.connector.schema import Message from coagent.connector.schema import Memory from coagent.chat.utils import History, wrap_done from coagent.llm_models.llm_config import LLMConfig, EmbedConfig from coagent.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS PHASE_MODULE = importlib.import_module("coagent.connector.phase") class AgentChat: def __init__( self, engine_name: str = "", top_k: int = 1, stream: bool = False, ) -> None: self.top_k = top_k self.stream = stream self.chatPhase_dict: Dict[str, BasePhase] = {} def chat( self, query: str = Body(..., description="用户输入", examples=["hello"]), phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]), chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]), history: List[History] = Body( [], description="历史对话", examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]] ), doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]), search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]), top_k: int = Body(5, description="匹配向量数"), score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), stream: bool = Body(False, description="流式输出"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), choose_tools: List[str] = Body([], description="选择tool的集合"), do_search: bool = Body(False, description="是否进行搜索"), do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"), do_code_retrieval: bool = Body(False, description="是否执行代码检索"), do_tool_retrieval: bool = Body(False, description="是否执行工具检索"), custom_phase_configs: dict = Body({}, description="自定义phase配置"), custom_chain_configs: dict = Body({}, description="自定义chain配置"), custom_role_configs: dict = Body({}, description="自定义role配置"), history_node_list: List = Body([], description="代码历史相关节点"), isDetailed: bool = Body(False, description="是否输出完整的agent相关内容"), upload_file: Union[str, Path, bytes] = "", kb_root_path: str = Body("", description="知识库存储路径"), jupyter_work_path: str = Body("", description="sandbox执行环境"), sandbox_server: str = Body({}, description="代码历史相关节点"), api_key: str = Body(os.environ.get("OPENAI_API_KEY"), description=""), api_base_url: str = Body(os.environ.get("API_BASE_URL"),), embed_model: str = Body("", description="向量模型"), embed_model_path: str = Body("", description="向量模型路径"), model_device: str = Body("", description="模型加载设备"), embed_engine: str = Body("", description="向量模型类型"), model_name: str = Body("", description="llm模型名称"), temperature: float = Body(0.2, description=""), **kargs ) -> Message: # update configs phase_configs, chain_configs, agent_configs = self.update_configs( custom_phase_configs, custom_chain_configs, custom_role_configs) params = locals() params.pop("self") embed_config: EmbedConfig = EmbedConfig(**params) llm_config: LLMConfig = LLMConfig(**params) logger.info('phase_configs={}'.format(phase_configs)) logger.info('chain_configs={}'.format(chain_configs)) logger.info('agent_configs={}'.format(agent_configs)) logger.info('phase_name') logger.info('chain_name') # choose tools tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT]) if upload_file: upload_file_name = upload_file if upload_file and isinstance(upload_file, str) else upload_file.name for _filename_idx in range(len(upload_file_name), 0, -1): if upload_file_name[:_filename_idx] in query: query = query.replace(upload_file_name[:_filename_idx], upload_file_name) break input_message = Message( role_content=query, role_type="user", role_name="human", input_query=query, origin_query=query, phase_name=phase_name, chain_name=chain_name, do_search=do_search, do_doc_retrieval=do_doc_retrieval, do_code_retrieval=do_code_retrieval, do_tool_retrieval=do_tool_retrieval, doc_engine_name=doc_engine_name, search_engine_name=search_engine_name, code_engine_name=code_engine_name, score_threshold=score_threshold, top_k=top_k, history_node_list=history_node_list, tools=tools ) # history memory mangemant history = Memory(messages=[ Message(role_name=i["role"], role_type=i["role"], role_content=i["content"]) for i in history ]) # start to execute phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"]) # TODO 需要把相关信息补充上去 phase = phase_class(input_message.phase_name, task = input_message.task, base_phase_config = phase_configs, base_chain_config = chain_configs, base_role_config = agent_configs, phase_config = None, kb_root_path = kb_root_path, jupyter_work_path = jupyter_work_path, sandbox_server = sandbox_server, embed_config = embed_config, llm_config = llm_config, ) output_message, local_memory = phase.step(input_message, history) def chat_iterator(message: Message, local_memory: Memory, isDetailed=False): step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["user"]) final_content = message.role_content logger.debug(f"{step_content}") result = { "answer": "", "db_docs": [str(doc) for doc in message.db_docs], "search_docs": [str(doc) for doc in message.search_docs], "code_docs": [str(doc) for doc in message.code_docs], "related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0], "figures": message.figures, "step_content": step_content, "final_content": final_content, } related_nodes, has_nodes = [], [ ] for nodes in result["related_nodes"]: for node in nodes: if node not in has_nodes: related_nodes.append(node) result["related_nodes"] = related_nodes # logger.debug(f"{result['figures'].keys()}, isDetailed: {isDetailed}") message_str = step_content if self.stream: for token in message_str: result["answer"] = token yield json.dumps(result, ensure_ascii=False) else: for token in message_str: result["answer"] += token yield json.dumps(result, ensure_ascii=False) return StreamingResponse(chat_iterator(output_message, local_memory, isDetailed), media_type="text/event-stream") def achat( self, query: str = Body(..., description="用户输入", examples=["hello"]), phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]), chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]), history: List[History] = Body( [], description="历史对话", examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]] ), doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]), search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]), cb_search_type: str = Body(..., description="代码查询模式", examples=["tag"]), top_k: int = Body(5, description="匹配向量数"), score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), stream: bool = Body(False, description="流式输出"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), choose_tools: List[str] = Body([], description="选择tool的集合"), do_search: bool = Body(False, description="是否进行搜索"), do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"), do_code_retrieval: bool = Body(False, description="是否执行代码检索"), do_tool_retrieval: bool = Body(False, description="是否执行工具检索"), custom_phase_configs: dict = Body({}, description="自定义phase配置"), custom_chain_configs: dict = Body({}, description="自定义chain配置"), custom_role_configs: dict = Body({}, description="自定义role配置"), history_node_list: List = Body([], description="代码历史相关节点"), isDetailed: bool = Body(False, description="是否输出完整的agent相关内容"), upload_file: Union[str, Path, bytes] = "", kb_root_path: str = Body("", description="知识库存储路径"), jupyter_work_path: str = Body("", description="sandbox执行环境"), sandbox_server: str = Body({}, description="代码历史相关节点"), api_key: str = Body(os.environ["OPENAI_API_KEY"], description=""), api_base_url: str = Body(os.environ.get("API_BASE_URL"),), embed_model: str = Body("", description="向量模型"), embed_model_path: str = Body("", description="向量模型路径"), model_device: str = Body("", description="模型加载设备"), embed_engine: str = Body("", description="向量模型类型"), model_name: str = Body("", description="llm模型名称"), temperature: float = Body(0.2, description=""), **kargs ) -> Message: # update configs phase_configs, chain_configs, agent_configs = self.update_configs( custom_phase_configs, custom_chain_configs, custom_role_configs) # params = locals() params.pop("self") embed_config: EmbedConfig = EmbedConfig(**params) llm_config: LLMConfig = LLMConfig(**params) # choose tools tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT]) if upload_file: upload_file_name = upload_file if upload_file and isinstance(upload_file, str) else upload_file.name for _filename_idx in range(len(upload_file_name), 0, -1): if upload_file_name[:_filename_idx] in query: query = query.replace(upload_file_name[:_filename_idx], upload_file_name) break input_message = Message( role_content=query, role_type="user", role_name="human", input_query=query, origin_query=query, phase_name=phase_name, chain_name=chain_name, do_search=do_search, do_doc_retrieval=do_doc_retrieval, do_code_retrieval=do_code_retrieval, do_tool_retrieval=do_tool_retrieval, doc_engine_name=doc_engine_name, search_engine_name=search_engine_name, code_engine_name=code_engine_name, cb_search_type=cb_search_type, score_threshold=score_threshold, top_k=top_k, history_node_list=history_node_list, tools=tools ) # history memory mangemant history = Memory(messages=[ Message(role_name=i["role"], role_type=i["role"], role_content=i["content"]) for i in history ]) # start to execute if phase_configs[input_message.phase_name]["phase_type"] not in self.chatPhase_dict: phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"]) phase = phase_class(input_message.phase_name, task = input_message.task, base_phase_config = phase_configs, base_chain_config = chain_configs, base_role_config = agent_configs, phase_config = None, kb_root_path = kb_root_path, jupyter_work_path = jupyter_work_path, sandbox_server = sandbox_server, embed_config = embed_config, llm_config = llm_config, ) self.chatPhase_dict[phase_configs[input_message.phase_name]["phase_type"]] = phase else: phase = self.chatPhase_dict[phase_configs[input_message.phase_name]["phase_type"]] def chat_iterator(message: Message, local_memory: Memory, isDetailed=False): step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["human"]) step_content = "\n\n".join([f"{v}" for parsed_output in local_memory.get_parserd_output_list()[1:] for k, v in parsed_output.items() if k not in ["Action Status"]]) final_content = message.role_content result = { "answer": "", "db_docs": [str(doc) for doc in message.db_docs], "search_docs": [str(doc) for doc in message.search_docs], "code_docs": [str(doc) for doc in message.code_docs], "related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0], "figures": message.figures, "step_content": step_content or final_content, "final_content": final_content, } related_nodes, has_nodes = [], [ ] for nodes in result["related_nodes"]: for node in nodes: if node not in has_nodes: related_nodes.append(node) result["related_nodes"] = related_nodes # logger.debug(f"{result['figures'].keys()}, isDetailed: {isDetailed}") message_str = step_content if self.stream: for token in message_str: result["answer"] = token yield json.dumps(result, ensure_ascii=False) else: for token in message_str: result["answer"] += token yield json.dumps(result, ensure_ascii=False) for output_message, local_memory in phase.astep(input_message, history): # logger.debug(f"output_message: {output_message}") # output_message = Message(**output_message) # local_memory = Memory(**local_memory) for result in chat_iterator(output_message, local_memory, isDetailed): yield result def _chat(self, ): pass def update_configs(self, custom_phase_configs, custom_chain_configs, custom_role_configs): '''update phase/chain/agent configs''' phase_configs = copy.deepcopy(PHASE_CONFIGS) phase_configs.update(custom_phase_configs) chain_configs = copy.deepcopy(CHAIN_CONFIGS) chain_configs.update(custom_chain_configs) agent_configs = copy.deepcopy(AGETN_CONFIGS) agent_configs.update(custom_role_configs) # phase_configs = load_phase_configs(new_phase_configs) # chian_configs = load_chain_configs(new_chain_configs) # agent_configs = load_role_configs(new_agent_configs) return phase_configs, chain_configs, agent_configs