import os, sys, json
from loguru import logger
src_dir = os.path.join(
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
from configs.server_config import SANDBOX_SERVER
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.schema import Message
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
import importlib
from loguru import logger
from import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# 定义一个新的agent类
class CodeRetrieval(BaseAgent):
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
# 根据问题获取代码片段和节点信息
action_json =, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
local_graph_path=message.local_graph_path, use_nh=message.use_nh)
current_vertex = action_json['vertex']
message.customed_kargs["Code Snippet"] = action_json["code"]
message.customed_kargs['Current_Vertex'] = current_vertex
# 获取邻近节点
action_json =, message.customed_kargs['Current_Vertex'])
# 获取邻近节点所有代码
relative_vertex = []
retrieval_Codes = []
for vertex in action_json["vertices"]:
# 由于代码是文件级别,所以相同文件代码不再获取
# logger.debug(f"{current_vertex}, {vertex}")
current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
action_json =, vertex)
if action_json["code"]:
message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
message.customed_kargs["Relative_vertex"] = relative_vertex
return message
# add agent or prompt_manager class
agent_module = importlib.import_module("coagent.connector.agents")
setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
llm_config = LLMConfig(
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
## initialize codebase
# delete codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = False
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
# load codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = True
do_interpret = True
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
vertexes = cbh.search_vertices(vertex_type="class")
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "0"
phase_name = "code2Tests"
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
# round-1
test_cases = []
for vertex in vertexes:
query_content = f"{vertex}生成可执行的测例 "
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
output_message, output_memory = phase.step(query, reinit_memory=True)
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
values = output_memory.get_spec_parserd_output()
test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
f.write(test_code["Test Code"])
