110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
|
# encoding: utf-8
|
|||
|
'''
|
|||
|
@author: 温进
|
|||
|
@file: codechat_tools.py.py
|
|||
|
@time: 2023/12/14 上午10:24
|
|||
|
@desc:
|
|||
|
'''
|
|||
|
import json
|
|||
|
import os
|
|||
|
import re
|
|||
|
from pydantic import BaseModel, Field
|
|||
|
from typing import List, Dict
|
|||
|
import requests
|
|||
|
import numpy as np
|
|||
|
from loguru import logger
|
|||
|
|
|||
|
from configs.model_config import (
|
|||
|
CODE_SEARCH_TOP_K)
|
|||
|
from .base_tool import BaseToolModel
|
|||
|
|
|||
|
from dev_opsgpt.service.cb_api import search_code, search_related_vertices, search_code_by_vertex
|
|||
|
|
|||
|
|
|||
|
# 问题进来
|
|||
|
# 调用函数 0:输入问题,输出代码文件名 1 和 代码文件 1
|
|||
|
#
|
|||
|
# agent 1
|
|||
|
# 1. LLM:代码+问题 输出:是否能解决
|
|||
|
#
|
|||
|
# agent 2
|
|||
|
# 1. 调用函数 1 :输入:代码文件名 1 输出:代码文件名列表
|
|||
|
# 2. LLM:输入代码文件 1, 问题,代码文件名列表,输出:代码文件名 2
|
|||
|
# 3. 调用函数 2: 输入 :代码文件名 2 输出:代码文件 2
|
|||
|
|
|||
|
|
|||
|
class CodeRetrievalSingle(BaseToolModel):
|
|||
|
name = "CodeRetrievalOneCode"
|
|||
|
description = "输入用户的问题,输出一个代码文件名和代码文件"
|
|||
|
|
|||
|
class ToolInputArgs(BaseModel):
|
|||
|
query: str = Field(..., description="检索的问题")
|
|||
|
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
|||
|
|
|||
|
class ToolOutputArgs(BaseModel):
|
|||
|
"""Output for MetricsQuery."""
|
|||
|
code: str = Field(..., description="检索代码")
|
|||
|
vertex: str = Field(..., description="代码对应 id")
|
|||
|
|
|||
|
@classmethod
|
|||
|
def run(cls, code_base_name, query):
|
|||
|
"""excute your tool!"""
|
|||
|
|
|||
|
search_type = 'description'
|
|||
|
code_limit = 1
|
|||
|
|
|||
|
# default
|
|||
|
search_result = search_code(code_base_name, query, code_limit, search_type=search_type,
|
|||
|
history_node_list=[])
|
|||
|
|
|||
|
logger.debug(search_result)
|
|||
|
code = search_result['context']
|
|||
|
vertex = search_result['related_vertices'][0]
|
|||
|
# logger.debug(f"code: {code}, vertex: {vertex}")
|
|||
|
|
|||
|
res = {
|
|||
|
'code': code,
|
|||
|
'vertex': vertex
|
|||
|
}
|
|||
|
|
|||
|
return res
|
|||
|
|
|||
|
|
|||
|
class RelatedVerticesRetrival(BaseToolModel):
|
|||
|
name = "RelatedVerticesRetrival"
|
|||
|
description = "输入代码节点名,返回相连的节点名"
|
|||
|
|
|||
|
class ToolInputArgs(BaseModel):
|
|||
|
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
|||
|
vertex: str = Field(..., description="节点名", examples=["samples"])
|
|||
|
|
|||
|
class ToolOutputArgs(BaseModel):
|
|||
|
"""Output for MetricsQuery."""
|
|||
|
vertices: list = Field(..., description="相连节点名")
|
|||
|
|
|||
|
@classmethod
|
|||
|
def run(cls, code_base_name: str, vertex: str):
|
|||
|
"""execute your tool!"""
|
|||
|
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
|
|||
|
logger.debug(f"related_vertices: {related_vertices}")
|
|||
|
|
|||
|
return related_vertices
|
|||
|
|
|||
|
|
|||
|
class Vertex2Code(BaseToolModel):
|
|||
|
name = "Vertex2Code"
|
|||
|
description = "输入代码节点名,返回对应的代码文件"
|
|||
|
|
|||
|
class ToolInputArgs(BaseModel):
|
|||
|
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
|||
|
vertex: str = Field(..., description="节点名", examples=["samples"])
|
|||
|
|
|||
|
class ToolOutputArgs(BaseModel):
|
|||
|
"""Output for MetricsQuery."""
|
|||
|
code: str = Field(..., description="代码名")
|
|||
|
|
|||
|
@classmethod
|
|||
|
def run(cls, code_base_name: str, vertex: str):
|
|||
|
"""execute your tool!"""
|
|||
|
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
|
|||
|
return res
|