256 lines
11 KiB
Python
256 lines
11 KiB
Python
|
import importlib
|
|||
|
from typing import List, Union, Dict, Any
|
|||
|
from loguru import logger
|
|||
|
import os
|
|||
|
from langchain.embeddings.base import Embeddings
|
|||
|
from langchain.agents import Tool
|
|||
|
from langchain.llms.base import BaseLLM, LLM
|
|||
|
|
|||
|
from coagent.retrieval.base_retrieval import IMRertrieval
|
|||
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
|||
|
from coagent.connector.phase import BasePhase
|
|||
|
from coagent.connector.agents import BaseAgent
|
|||
|
from coagent.connector.chains import BaseChain
|
|||
|
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
|
|||
|
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
|||
|
|
|||
|
|
|||
|
class AgentFlow:
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
role_name: str,
|
|||
|
agent_type: str,
|
|||
|
role_type: str = "assistant",
|
|||
|
agent_index: int = 0,
|
|||
|
role_prompt: str = "",
|
|||
|
prompt_config: List[Dict[str, Any]] = [],
|
|||
|
prompt_manager_type: str = "PromptManager",
|
|||
|
chat_turn: int = 3,
|
|||
|
focus_agents: List[str] = [],
|
|||
|
focus_messages: List[str] = [],
|
|||
|
embeddings: Embeddings = None,
|
|||
|
llm: BaseLLM = None,
|
|||
|
doc_retrieval: IMRertrieval = None,
|
|||
|
code_retrieval: IMRertrieval = None,
|
|||
|
search_retrieval: IMRertrieval = None,
|
|||
|
**kwargs
|
|||
|
):
|
|||
|
self.role_type = role_type
|
|||
|
self.role_name = role_name
|
|||
|
self.agent_type = agent_type
|
|||
|
self.role_prompt = role_prompt
|
|||
|
self.agent_index = agent_index
|
|||
|
|
|||
|
self.prompt_config = prompt_config
|
|||
|
self.prompt_manager_type = prompt_manager_type
|
|||
|
|
|||
|
self.chat_turn = chat_turn
|
|||
|
self.focus_agents = focus_agents
|
|||
|
self.focus_messages = focus_messages
|
|||
|
|
|||
|
self.embeddings = embeddings
|
|||
|
self.llm = llm
|
|||
|
self.doc_retrieval = doc_retrieval
|
|||
|
self.code_retrieval = code_retrieval
|
|||
|
self.search_retrieval = search_retrieval
|
|||
|
# self.build_config()
|
|||
|
# self.build_agent()
|
|||
|
|
|||
|
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
|||
|
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
|||
|
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
|||
|
|
|||
|
def build_agent(self,
|
|||
|
embeddings: Embeddings = None, llm: BaseLLM = None,
|
|||
|
doc_retrieval: IMRertrieval = None,
|
|||
|
code_retrieval: IMRertrieval = None,
|
|||
|
search_retrieval: IMRertrieval = None,
|
|||
|
):
|
|||
|
# 可注册个性化的agent,仅通过start_action和end_action来注册
|
|||
|
# class ExtraAgent(BaseAgent):
|
|||
|
# def start_action_step(self, message: Message) -> Message:
|
|||
|
# pass
|
|||
|
|
|||
|
# def end_action_step(self, message: Message) -> Message:
|
|||
|
# pass
|
|||
|
# agent_module = importlib.import_module("coagent.connector.agents")
|
|||
|
# setattr(agent_module, 'extraAgent', ExtraAgent)
|
|||
|
|
|||
|
# 可注册个性化的prompt组装方式,
|
|||
|
# class CodeRetrievalPM(PromptManager):
|
|||
|
# def handle_code_packages(self, **kwargs) -> str:
|
|||
|
# if 'previous_agent_message' not in kwargs:
|
|||
|
# return ""
|
|||
|
# previous_agent_message: Message = kwargs['previous_agent_message']
|
|||
|
# # 由于两个agent共用了同一个manager,所以临时性处理
|
|||
|
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
|
|||
|
# return ", ".join([str(v) for v in vertices])
|
|||
|
|
|||
|
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
|||
|
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
|
|||
|
|
|||
|
# agent实例化
|
|||
|
agent_module = importlib.import_module("coagent.connector.agents")
|
|||
|
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
|
|||
|
role = Role(
|
|||
|
role_type=self.agent_type, role_name=self.role_name,
|
|||
|
agent_type=self.agent_type, role_prompt=self.role_prompt,
|
|||
|
)
|
|||
|
|
|||
|
self.build_config(embeddings, llm)
|
|||
|
self.agent = baseAgent(
|
|||
|
role=role,
|
|||
|
prompt_config = [PromptField(**config) for config in self.prompt_config],
|
|||
|
prompt_manager_type=self.prompt_manager_type,
|
|||
|
chat_turn=self.chat_turn,
|
|||
|
focus_agents=self.focus_agents,
|
|||
|
focus_message_keys=self.focus_messages,
|
|||
|
llm_config=self.llm_config,
|
|||
|
embed_config=self.embed_config,
|
|||
|
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
|||
|
code_retrieval=code_retrieval or self.code_retrieval,
|
|||
|
search_retrieval=search_retrieval or self.search_retrieval,
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class ChainFlow:
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
chain_name: str,
|
|||
|
chain_index: int = 0,
|
|||
|
agent_flows: List[AgentFlow] = [],
|
|||
|
chat_turn: int = 5,
|
|||
|
do_checker: bool = False,
|
|||
|
embeddings: Embeddings = None,
|
|||
|
llm: BaseLLM = None,
|
|||
|
doc_retrieval: IMRertrieval = None,
|
|||
|
code_retrieval: IMRertrieval = None,
|
|||
|
search_retrieval: IMRertrieval = None,
|
|||
|
# chain_type: str = "BaseChain",
|
|||
|
**kwargs
|
|||
|
):
|
|||
|
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
|
|||
|
self.chat_turn = chat_turn
|
|||
|
self.do_checker = do_checker
|
|||
|
self.chain_name = chain_name
|
|||
|
self.chain_index = chain_index
|
|||
|
self.chain_type = "BaseChain"
|
|||
|
|
|||
|
self.embeddings = embeddings
|
|||
|
self.llm = llm
|
|||
|
|
|||
|
self.doc_retrieval = doc_retrieval
|
|||
|
self.code_retrieval = code_retrieval
|
|||
|
self.search_retrieval = search_retrieval
|
|||
|
# self.build_config()
|
|||
|
# self.build_chain()
|
|||
|
|
|||
|
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
|||
|
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
|||
|
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
|||
|
|
|||
|
def build_chain(self,
|
|||
|
embeddings: Embeddings = None, llm: BaseLLM = None,
|
|||
|
doc_retrieval: IMRertrieval = None,
|
|||
|
code_retrieval: IMRertrieval = None,
|
|||
|
search_retrieval: IMRertrieval = None,
|
|||
|
):
|
|||
|
# chain 实例化
|
|||
|
chain_module = importlib.import_module("coagent.connector.chains")
|
|||
|
baseChain: BaseChain = getattr(chain_module, self.chain_type)
|
|||
|
|
|||
|
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
|
|||
|
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
|
|||
|
|
|||
|
# agent 实例化
|
|||
|
self.build_config(embeddings, llm)
|
|||
|
for agent_flow in self.agent_flows:
|
|||
|
agent_flow.build_agent(embeddings, llm)
|
|||
|
|
|||
|
self.chain = baseChain(
|
|||
|
chain_config,
|
|||
|
[agent_flow.agent for agent_flow in self.agent_flows],
|
|||
|
embed_config=self.embed_config,
|
|||
|
llm_config=self.llm_config,
|
|||
|
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
|||
|
code_retrieval=code_retrieval or self.code_retrieval,
|
|||
|
search_retrieval=search_retrieval or self.search_retrieval,
|
|||
|
)
|
|||
|
|
|||
|
class PhaseFlow:
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
phase_name: str,
|
|||
|
chain_flows: List[ChainFlow],
|
|||
|
embeddings: Embeddings = None,
|
|||
|
llm: BaseLLM = None,
|
|||
|
tools: List[Tool] = [],
|
|||
|
doc_retrieval: IMRertrieval = None,
|
|||
|
code_retrieval: IMRertrieval = None,
|
|||
|
search_retrieval: IMRertrieval = None,
|
|||
|
**kwargs
|
|||
|
):
|
|||
|
self.phase_name = phase_name
|
|||
|
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
|
|||
|
self.phase_type = "BasePhase"
|
|||
|
self.tools = tools
|
|||
|
|
|||
|
self.embeddings = embeddings
|
|||
|
self.llm = llm
|
|||
|
|
|||
|
self.doc_retrieval = doc_retrieval
|
|||
|
self.code_retrieval = code_retrieval
|
|||
|
self.search_retrieval = search_retrieval
|
|||
|
# self.build_config()
|
|||
|
self.build_phase()
|
|||
|
|
|||
|
def __call__(self, params: dict) -> str:
|
|||
|
|
|||
|
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
|||
|
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
|
|||
|
try:
|
|||
|
logger.info(f"params: {params}")
|
|||
|
query_content = params.get("query") or params.get("input")
|
|||
|
search_type = params.get("search_type")
|
|||
|
query = Message(
|
|||
|
role_name="human", role_type="user", tools=self.tools,
|
|||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
|||
|
cb_search_type=search_type,
|
|||
|
)
|
|||
|
# phase.pre_print(query)
|
|||
|
output_message, output_memory = self.phase.step(query)
|
|||
|
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
|
|||
|
return output_content
|
|||
|
except Exception as e:
|
|||
|
logger.exception(e)
|
|||
|
return f"Error {e}"
|
|||
|
|
|||
|
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
|||
|
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
|||
|
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
|||
|
|
|||
|
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
|||
|
# phase 实例化
|
|||
|
phase_module = importlib.import_module("coagent.connector.phase")
|
|||
|
basePhase: BasePhase = getattr(phase_module, self.phase_type)
|
|||
|
|
|||
|
# chain 实例化
|
|||
|
self.build_config(self.embeddings or embeddings, self.llm or llm)
|
|||
|
os.environ["log_verbose"] = "2"
|
|||
|
for chain_flow in self.chain_flows:
|
|||
|
chain_flow.build_chain(
|
|||
|
self.embeddings or embeddings, self.llm or llm,
|
|||
|
self.doc_retrieval, self.code_retrieval, self.search_retrieval
|
|||
|
)
|
|||
|
|
|||
|
self.phase: BasePhase = basePhase(
|
|||
|
phase_name=self.phase_name,
|
|||
|
chains=[chain_flow.chain for chain_flow in self.chain_flows],
|
|||
|
embed_config=self.embed_config,
|
|||
|
llm_config=self.llm_config,
|
|||
|
doc_retrieval=self.doc_retrieval,
|
|||
|
code_retrieval=self.code_retrieval,
|
|||
|
search_retrieval=self.search_retrieval
|
|||
|
)
|