256 lines
11 KiB
Python
256 lines
11 KiB
Python
import importlib
|
||
from typing import List, Union, Dict, Any
|
||
from loguru import logger
|
||
import os
|
||
from langchain.embeddings.base import Embeddings
|
||
from langchain.agents import Tool
|
||
from langchain.llms.base import BaseLLM, LLM
|
||
|
||
from coagent.retrieval.base_retrieval import IMRertrieval
|
||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||
from coagent.connector.phase import BasePhase
|
||
from coagent.connector.agents import BaseAgent
|
||
from coagent.connector.chains import BaseChain
|
||
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
|
||
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||
|
||
|
||
class AgentFlow:
|
||
def __init__(
|
||
self,
|
||
role_name: str,
|
||
agent_type: str,
|
||
role_type: str = "assistant",
|
||
agent_index: int = 0,
|
||
role_prompt: str = "",
|
||
prompt_config: List[Dict[str, Any]] = [],
|
||
prompt_manager_type: str = "PromptManager",
|
||
chat_turn: int = 3,
|
||
focus_agents: List[str] = [],
|
||
focus_messages: List[str] = [],
|
||
embeddings: Embeddings = None,
|
||
llm: BaseLLM = None,
|
||
doc_retrieval: IMRertrieval = None,
|
||
code_retrieval: IMRertrieval = None,
|
||
search_retrieval: IMRertrieval = None,
|
||
**kwargs
|
||
):
|
||
self.role_type = role_type
|
||
self.role_name = role_name
|
||
self.agent_type = agent_type
|
||
self.role_prompt = role_prompt
|
||
self.agent_index = agent_index
|
||
|
||
self.prompt_config = prompt_config
|
||
self.prompt_manager_type = prompt_manager_type
|
||
|
||
self.chat_turn = chat_turn
|
||
self.focus_agents = focus_agents
|
||
self.focus_messages = focus_messages
|
||
|
||
self.embeddings = embeddings
|
||
self.llm = llm
|
||
self.doc_retrieval = doc_retrieval
|
||
self.code_retrieval = code_retrieval
|
||
self.search_retrieval = search_retrieval
|
||
# self.build_config()
|
||
# self.build_agent()
|
||
|
||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||
|
||
def build_agent(self,
|
||
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||
doc_retrieval: IMRertrieval = None,
|
||
code_retrieval: IMRertrieval = None,
|
||
search_retrieval: IMRertrieval = None,
|
||
):
|
||
# 可注册个性化的agent,仅通过start_action和end_action来注册
|
||
# class ExtraAgent(BaseAgent):
|
||
# def start_action_step(self, message: Message) -> Message:
|
||
# pass
|
||
|
||
# def end_action_step(self, message: Message) -> Message:
|
||
# pass
|
||
# agent_module = importlib.import_module("coagent.connector.agents")
|
||
# setattr(agent_module, 'extraAgent', ExtraAgent)
|
||
|
||
# 可注册个性化的prompt组装方式,
|
||
# class CodeRetrievalPM(PromptManager):
|
||
# def handle_code_packages(self, **kwargs) -> str:
|
||
# if 'previous_agent_message' not in kwargs:
|
||
# return ""
|
||
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||
# # 由于两个agent共用了同一个manager,所以临时性处理
|
||
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
|
||
# return ", ".join([str(v) for v in vertices])
|
||
|
||
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
|
||
|
||
# agent实例化
|
||
agent_module = importlib.import_module("coagent.connector.agents")
|
||
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
|
||
role = Role(
|
||
role_type=self.agent_type, role_name=self.role_name,
|
||
agent_type=self.agent_type, role_prompt=self.role_prompt,
|
||
)
|
||
|
||
self.build_config(embeddings, llm)
|
||
self.agent = baseAgent(
|
||
role=role,
|
||
prompt_config = [PromptField(**config) for config in self.prompt_config],
|
||
prompt_manager_type=self.prompt_manager_type,
|
||
chat_turn=self.chat_turn,
|
||
focus_agents=self.focus_agents,
|
||
focus_message_keys=self.focus_messages,
|
||
llm_config=self.llm_config,
|
||
embed_config=self.embed_config,
|
||
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||
code_retrieval=code_retrieval or self.code_retrieval,
|
||
search_retrieval=search_retrieval or self.search_retrieval,
|
||
)
|
||
|
||
|
||
|
||
class ChainFlow:
|
||
def __init__(
|
||
self,
|
||
chain_name: str,
|
||
chain_index: int = 0,
|
||
agent_flows: List[AgentFlow] = [],
|
||
chat_turn: int = 5,
|
||
do_checker: bool = False,
|
||
embeddings: Embeddings = None,
|
||
llm: BaseLLM = None,
|
||
doc_retrieval: IMRertrieval = None,
|
||
code_retrieval: IMRertrieval = None,
|
||
search_retrieval: IMRertrieval = None,
|
||
# chain_type: str = "BaseChain",
|
||
**kwargs
|
||
):
|
||
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
|
||
self.chat_turn = chat_turn
|
||
self.do_checker = do_checker
|
||
self.chain_name = chain_name
|
||
self.chain_index = chain_index
|
||
self.chain_type = "BaseChain"
|
||
|
||
self.embeddings = embeddings
|
||
self.llm = llm
|
||
|
||
self.doc_retrieval = doc_retrieval
|
||
self.code_retrieval = code_retrieval
|
||
self.search_retrieval = search_retrieval
|
||
# self.build_config()
|
||
# self.build_chain()
|
||
|
||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||
|
||
def build_chain(self,
|
||
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||
doc_retrieval: IMRertrieval = None,
|
||
code_retrieval: IMRertrieval = None,
|
||
search_retrieval: IMRertrieval = None,
|
||
):
|
||
# chain 实例化
|
||
chain_module = importlib.import_module("coagent.connector.chains")
|
||
baseChain: BaseChain = getattr(chain_module, self.chain_type)
|
||
|
||
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
|
||
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
|
||
|
||
# agent 实例化
|
||
self.build_config(embeddings, llm)
|
||
for agent_flow in self.agent_flows:
|
||
agent_flow.build_agent(embeddings, llm)
|
||
|
||
self.chain = baseChain(
|
||
chain_config,
|
||
[agent_flow.agent for agent_flow in self.agent_flows],
|
||
embed_config=self.embed_config,
|
||
llm_config=self.llm_config,
|
||
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||
code_retrieval=code_retrieval or self.code_retrieval,
|
||
search_retrieval=search_retrieval or self.search_retrieval,
|
||
)
|
||
|
||
class PhaseFlow:
|
||
def __init__(
|
||
self,
|
||
phase_name: str,
|
||
chain_flows: List[ChainFlow],
|
||
embeddings: Embeddings = None,
|
||
llm: BaseLLM = None,
|
||
tools: List[Tool] = [],
|
||
doc_retrieval: IMRertrieval = None,
|
||
code_retrieval: IMRertrieval = None,
|
||
search_retrieval: IMRertrieval = None,
|
||
**kwargs
|
||
):
|
||
self.phase_name = phase_name
|
||
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
|
||
self.phase_type = "BasePhase"
|
||
self.tools = tools
|
||
|
||
self.embeddings = embeddings
|
||
self.llm = llm
|
||
|
||
self.doc_retrieval = doc_retrieval
|
||
self.code_retrieval = code_retrieval
|
||
self.search_retrieval = search_retrieval
|
||
# self.build_config()
|
||
self.build_phase()
|
||
|
||
def __call__(self, params: dict) -> str:
|
||
|
||
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
|
||
try:
|
||
logger.info(f"params: {params}")
|
||
query_content = params.get("query") or params.get("input")
|
||
search_type = params.get("search_type")
|
||
query = Message(
|
||
role_name="human", role_type="user", tools=self.tools,
|
||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||
cb_search_type=search_type,
|
||
)
|
||
# phase.pre_print(query)
|
||
output_message, output_memory = self.phase.step(query)
|
||
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
|
||
return output_content
|
||
except Exception as e:
|
||
logger.exception(e)
|
||
return f"Error {e}"
|
||
|
||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||
|
||
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||
# phase 实例化
|
||
phase_module = importlib.import_module("coagent.connector.phase")
|
||
basePhase: BasePhase = getattr(phase_module, self.phase_type)
|
||
|
||
# chain 实例化
|
||
self.build_config(self.embeddings or embeddings, self.llm or llm)
|
||
os.environ["log_verbose"] = "2"
|
||
for chain_flow in self.chain_flows:
|
||
chain_flow.build_chain(
|
||
self.embeddings or embeddings, self.llm or llm,
|
||
self.doc_retrieval, self.code_retrieval, self.search_retrieval
|
||
)
|
||
|
||
self.phase: BasePhase = basePhase(
|
||
phase_name=self.phase_name,
|
||
chains=[chain_flow.chain for chain_flow in self.chain_flows],
|
||
embed_config=self.embed_config,
|
||
llm_config=self.llm_config,
|
||
doc_retrieval=self.doc_retrieval,
|
||
code_retrieval=self.code_retrieval,
|
||
search_retrieval=self.search_retrieval
|
||
)
|