From 530890621359f19ea0d596836365188c8a4bc942 Mon Sep 17 00:00:00 2001
From: Christopher Doris <github.com/cjdoris>
Date: Wed, 24 Jul 2024 18:03:24 +0100
Subject: [PATCH] simplify GIL handling

---
 pysrc/juliacall/__init__.py |   2 +-
 src/C/C.jl                  |   1 -
 src/C/context.jl            | 120 +++++++++++++++++-------------------
 src/C/gil.jl                |  24 --------
 src/C/pointers.jl           |   2 +
 src/Compat/Compat.jl        |   8 +--
 src/Convert/Convert.jl      |  10 ++-
 src/Core/Core.jl            |  10 ++-
 src/GC/GC.jl                |  20 +++---
 src/JlWrap/C.jl             |   4 +-
 src/JlWrap/JlWrap.jl        |  42 ++++++-------
 11 files changed, 99 insertions(+), 144 deletions(-)
 delete mode 100644 src/C/gil.jl

diff --git a/pysrc/juliacall/__init__.py b/pysrc/juliacall/__init__.py
index 6a67df2a..c9c369a8 100644
--- a/pysrc/juliacall/__init__.py
+++ b/pysrc/juliacall/__init__.py
@@ -178,7 +178,7 @@ def args_from_config():
             os.environ['PATH'] = libdir
 
     # Open the library
-    CONFIG['lib'] = lib = c.CDLL(libpath, mode=c.RTLD_GLOBAL)
+    CONFIG['lib'] = lib = c.PyDLL(libpath, mode=c.RTLD_GLOBAL)
 
     # parse options
     argc, argv = args_from_config()
diff --git a/src/C/C.jl b/src/C/C.jl
index 64207763..4fb055a8 100644
--- a/src/C/C.jl
+++ b/src/C/C.jl
@@ -17,7 +17,6 @@ include("consts.jl")
 include("pointers.jl")
 include("extras.jl")
 include("context.jl")
-include("gil.jl")
 include("api.jl")
 
 function __init__()
diff --git a/src/C/context.jl b/src/C/context.jl
index b6e99949..57c7d3ad 100644
--- a/src/C/context.jl
+++ b/src/C/context.jl
@@ -145,63 +145,61 @@ function init_context()
         @require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" init_pycall(PyCall)
 
         # Initialize the interpreter
-        with_gil() do
-            CTX.is_preinitialized = Py_IsInitialized() != 0
-            if CTX.is_preinitialized
-                @assert CTX.which == :PyCall || CTX.matches_pycall isa Bool
+        CTX.is_preinitialized = Py_IsInitialized() != 0
+        if CTX.is_preinitialized
+            @assert CTX.which == :PyCall || CTX.matches_pycall isa Bool
+        else
+            @assert CTX.which != :PyCall
+            # Find ProgramName and PythonHome
+            script = if Sys.iswindows()
+                """
+                import sys
+                print(sys.executable)
+                if hasattr(sys, "base_exec_prefix"):
+                    sys.stdout.write(sys.base_exec_prefix)
+                else:
+                    sys.stdout.write(sys.exec_prefix)
+                """
             else
-                @assert CTX.which != :PyCall
-                # Find ProgramName and PythonHome
-                script = if Sys.iswindows()
-                    """
-                    import sys
-                    print(sys.executable)
-                    if hasattr(sys, "base_exec_prefix"):
-                        sys.stdout.write(sys.base_exec_prefix)
-                    else:
-                        sys.stdout.write(sys.exec_prefix)
-                    """
-                else
-                    """
-                    import sys
-                    print(sys.executable)
-                    if hasattr(sys, "base_exec_prefix"):
-                        sys.stdout.write(sys.base_prefix)
-                        sys.stdout.write(":")
-                        sys.stdout.write(sys.base_exec_prefix)
-                    else:
-                        sys.stdout.write(sys.prefix)
-                        sys.stdout.write(":")
-                        sys.stdout.write(sys.exec_prefix)
-                    """
-                end
-                CTX.pyprogname, CTX.pyhome = readlines(python_cmd(["-c", script]))
+                """
+                import sys
+                print(sys.executable)
+                if hasattr(sys, "base_exec_prefix"):
+                    sys.stdout.write(sys.base_prefix)
+                    sys.stdout.write(":")
+                    sys.stdout.write(sys.base_exec_prefix)
+                else:
+                    sys.stdout.write(sys.prefix)
+                    sys.stdout.write(":")
+                    sys.stdout.write(sys.exec_prefix)
+                """
+            end
+            CTX.pyprogname, CTX.pyhome = readlines(python_cmd(["-c", script]))
 
-                # Set PythonHome
-                CTX.pyhome_w = Base.cconvert(Cwstring, CTX.pyhome)
-                Py_SetPythonHome(pointer(CTX.pyhome_w))
+            # Set PythonHome
+            CTX.pyhome_w = Base.cconvert(Cwstring, CTX.pyhome)
+            Py_SetPythonHome(pointer(CTX.pyhome_w))
 
-                # Set ProgramName
-                CTX.pyprogname_w = Base.cconvert(Cwstring, CTX.pyprogname)
-                Py_SetProgramName(pointer(CTX.pyprogname_w))
+            # Set ProgramName
+            CTX.pyprogname_w = Base.cconvert(Cwstring, CTX.pyprogname)
+            Py_SetProgramName(pointer(CTX.pyprogname_w))
 
-                # Start the interpreter and register exit hooks
-                Py_InitializeEx(0)
-                atexit() do
-                    CTX.is_initialized = false
-                    if CTX.version === missing || CTX.version < v"3.6"
-                        Py_Finalize()
-                    else
-                        if Py_FinalizeEx() == -1
-                            @warn "Py_FinalizeEx() error"
-                        end
+            # Start the interpreter and register exit hooks
+            Py_InitializeEx(0)
+            atexit() do
+                CTX.is_initialized = false
+                if CTX.version === missing || CTX.version < v"3.6"
+                    Py_Finalize()
+                else
+                    if Py_FinalizeEx() == -1
+                        @warn "Py_FinalizeEx() error"
                     end
                 end
             end
-            CTX.is_initialized = true
-            if Py_AtExit(@cfunction(_atpyexit, Cvoid, ())) == -1
-                @warn "Py_AtExit() error"
-            end
+        end
+        CTX.is_initialized = true
+        if Py_AtExit(@cfunction(_atpyexit, Cvoid, ())) == -1
+            @warn "Py_AtExit() error"
         end
     end
 
@@ -218,20 +216,16 @@ function init_context()
         ENV["JULIA_PYTHONCALL_EXE"] = CTX.exe_path::String
     end
 
-    with_gil() do
-
-        # Get the python version
-        verstr = Base.unsafe_string(Py_GetVersion())
-        vermatch = match(r"^[0-9.]+", verstr)
-        if vermatch === nothing
-            error("Cannot parse version from version string: $(repr(verstr))")
-        end
-        CTX.version = VersionNumber(vermatch.match)
-        v"3.5" ≤ CTX.version < v"4" || error(
-            "Only Python 3.5+ is supported, this is Python $(CTX.version) at $(CTX.exe_path===missing ? "unknown location" : CTX.exe_path).",
-        )
-
+    # Get the python version
+    verstr = Base.unsafe_string(Py_GetVersion())
+    vermatch = match(r"^[0-9.]+", verstr)
+    if vermatch === nothing
+        error("Cannot parse version from version string: $(repr(verstr))")
     end
+    CTX.version = VersionNumber(vermatch.match)
+    v"3.5" ≤ CTX.version < v"4" || error(
+        "Only Python 3.5+ is supported, this is Python $(CTX.version) at $(CTX.exe_path===missing ? "unknown location" : CTX.exe_path).",
+    )
 
     @debug "Initialized PythonCall.jl" CTX.is_embedded CTX.is_initialized CTX.exe_path CTX.lib_path CTX.lib_ptr CTX.pyprogname CTX.pyhome CTX.version
 
diff --git a/src/C/gil.jl b/src/C/gil.jl
deleted file mode 100644
index 50d46707..00000000
--- a/src/C/gil.jl
+++ /dev/null
@@ -1,24 +0,0 @@
-"""
-    with_gil(f, [c=true])
-
-Compute `f()` with the GIL enabled.
-
-This may need a `try-finally` block to ensure the GIL is released again. If you know that `f` cannot throw, pass `c=false` to avoid this overhead.
-"""
-@inline function with_gil(f, c::Bool = true)
-    if !CTX.is_embedded
-        f()
-    elseif c
-        g = PyGILState_Ensure()
-        try
-            f()
-        finally
-            PyGILState_Release(g)
-        end
-    else
-        g = PyGILState_Ensure()
-        r = f()
-        PyGILState_Release(g)
-        r
-    end
-end
diff --git a/src/C/pointers.jl b/src/C/pointers.jl
index dd0476fc..6faabb60 100644
--- a/src/C/pointers.jl
+++ b/src/C/pointers.jl
@@ -22,6 +22,8 @@ const CAPI_FUNC_SIGS = Dict{Symbol,Pair{Tuple,Type}}(
     :PyEval_RestoreThread => (Ptr{Cvoid},) => Cvoid,
     :PyGILState_Ensure => () => PyGILState_STATE,
     :PyGILState_Release => (PyGILState_STATE,) => Cvoid,
+    :PyGILState_GetThisThreadState => () => Ptr{Cvoid},
+    :PyGILState_Check => () => Cint,
     # IMPORT
     :PyImport_ImportModule => (Ptr{Cchar},) => PyPtr,
     :PyImport_Import => (PyPtr,) => PyPtr,
diff --git a/src/Compat/Compat.jl b/src/Compat/Compat.jl
index 487fe8a3..0efd24b7 100644
--- a/src/Compat/Compat.jl
+++ b/src/Compat/Compat.jl
@@ -34,11 +34,9 @@ include("tables.jl")
 include("pycall.jl")
 
 function __init__()
-    C.with_gil() do
-        init_gui()
-        init_pyshow()
-        init_tables()
-    end
+    init_gui()
+    init_pyshow()
+    init_tables()
     @require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" init_pycall(PyCall)
 end
 end
diff --git a/src/Convert/Convert.jl b/src/Convert/Convert.jl
index 81b7260f..a97f1a48 100644
--- a/src/Convert/Convert.jl
+++ b/src/Convert/Convert.jl
@@ -55,12 +55,10 @@ include("numpy.jl")
 include("pandas.jl")
 
 function __init__()
-    C.with_gil() do
-        init_pyconvert()
-        init_ctypes()
-        init_numpy()
-        init_pandas()
-    end
+    init_pyconvert()
+    init_ctypes()
+    init_numpy()
+    init_pandas()
 end
 
 end
diff --git a/src/Core/Core.jl b/src/Core/Core.jl
index e84f24d1..30d185a7 100644
--- a/src/Core/Core.jl
+++ b/src/Core/Core.jl
@@ -39,12 +39,10 @@ include("juliacall.jl")
 include("pyconst_macro.jl")
 
 function __init__()
-    C.with_gil() do
-        init_consts()
-        init_datetime()
-        init_stdlib()
-        init_juliacall()
-    end
+    init_consts()
+    init_datetime()
+    init_stdlib()
+    init_juliacall()
 end
 
 end
diff --git a/src/GC/GC.jl b/src/GC/GC.jl
index 0d1fa9a8..7bccfadc 100644
--- a/src/GC/GC.jl
+++ b/src/GC/GC.jl
@@ -40,11 +40,9 @@ Like most PythonCall functions, you must only call this from the main thread.
 function enable()
     ENABLED[] = true
     if !isempty(QUEUE)
-        C.with_gil(false) do
-            for ptr in QUEUE
-                if ptr != C.PyNULL
-                    C.Py_DecRef(ptr)
-                end
+        for ptr in QUEUE
+            if ptr != C.PyNULL
+                C.Py_DecRef(ptr)
             end
         end
     end
@@ -55,9 +53,7 @@ end
 function enqueue(ptr::C.PyPtr)
     if ptr != C.PyNULL && C.CTX.is_initialized
         if ENABLED[]
-            C.with_gil(false) do
-                C.Py_DecRef(ptr)
-            end
+            C.Py_DecRef(ptr)
         else
             push!(QUEUE, ptr)
         end
@@ -68,11 +64,9 @@ end
 function enqueue_all(ptrs)
     if C.CTX.is_initialized
         if ENABLED[]
-            C.with_gil(false) do
-                for ptr in ptrs
-                    if ptr != C.PyNULL
-                        C.Py_DecRef(ptr)
-                    end
+            for ptr in ptrs
+                if ptr != C.PyNULL
+                    C.Py_DecRef(ptr)
                 end
             end
         else
diff --git a/src/JlWrap/C.jl b/src/JlWrap/C.jl
index ed1ab655..7b181881 100644
--- a/src/JlWrap/C.jl
+++ b/src/JlWrap/C.jl
@@ -334,9 +334,7 @@ function init_c()
 end
 
 function __init__()
-    C.with_gil() do
-        init_c()
-    end
+    init_c()
 end
 
 PyJuliaValue_IsNull(o::C.PyPtr) = UnsafePtr{PyJuliaValueObject}(o).value[] == 0
diff --git a/src/JlWrap/JlWrap.jl b/src/JlWrap/JlWrap.jl
index aff3e833..2a55123a 100644
--- a/src/JlWrap/JlWrap.jl
+++ b/src/JlWrap/JlWrap.jl
@@ -65,28 +65,26 @@ include("set.jl")
 include("callback.jl")
 
 function __init__()
-    Cjl.C.with_gil() do
-        init_base()
-        init_raw()
-        init_any()
-        init_iter()
-        init_type()
-        init_module()
-        init_io()
-        init_number()
-        init_array()
-        init_vector()
-        init_dict()
-        init_set()
-        init_callback()
-        # add packages to juliacall
-        jl = pyjuliacallmodule
-        jl.Core = Base.Core
-        jl.Base = Base
-        jl.Main = Main
-        jl.Pkg = Pkg
-        jl.PythonCall = PythonCall
-    end
+    init_base()
+    init_raw()
+    init_any()
+    init_iter()
+    init_type()
+    init_module()
+    init_io()
+    init_number()
+    init_array()
+    init_vector()
+    init_dict()
+    init_set()
+    init_callback()
+    # add packages to juliacall
+    jl = pyjuliacallmodule
+    jl.Core = Base.Core
+    jl.Base = Base
+    jl.Main = Main
+    jl.Pkg = Pkg
+    jl.PythonCall = PythonCall
 end
 
 end