@@ -285,6 +285,11 @@ def flatten(node):
285
285
return '' .join (acc )
286
286
287
287
288
+ def make_xml (text ):
289
+ xml = ET .XML ('<xml>%s</xml>' % text )
290
+ return xml
291
+
292
+
288
293
def normalize_xpath (path ):
289
294
path = path .replace ("{{channel}}" , channel )
290
295
if path .startswith ('//' ):
@@ -401,7 +406,7 @@ def get_tree_count(tree, path):
401
406
return len (tree .findall (path ))
402
407
403
408
404
- def check_snapshot (snapshot_name , tree , normalize_to_text ):
409
+ def check_snapshot (snapshot_name , actual_tree , normalize_to_text ):
405
410
assert rust_test_path .endswith ('.rs' )
406
411
snapshot_path = '{}.{}.{}' .format (rust_test_path [:- 3 ], snapshot_name , 'html' )
407
412
try :
@@ -414,11 +419,15 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
414
419
raise FailedCheck ('No saved snapshot value' )
415
420
416
421
if not normalize_to_text :
417
- actual_str = ET .tostring (tree ).decode ('utf-8' )
422
+ actual_str = ET .tostring (actual_tree ).decode ('utf-8' )
418
423
else :
419
- actual_str = flatten (tree )
424
+ actual_str = flatten (actual_tree )
425
+
426
+ if not expected_str \
427
+ or (not normalize_to_text and
428
+ not compare_tree (make_xml (actual_str ), make_xml (expected_str ), stderr )) \
429
+ or (normalize_to_text and actual_str != expected_str ):
420
430
421
- if expected_str != actual_str :
422
431
if bless :
423
432
with open (snapshot_path , 'w' ) as snapshot_file :
424
433
snapshot_file .write (actual_str )
@@ -430,6 +439,59 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
430
439
print ()
431
440
raise FailedCheck ('Actual snapshot value is different than expected' )
432
441
442
+
443
+ # Adapted from https://github.com/formencode/formencode/blob/3a1ba9de2fdd494dd945510a4568a3afeddb0b2e/formencode/doctest_xml_compare.py#L72-L120
444
+ def compare_tree (x1 , x2 , reporter = None ):
445
+ if x1 .tag != x2 .tag :
446
+ if reporter :
447
+ reporter ('Tags do not match: %s and %s' % (x1 .tag , x2 .tag ))
448
+ return False
449
+ for name , value in x1 .attrib .items ():
450
+ if x2 .attrib .get (name ) != value :
451
+ if reporter :
452
+ reporter ('Attributes do not match: %s=%r, %s=%r'
453
+ % (name , value , name , x2 .attrib .get (name )))
454
+ return False
455
+ for name in x2 .attrib :
456
+ if name not in x1 .attrib :
457
+ if reporter :
458
+ reporter ('x2 has an attribute x1 is missing: %s'
459
+ % name )
460
+ return False
461
+ if not text_compare (x1 .text , x2 .text ):
462
+ if reporter :
463
+ reporter ('text: %r != %r' % (x1 .text , x2 .text ))
464
+ return False
465
+ if not text_compare (x1 .tail , x2 .tail ):
466
+ if reporter :
467
+ reporter ('tail: %r != %r' % (x1 .tail , x2 .tail ))
468
+ return False
469
+ cl1 = list (x1 )
470
+ cl2 = list (x2 )
471
+ if len (cl1 ) != len (cl2 ):
472
+ if reporter :
473
+ reporter ('children length differs, %i != %i'
474
+ % (len (cl1 ), len (cl2 )))
475
+ return False
476
+ i = 0
477
+ for c1 , c2 in zip (cl1 , cl2 ):
478
+ i += 1
479
+ if not compare_tree (c1 , c2 , reporter = reporter ):
480
+ if reporter :
481
+ reporter ('children %i do not match: %s'
482
+ % (i , c1 .tag ))
483
+ return False
484
+ return True
485
+
486
+
487
+ def text_compare (t1 , t2 ):
488
+ if not t1 and not t2 :
489
+ return True
490
+ if t1 == '*' or t2 == '*' :
491
+ return True
492
+ return (t1 or '' ).strip () == (t2 or '' ).strip ()
493
+
494
+
433
495
def stderr (* args ):
434
496
if sys .version_info .major < 3 :
435
497
file = codecs .getwriter ('utf-8' )(sys .stderr )
0 commit comments