272 lines
14 KiB
Python
272 lines
14 KiB
Python
from typing import List, Union, Dict, Tuple
|
||
import os
|
||
import json
|
||
import importlib
|
||
import copy
|
||
from loguru import logger
|
||
|
||
from langchain.schema import BaseRetriever
|
||
|
||
from coagent.connector.agents import BaseAgent
|
||
from coagent.connector.chains import BaseChain
|
||
from coagent.connector.schema import (
|
||
Memory, Task, Message, AgentConfig, ChainConfig, PhaseConfig, LogVerboseEnum,
|
||
CompletePhaseConfig,
|
||
load_chain_configs, load_phase_configs, load_role_configs
|
||
)
|
||
from coagent.connector.memory_manager import BaseMemoryManager, LocalMemoryManager
|
||
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||
from coagent.connector.message_process import MessageUtils
|
||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||
|
||
|
||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||
|
||
|
||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
|
||
class BasePhase:
|
||
|
||
def __init__(
|
||
self,
|
||
phase_name: str,
|
||
phase_config: CompletePhaseConfig = None,
|
||
kb_root_path: str = KB_ROOT_PATH,
|
||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||
sandbox_server: dict = {},
|
||
embed_config: EmbedConfig = None,
|
||
llm_config: LLMConfig = None,
|
||
task: Task = None,
|
||
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
|
||
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
||
base_role_config: Union[dict, str] = AGETN_CONFIGS,
|
||
chains: List[BaseChain] = [],
|
||
doc_retrieval: Union[BaseRetriever] = None,
|
||
code_retrieval = None,
|
||
search_retrieval = None,
|
||
log_verbose: str = "0"
|
||
) -> None:
|
||
#
|
||
self.phase_name = phase_name
|
||
self.do_summary = False
|
||
self.do_search = search_retrieval is not None
|
||
self.do_code_retrieval = code_retrieval is not None
|
||
self.do_doc_retrieval = doc_retrieval is not None
|
||
self.do_tool_retrieval = False
|
||
# memory_pool dont have specific order
|
||
# self.memory_pool = Memory(messages=[])
|
||
self.embed_config = embed_config
|
||
self.llm_config = llm_config
|
||
self.sandbox_server = sandbox_server
|
||
self.jupyter_work_path = jupyter_work_path
|
||
self.kb_root_path = kb_root_path
|
||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||
# TODO透传
|
||
self.doc_retrieval = doc_retrieval
|
||
self.code_retrieval = code_retrieval
|
||
self.search_retrieval = search_retrieval
|
||
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||
self.global_memory = Memory(messages=[])
|
||
self.phase_memory: List[Memory] = []
|
||
# according phase name to init the phase contains
|
||
self.chains: List[BaseChain] = chains if chains else self.init_chains(
|
||
phase_name,
|
||
phase_config,
|
||
task=task,
|
||
memory=None,
|
||
base_phase_config = base_phase_config,
|
||
base_chain_config = base_chain_config,
|
||
base_role_config = base_role_config,
|
||
)
|
||
self.memory_manager: BaseMemoryManager = LocalMemoryManager(
|
||
unique_name=phase_name, do_init=True, kb_root_path = kb_root_path, embed_config=embed_config, llm_config=llm_config
|
||
)
|
||
self.conv_summary_agent = BaseAgent(
|
||
role=role_configs["conv_summary"].role,
|
||
prompt_config=role_configs["conv_summary"].prompt_config,
|
||
task = None, memory = None,
|
||
llm_config=self.llm_config,
|
||
embed_config=self.embed_config,
|
||
sandbox_server=sandbox_server,
|
||
jupyter_work_path=jupyter_work_path,
|
||
kb_root_path=kb_root_path
|
||
)
|
||
|
||
def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||
if reinit_memory:
|
||
self.memory_manager.re_init(reinit_memory)
|
||
self.memory_manager.append(query)
|
||
summary_message = None
|
||
chain_message = Memory(messages=[])
|
||
local_phase_memory = Memory(messages=[])
|
||
# do_search、do_doc_search、do_code_search
|
||
query = self.message_utils.get_extrainfo_step(query, self.do_search, self.do_doc_retrieval, self.do_code_retrieval, self.do_tool_retrieval)
|
||
query.parsed_output = query.parsed_output if query.parsed_output else {"origin_query": query.input_query}
|
||
query.parsed_output_list = query.parsed_output_list if query.parsed_output_list else [{"origin_query": query.input_query}]
|
||
input_message = copy.deepcopy(query)
|
||
|
||
self.global_memory.append(input_message)
|
||
local_phase_memory.append(input_message)
|
||
for chain in self.chains:
|
||
# chain can supply background and query to next chain
|
||
for output_message, local_chain_memory in chain.astep(input_message, history, background=chain_message, memory_manager=self.memory_manager):
|
||
# logger.debug(f"local_memory: {local_phase_memory + local_chain_memory}")
|
||
yield output_message, local_phase_memory + local_chain_memory
|
||
|
||
output_message = self.message_utils.inherit_extrainfo(input_message, output_message)
|
||
input_message = output_message
|
||
# logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}")
|
||
# 这一段也有问题
|
||
self.global_memory.extend(local_chain_memory)
|
||
local_phase_memory.extend(local_chain_memory)
|
||
|
||
# whether to use summary_llm
|
||
if self.do_summary:
|
||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {local_phase_memory.to_str_messages(content_key='step_content')}")
|
||
for summary_message in self.conv_summary_agent.astep(query, background=local_phase_memory, memory_manager=self.memory_manager):
|
||
pass
|
||
# summary_message = Message(**summary_message)
|
||
summary_message.role_name = chain.chainConfig.chain_name
|
||
summary_message = self.conv_summary_agent.message_utils.parser(summary_message)
|
||
summary_message = self.message_utils.inherit_extrainfo(output_message, summary_message)
|
||
chain_message.append(summary_message)
|
||
|
||
message = summary_message or output_message
|
||
yield message, local_phase_memory
|
||
|
||
# 由于不会存在多轮chain执行,所以直接保留memory即可
|
||
for chain in self.chains:
|
||
self.phase_memory.append(chain.global_memory)
|
||
# TODO:local_memory缺少添加summary的过程
|
||
message = summary_message or output_message
|
||
message.role_name = self.phase_name
|
||
yield message, local_phase_memory
|
||
|
||
def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||
for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory):
|
||
pass
|
||
return message, local_phase_memory
|
||
|
||
def pre_print(self, query, history: Memory = None) -> List[str]:
|
||
chain_message = Memory(messages=[])
|
||
for chain in self.chains:
|
||
chain.pre_print(query, history, background=chain_message, memory_manager=self.memory_manager)
|
||
|
||
def init_chains(self, phase_name: str, phase_config: CompletePhaseConfig, base_phase_config, base_chain_config,
|
||
base_role_config, task=None, memory=None) -> List[BaseChain]:
|
||
# load config
|
||
role_configs = load_role_configs(base_role_config)
|
||
chain_configs = load_chain_configs(base_chain_config)
|
||
phase_configs = load_phase_configs(base_phase_config)
|
||
|
||
chains = []
|
||
self.chain_module = importlib.import_module("coagent.connector.chains")
|
||
self.agent_module = importlib.import_module("coagent.connector.agents")
|
||
|
||
phase: PhaseConfig = phase_configs.get(phase_name)
|
||
# set phase
|
||
self.do_summary = phase.do_summary
|
||
self.do_search = phase.do_search
|
||
self.do_code_retrieval = phase.do_code_retrieval
|
||
self.do_doc_retrieval = phase.do_doc_retrieval
|
||
self.do_tool_retrieval = phase.do_tool_retrieval
|
||
logger.info(f"start to init the phase, the phase_name is {phase_name}, it contains these chains such as {phase.chains}")
|
||
|
||
for chain_name in phase.chains:
|
||
# logger.debug(f"{chain_configs.keys()}")
|
||
chain_config: ChainConfig = chain_configs[chain_name]
|
||
logger.info(f"start to init the chain, the chain_name is {chain_name}, it contains these agents such as {chain_config.agents}")
|
||
|
||
agents = []
|
||
for agent_name in chain_config.agents:
|
||
agent_config: AgentConfig = role_configs[agent_name]
|
||
llm_config = copy.deepcopy(self.llm_config)
|
||
llm_config.stop = agent_config.stop
|
||
baseAgent: BaseAgent = getattr(self.agent_module, agent_config.role.agent_type)
|
||
base_agent = baseAgent(
|
||
role=agent_config.role,
|
||
prompt_config = agent_config.prompt_config,
|
||
prompt_manager_type=agent_config.prompt_manager_type,
|
||
task = task,
|
||
memory = memory,
|
||
chat_turn=agent_config.chat_turn,
|
||
focus_agents=agent_config.focus_agents,
|
||
focus_message_keys=agent_config.focus_message_keys,
|
||
llm_config=llm_config,
|
||
embed_config=self.embed_config,
|
||
sandbox_server=self.sandbox_server,
|
||
jupyter_work_path=self.jupyter_work_path,
|
||
kb_root_path=self.kb_root_path,
|
||
doc_retrieval=self.doc_retrieval,
|
||
code_retrieval=self.code_retrieval,
|
||
search_retrieval=self.search_retrieval,
|
||
log_verbose=self.log_verbose
|
||
)
|
||
if agent_config.role.agent_type == "SelectorAgent":
|
||
for group_agent_name in agent_config.group_agents:
|
||
group_agent_config = role_configs[group_agent_name]
|
||
llm_config = copy.deepcopy(self.llm_config)
|
||
llm_config.stop = group_agent_config.stop
|
||
baseAgent: BaseAgent = getattr(self.agent_module, group_agent_config.role.agent_type)
|
||
group_base_agent = baseAgent(
|
||
role=group_agent_config.role,
|
||
prompt_config = group_agent_config.prompt_config,
|
||
prompt_manager_type=group_agent_config.prompt_manager_type,
|
||
task = task,
|
||
memory = memory,
|
||
chat_turn=group_agent_config.chat_turn,
|
||
focus_agents=group_agent_config.focus_agents,
|
||
focus_message_keys=group_agent_config.focus_message_keys,
|
||
llm_config=llm_config,
|
||
embed_config=self.embed_config,
|
||
sandbox_server=self.sandbox_server,
|
||
jupyter_work_path=self.jupyter_work_path,
|
||
kb_root_path=self.kb_root_path,
|
||
doc_retrieval=self.doc_retrieval,
|
||
code_retrieval=self.code_retrieval,
|
||
search_retrieval=self.search_retrieval,
|
||
log_verbose=self.log_verbose
|
||
)
|
||
base_agent.group_agents.append(group_base_agent)
|
||
|
||
agents.append(base_agent)
|
||
|
||
chain_instance = BaseChain(
|
||
chain_config,
|
||
agents,
|
||
jupyter_work_path=self.jupyter_work_path,
|
||
sandbox_server=self.sandbox_server,
|
||
embed_config=self.embed_config,
|
||
llm_config=self.llm_config,
|
||
kb_root_path=self.kb_root_path,
|
||
doc_retrieval=self.doc_retrieval,
|
||
code_retrieval=self.code_retrieval,
|
||
search_retrieval=self.search_retrieval,
|
||
log_verbose=self.log_verbose
|
||
)
|
||
chains.append(chain_instance)
|
||
|
||
return chains
|
||
|
||
def update(self) -> Memory:
|
||
pass
|
||
|
||
def get_memory(self, ) -> Memory:
|
||
return Memory.from_memory_list(
|
||
[chain.get_memory() for chain in self.chains]
|
||
)
|
||
|
||
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str:
|
||
memory = self.global_memory if do_all_memory else self.phase_memory
|
||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||
|
||
def get_chains_memory(self, content_key="role_content") -> List[Tuple]:
|
||
return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory]
|
||
|
||
def get_chains_memory_str(self, content_key="role_content") -> str:
|
||
return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains]) |