Skip to content

Commit 2301d17

Browse files
pradeepfnfacebook-github-bot
authored andcommitted
Move embedding_rocksdb_wrapper to its own header. (pytorch#700)
Summary: X-link: pytorch#3622 Pull Request resolved: facebookresearch/FBGEMM#700 because we can't instantiate the kvTensorWrapper without EmbeddingRDBWrapper in C++. $title. Reviewed By: jiayulu Differential Revision: D68727692 fbshipit-source-id: ca141c45ac2520323a55ab93a63ba617a19c8c57
1 parent c57de6f commit 2301d17

File tree

2 files changed

+144
-128
lines changed

2 files changed

+144
-128
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include "kv_tensor_wrapper.h"
12+
13+
namespace ssd {
14+
15+
class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
16+
public:
17+
EmbeddingRocksDBWrapper(
18+
std::string path,
19+
int64_t num_shards,
20+
int64_t num_threads,
21+
int64_t memtable_flush_period,
22+
int64_t memtable_flush_offset,
23+
int64_t l0_files_per_compact,
24+
int64_t max_D,
25+
int64_t rate_limit_mbps,
26+
int64_t size_ratio,
27+
int64_t compaction_ratio,
28+
int64_t write_buffer_size,
29+
int64_t max_write_buffer_num,
30+
double uniform_init_lower,
31+
double uniform_init_upper,
32+
int64_t row_storage_bitwidth = 32,
33+
int64_t cache_size = 0,
34+
bool use_passed_in_path = false,
35+
int64_t tbe_unique_id = 0,
36+
int64_t l2_cache_size_gb = 0,
37+
bool enable_async_update = false)
38+
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
39+
path,
40+
num_shards,
41+
num_threads,
42+
memtable_flush_period,
43+
memtable_flush_offset,
44+
l0_files_per_compact,
45+
max_D,
46+
rate_limit_mbps,
47+
size_ratio,
48+
compaction_ratio,
49+
write_buffer_size,
50+
max_write_buffer_num,
51+
uniform_init_lower,
52+
uniform_init_upper,
53+
row_storage_bitwidth,
54+
cache_size,
55+
use_passed_in_path,
56+
tbe_unique_id,
57+
l2_cache_size_gb,
58+
enable_async_update)) {}
59+
60+
void set_cuda(
61+
Tensor indices,
62+
Tensor weights,
63+
Tensor count,
64+
int64_t timestep,
65+
bool is_bwd) {
66+
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
67+
}
68+
69+
void get_cuda(Tensor indices, Tensor weights, Tensor count) {
70+
return impl_->get_cuda(indices, weights, count);
71+
}
72+
73+
void set(Tensor indices, Tensor weights, Tensor count) {
74+
return impl_->set(indices, weights, count);
75+
}
76+
77+
void set_range_to_storage(
78+
const Tensor& weights,
79+
const int64_t start,
80+
const int64_t length) {
81+
return impl_->set_range_to_storage(weights, start, length);
82+
}
83+
84+
void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) {
85+
return impl_->get(indices, weights, count, sleep_ms);
86+
}
87+
88+
std::vector<int64_t> get_mem_usage() {
89+
return impl_->get_mem_usage();
90+
}
91+
92+
std::vector<double> get_rocksdb_io_duration(
93+
const int64_t step,
94+
const int64_t interval) {
95+
return impl_->get_rocksdb_io_duration(step, interval);
96+
}
97+
98+
std::vector<double> get_l2cache_perf(
99+
const int64_t step,
100+
const int64_t interval) {
101+
return impl_->get_l2cache_perf(step, interval);
102+
}
103+
104+
void compact() {
105+
return impl_->compact();
106+
}
107+
108+
void flush() {
109+
return impl_->flush();
110+
}
111+
112+
void reset_l2_cache() {
113+
return impl_->reset_l2_cache();
114+
}
115+
116+
void wait_util_filling_work_done() {
117+
return impl_->wait_util_filling_work_done();
118+
}
119+
120+
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> create_snapshot() {
121+
auto handle = impl_->create_snapshot();
122+
return c10::make_intrusive<EmbeddingSnapshotHandleWrapper>(handle, impl_);
123+
}
124+
125+
void release_snapshot(
126+
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle) {
127+
auto handle = snapshot_handle->handle;
128+
CHECK_NE(handle, nullptr);
129+
impl_->release_snapshot(handle);
130+
}
131+
132+
int64_t get_snapshot_count() const {
133+
return impl_->get_snapshot_count();
134+
}
135+
136+
private:
137+
friend class KVTensorWrapper;
138+
139+
// shared pointer since we use shared_from_this() in callbacks.
140+
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
141+
};
142+
143+
} // namespace ssd

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

