Skip to content

Commit eb6d4f4

Browse files
Funth0masebroecker
authored andcommitted
attr.ibify CompareResult in compare module (#327)
Fix accessing protected memebers.
1 parent bfdb02b commit eb6d4f4

File tree

1 file changed

+28
-31
lines changed

1 file changed

+28
-31
lines changed

src/canmatrix/compare.py

+28-31
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,23 @@
2929
import sys
3030
import typing
3131

32+
import attr
33+
3234
import canmatrix
3335

3436
logger = logging.getLogger(__name__)
3537
ConfigDict = typing.Optional[typing.Mapping[str, typing.Union[str, bool]]]
3638
WithAttribute = typing.TypeVar("WithAttribute", canmatrix.CanMatrix, canmatrix.Ecu, canmatrix.Frame, canmatrix.Signal)
3739

3840

41+
@attr.s
3942
class CompareResult(object):
4043
"""Hold comparison results in logical tree."""
41-
42-
def __init__(self, result=None, mtype=None, ref=None, changes=None):
43-
# type: (str, str, typing.Any, typing.List) -> None
44-
# any of equal, added, deleted, changed
45-
self.result = result
46-
# db, ecu, frame, signal, attribute
47-
self._type = mtype
48-
# reference to related object
49-
self._ref = ref
50-
self._changes = changes
51-
self._children = [] # type: typing.List[CompareResult]
44+
result = attr.ib(default=None) # type: typing.Optional[str] # any of equal, added, deleted, changed
45+
type = attr.ib(default=None) # type: typing.Optional[str] # db, ecu, frame, signal, signalGroup or attribute
46+
ref = attr.ib(default=None) # type: typing.Any # reference to related object
47+
changes = attr.ib(default=None) # type: typing.Optional[typing.List]
48+
_children = attr.ib(factory=list) # type: typing.List[CompareResult] # nested CompareResults
5249

5350
def add_child(self, child):
5451
# type: (CompareResult) -> None
@@ -112,15 +109,15 @@ def compare_db(db1, db2, ignore=None):
112109
db2.global_defines))
113110

114111
temp = compare_define_list(db1.ecu_defines, db2.ecu_defines)
115-
temp._type = "ECU Defines"
112+
temp.type = "ECU Defines"
116113
result.add_child(temp)
117114

118115
temp = compare_define_list(db1.frame_defines, db2.frame_defines)
119-
temp._type = "Frame Defines"
116+
temp.type = "Frame Defines"
120117
result.add_child(temp)
121118

122119
temp = compare_define_list(db1.signal_defines, db2.signal_defines)
123-
temp._type = "Signal Defines"
120+
temp.type = "Signal Defines"
124121
result.add_child(temp)
125122

126123
if "VALUETABLES" in ignore and ignore["VALUETABLES"]:
@@ -475,33 +472,33 @@ def compare_signal(s1, s2, ignore=None):
475472

476473
def dump_result(res, depth=0):
477474
# type: (CompareResult, int) -> None
478-
if res._type is not None and res.result != "equal":
475+
if res.type is not None and res.result != "equal":
479476
for _ in range(0, depth):
480477
print(" ", end=' ')
481-
print(res._type + " " + res.result + " ", end=' ')
482-
if hasattr(res._ref, 'name'):
483-
print(res._ref.name)
478+
print(res.type + " " + res.result + " ", end=' ')
479+
if hasattr(res.ref, 'name'):
480+
print(res.ref.name)
484481
else:
485482
print(" ")
486-
if res._changes is not None and res._changes[0] is not None and res._changes[1] is not None:
483+
if res.changes is not None and res.changes[0] is not None and res.changes[1] is not None:
487484
for _ in range(0, depth):
488485
print(" ", end=' ')
489-
print(type(res._changes[0]))
486+
print(type(res.changes[0]))
490487
if sys.version_info[0] < 3:
491-
if isinstance(res._changes[0], type(u'')):
492-
res._changes[0] = res._changes[0].encode('ascii', 'ignore')
493-
if isinstance(res._changes[1], type(u'')):
494-
res._changes[1] = res._changes[1].encode('ascii', 'ignore')
488+
if isinstance(res.changes[0], type(u'')):
489+
res.changes[0] = res.changes[0].encode('ascii', 'ignore')
490+
if isinstance(res.changes[1], type(u'')):
491+
res.changes[1] = res.changes[1].encode('ascii', 'ignore')
495492
else:
496-
if type(res._changes[0]) == str:
497-
res._changes[0] = res._changes[0].encode('ascii', 'ignore')
498-
if type(res._changes[1]) == str:
499-
res._changes[1] = res._changes[1].encode('ascii', 'ignore')
493+
if type(res.changes[0]) == str:
494+
res.changes[0] = res.changes[0].encode('ascii', 'ignore')
495+
if type(res.changes[1]) == str:
496+
res.changes[1] = res.changes[1].encode('ascii', 'ignore')
500497
print("old: " +
501-
str(res._changes[0]) +
498+
str(res.changes[0]) +
502499
" new: " +
503-
str(res._changes[1]))
504-
for child in res._children:
500+
str(res.changes[1]))
501+
for child in res.children:
505502
dump_result(child, depth + 1)
506503

507504

0 commit comments

Comments
 (0)