Source code for kwutil.util_eval

"""
Defines a safer eval function. See references for related work. It is possible
that one of those libraries may be better suited.

We currently vendor code from [evalidate]_.

References:
    .. [pure-eval] https://pypi.org/project/pure-eval/
    .. [safeeval] https://pypi.org/project/safeeval/
    .. [evalidate] https://pypi.org/project/evalidate/
    .. [PythonSafeEval] https://pypi.org/project/PythonSafeEval/
    .. [numexpr] https://pypi.org/project/numexpr/2.6.1/
"""
import ast


[docs] class RestrictedSyntaxError(Exception): """ An exception raised by restricted_eval if a disallowed expression is given """
[docs] def restricted_eval(expr, max_chars=32, local_dict=None, builtins_passlist=None): """ A restricted form of Python's eval that is meant to be slightly safer Args: expr (str): the expression to evaluate max_char (int): expression cannot be more than this many characters local_dict (Dict[str, Any]): a list of variables allowed to be used builtins_passlist (List[str] | None) : if specified, only allow use of certain builtins References: https://realpython.com/python-eval-function/#minimizing-the-security-issues-of-eval Notes: This function may not be safe, but it has as many mitigation measures that I know about. This function should be audited and possibly made even more restricted. The idea is that this should just be used to evaluate numeric expressions. Example: >>> from kwutil.util_eval import * # NOQA >>> builtins_passlist = ['min', 'max', 'round', 'sum'] >>> local_dict = {} >>> max_chars = 32 >>> expr = 'max(3 + 2, 9)' >>> result = restricted_eval(expr, max_chars, local_dict, builtins_passlist) >>> expr = '3 + 2' >>> result = restricted_eval(expr, max_chars, local_dict, builtins_passlist) >>> expr = '3 + 2' >>> result = restricted_eval(expr, max_chars) >>> import pytest >>> with pytest.raises(RestrictedSyntaxError): >>> expr = 'max(a + 2, 3)' >>> result = restricted_eval(expr, max_chars, dict(a=3)) """ import builtins if type(expr) is not str: raise TypeError(f'expr must be a pure str. Got {type(expr)}') if len(expr) > max_chars: raise RestrictedSyntaxError( 'num-workers-hueristic should be small text. ' 'We want to disallow attempts at crashing python ' 'by feeding nasty input into eval. But this may still ' 'be dangerous.' ) if local_dict is None: local_dict = {} if builtins_passlist is None: builtins_passlist = [] _builtins_passlist = set(builtins_passlist) allowed_builtins = {k: v for k, v in builtins.__dict__.items() if k in _builtins_passlist} local_dict['__builtins__'] = allowed_builtins allowed_names = list(allowed_builtins.keys()) + list(local_dict.keys()) code = compile(expr, "<string>", "eval") # Step 3 for name in code.co_names: if name not in allowed_names: raise RestrictedSyntaxError(f"Use of {name} not allowed") # result = eval(code, local_dict, local_dict) result = safeeval(expr, local_dict, addnodes=['Call'], funcs=builtins_passlist) return result
[docs] class EvalException(Exception): ...
[docs] class ValidationException(EvalException): ...
[docs] class CompilationException(EvalException): exc = None def __init__(self, exc): super().__init__(exc) self.exc = exc
[docs] class ExecutionException(EvalException): exc = None def __init__(self, exc): super().__init__(exc) self.exc = exc
[docs] class SafeAST(ast.NodeVisitor): """AST-tree walker class.""" allowed = {} def __init__(self, safenodes=None, addnodes=None, funcs=None, attrs=None): """create whitelist of allowed operations.""" if safenodes is not None: self.allowed = safenodes else: """ To generate these: # Get all subclasses of ast.AST recursively nodes = set() def recurse(cls): nodes.add(cls.__name__) for subclass in cls.__subclasses__(): recurse(subclass) blocklist = { # Base classes / abstract or special typing constructs not actual AST nodes 'AST', # base class, not a node type itself 'EnhancedAST', # not a standard ast node (likely custom or invalid) # Typing / parameter constructs, not AST nodes 'ParamSpec', 'TypeAlias', 'TypeIgnore', 'TypeVar', 'TypeVarTuple', 'Param', 'type_param', 'type_ignore', # Possibly ambiguous or internal helpers (not node classes) 'AugLoad', 'AugStore', # Special pseudo nodes or deprecated 'FunctionType', # not an AST node 'Suite', # does not exist in Python ast # Others that are unlikely to be AST nodes 'Div', # should be 'Div' operator but actually no class named Div; it's 'Div' operator type } recurse(ast.AST) final_list = [n for n in nodes if n not in blocklist and n.lower() != n] print(f'final_list = {ub.urepr(final_list, nl=1)}') [ 'Not', 'YieldFrom', 'Pass', 'BitAnd', 'DictComp', 'TryStar', 'LShift', 'Dict', 'AsyncWith', 'Num', 'MatMult', 'BinOp', 'AsyncFor', 'Ellipsis', 'MatchClass', 'NotIn', 'Del', 'MatchValue', 'GtE', 'ClassDef', 'Slice', 'NotEq', 'Constant', 'Starred', 'UAdd', 'Compare', 'Break', 'FloorDiv', 'GeneratorExp', 'Assert', 'Global', 'RShift', 'Set', 'Match', 'Str', 'UnaryOp', 'While', 'Add', 'MatchStar', 'AugAssign', 'Interactive', 'MatchSingleton', 'Is', 'Return', 'Expression', 'Lt', 'Attribute', 'Name', 'USub', 'Call', 'Continue', 'Await', 'Import', 'Nonlocal', 'BitOr', 'ListComp', 'Raise', 'Bytes', 'Mod', 'Lambda', 'If', 'Sub', 'Index', 'BoolOp', 'Module', 'MatchAs', 'AnnAssign', 'And', 'MatchSequence', 'IsNot', 'Delete', 'Load', 'LtE', 'FunctionDef', 'Subscript', 'List', 'FormattedValue', 'Invert', 'Yield', 'ImportFrom', 'Expr', 'BitXor', 'SetComp', 'Try', 'NameConstant', 'Pow', 'IfExp', 'With', 'Mult', 'ExtSlice', 'NamedExpr', 'MatchOr', 'ExceptHandler', 'For', 'Or', 'MatchMapping', 'In', 'Assign', 'Store', 'Gt', 'AsyncFunctionDef', 'Tuple', 'Eq', 'JoinedStr', ] """ # 123, 'asdf' values = ['Num', 'Str'] # any expression or constant expression = ['Expression'] constant = ['Constant', 'NameConstant'] # == ... compare = ['Compare', 'Eq', 'NotEq', 'Gt', 'GtE', 'Lt', 'LtE'] # variable name variables = ['Name', 'Load'] binop = ['BinOp'] arithmetics = ['Add', 'Sub', 'Mult', 'Div', 'FloorDiv', 'Mod', 'Pow', 'MatMult'] # added common arith ops, including MatMult and FloorDiv subscript = ['Subscript', 'Index', 'Slice'] # person['name'] boolop = ['BoolOp', 'And', 'Or', 'UnaryOp', 'Not', 'Invert', 'UAdd', 'USub'] # added Invert and unary +/- inop = ["In", "NotIn"] # "aaa" in i['list'] ifop = ["IfExp"] # for if expressions, like: expr1 if expr2 else expr3 div = ["Div", "Mod", "FloorDiv"] self.allowed = expression + constant + values + compare + \ variables + binop + arithmetics + subscript + boolop + \ inop + ifop + div self.allowed_funcs = funcs or list() self.allowed_attrs = attrs or list() if addnodes is not None: self.allowed = self.allowed + addnodes
[docs] def generic_visit(self, node): """Check node, raise exception if node is not in whitelist.""" if type(node).__name__ in self.allowed: if isinstance(node, ast.Attribute): if node.attr not in self.allowed_attrs: raise ValidationException( "Attribute {aname} is not allowed".format( aname=node.attr)) if isinstance(node, ast.Call): if isinstance(node.func, ast.Name): if node.func.id not in self.allowed_funcs: raise ValidationException( "Call to function {fname}() is not allowed".format( fname=node.func.id)) else: # Call to allowed function. good. No exception pass elif isinstance(node.func, ast.Attribute): pass # print("attr:", node.func.attr) else: raise ValidationException('Indirect function call') ast.NodeVisitor.generic_visit(self, node) else: raise ValidationException( "Operation type {optype} is not allowed".format( optype=type(node).__name__))
[docs] def evalidate(expression, safenodes=None, addnodes=None, funcs=None, attrs=None): """Validate expression. return node if it passes our checks or pass exception from SafeAST visit. """ try: node = ast.parse(expression, '<usercode>', 'eval') except SyntaxError as e: raise CompilationException(e) v = SafeAST(safenodes, addnodes, funcs, attrs) v.visit(node) return node
[docs] def safeeval(expression, context={}, safenodes=None, addnodes=None, funcs=None, attrs=None): """C-style simplified wrapper, eval() replacement. Args: expr (str): the expression to evaluate context (dict): Optional dictionary of variables to make available during evaluation. safenodes (List[str] | None): Specify the name of allowed AST nodes, if unspecified a default list is used. addnodes (List[str] | None): List of additional AST node names to allow in addition to safenodes. funcs (List[str]): list of allowed function names. attrs (List[str]): list of allowed attribute names. Returns: Any: the result of the expression Raises: ExecutionException - if the expression fails to execute CompilationException - if the expression fails to parse ValidationException - if the expression fails safety checks Example: >>> from kwutil.util_eval import safeeval >>> safeeval('3 + 2') 5 >>> safeeval('max(3, 2)', addnodes=['Call'], funcs=['max']) 3 >>> safeeval('x * 2', context={'x': 5}) 10 >>> safeeval('len(lst)', context={'lst': [1, 2, 3]}, addnodes=['Call'], funcs=['len']) 3 >>> safeeval('obj.value', context={'obj': type("O", (), {"value": 42})()}, addnodes=['Attribute'], attrs=['value']) 42 >>> import pytest >>> with pytest.raises(ValidationException): ... safeeval('exec("import os")') >>> with pytest.raises(ValidationException): ... safeeval('os.system("ls")', context={'os': __import__('os')}, addnodes=['Call', 'Attribute'], funcs=[]) """ # ValidationException thrown here node = evalidate(expression, safenodes, addnodes, funcs, attrs) code = compile(node, '<usercode>', 'eval') wcontext = context.copy() try: result = eval(code, wcontext) except Exception as e: raise ExecutionException(e) return result