Skip to content

Commit 2afdea1

Browse files
committed
fixup N support
1 parent d7833d4 commit 2afdea1

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/KernelAbstractions.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ macro kernel(ex...)
8282
elseif ex[i] isa Expr && ex[i].head == :(=) &&
8383
ex[i].args[1] == :inbounds && ex[i].args[2] isa Bool
8484
force_inbounds = ex[i].args[2]
85+
elseif ex[i] isa Int
86+
N = StaticSize(ex[i])
8587
else
8688
error(
8789
"Configuration should be of form:\n" *
@@ -659,7 +661,10 @@ Partition a kernel for the given ndrange and workgroupsize.
659661
end
660662

661663
if static_ndims <: StaticSize
662-
@assert get(static_ndims) == length(ndrange)
664+
N = only(get(static_ndims))
665+
if N !== length(ndrange)
666+
error("Mismatch between static kernel dimension (N=$N) and ndrange=$ndrange")
667+
end
663668
end
664669

665670
# TODO: Add static_ndims

test/runtests.jl

+11
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ kern_static(CPU(static = true), (1,))(A, ndrange = length(A))
2121
end
2222
@test_throws ErrorException("This kernel is unavailable for backend CPU") my_no_cpu_kernel(CPU())
2323

24+
@kernel 1 function OneD()
25+
end
26+
27+
@kernel 2 function TwoD()
28+
end
29+
30+
@test OneD(CPU())(ndrange=1024) === nothing
31+
@test_throws ErrorException("Mismatch between static kernel dimension (N=1) and ndrange=(1024, 1)") OneD(CPU())(ndrange=(1024, 1))
32+
@test_throws ErrorException("Mismatch between static kernel dimension (N=2) and ndrange=(1024, 1)") TwoD(CPU())(ndrange=1024)
33+
@test TwoD(CPU())(ndrange=(1024,1)) === nothing
34+
2435
# testing multiple configurations at the same time
2536
@kernel cpu = false inbounds = false function my_no_cpu_kernel2(a)
2637
end

0 commit comments

Comments
 (0)