Fix Dockerfile build issue
This commit is contained in:
75
autorag/vectordb/__init__.py
Normal file
75
autorag/vectordb/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from autorag.support import dynamically_find_function
|
||||
from autorag.utils.util import load_yaml_config
|
||||
from autorag.vectordb.base import BaseVectorStore
|
||||
|
||||
|
||||
def get_support_vectordb(vectordb_name: str):
|
||||
support_vectordb = {
|
||||
"chroma": ("autorag.vectordb.chroma", "Chroma"),
|
||||
"Chroma": ("autorag.vectordb.chroma", "Chroma"),
|
||||
"milvus": ("autorag.vectordb.milvus", "Milvus"),
|
||||
"Milvus": ("autorag.vectordb.milvus", "Milvus"),
|
||||
"weaviate": ("autorag.vectordb.weaviate", "Weaviate"),
|
||||
"Weaviate": ("autorag.vectordb.weaviate", "Weaviate"),
|
||||
"pinecone": ("autorag.vectordb.pinecone", "Pinecone"),
|
||||
"Pinecone": ("autorag.vectordb.pinecone", "Pinecone"),
|
||||
"couchbase": ("autorag.vectordb.couchbase", "Couchbase"),
|
||||
"Couchbase": ("autorag.vectordb.couchbase", "Couchbase"),
|
||||
"qdrant": ("autorag.vectordb.qdrant", "Qdrant"),
|
||||
"Qdrant": ("autorag.vectordb.qdrant", "Qdrant"),
|
||||
}
|
||||
return dynamically_find_function(vectordb_name, support_vectordb)
|
||||
|
||||
|
||||
def load_vectordb(vectordb_name: str, **kwargs):
|
||||
vectordb = get_support_vectordb(vectordb_name)
|
||||
return vectordb(**kwargs)
|
||||
|
||||
|
||||
def load_vectordb_from_yaml(yaml_path: str, vectordb_name: str, project_dir: str):
|
||||
config_dict = load_yaml_config(yaml_path)
|
||||
vectordb_list = config_dict.get("vectordb", [])
|
||||
if len(vectordb_list) == 0 or vectordb_name == "default":
|
||||
chroma_path = os.path.join(project_dir, "resources", "chroma")
|
||||
return load_vectordb(
|
||||
"chroma",
|
||||
client_type="persistent",
|
||||
embedding_model="openai",
|
||||
collection_name="openai",
|
||||
path=chroma_path,
|
||||
)
|
||||
|
||||
target_dict = list(filter(lambda x: x["name"] == vectordb_name, vectordb_list))
|
||||
target_dict[0].pop("name") # delete a name key
|
||||
target_vectordb_name = target_dict[0].pop("db_type")
|
||||
target_vectordb_params = target_dict[0]
|
||||
return load_vectordb(target_vectordb_name, **target_vectordb_params)
|
||||
|
||||
|
||||
def load_all_vectordb_from_yaml(
|
||||
yaml_path: str, project_dir: str
|
||||
) -> List[BaseVectorStore]:
|
||||
config_dict = load_yaml_config(yaml_path)
|
||||
vectordb_list = config_dict.get("vectordb", [])
|
||||
if len(vectordb_list) == 0:
|
||||
chroma_path = os.path.join(project_dir, "resources", "chroma")
|
||||
return [
|
||||
load_vectordb(
|
||||
"chroma",
|
||||
client_type="persistent",
|
||||
embedding_model="openai",
|
||||
collection_name="openai",
|
||||
path=chroma_path,
|
||||
)
|
||||
]
|
||||
|
||||
result_vectordbs = []
|
||||
for vectordb_dict in vectordb_list:
|
||||
_ = vectordb_dict.pop("name")
|
||||
vectordb_type = vectordb_dict.pop("db_type")
|
||||
vectordb = load_vectordb(vectordb_type, **vectordb_dict)
|
||||
result_vectordbs.append(vectordb)
|
||||
return result_vectordbs
|
||||
66
autorag/vectordb/base.py
Normal file
66
autorag/vectordb/base.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from autorag.utils.util import openai_truncate_by_token
|
||||
from autorag.embedding.base import EmbeddingModel
|
||||
|
||||
|
||||
class BaseVectorStore:
|
||||
support_similarity_metrics = ["l2", "ip", "cosine"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Union[str, List[dict]],
|
||||
similarity_metric: str = "cosine",
|
||||
embedding_batch: int = 100,
|
||||
):
|
||||
self.embedding = EmbeddingModel.load(embedding_model)()
|
||||
self.embedding_batch = embedding_batch
|
||||
self.embedding.embed_batch_size = embedding_batch
|
||||
assert (
|
||||
similarity_metric in self.support_similarity_metrics
|
||||
), f"search method {similarity_metric} is not supported"
|
||||
self.similarity_metric = similarity_metric
|
||||
|
||||
@abstractmethod
|
||||
async def add(
|
||||
self,
|
||||
ids: List[str],
|
||||
texts: List[str],
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self, queries: List[str], top_k: int, **kwargs
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Fetch the embeddings of the ids.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def is_exist(self, ids: List[str]) -> List[bool]:
|
||||
"""
|
||||
Check if the ids exist in the Vector DB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, ids: List[str]):
|
||||
pass
|
||||
|
||||
def truncated_inputs(self, inputs: List[str]) -> List[str]:
|
||||
if isinstance(self.embedding, OpenAIEmbedding):
|
||||
openai_embedding_limit = 8000
|
||||
results = openai_truncate_by_token(
|
||||
inputs, openai_embedding_limit, self.embedding.model_name
|
||||
)
|
||||
return results
|
||||
return inputs
|
||||
117
autorag/vectordb/chroma.py
Normal file
117
autorag/vectordb/chroma.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import List, Optional, Dict, Tuple, Union
|
||||
|
||||
from chromadb import (
|
||||
EphemeralClient,
|
||||
PersistentClient,
|
||||
DEFAULT_TENANT,
|
||||
DEFAULT_DATABASE,
|
||||
CloudClient,
|
||||
AsyncHttpClient,
|
||||
)
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.types import IncludeEnum, QueryResult
|
||||
|
||||
from autorag.utils.util import apply_recursive
|
||||
from autorag.vectordb.base import BaseVectorStore
|
||||
|
||||
|
||||
class Chroma(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Union[str, List[dict]],
|
||||
collection_name: str,
|
||||
embedding_batch: int = 100,
|
||||
client_type: str = "persistent",
|
||||
similarity_metric: str = "cosine",
|
||||
path: str = None,
|
||||
host: str = "localhost",
|
||||
port: int = 8000,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
tenant: str = DEFAULT_TENANT,
|
||||
database: str = DEFAULT_DATABASE,
|
||||
):
|
||||
super().__init__(embedding_model, similarity_metric, embedding_batch)
|
||||
if client_type == "ephemeral":
|
||||
self.client = EphemeralClient(tenant=tenant, database=database)
|
||||
elif client_type == "persistent":
|
||||
assert path is not None, "path must be provided for persistent client"
|
||||
self.client = PersistentClient(path=path, tenant=tenant, database=database)
|
||||
elif client_type == "http":
|
||||
self.client = AsyncHttpClient(
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
)
|
||||
elif client_type == "cloud":
|
||||
self.client = CloudClient(
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
api_key=api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"client_type {client_type} is not supported\n"
|
||||
"supported client types are: ephemeral, persistent, http, cloud"
|
||||
)
|
||||
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": similarity_metric},
|
||||
)
|
||||
|
||||
async def add(self, ids: List[str], texts: List[str]):
|
||||
texts = self.truncated_inputs(texts)
|
||||
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
|
||||
if isinstance(self.collection, AsyncCollection):
|
||||
await self.collection.add(ids=ids, embeddings=text_embeddings)
|
||||
else:
|
||||
self.collection.add(ids=ids, embeddings=text_embeddings)
|
||||
|
||||
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
||||
if isinstance(self.collection, AsyncCollection):
|
||||
fetch_result = await self.collection.get(
|
||||
ids, include=[IncludeEnum.embeddings]
|
||||
)
|
||||
else:
|
||||
fetch_result = self.collection.get(ids, include=[IncludeEnum.embeddings])
|
||||
fetch_embeddings = fetch_result["embeddings"]
|
||||
return fetch_embeddings
|
||||
|
||||
async def is_exist(self, ids: List[str]) -> List[bool]:
|
||||
if isinstance(self.collection, AsyncCollection):
|
||||
fetched_result = await self.collection.get(ids, include=[])
|
||||
else:
|
||||
fetched_result = self.collection.get(ids, include=[])
|
||||
existed_ids = fetched_result["ids"]
|
||||
return list(map(lambda x: x in existed_ids, ids))
|
||||
|
||||
async def query(
|
||||
self, queries: List[str], top_k: int, **kwargs
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
queries = self.truncated_inputs(queries)
|
||||
query_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(queries)
|
||||
if isinstance(self.collection, AsyncCollection):
|
||||
query_result: QueryResult = await self.collection.query(
|
||||
query_embeddings=query_embeddings, n_results=top_k
|
||||
)
|
||||
else:
|
||||
query_result: QueryResult = self.collection.query(
|
||||
query_embeddings=query_embeddings, n_results=top_k
|
||||
)
|
||||
ids = query_result["ids"]
|
||||
scores = query_result["distances"]
|
||||
scores = apply_recursive(lambda x: 1 - x, scores)
|
||||
return ids, scores
|
||||
|
||||
async def delete(self, ids: List[str]):
|
||||
if isinstance(self.collection, AsyncCollection):
|
||||
await self.collection.delete(ids)
|
||||
else:
|
||||
self.collection.delete(ids)
|
||||
218
autorag/vectordb/couchbase.py
Normal file
218
autorag/vectordb/couchbase.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.options import ClusterOptions
|
||||
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
from autorag.utils.util import make_batch
|
||||
from autorag.vectordb import BaseVectorStore
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class Couchbase(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Union[str, List[dict]],
|
||||
bucket_name: str,
|
||||
scope_name: str,
|
||||
collection_name: str,
|
||||
index_name: str,
|
||||
embedding_batch: int = 100,
|
||||
connection_string: str = "",
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
ingest_batch: int = 100,
|
||||
text_key: Optional[str] = "text",
|
||||
embedding_key: Optional[str] = "embedding",
|
||||
scoped_index: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
embedding_model=embedding_model,
|
||||
similarity_metric="ip",
|
||||
embedding_batch=embedding_batch,
|
||||
)
|
||||
|
||||
self.index_name = index_name
|
||||
self.bucket_name = bucket_name
|
||||
self.scope_name = scope_name
|
||||
self.collection_name = collection_name
|
||||
self.scoped_index = scoped_index
|
||||
self.text_key = text_key
|
||||
self.embedding_key = embedding_key
|
||||
self.ingest_batch = ingest_batch
|
||||
|
||||
auth = PasswordAuthenticator(username, password)
|
||||
self.cluster = Cluster(connection_string, ClusterOptions(auth))
|
||||
|
||||
# Wait until the cluster is ready for use.
|
||||
self.cluster.wait_until_ready(timedelta(seconds=5))
|
||||
|
||||
# Check if the bucket exists
|
||||
if not self._check_bucket_exists():
|
||||
raise ValueError(
|
||||
f"Bucket {self.bucket_name} does not exist. "
|
||||
" Please create the bucket before searching."
|
||||
)
|
||||
|
||||
try:
|
||||
self.bucket = self.cluster.bucket(self.bucket_name)
|
||||
self.scope = self.bucket.scope(self.scope_name)
|
||||
self.collection = self.scope.collection(self.collection_name)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Error connecting to couchbase. "
|
||||
"Please check the connection and credentials."
|
||||
) from e
|
||||
|
||||
# Check if the index exists. Throws ValueError if it doesn't
|
||||
try:
|
||||
self._check_index_exists()
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# Reinitialize to ensure a consistent state
|
||||
self.bucket = self.cluster.bucket(self.bucket_name)
|
||||
self.scope = self.bucket.scope(self.scope_name)
|
||||
self.collection = self.scope.collection(self.collection_name)
|
||||
|
||||
async def add(self, ids: List[str], texts: List[str]):
|
||||
from couchbase.exceptions import DocumentExistsException
|
||||
|
||||
texts = self.truncated_inputs(texts)
|
||||
text_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(texts)
|
||||
|
||||
documents_to_insert = []
|
||||
for _id, text, embedding in zip(ids, texts, text_embeddings):
|
||||
doc = {
|
||||
self.text_key: text,
|
||||
self.embedding_key: embedding,
|
||||
}
|
||||
documents_to_insert.append({_id: doc})
|
||||
|
||||
batch_documents_to_insert = make_batch(documents_to_insert, self.ingest_batch)
|
||||
|
||||
for batch in batch_documents_to_insert:
|
||||
insert_batch = {}
|
||||
for doc in batch:
|
||||
insert_batch.update(doc)
|
||||
try:
|
||||
self.collection.upsert_multi(insert_batch)
|
||||
except DocumentExistsException as e:
|
||||
logger.debug(f"Document already exists: {e}")
|
||||
|
||||
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
||||
# Fetch vectors by IDs
|
||||
fetched_result = self.collection.get_multi(ids)
|
||||
fetched_vectors = {
|
||||
k: v.value[f"{self.embedding_key}"]
|
||||
for k, v in fetched_result.results.items()
|
||||
}
|
||||
return list(map(lambda x: fetched_vectors[x], ids))
|
||||
|
||||
async def is_exist(self, ids: List[str]) -> List[bool]:
|
||||
existed_result = self.collection.exists_multi(ids)
|
||||
existed_ids = {k: v.exists for k, v in existed_result.results.items()}
|
||||
return list(map(lambda x: existed_ids[x], ids))
|
||||
|
||||
async def query(
|
||||
self, queries: List[str], top_k: int, **kwargs
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
import couchbase.search as search
|
||||
from couchbase.options import SearchOptions
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch
|
||||
|
||||
queries = self.truncated_inputs(queries)
|
||||
query_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(queries)
|
||||
|
||||
ids, scores = [], []
|
||||
for query_embedding in query_embeddings:
|
||||
# Create Search Request
|
||||
search_req = search.SearchRequest.create(
|
||||
VectorSearch.from_vector_query(
|
||||
VectorQuery(
|
||||
self.embedding_key,
|
||||
query_embedding,
|
||||
top_k,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Search
|
||||
if self.scoped_index:
|
||||
search_iter = self.scope.search(
|
||||
self.index_name,
|
||||
search_req,
|
||||
SearchOptions(limit=top_k),
|
||||
)
|
||||
|
||||
else:
|
||||
search_iter = self.cluster.search(
|
||||
self.index_name,
|
||||
search_req,
|
||||
SearchOptions(limit=top_k),
|
||||
)
|
||||
|
||||
# Parse the search results
|
||||
# search_iter.rows() can only be iterated once.
|
||||
id_list, score_list = [], []
|
||||
for result in search_iter.rows():
|
||||
id_list.append(result.id)
|
||||
score_list.append(result.score)
|
||||
|
||||
ids.append(id_list)
|
||||
scores.append(score_list)
|
||||
|
||||
return ids, scores
|
||||
|
||||
async def delete(self, ids: List[str]):
|
||||
self.collection.remove_multi(ids)
|
||||
|
||||
def _check_bucket_exists(self) -> bool:
|
||||
"""Check if the bucket exists in the linked Couchbase cluster.
|
||||
|
||||
Returns:
|
||||
True if the bucket exists
|
||||
"""
|
||||
bucket_manager = self.cluster.buckets()
|
||||
try:
|
||||
bucket_manager.get_bucket(self.bucket_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Error checking if bucket exists:", e)
|
||||
return False
|
||||
|
||||
def _check_index_exists(self) -> bool:
|
||||
"""Check if the Search index exists in the linked Couchbase cluster
|
||||
Returns:
|
||||
bool: True if the index exists, False otherwise.
|
||||
Raises a ValueError if the index does not exist.
|
||||
"""
|
||||
if self.scoped_index:
|
||||
all_indexes = [
|
||||
index.name for index in self.scope.search_indexes().get_all_indexes()
|
||||
]
|
||||
if self.index_name not in all_indexes:
|
||||
raise ValueError(
|
||||
f"Index {self.index_name} does not exist. "
|
||||
" Please create the index before searching."
|
||||
)
|
||||
else:
|
||||
all_indexes = [
|
||||
index.name for index in self.cluster.search_indexes().get_all_indexes()
|
||||
]
|
||||
if self.index_name not in all_indexes:
|
||||
raise ValueError(
|
||||
f"Index {self.index_name} does not exist. "
|
||||
" Please create the index before searching."
|
||||
)
|
||||
|
||||
return True
|
||||
168
autorag/vectordb/milvus.py
Normal file
168
autorag/vectordb/milvus.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple, Optional, Union
|
||||
|
||||
from pymilvus import (
|
||||
DataType,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
connections,
|
||||
Collection,
|
||||
MilvusException,
|
||||
)
|
||||
from pymilvus.orm import utility
|
||||
|
||||
from autorag.utils.util import apply_recursive
|
||||
from autorag.vectordb import BaseVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class Milvus(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Union[str, List[dict]],
|
||||
collection_name: str,
|
||||
embedding_batch: int = 100,
|
||||
similarity_metric: str = "cosine",
|
||||
index_type: str = "IVF_FLAT",
|
||||
uri: str = "http://localhost:19530",
|
||||
db_name: str = "",
|
||||
token: str = "",
|
||||
user: str = "",
|
||||
password: str = "",
|
||||
timeout: Optional[float] = None,
|
||||
params: Dict[str, Any] = {},
|
||||
):
|
||||
super().__init__(embedding_model, similarity_metric, embedding_batch)
|
||||
|
||||
# Connect to Milvus server
|
||||
connections.connect(
|
||||
"default",
|
||||
uri=uri,
|
||||
token=token,
|
||||
db_name=db_name,
|
||||
user=user,
|
||||
password=password,
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
self.timeout = timeout
|
||||
self.params = params
|
||||
self.index_type = index_type
|
||||
|
||||
# Set Collection
|
||||
if not utility.has_collection(collection_name, timeout=timeout):
|
||||
# Get the dimension of the embeddings
|
||||
test_embedding_result: List[float] = self.embedding.get_query_embedding(
|
||||
"test"
|
||||
)
|
||||
dimension = len(test_embedding_result)
|
||||
|
||||
pk = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=128,
|
||||
is_primary=True,
|
||||
auto_id=False,
|
||||
)
|
||||
field = FieldSchema(
|
||||
name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension
|
||||
)
|
||||
schema = CollectionSchema(fields=[pk, field])
|
||||
|
||||
self.collection = Collection(name=self.collection_name, schema=schema)
|
||||
index_params = {
|
||||
"metric_type": self.similarity_metric.upper(),
|
||||
"index_type": self.index_type.upper(),
|
||||
"params": self.params,
|
||||
}
|
||||
self.collection.create_index(
|
||||
field_name="vector", index_params=index_params, timeout=self.timeout
|
||||
)
|
||||
else:
|
||||
self.collection = Collection(name=self.collection_name)
|
||||
|
||||
async def add(self, ids: List[str], texts: List[str]):
|
||||
texts = self.truncated_inputs(texts)
|
||||
text_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(texts)
|
||||
|
||||
# make data for insertion
|
||||
data = list(
|
||||
map(lambda _id, vector: {"id": _id, "vector": vector}, ids, text_embeddings)
|
||||
)
|
||||
|
||||
# Insert data into the collection
|
||||
res = self.collection.insert(data=data, timeout=self.timeout)
|
||||
assert (
|
||||
res.insert_count == len(ids)
|
||||
), f"Insertion failed. Try to insert {len(ids)} but only {res['insert_count']} inserted."
|
||||
|
||||
self.collection.flush(timeout=self.timeout)
|
||||
|
||||
async def query(
|
||||
self, queries: List[str], top_k: int, **kwargs
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
queries = self.truncated_inputs(queries)
|
||||
query_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(queries)
|
||||
|
||||
self.collection.load(timeout=self.timeout)
|
||||
|
||||
# Perform similarity search
|
||||
results = self.collection.search(
|
||||
data=query_embeddings,
|
||||
limit=top_k,
|
||||
anns_field="vector",
|
||||
param={"metric_type": self.similarity_metric.upper()},
|
||||
timeout=self.timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract IDs and distances
|
||||
ids = [[str(hit.id) for hit in result] for result in results]
|
||||
distances = [[hit.distance for hit in result] for result in results]
|
||||
|
||||
if self.similarity_metric in ["l2"]:
|
||||
distances = apply_recursive(lambda x: -x, distances)
|
||||
|
||||
return ids, distances
|
||||
|
||||
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
||||
try:
|
||||
self.collection.load(timeout=self.timeout)
|
||||
except MilvusException as e:
|
||||
logger.warning(f"Failed to load collection: {e}")
|
||||
return [[]] * len(ids)
|
||||
# Fetch vectors by IDs
|
||||
results = self.collection.query(
|
||||
expr=f"id in {ids}", output_fields=["id", "vector"], timeout=self.timeout
|
||||
)
|
||||
id_vector_dict = {str(result["id"]): result["vector"] for result in results}
|
||||
result = [id_vector_dict[_id] for _id in ids]
|
||||
return result
|
||||
|
||||
async def is_exist(self, ids: List[str]) -> List[bool]:
|
||||
try:
|
||||
self.collection.load(timeout=self.timeout)
|
||||
except MilvusException:
|
||||
return [False] * len(ids)
|
||||
# Check the existence of IDs
|
||||
results = self.collection.query(
|
||||
expr=f"id in {ids}", output_fields=["id"], timeout=self.timeout
|
||||
)
|
||||
# Determine existence
|
||||
existing_ids = {str(result["id"]) for result in results}
|
||||
return [str(_id) in existing_ids for _id in ids]
|
||||
|
||||
async def delete(self, ids: List[str]):
|
||||
# Delete entries by IDs
|
||||
self.collection.delete(expr=f"id in {ids}", timeout=self.timeout)
|
||||
|
||||
def delete_collection(self):
|
||||
# Delete the collection
|
||||
self.collection.release(timeout=self.timeout)
|
||||
self.collection.drop_index(timeout=self.timeout)
|
||||
self.collection.drop(timeout=self.timeout)
|
||||
119
autorag/vectordb/pinecone.py
Normal file
119
autorag/vectordb/pinecone.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import logging
|
||||
|
||||
from pinecone.grpc import PineconeGRPC as Pinecone_client
|
||||
from pinecone import ServerlessSpec
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from autorag.utils.util import make_batch, apply_recursive
|
||||
from autorag.vectordb import BaseVectorStore
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class Pinecone(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Union[str, List[dict]],
|
||||
index_name: str,
|
||||
embedding_batch: int = 100,
|
||||
dimension: int = 1536,
|
||||
similarity_metric: str = "cosine", # "cosine", "dotproduct", "euclidean"
|
||||
cloud: Optional[str] = "aws",
|
||||
region: Optional[str] = "us-east-1",
|
||||
api_key: Optional[str] = None,
|
||||
deletion_protection: Optional[str] = "disabled", # "enabled" or "disabled"
|
||||
namespace: Optional[str] = "default",
|
||||
ingest_batch: int = 200,
|
||||
):
|
||||
super().__init__(embedding_model, similarity_metric, embedding_batch)
|
||||
|
||||
self.index_name = index_name
|
||||
self.namespace = namespace
|
||||
self.ingest_batch = ingest_batch
|
||||
|
||||
self.client = Pinecone_client(api_key=api_key)
|
||||
|
||||
if similarity_metric == "ip":
|
||||
similarity_metric = "dotproduct"
|
||||
elif similarity_metric == "l2":
|
||||
similarity_metric = "euclidean"
|
||||
|
||||
if not self.client.has_index(index_name):
|
||||
self.client.create_index(
|
||||
name=index_name,
|
||||
dimension=dimension,
|
||||
metric=similarity_metric,
|
||||
spec=ServerlessSpec(
|
||||
cloud=cloud,
|
||||
region=region,
|
||||
),
|
||||
deletion_protection=deletion_protection,
|
||||
)
|
||||
self.index = self.client.Index(index_name)
|
||||
|
||||
async def add(self, ids: List[str], texts: List[str]):
|
||||
texts = self.truncated_inputs(texts)
|
||||
text_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(texts)
|
||||
|
||||
vector_tuples = list(zip(ids, text_embeddings))
|
||||
batch_vectors = make_batch(vector_tuples, self.ingest_batch)
|
||||
|
||||
async_res = [
|
||||
self.index.upsert(
|
||||
vectors=batch_vector_tuples,
|
||||
namespace=self.namespace,
|
||||
async_req=True,
|
||||
)
|
||||
for batch_vector_tuples in batch_vectors
|
||||
]
|
||||
# Wait for the async requests to finish
|
||||
[async_result.result() for async_result in async_res]
|
||||
|
||||
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
||||
results = self.index.fetch(ids=ids, namespace=self.namespace)
|
||||
id_vector_dict = {
|
||||
str(key): val["values"] for key, val in results["vectors"].items()
|
||||
}
|
||||
result = [id_vector_dict[_id] for _id in ids]
|
||||
return result
|
||||
|
||||
async def is_exist(self, ids: List[str]) -> List[bool]:
|
||||
fetched_result = self.index.fetch(ids=ids, namespace=self.namespace)
|
||||
existed_ids = list(map(str, fetched_result.get("vectors", {}).keys()))
|
||||
return list(map(lambda x: x in existed_ids, ids))
|
||||
|
||||
async def query(
|
||||
self, queries: List[str], top_k: int, **kwargs
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
queries = self.truncated_inputs(queries)
|
||||
query_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(queries)
|
||||
|
||||
ids, scores = [], []
|
||||
for query_embedding in query_embeddings:
|
||||
response = self.index.query(
|
||||
vector=query_embedding,
|
||||
top_k=top_k,
|
||||
include_values=True,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
|
||||
ids.append([o.id for o in response.matches])
|
||||
scores.append([o.score for o in response.matches])
|
||||
|
||||
if self.similarity_metric in ["l2"]:
|
||||
scores = apply_recursive(lambda x: -x, scores)
|
||||
|
||||
return ids, scores
|
||||
|
||||
async def delete(self, ids: List[str]):
|
||||
# Delete entries by IDs
|
||||
self.index.delete(ids=ids, namespace=self.namespace)
|
||||
|
||||
def delete_index(self):
|
||||
# Delete the index
|
||||
self.client.delete_index(self.index_name)
|
||||
153
autorag/vectordb/qdrant.py
Normal file
153
autorag/vectordb/qdrant.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
VectorParams,
|
||||
PointStruct,
|
||||
PointIdsList,
|
||||
HasIdCondition,
|
||||
Filter,
|
||||
SearchRequest,
|
||||
)
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from autorag.vectordb import BaseVectorStore
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class Qdrant(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Union[str, List[dict]],
|
||||
collection_name: str,
|
||||
embedding_batch: int = 100,
|
||||
similarity_metric: str = "cosine",
|
||||
client_type: str = "docker",
|
||||
url: str = "http://localhost:6333",
|
||||
host: str = "",
|
||||
api_key: str = "",
|
||||
dimension: int = 1536,
|
||||
ingest_batch: int = 64,
|
||||
parallel: int = 1,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
super().__init__(embedding_model, similarity_metric, embedding_batch)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.ingest_batch = ingest_batch
|
||||
self.parallel = parallel
|
||||
self.max_retries = max_retries
|
||||
|
||||
if similarity_metric == "cosine":
|
||||
distance = Distance.COSINE
|
||||
elif similarity_metric == "ip":
|
||||
distance = Distance.DOT
|
||||
elif similarity_metric == "l2":
|
||||
distance = Distance.EUCLID
|
||||
else:
|
||||
raise ValueError(
|
||||
f"similarity_metric {similarity_metric} is not supported\n"
|
||||
"supported similarity metrics are: cosine, ip, l2"
|
||||
)
|
||||
|
||||
if client_type == "docker":
|
||||
self.client = QdrantClient(
|
||||
url=url,
|
||||
)
|
||||
elif client_type == "cloud":
|
||||
self.client = QdrantClient(
|
||||
host=host,
|
||||
api_key=api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"client_type {client_type} is not supported\n"
|
||||
"supported client types are: docker, cloud"
|
||||
)
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
self.client.create_collection(
|
||||
collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=dimension,
|
||||
distance=distance,
|
||||
),
|
||||
)
|
||||
self.collection = self.client.get_collection(collection_name)
|
||||
|
||||
async def add(self, ids: List[str], texts: List[str]):
|
||||
texts = self.truncated_inputs(texts)
|
||||
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
|
||||
|
||||
points = list(
|
||||
map(lambda x: PointStruct(id=x[0], vector=x[1]), zip(ids, text_embeddings))
|
||||
)
|
||||
|
||||
self.client.upload_points(
|
||||
collection_name=self.collection_name,
|
||||
points=points,
|
||||
batch_size=self.ingest_batch,
|
||||
parallel=self.parallel,
|
||||
max_retries=self.max_retries,
|
||||
wait=True,
|
||||
)
|
||||
|
||||
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
||||
# Fetch vectors by IDs
|
||||
fetched_results = self.client.retrieve(
|
||||
collection_name=self.collection_name,
|
||||
ids=ids,
|
||||
with_vectors=True,
|
||||
)
|
||||
return list(map(lambda x: x.vector, fetched_results))
|
||||
|
||||
async def is_exist(self, ids: List[str]) -> List[bool]:
|
||||
existed_result = self.client.scroll(
|
||||
collection_name=self.collection_name,
|
||||
scroll_filter=Filter(
|
||||
must=[
|
||||
HasIdCondition(has_id=ids),
|
||||
],
|
||||
),
|
||||
)
|
||||
# existed_result is tuple. So we use existed_result[0] to get list of Record
|
||||
existed_ids = list(map(lambda x: x.id, existed_result[0]))
|
||||
return list(map(lambda x: x in existed_ids, ids))
|
||||
|
||||
async def query(
|
||||
self, queries: List[str], top_k: int, **kwargs
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
queries = self.truncated_inputs(queries)
|
||||
query_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(queries)
|
||||
|
||||
search_queries = list(
|
||||
map(
|
||||
lambda x: SearchRequest(vector=x, limit=top_k, with_vector=True),
|
||||
query_embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
search_result = self.client.search_batch(
|
||||
collection_name=self.collection_name, requests=search_queries
|
||||
)
|
||||
|
||||
# Extract IDs and distances
|
||||
ids = [[str(hit.id) for hit in result] for result in search_result]
|
||||
scores = [[hit.score for hit in result] for result in search_result]
|
||||
|
||||
return ids, scores
|
||||
|
||||
async def delete(self, ids: List[str]):
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=PointIdsList(points=ids),
|
||||
)
|
||||
|
||||
def delete_collection(self):
|
||||
# Delete the collection
|
||||
self.client.delete_collection(self.collection_name)
|
||||
167
autorag/vectordb/weaviate.py
Normal file
167
autorag/vectordb/weaviate.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import logging
|
||||
|
||||
import weaviate
|
||||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.config import Property, DataType
|
||||
import weaviate.classes as wvc
|
||||
from weaviate.classes.query import MetadataQuery
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from autorag.vectordb import BaseVectorStore
|
||||
|
||||
logger = logging.getLogger("AutoRAG")
|
||||
|
||||
|
||||
class Weaviate(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Union[str, List[dict]],
|
||||
collection_name: str,
|
||||
embedding_batch: int = 100,
|
||||
similarity_metric: str = "cosine",
|
||||
client_type: str = "docker",
|
||||
host: str = "localhost",
|
||||
port: int = 8080,
|
||||
grpc_port: int = 50051,
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
text_key: str = "content",
|
||||
):
|
||||
super().__init__(embedding_model, similarity_metric, embedding_batch)
|
||||
|
||||
self.text_key = text_key
|
||||
|
||||
if client_type == "docker":
|
||||
self.client = weaviate.connect_to_local(
|
||||
host=host,
|
||||
port=port,
|
||||
grpc_port=grpc_port,
|
||||
)
|
||||
elif client_type == "cloud":
|
||||
self.client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=url,
|
||||
auth_credentials=Auth.api_key(api_key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"client_type {client_type} is not supported\n"
|
||||
"supported client types are: docker, cloud"
|
||||
)
|
||||
if similarity_metric == "cosine":
|
||||
distance_metric = wvc.config.VectorDistances.COSINE
|
||||
elif similarity_metric == "ip":
|
||||
distance_metric = wvc.config.VectorDistances.DOT
|
||||
elif similarity_metric == "l2":
|
||||
distance_metric = wvc.config.VectorDistances.L2_SQUARED
|
||||
else:
|
||||
raise ValueError(
|
||||
f"similarity_metric {similarity_metric} is not supported\n"
|
||||
"supported similarity metrics are: cosine, ip, l2"
|
||||
)
|
||||
|
||||
if not self.client.collections.exists(collection_name):
|
||||
self.client.collections.create(
|
||||
collection_name,
|
||||
properties=[
|
||||
Property(
|
||||
name="content", data_type=DataType.TEXT, skip_vectorization=True
|
||||
),
|
||||
],
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
vector_index_config=wvc.config.Configure.VectorIndex.hnsw( # hnsw, flat, dynamic,
|
||||
distance_metric=distance_metric
|
||||
),
|
||||
)
|
||||
self.collection = self.client.collections.get(collection_name)
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def add(self, ids: List[str], texts: List[str]):
|
||||
texts = self.truncated_inputs(texts)
|
||||
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
|
||||
|
||||
with self.client.batch.dynamic() as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {self.text_key: text}
|
||||
|
||||
batch.add_object(
|
||||
collection=self.collection_name,
|
||||
properties=data_properties,
|
||||
uuid=ids[i],
|
||||
vector=text_embeddings[i],
|
||||
)
|
||||
|
||||
failed_objs = self.client.batch.failed_objects
|
||||
for obj in failed_objs:
|
||||
err_message = (
|
||||
f"Failed to add object: {obj.original_uuid}\nReason: {obj.message}"
|
||||
)
|
||||
|
||||
logger.error(err_message)
|
||||
|
||||
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
||||
# Fetch vectors by IDs
|
||||
results = self.collection.query.fetch_objects(
|
||||
filters=wvc.query.Filter.by_property("_id").contains_any(ids),
|
||||
include_vector=True,
|
||||
)
|
||||
id_vector_dict = {
|
||||
str(object.uuid): object.vector["default"] for object in results.objects
|
||||
}
|
||||
result = [id_vector_dict[_id] for _id in ids]
|
||||
return result
|
||||
|
||||
async def is_exist(self, ids: List[str]) -> List[bool]:
|
||||
fetched_result = self.collection.query.fetch_objects(
|
||||
filters=wvc.query.Filter.by_property("_id").contains_any(ids),
|
||||
)
|
||||
existed_ids = [str(result.uuid) for result in fetched_result.objects]
|
||||
return list(map(lambda x: x in existed_ids, ids))
|
||||
|
||||
async def query(
|
||||
self, queries: List[str], top_k: int, **kwargs
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
queries = self.truncated_inputs(queries)
|
||||
query_embeddings: List[
|
||||
List[float]
|
||||
] = await self.embedding.aget_text_embedding_batch(queries)
|
||||
|
||||
ids, scores = [], []
|
||||
for query_embedding in query_embeddings:
|
||||
response = self.collection.query.near_vector(
|
||||
near_vector=query_embedding,
|
||||
limit=top_k,
|
||||
return_metadata=MetadataQuery(distance=True),
|
||||
)
|
||||
|
||||
ids.append([o.uuid for o in response.objects])
|
||||
scores.append(
|
||||
[
|
||||
distance_to_score(o.metadata.distance, self.similarity_metric)
|
||||
for o in response.objects
|
||||
]
|
||||
)
|
||||
|
||||
return ids, scores
|
||||
|
||||
async def delete(self, ids: List[str]):
|
||||
filter = wvc.query.Filter.by_id().contains_any(ids)
|
||||
self.collection.data.delete_many(where=filter)
|
||||
|
||||
def delete_collection(self):
|
||||
# Delete the collection
|
||||
self.client.collections.delete(self.collection_name)
|
||||
|
||||
|
||||
def distance_to_score(distance: float, similarity_metric) -> float:
|
||||
if similarity_metric == "cosine":
|
||||
return 1 - distance
|
||||
elif similarity_metric == "ip":
|
||||
return -distance
|
||||
elif similarity_metric == "l2":
|
||||
return -distance
|
||||
else:
|
||||
raise ValueError(
|
||||
f"similarity_metric {similarity_metric} is not supported\n"
|
||||
"supported similarity metrics are: cosine, ip, l2"
|
||||
)
|
||||
Reference in New Issue
Block a user