Skip to content

[DeepSeek][kernels] index select permute, cuda #1083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

lessw2020
Copy link
Contributor

This PR adds an index_select_permute operation termed fast_permute_tokens.
Basically we do an index select on tokens to prep/move them into contiguous memory by expert.

I found that based on the problem size, one kernel does not fit all if we want to beat out PyTorch. So this featues an adaptive kernel that is really 4 kernels that are selected based on problem size.

Each kernel is small but progressively adds improvements needed to scale addressing the problem size.
Small = use shared memory for transfer, one thread per element
Medium = multiple tokens per thread block
Large = 2D grid, where each thread handles a tile of (tokens x features)
XL = 2D grid with templated vectorized memory

Initial perf and verification results:

=== Small Configuration ===
Benchmarking with batch_size=1024, hidden_dim=4096, n_indices=512
Verifying CUDA implementation matches PyTorch...
✓ Results match!
PyTorch:  0.020 ± 0.001 ms
CUDA:     0.011 ± 0.000 ms
Speedup:  1.84x

=== Medium Configuration ===
Benchmarking with batch_size=4096, hidden_dim=4096, n_indices=4096
Verifying CUDA implementation matches PyTorch...
✓ Results match!
PyTorch:  0.085 ± 0.002 ms
CUDA:     0.042 ± 0.001 ms
Speedup:  2.01x

=== Large Configuration ===
Benchmarking with batch_size=8192, hidden_dim=4096, n_indices=8192
Verifying CUDA implementation matches PyTorch...
✓ Results match!
PyTorch:  0.163 ± 0.001 ms
CUDA:     0.068 ± 0.001 ms
Speedup:  2.38x

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2025
@lessw2020 lessw2020 requested a review from kwen2501 April 10, 2025 03:44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you intend to upload this binary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for flagging...not really - I just forgot to remove it at the checkin. Let me add a .gitignore so I don't have to manually do it.

}

// medium kernel - multiple tokens per block
template <typename scalar_t, int TOKENS_PER_BLOCK, int THREADS_PER_BLOCK>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we'd need to template on THREADS_PER_BLOCK. Can you access it via blockDim.x in the kernel?

