Skip to content

Commit 01d0269

Browse files
Implement FD shmem
Try dont_limit on recursive resolve_shmem methods Fixes + more dont limit Matrix field fixes Matrix field fixes DivergenceF2C fix MatrixField fixes Qualify DivergenceF2C wip Refactor + fixed space bug. All seems good. More tests.. Fixes Test updates Fixes Allow disabling shmem using broadcast style Fix fused cuda operations in LG Revert some unwanted pieces More fixes Format
1 parent 17095b4 commit 01d0269

14 files changed

+1343
-28
lines changed

.buildkite/pipeline.yml

+10
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,16 @@ steps:
607607
key: unit_spec_ops_plane
608608
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/spectralelement/plane.jl"
609609

610+
- label: "Unit: FD operator (shmem)"
611+
key: unit_fd_operator_shmem
612+
command:
613+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/unit_fd_ops_shared_memory.jl"
614+
- "julia --color=yes --project=.buildkite test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl"
615+
env:
616+
CLIMACOMMS_DEVICE: "CUDA"
617+
agents:
618+
slurm_gpus: 1
619+
610620
- label: "Unit: column"
611621
key: unit_column
612622
command:

ext/ClimaCoreCUDAExt.jl

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ include(joinpath("cuda", "operators_integral.jl"))
3232
include(joinpath("cuda", "remapping_interpolate_array.jl"))
3333
include(joinpath("cuda", "limiters.jl"))
3434
include(joinpath("cuda", "operators_sem_shmem.jl"))
35+
include(joinpath("cuda", "operators_fd_shmem_common.jl"))
36+
include(joinpath("cuda", "operators_fd_shmem.jl"))
3537
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
3638
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
3739
include(joinpath("cuda", "operators_spectral_element.jl"))

ext/cuda/data_layouts_threadblock.jl

+30
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,33 @@ end
289289
ij,
290290
slabidx,
291291
) = Operators.is_valid_index(space, ij, slabidx)
292+
293+
##### shmem fd kernel partition
294+
@inline function fd_stencil_partition(
295+
us::DataLayouts.UniversalSize,
296+
n_face_levels::Integer,
297+
n_max_threads::Integer = 256;
298+
)
299+
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
300+
Nvthreads = n_face_levels
301+
@assert Nvthreads <= maximum_allowable_threads()[1] "Number of vertical face levels cannot exceed $(maximum_allowable_threads()[1])"
302+
Nvblocks = cld(Nv, Nvthreads) # +1 may be needed to guarantee that shared memory is populated at the last cell face
303+
return (;
304+
threads = (Nvthreads,),
305+
blocks = (Nh, Nvblocks, Nq * Nq),
306+
Nvthreads,
307+
)
308+
end
309+
@inline function fd_stencil_universal_index(space::Spaces.AbstractSpace, us)
310+
(tv,) = CUDA.threadIdx()
311+
(h, bv, ij) = CUDA.blockIdx()
312+
v = tv + (bv - 1) * CUDA.blockDim().x
313+
(Nq, _, _, _, _) = DataLayouts.universal_size(us)
314+
if Nq * Nq < ij
315+
return CartesianIndex((-1, -1, 1, -1, -1))
316+
end
317+
@inbounds (i, j) = CartesianIndices((Nq, Nq))[ij].I
318+
return CartesianIndex((i, j, 1, v, h))
319+
end
320+
@inline fd_stencil_is_valid_index(I::CI5, us::UniversalSize) =
321+
1 I[5] DataLayouts.get_Nh(us)

ext/cuda/operators_fd_shmem.jl

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
2+
import CUDA
3+
import ClimaCore.Operators: return_eltype, get_local_geometry
4+
5+
Base.@propagate_inbounds function fd_operator_shmem(
6+
space,
7+
::Val{Nvt},
8+
op::Operators.DivergenceF2C,
9+
args...,
10+
) where {Nvt}
11+
# allocate temp output
12+
RT = return_eltype(op, args...)
13+
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
14+
return Ju³
15+
end
16+
17+
Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
18+
op::Operators.DivergenceF2C,
19+
Ju³,
20+
loc, # can be any location
21+
space,
22+
idx::Utilities.PlusHalf,
23+
hidx,
24+
arg,
25+
)
26+
@inbounds begin
27+
vt = threadIdx().x
28+
lg = Geometry.LocalGeometry(space, idx, hidx)
29+
= Operators.getidx(space, arg, loc, idx, hidx)
30+
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
31+
end
32+
return nothing
33+
end
34+
35+
Base.@propagate_inbounds function fd_operator_fill_shmem_left_boundary!(
36+
op::Operators.DivergenceF2C,
37+
bc::Operators.SetValue,
38+
Ju³,
39+
loc,
40+
space,
41+
idx::Utilities.PlusHalf,
42+
hidx,
43+
arg,
44+
)
45+
idx == Operators.left_face_boundary_idx(space) ||
46+
error("Incorrect left idx")
47+
@inbounds begin
48+
vt = threadIdx().x
49+
lg = Geometry.LocalGeometry(space, idx, hidx)
50+
= Operators.getidx(space, bc.val, loc, nothing, hidx)
51+
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
52+
end
53+
return nothing
54+
end
55+
56+
Base.@propagate_inbounds function fd_operator_fill_shmem_right_boundary!(
57+
op::Operators.DivergenceF2C,
58+
bc::Operators.SetValue,
59+
Ju³,
60+
loc,
61+
space,
62+
idx::Utilities.PlusHalf,
63+
hidx,
64+
arg,
65+
)
66+
# The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
67+
idx == Operators.right_face_boundary_idx(space) ||
68+
error("Incorrect right idx")
69+
@inbounds begin
70+
vt = threadIdx().x
71+
lg = Geometry.LocalGeometry(space, idx, hidx)
72+
= Operators.getidx(space, bc.val, loc, nothing, hidx)
73+
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
74+
end
75+
return nothing
76+
end
77+
78+
Base.@propagate_inbounds function fd_operator_evaluate(
79+
op::Operators.DivergenceF2C,
80+
Ju³,
81+
loc,
82+
space,
83+
idx::Integer,
84+
hidx,
85+
args...,
86+
)
87+
@inbounds begin
88+
vt = threadIdx().x
89+
local_geometry = Geometry.LocalGeometry(space, idx, hidx)
90+
Ju³₋ = Ju³[vt] # corresponds to idx - half
91+
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
92+
return (Ju³₊ Ju³₋) local_geometry.invJ
93+
end
94+
end

0 commit comments

Comments
 (0)