# encoding: utf-8 ''' @author: 温进 @file: cb_query_tool.py @time: 2023/11/2 下午4:41 @desc: ''' import json import os import re from pydantic import BaseModel, Field from typing import List, Dict import requests import numpy as np from loguru import logger from configs.model_config import ( CODE_SEARCH_TOP_K) from .base_tool import BaseToolModel from dev_opsgpt.service.cb_api import search_code class CodeRetrieval(BaseToolModel): name = "CodeRetrieval" description = "采用知识图谱从本地代码知识库获取相关代码" class ToolInputArgs(BaseModel): query: str = Field(..., description="检索的关键字或问题") code_base_name: str = Field(..., description="知识库名称", examples=["samples"]) code_limit: int = Field(CODE_SEARCH_TOP_K, description="检索返回的数量") class ToolOutputArgs(BaseModel): """Output for MetricsQuery.""" code: str = Field(..., description="检索代码") @classmethod def run(cls, code_base_name, query, code_limit=CODE_SEARCH_TOP_K, history_node_list=[]): """excute your tool!""" codes = search_code(code_base_name, query, code_limit, history_node_list=history_node_list) return_codes = [] related_code = codes['related_code'] related_nodes = codes['related_node'] for idx, code in enumerate(related_code): return_codes.append({'index': idx, 'code': code, "related_nodes": related_nodes}) return return_codes