codefuse-chatbot/coagent/tools/docs_retrieval.py

40 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pydantic import BaseModel, Field
from loguru import logger
from coagent.llm_models.llm_config import EmbedConfig
from .base_tool import BaseToolModel
from coagent.service.kb_api import search_docs
class DocRetrieval(BaseToolModel):
name = "DocRetrieval"
description = "采用向量化对本地知识库进行检索"
class ToolInputArgs(BaseModel):
query: str = Field(..., description="检索的关键字或问题")
knowledge_base_name: str = Field(..., description="知识库名称", examples=["samples"])
search_top: int = Field(5, description="检索返回的数量")
score_threshold: float = Field(1.0, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1)
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
title: str = Field(..., description="检索网页标题")
snippet: str = Field(..., description="检索内容的判断")
link: str = Field(..., description="检索网页地址")
@classmethod
def run(cls, query, knowledge_base_name, search_top=5, score_threshold=1.0, embed_config: EmbedConfig=EmbedConfig(), kb_root_path: str=""):
"""excute your tool!"""
try:
docs = search_docs(query, knowledge_base_name, search_top, score_threshold,
kb_root_path=kb_root_path, embed_engine=embed_config.embed_engine,
embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
model_device=embed_config.model_device
)
except Exception as e:
logger.exception(e)
return_docs = []
for idx, doc in enumerate(docs):
return_docs.append({"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("source"), "link": doc.metadata.get("source")})
return return_docs