Fix Dockerfile build issue

This commit is contained in:
kyy
2025-03-18 16:41:12 +09:00
parent 6814230bfb
commit 9323aa254a
228 changed files with 467 additions and 3488 deletions

View File

@@ -0,0 +1,75 @@
import os
from typing import List
from autorag.support import dynamically_find_function
from autorag.utils.util import load_yaml_config
from autorag.vectordb.base import BaseVectorStore
def get_support_vectordb(vectordb_name: str):
support_vectordb = {
"chroma": ("autorag.vectordb.chroma", "Chroma"),
"Chroma": ("autorag.vectordb.chroma", "Chroma"),
"milvus": ("autorag.vectordb.milvus", "Milvus"),
"Milvus": ("autorag.vectordb.milvus", "Milvus"),
"weaviate": ("autorag.vectordb.weaviate", "Weaviate"),
"Weaviate": ("autorag.vectordb.weaviate", "Weaviate"),
"pinecone": ("autorag.vectordb.pinecone", "Pinecone"),
"Pinecone": ("autorag.vectordb.pinecone", "Pinecone"),
"couchbase": ("autorag.vectordb.couchbase", "Couchbase"),
"Couchbase": ("autorag.vectordb.couchbase", "Couchbase"),
"qdrant": ("autorag.vectordb.qdrant", "Qdrant"),
"Qdrant": ("autorag.vectordb.qdrant", "Qdrant"),
}
return dynamically_find_function(vectordb_name, support_vectordb)
def load_vectordb(vectordb_name: str, **kwargs):
vectordb = get_support_vectordb(vectordb_name)
return vectordb(**kwargs)
def load_vectordb_from_yaml(yaml_path: str, vectordb_name: str, project_dir: str):
config_dict = load_yaml_config(yaml_path)
vectordb_list = config_dict.get("vectordb", [])
if len(vectordb_list) == 0 or vectordb_name == "default":
chroma_path = os.path.join(project_dir, "resources", "chroma")
return load_vectordb(
"chroma",
client_type="persistent",
embedding_model="openai",
collection_name="openai",
path=chroma_path,
)
target_dict = list(filter(lambda x: x["name"] == vectordb_name, vectordb_list))
target_dict[0].pop("name") # delete a name key
target_vectordb_name = target_dict[0].pop("db_type")
target_vectordb_params = target_dict[0]
return load_vectordb(target_vectordb_name, **target_vectordb_params)
def load_all_vectordb_from_yaml(
yaml_path: str, project_dir: str
) -> List[BaseVectorStore]:
config_dict = load_yaml_config(yaml_path)
vectordb_list = config_dict.get("vectordb", [])
if len(vectordb_list) == 0:
chroma_path = os.path.join(project_dir, "resources", "chroma")
return [
load_vectordb(
"chroma",
client_type="persistent",
embedding_model="openai",
collection_name="openai",
path=chroma_path,
)
]
result_vectordbs = []
for vectordb_dict in vectordb_list:
_ = vectordb_dict.pop("name")
vectordb_type = vectordb_dict.pop("db_type")
vectordb = load_vectordb(vectordb_type, **vectordb_dict)
result_vectordbs.append(vectordb)
return result_vectordbs

66
autorag/vectordb/base.py Normal file
View File

@@ -0,0 +1,66 @@
from abc import abstractmethod
from typing import List, Tuple, Union
from llama_index.embeddings.openai import OpenAIEmbedding
from autorag.utils.util import openai_truncate_by_token
from autorag.embedding.base import EmbeddingModel
class BaseVectorStore:
support_similarity_metrics = ["l2", "ip", "cosine"]
def __init__(
self,
embedding_model: Union[str, List[dict]],
similarity_metric: str = "cosine",
embedding_batch: int = 100,
):
self.embedding = EmbeddingModel.load(embedding_model)()
self.embedding_batch = embedding_batch
self.embedding.embed_batch_size = embedding_batch
assert (
similarity_metric in self.support_similarity_metrics
), f"search method {similarity_metric} is not supported"
self.similarity_metric = similarity_metric
@abstractmethod
async def add(
self,
ids: List[str],
texts: List[str],
):
pass
@abstractmethod
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
pass
@abstractmethod
async def fetch(self, ids: List[str]) -> List[List[float]]:
"""
Fetch the embeddings of the ids.
"""
pass
@abstractmethod
async def is_exist(self, ids: List[str]) -> List[bool]:
"""
Check if the ids exist in the Vector DB.
"""
pass
@abstractmethod
async def delete(self, ids: List[str]):
pass
def truncated_inputs(self, inputs: List[str]) -> List[str]:
if isinstance(self.embedding, OpenAIEmbedding):
openai_embedding_limit = 8000
results = openai_truncate_by_token(
inputs, openai_embedding_limit, self.embedding.model_name
)
return results
return inputs

