From f406b88f3237d4ecd7661b4b8911eaf441935973 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Wed, 19 Mar 2025 15:06:05 -0700 Subject: [PATCH] Fixes #3303 --- advanced_source/semi_structured_sparse.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/advanced_source/semi_structured_sparse.py b/advanced_source/semi_structured_sparse.py index 38c2c6878b..919cde01a2 100644 --- a/advanced_source/semi_structured_sparse.py +++ b/advanced_source/semi_structured_sparse.py @@ -54,6 +54,8 @@ from torch.utils.benchmark import Timer SparseSemiStructuredTensor._FORCE_CUTLASS = True +torch.set_default_device("cuda:0") + # mask Linear weight to be 2:4 sparse mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool() linear = torch.nn.Linear(10240, 3072).half().cuda().eval()