codefuse-chatbot/dev_opsgpt/tools/docs_retrieval.py

43 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import re
from pydantic import BaseModel, Field
from typing import List, Dict
import requests
import numpy as np
from loguru import logger
from configs.model_config import (
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from .base_tool import BaseToolModel
from dev_opsgpt.service.kb_api import search_docs
class DocRetrieval(BaseToolModel):
name = "DocRetrieval"
description = "采用向量化对本地知识库进行检索"
class ToolInputArgs(BaseModel):
query: str = Field(..., description="检索的关键字或问题")
knowledge_base_name: str = Field(..., description="知识库名称", examples=["samples"])
search_top: int = Field(VECTOR_SEARCH_TOP_K, description="检索返回的数量")
score_threshold: float = Field(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1)
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
title: str = Field(..., description="检索网页标题")
snippet: str = Field(..., description="检索内容的判断")
link: str = Field(..., description="检索网页地址")
@classmethod
def run(cls, query, knowledge_base_name, search_top=VECTOR_SEARCH_TOP_K, score_threshold=SCORE_THRESHOLD):
"""excute your tool!"""
docs = search_docs(query, knowledge_base_name, search_top, score_threshold)
return_docs = []
for idx, doc in enumerate(docs):
return_docs.append({"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("source"), "link": doc.metadata.get("source")})
return return_docs