|
14 | 14 | #include <torch/custom_class.h>
|
15 | 15 |
|
16 | 16 | #include "./ssd_table_batched_embeddings.h"
|
| 17 | +#include "embedding_rocksdb_wrapper.h" |
17 | 18 | #include "fbgemm_gpu/utils/ops_utils.h"
|
18 | 19 |
|
19 | 20 | using namespace at;
|
@@ -258,134 +259,6 @@ void compact_indices_cuda(
|
258 | 259 |
|
259 | 260 | namespace ssd {
|
260 | 261 |
|
261 |
| -class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { |
262 |
| - public: |
263 |
| - EmbeddingRocksDBWrapper( |
264 |
| - std::string path, |
265 |
| - int64_t num_shards, |
266 |
| - int64_t num_threads, |
267 |
| - int64_t memtable_flush_period, |
268 |
| - int64_t memtable_flush_offset, |
269 |
| - int64_t l0_files_per_compact, |
270 |
| - int64_t max_D, |
271 |
| - int64_t rate_limit_mbps, |
272 |
| - int64_t size_ratio, |
273 |
| - int64_t compaction_ratio, |
274 |
| - int64_t write_buffer_size, |
275 |
| - int64_t max_write_buffer_num, |
276 |
| - double uniform_init_lower, |
277 |
| - double uniform_init_upper, |
278 |
| - int64_t row_storage_bitwidth = 32, |
279 |
| - int64_t cache_size = 0, |
280 |
| - bool use_passed_in_path = false, |
281 |
| - int64_t tbe_unique_id = 0, |
282 |
| - int64_t l2_cache_size_gb = 0, |
283 |
| - bool enable_async_update = false) |
284 |
| - : impl_(std::make_shared<ssd::EmbeddingRocksDB>( |
285 |
| - path, |
286 |
| - num_shards, |
287 |
| - num_threads, |
288 |
| - memtable_flush_period, |
289 |
| - memtable_flush_offset, |
290 |
| - l0_files_per_compact, |
291 |
| - max_D, |
292 |
| - rate_limit_mbps, |
293 |
| - size_ratio, |
294 |
| - compaction_ratio, |
295 |
| - write_buffer_size, |
296 |
| - max_write_buffer_num, |
297 |
| - uniform_init_lower, |
298 |
| - uniform_init_upper, |
299 |
| - row_storage_bitwidth, |
300 |
| - cache_size, |
301 |
| - use_passed_in_path, |
302 |
| - tbe_unique_id, |
303 |
| - l2_cache_size_gb, |
304 |
| - enable_async_update)) {} |
305 |
| - |
306 |
| - void set_cuda( |
307 |
| - Tensor indices, |
308 |
| - Tensor weights, |
309 |
| - Tensor count, |
310 |
| - int64_t timestep, |
311 |
| - bool is_bwd) { |
312 |
| - return impl_->set_cuda(indices, weights, count, timestep, is_bwd); |
313 |
| - } |
314 |
| - |
315 |
| - void get_cuda(Tensor indices, Tensor weights, Tensor count) { |
316 |
| - return impl_->get_cuda(indices, weights, count); |
317 |
| - } |
318 |
| - |
319 |
| - void set(Tensor indices, Tensor weights, Tensor count) { |
320 |
| - return impl_->set(indices, weights, count); |
321 |
| - } |
322 |
| - |
323 |
| - void set_range_to_storage( |
324 |
| - const Tensor& weights, |
325 |
| - const int64_t start, |
326 |
| - const int64_t length) { |
327 |
| - return impl_->set_range_to_storage(weights, start, length); |
328 |
| - } |
329 |
| - |
330 |
| - void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) { |
331 |
| - return impl_->get(indices, weights, count, sleep_ms); |
332 |
| - } |
333 |
| - |
334 |
| - std::vector<int64_t> get_mem_usage() { |
335 |
| - return impl_->get_mem_usage(); |
336 |
| - } |
337 |
| - |
338 |
| - std::vector<double> get_rocksdb_io_duration( |
339 |
| - const int64_t step, |
340 |
| - const int64_t interval) { |
341 |
| - return impl_->get_rocksdb_io_duration(step, interval); |
342 |
| - } |
343 |
| - |
344 |
| - std::vector<double> get_l2cache_perf( |
345 |
| - const int64_t step, |
346 |
| - const int64_t interval) { |
347 |
| - return impl_->get_l2cache_perf(step, interval); |
348 |
| - } |
349 |
| - |
350 |
| - void compact() { |
351 |
| - return impl_->compact(); |
352 |
| - } |
353 |
| - |
354 |
| - void flush() { |
355 |
| - return impl_->flush(); |
356 |
| - } |
357 |
| - |
358 |
| - void reset_l2_cache() { |
359 |
| - return impl_->reset_l2_cache(); |
360 |
| - } |
361 |
| - |
362 |
| - void wait_util_filling_work_done() { |
363 |
| - return impl_->wait_util_filling_work_done(); |
364 |
| - } |
365 |
| - |
366 |
| - c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> create_snapshot() { |
367 |
| - auto handle = impl_->create_snapshot(); |
368 |
| - return c10::make_intrusive<EmbeddingSnapshotHandleWrapper>(handle, impl_); |
369 |
| - } |
370 |
| - |
371 |
| - void release_snapshot( |
372 |
| - c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle) { |
373 |
| - auto handle = snapshot_handle->handle; |
374 |
| - CHECK_NE(handle, nullptr); |
375 |
| - impl_->release_snapshot(handle); |
376 |
| - } |
377 |
| - |
378 |
| - int64_t get_snapshot_count() const { |
379 |
| - return impl_->get_snapshot_count(); |
380 |
| - } |
381 |
| - |
382 |
| - private: |
383 |
| - friend class KVTensorWrapper; |
384 |
| - |
385 |
| - // shared pointer since we use shared_from_this() in callbacks. |
386 |
| - std::shared_ptr<ssd::EmbeddingRocksDB> impl_; |
387 |
| -}; |
388 |
| - |
389 | 262 | SnapshotHandle::SnapshotHandle(EmbeddingRocksDB* db) : db_(db) {
|
390 | 263 | auto num_shards = db->num_shards();
|
391 | 264 | CHECK_GT(num_shards, 0);
|
|
0 commit comments