diff --git a/advanced_source/semi_structured_sparse.py b/advanced_source/semi_structured_sparse.py index 38c2c6878b..919cde01a2 100644 --- a/advanced_source/semi_structured_sparse.py +++ b/advanced_source/semi_structured_sparse.py @@ -54,6 +54,8 @@ from torch.utils.benchmark import Timer SparseSemiStructuredTensor._FORCE_CUTLASS = True +torch.set_default_device("cuda:0") + # mask Linear weight to be 2:4 sparse mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool() linear = torch.nn.Linear(10240, 3072).half().cuda().eval()