|
1 | 1 | from contextlib import contextmanager
|
2 |
| -from typing import Dict, Iterator, List |
| 2 | +from typing import Dict, Iterator, List, Set |
3 | 3 | from typing_extensions import Final
|
4 | 4 |
|
5 | 5 | from mypy.nodes import (
|
6 | 6 | Block, AssignmentStmt, NameExpr, MypyFile, FuncDef, Lvalue, ListExpr, TupleExpr,
|
7 | 7 | WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, MatchStmt, StarExpr,
|
8 |
| - ImportFrom, MemberExpr, IndexExpr, Import, ClassDef |
| 8 | + ImportFrom, MemberExpr, IndexExpr, Import, ImportAll, ClassDef |
9 | 9 | )
|
10 | 10 | from mypy.patterns import AsPattern
|
11 | 11 | from mypy.traverser import TraverserVisitor
|
@@ -262,15 +262,9 @@ def flush_refs(self) -> None:
|
262 | 262 | # as it will be publicly visible outside the module.
|
263 | 263 | to_rename = refs[:-1]
|
264 | 264 | for i, item in enumerate(to_rename):
|
265 |
| - self.rename_refs(item, i) |
| 265 | + rename_refs(item, i) |
266 | 266 | self.refs.pop()
|
267 | 267 |
|
268 |
| - def rename_refs(self, names: List[NameExpr], index: int) -> None: |
269 |
| - name = names[0].name |
270 |
| - new_name = name + "'" * (index + 1) |
271 |
| - for expr in names: |
272 |
| - expr.name = new_name |
273 |
| - |
274 | 268 | # Helpers for determining which assignments define new variables
|
275 | 269 |
|
276 | 270 | def clear(self) -> None:
|
@@ -392,3 +386,162 @@ def record_assignment(self, name: str, can_be_redefined: bool) -> bool:
|
392 | 386 | else:
|
393 | 387 | # Assigns to an existing variable.
|
394 | 388 | return False
|
| 389 | + |
| 390 | + |
| 391 | +class LimitedVariableRenameVisitor(TraverserVisitor): |
| 392 | + """Perform some limited variable renaming in with statements. |
| 393 | +
|
| 394 | + This allows reusing a variable in multiple with statements with |
| 395 | + different types. For example, the two instances of 'x' can have |
| 396 | + incompatible types: |
| 397 | +
|
| 398 | + with C() as x: |
| 399 | + f(x) |
| 400 | + with D() as x: |
| 401 | + g(x) |
| 402 | +
|
| 403 | + The above code gets renamed conceptually into this (not valid Python!): |
| 404 | +
|
| 405 | + with C() as x': |
| 406 | + f(x') |
| 407 | + with D() as x: |
| 408 | + g(x) |
| 409 | +
|
| 410 | + If there's a reference to a variable defined in 'with' outside the |
| 411 | + statement, or if there's any trickiness around variable visibility |
| 412 | + (e.g. function definitions), we give up and won't perform renaming. |
| 413 | +
|
| 414 | + The main use case is to allow binding both readable and writable |
| 415 | + binary files into the same variable. These have different types: |
| 416 | +
|
| 417 | + with open(fnam, 'rb') as f: ... |
| 418 | + with open(fnam, 'wb') as f: ... |
| 419 | + """ |
| 420 | + |
| 421 | + def __init__(self) -> None: |
| 422 | + # Short names of variables bound in with statements using "as" |
| 423 | + # in a surrounding scope |
| 424 | + self.bound_vars: List[str] = [] |
| 425 | + # Names that can't be safely renamed, per scope ('*' means that |
| 426 | + # no names can be renamed) |
| 427 | + self.skipped: List[Set[str]] = [] |
| 428 | + # References to variables that we may need to rename. List of |
| 429 | + # scopes; each scope is a mapping from name to list of collections |
| 430 | + # of names that refer to the same logical variable. |
| 431 | + self.refs: List[Dict[str, List[List[NameExpr]]]] = [] |
| 432 | + |
| 433 | + def visit_mypy_file(self, file_node: MypyFile) -> None: |
| 434 | + """Rename variables within a file. |
| 435 | +
|
| 436 | + This is the main entry point to this class. |
| 437 | + """ |
| 438 | + with self.enter_scope(): |
| 439 | + for d in file_node.defs: |
| 440 | + d.accept(self) |
| 441 | + |
| 442 | + def visit_func_def(self, fdef: FuncDef) -> None: |
| 443 | + self.reject_redefinition_of_vars_in_scope() |
| 444 | + with self.enter_scope(): |
| 445 | + for arg in fdef.arguments: |
| 446 | + self.record_skipped(arg.variable.name) |
| 447 | + super().visit_func_def(fdef) |
| 448 | + |
| 449 | + def visit_class_def(self, cdef: ClassDef) -> None: |
| 450 | + self.reject_redefinition_of_vars_in_scope() |
| 451 | + with self.enter_scope(): |
| 452 | + super().visit_class_def(cdef) |
| 453 | + |
| 454 | + def visit_with_stmt(self, stmt: WithStmt) -> None: |
| 455 | + for expr in stmt.expr: |
| 456 | + expr.accept(self) |
| 457 | + old_len = len(self.bound_vars) |
| 458 | + for target in stmt.target: |
| 459 | + if target is not None: |
| 460 | + self.analyze_lvalue(target) |
| 461 | + for target in stmt.target: |
| 462 | + if target: |
| 463 | + target.accept(self) |
| 464 | + stmt.body.accept(self) |
| 465 | + |
| 466 | + while len(self.bound_vars) > old_len: |
| 467 | + self.bound_vars.pop() |
| 468 | + |
| 469 | + def analyze_lvalue(self, lvalue: Lvalue) -> None: |
| 470 | + if isinstance(lvalue, NameExpr): |
| 471 | + name = lvalue.name |
| 472 | + if name in self.bound_vars: |
| 473 | + # Name bound in a surrounding with statement, so it can be renamed |
| 474 | + self.visit_name_expr(lvalue) |
| 475 | + else: |
| 476 | + var_info = self.refs[-1] |
| 477 | + if name not in var_info: |
| 478 | + var_info[name] = [] |
| 479 | + var_info[name].append([]) |
| 480 | + self.bound_vars.append(name) |
| 481 | + elif isinstance(lvalue, (ListExpr, TupleExpr)): |
| 482 | + for item in lvalue.items: |
| 483 | + self.analyze_lvalue(item) |
| 484 | + elif isinstance(lvalue, MemberExpr): |
| 485 | + lvalue.expr.accept(self) |
| 486 | + elif isinstance(lvalue, IndexExpr): |
| 487 | + lvalue.base.accept(self) |
| 488 | + lvalue.index.accept(self) |
| 489 | + elif isinstance(lvalue, StarExpr): |
| 490 | + self.analyze_lvalue(lvalue.expr) |
| 491 | + |
| 492 | + def visit_import(self, imp: Import) -> None: |
| 493 | + # We don't support renaming imports |
| 494 | + for id, as_id in imp.ids: |
| 495 | + self.record_skipped(as_id or id) |
| 496 | + |
| 497 | + def visit_import_from(self, imp: ImportFrom) -> None: |
| 498 | + # We don't support renaming imports |
| 499 | + for id, as_id in imp.names: |
| 500 | + self.record_skipped(as_id or id) |
| 501 | + |
| 502 | + def visit_import_all(self, imp: ImportAll) -> None: |
| 503 | + # Give up, since we don't know all imported names yet |
| 504 | + self.reject_redefinition_of_vars_in_scope() |
| 505 | + |
| 506 | + def visit_name_expr(self, expr: NameExpr) -> None: |
| 507 | + name = expr.name |
| 508 | + if name in self.bound_vars: |
| 509 | + # Record reference so that it can be renamed later |
| 510 | + for scope in reversed(self.refs): |
| 511 | + if name in scope: |
| 512 | + scope[name][-1].append(expr) |
| 513 | + else: |
| 514 | + self.record_skipped(name) |
| 515 | + |
| 516 | + @contextmanager |
| 517 | + def enter_scope(self) -> Iterator[None]: |
| 518 | + self.skipped.append(set()) |
| 519 | + self.refs.append({}) |
| 520 | + yield None |
| 521 | + self.flush_refs() |
| 522 | + |
| 523 | + def reject_redefinition_of_vars_in_scope(self) -> None: |
| 524 | + self.record_skipped('*') |
| 525 | + |
| 526 | + def record_skipped(self, name: str) -> None: |
| 527 | + self.skipped[-1].add(name) |
| 528 | + |
| 529 | + def flush_refs(self) -> None: |
| 530 | + ref_dict = self.refs.pop() |
| 531 | + skipped = self.skipped.pop() |
| 532 | + if '*' not in skipped: |
| 533 | + for name, refs in ref_dict.items(): |
| 534 | + if len(refs) <= 1 or name in skipped: |
| 535 | + continue |
| 536 | + # At module top level we must not rename the final definition, |
| 537 | + # as it may be publicly visible |
| 538 | + to_rename = refs[:-1] |
| 539 | + for i, item in enumerate(to_rename): |
| 540 | + rename_refs(item, i) |
| 541 | + |
| 542 | + |
| 543 | +def rename_refs(names: List[NameExpr], index: int) -> None: |
| 544 | + name = names[0].name |
| 545 | + new_name = name + "'" * (index + 1) |
| 546 | + for expr in names: |
| 547 | + expr.name = new_name |
0 commit comments