codefuse-chatbot/tests/chains_test.py

168 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, sys, requests
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(src_dir)
from dev_opsgpt.tools import (
toLangchainTools, get_tool_schema, DDGSTool, DocRetrieval,
TOOL_DICT, TOOL_SETS
)
from configs.model_config import *
from dev_opsgpt.connector.phase import BasePhase
from dev_opsgpt.connector.agents import BaseAgent
from dev_opsgpt.connector.chains import BaseChain
from dev_opsgpt.connector.connector_schema import (
Message, load_role_configs, load_phase_configs, load_chain_configs
)
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
import importlib
print(src_dir)
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
role_configs = load_role_configs(AGETN_CONFIGS)
chain_configs = load_chain_configs(CHAIN_CONFIGS)
phase_configs = load_phase_configs(PHASE_CONFIGS)
agent_module = importlib.import_module("dev_opsgpt.connector.agents")
# agent的测试
query = Message(role_name="tool_react", role_type="human",
role_content="我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, \
18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987]\
我不知道这份数据是否存在问题,请帮我判断一下", tools=tools)
query = Message(role_name="tool_react", role_type="human",
role_content="帮我确认下127.0.0.1这个服务器的在10点是否存在异常请帮我判断一下", tools=tools)
query = Message(role_name="code_react", role_type="human",
role_content="帮我确认当前目录下有哪些文件", tools=tools)
# "给我一份冒泡排序的代码"
query = Message(role_name="intention_recognizer", role_type="human",
role_content="对employee_data.csv进行数据分析", tools=tools)
# role = role_configs["general_planner"]
# agent_class = getattr(agent_module, role.role.agent_type)
# agent = agent_class(role.role,
# task = None,
# memory = None,
# chat_turn=role.chat_turn,
# do_search = role.do_search,
# do_doc_retrieval = role.do_doc_retrieval,
# do_tool_retrieval = role.do_tool_retrieval,)
# message = agent.run(query)
# print(message.role_content)
# chain的测试
# query = Message(role_name="deveploer", role_type="human", role_content="编写冒泡排序,并生成测例")
# query = Message(role_name="general_planner", role_type="human", role_content="对employee_data.csv进行数据分析")
# query = Message(role_name="tool_react", role_type="human", role_content="我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987]\我不知道这份数据是否存在问题,请帮我判断一下", tools=tools)
# role = role_configs[query.role_name]
role1 = role_configs["planner"]
role2 = role_configs["code_react"]
agents = [
getattr(agent_module, role1.role.agent_type)(role1.role,
task = None,
memory = None,
do_search = role1.do_search,
do_doc_retrieval = role1.do_doc_retrieval,
do_tool_retrieval = role1.do_tool_retrieval,),
getattr(agent_module, role2.role.agent_type)(role2.role,
task = None,
memory = None,
do_search = role2.do_search,
do_doc_retrieval = role2.do_doc_retrieval,
do_tool_retrieval = role2.do_tool_retrieval,),
]
query = Message(role_name="user", role_type="human",
role_content="确认本地是否存在employee_data.csv并查看它有哪些列和数据类型分析这份数据的内容根据这个数据预测未来走势", tools=tools)
query = Message(role_name="user", role_type="human",
role_content="确认本地是否存在employee_data.csv并查看它有哪些列和数据类型", tools=tools)
chain = BaseChain(chain_configs["dataAnalystChain"], agents, do_code_exec=False)
# message = chain.step(query)
# print(message.role_content)
# print("\n".join("\n".join([": ".join(j) for j in i]) for i in chain.get_agents_memory()))
# print("\n".join(": ".join(i) for i in chain.get_memory()))
# print( chain.get_agents_memory_str())
# print( chain.get_memory_str())
# 测试 phase
phase_name = "toolReactPhase"
# phase_name = "codeReactPhase"
# phase_name = "chatPhase"
phase = BasePhase(phase_name,
task = None,
phase_config = PHASE_CONFIGS,
chain_config = CHAIN_CONFIGS,
role_config = AGETN_CONFIGS,
do_summary=False,
do_code_retrieval=False,
do_doc_retrieval=True,
do_search=False,
)
query = Message(role_name="user", role_type="human",
role_content="确认本地是否存在employee_data.csv并查看它有哪些列和数据类型并选择合适的数值列画出折线图")
query = Message(role_name="user", role_type="human",
role_content="判断下127.0.0.1这个服务器的在10点的监控数据是否存在异常", tools=tools)
# 根据其他类似的类,新开发个 ExceptionComponent2继承 AbstractTrafficComponent
# query = Message(role_name="human", role_type="human", role_content="langchain有什么用")
# output_message = phase.step(query)
# print(phase.get_chains_memory(content_key="step_content"))
# print(phase.get_chains_memory_str(content_key="step_content"))
# print(output_message.to_tuple_message(return_all=True))
from dev_opsgpt.tools import DDGSTool, CodeRetrieval
# print(DDGSTool.run("langchain是什么", 3))
# print(CodeRetrieval.run("dsadsadsa", query.role_content, code_limit=3, history_node_list=[]))
# from dev_opsgpt.chat.agent_chat import AgentChat
# agentChat = AgentChat()
# value = {
# "query": "帮我确认下127.0.0.1这个服务器的在10点是否存在异常请帮我判断一下",
# "phase_name": "toolReactPhase",
# "chain_name": "",
# "history": [],
# "doc_engine_name": "DSADSAD",
# "search_engine_name": "duckduckgo",
# "top_k": 3,
# "score_threshold": 1.0,
# "stream": False,
# "local_doc_url": False,
# "do_search": False,
# "do_doc_retrieval": False,
# "do_code_retrieval": False,
# "do_tool_retrieval": False,
# "custom_phase_configs": {},
# "custom_chain_configs": {},
# "custom_role_configs": {},
# "choose_tools": list(TOOL_SETS)
# }
# answer = agentChat.chat(**value)
# print(answer)