Skip to content

Commit 76a0a70

Browse files
committed
Added more functions to SparseVector
1 parent a619f7d commit 76a0a70

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

lib/pgvector/sparse_vector.ex

+40-2
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,51 @@ defmodule Pgvector.SparseVector do
6060
def from_binary(binary) when is_binary(binary) do
6161
%Pgvector.SparseVector{data: binary}
6262
end
63+
64+
@doc """
65+
Returns the dimensions
66+
"""
67+
def dimensions(vector) when is_struct(vector, Pgvector.SparseVector) do
68+
<<dim::signed-32, _::binary>> = vector.data
69+
dim
70+
end
71+
72+
@doc """
73+
Returns the indices
74+
"""
75+
def indices(vector) when is_struct(vector, Pgvector.SparseVector) do
76+
<<_::signed-32, nnz::signed-32, 0::signed-32, indices::binary-size(nnz)-unit(32),
77+
_::binary-size(nnz)-unit(32)>> = vector.data
78+
79+
for <<v::signed-32 <- indices>>, do: v
80+
end
81+
82+
@doc """
83+
Returns the values
84+
"""
85+
def values(vector) when is_struct(vector, Pgvector.SparseVector) do
86+
<<_::signed-32, nnz::signed-32, 0::signed-32, _::binary-size(nnz)-unit(32),
87+
values::binary-size(nnz)-unit(32)>> = vector.data
88+
89+
for <<v::float-32 <- values>>, do: v
90+
end
6391
end
6492

6593
defimpl Inspect, for: Pgvector.SparseVector do
6694
import Inspect.Algebra
6795

6896
def inspect(vector, opts) do
69-
# TODO improve
70-
concat(["Pgvector.SparseVector.new(", Inspect.List.inspect(Pgvector.to_list(vector), opts), ")"])
97+
dimensions = vector |> Pgvector.SparseVector.dimensions()
98+
indices = vector |> Pgvector.SparseVector.indices()
99+
values = vector |> Pgvector.SparseVector.values()
100+
elements = Enum.zip(indices, values) |> Enum.into(%{})
101+
102+
concat([
103+
"Pgvector.SparseVector.new(",
104+
Inspect.Map.inspect(elements, opts),
105+
", ",
106+
Inspect.Integer.inspect(dimensions, opts),
107+
")"
108+
])
71109
end
72110
end

test/sparse_vector_test.exs

+16-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,24 @@ defmodule SparseVectorTest do
2121
assert [1.0, 0.0, 2.0, 0.0, 3.0, 0.0] == map |> Pgvector.SparseVector.new(6) |> Pgvector.to_list()
2222
end
2323

24+
test "dimensions" do
25+
vector = Pgvector.SparseVector.new([1, 2, 3])
26+
assert 3 == vector |> Pgvector.SparseVector.dimensions()
27+
end
28+
29+
test "indices" do
30+
vector = Pgvector.SparseVector.new([1, 2, 3])
31+
assert [0, 1, 2] == vector |> Pgvector.SparseVector.indices()
32+
end
33+
34+
test "values" do
35+
vector = Pgvector.SparseVector.new([1, 2, 3])
36+
assert [1, 2, 3] == vector |> Pgvector.SparseVector.values()
37+
end
38+
2439
test "inspect" do
2540
vector = Pgvector.SparseVector.new([1, 2, 3])
26-
assert "Pgvector.SparseVector.new([1.0, 2.0, 3.0])" == inspect(vector)
41+
assert "Pgvector.SparseVector.new(%{0 => 1.0, 1 => 2.0, 2 => 3.0}, 3)" == inspect(vector)
2742
end
2843

2944
test "equals" do

0 commit comments

Comments
 (0)