201 lines
6.6 KiB
Python
201 lines
6.6 KiB
Python
import csv
|
|
import os
|
|
|
|
import streamlit as st
|
|
from dotenv import load_dotenv
|
|
from kiwipiepy import Kiwi
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
|
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
|
from langchain_community.document_loaders import (
|
|
CSVLoader,
|
|
DirectoryLoader,
|
|
PyPDFLoader,
|
|
TextLoader,
|
|
UnstructuredHTMLLoader,
|
|
UnstructuredMarkdownLoader,
|
|
)
|
|
from langchain_community.retrievers import BM25Retriever
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain_core.messages import ChatMessage
|
|
from langchain_experimental.text_splitter import SemanticChunker
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
csv.field_size_limit(10000000)
|
|
load_dotenv()
|
|
EMBEDDING_MODEL = "nlpai-lab/KoE5"
|
|
RERANK_MODEL = "upskyy/ko-reranker-8k"
|
|
LOAD_DOCS = "data/fake_rag"
|
|
SAVE_VD = "fake_all"
|
|
|
|
|
|
def embedding():
|
|
model_kwargs = {"device": "cuda"}
|
|
model_name = EMBEDDING_MODEL
|
|
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
|
|
return embeddings
|
|
|
|
|
|
def load_and_process_documents(directory_path):
|
|
loaders = [
|
|
DirectoryLoader(
|
|
directory_path,
|
|
glob="*.html",
|
|
loader_cls=UnstructuredHTMLLoader,
|
|
loader_kwargs={"encoding": "utf-8"},
|
|
),
|
|
DirectoryLoader(
|
|
directory_path,
|
|
glob="*.csv",
|
|
loader_cls=CSVLoader,
|
|
loader_kwargs={"encoding": "utf-8"},
|
|
),
|
|
DirectoryLoader(directory_path, glob="*.pdf", loader_cls=PyPDFLoader),
|
|
DirectoryLoader(
|
|
directory_path,
|
|
glob="*.md",
|
|
loader_cls=UnstructuredMarkdownLoader,
|
|
loader_kwargs={"encoding": "utf-8"},
|
|
),
|
|
DirectoryLoader(
|
|
directory_path,
|
|
glob="*.txt",
|
|
loader_cls=TextLoader,
|
|
loader_kwargs={"encoding": "utf-8"},
|
|
),
|
|
]
|
|
|
|
all_documents = []
|
|
for loader in loaders:
|
|
loaded_docs = loader.load()
|
|
all_documents.extend(loaded_docs)
|
|
|
|
# 메타데이터에 source 파일명 기록
|
|
for doc in all_documents:
|
|
filename = os.path.basename(doc.metadata["source"])
|
|
filename_without_ext = os.path.splitext(filename)[0]
|
|
doc.metadata = {"source": filename_without_ext}
|
|
|
|
text_splitter = SemanticChunker(embeddings=embedding())
|
|
texts = text_splitter.split_documents(all_documents)
|
|
return texts
|
|
|
|
|
|
def save_faiss_index(faiss_index, path="faiss_index"):
|
|
faiss_index.save_local(path)
|
|
|
|
|
|
def load_faiss_index(path="faiss_index"):
|
|
embeddings = embedding()
|
|
return FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
|
|
|
|
|
|
kiwi = Kiwi()
|
|
|
|
|
|
@st.cache_resource
|
|
def initialize_processing():
|
|
processed_texts = load_and_process_documents(LOAD_DOCS)
|
|
docs = processed_texts
|
|
try:
|
|
faiss_index = load_faiss_index(SAVE_VD)
|
|
except Exception:
|
|
embeddings = embedding()
|
|
faiss_index = FAISS.from_documents(docs, embeddings)
|
|
save_faiss_index(faiss_index, SAVE_VD)
|
|
faiss = faiss_index.as_retriever(search_kwargs={"k": 3})
|
|
kiwi_bm25 = BM25Retriever.from_documents(
|
|
docs, preprocess_func=lambda text: [token.form for token in kiwi.tokenize(text)]
|
|
)
|
|
kiwi_bm25.k = 3
|
|
st.session_state["retrievers"] = {
|
|
"kiwi_bm25": kiwi_bm25,
|
|
"faiss": faiss,
|
|
}
|
|
st.session_state["Ensembleretriever"] = EnsembleRetriever(
|
|
retrievers=[kiwi_bm25, faiss],
|
|
weights=[0.6, 0.4],
|
|
search_type="mmr",
|
|
)
|
|
rerank_model = HuggingFaceCrossEncoder(model_name=RERANK_MODEL)
|
|
compressor = CrossEncoderReranker(model=rerank_model, top_n=3)
|
|
st.session_state["compression_retriever"] = ContextualCompressionRetriever(
|
|
base_compressor=compressor, base_retriever=st.session_state["Ensembleretriever"]
|
|
)
|
|
return st.session_state["retrievers"], st.session_state["compression_retriever"]
|
|
|
|
|
|
# GPT 모델 설정
|
|
gpt_model = ChatOpenAI(model_name="gpt-4", temperature=0)
|
|
|
|
st.set_page_config(page_title="RAG 테스트", page_icon="💬")
|
|
st.title("RAG 테스트")
|
|
|
|
if "initialize_processing" not in st.session_state:
|
|
st.session_state["retrievers"], st.session_state["compression_retriever"] = (
|
|
initialize_processing()
|
|
)
|
|
|
|
retrievers = st.session_state["retrievers"]
|
|
reranker = st.session_state["compression_retriever"]
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state["messages"] = [
|
|
ChatMessage(role="assistant", content="무엇을 도와드릴까요?")
|
|
]
|
|
|
|
|
|
# 대화 기록 출력
|
|
def print_history(st, session_state):
|
|
for msg in session_state.messages:
|
|
st.chat_message(msg.role).write(msg.content)
|
|
|
|
|
|
# 대화 기록 추가
|
|
def add_history(session_state, role, content):
|
|
session_state.messages.append(ChatMessage(role=role, content=content))
|
|
|
|
|
|
print_history(st, st.session_state)
|
|
|
|
if user_input := st.chat_input():
|
|
add_history(st.session_state, "user", user_input)
|
|
st.chat_message("user").write(user_input)
|
|
|
|
with st.chat_message("assistant"):
|
|
chat_container = st.empty()
|
|
|
|
with st.spinner("문서 검색 및 답변 생성 중..."):
|
|
docs = reranker.invoke(user_input)
|
|
|
|
context_texts = "\n\n".join(
|
|
[
|
|
f"- {doc[0].metadata['source']} (유사도: {doc[1]}): {doc[0].page_content}"
|
|
for doc in docs
|
|
]
|
|
)
|
|
|
|
prompt = f"""사용자의 질문: {user_input}
|
|
다음은 참고할 수 있는 문서입니다:
|
|
{context_texts}
|
|
위 문서의 내용을 기반으로 사용자의 질문에 대한 답변을 생성해주세요.
|
|
문서에서 관련 내용을 찾을 수 없다면, 일반적인 정보를 제공해주세요.
|
|
"""
|
|
gpt_response = gpt_model.invoke(prompt)
|
|
|
|
with st.chat_message("assistant"):
|
|
add_history(st.session_state, "assistant", gpt_response.content)
|
|
st.markdown(f"**Chat 응답:**\n\n{gpt_response.content}")
|
|
st.markdown("---")
|
|
|
|
# 문서 검색 결과 출력
|
|
for idx, (doc, score) in enumerate(docs):
|
|
with st.chat_message("assistant"):
|
|
st.markdown(
|
|
f"**{idx + 1}번 문서**: {doc.metadata['source']} (유사도: {score})"
|
|
)
|
|
with st.expander("문서 내용 보기"):
|
|
st.markdown(doc.page_content)
|
|
st.markdown("---")
|