codefuse-chatbot/coagent/tools/codechat_tools.py

116 lines
4.3 KiB
Python
Raw Normal View History

# encoding: utf-8
'''
@author: 温进
@file: codechat_tools.py.py
@time: 2023/12/14 上午10:24
@desc:
'''
import os
from pydantic import BaseModel, Field
from loguru import logger
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from .base_tool import BaseToolModel
from coagent.service.cb_api import search_code, search_related_vertices, search_code_by_vertex
# 问题进来
# 调用函数 0输入问题输出代码文件名 1 和 代码文件 1
#
# agent 1
# 1. LLM代码+问题 输出:是否能解决
#
# agent 2
# 1. 调用函数 1 :输入:代码文件名 1 输出:代码文件名列表
# 2. LLM输入代码文件 1 问题,代码文件名列表,输出:代码文件名 2
# 3. 调用函数 2 输入 :代码文件名 2 输出:代码文件 2
class CodeRetrievalSingle(BaseToolModel):
name = "CodeRetrievalOneCode"
description = "输入用户的问题,输出一个代码文件名和代码文件"
class ToolInputArgs(BaseModel):
query: str = Field(..., description="检索的问题")
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
code: str = Field(..., description="检索代码")
vertex: str = Field(..., description="代码对应 id")
@classmethod
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs):
"""excute your tool!"""
code_limit = 1
# default
search_result = search_code(code_base_name, query, code_limit, search_type=search_type,
history_node_list=[],
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True),
local_graph_path=kargs.get("local_graph_path", "")
)
if os.environ.get("log_verbose", "0") >= "3":
logger.debug(search_result)
code = search_result['context']
vertex = search_result['related_vertices'][0]
# logger.debug(f"code: {code}, vertex: {vertex}")
res = {
'code': code,
'vertex': vertex
}
return res
class RelatedVerticesRetrival(BaseToolModel):
name = "RelatedVerticesRetrival"
description = "输入代码节点名,返回相连的节点名"
class ToolInputArgs(BaseModel):
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
vertex: str = Field(..., description="节点名", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
vertices: list = Field(..., description="相连节点名")
@classmethod
def run(cls, code_base_name: str, vertex: str, **kargs):
"""execute your tool!"""
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
# logger.debug(f"related_vertices: {related_vertices}")
return related_vertices
class Vertex2Code(BaseToolModel):
name = "Vertex2Code"
description = "输入代码节点名,返回对应的代码文件"
class ToolInputArgs(BaseModel):
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
vertex: str = Field(..., description="节点名", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
code: str = Field(..., description="代码名")
@classmethod
def run(cls, code_base_name: str, vertex: str, **kargs):
"""execute your tool!"""
# format vertex
if ',' in vertex:
vertex_list = vertex.split(',')
vertex = vertex_list[0].strip(' "')
else:
vertex = vertex.strip(' "')
# logger.info(f'vertex={vertex}')
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
return res