Skip to content

Commit 1ae977b

Browse files
authored
Add py.unchecked for casting PyObject to Pydust class (#230)
TODO(ngates): we should store PyType objects on module state and then auto-traverse them. See #229 Fixes #226, #227, #228
1 parent 42632fb commit 1ae977b

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

pydust/src/conversions.zig

+26
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
const py = @import("./pydust.zig");
1414
const tramp = @import("./trampoline.zig");
15+
const pytypes = @import("./pytypes.zig");
16+
const State = @import("./discovery.zig").State;
1517

1618
/// Zig PyObject-like -> ffi.PyObject. Convert a Zig PyObject-like value into a py.PyObject.
1719
/// e.g. py.PyObject, py.PyTuple, ffi.PyObject, etc.
@@ -38,6 +40,30 @@ pub inline fn as(comptime T: type, obj: anytype) py.PyError!T {
3840
return tramp.Trampoline(T).unwrap(object(obj));
3941
}
4042

43+
/// Python -> Pydust. Perform a checked cast from a PyObject to a given PyDust class type.
44+
pub inline fn checked(comptime T: type, obj: py.PyObject) py.PyError!T {
45+
const definition = State.getDefinition(@typeInfo(T).Pointer.child);
46+
if (definition.type != .class) {
47+
@compileError("Can only perform checked cast into a PyDust class type");
48+
}
49+
50+
// TODO(ngates): to perform fast type checking, we need to store our PyType on the parent module.
51+
// See how the Python JSON module did this: https://github.com/python/cpython/commit/33f15a16d40cb8010a8c758952cbf88d7912ee2d#diff-efe183ae0b85e5b8d9bbbc588452dd4de80b39fd5c5174ee499ba554217a39edR1814
52+
// For now, we perform a slow import/isinstance check by using the `as` conversion.
53+
return as(T, obj);
54+
}
55+
56+
/// Python -> Pydust. Perform an unchecked cast from a PyObject to a given PyDust class type.
57+
pub inline fn unchecked(comptime T: type, obj: py.PyObject) T {
58+
const Definition = @typeInfo(T).Pointer.child;
59+
const definition = State.getDefinition(Definition);
60+
if (definition.type != .class) {
61+
@compileError("Can only perform unchecked cast into a PyDust class type. Found " ++ @typeName(Definition));
62+
}
63+
const instance: *pytypes.PyTypeStruct(Definition) = @ptrCast(@alignCast(obj.py));
64+
return &instance.state;
65+
}
66+
4167
const testing = @import("std").testing;
4268
const expect = testing.expect;
4369

pydust/src/functions.zig

+19-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ pub const Signature = struct {
3333
pub fn supportsKwargs(comptime self: @This()) bool {
3434
return self.nkwargs > 0 or self.varkwargsIdx != null;
3535
}
36+
37+
pub fn isModuleMethod(comptime self: @This()) bool {
38+
if (self.selfParam) |Self| {
39+
return State.getDefinition(@typeInfo(Self).Pointer.child).type == .module;
40+
}
41+
return false;
42+
}
3643
};
3744

3845
pub const BinaryOperators = std.ComptimeStringMap(c_int, .{
@@ -257,7 +264,8 @@ pub fn wrap(comptime definition: type, comptime func: anytype, comptime sig: Sig
257264
}
258265

259266
inline fn internal(pyself: py.PyObject, pyargs: []py.PyObject) PyError!py.PyObject {
260-
const self = if (sig.selfParam) |Self| try py.as(Self, pyself) else null;
267+
const self = if (sig.selfParam) |Self| try castSelf(Self, pyself) else null;
268+
261269
if (sig.argsParam) |Args| {
262270
const args = try unwrapArgs(Args, pyargs, py.Kwargs.init(py.allocator));
263271
const result = if (sig.selfParam) |_| func(self, args) else func(args);
@@ -302,10 +310,19 @@ pub fn wrap(comptime definition: type, comptime func: anytype, comptime sig: Sig
302310
pykwargs: py.Kwargs,
303311
) PyError!py.PyObject {
304312
const args = try unwrapArgs(sig.argsParam.?, pyargs, pykwargs);
305-
const self = if (sig.selfParam) |Self| try py.as(Self, pyself) else null;
313+
const self = if (sig.selfParam) |Self| try castSelf(Self, pyself) else null;
306314
const result = if (sig.selfParam) |_| func(self, args) else func(args);
307315
return py.createOwned(tramp.coerceError(result));
308316
}
317+
318+
inline fn castSelf(comptime Self: type, pyself: py.PyObject) !Self {
319+
if (comptime sig.isModuleMethod()) {
320+
const mod = py.PyModule{ .obj = pyself };
321+
return try mod.getState(@typeInfo(Self).Pointer.child);
322+
} else {
323+
return py.unchecked(Self, pyself);
324+
}
325+
}
309326
};
310327
}
311328

pydust/src/pytypes.zig

+1-1
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ fn RichCompare(comptime definition: type) type {
666666
const CompareOpArg = typeInfo.params[2].type.?;
667667
if (CompareOpArg != py.CompareOp) @compileError("Third parameter of __richcompare__ must be a py.CompareOp");
668668

669-
const self = py.as(Self, pyself) catch return null;
669+
const self = py.unchecked(Self, .{ .py = pyself });
670670
const otherArg = tramp.Trampoline(Other).unwrap(.{ .py = pyother }) catch return null;
671671
const opEnum: py.CompareOp = @enumFromInt(op);
672672

pydust/src/types/module.zig

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub const PyModule = extern struct {
3131
return .{ .obj = .{ .py = ffi.PyImport_ImportModule(name) orelse return PyError.PyRaised } };
3232
}
3333

34-
pub fn getState(self: PyModule, comptime state: type) !*state {
34+
pub fn getState(self: PyModule, comptime ModState: type) !*ModState {
3535
const statePtr = ffi.PyModule_GetState(self.obj.py) orelse return PyError.PyRaised;
3636
return @ptrCast(@alignCast(statePtr));
3737
}

0 commit comments

Comments
 (0)