Skip to content

torchrec support on kvzch emb lookup module #4035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

duduyi2013
Copy link
Contributor

Summary:

Change logs

  1. add ZeroCollisionKeyValueEmbedding emb lookup
  2. address existing unit test missing for ssd offloading
  3. add new ut for kv zch embedding module
  4. add a temp hack solution for calculate bucket metadata
  5. embedding updates, details illustrated below

#######################################################################
########################### embedding.py updates ##########################
#######################################################################

  1. keep the original idea to init shardedTensor during training init
  2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
  3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
  4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
  5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
  6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done on the nn.module side, which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Differential Revision: D73567631

Summary:
change list
1. add bucket concept into ssd tbe
2. update split_embedding_weights to make it return a tuple of 3 tensors(weight, weight_id, id_cnt_per_bucket)
3. add new ut for the key value embedding cases

Differential Revision: D73274786
Summary:
# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Differential Revision: D73567631
Copy link

netlify bot commented Apr 28, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 26b0b47
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/680fda70d6988b0008722fca
😎 Deploy Preview https://deploy-preview-4035--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants