186 lines
5.7 KiB
Python
186 lines
5.7 KiB
Python
|
# Copyright 2017 The Abseil Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
"""Decorator and context manager for saving and restoring flag values.
|
||
|
|
||
|
There are many ways to save and restore. Always use the most convenient method
|
||
|
for a given use case.
|
||
|
|
||
|
Here are examples of each method. They all call do_stuff() while FLAGS.someflag
|
||
|
is temporarily set to 'foo'.
|
||
|
|
||
|
from absl.testing import flagsaver
|
||
|
|
||
|
# Use a decorator which can optionally override flags via arguments.
|
||
|
@flagsaver.flagsaver(someflag='foo')
|
||
|
def some_func():
|
||
|
do_stuff()
|
||
|
|
||
|
# Use a decorator which does not override flags itself.
|
||
|
@flagsaver.flagsaver
|
||
|
def some_func():
|
||
|
FLAGS.someflag = 'foo'
|
||
|
do_stuff()
|
||
|
|
||
|
# Use a context manager which can optionally override flags via arguments.
|
||
|
with flagsaver.flagsaver(someflag='foo'):
|
||
|
do_stuff()
|
||
|
|
||
|
# Save and restore the flag values yourself.
|
||
|
saved_flag_values = flagsaver.save_flag_values()
|
||
|
try:
|
||
|
FLAGS.someflag = 'foo'
|
||
|
do_stuff()
|
||
|
finally:
|
||
|
flagsaver.restore_flag_values(saved_flag_values)
|
||
|
|
||
|
We save and restore a shallow copy of each Flag object's __dict__ attribute.
|
||
|
This preserves all attributes of the flag, such as whether or not it was
|
||
|
overridden from its default value.
|
||
|
|
||
|
WARNING: Currently a flag that is saved and then deleted cannot be restored. An
|
||
|
exception will be raised. However if you *add* a flag after saving flag values,
|
||
|
and then restore flag values, the added flag will be deleted with no errors.
|
||
|
"""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import functools
|
||
|
import inspect
|
||
|
|
||
|
from absl import flags
|
||
|
import six
|
||
|
|
||
|
FLAGS = flags.FLAGS
|
||
|
|
||
|
|
||
|
def flagsaver(*args, **kwargs):
|
||
|
"""The main flagsaver interface. See module doc for usage."""
|
||
|
if not args:
|
||
|
return _FlagOverrider(**kwargs)
|
||
|
elif len(args) == 1:
|
||
|
if kwargs:
|
||
|
raise ValueError(
|
||
|
"It's invalid to specify both positional and keyword parameters.")
|
||
|
func = args[0]
|
||
|
if inspect.isclass(func):
|
||
|
raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
|
||
|
return _wrap(func, {})
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"It's invalid to specify more than one positional parameters.")
|
||
|
|
||
|
|
||
|
def save_flag_values(flag_values=FLAGS):
|
||
|
"""Returns copy of flag values as a dict.
|
||
|
|
||
|
Args:
|
||
|
flag_values: FlagValues, the FlagValues instance with which the flag will
|
||
|
be saved. This should almost never need to be overridden.
|
||
|
Returns:
|
||
|
Dictionary mapping keys to values. Keys are flag names, values are
|
||
|
corresponding __dict__ members. E.g. {'key': value_dict, ...}.
|
||
|
"""
|
||
|
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
|
||
|
|
||
|
|
||
|
def restore_flag_values(saved_flag_values, flag_values=FLAGS):
|
||
|
"""Restores flag values based on the dictionary of flag values.
|
||
|
|
||
|
Args:
|
||
|
saved_flag_values: {'flag_name': value_dict, ...}
|
||
|
flag_values: FlagValues, the FlagValues instance from which the flag will
|
||
|
be restored. This should almost never need to be overridden.
|
||
|
"""
|
||
|
new_flag_names = list(flag_values)
|
||
|
for name in new_flag_names:
|
||
|
saved = saved_flag_values.get(name)
|
||
|
if saved is None:
|
||
|
# If __dict__ was not saved delete "new" flag.
|
||
|
delattr(flag_values, name)
|
||
|
else:
|
||
|
if flag_values[name].value != saved['_value']:
|
||
|
flag_values[name].value = saved['_value'] # Ensure C++ value is set.
|
||
|
flag_values[name].__dict__ = saved
|
||
|
|
||
|
|
||
|
def _wrap(func, overrides):
|
||
|
"""Creates a wrapper function that saves/restores flag values.
|
||
|
|
||
|
Args:
|
||
|
func: function object - This will be called between saving flags and
|
||
|
restoring flags.
|
||
|
overrides: {str: object} - Flag names mapped to their values. These flags
|
||
|
will be set after saving the original flag state.
|
||
|
|
||
|
Returns:
|
||
|
return value from func()
|
||
|
"""
|
||
|
@functools.wraps(func)
|
||
|
def _flagsaver_wrapper(*args, **kwargs):
|
||
|
"""Wrapper function that saves and restores flags."""
|
||
|
with _FlagOverrider(**overrides):
|
||
|
return func(*args, **kwargs)
|
||
|
return _flagsaver_wrapper
|
||
|
|
||
|
|
||
|
class _FlagOverrider(object):
|
||
|
"""Overrides flags for the duration of the decorated function call.
|
||
|
|
||
|
It also restores all original values of flags after decorated method
|
||
|
completes.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, **overrides):
|
||
|
self._overrides = overrides
|
||
|
self._saved_flag_values = None
|
||
|
|
||
|
def __call__(self, func):
|
||
|
if inspect.isclass(func):
|
||
|
raise TypeError('flagsaver cannot be applied to a class.')
|
||
|
return _wrap(func, self._overrides)
|
||
|
|
||
|
def __enter__(self):
|
||
|
self._saved_flag_values = save_flag_values(FLAGS)
|
||
|
try:
|
||
|
for name, value in six.iteritems(self._overrides):
|
||
|
setattr(FLAGS, name, value)
|
||
|
except:
|
||
|
# It may fail because of flag validators.
|
||
|
restore_flag_values(self._saved_flag_values, FLAGS)
|
||
|
raise
|
||
|
|
||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||
|
restore_flag_values(self._saved_flag_values, FLAGS)
|
||
|
|
||
|
|
||
|
def _copy_flag_dict(flag):
|
||
|
"""Returns a copy of the flag object's __dict__.
|
||
|
|
||
|
It's mostly a shallow copy of the __dict__, except it also does a shallow
|
||
|
copy of the validator list.
|
||
|
|
||
|
Args:
|
||
|
flag: flags.Flag, the flag to copy.
|
||
|
|
||
|
Returns:
|
||
|
A copy of the flag object's __dict__.
|
||
|
"""
|
||
|
copy = flag.__dict__.copy()
|
||
|
copy['_value'] = flag.value # Ensure correct restore for C++ flags.
|
||
|
copy['validators'] = list(flag.validators)
|
||
|
return copy
|