-
Notifications
You must be signed in to change notification settings - Fork 346
[DeepSeek][kernels] index select permute, cuda #1083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you intend to upload this binary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for flagging...not really - I just forgot to remove it at the checkin. Let me add a .gitignore so I don't have to manually do it.
} | ||
|
||
// medium kernel - multiple tokens per block | ||
template <typename scalar_t, int TOKENS_PER_BLOCK, int THREADS_PER_BLOCK> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we'd need to template on THREADS_PER_BLOCK
. Can you access it via blockDim.x
in the kernel?
// each thread loads features into smem | ||
for (int i = thread_idx; i < feature_size; i += BLOCK_SIZE) { | ||
shared_features[i] = input[src_idx * feature_size + i]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may be able to use cp.async.ca.shared.global
to speedup the copy from global mem to shared mem. (It skips the registers).
Alternatively, you can use the Pipeline Primitive Interface: __pipeline_memcpy_async
.
Some examples here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may not see instant speedup over current small sizes, but an opportunity may be extending the coverage of your small kernel to medium sizes, if using the above instructions and form a load-chunk-store-chunk pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ngimel.
I wonder if index_select
uses shared mem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index_select
in pytorch has a very slow kernel that no one bothered to update (we should). However, triton code generated by torch.compile
for index_select is going pretty much at bandwidth limit, so I don't think async copy or tma are really needed here. Conceptually, if you are permuting contiguous slices (like you do here) the indexing math is pretty light, so there's nothing to hide with async copy, just vectorized stores will do.
// wait for everyone to load... | ||
__syncthreads(); | ||
|
||
// each thread writes features to output | ||
for (int i = thread_idx; i < feature_size; i += BLOCK_SIZE) { | ||
output[token_idx * feature_size + i] = shared_features[i]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems there is no need for __syncthreads
A thread that loads shared_features[i]
takes care of the write-out from shared_features[i]
.
General comment: maybe we can use unrolling to speedup the kernels too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This optimizes performance for an extremely common function, and as such should go into pytorch core and not into torchtitan. Additionally, torch.compile already provides performance better than this kernels, so given that torch.titan relies on torch.compile already, the benefit of adding precompiled implementations is unclear
import torch
def fn_index_select(x, indices):
return torch.index_select(x, dim=0, index=indices)
def fn_gather(x, indices):
return torch.gather(x, dim=0, index=indices.unsqueeze(1).expand(-1, x.shape[1]))
def fn_scatter(dst, src, indices):
return dst.scatter_(0, indices, src)
def fn_index(x, indices):
return x[indices]
def scatter_vals(M, N, K):
src = torch.randn([M, K], device="cuda", dtype=torch.float16).abs().float()
dst = torch.randn([N, K], device="cuda", dtype=torch.float16).abs().float()
dst_eager = dst.clone()
dst_compile = dst.clone()
if M == N:
indices = torch.randperm(N, device="cuda", dtype=torch.int64)
else:
indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int64)
indices = indices.unsqueeze(1).expand(-1, K)
compiled_fn = torch.compile(fn_scatter)
compiled_fn(dst_compile, src, indices)
fn_scatter(dst_eager, src, indices)
with torch.profiler.profile() as p:
for _ in range(10):
fn_scatter(dst_eager, src, indices)
compiled_fn(dst_compile, src, indices)
print("export trace")
p.export_chrome_trace("/home/ngimel/sandbox/trace.json")
#print(p.events())
def gather_vals(M, N, K):
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
if M == N:
indices = torch.randperm(N, device="cuda", dtype=torch.int64)
else:
indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int64)
compiled_fn = torch.compile(fn_gather)
compiled_fn(src, indices)
res_index_select = fn_index_select(src, indices)
res_gather = fn_gather(src, indices)
torch.testing.assert_close(res_index_select, res_gather)
with torch.profiler.profile() as p:
for _ in range(10):
fn_gather(src, indices)
fn_index(src, indices)
compiled_fn(src, indices)
from torch.utils.benchmark import Timer
t = Timer(stmt="fn_index(src, indices)", globals={"src": src, "indices": indices, "fn_index": fn_index})
print("Torch time", t.blocked_autorange().mean)
t = Timer(stmt="compiled_fn(src, indices)", globals={"src": src, "indices": indices, "compiled_fn": compiled_fn})
print("Compiled time", t.blocked_autorange().mean)
p.export_chrome_trace(f"trace{M}_{N}_{K}.json")
#print(p.events())
#gather_vals(4096, 4096, 5120)
gather_vals(1024, 512, 4096)
gather_vals(4096, 4096, 4096)
gather_vals(8192, 8192, 4096)
#scatter_vals(8192, 8192, 5120)
Torch time 1.126700707245618e-05
Compiled time 2.5844732392579318e-05
Torch time 8.486244548112154e-05
Compiled time 3.984669018536806e-05
Torch time 0.0001635360140353441
Compiled time 6.696143606677652e-05
(ignore printed compile benchmark for the smallest shape, it's benchmarking overhead, the kernel itself is 2 us as seen in the profile)
@ngimel - thanks for info above! 2 - "Additionally, torch.compile already provides performance better than this kernels, so given that torch.titan relies on torch.compile already, the benefit of adding precompiled implementations is unclear" Does it make sense to just upgrade the pytorch eager core code with the compile generated triton kernel then as a generic fix, so that folks that don't use compile can get a faster experience? |
we still try to avoid calling triton kernels in eager, given that they come with unpredictable recompilation. However, triton kernel is a proof that just vectorized loads/stores with simple 1d blocking is enough to get good perf, so we can just write this simple kernel in eager. |
This PR adds an index_select_permute operation termed fast_permute_tokens.
Basically we do an index select on tokens to prep/move them into contiguous memory by expert.
I found that based on the problem size, one kernel does not fit all if we want to beat out PyTorch. So this featues an adaptive kernel that is really 4 kernels that are selected based on problem size.
Each kernel is small but progressively adds improvements needed to scale addressing the problem size.
Small = use shared memory for transfer, one thread per element
Medium = multiple tokens per thread block
Large = 2D grid, where each thread handles a tile of (tokens x features)
XL = 2D grid with templated vectorized memory
Initial perf and verification results: