codefuse-chatbot/dev_opsgpt/tools/cb_query_tool.py

63 lines
1.9 KiB
Python

# encoding: utf-8
'''
@author: 温进
@file: cb_query_tool.py
@time: 2023/11/2 下午4:41
@desc:
'''
import json
import os
import re
from pydantic import BaseModel, Field
from typing import List, Dict
import requests
import numpy as np
from loguru import logger
from configs.model_config import (
CODE_SEARCH_TOP_K)
from .base_tool import BaseToolModel
from dev_opsgpt.service.cb_api import search_code
class CodeRetrieval(BaseToolModel):
name = "CodeRetrieval"
description = "采用知识图谱从本地代码知识库获取相关代码"
class ToolInputArgs(BaseModel):
query: str = Field(..., description="检索的关键字或问题")
code_base_name: str = Field(..., description="知识库名称", examples=["samples"])
code_limit: int = Field(CODE_SEARCH_TOP_K, description="检索返回的数量")
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
code: str = Field(..., description="检索代码")
@classmethod
def run(cls, code_base_name, query, code_limit=CODE_SEARCH_TOP_K, history_node_list=[], search_type="tag"):
"""excute your tool!"""
search_type = {
'基于 cypher': 'cypher',
'基于标签': 'tag',
'基于描述': 'description',
'tag': 'tag',
'description': 'description',
'cypher': 'cypher'
}.get(search_type, 'tag')
# default
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list)
return_codes = []
context = codes['context']
related_nodes = codes['related_vertices']
logger.debug(f"{code_base_name}, {query}, {code_limit}, {search_type}")
logger.debug(f"context: {context}, related_nodes: {related_nodes}")
return_codes.append({'index': 0, 'code': context, "related_nodes": related_nodes})
return return_codes