+1-128
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <torch/custom_class.h>
1515

1616
#include "./ssd_table_batched_embeddings.h"
17+
#include "embedding_rocksdb_wrapper.h"
1718
#include "fbgemm_gpu/utils/ops_utils.h"
1819

1920
using namespace at;
@@ -258,134 +259,6 @@ void compact_indices_cuda(
258259

259260
namespace ssd {
260261

261-
class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
262-
public:
263-
EmbeddingRocksDBWrapper(
264-
std::string path,
265-
int64_t num_shards,
266-
int64_t num_threads,
267-
int64_t memtable_flush_period,
268-
int64_t memtable_flush_offset,
269-
int64_t l0_files_per_compact,
270-
int64_t max_D,
271-
int64_t rate_limit_mbps,
272-
int64_t size_ratio,
273-
int64_t compaction_ratio,
274-
int64_t write_buffer_size,
275-
int64_t max_write_buffer_num,
276-
double uniform_init_lower,
277-
double uniform_init_upper,
278-
int64_t row_storage_bitwidth = 32,
279-
int64_t cache_size = 0,
280-
bool use_passed_in_path = false,
281-
int64_t tbe_unique_id = 0,
282-
int64_t l2_cache_size_gb = 0,
283-
bool enable_async_update = false)
284-
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
285-
path,
286-
num_shards,
287-
num_threads,
288-
memtable_flush_period,
289-
memtable_flush_offset,
290-
l0_files_per_compact,
291-
max_D,
292-
rate_limit_mbps,
293-
size_ratio,
294-
compaction_ratio,
295-
write_buffer_size,
296-
max_write_buffer_num,
297-
uniform_init_lower,
298-
uniform_init_upper,
299-
row_storage_bitwidth,
300-
cache_size,
301-
use_passed_in_path,
302-
tbe_unique_id,
303-
l2_cache_size_gb,
304-
enable_async_update)) {}
305-
306-
void set_cuda(
307-
Tensor indices,
308-
Tensor weights,
309-
Tensor count,
310-
int64_t timestep,
311-
bool is_bwd) {
312-
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
313-
}
314-
315-
void get_cuda(Tensor indices, Tensor weights, Tensor count) {
316-
return impl_->get_cuda(indices, weights, count);
317-
}
318-
319-
void set(Tensor indices, Tensor weights, Tensor count) {
320-
return impl_->set(indices, weights, count);
321-
}
322-
323-
void set_range_to_storage(
324-
const Tensor& weights,
325-
const int64_t start,
326-
const int64_t length) {
327-
return impl_->set_range_to_storage(weights, start, length);
328-
}
329-
330-
void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) {
331-
return impl_->get(indices, weights, count, sleep_ms);
332-
}
333-
334-
std::vector<int64_t> get_mem_usage() {
335-
return impl_->get_mem_usage();
336-
}
337-
338-
std::vector<double> get_rocksdb_io_duration(
339-
const int64_t step,
340-
const int64_t interval) {
341-
return impl_->get_rocksdb_io_duration(step, interval);
342-
}
343-
344-
std::vector<double> get_l2cache_perf(
345-
const int64_t step,
346-
const int64_t interval) {
347-
return impl_->get_l2cache_perf(step, interval);
348-
}
349-
350-
void compact() {
351-
return impl_->compact();
352-
}
353-
354-
void flush() {
355-
return impl_->flush();
356-
}
357-
358-
void reset_l2_cache() {
359-
return impl_->reset_l2_cache();
360-
}
361-
362-
void wait_util_filling_work_done() {
363-
return impl_->wait_util_filling_work_done();
364-
}
365-
366-
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> create_snapshot() {
367-
auto handle = impl_->create_snapshot();
368-
return c10::make_intrusive<EmbeddingSnapshotHandleWrapper>(handle, impl_);
369-
}
370-
371-
void release_snapshot(
372-
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle) {
373-
auto handle = snapshot_handle->handle;
374-
CHECK_NE(handle, nullptr);
375-
impl_->release_snapshot(handle);
376-
}
377-
378-
int64_t get_snapshot_count() const {
379-
return impl_->get_snapshot_count();
380-
}
381-
382-
private:
383-
friend class KVTensorWrapper;
384-
385-
// shared pointer since we use shared_from_this() in callbacks.
386-
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
387-
};
388-
389262
SnapshotHandle::SnapshotHandle(EmbeddingRocksDB* db) : db_(db) {
390263
auto num_shards = db->num_shards();
391264
CHECK_GT(num_shards, 0);

0 commit comments

Comments
 (0)