-
Notifications
You must be signed in to change notification settings - Fork 495
short circuit the flatten/unflatten between EBC and KTRegroupAsDict modules #2393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+138
−3
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This pull request was exported from Phabricator. Differential Revision: D62606738 |
This pull request was exported from Phabricator. Differential Revision: D62606738 |
TroyGarden
added a commit
to TroyGarden/torchrec
that referenced
this pull request
Sep 13, 2024
…odules (pytorch#2393) Summary: X-link: pytorch/pytorch#136045 Pull Request resolved: pytorch#2393 # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Differential Revision: D62606738
18cdef6
to
0ce7346
Compare
TroyGarden
added a commit
to TroyGarden/pytorch
that referenced
this pull request
Sep 13, 2024
…BC and KTRegroupAsDict modules (pytorch#136045) Summary: Pull Request resolved: pytorch#136045 X-link: pytorch/torchrec#2393 # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Test Plan: # run test * serializer ``` buck2 run fbcode//mode/opt fbcode//torchrec/ir/tests:test_serializer ``` * apf ``` buck2 run fbcode//mode/opt fbcode//aps_models/ads/gmp/tests/ne/e2e_deterministic_tests:gmp_e2e_ne_tests -- --filter-text 'test_mtml_instagram_model_562438350_single_gpu_with_ir' ``` * local mp run ``` ==== Finished E2E deterministic test for mtml_instagram_model_gmp_474023725_non_kjt_unary ==== finished test_mtml_instagram_model_562438350_single_gpu_with_ir Imports took: 6.0s! Profile with --import-profiler. --_ |""---__ Executed 1 example in 203.1s: |'.| || . """| Successful: 1 | || || /|\""-. | Failed: 0 | || || | | | Skipped: 0 | || || | \|/ | Not executed: 8 |."| || --"" '__| https://testslide.readthedocs.io/ --" |__---""" ``` Differential Revision: D62606738
TroyGarden
added a commit
to TroyGarden/pytorch
that referenced
this pull request
Sep 13, 2024
…BC and KTRegroupAsDict modules (pytorch#136045) Summary: Pull Request resolved: pytorch#136045 X-link: pytorch/torchrec#2393 # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Test Plan: # run test * serializer ``` buck2 run fbcode//mode/opt fbcode//torchrec/ir/tests:test_serializer ``` * apf ``` buck2 run fbcode//mode/opt fbcode//aps_models/ads/gmp/tests/ne/e2e_deterministic_tests:gmp_e2e_ne_tests -- --filter-text 'test_mtml_instagram_model_562438350_single_gpu_with_ir' ``` * local mp run ``` ==== Finished E2E deterministic test for mtml_instagram_model_gmp_474023725_non_kjt_unary ==== finished test_mtml_instagram_model_562438350_single_gpu_with_ir Imports took: 6.0s! Profile with --import-profiler. --_ |""---__ Executed 1 example in 203.1s: |'.| || . """| Successful: 1 | || || /|\""-. | Failed: 0 | || || | | | Skipped: 0 | || || | \|/ | Not executed: 8 |."| || --"" '__| https://testslide.readthedocs.io/ --" |__---""" ``` Differential Revision: D62606738
TroyGarden
added a commit
to TroyGarden/pytorch
that referenced
this pull request
Sep 13, 2024
…BC and KTRegroupAsDict modules (pytorch#136045) Summary: Pull Request resolved: pytorch#136045 X-link: pytorch/torchrec#2393 # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Test Plan: # run test * serializer ``` buck2 run fbcode//mode/opt fbcode//torchrec/ir/tests:test_serializer ``` * apf ``` buck2 run fbcode//mode/opt fbcode//aps_models/ads/gmp/tests/ne/e2e_deterministic_tests:gmp_e2e_ne_tests -- --filter-text 'test_mtml_instagram_model_562438350_single_gpu_with_ir' ``` * local mp run ``` ==== Finished E2E deterministic test for mtml_instagram_model_gmp_474023725_non_kjt_unary ==== finished test_mtml_instagram_model_562438350_single_gpu_with_ir Imports took: 6.0s! Profile with --import-profiler. --_ |""---__ Executed 1 example in 203.1s: |'.| || . """| Successful: 1 | || || /|\""-. | Failed: 0 | || || | | | Skipped: 0 | || || | \|/ | Not executed: 8 |."| || --"" '__| https://testslide.readthedocs.io/ --" |__---""" ``` Differential Revision: D62606738
TroyGarden
added a commit
to TroyGarden/pytorch
that referenced
this pull request
Sep 13, 2024
…BC and KTRegroupAsDict modules (pytorch#136045) Summary: Pull Request resolved: pytorch#136045 X-link: pytorch/torchrec#2393 # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Test Plan: # run test * serializer ``` buck2 run fbcode//mode/opt fbcode//torchrec/ir/tests:test_serializer ``` * apf ``` buck2 run fbcode//mode/opt fbcode//aps_models/ads/gmp/tests/ne/e2e_deterministic_tests:gmp_e2e_ne_tests -- --filter-text 'test_mtml_instagram_model_562438350_single_gpu_with_ir' ``` * local mp run ``` ==== Finished E2E deterministic test for mtml_instagram_model_gmp_474023725_non_kjt_unary ==== finished test_mtml_instagram_model_562438350_single_gpu_with_ir Imports took: 6.0s! Profile with --import-profiler. --_ |""---__ Executed 1 example in 203.1s: |'.| || . """| Successful: 1 | || || /|\""-. | Failed: 0 | || || | | | Skipped: 0 | || || | \|/ | Not executed: 8 |."| || --"" '__| https://testslide.readthedocs.io/ --" |__---""" ``` Differential Revision: D62606738
…odules (pytorch#2393) Summary: X-link: pytorch/pytorch#136045 Pull Request resolved: pytorch#2393 # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Reviewed By: malaybag Differential Revision: D62606738
This pull request was exported from Phabricator. Differential Revision: D62606738 |
0ce7346
to
4ea2aa5
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
fb-exported
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
context
NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the pic {F1864810545}
short_circuit_pytree_ebc_regroup
to the torchrec main api calldecapsulate_ir_modules
, which should only be visible to the infra layer, not to the users.details
_short_circuit_pytree_ebc_regroup
function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC.prune_pytree_flatten_unflatten
removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns.WARNING: The flag
short_circuit_pytree_ebc_regroup
should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find aKTRegroupAsDict
module, orfinalize_interpreter_modules
is notTrue
.additional changes
finalize_interpreter_modules
process inside the torchrec main apidecapsulate_ir_modules
.graph.owning_module
in export.unflatten as required by the graph modificationsparse_module
for closely mimicing the APF model structure.Differential Revision: D62606738