codefuse-chatbot/coagent/connector/antflow/flow.py

256 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import importlib
from typing import List, Union, Dict, Any
from loguru import logger
import os
from langchain.embeddings.base import Embeddings
from langchain.agents import Tool
from langchain.llms.base import BaseLLM, LLM
from coagent.retrieval.base_retrieval import IMRertrieval
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.chains import BaseChain
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
class AgentFlow:
def __init__(
self,
role_name: str,
agent_type: str,
role_type: str = "assistant",
agent_index: int = 0,
role_prompt: str = "",
prompt_config: List[Dict[str, Any]] = [],
prompt_manager_type: str = "PromptManager",
chat_turn: int = 3,
focus_agents: List[str] = [],
focus_messages: List[str] = [],
embeddings: Embeddings = None,
llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
**kwargs
):
self.role_type = role_type
self.role_name = role_name
self.agent_type = agent_type
self.role_prompt = role_prompt
self.agent_index = agent_index
self.prompt_config = prompt_config
self.prompt_manager_type = prompt_manager_type
self.chat_turn = chat_turn
self.focus_agents = focus_agents
self.focus_messages = focus_messages
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
# self.build_agent()
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_agent(self,
embeddings: Embeddings = None, llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
):
# 可注册个性化的agent仅通过start_action和end_action来注册
# class ExtraAgent(BaseAgent):
# def start_action_step(self, message: Message) -> Message:
# pass
# def end_action_step(self, message: Message) -> Message:
# pass
# agent_module = importlib.import_module("coagent.connector.agents")
# setattr(agent_module, 'extraAgent', ExtraAgent)
# 可注册个性化的prompt组装方式
# class CodeRetrievalPM(PromptManager):
# def handle_code_packages(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# # 由于两个agent共用了同一个manager所以临时性处理
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
# return ", ".join([str(v) for v in vertices])
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
# agent实例化
agent_module = importlib.import_module("coagent.connector.agents")
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
role = Role(
role_type=self.agent_type, role_name=self.role_name,
agent_type=self.agent_type, role_prompt=self.role_prompt,
)
self.build_config(embeddings, llm)
self.agent = baseAgent(
role=role,
prompt_config = [PromptField(**config) for config in self.prompt_config],
prompt_manager_type=self.prompt_manager_type,
chat_turn=self.chat_turn,
focus_agents=self.focus_agents,
focus_message_keys=self.focus_messages,
llm_config=self.llm_config,
embed_config=self.embed_config,
doc_retrieval=doc_retrieval or self.doc_retrieval,
code_retrieval=code_retrieval or self.code_retrieval,
search_retrieval=search_retrieval or self.search_retrieval,
)
class ChainFlow:
def __init__(
self,
chain_name: str,
chain_index: int = 0,
agent_flows: List[AgentFlow] = [],
chat_turn: int = 5,
do_checker: bool = False,
embeddings: Embeddings = None,
llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
# chain_type: str = "BaseChain",
**kwargs
):
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
self.chat_turn = chat_turn
self.do_checker = do_checker
self.chain_name = chain_name
self.chain_index = chain_index
self.chain_type = "BaseChain"
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
# self.build_chain()
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_chain(self,
embeddings: Embeddings = None, llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
):
# chain 实例化
chain_module = importlib.import_module("coagent.connector.chains")
baseChain: BaseChain = getattr(chain_module, self.chain_type)
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
# agent 实例化
self.build_config(embeddings, llm)
for agent_flow in self.agent_flows:
agent_flow.build_agent(embeddings, llm)
self.chain = baseChain(
chain_config,
[agent_flow.agent for agent_flow in self.agent_flows],
embed_config=self.embed_config,
llm_config=self.llm_config,
doc_retrieval=doc_retrieval or self.doc_retrieval,
code_retrieval=code_retrieval or self.code_retrieval,
search_retrieval=search_retrieval or self.search_retrieval,
)
class PhaseFlow:
def __init__(
self,
phase_name: str,
chain_flows: List[ChainFlow],
embeddings: Embeddings = None,
llm: BaseLLM = None,
tools: List[Tool] = [],
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
**kwargs
):
self.phase_name = phase_name
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
self.phase_type = "BasePhase"
self.tools = tools
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
self.build_phase()
def __call__(self, params: dict) -> str:
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常请帮我判断一下"
try:
logger.info(f"params: {params}")
query_content = params.get("query") or params.get("input")
search_type = params.get("search_type")
query = Message(
role_name="human", role_type="user", tools=self.tools,
role_content=query_content, input_query=query_content, origin_query=query_content,
cb_search_type=search_type,
)
# phase.pre_print(query)
output_message, output_memory = self.phase.step(query)
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
return output_content
except Exception as e:
logger.exception(e)
return f"Error {e}"
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
# phase 实例化
phase_module = importlib.import_module("coagent.connector.phase")
basePhase: BasePhase = getattr(phase_module, self.phase_type)
# chain 实例化
self.build_config(self.embeddings or embeddings, self.llm or llm)
os.environ["log_verbose"] = "2"
for chain_flow in self.chain_flows:
chain_flow.build_chain(
self.embeddings or embeddings, self.llm or llm,
self.doc_retrieval, self.code_retrieval, self.search_retrieval
)
self.phase: BasePhase = basePhase(
phase_name=self.phase_name,
chains=[chain_flow.chain for chain_flow in self.chain_flows],
embed_config=self.embed_config,
llm_config=self.llm_config,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval
)