codefuse-chatbot/coagent/connector/phase/base_phase.py

255 lines
13 KiB
Python
Raw Normal View History

from typing import List, Union, Dict, Tuple
import os
import json
import importlib
import copy
from loguru import logger
from coagent.connector.agents import BaseAgent
from coagent.connector.chains import BaseChain
from coagent.connector.schema import (
Memory, Task, Message, AgentConfig, ChainConfig, PhaseConfig, LogVerboseEnum,
CompletePhaseConfig,
load_chain_configs, load_phase_configs, load_role_configs
)
from coagent.connector.memory_manager import BaseMemoryManager, LocalMemoryManager
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
from coagent.connector.message_process import MessageUtils
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
# from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
# from configs.server_config import SANDBOX_SERVER
role_configs = load_role_configs(AGETN_CONFIGS)
chain_configs = load_chain_configs(CHAIN_CONFIGS)
phase_configs = load_phase_configs(PHASE_CONFIGS)
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
class BasePhase:
def __init__(
self,
phase_name: str,
phase_config: CompletePhaseConfig = None,
kb_root_path: str = KB_ROOT_PATH,
jupyter_work_path: str = JUPYTER_WORK_PATH,
sandbox_server: dict = {},
embed_config: EmbedConfig = EmbedConfig(),
llm_config: LLMConfig = LLMConfig(),
task: Task = None,
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
base_role_config: Union[dict, str] = AGETN_CONFIGS,
log_verbose: str = "0"
) -> None:
#
self.phase_name = phase_name
self.do_summary = False
self.do_search = False
self.do_code_retrieval = False
self.do_doc_retrieval = False
self.do_tool_retrieval = False
# memory_pool dont have specific order
# self.memory_pool = Memory(messages=[])
self.embed_config = embed_config
self.llm_config = llm_config
self.sandbox_server = sandbox_server
self.jupyter_work_path = jupyter_work_path
self.kb_root_path = kb_root_path
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
self.global_memory = Memory(messages=[])
self.phase_memory: List[Memory] = []
# according phase name to init the phase contains
self.chains: List[BaseChain] = self.init_chains(
phase_name,
phase_config,
task=task,
memory=None,
base_phase_config = base_phase_config,
base_chain_config = base_chain_config,
base_role_config = base_role_config,
)
self.memory_manager: BaseMemoryManager = LocalMemoryManager(
unique_name=phase_name, do_init=True, kb_root_path = kb_root_path, embed_config=embed_config, llm_config=llm_config
)
self.conv_summary_agent = BaseAgent(
role=role_configs["conv_summary"].role,
prompt_config=role_configs["conv_summary"].prompt_config,
task = None, memory = None,
llm_config=self.llm_config,
embed_config=self.embed_config,
sandbox_server=sandbox_server,
jupyter_work_path=jupyter_work_path,
kb_root_path=kb_root_path
)
def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
self.memory_manager.append(query)
summary_message = None
chain_message = Memory(messages=[])
local_phase_memory = Memory(messages=[])
# do_search、do_doc_search、do_code_search
query = self.message_utils.get_extrainfo_step(query, self.do_search, self.do_doc_retrieval, self.do_code_retrieval, self.do_tool_retrieval)
query.parsed_output = query.parsed_output if query.parsed_output else {"origin_query": query.input_query}
query.parsed_output_list = query.parsed_output_list if query.parsed_output_list else [{"origin_query": query.input_query}]
input_message = copy.deepcopy(query)
self.global_memory.append(input_message)
local_phase_memory.append(input_message)
for chain in self.chains:
# chain can supply background and query to next chain
for output_message, local_chain_memory in chain.astep(input_message, history, background=chain_message, memory_manager=self.memory_manager):
# logger.debug(f"local_memory: {local_phase_memory + local_chain_memory}")
yield output_message, local_phase_memory + local_chain_memory
output_message = self.message_utils.inherit_extrainfo(input_message, output_message)
input_message = output_message
# logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}")
# 这一段也有问题
self.global_memory.extend(local_chain_memory)
local_phase_memory.extend(local_chain_memory)
# whether to use summary_llm
if self.do_summary:
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {local_phase_memory.to_str_messages(content_key='step_content')}")
for summary_message in self.conv_summary_agent.astep(query, background=local_phase_memory, memory_manager=self.memory_manager):
pass
# summary_message = Message(**summary_message)
summary_message.role_name = chain.chainConfig.chain_name
summary_message = self.conv_summary_agent.message_utils.parser(summary_message)
summary_message = self.message_utils.inherit_extrainfo(output_message, summary_message)
chain_message.append(summary_message)
message = summary_message or output_message
yield message, local_phase_memory
# 由于不会存在多轮chain执行所以直接保留memory即可
for chain in self.chains:
self.phase_memory.append(chain.global_memory)
# TODOlocal_memory缺少添加summary的过程
message = summary_message or output_message
message.role_name = self.phase_name
yield message, local_phase_memory
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
for message, local_phase_memory in self.astep(query, history=history):
pass
return message, local_phase_memory
def pre_print(self, query, history: Memory = None) -> List[str]:
chain_message = Memory(messages=[])
for chain in self.chains:
chain.pre_print(query, history, background=chain_message, memory_manager=self.memory_manager)
def init_chains(self, phase_name: str, phase_config: CompletePhaseConfig, base_phase_config, base_chain_config,
base_role_config, task=None, memory=None) -> List[BaseChain]:
# load config
role_configs = load_role_configs(base_role_config)
chain_configs = load_chain_configs(base_chain_config)
phase_configs = load_phase_configs(base_phase_config)
chains = []
self.chain_module = importlib.import_module("coagent.connector.chains")
self.agent_module = importlib.import_module("coagent.connector.agents")
phase: PhaseConfig = phase_configs.get(phase_name)
# set phase
self.do_summary = phase.do_summary
self.do_search = phase.do_search
self.do_code_retrieval = phase.do_code_retrieval
self.do_doc_retrieval = phase.do_doc_retrieval
self.do_tool_retrieval = phase.do_tool_retrieval
logger.info(f"start to init the phase, the phase_name is {phase_name}, it contains these chains such as {phase.chains}")
for chain_name in phase.chains:
# logger.debug(f"{chain_configs.keys()}")
chain_config: ChainConfig = chain_configs[chain_name]
logger.info(f"start to init the chain, the chain_name is {chain_name}, it contains these agents such as {chain_config.agents}")
agents = []
for agent_name in chain_config.agents:
agent_config: AgentConfig = role_configs[agent_name]
llm_config = copy.deepcopy(self.llm_config)
llm_config.stop = agent_config.stop
baseAgent: BaseAgent = getattr(self.agent_module, agent_config.role.agent_type)
base_agent = baseAgent(
role=agent_config.role,
prompt_config = agent_config.prompt_config,
prompt_manager_type=agent_config.prompt_manager_type,
task = task,
memory = memory,
chat_turn=agent_config.chat_turn,
focus_agents=agent_config.focus_agents,
focus_message_keys=agent_config.focus_message_keys,
llm_config=llm_config,
embed_config=self.embed_config,
sandbox_server=self.sandbox_server,
jupyter_work_path=self.jupyter_work_path,
kb_root_path=self.kb_root_path,
log_verbose=self.log_verbose
)
if agent_config.role.agent_type == "SelectorAgent":
for group_agent_name in agent_config.group_agents:
group_agent_config = role_configs[group_agent_name]
llm_config = copy.deepcopy(self.llm_config)
llm_config.stop = group_agent_config.stop
baseAgent: BaseAgent = getattr(self.agent_module, group_agent_config.role.agent_type)
group_base_agent = baseAgent(
role=group_agent_config.role,
prompt_config = group_agent_config.prompt_config,
prompt_manager_type=agent_config.prompt_manager_type,
task = task,
memory = memory,
chat_turn=group_agent_config.chat_turn,
focus_agents=group_agent_config.focus_agents,
focus_message_keys=group_agent_config.focus_message_keys,
llm_config=llm_config,
embed_config=self.embed_config,
sandbox_server=self.sandbox_server,
jupyter_work_path=self.jupyter_work_path,
kb_root_path=self.kb_root_path,
log_verbose=self.log_verbose
)
base_agent.group_agents.append(group_base_agent)
agents.append(base_agent)
chain_instance = BaseChain(
agents, chain_config.chat_turn,
do_checker=chain_configs[chain_name].do_checker,
jupyter_work_path=self.jupyter_work_path,
sandbox_server=self.sandbox_server,
embed_config=self.embed_config,
llm_config=self.llm_config,
kb_root_path=self.kb_root_path,
log_verbose=self.log_verbose
)
chains.append(chain_instance)
return chains
def update(self) -> Memory:
pass
def get_memory(self, ) -> Memory:
return Memory.from_memory_list(
[chain.get_memory() for chain in self.chains]
)
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str:
memory = self.global_memory if do_all_memory else self.phase_memory
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
def get_chains_memory(self, content_key="role_content") -> List[Tuple]:
return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory]
def get_chains_memory_str(self, content_key="role_content") -> str:
return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains])