Skip to content

Commit 2717bf6

Browse files
committed
add group index
1 parent 95f837e commit 2717bf6

File tree

4 files changed

+67
-7
lines changed

4 files changed

+67
-7
lines changed

src/KernelAbstractions.jl

+32-5
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,38 @@ macro synchronize()
113113
end
114114

115115
"""
116-
@index(Global)
117-
@index(Local)
118-
@index(Global, Cartesian)
116+
@index
117+
118+
The `@index` macro can be used to give you the index of a workitem within a kernel
119+
function. It supports both the production of a linear index or a cartesian index.
120+
A cartesian index is a general N-dimensional index that is derived from the iteration space.
121+
122+
# Index granularity
123+
124+
- `Global`: Used to access global memory.
125+
- `Group`: The index of the `workgroup`.
126+
- `Local`: The within `workgroup` index.
127+
128+
# Index kind
129+
130+
- `Linear`: Produces an `Int64` that can be used to linearly index into memory.
131+
- `Global`: Produces a `CartesianIndex{N}` that can be used to index into memory.
132+
133+
If the index kind is not provided it defaults to `Linear`, this is suspect to change.
134+
135+
# Examples
136+
137+
```julia
138+
@index(Global, Linear)
139+
@index(Global, Cartesian)
140+
@index(Local, Cartesian)
141+
@index(Group, Linear)
142+
@index(Global)
143+
```
119144
"""
120145
macro index(locale, args...)
121-
if !(locale === :Global || locale === :Local)
122-
error("@index requires as first argument either :Global or :Local")
146+
if !(locale === :Global || locale === :Local || locale === :Group)
147+
error("@index requires as first argument either :Global, :Local or :Group")
123148
end
124149

125150
if length(args) >= 1
@@ -142,9 +167,11 @@ end
142167
###
143168

144169
function __index_Local_Linear end
170+
function __index_Group_Linear end
145171
function __index_Global_Linear end
146172

147173
function __index_Local_Cartesian end
174+
function __index_Group_Cartesian end
148175
function __index_Global_Cartesian end
149176

150177
struct ConstAdaptor end

src/backends/cpu.jl

+9
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ end
8080
return @inbounds LinearIndices(indices)[idx]
8181
end
8282

83+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Group_Linear), idx::CartesianIndex)
84+
indices = blocks(__iterspace(ctx.metadata))
85+
return @inbounds LinearIndices(indices)[__groupindex(ctx.metadata)]
86+
end
87+
8388
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Linear), idx::CartesianIndex)
8489
I = @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
8590
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
@@ -89,6 +94,10 @@ end
8994
return idx
9095
end
9196

97+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Group_Cartesian), idx::CartesianIndex)
98+
__groupindex(ctx.metadata)
99+
end
100+
92101
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Cartesian), idx::CartesianIndex)
93102
return @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
94103
end

src/backends/cuda.jl

+8
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ end
116116
return CUDAnative.threadIdx().x
117117
end
118118

119+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Group_Linear))
120+
return CUDAnative.blockIdx().x
121+
end
122+
119123
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Linear))
120124
I = @inbounds expand(__iterspace(ctx.metadata), CUDAnative.blockIdx().x, CUDAnative.threadIdx().x)
121125
# TODO: This is unfortunate, can we get the linear index cheaper
@@ -126,6 +130,10 @@ end
126130
@inbounds workitems(__iterspace(ctx.metadata))[CUDAnative.threadIdx().x]
127131
end
128132

133+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Group_Cartesian))
134+
@inbounds blocks(__iterspace(ctx.metadata))[CUDAnative.blockIdx().x]
135+
end
136+
129137
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Cartesian))
130138
return @inbounds expand(__iterspace(ctx.metadata), CUDAnative.blockIdx().x, CUDAnative.threadIdx().x)
131139
end

test/test.jl

+18-2
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,31 @@ end
4444
end
4545
@kernel function index_linear_local(A)
4646
I = @index(Global, Linear)
47-
li = @index(Local, Linear)
48-
A[I] = li
47+
i = @index(Local, Linear)
48+
A[I] = i
49+
end
50+
@kernel function index_linear_group(A)
51+
I = @index(Global, Linear)
52+
i = @index(Group, Linear)
53+
A[I] = i
4954
end
5055
@kernel function index_cartesian_global(A)
5156
I = @index(Global, Cartesian)
5257
A[I] = I
5358
end
59+
@kernel function index_cartesian_local(A)
60+
I = @index(Global, Cartesian)
61+
i = @index(Local, Cartesian)
62+
A[I] = i
63+
end
64+
@kernel function index_cartesian_group(A)
65+
I = @index(Global, Cartesian)
66+
i = @index(Group, Cartesian)
67+
A[I] = i
68+
end
5469

5570
function indextest(backend, ArrayT)
71+
# TODO: add test for _group and _local_cartesian
5672
A = ArrayT{Int}(undef, 16, 16)
5773
wait(index_linear_global(backend, 8)(A, ndrange=length(A)))
5874
@test all(A .== LinearIndices(A))

0 commit comments

Comments
 (0)