Skip to content

Commit 57b7d3b

Browse files
committed
Added test for vector type with SQLAlchemy and asyncpg - #114
1 parent 972b673 commit 57b7d3b

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_sqlalchemy.py

+23
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,29 @@ def connect(dbapi_connection, connection_record):
527527

528528
await engine.dispose()
529529

530+
@pytest.mark.asyncio
531+
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
532+
async def test_asyncpg_vector(self):
533+
import asyncpg
534+
535+
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
536+
async_session = async_sessionmaker(engine, expire_on_commit=False)
537+
538+
# TODO do not throw error when types are registered
539+
# @event.listens_for(engine.sync_engine, "connect")
540+
# def connect(dbapi_connection, connection_record):
541+
# from pgvector.asyncpg import register_vector
542+
# dbapi_connection.run_async(register_vector)
543+
544+
async with async_session() as session:
545+
async with session.begin():
546+
embedding = np.array([1, 2, 3])
547+
session.add(Item(id=1, embedding=embedding))
548+
item = await session.get(Item, 1)
549+
assert np.array_equal(item.embedding, embedding)
550+
551+
await engine.dispose()
552+
530553
@pytest.mark.asyncio
531554
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
532555
async def test_asyncpg_bit(self):

0 commit comments

Comments
 (0)