Skip to content

Commit 07947d4

Browse files
Boris Saranafacebook-github-bot
Boris Sarana
authored andcommitted
Cleanup rollout code for sharding optimization (#2921)
Summary: Pull Request resolved: #2921 As per title, the optimization has been rolled out to production jobs for several months so it it is time to delete the rollout code and old implementation. Reviewed By: ilyas409 Differential Revision: D73693963 fbshipit-source-id: f31db8ada9eafb6c2df346ce3e60fc9798298f44
1 parent f9dd63c commit 07947d4

File tree

3 files changed

+24
-61
lines changed

3 files changed

+24
-61
lines changed

torchrec/distributed/embedding_types.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
from torch.distributed._tensor.placement_types import Placement
3232
from torch.nn.modules.module import _addindent
3333
from torch.nn.parallel import DistributedDataParallel
34-
from torchrec.distributed.global_settings import (
35-
construct_sharded_tensor_from_metadata_enabled,
36-
)
34+
3735
from torchrec.distributed.types import (
3836
get_tensor_size_bytes,
3937
ModuleSharder,
@@ -358,11 +356,6 @@ def __init__(
358356
self._lookups: List[nn.Module] = []
359357
self._output_dists: List[nn.Module] = []
360358

361-
# option to construct ShardedTensor from metadata avoiding expensive all-gather
362-
self._construct_sharded_tensor_from_metadata: bool = (
363-
construct_sharded_tensor_from_metadata_enabled()
364-
)
365-
366359
def prefetch(
367360
self,
368361
dist_input: KJTList,

torchrec/distributed/embeddingbag.py

+23-37
Original file line numberDiff line numberDiff line change
@@ -1006,46 +1006,32 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no
10061006
# access is allowed on them.
10071007

10081008
# create ShardedTensor from local shards and metadata avoding all_gather collective
1009-
if self._construct_sharded_tensor_from_metadata:
1010-
sharding_spec = none_throws(
1011-
self.module_sharding_plan[table_name].sharding_spec
1012-
)
1013-
1014-
tensor_properties = TensorProperties(
1015-
dtype=(
1016-
data_type_to_dtype(
1017-
self._table_name_to_config[table_name].data_type
1018-
)
1019-
),
1020-
)
1009+
sharding_spec = none_throws(
1010+
self.module_sharding_plan[table_name].sharding_spec
1011+
)
10211012

1022-
self._model_parallel_name_to_sharded_tensor[table_name] = (
1023-
ShardedTensor._init_from_local_shards_and_global_metadata(
1024-
local_shards=local_shards,
1025-
sharded_tensor_metadata=sharding_spec.build_metadata(
1026-
tensor_sizes=self._name_to_table_size[table_name],
1027-
tensor_properties=tensor_properties,
1028-
),
1029-
process_group=(
1030-
self._env.sharding_pg
1031-
if isinstance(self._env, ShardingEnv2D)
1032-
else self._env.process_group
1033-
),
1034-
)
1035-
)
1036-
else:
1037-
# create ShardedTensor from local shards using all_gather collective
1038-
self._model_parallel_name_to_sharded_tensor[table_name] = (
1039-
ShardedTensor._init_from_local_shards(
1040-
local_shards,
1041-
self._name_to_table_size[table_name],
1042-
process_group=(
1043-
self._env.sharding_pg
1044-
if isinstance(self._env, ShardingEnv2D)
1045-
else self._env.process_group
1046-
),
1013+
tensor_properties = TensorProperties(
1014+
dtype=(
1015+
data_type_to_dtype(
1016+
self._table_name_to_config[table_name].data_type
10471017
)
1018+
),
1019+
)
1020+
1021+
self._model_parallel_name_to_sharded_tensor[table_name] = (
1022+
ShardedTensor._init_from_local_shards_and_global_metadata(
1023+
local_shards=local_shards,
1024+
sharded_tensor_metadata=sharding_spec.build_metadata(
1025+
tensor_sizes=self._name_to_table_size[table_name],
1026+
tensor_properties=tensor_properties,
1027+
),
1028+
process_group=(
1029+
self._env.sharding_pg
1030+
if isinstance(self._env, ShardingEnv2D)
1031+
else self._env.process_group
1032+
),
10481033
)
1034+
)
10491035

10501036
def extract_sharded_kvtensors(
10511037
module: ShardedEmbeddingBagCollection,

torchrec/distributed/global_settings.py

-16
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,8 @@
77

88
# pyre-strict
99

10-
import os
11-
1210
PROPOGATE_DEVICE: bool = False
1311

14-
TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = (
15-
"TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA"
16-
)
17-
1812

1913
def set_propogate_device(val: bool) -> None:
2014
global PROPOGATE_DEVICE
@@ -24,13 +18,3 @@ def set_propogate_device(val: bool) -> None:
2418
def get_propogate_device() -> bool:
2519
global PROPOGATE_DEVICE
2620
return PROPOGATE_DEVICE
27-
28-
29-
def construct_sharded_tensor_from_metadata_enabled() -> bool:
30-
return (
31-
os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
32-
)
33-
34-
35-
def enable_construct_sharded_tensor_from_metadata() -> None:
36-
os.environ[TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV] = "1"

0 commit comments

Comments
 (0)