63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
# encoding: utf-8
|
|
'''
|
|
@author: 温进
|
|
@file: cb_query_tool.py
|
|
@time: 2023/11/2 下午4:41
|
|
@desc:
|
|
'''
|
|
import json
|
|
import os
|
|
import re
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Dict
|
|
import requests
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from configs.model_config import (
|
|
CODE_SEARCH_TOP_K)
|
|
from .base_tool import BaseToolModel
|
|
|
|
from dev_opsgpt.service.cb_api import search_code
|
|
|
|
|
|
class CodeRetrieval(BaseToolModel):
|
|
name = "CodeRetrieval"
|
|
description = "采用知识图谱从本地代码知识库获取相关代码"
|
|
|
|
class ToolInputArgs(BaseModel):
|
|
query: str = Field(..., description="检索的关键字或问题")
|
|
code_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
|
code_limit: int = Field(CODE_SEARCH_TOP_K, description="检索返回的数量")
|
|
|
|
class ToolOutputArgs(BaseModel):
|
|
"""Output for MetricsQuery."""
|
|
code: str = Field(..., description="检索代码")
|
|
|
|
@classmethod
|
|
def run(cls, code_base_name, query, code_limit=CODE_SEARCH_TOP_K, history_node_list=[], search_type="tag"):
|
|
"""excute your tool!"""
|
|
|
|
search_type = {
|
|
'基于 cypher': 'cypher',
|
|
'基于标签': 'tag',
|
|
'基于描述': 'description',
|
|
'tag': 'tag',
|
|
'description': 'description',
|
|
'cypher': 'cypher'
|
|
}.get(search_type, 'tag')
|
|
|
|
# default
|
|
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list)
|
|
return_codes = []
|
|
context = codes['context']
|
|
related_nodes = codes['related_vertices']
|
|
logger.debug(f"{code_base_name}, {query}, {code_limit}, {search_type}")
|
|
logger.debug(f"context: {context}, related_nodes: {related_nodes}")
|
|
|
|
return_codes.append({'index': 0, 'code': context, "related_nodes": related_nodes})
|
|
|
|
return return_codes
|
|
|
|
|