43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
|
import json
|
|||
|
import os
|
|||
|
import re
|
|||
|
from pydantic import BaseModel, Field
|
|||
|
from typing import List, Dict
|
|||
|
import requests
|
|||
|
import numpy as np
|
|||
|
from loguru import logger
|
|||
|
|
|||
|
from configs.model_config import (
|
|||
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
|||
|
from .base_tool import BaseToolModel
|
|||
|
|
|||
|
|
|||
|
|
|||
|
from dev_opsgpt.service.kb_api import search_docs
|
|||
|
|
|||
|
|
|||
|
class DocRetrieval(BaseToolModel):
|
|||
|
name = "DocRetrieval"
|
|||
|
description = "采用向量化对本地知识库进行检索"
|
|||
|
|
|||
|
class ToolInputArgs(BaseModel):
|
|||
|
query: str = Field(..., description="检索的关键字或问题")
|
|||
|
knowledge_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
|||
|
search_top: int = Field(VECTOR_SEARCH_TOP_K, description="检索返回的数量")
|
|||
|
score_threshold: float = Field(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1)
|
|||
|
|
|||
|
class ToolOutputArgs(BaseModel):
|
|||
|
"""Output for MetricsQuery."""
|
|||
|
title: str = Field(..., description="检索网页标题")
|
|||
|
snippet: str = Field(..., description="检索内容的判断")
|
|||
|
link: str = Field(..., description="检索网页地址")
|
|||
|
|
|||
|
@classmethod
|
|||
|
def run(cls, query, knowledge_base_name, search_top=VECTOR_SEARCH_TOP_K, score_threshold=SCORE_THRESHOLD):
|
|||
|
"""excute your tool!"""
|
|||
|
docs = search_docs(query, knowledge_base_name, search_top, score_threshold)
|
|||
|
return_docs = []
|
|||
|
for idx, doc in enumerate(docs):
|
|||
|
return_docs.append({"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("source"), "link": doc.metadata.get("source")})
|
|||
|
return return_docs
|