@@ -22,17 +22,9 @@ export type NearestEntry = {
22
22
index : number ;
23
23
dist : number ;
24
24
} ;
25
- /**
26
- * Optimal size for the height of the matrix when doing computation on the GPU
27
- * using WebGL. This was found experimentally.
28
- *
29
- * This also guarantees that for computing pair-wise distance for up to 10K
30
- * vectors, no more than 40MB will be allocated in the GPU. Without the
31
- * allocation limit, we can freeze the graphics of the whole OS.
32
- */
33
- const OPTIMAL_GPU_BLOCK_SIZE = 256 ;
34
- /** Id of message box used for knn gpu progress bar. */
35
- const KNN_GPU_MSG_ID = 'knn-gpu' ;
25
+
26
+ /** Id of message box used for knn. */
27
+ const KNN_MSG_ID = 'knn' ;
36
28
37
29
/**
38
30
* Returns the K nearest neighbors for each vector where the distance
@@ -52,105 +44,66 @@ export function findKNNGPUCosDistNorm<T>(
52
44
const N = dataPoints . length ;
53
45
const dim = accessor ( dataPoints [ 0 ] ) . length ;
54
46
// The goal is to compute a large matrix multiplication A*A.T where A is of
55
- // size NxD and A.T is its transpose. This results in a NxN matrix which
56
- // could be too big to store on the GPU memory. To avoid memory overflow, we
57
- // compute multiple A*partial_A.T where partial_A is of size BxD (B is much
58
- // smaller than N). This results in storing only NxB size matrices on the GPU
59
- // at a given time.
47
+ // size NxD and A.T is its transpose. This results in a NxN matrix.
60
48
// A*A.T will give us NxN matrix holding the cosine distance between every
61
49
// pair of points, which we sort using KMin data structure to obtain the
62
50
// K nearest neighbors for each point.
63
51
const nearest : NearestEntry [ ] [ ] = new Array ( N ) ;
64
- let numPieces = Math . ceil ( N / OPTIMAL_GPU_BLOCK_SIZE ) ;
65
- const actualPieceSize = Math . floor ( N / numPieces ) ;
66
- const modulo = N % actualPieceSize ;
67
- numPieces += modulo ? 1 : 0 ;
68
- let offset = 0 ;
69
- let progress = 0 ;
70
- let progressDiff = 1 / ( 2 * numPieces ) ;
71
- let piece = 0 ;
72
52
73
53
const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
74
54
const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
75
55
const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
76
- // 1 - A * A^T.
77
- const bigMatrixSquared = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
78
- const cosDistMatrix = tf . sub ( 1 , bigMatrixSquared ) ;
79
-
80
- let maybePaddedCosDistMatrix = cosDistMatrix ;
81
- if ( actualPieceSize * numPieces > N ) {
82
- // Expect the input to be rank 2 (though it is not typed that way) so we
83
- // want to pad the first dimension so we split very evenly (all splitted
84
- // tensor have exactly the same dimesion).
85
- const padding : Array < [ number , number ] > = [
86
- [ 0 , actualPieceSize * numPieces - N ] ,
87
- [ 0 , 0 ] ,
88
- ] ;
89
- maybePaddedCosDistMatrix = tf . pad ( cosDistMatrix , padding ) ;
90
- }
91
- const splits = tf . split (
92
- maybePaddedCosDistMatrix ,
93
- new Array ( numPieces ) . fill ( actualPieceSize ) ,
94
- 0
95
- ) ;
96
56
97
57
function step ( resolve : ( result : NearestEntry [ ] [ ] ) => void ) {
98
- let progressMsg =
99
- 'Finding nearest neighbors: ' + ( progress * 100 ) . toFixed ( ) + '%' ;
100
58
util
101
59
. runAsyncTask (
102
- progressMsg ,
60
+ 'Finding nearest neighbors...' ,
103
61
async ( ) => {
62
+ // 1 - A * A^T.
63
+ const bigMatrixSquared = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
64
+ const cosDistMatrix = tf . sub ( 1 , bigMatrixSquared ) ;
104
65
// `.data()` returns flattened Float32Array of B * N dimension.
105
66
// For matrix of
106
67
// [ 1 2 ]
107
68
// [ 3 4 ],
108
69
// `.data()` returns [1, 2, 3, 4].
109
- const partial = await splits [ piece ] . data ( ) ;
110
- progress += progressDiff ;
111
- for ( let i = 0 ; i < actualPieceSize ; i ++ ) {
70
+ const partial = await cosDistMatrix . data ( ) ;
71
+ bigMatrixSquared . dispose ( ) ;
72
+ cosDistMatrix . dispose ( ) ;
73
+ for ( let i = 0 ; i < N ; i ++ ) {
112
74
let kMin = new KMin < NearestEntry > ( k ) ;
113
- let iReal = offset + i ;
114
- if ( iReal >= N ) break ;
115
75
for ( let j = 0 ; j < N ; j ++ ) {
116
76
// Skip diagonal entries.
117
- if ( j === iReal ) {
77
+ if ( j === i ) {
118
78
continue ;
119
79
}
120
80
// Access i * N's row at `j` column.
121
81
// Reach row has N entries and j-th index has cosine distance
122
- // between iReal vs. j-th vectors.
82
+ // between i-th vs. j-th vectors.
123
83
const cosDist = partial [ i * N + j ] ;
124
84
if ( cosDist >= 0 ) {
125
85
kMin . add ( cosDist , { index : j , dist : cosDist } ) ;
126
86
}
127
87
}
128
- nearest [ iReal ] = kMin . getMinKItems ( ) ;
88
+ nearest [ i ] = kMin . getMinKItems ( ) ;
129
89
}
130
- progress += progressDiff ;
131
- offset += actualPieceSize ;
132
- piece ++ ;
133
90
} ,
134
- KNN_GPU_MSG_ID
91
+ KNN_MSG_ID ,
135
92
)
136
93
. then (
137
94
( ) => {
138
- if ( piece < numPieces ) {
139
- step ( resolve ) ;
140
- } else {
141
- logging . setModalMessage ( null ! , KNN_GPU_MSG_ID ) ;
142
- // Discard all tensors and free up the memory.
143
- bigMatrix . dispose ( ) ;
144
- bigMatrixTransposed . dispose ( ) ;
145
- bigMatrixSquared . dispose ( ) ;
146
- cosDistMatrix . dispose ( ) ;
147
- splits . forEach ( ( split ) => split . dispose ( ) ) ;
148
- resolve ( nearest ) ;
149
- }
95
+ logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
96
+ // Discard all tensors and free up the memory.
97
+ bigMatrix . dispose ( ) ;
98
+ bigMatrixTransposed . dispose ( ) ;
99
+ resolve ( nearest ) ;
150
100
} ,
151
101
( error ) => {
102
+ // Discard all tensors and free up the memory.
103
+ bigMatrix . dispose ( ) ;
104
+ bigMatrixTransposed . dispose ( ) ;
152
105
// GPU failed. Reverting back to CPU.
153
- logging . setModalMessage ( null ! , KNN_GPU_MSG_ID ) ;
106
+ logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
154
107
let distFunc = ( a , b , limit ) => vector . cosDistNorm ( a , b ) ;
155
108
findKNN ( dataPoints , k , accessor , distFunc ) . then ( ( nearest ) => {
156
109
resolve ( nearest ) ;
@@ -212,47 +165,12 @@ export function findKNN<T>(
212
165
for ( let i = 0 ; i < N ; i ++ ) {
213
166
nearest [ i ] = kMin [ i ] . getMinKItems ( ) ;
214
167
}
168
+ logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
215
169
return nearest ;
216
- }
170
+ } ,
171
+ KNN_MSG_ID ,
217
172
) ;
218
173
}
219
- /** Calculates the minimum distance between a search point and a rectangle. */
220
- function minDist (
221
- point : [ number , number ] ,
222
- x1 : number ,
223
- y1 : number ,
224
- x2 : number ,
225
- y2 : number
226
- ) {
227
- let x = point [ 0 ] ;
228
- let y = point [ 1 ] ;
229
- let dx1 = x - x1 ;
230
- let dx2 = x - x2 ;
231
- let dy1 = y - y1 ;
232
- let dy2 = y - y2 ;
233
- if ( dx1 * dx2 <= 0 ) {
234
- // x is between x1 and x2
235
- if ( dy1 * dy2 <= 0 ) {
236
- // (x,y) is inside the rectangle
237
- return 0 ; // return 0 as point is in rect
238
- }
239
- return Math . min ( Math . abs ( dy1 ) , Math . abs ( dy2 ) ) ;
240
- }
241
- if ( dy1 * dy2 <= 0 ) {
242
- // y is between y1 and y2
243
- // We know it is already inside the rectangle
244
- return Math . min ( Math . abs ( dx1 ) , Math . abs ( dx2 ) ) ;
245
- }
246
- let corner : [ number , number ] ;
247
- if ( x > x2 ) {
248
- // Upper-right vs lower-right.
249
- corner = y > y2 ? [ x2 , y2 ] : [ x2 , y1 ] ;
250
- } else {
251
- // Upper-left vs lower-left.
252
- corner = y > y2 ? [ x1 , y2 ] : [ x1 , y1 ] ;
253
- }
254
- return Math . sqrt ( vector . dist22D ( [ x , y ] , corner ) ) ;
255
- }
256
174
/**
257
175
* Returns the nearest neighbors of a particular point.
258
176
*
@@ -282,4 +200,3 @@ export function findKNNofPoint<T>(
282
200
return kMin . getMinKItems ( ) ;
283
201
}
284
202
285
- export const TEST_ONLY = { OPTIMAL_GPU_BLOCK_SIZE } ;
0 commit comments