116 lines
4.3 KiB
Python
116 lines
4.3 KiB
Python
# encoding: utf-8
|
||
'''
|
||
@author: 温进
|
||
@file: codechat_tools.py.py
|
||
@time: 2023/12/14 上午10:24
|
||
@desc:
|
||
'''
|
||
import os
|
||
from pydantic import BaseModel, Field
|
||
from loguru import logger
|
||
|
||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||
from .base_tool import BaseToolModel
|
||
|
||
from coagent.service.cb_api import search_code, search_related_vertices, search_code_by_vertex
|
||
|
||
|
||
# 问题进来
|
||
# 调用函数 0:输入问题,输出代码文件名 1 和 代码文件 1
|
||
#
|
||
# agent 1
|
||
# 1. LLM:代码+问题 输出:是否能解决
|
||
#
|
||
# agent 2
|
||
# 1. 调用函数 1 :输入:代码文件名 1 输出:代码文件名列表
|
||
# 2. LLM:输入代码文件 1, 问题,代码文件名列表,输出:代码文件名 2
|
||
# 3. 调用函数 2: 输入 :代码文件名 2 输出:代码文件 2
|
||
|
||
|
||
class CodeRetrievalSingle(BaseToolModel):
|
||
name = "CodeRetrievalOneCode"
|
||
description = "输入用户的问题,输出一个代码文件名和代码文件"
|
||
|
||
class ToolInputArgs(BaseModel):
|
||
query: str = Field(..., description="检索的问题")
|
||
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
||
|
||
class ToolOutputArgs(BaseModel):
|
||
"""Output for MetricsQuery."""
|
||
code: str = Field(..., description="检索代码")
|
||
vertex: str = Field(..., description="代码对应 id")
|
||
|
||
@classmethod
|
||
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs):
|
||
"""excute your tool!"""
|
||
|
||
code_limit = 1
|
||
|
||
# default
|
||
search_result = search_code(code_base_name, query, code_limit, search_type=search_type,
|
||
history_node_list=[],
|
||
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
||
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
||
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True),
|
||
local_graph_path=kargs.get("local_graph_path", "")
|
||
)
|
||
if os.environ.get("log_verbose", "0") >= "3":
|
||
logger.debug(search_result)
|
||
code = search_result['context']
|
||
vertex = search_result['related_vertices'][0]
|
||
# logger.debug(f"code: {code}, vertex: {vertex}")
|
||
|
||
res = {
|
||
'code': code,
|
||
'vertex': vertex
|
||
}
|
||
|
||
return res
|
||
|
||
|
||
class RelatedVerticesRetrival(BaseToolModel):
|
||
name = "RelatedVerticesRetrival"
|
||
description = "输入代码节点名,返回相连的节点名"
|
||
|
||
class ToolInputArgs(BaseModel):
|
||
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
||
vertex: str = Field(..., description="节点名", examples=["samples"])
|
||
|
||
class ToolOutputArgs(BaseModel):
|
||
"""Output for MetricsQuery."""
|
||
vertices: list = Field(..., description="相连节点名")
|
||
|
||
@classmethod
|
||
def run(cls, code_base_name: str, vertex: str, **kargs):
|
||
"""execute your tool!"""
|
||
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
|
||
# logger.debug(f"related_vertices: {related_vertices}")
|
||
|
||
return related_vertices
|
||
|
||
|
||
class Vertex2Code(BaseToolModel):
|
||
name = "Vertex2Code"
|
||
description = "输入代码节点名,返回对应的代码文件"
|
||
|
||
class ToolInputArgs(BaseModel):
|
||
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
||
vertex: str = Field(..., description="节点名", examples=["samples"])
|
||
|
||
class ToolOutputArgs(BaseModel):
|
||
"""Output for MetricsQuery."""
|
||
code: str = Field(..., description="代码名")
|
||
|
||
@classmethod
|
||
def run(cls, code_base_name: str, vertex: str, **kargs):
|
||
"""execute your tool!"""
|
||
# format vertex
|
||
if ',' in vertex:
|
||
vertex_list = vertex.split(',')
|
||
vertex = vertex_list[0].strip(' "')
|
||
else:
|
||
vertex = vertex.strip(' "')
|
||
|
||
# logger.info(f'vertex={vertex}')
|
||
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
|
||
return res |