-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathExample.java
69 lines (60 loc) · 2.95 KB
/
Example.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
package com.example;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import com.pgvector.PGvector;
import org.ankane.disco.Data;
import org.ankane.disco.Dataset;
import org.ankane.disco.Recommender;
public class Example {
public static void main(String[] args) throws Exception {
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
PGvector.addVectorType(conn);
Statement createStmt = conn.createStatement();
createStmt.executeUpdate("DROP TABLE IF EXISTS users");
createStmt.executeUpdate("DROP TABLE IF EXISTS movies");
createStmt.executeUpdate("CREATE TABLE users (id integer PRIMARY KEY, factors vector(20))");
createStmt.executeUpdate("CREATE TABLE movies (name text PRIMARY KEY, factors vector(20))");
Dataset<Integer, String> data = Data.loadMovieLens();
Recommender<Integer, String> recommender = Recommender
.builder()
.factors(20)
.fitExplicit(data);
for (Integer userId : recommender.userIds()) {
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO users (id, factors) VALUES (?, ?)");
insertStmt.setInt(1, userId);
insertStmt.setObject(2, new PGvector(recommender.userFactors(userId).get()));
insertStmt.executeUpdate();
}
for (String itemId : recommender.itemIds()) {
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO movies (name, factors) VALUES (?, ?)");
insertStmt.setString(1, itemId);
insertStmt.setObject(2, new PGvector(recommender.itemFactors(itemId).get()));
insertStmt.executeUpdate();
}
String movie = "Star Wars (1977)";
System.out.printf("Item-based recommendations for %s\n", movie);
PreparedStatement neighborStmt = conn.prepareStatement("SELECT name FROM movies WHERE name != ? ORDER BY factors <=> (SELECT factors FROM movies WHERE name = ?) LIMIT 5");
neighborStmt.setString(1, movie);
neighborStmt.setString(2, movie);
ResultSet rs = neighborStmt.executeQuery();
while (rs.next()) {
System.out.println("- " + rs.getString("name"));
}
int userId = 123;
System.out.printf("\nUser-based recommendations for user %d\n", userId);
neighborStmt = conn.prepareStatement("SELECT name FROM movies ORDER BY factors <#> (SELECT factors FROM users WHERE id = ?) LIMIT 5");
neighborStmt.setInt(1, userId);
rs = neighborStmt.executeQuery();
while (rs.next()) {
System.out.println("- " + rs.getString("name"));
}
conn.close();
}
}