48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
# encoding: utf-8
|
|
'''
|
|
@author: 温进
|
|
@file: cb_query_tool.py
|
|
@time: 2023/11/2 下午4:41
|
|
@desc:
|
|
'''
|
|
import json
|
|
import os
|
|
import re
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Dict
|
|
import requests
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from configs.model_config import (
|
|
CODE_SEARCH_TOP_K)
|
|
from .base_tool import BaseToolModel
|
|
|
|
from dev_opsgpt.service.cb_api import search_code
|
|
|
|
|
|
class CodeRetrieval(BaseToolModel):
|
|
name = "CodeRetrieval"
|
|
description = "采用知识图谱从本地代码知识库获取相关代码"
|
|
|
|
class ToolInputArgs(BaseModel):
|
|
query: str = Field(..., description="检索的关键字或问题")
|
|
code_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
|
code_limit: int = Field(CODE_SEARCH_TOP_K, description="检索返回的数量")
|
|
|
|
class ToolOutputArgs(BaseModel):
|
|
"""Output for MetricsQuery."""
|
|
code: str = Field(..., description="检索代码")
|
|
|
|
@classmethod
|
|
def run(cls, code_base_name, query, code_limit=CODE_SEARCH_TOP_K, history_node_list=[]):
|
|
"""excute your tool!"""
|
|
codes = search_code(code_base_name, query, code_limit, history_node_list=history_node_list)
|
|
return_codes = []
|
|
related_code = codes['related_code']
|
|
related_nodes = codes['related_node']
|
|
|
|
for idx, code in enumerate(related_code):
|
|
return_codes.append({'index': idx, 'code': code, "related_nodes": related_nodes})
|
|
return return_codes
|