61 lines
2.3 KiB
Python
61 lines
2.3 KiB
Python
# encoding: utf-8
|
|
'''
|
|
@author: 温进
|
|
@file: cb_query_tool.py
|
|
@time: 2023/11/2 下午4:41
|
|
@desc:
|
|
'''
|
|
from pydantic import BaseModel, Field
|
|
from loguru import logger
|
|
|
|
from coagent.llm_models import LLMConfig, EmbedConfig
|
|
|
|
from .base_tool import BaseToolModel
|
|
|
|
from coagent.service.cb_api import search_code
|
|
|
|
|
|
class CodeRetrieval(BaseToolModel):
|
|
name = "CodeRetrieval"
|
|
description = "采用知识图谱从本地代码知识库获取相关代码"
|
|
|
|
class ToolInputArgs(BaseModel):
|
|
query: str = Field(..., description="检索的关键字或问题")
|
|
code_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
|
code_limit: int = Field(1, description="检索返回的数量")
|
|
|
|
class ToolOutputArgs(BaseModel):
|
|
"""Output for MetricsQuery."""
|
|
code: str = Field(..., description="检索代码")
|
|
|
|
@classmethod
|
|
def run(cls, code_base_name, query, code_limit=1, history_node_list=[], search_type="tag", llm_config: LLMConfig=None, embed_config: EmbedConfig=None):
|
|
"""excute your tool!"""
|
|
|
|
search_type = {
|
|
'基于 cypher': 'cypher',
|
|
'基于标签': 'tag',
|
|
'基于描述': 'description',
|
|
'tag': 'tag',
|
|
'description': 'description',
|
|
'cypher': 'cypher'
|
|
}.get(search_type, 'tag')
|
|
|
|
# default
|
|
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list,
|
|
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
|
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
|
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key
|
|
)
|
|
return_codes = []
|
|
context = codes['context']
|
|
related_nodes = codes['related_vertices']
|
|
logger.debug(f"{code_base_name}, {query}, {code_limit}, {search_type}")
|
|
logger.debug(f"context: {context}, related_nodes: {related_nodes}")
|
|
|
|
return_codes.append({'index': 0, 'code': context, "related_nodes": related_nodes})
|
|
|
|
return return_codes
|
|
|
|
|