# encoding: utf-8
'''
@author: 温进
@file: cypher_generator.py
@time: 2023/11/24 上午10:17
@desc:
'''
from loguru import logger

from dev_opsgpt.llm_models.openai_model import getChatModel
from dev_opsgpt.utils.postprocess import replace_lt_gt
from langchain.schema import (
    HumanMessage,
)
from langchain.chains.graph_qa.prompts import NGQL_GENERATION_PROMPT


schema = '''
Node properties: [{'tag': 'package', 'properties': []}, {'tag': 'class', 'properties': []}, {'tag': 'method', 'properties': []}]
Edge properties: [{'edge': 'contain', 'properties': []}, {'edge': 'depend', 'properties': []}]
Relationships: ['(:package)-[:contain]->(:class)', '(:class)-[:contain]->(:method)', '(:package)-[:contain]->(:package)']
'''


class CypherGenerator:
    def __init__(self):
        self.model = getChatModel()

    def get_cypher(self, query: str):
        '''
        get cypher from query
        @param query:
        @return:
        '''
        content = NGQL_GENERATION_PROMPT.format(schema=schema, question=query)

        ans = ''
        message = [HumanMessage(content=content)]
        chat_res = self.model.predict_messages(message)
        ans = chat_res.content

        ans = replace_lt_gt(ans)

        ans = self.post_process(ans)
        return ans

    def post_process(self, cypher_res: str):
        '''
        判断是否为正确的 cypher
        @param cypher_res:
        @return:
        '''
        if '(' not in cypher_res or ')' not in cypher_res:
            return ''

        return cypher_res


if __name__ == '__main__':
    query = '代码中一共有多少个类'
    cg = CypherGenerator(engine='openai')

    cg.get_cypher(query)