Skip to content

Commit e6842e5

Browse files
authored
Add support for unary functions (#235)
1 parent f88bfba commit e6842e5

File tree

6 files changed

+188
-11
lines changed

6 files changed

+188
-11
lines changed

docs/guide/classes.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ Also note the shorthand signatures:
167167

168168
```zig
169169
const binaryfunc = fn(*Self, object) !object;
170+
const unaryfunc = fn(*Self) !object;
171+
const inquiry = fn(*Self) !bool;
170172
```
171173

172174
### Type Methods
@@ -210,7 +212,6 @@ The remaining mapping methods are yet to be implemented.
210212
| `__gt__` | `#!zig fn(*Self, object) !bool` |
211213
| `__ge__` | `#!zig fn(*Self, object) !bool` |
212214

213-
214215
!!! note
215216

216217
By default, `__ne__` will delegate to the negation of `__eq__` if it is defined.
@@ -258,6 +259,14 @@ to implement the full comparison logic in a single `__richcompare__` function.
258259
| `__ifloordiv__` | `binaryfunc` |
259260
| `__matmul__` | `binaryfunc` |
260261
| `__imatmul__` | `binaryfunc` |
262+
| `__neg__` | `unaryfunc` |
263+
| `__pos__` | `unaryfunc` |
264+
| `__abs__` | `unaryfunc` |
265+
| `__invert__` | `unaryfunc` |
266+
| `__int__` | `unaryfunc` |
267+
| `__float__` | `unaryfunc` |
268+
| `__index__` | `unaryfunc` |
269+
| `__bool__` | `inquiry` |
261270

262271
!!! note
263272

example/operators.pyi

+45
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,48 @@ class Ops:
326326
"""
327327
...
328328
def num(self, /): ...
329+
330+
class UnaryOps:
331+
def __init__(self, num, /):
332+
pass
333+
def __neg__(self, /):
334+
"""
335+
-self
336+
"""
337+
...
338+
def __pos__(self, /):
339+
"""
340+
+self
341+
"""
342+
...
343+
def __abs__(self, /):
344+
"""
345+
abs(self)
346+
"""
347+
...
348+
def __bool__(self, /):
349+
"""
350+
True if self else False
351+
"""
352+
...
353+
def __invert__(self, /):
354+
"""
355+
~self
356+
"""
357+
...
358+
def __int__(self, /):
359+
"""
360+
int(self)
361+
"""
362+
...
363+
def __float__(self, /):
364+
"""
365+
float(self)
366+
"""
367+
...
368+
def __index__(self, /):
369+
"""
370+
Return self converted to an integer, if self is suitable for use as an index into a list.
371+
"""
372+
...
373+
def num(self, /): ...

example/operators.zig

+47
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,53 @@ pub const Ops = py.class(struct {
165165
});
166166
// --8<-- [end:all]
167167

168+
pub const UnaryOps = py.class(struct {
169+
const Self = @This();
170+
171+
num: i64,
172+
173+
pub fn __init__(self: *Self, args: struct { num: i64 }) !void {
174+
self.num = args.num;
175+
}
176+
177+
pub fn num(self: *const Self) i64 {
178+
return self.num;
179+
}
180+
181+
pub fn __neg__(self: *Self) !py.PyLong {
182+
return py.PyLong.create(-self.num);
183+
}
184+
185+
pub fn __pos__(self: *Self) !*Self {
186+
py.incref(self);
187+
return self;
188+
}
189+
190+
pub fn __abs__(self: *Self) !*Self {
191+
return py.init(Self, .{ .num = @as(i64, @intCast(std.math.absCast(self.num))) });
192+
}
193+
194+
pub fn __invert__(self: *Self) !*Self {
195+
return py.init(Self, .{ .num = ~self.num });
196+
}
197+
198+
pub fn __int__(self: *Self) !py.PyLong {
199+
return py.PyLong.create(self.num);
200+
}
201+
202+
pub fn __float__(self: *Self) !py.PyFloat {
203+
return py.PyFloat.create(@as(f64, @floatFromInt(self.num)));
204+
}
205+
206+
pub fn __index__(self: *Self) !py.PyLong {
207+
return py.PyLong.create(self.num);
208+
}
209+
210+
pub fn __bool__(self: *Self) !bool {
211+
return self.num == 1;
212+
}
213+
});
214+
168215
// --8<-- [start:ops]
169216
pub const Operator = py.class(struct {
170217
const Self = @This();

pydust/src/functions.zig

+20-10
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ pub const Signature = struct {
4242
}
4343
};
4444

45+
pub const UnaryOperators = std.ComptimeStringMap(c_int, .{
46+
.{ "__neg__", ffi.Py_nb_negative },
47+
.{ "__pos__", ffi.Py_nb_positive },
48+
.{ "__abs__", ffi.Py_nb_absolute },
49+
.{ "__invert__", ffi.Py_nb_invert },
50+
.{ "__int__", ffi.Py_nb_int },
51+
.{ "__float__", ffi.Py_nb_float },
52+
.{ "__index__", ffi.Py_nb_index },
53+
});
54+
4555
pub const BinaryOperators = std.ComptimeStringMap(c_int, .{
4656
.{ "__add__", ffi.Py_nb_add },
4757
.{ "__iadd__", ffi.Py_nb_inplace_add },
@@ -72,7 +82,6 @@ pub const BinaryOperators = std.ComptimeStringMap(c_int, .{
7282
.{ "__imatmul__", ffi.Py_nb_inplace_matrix_multiply },
7383
.{ "__getitem__", ffi.Py_mp_subscript },
7484
});
75-
pub const NBinaryOperators = BinaryOperators.kvs.len;
7685

7786
// TODO(marko): Move this somewhere.
7887
fn keys(comptime stringMap: type) [stringMap.kvs.len][]const u8 {
@@ -93,19 +102,20 @@ pub const compareFuncs = .{
93102
};
94103

95104
const reservedNames = .{
96-
"__new__",
97-
"__init__",
98-
"__len__",
99-
"__del__",
105+
"__bool__",
100106
"__buffer__",
101-
"__str__",
102-
"__repr__",
103-
"__release_buffer__",
107+
"__del__",
108+
"__hash__",
109+
"__init__",
104110
"__iter__",
111+
"__len__",
112+
"__new__",
105113
"__next__",
106-
"__hash__",
114+
"__release_buffer__",
115+
"__repr__",
107116
"__richcompare__",
108-
} ++ compareFuncs ++ keys(BinaryOperators);
117+
"__str__",
118+
} ++ compareFuncs ++ keys(BinaryOperators) ++ keys(UnaryOperators);
109119

110120
/// Parse the arguments of a Zig function into a Pydust function siganture.
111121
pub fn parseSignature(comptime name: []const u8, comptime func: Type.Fn, comptime SelfTypes: []const type) Signature {

pydust/src/pytypes.zig

+43
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,13 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
215215
}};
216216
}
217217

218+
if (@hasDecl(definition, "__bool__")) {
219+
slots_ = slots_ ++ .{ffi.PyType_Slot{
220+
.slot = ffi.Py_nb_bool,
221+
.pfunc = @ptrCast(@constCast(&nb_bool)),
222+
}};
223+
}
224+
218225
if (richcmp.hasCompare) {
219226
slots_ = slots_ ++ .{ffi.PyType_Slot{
220227
.slot = ffi.Py_tp_richcompare,
@@ -232,6 +239,16 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
232239
}
233240
}
234241

242+
for (funcs.UnaryOperators.kvs) |kv| {
243+
if (@hasDecl(definition, kv.key)) {
244+
const op = UnaryOperator(definition, kv.key);
245+
slots_ = slots_ ++ .{ffi.PyType_Slot{
246+
.slot = kv.value,
247+
.pfunc = @ptrCast(@constCast(&op.call)),
248+
}};
249+
}
250+
}
251+
235252
slots_ = slots_ ++ .{ffi.PyType_Slot{
236253
.slot = ffi.Py_tp_methods,
237254
.pfunc = @ptrCast(@constCast(&methods.pydefs)),
@@ -369,6 +386,12 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
369386
const result = tramp.coerceError(definition.__call__(self, call_args.argsStruct)) catch return null;
370387
return (py.createOwned(result) catch return null).py;
371388
}
389+
390+
fn nb_bool(pyself: *ffi.PyObject) callconv(.C) c_int {
391+
const self: *PyTypeStruct(definition) = @ptrCast(pyself);
392+
const result = tramp.coerceError(definition.__bool__(&self.state)) catch return -1;
393+
return @intCast(@intFromBool(result));
394+
}
372395
};
373396
}
374397

@@ -593,6 +616,26 @@ fn BinaryOperator(
593616
};
594617
}
595618

619+
fn UnaryOperator(
620+
comptime definition: type,
621+
comptime op: []const u8,
622+
) type {
623+
return struct {
624+
fn call(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject {
625+
const func = @field(definition, op);
626+
const typeInfo = @typeInfo(@TypeOf(func)).Fn;
627+
628+
if (typeInfo.params.len != 1) @compileError(op ++ " must take exactly one parameter");
629+
630+
// TODO(ngates): do we want to trampoline the self argument?
631+
const self: *PyTypeStruct(definition) = @ptrCast(pyself);
632+
633+
const result = tramp.coerceError(func(&self.state)) catch return null;
634+
return (py.createOwned(result) catch return null).py;
635+
}
636+
};
637+
}
638+
596639
fn EqualsOperator(
597640
comptime definition: type,
598641
comptime op: []const u8,

test/test_operators.py

+23
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,29 @@ def test_iops(iop, expected):
7171
assert ops.num() == expected
7272

7373

74+
@pytest.mark.parametrize(
75+
"op,expected",
76+
[
77+
(operator.pos, -3),
78+
(operator.neg, 3),
79+
(operator.invert, 2),
80+
(operator.index, -3),
81+
(operator.abs, 3),
82+
(bool, False),
83+
(int, -3),
84+
(float, -3.0),
85+
],
86+
)
87+
def test_unaryops(op, expected):
88+
ops = operators.UnaryOps(-3)
89+
res = op(ops)
90+
91+
if isinstance(res, operators.UnaryOps):
92+
assert res.num() == expected
93+
else:
94+
assert res == expected
95+
96+
7497
def test_divmod():
7598
ops = operators.Ops(3)
7699
other = operators.Ops(2)

0 commit comments

Comments
 (0)