2023-11-07 19:44:47 +08:00
|
|
|
|
from typing import List, Union, Dict, Tuple
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
import importlib
|
|
|
|
|
import copy
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
2023-12-26 11:41:53 +08:00
|
|
|
|
from dev_opsgpt.connector.agents import BaseAgent, SelectorAgent
|
2023-11-07 19:44:47 +08:00
|
|
|
|
from dev_opsgpt.connector.chains import BaseChain
|
|
|
|
|
from dev_opsgpt.tools.base_tool import BaseTools, Tool
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
from dev_opsgpt.connector.schema import (
|
2023-12-26 11:41:53 +08:00
|
|
|
|
Memory, Task, Env, Role, Message, Doc, AgentConfig, ChainConfig, PhaseConfig, CodeDoc,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
load_chain_configs, load_phase_configs, load_role_configs
|
|
|
|
|
)
|
|
|
|
|
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
2023-12-07 20:17:21 +08:00
|
|
|
|
from dev_opsgpt.connector.message_process import MessageUtils
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
|
|
role_configs = load_role_configs(AGETN_CONFIGS)
|
|
|
|
|
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
|
|
|
|
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasePhase:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
phase_name: str,
|
|
|
|
|
task: Task = None,
|
|
|
|
|
do_summary: bool = False,
|
|
|
|
|
do_search: bool = False,
|
|
|
|
|
do_doc_retrieval: bool = False,
|
|
|
|
|
do_code_retrieval: bool = False,
|
|
|
|
|
do_tool_retrieval: bool = False,
|
|
|
|
|
phase_config: Union[dict, str] = PHASE_CONFIGS,
|
|
|
|
|
chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
|
|
|
|
role_config: Union[dict, str] = AGETN_CONFIGS,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.conv_summary_agent = BaseAgent(role=role_configs["conv_summary"].role,
|
|
|
|
|
task = None,
|
|
|
|
|
memory = None,
|
|
|
|
|
do_search = role_configs["conv_summary"].do_search,
|
|
|
|
|
do_doc_retrieval = role_configs["conv_summary"].do_doc_retrieval,
|
|
|
|
|
do_tool_retrieval = role_configs["conv_summary"].do_tool_retrieval,
|
|
|
|
|
do_filter=False, do_use_self_memory=False)
|
|
|
|
|
|
|
|
|
|
self.chains: List[BaseChain] = self.init_chains(
|
|
|
|
|
phase_name,
|
|
|
|
|
task=task,
|
|
|
|
|
memory=None,
|
|
|
|
|
phase_config = phase_config,
|
|
|
|
|
chain_config = chain_config,
|
|
|
|
|
role_config = role_config,
|
|
|
|
|
)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
self.message_utils = MessageUtils()
|
2023-11-07 19:44:47 +08:00
|
|
|
|
self.phase_name = phase_name
|
|
|
|
|
self.do_summary = do_summary
|
|
|
|
|
self.do_search = do_search
|
|
|
|
|
self.do_code_retrieval = do_code_retrieval
|
|
|
|
|
self.do_doc_retrieval = do_doc_retrieval
|
|
|
|
|
self.do_tool_retrieval = do_tool_retrieval
|
2023-12-07 20:17:21 +08:00
|
|
|
|
#
|
|
|
|
|
self.global_memory = Memory(messages=[])
|
2023-11-07 19:44:47 +08:00
|
|
|
|
# self.chain_message = Memory([])
|
|
|
|
|
self.phase_memory: List[Memory] = []
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# memory_pool dont have specific order
|
|
|
|
|
self.memory_pool = Memory(messages=[])
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
2023-11-07 19:44:47 +08:00
|
|
|
|
summary_message = None
|
2023-12-07 20:17:21 +08:00
|
|
|
|
chain_message = Memory(messages=[])
|
|
|
|
|
local_phase_memory = Memory(messages=[])
|
2023-11-07 19:44:47 +08:00
|
|
|
|
# do_search、do_doc_search、do_code_search
|
2023-12-07 20:17:21 +08:00
|
|
|
|
query = self.message_utils.get_extrainfo_step(query, self.do_search, self.do_doc_retrieval, self.do_code_retrieval, self.do_tool_retrieval)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
input_message = copy.deepcopy(query)
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
self.global_memory.append(input_message)
|
|
|
|
|
local_phase_memory.append(input_message)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
for chain in self.chains:
|
|
|
|
|
# chain can supply background and query to next chain
|
2023-12-07 20:17:21 +08:00
|
|
|
|
for output_message, local_chain_memory in chain.astep(input_message, history, background=chain_message, memory_pool=self.memory_pool):
|
|
|
|
|
# logger.debug(f"local_memory: {local_memory + chain_memory}")
|
|
|
|
|
yield output_message, local_phase_memory + local_chain_memory
|
|
|
|
|
|
|
|
|
|
output_message = self.message_utils.inherit_extrainfo(input_message, output_message)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
input_message = output_message
|
|
|
|
|
logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}")
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# 这一段也有问题
|
|
|
|
|
self.global_memory.extend(local_chain_memory)
|
|
|
|
|
local_phase_memory.extend(local_chain_memory)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# whether to use summary_llm
|
2023-11-07 19:44:47 +08:00
|
|
|
|
if self.do_summary:
|
2023-12-07 20:17:21 +08:00
|
|
|
|
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {local_phase_memory.to_str_messages(content_key='step_content')}")
|
|
|
|
|
for summary_message in self.conv_summary_agent.arun(query, background=local_phase_memory, memory_pool=self.memory_pool):
|
|
|
|
|
pass
|
|
|
|
|
# summary_message = Message(**summary_message)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
summary_message.role_name = chain.chainConfig.chain_name
|
2023-12-07 20:17:21 +08:00
|
|
|
|
summary_message = self.conv_summary_agent.message_utils.parser(summary_message)
|
|
|
|
|
summary_message = self.conv_summary_agent.message_utils.filter(summary_message)
|
|
|
|
|
summary_message = self.message_utils.inherit_extrainfo(output_message, summary_message)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
chain_message.append(summary_message)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
message = summary_message or output_message
|
|
|
|
|
yield message, local_phase_memory
|
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
# 由于不会存在多轮chain执行,所以直接保留memory即可
|
|
|
|
|
for chain in self.chains:
|
|
|
|
|
self.phase_memory.append(chain.global_memory)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# TODO:local_memory缺少添加summary的过程
|
2023-11-07 19:44:47 +08:00
|
|
|
|
message = summary_message or output_message
|
|
|
|
|
message.role_name = self.phase_name
|
2023-12-07 20:17:21 +08:00
|
|
|
|
yield message, local_phase_memory
|
|
|
|
|
|
|
|
|
|
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
|
|
|
|
for message, local_phase_memory in self.astep(query, history=history):
|
|
|
|
|
pass
|
|
|
|
|
return message, local_phase_memory
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
|
|
def init_chains(self, phase_name, phase_config, chain_config,
|
|
|
|
|
role_config, task=None, memory=None) -> List[BaseChain]:
|
|
|
|
|
# load config
|
|
|
|
|
role_configs = load_role_configs(role_config)
|
|
|
|
|
chain_configs = load_chain_configs(chain_config)
|
|
|
|
|
phase_configs = load_phase_configs(phase_config)
|
|
|
|
|
|
|
|
|
|
chains = []
|
|
|
|
|
self.chain_module = importlib.import_module("dev_opsgpt.connector.chains")
|
|
|
|
|
self.agent_module = importlib.import_module("dev_opsgpt.connector.agents")
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
phase = phase_configs.get(phase_name)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
logger.info(f"start to init the phase, the phase_name is {phase_name}, it contains these chains such as {phase.chains}")
|
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
for chain_name in phase.chains:
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# logger.debug(f"{chain_configs.keys()}")
|
2023-11-07 19:44:47 +08:00
|
|
|
|
chain_config = chain_configs[chain_name]
|
2023-12-07 20:17:21 +08:00
|
|
|
|
logger.info(f"start to init the chain, the chain_name is {chain_name}, it contains these agents such as {chain_config.agents}")
|
|
|
|
|
|
|
|
|
|
agents = []
|
|
|
|
|
for agent_name in chain_config.agents:
|
|
|
|
|
agent_config = role_configs[agent_name]
|
|
|
|
|
baseAgent: BaseAgent = getattr(self.agent_module, agent_config.role.agent_type)
|
|
|
|
|
base_agent = baseAgent(
|
|
|
|
|
agent_config.role,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
task = task,
|
|
|
|
|
memory = memory,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
chat_turn=agent_config.chat_turn,
|
|
|
|
|
do_search = agent_config.do_search,
|
|
|
|
|
do_doc_retrieval = agent_config.do_doc_retrieval,
|
|
|
|
|
do_tool_retrieval = agent_config.do_tool_retrieval,
|
|
|
|
|
stop= agent_config.stop,
|
|
|
|
|
focus_agents=agent_config.focus_agents,
|
|
|
|
|
focus_message_keys=agent_config.focus_message_keys,
|
2023-11-07 19:44:47 +08:00
|
|
|
|
)
|
2023-12-26 11:41:53 +08:00
|
|
|
|
|
|
|
|
|
if agent_config.role.agent_type == "SelectorAgent":
|
|
|
|
|
for group_agent_name in agent_config.group_agents:
|
|
|
|
|
group_agent_config = role_configs[group_agent_name]
|
|
|
|
|
baseAgent: BaseAgent = getattr(self.agent_module, group_agent_config.role.agent_type)
|
|
|
|
|
group_base_agent = baseAgent(
|
|
|
|
|
group_agent_config.role,
|
|
|
|
|
task = task,
|
|
|
|
|
memory = memory,
|
|
|
|
|
chat_turn=group_agent_config.chat_turn,
|
|
|
|
|
do_search = group_agent_config.do_search,
|
|
|
|
|
do_doc_retrieval = group_agent_config.do_doc_retrieval,
|
|
|
|
|
do_tool_retrieval = group_agent_config.do_tool_retrieval,
|
|
|
|
|
stop= group_agent_config.stop,
|
|
|
|
|
focus_agents=group_agent_config.focus_agents,
|
|
|
|
|
focus_message_keys=group_agent_config.focus_message_keys,
|
|
|
|
|
)
|
|
|
|
|
base_agent.group_agents.append(group_base_agent)
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
agents.append(base_agent)
|
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
|
chain_instance = BaseChain(
|
|
|
|
|
chain_config, agents, chain_config.chat_turn,
|
|
|
|
|
do_checker=chain_configs[chain_name].do_checker,
|
2023-12-26 11:41:53 +08:00
|
|
|
|
)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
chains.append(chain_instance)
|
|
|
|
|
|
|
|
|
|
return chains
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# def get_extrainfo_step(self, input_message):
|
|
|
|
|
# if self.do_doc_retrieval:
|
|
|
|
|
# input_message = self.get_doc_retrieval(input_message)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# # logger.debug(F"self.do_code_retrieval: {self.do_code_retrieval}")
|
|
|
|
|
# if self.do_code_retrieval:
|
|
|
|
|
# input_message = self.get_code_retrieval(input_message)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# if self.do_search:
|
|
|
|
|
# input_message = self.get_search_retrieval(input_message)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# return input_message
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
|
|
|
|
# output_message.db_docs = input_message.db_docs
|
|
|
|
|
# output_message.search_docs = input_message.search_docs
|
|
|
|
|
# output_message.code_docs = input_message.code_docs
|
|
|
|
|
# output_message.figures.update(input_message.figures)
|
|
|
|
|
# output_message.origin_query = input_message.origin_query
|
|
|
|
|
# return output_message
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# def get_search_retrieval(self, message: Message,) -> Message:
|
|
|
|
|
# SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
|
|
|
|
# search_docs = []
|
|
|
|
|
# for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
|
|
|
|
# doc.update({"index": idx})
|
|
|
|
|
# search_docs.append(Doc(**doc))
|
|
|
|
|
# message.search_docs = search_docs
|
|
|
|
|
# return message
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# def get_doc_retrieval(self, message: Message) -> Message:
|
|
|
|
|
# query = message.role_content
|
|
|
|
|
# knowledge_basename = message.doc_engine_name
|
|
|
|
|
# top_k = message.top_k
|
|
|
|
|
# score_threshold = message.score_threshold
|
|
|
|
|
# if knowledge_basename:
|
|
|
|
|
# docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold)
|
|
|
|
|
# message.db_docs = [Doc(**doc) for doc in docs]
|
|
|
|
|
# return message
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# def get_code_retrieval(self, message: Message) -> Message:
|
|
|
|
|
# # DocRetrieval.run("langchain是什么", "DSADSAD")
|
|
|
|
|
# query = message.input_query
|
|
|
|
|
# code_engine_name = message.code_engine_name
|
|
|
|
|
# history_node_list = message.history_node_list
|
|
|
|
|
# code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list)
|
|
|
|
|
# message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
|
|
|
|
# return message
|
|
|
|
|
|
|
|
|
|
# def get_tool_retrieval(self, message: Message) -> Message:
|
|
|
|
|
# return message
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
|
|
def update(self) -> Memory:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def get_memory(self, ) -> Memory:
|
|
|
|
|
return Memory.from_memory_list(
|
|
|
|
|
[chain.get_memory() for chain in self.chains]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str:
|
2023-12-07 20:17:21 +08:00
|
|
|
|
memory = self.global_memory if do_all_memory else self.phase_memory
|
2023-11-07 19:44:47 +08:00
|
|
|
|
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
|
|
|
|
|
|
|
|
|
def get_chains_memory(self, content_key="role_content") -> List[Tuple]:
|
|
|
|
|
return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory]
|
|
|
|
|
|
|
|
|
|
def get_chains_memory_str(self, content_key="role_content") -> str:
|
|
|
|
|
return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains])
|