-
Notifications
You must be signed in to change notification settings - Fork 27
Adapt code to use NestedTensor #313
Comments
Hello @bubas3000, Yes! This should be possible :) What would be the shape of Thanks, |
Hello @cpuhrsch , Thanks for your quick reply! The shape of inputs is (12000,100,9) I will try to explain what each one means to be more clear. The linear layer transforms the 9 dimensions to 64 (is done pointwise). Thank you for your help, |
I forgot to mention that norm is BatchNorm1d. One more question if you allow me. How is the time performance of nested tensors in comparison to normal torch tensors? Thank you once again, |
Hello @bubas3000, I wrote up a codesnippet and am now working on adding ops required to do this. For now this is without autograd support, which will follow in another PR. Here is the snippet from the PR referenced in this issue
Does this align with your goals? Thanks, |
Hello @cpuhrsch , That's what I am looking for, thank you! Is it expected to have autograd support soon or should I try to do it "by hand"? Thank you once more, |
Hello @bubas3000, Autograd is already supported, but I need to double check all backward passes have been implemented. The forward PR was merged, so I'm doing that next now. Regarding time performance, most of these kernels are currently still implemented as for-loops. However, let me trace through the ops you're using and see if we can implement a fast-path for those shapes. As an aside, BatchNorm1d will be the least likely to match performance of a regular torch.Tensor, because PyTorch calls into cudnn's highly optimizes version of it. To support irregular shapes BatchNorm1d here is implemented via regular math operators. Thank you, |
Hello @cpuhrsch , Using for-loops will definitely hurt my time performance... I can try to run without BatchNorm1d if you think it would help. Thank you, |
Ps tried to run this snippet and I got the following error: Traceback (most recent call last):
|
Hello @bubas3000, Are you using the most recent commit? If you're using the binaries, make sure to force a clean reinstall to get the newest ones (they get automatically rebuilt over night). You can print the version+hash via Thanks, |
Hi @cpuhrsch , Thank you, |
Hello @bubas3000, I'm happy to hear that! It's enough to cite this git, there is no paper yet. Would you be willing to share your solution? We can use that as a baseline for future performance improvements. Thank you, |
I have a model where I would love to use NestedTensor, I have a lot of padding going on and nested tensors would save a lot of memory, the net where I would like to use them is composed by a linear layer followed by batchnorm and Relu, finally a max operation is done over the channels.
Foward looks like this
def forward(self, inputs):
Is it possible to use Nested Tensors? The project supports python 3.6+, pytorch 0.4.1+.
Thank you in advance
The text was updated successfully, but these errors were encountered: