codefuse-chatbot/dev_opsgpt/tools/codechat_tools.py

110 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: utf-8
'''
@author: 温进
@file: codechat_tools.py.py
@time: 2023/12/14 上午10:24
@desc:
'''
import json
import os
import re
from pydantic import BaseModel, Field
from typing import List, Dict
import requests
import numpy as np
from loguru import logger
from configs.model_config import (
CODE_SEARCH_TOP_K)
from .base_tool import BaseToolModel
from dev_opsgpt.service.cb_api import search_code, search_related_vertices, search_code_by_vertex
# 问题进来
# 调用函数 0输入问题输出代码文件名 1 和 代码文件 1
#
# agent 1
# 1. LLM代码+问题 输出:是否能解决
#
# agent 2
# 1. 调用函数 1 :输入:代码文件名 1 输出:代码文件名列表
# 2. LLM输入代码文件 1 问题,代码文件名列表,输出:代码文件名 2
# 3. 调用函数 2 输入 :代码文件名 2 输出:代码文件 2
class CodeRetrievalSingle(BaseToolModel):
name = "CodeRetrievalOneCode"
description = "输入用户的问题,输出一个代码文件名和代码文件"
class ToolInputArgs(BaseModel):
query: str = Field(..., description="检索的问题")
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
code: str = Field(..., description="检索代码")
vertex: str = Field(..., description="代码对应 id")
@classmethod
def run(cls, code_base_name, query):
"""excute your tool!"""
search_type = 'description'
code_limit = 1
# default
search_result = search_code(code_base_name, query, code_limit, search_type=search_type,
history_node_list=[])
logger.debug(search_result)
code = search_result['context']
vertex = search_result['related_vertices'][0]
# logger.debug(f"code: {code}, vertex: {vertex}")
res = {
'code': code,
'vertex': vertex
}
return res
class RelatedVerticesRetrival(BaseToolModel):
name = "RelatedVerticesRetrival"
description = "输入代码节点名,返回相连的节点名"
class ToolInputArgs(BaseModel):
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
vertex: str = Field(..., description="节点名", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
vertices: list = Field(..., description="相连节点名")
@classmethod
def run(cls, code_base_name: str, vertex: str):
"""execute your tool!"""
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
logger.debug(f"related_vertices: {related_vertices}")
return related_vertices
class Vertex2Code(BaseToolModel):
name = "Vertex2Code"
description = "输入代码节点名,返回对应的代码文件"
class ToolInputArgs(BaseModel):
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
vertex: str = Field(..., description="节点名", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
code: str = Field(..., description="代码名")
@classmethod
def run(cls, code_base_name: str, vertex: str):
"""execute your tool!"""
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
return res