8
8
import typing
9
9
from itertools import zip_longest
10
10
from typing import Any , Callable , Dict , List , Optional , Set , Type , Iterable
11
+ try :
12
+ import numpy as np
13
+ except ImportError :
14
+ import sys
15
+ print ("numpy is not installed" , file = sys .stderr )
11
16
12
17
from utbot_executor .deep_serialization .config import PICKLE_PROTO
13
18
from utbot_executor .deep_serialization .iterator_wrapper import IteratorWrapper
@@ -41,7 +46,7 @@ def __init__(self, obj: object) -> None:
41
46
self .id_ = PythonId (str (id (self .obj )))
42
47
43
48
def _initialize (
44
- self , deserialized_obj : object = None , comparable : bool = True
49
+ self , deserialized_obj : object = None , comparable : bool = True
45
50
) -> None :
46
51
self .deserialized_obj = deserialized_obj
47
52
self .comparable = comparable
@@ -111,14 +116,42 @@ def initialize(self) -> None:
111
116
elif self .typeinfo .fullname == "builtins.set" :
112
117
deserialized_obj = set (deserialized_obj )
113
118
119
+
114
120
comparable = all (serializer .get_by_id (elem ).comparable for elem in self .items )
115
121
116
122
super ()._initialize (deserialized_obj , comparable )
117
123
124
+ class NdarrayMemoryObject (MemoryObject ):
125
+ strategy : str = "ndarray"
126
+ items : List [PythonId ] = []
127
+ dimensions : List [int ] = []
128
+
129
+ def __init__ (self , ndarray_object : object ) -> None :
130
+ self .items : List [PythonId ] = []
131
+ super ().__init__ (ndarray_object )
132
+
133
+ def initialize (self ) -> None :
134
+ serializer = PythonSerializer ()
135
+ self .deserialized_obj = [] # for recursive collections
136
+ self .comparable = False # for recursive collections
137
+
138
+ temp_object = self .obj .copy ().flatten ()
139
+
140
+ self .dimensions = self .obj .shape
141
+ if temp_object .shape != (0 , ):
142
+ for elem in temp_object :
143
+ elem_id = serializer .write_object_to_memory (elem )
144
+ self .items .append (elem_id )
145
+ self .deserialized_obj .append (serializer [elem_id ])
146
+
147
+ deserialized_obj = self .deserialized_obj
148
+ comparable = all (serializer .get_by_id (elem ).comparable for elem in self .items ) if self .deserialized_obj != [] else True
149
+ super ()._initialize (deserialized_obj , comparable )
150
+
118
151
def __repr__ (self ) -> str :
119
152
if hasattr (self , "obj" ):
120
153
return str (self .obj )
121
- return f"{ self .typeinfo .kind } { self .items } "
154
+ return f"{ self .typeinfo .kind } { self .items } { self . dimensions } "
122
155
123
156
124
157
class DictMemoryObject (MemoryObject ):
@@ -264,10 +297,10 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
264
297
265
298
is_reconstructor = constructor_kind .qualname == "copyreg._reconstructor"
266
299
is_reduce_user_type = (
267
- len (self .reduce_value [1 ]) == 3
268
- and isinstance (self .reduce_value [1 ][0 ], type (self .obj ))
269
- and self .reduce_value [1 ][1 ] is object
270
- and self .reduce_value [1 ][2 ] is None
300
+ len (self .reduce_value [1 ]) == 3
301
+ and isinstance (self .reduce_value [1 ][0 ], type (self .obj ))
302
+ and self .reduce_value [1 ][1 ] is object
303
+ and self .reduce_value [1 ][2 ] is None
271
304
)
272
305
is_reduce_ex_user_type = len (self .reduce_value [1 ]) == 1 and isinstance (
273
306
self .reduce_value [1 ][0 ], type (self .obj )
@@ -294,8 +327,8 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
294
327
len (inspect .signature (init_method ).parameters ),
295
328
)
296
329
if (
297
- not init_from_object
298
- and len (inspect .signature (init_method ).parameters ) == 1
330
+ not init_from_object
331
+ and len (inspect .signature (init_method ).parameters ) == 1
299
332
) or init_from_object :
300
333
logging .debug ("init with one argument! %s" , init_method )
301
334
constructor_arguments = []
@@ -317,9 +350,9 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
317
350
if is_reconstructor and is_user_type :
318
351
constructor_arguments = self .reduce_value [1 ]
319
352
if (
320
- len (constructor_arguments ) == 3
321
- and constructor_arguments [- 1 ] is None
322
- and constructor_arguments [- 2 ] == object
353
+ len (constructor_arguments ) == 3
354
+ and constructor_arguments [- 1 ] is None
355
+ and constructor_arguments [- 2 ] == object
323
356
):
324
357
del constructor_arguments [1 :]
325
358
callable_constructor = object .__new__
@@ -392,6 +425,12 @@ def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
392
425
return ListMemoryObject
393
426
return None
394
427
428
+ class NdarrayMemoryObjectProvider (MemoryObjectProvider ):
429
+ @staticmethod
430
+ def get_serializer (obj : object ) -> Optional [Type [MemoryObject ]]:
431
+ if type (obj ) == np .ndarray :
432
+ return NdarrayMemoryObject
433
+ return None
395
434
396
435
class DictMemoryObjectProvider (MemoryObjectProvider ):
397
436
@staticmethod
@@ -425,6 +464,7 @@ def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
425
464
return None
426
465
427
466
467
+
428
468
class ReprMemoryObjectProvider (MemoryObjectProvider ):
429
469
@staticmethod
430
470
def get_serializer (obj : object ) -> Optional [Type [MemoryObject ]]:
@@ -450,6 +490,7 @@ class PythonSerializer:
450
490
visited : Set [PythonId ] = set ()
451
491
452
492
providers : List [MemoryObjectProvider ] = [
493
+ NdarrayMemoryObjectProvider ,
453
494
ListMemoryObjectProvider ,
454
495
DictMemoryObjectProvider ,
455
496
IteratorMemoryObjectProvider ,
0 commit comments