117
autorag/vectordb/chroma.py Normal file
View File

@@ -0,0 +1,117 @@
from typing import List, Optional, Dict, Tuple, Union
from chromadb import (
EphemeralClient,
PersistentClient,
DEFAULT_TENANT,
DEFAULT_DATABASE,
CloudClient,
AsyncHttpClient,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.types import IncludeEnum, QueryResult
from autorag.utils.util import apply_recursive
from autorag.vectordb.base import BaseVectorStore
class Chroma(BaseVectorStore):
def __init__(
self,
embedding_model: Union[str, List[dict]],
collection_name: str,
embedding_batch: int = 100,
client_type: str = "persistent",
similarity_metric: str = "cosine",
path: str = None,
host: str = "localhost",
port: int = 8000,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
api_key: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
if client_type == "ephemeral":
self.client = EphemeralClient(tenant=tenant, database=database)
elif client_type == "persistent":
assert path is not None, "path must be provided for persistent client"
self.client = PersistentClient(path=path, tenant=tenant, database=database)
elif client_type == "http":
self.client = AsyncHttpClient(
host=host,
port=port,
ssl=ssl,
headers=headers,
tenant=tenant,
database=database,
)
elif client_type == "cloud":
self.client = CloudClient(
tenant=tenant,
database=database,
api_key=api_key,
)
else:
raise ValueError(
f"client_type {client_type} is not supported\n"
"supported client types are: ephemeral, persistent, http, cloud"
)
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": similarity_metric},
)
async def add(self, ids: List[str], texts: List[str]):
texts = self.truncated_inputs(texts)
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
if isinstance(self.collection, AsyncCollection):
await self.collection.add(ids=ids, embeddings=text_embeddings)
else:
self.collection.add(ids=ids, embeddings=text_embeddings)
async def fetch(self, ids: List[str]) -> List[List[float]]:
if isinstance(self.collection, AsyncCollection):
fetch_result = await self.collection.get(
ids, include=[IncludeEnum.embeddings]
)
else:
fetch_result = self.collection.get(ids, include=[IncludeEnum.embeddings])
fetch_embeddings = fetch_result["embeddings"]
return fetch_embeddings
async def is_exist(self, ids: List[str]) -> List[bool]:
if isinstance(self.collection, AsyncCollection):
fetched_result = await self.collection.get(ids, include=[])
else:
fetched_result = self.collection.get(ids, include=[])
existed_ids = fetched_result["ids"]
return list(map(lambda x: x in existed_ids, ids))
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
if isinstance(self.collection, AsyncCollection):
query_result: QueryResult = await self.collection.query(
query_embeddings=query_embeddings, n_results=top_k
)
else:
query_result: QueryResult = self.collection.query(
query_embeddings=query_embeddings, n_results=top_k
)
ids = query_result["ids"]
scores = query_result["distances"]
scores = apply_recursive(lambda x: 1 - x, scores)
return ids, scores
async def delete(self, ids: List[str]):
if isinstance(self.collection, AsyncCollection):
await self.collection.delete(ids)
else:
self.collection.delete(ids)

View File

@@ -0,0 +1,218 @@
import logging
from datetime import timedelta
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.options import ClusterOptions
from typing import List, Tuple, Optional, Union
from autorag.utils.util import make_batch
from autorag.vectordb import BaseVectorStore
logger = logging.getLogger("AutoRAG")
class Couchbase(BaseVectorStore):
def __init__(
self,
embedding_model: Union[str, List[dict]],
bucket_name: str,
scope_name: str,
collection_name: str,
index_name: str,
embedding_batch: int = 100,
connection_string: str = "",
username: str = "",
password: str = "",
ingest_batch: int = 100,
text_key: Optional[str] = "text",
embedding_key: Optional[str] = "embedding",
scoped_index: bool = True,
):
super().__init__(
embedding_model=embedding_model,
similarity_metric="ip",
embedding_batch=embedding_batch,
)
self.index_name = index_name
self.bucket_name = bucket_name
self.scope_name = scope_name
self.collection_name = collection_name
self.scoped_index = scoped_index
self.text_key = text_key
self.embedding_key = embedding_key
self.ingest_batch = ingest_batch
auth = PasswordAuthenticator(username, password)
self.cluster = Cluster(connection_string, ClusterOptions(auth))
# Wait until the cluster is ready for use.
self.cluster.wait_until_ready(timedelta(seconds=5))
# Check if the bucket exists
if not self._check_bucket_exists():
raise ValueError(
f"Bucket {self.bucket_name} does not exist. "
" Please create the bucket before searching."
)
try:
self.bucket = self.cluster.bucket(self.bucket_name)
self.scope = self.bucket.scope(self.scope_name)
self.collection = self.scope.collection(self.collection_name)
except Exception as e:
raise ValueError(
"Error connecting to couchbase. "
"Please check the connection and credentials."
) from e
# Check if the index exists. Throws ValueError if it doesn't
try:
self._check_index_exists()
except Exception:
raise
# Reinitialize to ensure a consistent state
self.bucket = self.cluster.bucket(self.bucket_name)
self.scope = self.bucket.scope(self.scope_name)
self.collection = self.scope.collection(self.collection_name)
async def add(self, ids: List[str], texts: List[str]):
from couchbase.exceptions import DocumentExistsException
texts = self.truncated_inputs(texts)
text_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(texts)
documents_to_insert = []
for _id, text, embedding in zip(ids, texts, text_embeddings):
doc = {
self.text_key: text,
self.embedding_key: embedding,
}
documents_to_insert.append({_id: doc})
batch_documents_to_insert = make_batch(documents_to_insert, self.ingest_batch)
for batch in batch_documents_to_insert:
insert_batch = {}
for doc in batch:
insert_batch.update(doc)
try:
self.collection.upsert_multi(insert_batch)
except DocumentExistsException as e:
logger.debug(f"Document already exists: {e}")
async def fetch(self, ids: List[str]) -> List[List[float]]:
# Fetch vectors by IDs
fetched_result = self.collection.get_multi(ids)
fetched_vectors = {
k: v.value[f"{self.embedding_key}"]
for k, v in fetched_result.results.items()
}
return list(map(lambda x: fetched_vectors[x], ids))
async def is_exist(self, ids: List[str]) -> List[bool]:
existed_result = self.collection.exists_multi(ids)
existed_ids = {k: v.exists for k, v in existed_result.results.items()}
return list(map(lambda x: existed_ids[x], ids))
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
import couchbase.search as search
from couchbase.options import SearchOptions
from couchbase.vector_search import VectorQuery, VectorSearch
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
ids, scores = [], []
for query_embedding in query_embeddings:
# Create Search Request
search_req = search.SearchRequest.create(
VectorSearch.from_vector_query(
VectorQuery(
self.embedding_key,
query_embedding,
top_k,
)
)
)
# Search
if self.scoped_index:
search_iter = self.scope.search(
self.index_name,
search_req,
SearchOptions(limit=top_k),
)
else:
search_iter = self.cluster.search(
self.index_name,
search_req,
SearchOptions(limit=top_k),
)
# Parse the search results
# search_iter.rows() can only be iterated once.
id_list, score_list = [], []
for result in search_iter.rows():
id_list.append(result.id)
score_list.append(result.score)
ids.append(id_list)
scores.append(score_list)
return ids, scores
async def delete(self, ids: List[str]):
self.collection.remove_multi(ids)
def _check_bucket_exists(self) -> bool:
"""Check if the bucket exists in the linked Couchbase cluster.
Returns:
True if the bucket exists
"""
bucket_manager = self.cluster.buckets()
try:
bucket_manager.get_bucket(self.bucket_name)
return True
except Exception as e:
logger.debug("Error checking if bucket exists:", e)
return False
def _check_index_exists(self) -> bool:
"""Check if the Search index exists in the linked Couchbase cluster
Returns:
bool: True if the index exists, False otherwise.
Raises a ValueError if the index does not exist.
"""
if self.scoped_index:
all_indexes = [
index.name for index in self.scope.search_indexes().get_all_indexes()
]
if self.index_name not in all_indexes:
raise ValueError(
f"Index {self.index_name} does not exist. "
" Please create the index before searching."
)
else:
all_indexes = [
index.name for index in self.cluster.search_indexes().get_all_indexes()
]
if self.index_name not in all_indexes:
raise ValueError(
f"Index {self.index_name} does not exist. "
" Please create the index before searching."
)
return True

168
autorag/vectordb/milvus.py Normal file
View File

@@ -0,0 +1,168 @@
import logging
from typing import Any, Dict, List, Tuple, Optional, Union
from pymilvus import (
DataType,
FieldSchema,
CollectionSchema,
connections,
Collection,
MilvusException,
)
from pymilvus.orm import utility
from autorag.utils.util import apply_recursive
from autorag.vectordb import BaseVectorStore
logger = logging.getLogger("AutoRAG")
class Milvus(BaseVectorStore):
def __init__(
self,
embedding_model: Union[str, List[dict]],
collection_name: str,
embedding_batch: int = 100,
similarity_metric: str = "cosine",
index_type: str = "IVF_FLAT",
uri: str = "http://localhost:19530",
db_name: str = "",
token: str = "",
user: str = "",
password: str = "",
timeout: Optional[float] = None,
params: Dict[str, Any] = {},
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
# Connect to Milvus server
connections.connect(
"default",
uri=uri,
token=token,
db_name=db_name,
user=user,
password=password,
)
self.collection_name = collection_name
self.timeout = timeout
self.params = params
self.index_type = index_type
# Set Collection
if not utility.has_collection(collection_name, timeout=timeout):
# Get the dimension of the embeddings
test_embedding_result: List[float] = self.embedding.get_query_embedding(
"test"
)
dimension = len(test_embedding_result)
pk = FieldSchema(
name="id",
dtype=DataType.VARCHAR,
max_length=128,
is_primary=True,
auto_id=False,
)
field = FieldSchema(
name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension
)
schema = CollectionSchema(fields=[pk, field])
self.collection = Collection(name=self.collection_name, schema=schema)
index_params = {
"metric_type": self.similarity_metric.upper(),
"index_type": self.index_type.upper(),
"params": self.params,
}
self.collection.create_index(
field_name="vector", index_params=index_params, timeout=self.timeout
)
else:
self.collection = Collection(name=self.collection_name)
async def add(self, ids: List[str], texts: List[str]):
texts = self.truncated_inputs(texts)
text_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(texts)
# make data for insertion
data = list(
map(lambda _id, vector: {"id": _id, "vector": vector}, ids, text_embeddings)
)
# Insert data into the collection
res = self.collection.insert(data=data, timeout=self.timeout)
assert (
res.insert_count == len(ids)
), f"Insertion failed. Try to insert {len(ids)} but only {res['insert_count']} inserted."
self.collection.flush(timeout=self.timeout)
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
self.collection.load(timeout=self.timeout)
# Perform similarity search
results = self.collection.search(
data=query_embeddings,
limit=top_k,
anns_field="vector",
param={"metric_type": self.similarity_metric.upper()},
timeout=self.timeout,
**kwargs,
)
# Extract IDs and distances
ids = [[str(hit.id) for hit in result] for result in results]
distances = [[hit.distance for hit in result] for result in results]
if self.similarity_metric in ["l2"]:
distances = apply_recursive(lambda x: -x, distances)
return ids, distances
async def fetch(self, ids: List[str]) -> List[List[float]]:
try:
self.collection.load(timeout=self.timeout)
except MilvusException as e:
logger.warning(f"Failed to load collection: {e}")
return [[]] * len(ids)
# Fetch vectors by IDs
results = self.collection.query(
expr=f"id in {ids}", output_fields=["id", "vector"], timeout=self.timeout
)
id_vector_dict = {str(result["id"]): result["vector"] for result in results}
result = [id_vector_dict[_id] for _id in ids]
return result
async def is_exist(self, ids: List[str]) -> List[bool]:
try:
self.collection.load(timeout=self.timeout)
except MilvusException:
return [False] * len(ids)
# Check the existence of IDs
results = self.collection.query(
expr=f"id in {ids}", output_fields=["id"], timeout=self.timeout
)
# Determine existence
existing_ids = {str(result["id"]) for result in results}
return [str(_id) in existing_ids for _id in ids]
async def delete(self, ids: List[str]):
# Delete entries by IDs
self.collection.delete(expr=f"id in {ids}", timeout=self.timeout)
def delete_collection(self):
# Delete the collection
self.collection.release(timeout=self.timeout)
self.collection.drop_index(timeout=self.timeout)
self.collection.drop(timeout=self.timeout)

View File

@@ -0,0 +1,119 @@
import logging
from pinecone.grpc import PineconeGRPC as Pinecone_client
from pinecone import ServerlessSpec
from typing import List, Optional, Tuple, Union
from autorag.utils.util import make_batch, apply_recursive
from autorag.vectordb import BaseVectorStore
logger = logging.getLogger("AutoRAG")
class Pinecone(BaseVectorStore):
def __init__(
self,
embedding_model: Union[str, List[dict]],
index_name: str,
embedding_batch: int = 100,
dimension: int = 1536,
similarity_metric: str = "cosine", # "cosine", "dotproduct", "euclidean"
cloud: Optional[str] = "aws",
region: Optional[str] = "us-east-1",
api_key: Optional[str] = None,
deletion_protection: Optional[str] = "disabled", # "enabled" or "disabled"
namespace: Optional[str] = "default",
ingest_batch: int = 200,
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
self.index_name = index_name
self.namespace = namespace
self.ingest_batch = ingest_batch
self.client = Pinecone_client(api_key=api_key)
if similarity_metric == "ip":
similarity_metric = "dotproduct"
elif similarity_metric == "l2":
similarity_metric = "euclidean"
if not self.client.has_index(index_name):
self.client.create_index(
name=index_name,
dimension=dimension,
metric=similarity_metric,
spec=ServerlessSpec(
cloud=cloud,
region=region,
),
deletion_protection=deletion_protection,
)
self.index = self.client.Index(index_name)
async def add(self, ids: List[str], texts: List[str]):
texts = self.truncated_inputs(texts)
text_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(texts)
vector_tuples = list(zip(ids, text_embeddings))
batch_vectors = make_batch(vector_tuples, self.ingest_batch)
async_res = [
self.index.upsert(
vectors=batch_vector_tuples,
namespace=self.namespace,
async_req=True,
)
for batch_vector_tuples in batch_vectors
]
# Wait for the async requests to finish
[async_result.result() for async_result in async_res]
async def fetch(self, ids: List[str]) -> List[List[float]]:
results = self.index.fetch(ids=ids, namespace=self.namespace)
id_vector_dict = {
str(key): val["values"] for key, val in results["vectors"].items()
}
result = [id_vector_dict[_id] for _id in ids]
return result
async def is_exist(self, ids: List[str]) -> List[bool]:
fetched_result = self.index.fetch(ids=ids, namespace=self.namespace)
existed_ids = list(map(str, fetched_result.get("vectors", {}).keys()))
return list(map(lambda x: x in existed_ids, ids))
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
ids, scores = [], []
for query_embedding in query_embeddings:
response = self.index.query(
vector=query_embedding,
top_k=top_k,
include_values=True,
namespace=self.namespace,
)
ids.append([o.id for o in response.matches])
scores.append([o.score for o in response.matches])
if self.similarity_metric in ["l2"]:
scores = apply_recursive(lambda x: -x, scores)
return ids, scores
async def delete(self, ids: List[str]):
# Delete entries by IDs
self.index.delete(ids=ids, namespace=self.namespace)
def delete_index(self):
# Delete the index
self.client.delete_index(self.index_name)

153
autorag/vectordb/qdrant.py Normal file
View File

@@ -0,0 +1,153 @@
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
PointIdsList,
HasIdCondition,
Filter,
SearchRequest,
)
from typing import List, Tuple, Union
from autorag.vectordb import BaseVectorStore
logger = logging.getLogger("AutoRAG")
class Qdrant(BaseVectorStore):
def __init__(
self,
embedding_model: Union[str, List[dict]],
collection_name: str,
embedding_batch: int = 100,
similarity_metric: str = "cosine",
client_type: str = "docker",
url: str = "http://localhost:6333",
host: str = "",
api_key: str = "",
dimension: int = 1536,
ingest_batch: int = 64,
parallel: int = 1,
max_retries: int = 3,
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
self.collection_name = collection_name
self.ingest_batch = ingest_batch
self.parallel = parallel
self.max_retries = max_retries
if similarity_metric == "cosine":
distance = Distance.COSINE
elif similarity_metric == "ip":
distance = Distance.DOT
elif similarity_metric == "l2":
distance = Distance.EUCLID
else:
raise ValueError(
f"similarity_metric {similarity_metric} is not supported\n"
"supported similarity metrics are: cosine, ip, l2"
)
if client_type == "docker":
self.client = QdrantClient(
url=url,
)
elif client_type == "cloud":
self.client = QdrantClient(
host=host,
api_key=api_key,
)
else:
raise ValueError(
f"client_type {client_type} is not supported\n"
"supported client types are: docker, cloud"
)
if not self.client.collection_exists(collection_name):
self.client.create_collection(
collection_name,
vectors_config=VectorParams(
size=dimension,
distance=distance,
),
)
self.collection = self.client.get_collection(collection_name)
async def add(self, ids: List[str], texts: List[str]):
texts = self.truncated_inputs(texts)
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
points = list(
map(lambda x: PointStruct(id=x[0], vector=x[1]), zip(ids, text_embeddings))
)
self.client.upload_points(
collection_name=self.collection_name,
points=points,
batch_size=self.ingest_batch,
parallel=self.parallel,
max_retries=self.max_retries,
wait=True,
)
async def fetch(self, ids: List[str]) -> List[List[float]]:
# Fetch vectors by IDs
fetched_results = self.client.retrieve(
collection_name=self.collection_name,
ids=ids,
with_vectors=True,
)
return list(map(lambda x: x.vector, fetched_results))
async def is_exist(self, ids: List[str]) -> List[bool]:
existed_result = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=Filter(
must=[
HasIdCondition(has_id=ids),
],
),
)
# existed_result is tuple. So we use existed_result[0] to get list of Record
existed_ids = list(map(lambda x: x.id, existed_result[0]))
return list(map(lambda x: x in existed_ids, ids))
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
search_queries = list(
map(
lambda x: SearchRequest(vector=x, limit=top_k, with_vector=True),
query_embeddings,
)
)
search_result = self.client.search_batch(
collection_name=self.collection_name, requests=search_queries
)
# Extract IDs and distances
ids = [[str(hit.id) for hit in result] for result in search_result]
scores = [[hit.score for hit in result] for result in search_result]
return ids, scores
async def delete(self, ids: List[str]):
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(points=ids),
)
def delete_collection(self):
# Delete the collection
self.client.delete_collection(self.collection_name)

View File

@@ -0,0 +1,167 @@
import logging
import weaviate
from weaviate.classes.init import Auth
from weaviate.classes.config import Property, DataType
import weaviate.classes as wvc
from weaviate.classes.query import MetadataQuery
from typing import List, Optional, Tuple, Union
from autorag.vectordb import BaseVectorStore
logger = logging.getLogger("AutoRAG")
class Weaviate(BaseVectorStore):
def __init__(
self,
embedding_model: Union[str, List[dict]],
collection_name: str,
embedding_batch: int = 100,
similarity_metric: str = "cosine",
client_type: str = "docker",
host: str = "localhost",
port: int = 8080,
grpc_port: int = 50051,
url: Optional[str] = None,
api_key: Optional[str] = None,
text_key: str = "content",
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
self.text_key = text_key
if client_type == "docker":
self.client = weaviate.connect_to_local(
host=host,
port=port,
grpc_port=grpc_port,
)
elif client_type == "cloud":
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url=url,
auth_credentials=Auth.api_key(api_key),
)
else:
raise ValueError(
f"client_type {client_type} is not supported\n"
"supported client types are: docker, cloud"
)
if similarity_metric == "cosine":
distance_metric = wvc.config.VectorDistances.COSINE
elif similarity_metric == "ip":
distance_metric = wvc.config.VectorDistances.DOT
elif similarity_metric == "l2":
distance_metric = wvc.config.VectorDistances.L2_SQUARED
else:
raise ValueError(
f"similarity_metric {similarity_metric} is not supported\n"
"supported similarity metrics are: cosine, ip, l2"
)
if not self.client.collections.exists(collection_name):
self.client.collections.create(
collection_name,
properties=[
Property(
name="content", data_type=DataType.TEXT, skip_vectorization=True
),
],
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
vector_index_config=wvc.config.Configure.VectorIndex.hnsw( # hnsw, flat, dynamic,
distance_metric=distance_metric
),
)
self.collection = self.client.collections.get(collection_name)
self.collection_name = collection_name
async def add(self, ids: List[str], texts: List[str]):
texts = self.truncated_inputs(texts)
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
with self.client.batch.dynamic() as batch:
for i, text in enumerate(texts):
data_properties = {self.text_key: text}
batch.add_object(
collection=self.collection_name,
properties=data_properties,
uuid=ids[i],
vector=text_embeddings[i],
)
failed_objs = self.client.batch.failed_objects
for obj in failed_objs:
err_message = (
f"Failed to add object: {obj.original_uuid}\nReason: {obj.message}"
)
logger.error(err_message)
async def fetch(self, ids: List[str]) -> List[List[float]]:
# Fetch vectors by IDs
results = self.collection.query.fetch_objects(
filters=wvc.query.Filter.by_property("_id").contains_any(ids),
include_vector=True,
)
id_vector_dict = {
str(object.uuid): object.vector["default"] for object in results.objects
}
result = [id_vector_dict[_id] for _id in ids]
return result
async def is_exist(self, ids: List[str]) -> List[bool]:
fetched_result = self.collection.query.fetch_objects(
filters=wvc.query.Filter.by_property("_id").contains_any(ids),
)
existed_ids = [str(result.uuid) for result in fetched_result.objects]
return list(map(lambda x: x in existed_ids, ids))
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
ids, scores = [], []
for query_embedding in query_embeddings:
response = self.collection.query.near_vector(
near_vector=query_embedding,
limit=top_k,
return_metadata=MetadataQuery(distance=True),
)
ids.append([o.uuid for o in response.objects])
scores.append(
[
distance_to_score(o.metadata.distance, self.similarity_metric)
for o in response.objects
]
)
return ids, scores
async def delete(self, ids: List[str]):
filter = wvc.query.Filter.by_id().contains_any(ids)
self.collection.data.delete_many(where=filter)
def delete_collection(self):
# Delete the collection
self.client.collections.delete(self.collection_name)
def distance_to_score(distance: float, similarity_metric) -> float:
if similarity_metric == "cosine":
return 1 - distance
elif similarity_metric == "ip":
return -distance
elif similarity_metric == "l2":
return -distance
else:
raise ValueError(
f"similarity_metric {similarity_metric} is not supported\n"
"supported similarity metrics are: cosine, ip, l2"
)