Skip to content

Commit e37d57c

Browse files
committed
[IMP] util/misc: more literal replaces
We want to avoid string replacement for code. closes #222 Signed-off-by: Christophe Simonis (chs) <[email protected]>
1 parent f83f486 commit e37d57c

File tree

2 files changed

+84
-4
lines changed

2 files changed

+84
-4
lines changed

src/base/tests/test_util.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1632,11 +1632,31 @@ def test_SelfPrint_failure(self, value):
16321632
"[('company_id','in', user.other.allowed_company_ids)]",
16331633
"[('company_id', 'in', user.other.allowed_company_ids)]",
16341634
),
1635+
(
1636+
"[('group_id','in', user.groups_id.ids)]",
1637+
"[('group_id', 'in', user.all_group_ids.ids)]",
1638+
),
1639+
(
1640+
"[('group_id','in', [g.id for g in user.groups_id])]",
1641+
"[('group_id', 'in', user.all_group_ids.ids)]",
1642+
),
1643+
(
1644+
"[(1, '=', 0), (1, '=', 1)]",
1645+
"[(0, '=', 1), (1, '=', 1)]",
1646+
),
16351647
]
16361648
)
16371649
@unittest.skipUnless(util.ast_unparse is not None, "`ast.unparse` available from Python3.9")
16381650
def test_literal_replace(self, orig, expected):
1639-
repl = util.literal_replace(orig, {"allowed_company_ids": "companies.active_ids"})
1651+
repl = util.literal_replace(
1652+
orig,
1653+
{
1654+
"allowed_company_ids": "companies.active_ids",
1655+
"user.groups_id.ids": "user.all_group_ids.ids",
1656+
"[g.id for g in user.groups_id]": "user.all_group_ids.ids",
1657+
"(1, '=', 0)": "(0, '=', 1)",
1658+
},
1659+
)
16401660
self.assertEqual(repl, expected)
16411661

16421662

src/util/misc.py

+63-3
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,70 @@ class _Replacer(ast.NodeTransformer):
519519
"""Replace literal nodes in an AST."""
520520

521521
def __init__(self, mapping):
522-
self.mapping = mapping
522+
self.mapping = collections.defaultdict(list)
523+
for key, value in mapping.items():
524+
key_ast = ast.parse(key, mode="eval").body
525+
self.mapping[key_ast.__class__.__name__].append((key_ast, value))
526+
527+
def _no_match(self, left, right):
528+
return False
529+
530+
def _match(self, left, right):
531+
same_ctx = getattr(left, "ctx", None).__class__ is getattr(right, "ctx", None).__class__
532+
cname = left.__class__.__name__
533+
same_type = right.__class__.__name__ == cname
534+
matcher = getattr(self, "_match_" + cname, self._no_match)
535+
return matcher(left, right) if same_type and same_ctx else False
536+
537+
def _match_Constant(self, left, right):
538+
# we don't care about kind for u-strings
539+
return type(left.value) is type(right.value) and left.value == right.value
540+
541+
def _match_Num(self, left, right):
542+
# Dreprecated, for Python <3.8
543+
return type(left.n) is type(right.n) and left.n == right.n
544+
545+
def _match_Str(self, left, right):
546+
# Deprecated, for Python <3.8
547+
return left.s == right.s
548+
549+
def _match_Name(self, left, right):
550+
return left.id == right.id
551+
552+
def _match_Attribute(self, left, right):
553+
return left.attr == right.attr and self._match(left.value, right.value)
554+
555+
def _match_List(self, left, right):
556+
return len(left.elts) == len(right.elts) and all(
557+
self._match(left_, right_) for left_, right_ in zip(left.elts, right.elts)
558+
)
559+
560+
def _match_Tuple(self, left, right):
561+
return self._match_List(left, right)
562+
563+
def _match_ListComp(self, left, right):
564+
return (
565+
self._match(left.elt, right.elt)
566+
and len(left.generators) == len(right.generators)
567+
and all(self._match(left_, right_) for left_, right_ in zip(left.generators, right.generators))
568+
)
569+
570+
def _match_comprehension(self, left, right):
571+
return (
572+
# async is not expected in our use cases, just for completeness
573+
getattr(left, "is_async", 0) == getattr(right, "is_async", 0)
574+
and len(left.ifs) == len(right.ifs)
575+
and self._match(left.target, right.target)
576+
and self._match(left.iter, right.iter)
577+
and all(self._match(left_, right_) for left_, right_ in zip(left.ifs, right.ifs))
578+
)
523579

524-
def visit_Name(self, node):
525-
return ast.Name(id=self.mapping[node.id], ctx=ast.Load()) if node.id in self.mapping else node
580+
def visit(self, node):
581+
if node.__class__.__name__ in self.mapping:
582+
for target_ast, new in self.mapping[node.__class__.__name__]:
583+
if self._match(node, target_ast):
584+
return ast.parse(new, mode="eval").body
585+
return super(_Replacer, self).visit(node)
526586

527587

528588
def literal_replace(expr, mapping):

0 commit comments

Comments
 (0)