| OLD | NEW |
| 1 # copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved. | 1 # copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved. |
| 2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr | 2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr |
| 3 # | 3 # |
| 4 # This file is part of astroid. | 4 # This file is part of astroid. |
| 5 # | 5 # |
| 6 # astroid is free software: you can redistribute it and/or modify it | 6 # astroid is free software: you can redistribute it and/or modify it |
| 7 # under the terms of the GNU Lesser General Public License as published by the | 7 # under the terms of the GNU Lesser General Public License as published by the |
| 8 # Free Software Foundation, either version 2.1 of the License, or (at your | 8 # Free Software Foundation, either version 2.1 of the License, or (at your |
| 9 # option) any later version. | 9 # option) any later version. |
| 10 # | 10 # |
| 11 # astroid is distributed in the hope that it will be useful, but | 11 # astroid is distributed in the hope that it will be useful, but |
| 12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or | 12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
| 13 # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License | 13 # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License |
| 14 # for more details. | 14 # for more details. |
| 15 # | 15 # |
| 16 # You should have received a copy of the GNU Lesser General Public License along | 16 # You should have received a copy of the GNU Lesser General Public License along |
| 17 # with astroid. If not, see <http://www.gnu.org/licenses/>. | 17 # with astroid. If not, see <http://www.gnu.org/licenses/>. |
| 18 """This module contains base classes and functions for the nodes and some | 18 """This module contains base classes and functions for the nodes and some |
| 19 inference utils. | 19 inference utils. |
| 20 """ | 20 """ |
| 21 | 21 |
| 22 __docformat__ = "restructuredtext en" | 22 __docformat__ = "restructuredtext en" |
| 23 | 23 |
| 24 import sys | 24 import sys |
| 25 from contextlib import contextmanager | 25 from contextlib import contextmanager |
| 26 | 26 |
| 27 from logilab.common.decorators import cachedproperty |
| 28 |
| 27 from astroid.exceptions import (InferenceError, AstroidError, NotFoundError, | 29 from astroid.exceptions import (InferenceError, AstroidError, NotFoundError, |
| 28 UnresolvableName, UseInferenceDefault) | 30 UnresolvableName, UseInferenceDefault) |
| 29 | 31 |
| 30 | 32 |
| 31 if sys.version_info >= (3, 0): | 33 if sys.version_info >= (3, 0): |
| 32 BUILTINS = 'builtins' | 34 BUILTINS = 'builtins' |
| 33 else: | 35 else: |
| 34 BUILTINS = '__builtin__' | 36 BUILTINS = '__builtin__' |
| 35 | 37 |
| 36 | 38 |
| (...skipping 12 matching lines...) Expand all Loading... |
| 49 if name in self.__dict__: | 51 if name in self.__dict__: |
| 50 return self.__dict__[name] | 52 return self.__dict__[name] |
| 51 return getattr(self._proxied, name) | 53 return getattr(self._proxied, name) |
| 52 | 54 |
| 53 def infer(self, context=None): | 55 def infer(self, context=None): |
| 54 yield self | 56 yield self |
| 55 | 57 |
| 56 | 58 |
| 57 # Inference ################################################################## | 59 # Inference ################################################################## |
| 58 | 60 |
| 61 MISSING = object() |
| 62 |
| 63 |
| 59 class InferenceContext(object): | 64 class InferenceContext(object): |
| 60 __slots__ = ('path', 'lookupname', 'callcontext', 'boundnode') | 65 __slots__ = ('path', 'callcontext', 'boundnode', 'infered') |
| 61 | 66 |
| 62 def __init__(self, path=None): | 67 def __init__(self, |
| 68 path=None, callcontext=None, boundnode=None, infered=None): |
| 63 if path is None: | 69 if path is None: |
| 64 self.path = set() | 70 self.path = frozenset() |
| 65 else: | 71 else: |
| 66 self.path = path | 72 self.path = path |
| 67 self.lookupname = None | 73 self.callcontext = callcontext |
| 68 self.callcontext = None | 74 self.boundnode = boundnode |
| 69 self.boundnode = None | 75 if infered is None: |
| 76 self.infered = {} |
| 77 else: |
| 78 self.infered = infered |
| 70 | 79 |
| 71 def push(self, node): | 80 def push(self, key): |
| 72 name = self.lookupname | 81 # This returns a NEW context with the same attributes, but a new key |
| 73 if (node, name) in self.path: | 82 # added to `path`. The intention is that it's only passed to callees |
| 74 raise StopIteration() | 83 # and then destroyed; otherwise scope() may not work correctly. |
| 75 self.path.add((node, name)) | 84 # The cache will be shared, since it's the same exact dict. |
| 85 if key in self.path: |
| 86 # End the containing generator |
| 87 raise StopIteration |
| 76 | 88 |
| 77 def clone(self): | 89 return InferenceContext( |
| 78 # XXX copy lookupname/callcontext ? | 90 self.path.union([key]), |
| 79 clone = InferenceContext(self.path) | 91 self.callcontext, |
| 80 clone.callcontext = self.callcontext | 92 self.boundnode, |
| 81 clone.boundnode = self.boundnode | 93 self.infered, |
| 82 return clone | 94 ) |
| 83 | 95 |
| 84 @contextmanager | 96 @contextmanager |
| 85 def restore_path(self): | 97 def scope(self, callcontext=MISSING, boundnode=MISSING): |
| 86 path = set(self.path) | 98 try: |
| 87 yield | 99 orig = self.callcontext, self.boundnode |
| 88 self.path = path | 100 if callcontext is not MISSING: |
| 101 self.callcontext = callcontext |
| 102 if boundnode is not MISSING: |
| 103 self.boundnode = boundnode |
| 104 yield |
| 105 finally: |
| 106 self.callcontext, self.boundnode = orig |
| 89 | 107 |
| 90 def copy_context(context): | 108 def cache_generator(self, key, generator): |
| 91 if context is not None: | 109 results = [] |
| 92 return context.clone() | 110 for result in generator: |
| 93 else: | 111 results.append(result) |
| 94 return InferenceContext() | 112 yield result |
| 113 |
| 114 self.infered[key] = tuple(results) |
| 115 return |
| 95 | 116 |
| 96 | 117 |
| 97 def _infer_stmts(stmts, context, frame=None): | 118 def _infer_stmts(stmts, context, frame=None, lookupname=None): |
| 98 """return an iterator on statements inferred by each statement in <stmts> | 119 """return an iterator on statements inferred by each statement in <stmts> |
| 99 """ | 120 """ |
| 100 stmt = None | 121 stmt = None |
| 101 infered = False | 122 infered = False |
| 102 if context is not None: | 123 if context is None: |
| 103 name = context.lookupname | |
| 104 context = context.clone() | |
| 105 else: | |
| 106 name = None | |
| 107 context = InferenceContext() | 124 context = InferenceContext() |
| 108 for stmt in stmts: | 125 for stmt in stmts: |
| 109 if stmt is YES: | 126 if stmt is YES: |
| 110 yield stmt | 127 yield stmt |
| 111 infered = True | 128 infered = True |
| 112 continue | 129 continue |
| 113 context.lookupname = stmt._infer_name(frame, name) | 130 |
| 131 kw = {} |
| 132 infered_name = stmt._infer_name(frame, lookupname) |
| 133 if infered_name is not None: |
| 134 # only returns not None if .infer() accepts a lookupname kwarg |
| 135 kw['lookupname'] = infered_name |
| 136 |
| 114 try: | 137 try: |
| 115 for infered in stmt.infer(context): | 138 for infered in stmt.infer(context, **kw): |
| 116 yield infered | 139 yield infered |
| 117 infered = True | 140 infered = True |
| 118 except UnresolvableName: | 141 except UnresolvableName: |
| 119 continue | 142 continue |
| 120 except InferenceError: | 143 except InferenceError: |
| 121 yield YES | 144 yield YES |
| 122 infered = True | 145 infered = True |
| 123 if not infered: | 146 if not infered: |
| 124 raise InferenceError(str(stmt)) | 147 raise InferenceError(str(stmt)) |
| 125 | 148 |
| (...skipping 37 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 163 # well | 186 # well |
| 164 if lookupclass: | 187 if lookupclass: |
| 165 try: | 188 try: |
| 166 return values + self._proxied.getattr(name, context) | 189 return values + self._proxied.getattr(name, context) |
| 167 except NotFoundError: | 190 except NotFoundError: |
| 168 pass | 191 pass |
| 169 return values | 192 return values |
| 170 | 193 |
| 171 def igetattr(self, name, context=None): | 194 def igetattr(self, name, context=None): |
| 172 """inferred getattr""" | 195 """inferred getattr""" |
| 196 if not context: |
| 197 context = InferenceContext() |
| 173 try: | 198 try: |
| 174 # avoid recursively inferring the same attr on the same class | 199 # avoid recursively inferring the same attr on the same class |
| 175 if context: | 200 new_context = context.push((self._proxied, name)) |
| 176 context.push((self._proxied, name)) | |
| 177 # XXX frame should be self._proxied, or not ? | 201 # XXX frame should be self._proxied, or not ? |
| 178 get_attr = self.getattr(name, context, lookupclass=False) | 202 get_attr = self.getattr(name, new_context, lookupclass=False) |
| 179 return _infer_stmts(self._wrap_attr(get_attr, context), context, | 203 return _infer_stmts( |
| 180 frame=self) | 204 self._wrap_attr(get_attr, new_context), |
| 205 new_context, |
| 206 frame=self, |
| 207 ) |
| 181 except NotFoundError: | 208 except NotFoundError: |
| 182 try: | 209 try: |
| 183 # fallback to class'igetattr since it has some logic to handle | 210 # fallback to class'igetattr since it has some logic to handle |
| 184 # descriptors | 211 # descriptors |
| 185 return self._wrap_attr(self._proxied.igetattr(name, context), | 212 return self._wrap_attr(self._proxied.igetattr(name, context), |
| 186 context) | 213 context) |
| 187 except NotFoundError: | 214 except NotFoundError: |
| 188 raise InferenceError(name) | 215 raise InferenceError(name) |
| 189 | 216 |
| 190 def _wrap_attr(self, attrs, context=None): | 217 def _wrap_attr(self, attrs, context=None): |
| 191 """wrap bound methods of attrs in a InstanceMethod proxies""" | 218 """wrap bound methods of attrs in a InstanceMethod proxies""" |
| 192 for attr in attrs: | 219 for attr in attrs: |
| 193 if isinstance(attr, UnboundMethod): | 220 if isinstance(attr, UnboundMethod): |
| 194 if BUILTINS + '.property' in attr.decoratornames(): | 221 if BUILTINS + '.property' in attr.decoratornames(): |
| 195 for infered in attr.infer_call_result(self, context): | 222 for infered in attr.infer_call_result(self, context): |
| 196 yield infered | 223 yield infered |
| (...skipping 70 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 267 class BoundMethod(UnboundMethod): | 294 class BoundMethod(UnboundMethod): |
| 268 """a special node representing a method bound to an instance""" | 295 """a special node representing a method bound to an instance""" |
| 269 def __init__(self, proxy, bound): | 296 def __init__(self, proxy, bound): |
| 270 UnboundMethod.__init__(self, proxy) | 297 UnboundMethod.__init__(self, proxy) |
| 271 self.bound = bound | 298 self.bound = bound |
| 272 | 299 |
| 273 def is_bound(self): | 300 def is_bound(self): |
| 274 return True | 301 return True |
| 275 | 302 |
| 276 def infer_call_result(self, caller, context): | 303 def infer_call_result(self, caller, context): |
| 277 context = context.clone() | 304 with context.scope(boundnode=self.bound): |
| 278 context.boundnode = self.bound | 305 for infered in self._proxied.infer_call_result(caller, context): |
| 279 return self._proxied.infer_call_result(caller, context) | 306 yield infered |
| 280 | 307 |
| 281 | 308 |
| 282 class Generator(Instance): | 309 class Generator(Instance): |
| 283 """a special node representing a generator. | 310 """a special node representing a generator. |
| 284 | 311 |
| 285 Proxied class is set once for all in raw_building. | 312 Proxied class is set once for all in raw_building. |
| 286 """ | 313 """ |
| 287 def callable(self): | 314 def callable(self): |
| 288 return False | 315 return False |
| 289 | 316 |
| (...skipping 11 matching lines...) Expand all Loading... |
| 301 | 328 |
| 302 | 329 |
| 303 # decorators ################################################################## | 330 # decorators ################################################################## |
| 304 | 331 |
| 305 def path_wrapper(func): | 332 def path_wrapper(func): |
| 306 """return the given infer function wrapped to handle the path""" | 333 """return the given infer function wrapped to handle the path""" |
| 307 def wrapped(node, context=None, _func=func, **kwargs): | 334 def wrapped(node, context=None, _func=func, **kwargs): |
| 308 """wrapper function handling context""" | 335 """wrapper function handling context""" |
| 309 if context is None: | 336 if context is None: |
| 310 context = InferenceContext() | 337 context = InferenceContext() |
| 311 context.push(node) | 338 context = context.push((node, kwargs.get('lookupname'))) |
| 339 |
| 312 yielded = set() | 340 yielded = set() |
| 313 for res in _func(node, context, **kwargs): | 341 for res in _func(node, context, **kwargs): |
| 314 # unproxy only true instance, not const, tuple, dict... | 342 # unproxy only true instance, not const, tuple, dict... |
| 315 if res.__class__ is Instance: | 343 if res.__class__ is Instance: |
| 316 ares = res._proxied | 344 ares = res._proxied |
| 317 else: | 345 else: |
| 318 ares = res | 346 ares = res |
| 319 if not ares in yielded: | 347 if not ares in yielded: |
| 320 yield res | 348 yield res |
| 321 yielded.add(ares) | 349 yielded.add(ares) |
| (...skipping 48 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 370 | 398 |
| 371 If the instance has some explicit inference function set, it will be | 399 If the instance has some explicit inference function set, it will be |
| 372 called instead of the default interface. | 400 called instead of the default interface. |
| 373 """ | 401 """ |
| 374 if self._explicit_inference is not None: | 402 if self._explicit_inference is not None: |
| 375 # explicit_inference is not bound, give it self explicitly | 403 # explicit_inference is not bound, give it self explicitly |
| 376 try: | 404 try: |
| 377 return self._explicit_inference(self, context, **kwargs) | 405 return self._explicit_inference(self, context, **kwargs) |
| 378 except UseInferenceDefault: | 406 except UseInferenceDefault: |
| 379 pass | 407 pass |
| 380 return self._infer(context, **kwargs) | 408 |
| 409 if not context: |
| 410 return self._infer(context, **kwargs) |
| 411 |
| 412 key = (self, kwargs.get('lookupname'), context.callcontext, context.boun
dnode) |
| 413 if key in context.infered: |
| 414 return iter(context.infered[key]) |
| 415 |
| 416 return context.cache_generator(key, self._infer(context, **kwargs)) |
| 381 | 417 |
| 382 def _repr_name(self): | 418 def _repr_name(self): |
| 383 """return self.name or self.attrname or '' for nice representation""" | 419 """return self.name or self.attrname or '' for nice representation""" |
| 384 return getattr(self, 'name', getattr(self, 'attrname', '')) | 420 return getattr(self, 'name', getattr(self, 'attrname', '')) |
| 385 | 421 |
| 386 def __str__(self): | 422 def __str__(self): |
| 387 return '%s(%s)' % (self.__class__.__name__, self._repr_name()) | 423 return '%s(%s)' % (self.__class__.__name__, self._repr_name()) |
| 388 | 424 |
| 389 def __repr__(self): | 425 def __repr__(self): |
| 390 return '<%s(%s) l.%s [%s] at 0x%x>' % (self.__class__.__name__, | 426 return '<%s(%s) l.%s [%s] at 0x%x>' % (self.__class__.__name__, |
| (...skipping 17 matching lines...) Expand all Loading... |
| 408 yield elt | 444 yield elt |
| 409 else: | 445 else: |
| 410 yield attr | 446 yield attr |
| 411 | 447 |
| 412 def last_child(self): | 448 def last_child(self): |
| 413 """an optimized version of list(get_children())[-1]""" | 449 """an optimized version of list(get_children())[-1]""" |
| 414 for field in self._astroid_fields[::-1]: | 450 for field in self._astroid_fields[::-1]: |
| 415 attr = getattr(self, field) | 451 attr = getattr(self, field) |
| 416 if not attr: # None or empty listy / tuple | 452 if not attr: # None or empty listy / tuple |
| 417 continue | 453 continue |
| 418 if isinstance(attr, (list, tuple)): | 454 if attr.__class__ in (list, tuple): |
| 419 return attr[-1] | 455 return attr[-1] |
| 420 else: | 456 else: |
| 421 return attr | 457 return attr |
| 422 return None | 458 return None |
| 423 | 459 |
| 424 def parent_of(self, node): | 460 def parent_of(self, node): |
| 425 """return true if i'm a parent of the given node""" | 461 """return true if i'm a parent of the given node""" |
| 426 parent = node.parent | 462 parent = node.parent |
| 427 while parent is not None: | 463 while parent is not None: |
| 428 if self is parent: | 464 if self is parent: |
| (...skipping 70 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 499 assert node.root() is myroot, \ | 535 assert node.root() is myroot, \ |
| 500 'nodes %s and %s are not from the same module' % (self, node) | 536 'nodes %s and %s are not from the same module' % (self, node) |
| 501 lineno = node.fromlineno | 537 lineno = node.fromlineno |
| 502 if node.fromlineno > mylineno: | 538 if node.fromlineno > mylineno: |
| 503 break | 539 break |
| 504 if lineno > nearest[1]: | 540 if lineno > nearest[1]: |
| 505 nearest = node, lineno | 541 nearest = node, lineno |
| 506 # FIXME: raise an exception if nearest is None ? | 542 # FIXME: raise an exception if nearest is None ? |
| 507 return nearest[0] | 543 return nearest[0] |
| 508 | 544 |
| 509 def set_line_info(self, lastchild): | 545 # these are lazy because they're relatively expensive to compute for every |
| 546 # single node, and they rarely get looked at |
| 547 |
| 548 @cachedproperty |
| 549 def fromlineno(self): |
| 510 if self.lineno is None: | 550 if self.lineno is None: |
| 511 self.fromlineno = self._fixed_source_line() | 551 return self._fixed_source_line() |
| 512 else: | 552 else: |
| 513 self.fromlineno = self.lineno | 553 return self.lineno |
| 554 |
| 555 @cachedproperty |
| 556 def tolineno(self): |
| 557 if not self._astroid_fields: |
| 558 # can't have children |
| 559 lastchild = None |
| 560 else: |
| 561 lastchild = self.last_child() |
| 514 if lastchild is None: | 562 if lastchild is None: |
| 515 self.tolineno = self.fromlineno | 563 return self.fromlineno |
| 516 else: | 564 else: |
| 517 self.tolineno = lastchild.tolineno | 565 return lastchild.tolineno |
| 518 return | 566 |
| 519 # TODO / FIXME: | 567 # TODO / FIXME: |
| 520 assert self.fromlineno is not None, self | 568 assert self.fromlineno is not None, self |
| 521 assert self.tolineno is not None, self | 569 assert self.tolineno is not None, self |
| 522 | 570 |
| 523 def _fixed_source_line(self): | 571 def _fixed_source_line(self): |
| 524 """return the line number where the given node appears | 572 """return the line number where the given node appears |
| 525 | 573 |
| 526 we need this method since not all nodes have the lineno attribute | 574 we need this method since not all nodes have the lineno attribute |
| 527 correctly set... | 575 correctly set... |
| 528 """ | 576 """ |
| 529 line = self.lineno | 577 line = self.lineno |
| 530 _node = self | 578 _node = self |
| 531 try: | 579 try: |
| 532 while line is None: | 580 while line is None: |
| 533 _node = _node.get_children().next() | 581 _node = next(_node.get_children()) |
| 534 line = _node.lineno | 582 line = _node.lineno |
| 535 except StopIteration: | 583 except StopIteration: |
| 536 _node = self.parent | 584 _node = self.parent |
| 537 while _node and line is None: | 585 while _node and line is None: |
| 538 line = _node.lineno | 586 line = _node.lineno |
| 539 _node = _node.parent | 587 _node = _node.parent |
| 540 return line | 588 return line |
| 541 | 589 |
| 542 def block_range(self, lineno): | 590 def block_range(self, lineno): |
| 543 """handle block line numbers range for non block opening statements | 591 """handle block line numbers range for non block opening statements |
| (...skipping 64 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 608 return stmts[index +1] | 656 return stmts[index +1] |
| 609 except IndexError: | 657 except IndexError: |
| 610 pass | 658 pass |
| 611 | 659 |
| 612 def previous_sibling(self): | 660 def previous_sibling(self): |
| 613 """return the previous sibling statement""" | 661 """return the previous sibling statement""" |
| 614 stmts = self.parent.child_sequence(self) | 662 stmts = self.parent.child_sequence(self) |
| 615 index = stmts.index(self) | 663 index = stmts.index(self) |
| 616 if index >= 1: | 664 if index >= 1: |
| 617 return stmts[index -1] | 665 return stmts[index -1] |
| OLD | NEW |