You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

308 lines
12 KiB

4 years ago
  1. #!/usr/bin/env python
  2. #
  3. # Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
  4. # Copyright (c) 2008-2016 California Institute of Technology.
  5. # Copyright (c) 2016-2018 The Uncertainty Quantification Foundation.
  6. # License: 3-clause BSD. The full license text is available at:
  7. # - https://github.com/uqfoundation/dill/blob/master/LICENSE
  8. """
  9. Methods for detecting objects leading to pickling failures.
  10. """
  11. import dis
  12. from inspect import ismethod, isfunction, istraceback, isframe, iscode
  13. from .pointers import parent, reference, at, parents, children
  14. from ._dill import _trace as trace
  15. from ._dill import PY3
  16. __all__ = ['baditems','badobjects','badtypes','code','errors','freevars',
  17. 'getmodule','globalvars','nestedcode','nestedglobals','outermost',
  18. 'referredglobals','referrednested','trace','varnames']
  19. def getmodule(object, _filename=None, force=False):
  20. """get the module of the object"""
  21. from inspect import getmodule as getmod
  22. module = getmod(object, _filename)
  23. if module or not force: return module
  24. if PY3: builtins = 'builtins'
  25. else: builtins = '__builtin__'
  26. builtins = __import__(builtins)
  27. from .source import getname
  28. name = getname(object, force=True)
  29. return builtins if name in vars(builtins).keys() else None
  30. def outermost(func): # is analogous to getsource(func,enclosing=True)
  31. """get outermost enclosing object (i.e. the outer function in a closure)
  32. NOTE: this is the object-equivalent of getsource(func, enclosing=True)
  33. """
  34. if PY3:
  35. if ismethod(func):
  36. _globals = func.__func__.__globals__ or {}
  37. elif isfunction(func):
  38. _globals = func.__globals__ or {}
  39. else:
  40. return #XXX: or raise? no matches
  41. _globals = _globals.items()
  42. else:
  43. if ismethod(func):
  44. _globals = func.im_func.func_globals or {}
  45. elif isfunction(func):
  46. _globals = func.func_globals or {}
  47. else:
  48. return #XXX: or raise? no matches
  49. _globals = _globals.iteritems()
  50. # get the enclosing source
  51. from .source import getsourcelines
  52. try: lines,lnum = getsourcelines(func, enclosing=True)
  53. except: #TypeError, IOError
  54. lines,lnum = [],None
  55. code = ''.join(lines)
  56. # get all possible names,objects that are named in the enclosing source
  57. _locals = ((name,obj) for (name,obj) in _globals if name in code)
  58. # now only save the objects that generate the enclosing block
  59. for name,obj in _locals: #XXX: don't really need 'name'
  60. try:
  61. if getsourcelines(obj) == (lines,lnum): return obj
  62. except: #TypeError, IOError
  63. pass
  64. return #XXX: or raise? no matches
  65. def nestedcode(func, recurse=True): #XXX: or return dict of {co_name: co} ?
  66. """get the code objects for any nested functions (e.g. in a closure)"""
  67. func = code(func)
  68. if not iscode(func): return [] #XXX: or raise? no matches
  69. nested = set()
  70. for co in func.co_consts:
  71. if co is None: continue
  72. co = code(co)
  73. if co:
  74. nested.add(co)
  75. if recurse: nested |= set(nestedcode(co, recurse=True))
  76. return list(nested)
  77. def code(func):
  78. '''get the code object for the given function or method
  79. NOTE: use dill.source.getsource(CODEOBJ) to get the source code
  80. '''
  81. if PY3:
  82. im_func = '__func__'
  83. func_code = '__code__'
  84. else:
  85. im_func = 'im_func'
  86. func_code = 'func_code'
  87. if ismethod(func): func = getattr(func, im_func)
  88. if isfunction(func): func = getattr(func, func_code)
  89. if istraceback(func): func = func.tb_frame
  90. if isframe(func): func = func.f_code
  91. if iscode(func): return func
  92. return
  93. #XXX: ugly: parse dis.dis for name after "<code object" in line and in globals?
  94. def referrednested(func, recurse=True): #XXX: return dict of {__name__: obj} ?
  95. """get functions defined inside of func (e.g. inner functions in a closure)
  96. NOTE: results may differ if the function has been executed or not.
  97. If len(nestedcode(func)) > len(referrednested(func)), try calling func().
  98. If possible, python builds code objects, but delays building functions
  99. until func() is called.
  100. """
  101. if PY3:
  102. att1 = '__code__'
  103. att0 = '__func__'
  104. else:
  105. att1 = 'func_code' # functions
  106. att0 = 'im_func' # methods
  107. import gc
  108. funcs = set()
  109. # get the code objects, and try to track down by referrence
  110. for co in nestedcode(func, recurse):
  111. # look for function objects that refer to the code object
  112. for obj in gc.get_referrers(co):
  113. # get methods
  114. _ = getattr(obj, att0, None) # ismethod
  115. if getattr(_, att1, None) is co: funcs.add(obj)
  116. # get functions
  117. elif getattr(obj, att1, None) is co: funcs.add(obj)
  118. # get frame objects
  119. elif getattr(obj, 'f_code', None) is co: funcs.add(obj)
  120. # get code objects
  121. elif hasattr(obj, 'co_code') and obj is co: funcs.add(obj)
  122. # frameobjs => func.func_code.co_varnames not in func.func_code.co_cellvars
  123. # funcobjs => func.func_code.co_cellvars not in func.func_code.co_varnames
  124. # frameobjs are not found, however funcobjs are...
  125. # (see: test_mixins.quad ... and test_mixins.wtf)
  126. # after execution, code objects get compiled, and then may be found by gc
  127. return list(funcs)
  128. def freevars(func):
  129. """get objects defined in enclosing code that are referred to by func
  130. returns a dict of {name:object}"""
  131. if PY3:
  132. im_func = '__func__'
  133. func_code = '__code__'
  134. func_closure = '__closure__'
  135. else:
  136. im_func = 'im_func'
  137. func_code = 'func_code'
  138. func_closure = 'func_closure'
  139. if ismethod(func): func = getattr(func, im_func)
  140. if isfunction(func):
  141. closures = getattr(func, func_closure) or ()
  142. func = getattr(func, func_code).co_freevars # get freevars
  143. else:
  144. return {}
  145. return dict((name,c.cell_contents) for (name,c) in zip(func,closures))
  146. # thanks to Davies Liu for recursion of globals
  147. def nestedglobals(func, recurse=True):
  148. """get the names of any globals found within func"""
  149. func = code(func)
  150. if func is None: return list()
  151. from .temp import capture
  152. names = set()
  153. with capture('stdout') as out:
  154. dis.dis(func) #XXX: dis.dis(None) disassembles last traceback
  155. for line in out.getvalue().splitlines():
  156. if '_GLOBAL' in line:
  157. name = line.split('(')[-1].split(')')[0]
  158. names.add(name)
  159. for co in getattr(func, 'co_consts', tuple()):
  160. if co and recurse and iscode(co):
  161. names.update(nestedglobals(co, recurse=True))
  162. return list(names)
  163. def referredglobals(func, recurse=True, builtin=False):
  164. """get the names of objects in the global scope referred to by func"""
  165. return globalvars(func, recurse, builtin).keys()
  166. def globalvars(func, recurse=True, builtin=False):
  167. """get objects defined in global scope that are referred to by func
  168. return a dict of {name:object}"""
  169. if PY3:
  170. im_func = '__func__'
  171. func_code = '__code__'
  172. func_globals = '__globals__'
  173. func_closure = '__closure__'
  174. else:
  175. im_func = 'im_func'
  176. func_code = 'func_code'
  177. func_globals = 'func_globals'
  178. func_closure = 'func_closure'
  179. if ismethod(func): func = getattr(func, im_func)
  180. if isfunction(func):
  181. globs = vars(getmodule(sum)).copy() if builtin else {}
  182. # get references from within closure
  183. orig_func, func = func, set()
  184. for obj in getattr(orig_func, func_closure) or {}:
  185. _vars = globalvars(obj.cell_contents, recurse, builtin) or {}
  186. func.update(_vars) #XXX: (above) be wary of infinte recursion?
  187. globs.update(_vars)
  188. # get globals
  189. globs.update(getattr(orig_func, func_globals) or {})
  190. # get names of references
  191. if not recurse:
  192. func.update(getattr(orig_func, func_code).co_names)
  193. else:
  194. func.update(nestedglobals(getattr(orig_func, func_code)))
  195. # find globals for all entries of func
  196. for key in func.copy(): #XXX: unnecessary...?
  197. nested_func = globs.get(key)
  198. if nested_func == orig_func:
  199. #func.remove(key) if key in func else None
  200. continue #XXX: globalvars(func, False)?
  201. func.update(globalvars(nested_func, True, builtin))
  202. elif iscode(func):
  203. globs = vars(getmodule(sum)).copy() if builtin else {}
  204. #globs.update(globals())
  205. if not recurse:
  206. func = func.co_names # get names
  207. else:
  208. orig_func = func.co_name # to stop infinite recursion
  209. func = set(nestedglobals(func))
  210. # find globals for all entries of func
  211. for key in func.copy(): #XXX: unnecessary...?
  212. if key == orig_func:
  213. #func.remove(key) if key in func else None
  214. continue #XXX: globalvars(func, False)?
  215. nested_func = globs.get(key)
  216. func.update(globalvars(nested_func, True, builtin))
  217. else:
  218. return {}
  219. #NOTE: if name not in func_globals, then we skip it...
  220. return dict((name,globs[name]) for name in func if name in globs)
  221. def varnames(func):
  222. """get names of variables defined by func
  223. returns a tuple (local vars, local vars referrenced by nested functions)"""
  224. func = code(func)
  225. if not iscode(func):
  226. return () #XXX: better ((),())? or None?
  227. return func.co_varnames, func.co_cellvars
  228. def baditems(obj, exact=False, safe=False): #XXX: obj=globals() ?
  229. """get items in object that fail to pickle"""
  230. if not hasattr(obj,'__iter__'): # is not iterable
  231. return [j for j in (badobjects(obj,0,exact,safe),) if j is not None]
  232. obj = obj.values() if getattr(obj,'values',None) else obj
  233. _obj = [] # can't use a set, as items may be unhashable
  234. [_obj.append(badobjects(i,0,exact,safe)) for i in obj if i not in _obj]
  235. return [j for j in _obj if j is not None]
  236. def badobjects(obj, depth=0, exact=False, safe=False):
  237. """get objects that fail to pickle"""
  238. from dill import pickles
  239. if not depth:
  240. if pickles(obj,exact,safe): return None
  241. return obj
  242. return dict(((attr, badobjects(getattr(obj,attr),depth-1,exact,safe)) \
  243. for attr in dir(obj) if not pickles(getattr(obj,attr),exact,safe)))
  244. def badtypes(obj, depth=0, exact=False, safe=False):
  245. """get types for objects that fail to pickle"""
  246. from dill import pickles
  247. if not depth:
  248. if pickles(obj,exact,safe): return None
  249. return type(obj)
  250. return dict(((attr, badtypes(getattr(obj,attr),depth-1,exact,safe)) \
  251. for attr in dir(obj) if not pickles(getattr(obj,attr),exact,safe)))
  252. def errors(obj, depth=0, exact=False, safe=False):
  253. """get errors for objects that fail to pickle"""
  254. from dill import pickles, copy
  255. if not depth:
  256. try:
  257. pik = copy(obj)
  258. if exact:
  259. assert pik == obj, \
  260. "Unpickling produces %s instead of %s" % (pik,obj)
  261. assert type(pik) == type(obj), \
  262. "Unpickling produces %s instead of %s" % (type(pik),type(obj))
  263. return None
  264. except Exception:
  265. import sys
  266. return sys.exc_info()[1]
  267. _dict = {}
  268. for attr in dir(obj):
  269. try:
  270. _attr = getattr(obj,attr)
  271. except Exception:
  272. import sys
  273. _dict[attr] = sys.exc_info()[1]
  274. continue
  275. if not pickles(_attr,exact,safe):
  276. _dict[attr] = errors(_attr,depth-1,exact,safe)
  277. return _dict
  278. # EOF