130 lines
6.5 KiB
Python
130 lines
6.5 KiB
Python
from typing import List, Tuple, Union
|
|
from loguru import logger
|
|
import copy, os
|
|
|
|
from langchain.schema import BaseRetriever
|
|
|
|
from coagent.connector.agents import BaseAgent
|
|
from coagent.connector.schema import (
|
|
Memory, Role, Message, ActionStatus, ChainConfig,
|
|
load_role_configs
|
|
)
|
|
from coagent.connector.memory_manager import BaseMemoryManager
|
|
from coagent.connector.message_process import MessageUtils
|
|
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
|
from coagent.connector.configs.agent_config import AGETN_CONFIGS
|
|
role_configs = load_role_configs(AGETN_CONFIGS)
|
|
|
|
|
|
class BaseChain:
|
|
def __init__(
|
|
self,
|
|
chainConfig: ChainConfig,
|
|
agents: List[BaseAgent],
|
|
# chat_turn: int = 1,
|
|
# do_checker: bool = False,
|
|
sandbox_server: dict = {},
|
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
|
kb_root_path: str = KB_ROOT_PATH,
|
|
llm_config: LLMConfig = LLMConfig(),
|
|
embed_config: EmbedConfig = None,
|
|
doc_retrieval: Union[BaseRetriever] = None,
|
|
code_retrieval = None,
|
|
search_retrieval = None,
|
|
log_verbose: str = "0"
|
|
) -> None:
|
|
self.chainConfig = chainConfig
|
|
self.agents: List[BaseAgent] = agents
|
|
self.chat_turn = chainConfig.chat_turn
|
|
self.do_checker = chainConfig.do_checker
|
|
self.sandbox_server = sandbox_server
|
|
self.jupyter_work_path = jupyter_work_path
|
|
self.llm_config = llm_config
|
|
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
|
self.checker = BaseAgent(role=role_configs["checker"].role,
|
|
prompt_config=role_configs["checker"].prompt_config,
|
|
task = None, memory = None,
|
|
llm_config=llm_config, embed_config=embed_config,
|
|
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
|
|
kb_root_path=kb_root_path,
|
|
doc_retrieval=doc_retrieval, code_retrieval=code_retrieval,
|
|
search_retrieval=search_retrieval
|
|
)
|
|
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
|
# all memory created by agent until instance deleted
|
|
self.global_memory = Memory(messages=[])
|
|
|
|
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
|
'''execute chain'''
|
|
for output_message, local_memory in self.astep(query, history, background, memory_manager):
|
|
pass
|
|
return output_message, local_memory
|
|
|
|
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
|
'''execute chain'''
|
|
for agent in self.agents:
|
|
agent.pre_print(query, history, background=background, memory_manager=memory_manager)
|
|
|
|
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]:
|
|
'''execute chain'''
|
|
local_memory = Memory(messages=[])
|
|
input_message = copy.deepcopy(query)
|
|
step_nums = copy.deepcopy(self.chat_turn)
|
|
check_message = None
|
|
|
|
# if input_message not in memory_manager:
|
|
# memory_manager.append(input_message)
|
|
|
|
self.global_memory.append(input_message)
|
|
# local_memory.append(input_message)
|
|
while step_nums > 0:
|
|
for agent in self.agents:
|
|
for output_message in agent.astep(input_message, history, background=background, memory_manager=memory_manager):
|
|
# logger.debug(f"local_memory {local_memory + output_message}")
|
|
yield output_message, local_memory + output_message
|
|
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
|
# according the output to choose one action for code_content or tool_content
|
|
# output_message = self.messageUtils.parser(output_message)
|
|
yield output_message, local_memory + output_message
|
|
# output_message = self.step_router(output_message)
|
|
input_message = output_message
|
|
self.global_memory.append(output_message)
|
|
|
|
local_memory.append(output_message)
|
|
# when get finished signal can stop early
|
|
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPPED:
|
|
action_status = False
|
|
break
|
|
if output_message.action_status == ActionStatus.FINISHED:
|
|
break
|
|
|
|
if self.do_checker and self.chat_turn > 1:
|
|
for check_message in self.checker.astep(query, background=local_memory, memory_manager=memory_manager):
|
|
pass
|
|
check_message = self.messageUtils.parser(check_message)
|
|
check_message = self.messageUtils.inherit_extrainfo(output_message, check_message)
|
|
# logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
|
|
|
if check_message.action_status == ActionStatus.FINISHED:
|
|
self.global_memory.append(check_message)
|
|
break
|
|
step_nums -= 1
|
|
#
|
|
output_message = check_message or output_message # 返回chain和checker的结果
|
|
output_message.input_query = query.input_query # chain和chain之间消息通信不改变问题
|
|
yield output_message, local_memory
|
|
|
|
def get_memory(self, content_key="role_content") -> Memory:
|
|
memory = self.global_memory
|
|
return memory.to_tuple_messages(content_key=content_key)
|
|
|
|
def get_memory_str(self, content_key="role_content") -> Memory:
|
|
memory = self.global_memory
|
|
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
|
|
|
def get_agents_memory(self, content_key="role_content"):
|
|
return [agent.get_memory(content_key=content_key) for agent in self.agents]
|
|
|
|
def get_agents_memory_str(self, content_key="role_content"):
|
|
return "************".join([f"{agent.role.role_name}\n" + agent.get_memory_str(content_key=content_key) for agent in self.agents]) |