@@ -519,10 +519,70 @@ class _Replacer(ast.NodeTransformer):
519
519
"""Replace literal nodes in an AST."""
520
520
521
521
def __init__ (self , mapping ):
522
- self .mapping = mapping
522
+ self .mapping = collections .defaultdict (list )
523
+ for key , value in mapping .items ():
524
+ key_ast = ast .parse (key , mode = "eval" ).body
525
+ self .mapping [key_ast .__class__ .__name__ ].append ((key_ast , value ))
526
+
527
+ def _no_match (self , left , right ):
528
+ return False
529
+
530
+ def _match (self , left , right ):
531
+ same_ctx = getattr (left , "ctx" , None ).__class__ is getattr (right , "ctx" , None ).__class__
532
+ cname = left .__class__ .__name__
533
+ same_type = right .__class__ .__name__ == cname
534
+ matcher = getattr (self , "_match_" + cname , self ._no_match )
535
+ return matcher (left , right ) if same_type and same_ctx else False
536
+
537
+ def _match_Constant (self , left , right ):
538
+ # we don't care about kind for u-strings
539
+ return type (left .value ) is type (right .value ) and left .value == right .value
540
+
541
+ def _match_Num (self , left , right ):
542
+ # Dreprecated, for Python <3.8
543
+ return type (left .n ) is type (right .n ) and left .n == right .n
544
+
545
+ def _match_Str (self , left , right ):
546
+ # Deprecated, for Python <3.8
547
+ return left .s == right .s
548
+
549
+ def _match_Name (self , left , right ):
550
+ return left .id == right .id
551
+
552
+ def _match_Attribute (self , left , right ):
553
+ return left .attr == right .attr and self ._match (left .value , right .value )
554
+
555
+ def _match_List (self , left , right ):
556
+ return len (left .elts ) == len (right .elts ) and all (
557
+ self ._match (left_ , right_ ) for left_ , right_ in zip (left .elts , right .elts )
558
+ )
559
+
560
+ def _match_Tuple (self , left , right ):
561
+ return self ._match_List (left , right )
562
+
563
+ def _match_ListComp (self , left , right ):
564
+ return (
565
+ self ._match (left .elt , right .elt )
566
+ and len (left .generators ) == len (right .generators )
567
+ and all (self ._match (left_ , right_ ) for left_ , right_ in zip (left .generators , right .generators ))
568
+ )
569
+
570
+ def _match_comprehension (self , left , right ):
571
+ return (
572
+ # async is not expected in our use cases, just for completeness
573
+ getattr (left , "is_async" , 0 ) == getattr (right , "is_async" , 0 )
574
+ and len (left .ifs ) == len (right .ifs )
575
+ and self ._match (left .target , right .target )
576
+ and self ._match (left .iter , right .iter )
577
+ and all (self ._match (left_ , right_ ) for left_ , right_ in zip (left .ifs , right .ifs ))
578
+ )
523
579
524
- def visit_Name (self , node ):
525
- return ast .Name (id = self .mapping [node .id ], ctx = ast .Load ()) if node .id in self .mapping else node
580
+ def visit (self , node ):
581
+ if node .__class__ .__name__ in self .mapping :
582
+ for target_ast , new in self .mapping [node .__class__ .__name__ ]:
583
+ if self ._match (node , target_ast ):
584
+ return ast .parse (new , mode = "eval" ).body
585
+ return super (_Replacer , self ).visit (node )
526
586
527
587
528
588
def literal_replace (expr , mapping ):
0 commit comments