-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathEntityFrameworkCoreTests.cs
144 lines (109 loc) · 6.92 KB
/
EntityFrameworkCoreTests.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using Pgvector.EntityFrameworkCore;
using System.Collections;
using System.ComponentModel.DataAnnotations.Schema;
namespace Pgvector.Tests;
public class ItemContext : DbContext
{
public DbSet<Item> Items { get; set; }
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
var connString = "Host=localhost;Database=pgvector_dotnet_test";
optionsBuilder.UseNpgsql(connString, o => o.UseVector());
}
protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.HasPostgresExtension("vector");
modelBuilder.Entity<Item>()
.HasIndex(i => i.Embedding)
.HasMethod("hnsw")
.HasOperators("vector_l2_ops")
.HasStorageParameter("m", 16)
.HasStorageParameter("ef_construction", 64);
}
}
[Table("efcore_items")]
public class Item
{
public int Id { get; set; }
[Column("embedding", TypeName = "vector(3)")]
public Vector? Embedding { get; set; }
[Column("half_embedding", TypeName = "halfvec(3)")]
public HalfVector? HalfEmbedding { get; set; }
[Column("binary_embedding", TypeName = "bit(3)")]
public BitArray? BinaryEmbedding { get; set; }
[Column("sparse_embedding", TypeName = "sparsevec(3)")]
public SparseVector? SparseEmbedding { get; set; }
}
public class EntityFrameworkCoreTests
{
[Fact]
public async Task Main()
{
await using var ctx = new ItemContext();
ctx.Database.ExecuteSql($"DROP TABLE IF EXISTS efcore_items");
var databaseCreator = ctx.GetService<IRelationalDatabaseCreator>();
databaseCreator.CreateTables();
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }), BinaryEmbedding = new BitArray(new bool[] { false, false, false }), SparseEmbedding = new SparseVector(new float[] { 1, 1, 1 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }), BinaryEmbedding = new BitArray(new bool[] { true, false, true }), SparseEmbedding = new SparseVector(new float[] { 2, 2, 2 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }), BinaryEmbedding = new BitArray(new bool[] { true, true, true }), SparseEmbedding = new SparseVector(new float[] { 1, 1, 2 }) });
ctx.SaveChanges();
var embedding = new Vector(new float[] { 1, 1, 1 });
var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray());
Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!);
Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray());
// vector distance functions
items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
items = await ctx.Items.OrderBy(x => x.Embedding!.MaxInnerProduct(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);
items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
// halfvec distance functions
var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
// sparsevec distance functions
var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
// bit distance functions
var binaryEmbedding = new BitArray(new bool[] { true, false, true });
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
// additional
items = await ctx.Items
.OrderBy(x => x.Id)
.Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
.ToListAsync();
Assert.Equal(new int[] { 1, 3 }, items.Select(v => v.Id).ToArray());
var neighbors = await ctx.Items
.OrderBy(x => x.Embedding!.L2Distance(embedding))
.Select(x => new { Entity = x, Distance = x.Embedding!.L2Distance(embedding) })
.ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, neighbors.Select(v => v.Entity.Id).ToArray());
Assert.Equal(new double[] { 0, 1, Math.Sqrt(3) }, neighbors.Select(v => v.Distance).ToArray());
}
}