Skip to content

Commit 9ae767d

Browse files
committed
remove knn splitting
1 parent e5d1771 commit 9ae767d

File tree

2 files changed

+29
-145
lines changed

2 files changed

+29
-145
lines changed

tensorboard/plugins/projector/vz_projector/knn.ts

+28-111
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,9 @@ export type NearestEntry = {
2222
index: number;
2323
dist: number;
2424
};
25-
/**
26-
* Optimal size for the height of the matrix when doing computation on the GPU
27-
* using WebGL. This was found experimentally.
28-
*
29-
* This also guarantees that for computing pair-wise distance for up to 10K
30-
* vectors, no more than 40MB will be allocated in the GPU. Without the
31-
* allocation limit, we can freeze the graphics of the whole OS.
32-
*/
33-
const OPTIMAL_GPU_BLOCK_SIZE = 256;
34-
/** Id of message box used for knn gpu progress bar. */
35-
const KNN_GPU_MSG_ID = 'knn-gpu';
25+
26+
/** Id of message box used for knn. */
27+
const KNN_MSG_ID = 'knn';
3628

3729
/**
3830
* Returns the K nearest neighbors for each vector where the distance
@@ -52,105 +44,66 @@ export function findKNNGPUCosDistNorm<T>(
5244
const N = dataPoints.length;
5345
const dim = accessor(dataPoints[0]).length;
5446
// The goal is to compute a large matrix multiplication A*A.T where A is of
55-
// size NxD and A.T is its transpose. This results in a NxN matrix which
56-
// could be too big to store on the GPU memory. To avoid memory overflow, we
57-
// compute multiple A*partial_A.T where partial_A is of size BxD (B is much
58-
// smaller than N). This results in storing only NxB size matrices on the GPU
59-
// at a given time.
47+
// size NxD and A.T is its transpose. This results in a NxN matrix.
6048
// A*A.T will give us NxN matrix holding the cosine distance between every
6149
// pair of points, which we sort using KMin data structure to obtain the
6250
// K nearest neighbors for each point.
6351
const nearest: NearestEntry[][] = new Array(N);
64-
let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE);
65-
const actualPieceSize = Math.floor(N / numPieces);
66-
const modulo = N % actualPieceSize;
67-
numPieces += modulo ? 1 : 0;
68-
let offset = 0;
69-
let progress = 0;
70-
let progressDiff = 1 / (2 * numPieces);
71-
let piece = 0;
7252

7353
const typedArray = vector.toTypedArray(dataPoints, accessor);
7454
const bigMatrix = tf.tensor(typedArray, [N, dim]);
7555
const bigMatrixTransposed = tf.transpose(bigMatrix);
76-
// 1 - A * A^T.
77-
const bigMatrixSquared = tf.matMul(bigMatrix, bigMatrixTransposed);
78-
const cosDistMatrix = tf.sub(1, bigMatrixSquared);
79-
80-
let maybePaddedCosDistMatrix = cosDistMatrix;
81-
if (actualPieceSize * numPieces > N) {
82-
// Expect the input to be rank 2 (though it is not typed that way) so we
83-
// want to pad the first dimension so we split very evenly (all splitted
84-
// tensor have exactly the same dimesion).
85-
const padding: Array<[number, number]> = [
86-
[0, actualPieceSize * numPieces - N],
87-
[0, 0],
88-
];
89-
maybePaddedCosDistMatrix = tf.pad(cosDistMatrix, padding);
90-
}
91-
const splits = tf.split(
92-
maybePaddedCosDistMatrix,
93-
new Array(numPieces).fill(actualPieceSize),
94-
0
95-
);
9656

9757
function step(resolve: (result: NearestEntry[][]) => void) {
98-
let progressMsg =
99-
'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%';
10058
util
10159
.runAsyncTask(
102-
progressMsg,
60+
'Finding nearest neighbors...',
10361
async () => {
62+
// 1 - A * A^T.
63+
const bigMatrixSquared = tf.matMul(bigMatrix, bigMatrixTransposed);
64+
const cosDistMatrix = tf.sub(1, bigMatrixSquared);
10465
// `.data()` returns flattened Float32Array of B * N dimension.
10566
// For matrix of
10667
// [ 1 2 ]
10768
// [ 3 4 ],
10869
// `.data()` returns [1, 2, 3, 4].
109-
const partial = await splits[piece].data();
110-
progress += progressDiff;
111-
for (let i = 0; i < actualPieceSize; i++) {
70+
const partial = await cosDistMatrix.data();
71+
bigMatrixSquared.dispose();
72+
cosDistMatrix.dispose();
73+
for (let i = 0; i < N; i++) {
11274
let kMin = new KMin<NearestEntry>(k);
113-
let iReal = offset + i;
114-
if (iReal >= N) break;
11575
for (let j = 0; j < N; j++) {
11676
// Skip diagonal entries.
117-
if (j === iReal) {
77+
if (j === i) {
11878
continue;
11979
}
12080
// Access i * N's row at `j` column.
12181
// Reach row has N entries and j-th index has cosine distance
122-
// between iReal vs. j-th vectors.
82+
// between i-th vs. j-th vectors.
12383
const cosDist = partial[i * N + j];
12484
if (cosDist >= 0) {
12585
kMin.add(cosDist, {index: j, dist: cosDist});
12686
}
12787
}
128-
nearest[iReal] = kMin.getMinKItems();
88+
nearest[i] = kMin.getMinKItems();
12989
}
130-
progress += progressDiff;
131-
offset += actualPieceSize;
132-
piece++;
13390
},
134-
KNN_GPU_MSG_ID
91+
KNN_MSG_ID,
13592
)
13693
.then(
13794
() => {
138-
if (piece < numPieces) {
139-
step(resolve);
140-
} else {
141-
logging.setModalMessage(null!, KNN_GPU_MSG_ID);
142-
// Discard all tensors and free up the memory.
143-
bigMatrix.dispose();
144-
bigMatrixTransposed.dispose();
145-
bigMatrixSquared.dispose();
146-
cosDistMatrix.dispose();
147-
splits.forEach((split) => split.dispose());
148-
resolve(nearest);
149-
}
95+
logging.setModalMessage(null!, KNN_MSG_ID);
96+
// Discard all tensors and free up the memory.
97+
bigMatrix.dispose();
98+
bigMatrixTransposed.dispose();
99+
resolve(nearest);
150100
},
151101
(error) => {
102+
// Discard all tensors and free up the memory.
103+
bigMatrix.dispose();
104+
bigMatrixTransposed.dispose();
152105
// GPU failed. Reverting back to CPU.
153-
logging.setModalMessage(null!, KNN_GPU_MSG_ID);
106+
logging.setModalMessage(null!, KNN_MSG_ID);
154107
let distFunc = (a, b, limit) => vector.cosDistNorm(a, b);
155108
findKNN(dataPoints, k, accessor, distFunc).then((nearest) => {
156109
resolve(nearest);
@@ -212,47 +165,12 @@ export function findKNN<T>(
212165
for (let i = 0; i < N; i++) {
213166
nearest[i] = kMin[i].getMinKItems();
214167
}
168+
logging.setModalMessage(null!, KNN_MSG_ID);
215169
return nearest;
216-
}
170+
},
171+
KNN_MSG_ID,
217172
);
218173
}
219-
/** Calculates the minimum distance between a search point and a rectangle. */
220-
function minDist(
221-
point: [number, number],
222-
x1: number,
223-
y1: number,
224-
x2: number,
225-
y2: number
226-
) {
227-
let x = point[0];
228-
let y = point[1];
229-
let dx1 = x - x1;
230-
let dx2 = x - x2;
231-
let dy1 = y - y1;
232-
let dy2 = y - y2;
233-
if (dx1 * dx2 <= 0) {
234-
// x is between x1 and x2
235-
if (dy1 * dy2 <= 0) {
236-
// (x,y) is inside the rectangle
237-
return 0; // return 0 as point is in rect
238-
}
239-
return Math.min(Math.abs(dy1), Math.abs(dy2));
240-
}
241-
if (dy1 * dy2 <= 0) {
242-
// y is between y1 and y2
243-
// We know it is already inside the rectangle
244-
return Math.min(Math.abs(dx1), Math.abs(dx2));
245-
}
246-
let corner: [number, number];
247-
if (x > x2) {
248-
// Upper-right vs lower-right.
249-
corner = y > y2 ? [x2, y2] : [x2, y1];
250-
} else {
251-
// Upper-left vs lower-left.
252-
corner = y > y2 ? [x1, y2] : [x1, y1];
253-
}
254-
return Math.sqrt(vector.dist22D([x, y], corner));
255-
}
256174
/**
257175
* Returns the nearest neighbors of a particular point.
258176
*
@@ -282,4 +200,3 @@ export function findKNNofPoint<T>(
282200
return kMin.getMinKItems();
283201
}
284202

285-
export const TEST_ONLY = {OPTIMAL_GPU_BLOCK_SIZE};

tensorboard/plugins/projector/vz_projector/knn_test.ts

+1-34
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
import {findKNN, findKNNGPUCosDistNorm, NearestEntry, TEST_ONLY} from './knn';
15+
import {findKNN, findKNNGPUCosDistNorm, NearestEntry} from './knn';
1616
import {cosDistNorm, unit} from './vector';
1717

1818
describe('projector knn test', () => {
@@ -65,22 +65,6 @@ describe('projector knn test', () => {
6565

6666
expect(getIndices(values)).toEqual([[1], [0]]);
6767
});
68-
69-
it('splits a large data into one that would fit into GPU memory', async () => {
70-
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
71-
const data = new Array(size).fill(
72-
unitVector(new Float32Array([1, 1, 1]))
73-
);
74-
const values = await findKNNGPUCosDistNorm(data, 1, (a) => a);
75-
76-
expect(getIndices(values)).toEqual([
77-
// Since distance to the diagonal entries (distance to self is 0) is
78-
// non-sensical, the diagonal entires are ignored. So for the first
79-
// item, the nearest neighbor should be 2nd item (index 1).
80-
[1],
81-
...new Array(size - 1).fill([0]),
82-
]);
83-
});
8468
});
8569

8670
describe('#findKNN', () => {
@@ -108,22 +92,5 @@ describe('projector knn test', () => {
10892
// Floating point precision makes it hard to test. Just assert indices.
10993
expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal));
11094
});
111-
112-
it('splits a large data without the result being wrong', async () => {
113-
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
114-
const data = Array.from(new Array(size)).map((_, index) => {
115-
return unitVector(new Float32Array([index + 1, index + 1]));
116-
});
117-
118-
const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a);
119-
const findKnnVal = await findKNN(
120-
data,
121-
2,
122-
(a) => a,
123-
(a, b, limit) => cosDistNorm(a, b)
124-
);
125-
126-
expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal));
127-
});
12895
});
12996
});

0 commit comments

Comments
 (0)