import re, traceback, uuid, copy, json, os from typing import Union from loguru import logger from langchain.schema import BaseRetriever from coagent.connector.schema import ( Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum ) from coagent.retrieval.base_retrieval import IMRertrieval from coagent.connector.memory_manager import BaseMemoryManager from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval from coagent.sandbox import PyCodeBox, CodeBoxResponse from coagent.llm_models.llm_config import LLMConfig, EmbedConfig from coagent.base_configs.env_config import JUPYTER_WORK_PATH from .utils import parse_dict_to_dict, parse_text_to_dict class MessageUtils: def __init__( self, role: Role = None, sandbox_server: dict = {}, jupyter_work_path: str = JUPYTER_WORK_PATH, embed_config: EmbedConfig = None, llm_config: LLMConfig = None, kb_root_path: str = "", doc_retrieval: Union[BaseRetriever, IMRertrieval] = None, code_retrieval: IMRertrieval = None, search_retrieval: IMRertrieval = None, log_verbose: str = "0" ) -> None: self.role = role self.sandbox_server = sandbox_server self.jupyter_work_path = jupyter_work_path self.embed_config = embed_config self.llm_config = llm_config self.kb_root_path = kb_root_path self.doc_retrieval = doc_retrieval self.code_retrieval = code_retrieval self.search_retrieval = search_retrieval self.codebox = PyCodeBox( remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"), remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"), remote_port=self.sandbox_server.get("port", "5050"), jupyter_work_path=jupyter_work_path, token="mytoken", do_code_exe=True, do_remote=self.sandbox_server.get("do_remote", False), do_check_net=False ) self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose def inherit_extrainfo(self, input_message: Message, output_message: Message): output_message.user_name = input_message.user_name output_message.db_docs = input_message.db_docs output_message.search_docs = input_message.search_docs output_message.code_docs = input_message.code_docs output_message.figures.update(input_message.figures) output_message.origin_query = input_message.origin_query output_message.code_engine_name = input_message.code_engine_name output_message.doc_engine_name = input_message.doc_engine_name output_message.search_engine_name = input_message.search_engine_name output_message.top_k = input_message.top_k output_message.score_threshold = input_message.score_threshold output_message.cb_search_type = input_message.cb_search_type output_message.do_doc_retrieval = input_message.do_doc_retrieval output_message.do_code_retrieval = input_message.do_code_retrieval output_message.do_tool_retrieval = input_message.do_tool_retrieval # output_message.tools = input_message.tools output_message.agents = input_message.agents # update customed_kargs, if exist, keep; else add customed_kargs = copy.deepcopy(input_message.customed_kargs) customed_kargs.update(output_message.customed_kargs) output_message.customed_kargs = customed_kargs return output_message def inherit_baseparam(self, input_message: Message, output_message: Message): # 只更新参数 output_message.doc_engine_name = input_message.doc_engine_name output_message.search_engine_name = input_message.search_engine_name output_message.top_k = input_message.top_k output_message.score_threshold = input_message.score_threshold output_message.cb_search_type = input_message.cb_search_type output_message.do_doc_retrieval = input_message.do_doc_retrieval output_message.do_code_retrieval = input_message.do_code_retrieval output_message.do_tool_retrieval = input_message.do_tool_retrieval # output_message.tools = input_message.tools output_message.agents = input_message.agents # 存在bug导致相同key被覆盖 output_message.customed_kargs.update(input_message.customed_kargs) return output_message def get_extrainfo_step(self, message: Message, do_search, do_doc_retrieval, do_code_retrieval, do_tool_retrieval) -> Message: '''''' if do_search: message = self.get_search_retrieval(message) if do_doc_retrieval: message = self.get_doc_retrieval(message) if do_code_retrieval: message = self.get_code_retrieval(message) if do_tool_retrieval: message = self.get_tool_retrieval(message) return message def get_search_retrieval(self, message: Message,) -> Message: SEARCH_ENGINES = {"duckduckgo": DDGSTool} search_docs = [] for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)): doc.update({"index": idx}) search_docs.append(Doc(**doc)) message.search_docs = search_docs return message def get_doc_retrieval(self, message: Message) -> Message: query = message.role_content knowledge_basename = message.doc_engine_name top_k = message.top_k score_threshold = message.score_threshold if self.doc_retrieval: if isinstance(self.doc_retrieval, BaseRetriever): docs = self.doc_retrieval.get_relevant_documents(query) else: # docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,) docs = self.doc_retrieval.run(query) docs = [ {"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")} for idx, doc in enumerate(docs) ] message.db_docs = [Doc(**doc) for doc in docs] else: if knowledge_basename: docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path) message.db_docs = [Doc(**doc) for doc in docs] return message def get_code_retrieval(self, message: Message) -> Message: query = message.role_content code_engine_name = message.code_engine_name history_node_list = message.history_node_list use_nh = message.use_nh local_graph_path = message.local_graph_path if self.code_retrieval: code_docs = self.code_retrieval.run( query, history_node_list=history_node_list, search_type=message.cb_search_type, code_limit=1 ) else: code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type, llm_config=self.llm_config, embed_config=self.embed_config, use_nh=use_nh, local_graph_path=local_graph_path) message.code_docs = [CodeDoc(**doc) for doc in code_docs] # related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0], # history_node_list.extend([node[0] for node in related_nodes]) return message def get_tool_retrieval(self, message: Message) -> Message: return message def step_router(self, message: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> tuple[Message, ...]: '''''' if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): logger.info(f"message.action_status: {message.action_status}") observation_message = None if message.action_status == ActionStatus.CODE_EXECUTING: message, observation_message = self.code_step(message) elif message.action_status == ActionStatus.TOOL_USING: message, observation_message = self.tool_step(message) elif message.action_status == ActionStatus.CODING2FILE: self.save_code2file(message, self.jupyter_work_path) elif message.action_status == ActionStatus.CODE_RETRIEVAL: pass elif message.action_status == ActionStatus.CODING: pass return message, observation_message def code_step(self, message: Message) -> Message: '''execute code''' # logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}") code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content)) code_prompt = f"The return error after executing the above code is {code_answer.code_exe_response},need to recover.\n" \ if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n" observation_message = Message( user_name=message.user_name, role_name="observation", role_type="function", #self.role.role_type, role_content="", step_content="", input_query=message.code_content, ) uid = str(uuid.uuid1()) if code_answer.code_exe_type == "image/png": message.figures[uid] = code_answer.code_exe_response message.code_answer = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n" message.observation = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n" message.step_content += f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n" # message.role_content += f"\n**Observation:**:执行上述代码后生成一张图片, 图片名为{uid}\n" observation_message.role_content = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n" observation_message.parsed_output = {"Observation": f"The return figure name is {uid} after executing the above code.\n"} else: message.code_answer = code_answer.code_exe_response message.observation = code_answer.code_exe_response message.step_content += f"\n**Observation:**: {code_prompt}\n" # message.role_content += f"\n**Observation:**: {code_prompt}\n" observation_message.role_content = f"\n**Observation:**: {code_prompt}\n" observation_message.parsed_output = {"Observation": f"{code_prompt}\n"} if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): logger.info(f"**Observation:** {message.action_status}, {message.observation}") return message, observation_message def tool_step(self, message: Message) -> Message: '''execute tool''' observation_message = Message( user_name=message.user_name, role_name="observation", role_type="function", #self.role.role_type, role_content="\n**Observation:** there is no tool can execute\n", step_content="", input_query=str(message.tool_params), tools=message.tools, ) if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): logger.info(f"message: {message.action_status}, {message.tool_params}") tool_names = [tool.name for tool in message.tools] if message.tool_name not in tool_names: message.tool_answer = "\n**Observation:** there is no tool can execute.\n" message.observation = "\n**Observation:** there is no tool can execute.\n" # message.role_content += f"\n**Observation:**: 不存在可以执行的tool\n" message.step_content += f"\n**Observation:** there is no tool can execute.\n" observation_message.role_content = f"\n**Observation:** there is no tool can execute.\n" observation_message.parsed_output = {"Observation": "there is no tool can execute.\n"} # logger.debug(message.tool_params) for tool in message.tools: if tool.name == message.tool_params.get("tool_name", ""): tool_res = tool.func(**message.tool_params.get("tool_params", {})) message.tool_answer = tool_res message.observation = tool_res # message.role_content += f"\n**Observation:**: {tool_res}\n" message.step_content += f"\n**Observation:** {tool_res}.\n" observation_message.role_content = f"\n**Observation:** {tool_res}.\n" observation_message.parsed_output = {"Observation": f"{tool_res}.\n"} break if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): logger.info(f"**Observation:** {message.action_status}, {message.observation}") return message, observation_message def parser(self, message: Message) -> Message: '''parse llm output into dict''' content = message.role_content # parse start parsed_dict = parse_text_to_dict(content) spec_parsed_dict = parse_dict_to_dict(parsed_dict) # select parse value action_value = parsed_dict.get('Action Status') if action_value: action_value = action_value.lower() code_content_value = spec_parsed_dict.get('code') if action_value == 'tool_using': tool_params_value = spec_parsed_dict.get('json') else: tool_params_value = None # add parse value to message message.action_status = action_value or "default" message.code_content = code_content_value message.tool_params = tool_params_value message.parsed_output = parsed_dict message.spec_parsed_output = spec_parsed_dict return message def save_code2file(self, message: Message, project_dir="./"): filename = message.parsed_output.get("SaveFileName") code = message.spec_parsed_output.get("code") for k, v in {">": ">", "≥": ">=", "<": "<", "≤": "<="}.items(): code = code.replace(k, v) file_path = os.path.join(project_dir, filename) if not os.path.exists(file_path): os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w") as f: f.write(code)