Files
rag_streamlit/main.py
2025-03-12 11:45:29 +09:00

210 lines
7.0 KiB
Python

import csv
import os
import streamlit as st
from dotenv import load_dotenv
from kiwipiepy import Kiwi
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.document_loaders import (
CSVLoader,
DirectoryLoader,
PyPDFLoader,
TextLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
)
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.messages import ChatMessage
from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
# from langchain.retrievers.document_compressors import CrossEncoderReranker
from custom_CrossEncoderReranker import CrossEncoderReranker
csv.field_size_limit(10000000)
load_dotenv()
EMBEDDING_MODEL = "nlpai-lab/KoE5"
RERANK_MODEL = "upskyy/ko-reranker-8k"
LOAD_DOCS = "data/fake_rag"
SAVE_VD = "fake_all_KCDS"
def embedding():
model_kwargs = {"device": "cuda"}
model_name = EMBEDDING_MODEL
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
return embeddings
def load_and_process_documents(directory_path):
loaders = [
DirectoryLoader(
directory_path,
glob="*.html",
loader_cls=UnstructuredHTMLLoader,
loader_kwargs={"encoding": "utf-8"},
),
DirectoryLoader(
directory_path,
glob="*.csv",
loader_cls=CSVLoader,
loader_kwargs={"encoding": "utf-8"},
),
DirectoryLoader(directory_path, glob="*.pdf", loader_cls=PyPDFLoader),
DirectoryLoader(
directory_path,
glob="*.md",
loader_cls=UnstructuredMarkdownLoader,
loader_kwargs={"encoding": "utf-8"},
),
DirectoryLoader(
directory_path,
glob="*.txt",
loader_cls=TextLoader,
loader_kwargs={"encoding": "utf-8"},
),
]
all_documents = []
for loader in loaders:
loaded_docs = loader.load()
all_documents.extend(loaded_docs)
# 메타데이터에 source 파일명 기록
for doc in all_documents:
filename = os.path.basename(doc.metadata["source"])
filename_without_ext = os.path.splitext(filename)[0]
doc.metadata = {"source": filename_without_ext}
text_splitter = SemanticChunker(embeddings=embedding())
texts = text_splitter.split_documents(all_documents)
return texts
def save_faiss_index(faiss_index, path="faiss_index"):
faiss_index.save_local(path)
print("-------------------------------------------")
print("save_faiss_index")
print("-------------------------------------------")
def load_faiss_index(path="faiss_index"):
embeddings = embedding()
print("-------------------------------------------")
print("load_faiss_index")
print("-------------------------------------------")
return FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
kiwi = Kiwi()
@st.cache_resource
def initialize_processing():
try:
faiss_index = load_faiss_index(SAVE_VD)
docs = list(faiss_index.docstore._dict.values())
except Exception:
print("no load_faiss_index")
docs = load_and_process_documents(LOAD_DOCS)
embeddings = embedding()
faiss_index = FAISS.from_documents(docs, embeddings)
save_faiss_index(faiss_index, SAVE_VD)
faiss = faiss_index.as_retriever(search_kwargs={"k": 3})
kiwi_bm25 = BM25Retriever.from_documents(
docs, preprocess_func=lambda text: [token.form for token in kiwi.tokenize(text)]
)
kiwi_bm25.k = 3
st.session_state["retrievers"] = {
"kiwi_bm25": kiwi_bm25,
"faiss": faiss,
}
st.session_state["Ensembleretriever"] = EnsembleRetriever(
retrievers=[kiwi_bm25, faiss],
weights=[0.6, 0.4],
search_type="mmr",
)
rerank_model = HuggingFaceCrossEncoder(model_name=RERANK_MODEL)
compressor = CrossEncoderReranker(model=rerank_model, top_n=3)
st.session_state["compression_retriever"] = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=st.session_state["Ensembleretriever"]
)
return st.session_state["retrievers"], st.session_state["compression_retriever"]
# GPT 모델 설정
gpt_model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
st.set_page_config(page_title="RAG 테스트", page_icon="💬")
st.title("RAG 테스트")
if "initialize_processing" not in st.session_state:
st.session_state["retrievers"], st.session_state["compression_retriever"] = (
initialize_processing()
)
retrievers = st.session_state["retrievers"]
reranker = st.session_state["compression_retriever"]
if "messages" not in st.session_state:
st.session_state["messages"] = [
ChatMessage(role="assistant", content="무엇을 도와드릴까요?")
]
# 대화 기록 출력
def print_history(st, session_state):
for msg in session_state.messages:
st.chat_message(msg.role).write(msg.content)
# 대화 기록 추가
def add_history(session_state, role, content):
session_state.messages.append(ChatMessage(role=role, content=content))
print_history(st, st.session_state)
if user_input := st.chat_input():
add_history(st.session_state, "user", user_input)
st.chat_message("user").write(user_input)
with st.chat_message("assistant"):
chat_container = st.empty()
with st.spinner("문서 검색 및 답변 생성 중..."):
docs = reranker.invoke(user_input)
context_texts = "\n\n".join(
[
f"- {doc[0].metadata['source']} (유사도: {doc[1]}): {doc[0].page_content}"
for doc in docs
]
)
prompt = f"""사용자의 질문: {user_input}
다음은 참고할 수 있는 문서입니다:
{context_texts}
위 문서의 내용을 기반으로 사용자의 질문에 대한 답변을 생성해주세요.
문서에서 관련 내용을 찾을 수 없다면, 일반적인 정보를 제공해주세요.
"""
gpt_response = gpt_model.invoke(prompt)
with st.chat_message("assistant"):
add_history(st.session_state, "assistant", gpt_response.content)
st.markdown(f"**Chat 응답:**\n\n{gpt_response.content}")
st.markdown("---")
# 문서 검색 결과 출력
for idx, (doc, score) in enumerate(docs):
with st.chat_message("assistant"):
st.markdown(
f"**{idx + 1}번 문서**: {doc.metadata['source']} (유사도: {score})"
)
with st.expander("문서 내용 보기"):
st.markdown(doc.page_content)
st.markdown("---")