You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Extends type and shape tracing with device (pytorch#9796)
Summary:
This PR extends the existing type and shape metadata tracing and verification done in autograd with device information. This expansion of tracing is required for pytorch#8354, is likely useful in other scenarios, and is a healthy sanity check, just like type and shape tracing.
The precise changes are:
- TypeAndShape -> InputMetadata, now includes device()
- Creating InputMetadata is simplified to just require a tensor, and callers were updated to use this simpler invocation wherever possible
- The gradient accumulator of a variable is now reset when set_data() is called if either the type or device changes, and this reset now locks to avoid contention with acquiring the gradient accumulator
- Mismatched devices during backward() will throw a runtime error, just like mismatched type and shape
- (Bonus!) Two uninitialized pointers in THCReduce are now initialized (to nullptr) to prevent build warnings
fyi colesbury
Pull Request resolved: pytorch#9796
Reviewed By: goldsborough
Differential Revision: D9119325
Pulled By: ezyang
fbshipit-source-id: 76d1861b8d4f74db0575ff1f3bd965e18f9463de
0 commit comments