Comment on lines +32 to +35
// each thread loads features into smem
for (int i = thread_idx; i < feature_size; i += BLOCK_SIZE) {
shared_features[i] = input[src_idx * feature_size + i];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may be able to use cp.async.ca.shared.global to speedup the copy from global mem to shared mem. (It skips the registers).
Alternatively, you can use the Pipeline Primitive Interface: __pipeline_memcpy_async.
Some examples here.

Copy link
Contributor

@kwen2501 kwen2501 Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may not see instant speedup over current small sizes, but an opportunity may be extending the coverage of your small kernel to medium sizes, if using the above instructions and form a load-chunk-store-chunk pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ngimel.
I wonder if index_select uses shared mem?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_select in pytorch has a very slow kernel that no one bothered to update (we should). However, triton code generated by torch.compile for index_select is going pretty much at bandwidth limit, so I don't think async copy or tma are really needed here. Conceptually, if you are permuting contiguous slices (like you do here) the indexing math is pretty light, so there's nothing to hide with async copy, just vectorized stores will do.

Comment on lines +37 to +43
// wait for everyone to load...
__syncthreads();

// each thread writes features to output
for (int i = thread_idx; i < feature_size; i += BLOCK_SIZE) {
output[token_idx * feature_size + i] = shared_features[i];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there is no need for __syncthreads
A thread that loads shared_features[i] takes care of the write-out from shared_features[i].

@kwen2501
Copy link
Contributor

General comment: maybe we can use unrolling to speedup the kernels too.
Reason is that these kernels are mainly doing memory access. Unrolling can issue more commands to fill up the round-trip latency.

Copy link

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This optimizes performance for an extremely common function, and as such should go into pytorch core and not into torchtitan. Additionally, torch.compile already provides performance better than this kernels, so given that torch.titan relies on torch.compile already, the benefit of adding precompiled implementations is unclear

import torch

def fn_index_select(x, indices):
    return torch.index_select(x, dim=0, index=indices)

def fn_gather(x, indices):
    return torch.gather(x, dim=0, index=indices.unsqueeze(1).expand(-1, x.shape[1]))

def fn_scatter(dst, src, indices):
    return dst.scatter_(0, indices, src)

def fn_index(x, indices):
    return x[indices]



def scatter_vals(M, N, K):
    src = torch.randn([M, K], device="cuda", dtype=torch.float16).abs().float()
    dst = torch.randn([N, K], device="cuda", dtype=torch.float16).abs().float()
    dst_eager = dst.clone()
    dst_compile = dst.clone()
    if M == N:
        indices = torch.randperm(N, device="cuda", dtype=torch.int64)
    else:
        indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int64)
    indices = indices.unsqueeze(1).expand(-1, K)

    compiled_fn = torch.compile(fn_scatter)
    compiled_fn(dst_compile, src, indices)
    fn_scatter(dst_eager, src, indices)
    with torch.profiler.profile() as p:
        for _ in range(10):
            fn_scatter(dst_eager, src, indices)
            compiled_fn(dst_compile, src, indices)

    print("export trace")
    p.export_chrome_trace("/home/ngimel/sandbox/trace.json")

    #print(p.events())




def gather_vals(M, N, K):
    src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
    if M == N:
        indices = torch.randperm(N, device="cuda", dtype=torch.int64)
    else:
        indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int64)

    compiled_fn = torch.compile(fn_gather)
    compiled_fn(src, indices)
    res_index_select = fn_index_select(src, indices)
    res_gather = fn_gather(src, indices)
    torch.testing.assert_close(res_index_select, res_gather)

    with torch.profiler.profile() as p:
        for _ in range(10):
            fn_gather(src, indices)
            fn_index(src, indices)
            compiled_fn(src, indices)

    from torch.utils.benchmark import Timer
    t = Timer(stmt="fn_index(src, indices)", globals={"src": src, "indices": indices, "fn_index": fn_index})
    print("Torch time", t.blocked_autorange().mean)
    t = Timer(stmt="compiled_fn(src, indices)", globals={"src": src, "indices": indices, "compiled_fn": compiled_fn})
    print("Compiled time", t.blocked_autorange().mean)

    p.export_chrome_trace(f"trace{M}_{N}_{K}.json")

    #print(p.events())

#gather_vals(4096, 4096, 5120)
gather_vals(1024, 512, 4096)
gather_vals(4096, 4096, 4096)
gather_vals(8192, 8192, 4096)
#scatter_vals(8192, 8192, 5120)
Torch time 1.126700707245618e-05
Compiled time 2.5844732392579318e-05
Torch time 8.486244548112154e-05
Compiled time 3.984669018536806e-05
Torch time 0.0001635360140353441
Compiled time 6.696143606677652e-05

(ignore printed compile benchmark for the smallest shape, it's benchmarking overhead, the kernel itself is 2 us as seen in the profile)

@lessw2020
Copy link
Contributor Author

@ngimel - thanks for info above!
To your questions:
1 - "This optimizes performance for an extremely common function, and as such should go into pytorch core and not into torchtitan. "
The reason this was started was b/c PyTorch Eager was slow and was a hotspot in initial profiling, and easiest to dev work right in place. If it proved out, then could look at moving it to core...doing it in reverse would be too much overhead.

2 - "Additionally, torch.compile already provides performance better than this kernels, so given that torch.titan relies on torch.compile already, the benefit of adding precompiled implementations is unclear"
Correction - TorchTitan main relies on torch.compile. There you are right, we do use torch.compile for regional compilation of the 'main' models (llama3, etc) transformer blocks.
However, we are not using torch.compile generally in experimental to start, at least not until things are in good enough shape to consider moving into main and then sure we can add on compile.

Does it make sense to just upgrade the pytorch eager core code with the compile generated triton kernel then as a generic fix, so that folks that don't use compile can get a faster experience?

@ngimel
Copy link

ngimel commented Apr 16, 2025

we still try to avoid calling triton kernels in eager, given that they come with unpredictable recompilation. However, triton kernel is a proof that just vectorized loads/stores with simple 1d blocking is enough to get good perf, so we can just write this simple kernel in eager.

@ngimel
Copy link

ngimel commented Apr 16, 2025

pytorch/pytorch#151490

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants