526 lines
19 KiB
Python
526 lines
19 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
# TODO(shivaniagrawal): Merge with core nest
|
||
|
"""## Functions for working with arbitrarily nested sequences of elements.
|
||
|
|
||
|
NOTE(mrry): This fork of the `tensorflow.python.util.nest` module
|
||
|
makes two changes:
|
||
|
|
||
|
1. It removes support for lists as a level of nesting in nested structures.
|
||
|
2. It adds support for `SparseTensorValue` as an atomic element.
|
||
|
|
||
|
The motivation for this change is twofold:
|
||
|
|
||
|
1. It seems more natural for lists to be treated (e.g. in Dataset constructors)
|
||
|
as tensors, rather than lists of (lists of...) tensors.
|
||
|
2. This is needed because `SparseTensorValue` is implemented as a `namedtuple`
|
||
|
that would normally be flattened and we want to be able to create sparse
|
||
|
tensor from `SparseTensorValue's similarly to creating tensors from numpy
|
||
|
arrays.
|
||
|
"""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import collections as _collections
|
||
|
|
||
|
import six as _six
|
||
|
|
||
|
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||
|
from tensorflow.python.framework import sparse_tensor as _sparse_tensor
|
||
|
|
||
|
|
||
|
def _sorted(dict_):
|
||
|
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
|
||
|
try:
|
||
|
return sorted(_six.iterkeys(dict_))
|
||
|
except TypeError:
|
||
|
raise TypeError("nest only supports dicts with sortable keys.")
|
||
|
|
||
|
|
||
|
def _sequence_like(instance, args):
|
||
|
"""Converts the sequence `args` to the same type as `instance`.
|
||
|
|
||
|
Args:
|
||
|
instance: an instance of `tuple`, `list`, or a `namedtuple` class.
|
||
|
args: elements to be converted to a sequence.
|
||
|
|
||
|
Returns:
|
||
|
`args` with the type of `instance`.
|
||
|
"""
|
||
|
if isinstance(instance, dict):
|
||
|
# Pack dictionaries in a deterministic order by sorting the keys.
|
||
|
# Notice this means that we ignore the original order of `OrderedDict`
|
||
|
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||
|
# ordered and plain dicts (e.g., flattening a dict but using a
|
||
|
# corresponding `OrderedDict` to pack it back).
|
||
|
result = dict(zip(_sorted(instance), args))
|
||
|
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
|
||
|
elif (isinstance(instance, tuple) and
|
||
|
hasattr(instance, "_fields") and
|
||
|
isinstance(instance._fields, _collections.Sequence) and
|
||
|
all(isinstance(f, _six.string_types) for f in instance._fields)):
|
||
|
# This is a namedtuple
|
||
|
return type(instance)(*args)
|
||
|
else:
|
||
|
# Not a namedtuple
|
||
|
return type(instance)(args)
|
||
|
|
||
|
|
||
|
def _yield_value(iterable):
|
||
|
if isinstance(iterable, dict):
|
||
|
# Iterate through dictionaries in a deterministic order by sorting the
|
||
|
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
||
|
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||
|
# ordered and plain dicts (e.g., flattening a dict but using a
|
||
|
# corresponding `OrderedDict` to pack it back).
|
||
|
for key in _sorted(iterable):
|
||
|
yield iterable[key]
|
||
|
elif isinstance(iterable, _sparse_tensor.SparseTensorValue):
|
||
|
yield iterable
|
||
|
else:
|
||
|
for value in iterable:
|
||
|
yield value
|
||
|
|
||
|
|
||
|
def is_sequence(seq):
|
||
|
"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
|
||
|
|
||
|
NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
|
||
|
which *does* treat a Python list as a sequence. For ergonomic
|
||
|
reasons, `tf.data` users would prefer to treat lists as
|
||
|
implicit `tf.Tensor` objects, and dicts as (nested) sequences.
|
||
|
|
||
|
Args:
|
||
|
seq: an input sequence.
|
||
|
|
||
|
Returns:
|
||
|
True if the sequence is a not a string or list and is a
|
||
|
collections.Sequence.
|
||
|
"""
|
||
|
return _pywrap_tensorflow.IsSequenceForData(seq)
|
||
|
|
||
|
|
||
|
def flatten(nest):
|
||
|
"""Returns a flat sequence from a given nested structure.
|
||
|
|
||
|
If `nest` is not a sequence, this returns a single-element list: `[nest]`.
|
||
|
|
||
|
Args:
|
||
|
nest: an arbitrarily nested structure or a scalar object.
|
||
|
Note, numpy arrays are considered scalars.
|
||
|
|
||
|
Returns:
|
||
|
A Python list, the flattened version of the input.
|
||
|
"""
|
||
|
return _pywrap_tensorflow.FlattenForData(nest)
|
||
|
|
||
|
|
||
|
def _recursive_assert_same_structure(nest1, nest2, check_types):
|
||
|
is_sequence_nest1 = is_sequence(nest1)
|
||
|
if is_sequence_nest1 != is_sequence(nest2):
|
||
|
raise ValueError(
|
||
|
"The two structures don't have the same nested structure. "
|
||
|
"First structure: %s, second structure: %s." % (nest1, nest2))
|
||
|
|
||
|
if is_sequence_nest1:
|
||
|
type_nest1 = type(nest1)
|
||
|
type_nest2 = type(nest2)
|
||
|
if check_types and type_nest1 != type_nest2:
|
||
|
raise TypeError(
|
||
|
"The two structures don't have the same sequence type. First "
|
||
|
"structure has type %s, while second structure has type %s."
|
||
|
% (type_nest1, type_nest2))
|
||
|
|
||
|
for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
|
||
|
_recursive_assert_same_structure(n1, n2, check_types)
|
||
|
|
||
|
|
||
|
def assert_same_structure(nest1, nest2, check_types=True):
|
||
|
"""Asserts that two structures are nested in the same way.
|
||
|
|
||
|
Args:
|
||
|
nest1: an arbitrarily nested structure.
|
||
|
nest2: an arbitrarily nested structure.
|
||
|
check_types: if `True` (default) types of sequences are checked as
|
||
|
well. If set to `False`, for example a list and a tuple of objects will
|
||
|
look same if they have the same size.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the two structures do not have the same number of elements or
|
||
|
if the two structures are not nested in the same way.
|
||
|
TypeError: If the two structures differ in the type of sequence in any of
|
||
|
their substructures. Only possible if `check_types` is `True`.
|
||
|
"""
|
||
|
len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
|
||
|
len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
|
||
|
if len_nest1 != len_nest2:
|
||
|
raise ValueError("The two structures don't have the same number of "
|
||
|
"elements. First structure: %s, second structure: %s."
|
||
|
% (nest1, nest2))
|
||
|
_recursive_assert_same_structure(nest1, nest2, check_types)
|
||
|
|
||
|
|
||
|
def _packed_nest_with_indices(structure, flat, index):
|
||
|
"""Helper function for pack_nest_as.
|
||
|
|
||
|
Args:
|
||
|
structure: Substructure (tuple of elements and/or tuples) to mimic
|
||
|
flat: Flattened values to output substructure for.
|
||
|
index: Index at which to start reading from flat.
|
||
|
|
||
|
Returns:
|
||
|
The tuple (new_index, child), where:
|
||
|
* new_index - the updated index into `flat` having processed `structure`.
|
||
|
* packed - the subset of `flat` corresponding to `structure`,
|
||
|
having started at `index`, and packed into the same nested
|
||
|
format.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if `structure` contains more elements than `flat`
|
||
|
(assuming indexing starts from `index`).
|
||
|
"""
|
||
|
packed = []
|
||
|
for s in _yield_value(structure):
|
||
|
if is_sequence(s):
|
||
|
new_index, child = _packed_nest_with_indices(s, flat, index)
|
||
|
packed.append(_sequence_like(s, child))
|
||
|
index = new_index
|
||
|
else:
|
||
|
packed.append(flat[index])
|
||
|
index += 1
|
||
|
return index, packed
|
||
|
|
||
|
|
||
|
def pack_sequence_as(structure, flat_sequence):
|
||
|
"""Returns a given flattened sequence packed into a nest.
|
||
|
|
||
|
If `structure` is a scalar, `flat_sequence` must be a single-element list;
|
||
|
in this case the return value is `flat_sequence[0]`.
|
||
|
|
||
|
Args:
|
||
|
structure: tuple or list constructed of scalars and/or other tuples/lists,
|
||
|
or a scalar. Note: numpy arrays are considered scalars.
|
||
|
flat_sequence: flat sequence to pack.
|
||
|
|
||
|
Returns:
|
||
|
packed: `flat_sequence` converted to have the same recursive structure as
|
||
|
`structure`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If nest and structure have different element counts.
|
||
|
"""
|
||
|
if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)):
|
||
|
raise TypeError("flat_sequence must be a sequence")
|
||
|
|
||
|
if not is_sequence(structure):
|
||
|
if len(flat_sequence) != 1:
|
||
|
raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1"
|
||
|
% len(flat_sequence))
|
||
|
return flat_sequence[0]
|
||
|
|
||
|
flat_structure = flatten(structure)
|
||
|
if len(flat_structure) != len(flat_sequence):
|
||
|
raise ValueError(
|
||
|
"Could not pack sequence. Structure had %d elements, but flat_sequence "
|
||
|
"had %d elements. Structure: %s, flat_sequence: %s."
|
||
|
% (len(flat_structure), len(flat_sequence), structure, flat_sequence))
|
||
|
|
||
|
_, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
|
||
|
return _sequence_like(structure, packed)
|
||
|
|
||
|
|
||
|
def map_structure(func, *structure, **check_types_dict):
|
||
|
"""Applies `func` to each entry in `structure` and returns a new structure.
|
||
|
|
||
|
Applies `func(x[0], x[1], ...)` where x[i] is an entry in
|
||
|
`structure[i]`. All structures in `structure` must have the same arity,
|
||
|
and the return value will contain the results in the same structure.
|
||
|
|
||
|
Args:
|
||
|
func: A callable that accepts as many arguments are there are structures.
|
||
|
*structure: scalar, or tuple or list of constructed scalars and/or other
|
||
|
tuples/lists, or scalars. Note: numpy arrays are considered scalars.
|
||
|
**check_types_dict: only valid keyword argument is `check_types`. If set to
|
||
|
`True` (default) the types of iterables within the structures have to be
|
||
|
same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
|
||
|
exception). To allow this set this argument to `False`.
|
||
|
|
||
|
Returns:
|
||
|
A new structure with the same arity as `structure`, whose values correspond
|
||
|
to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
|
||
|
location in `structure[i]`. If there are different sequence types and
|
||
|
`check_types` is `False` the sequence types of the first structure will be
|
||
|
used.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `func` is not callable or if the structures do not match
|
||
|
each other by depth tree.
|
||
|
ValueError: If no structure is provided or if the structures do not match
|
||
|
each other by type.
|
||
|
ValueError: If wrong keyword arguments are provided.
|
||
|
"""
|
||
|
if not callable(func):
|
||
|
raise TypeError("func must be callable, got: %s" % func)
|
||
|
|
||
|
if not structure:
|
||
|
raise ValueError("Must provide at least one structure")
|
||
|
|
||
|
if check_types_dict:
|
||
|
if "check_types" not in check_types_dict or len(check_types_dict) > 1:
|
||
|
raise ValueError("Only valid keyword argument is check_types")
|
||
|
check_types = check_types_dict["check_types"]
|
||
|
else:
|
||
|
check_types = True
|
||
|
|
||
|
for other in structure[1:]:
|
||
|
assert_same_structure(structure[0], other, check_types=check_types)
|
||
|
|
||
|
flat_structure = [flatten(s) for s in structure]
|
||
|
entries = zip(*flat_structure)
|
||
|
|
||
|
return pack_sequence_as(
|
||
|
structure[0], [func(*x) for x in entries])
|
||
|
|
||
|
|
||
|
def _yield_flat_up_to(shallow_tree, input_tree):
|
||
|
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
||
|
if is_sequence(shallow_tree):
|
||
|
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
||
|
_yield_value(input_tree)):
|
||
|
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
||
|
yield input_leaf
|
||
|
else:
|
||
|
yield input_tree
|
||
|
|
||
|
|
||
|
def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
||
|
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
|
||
|
|
||
|
That is, this function tests if the `input_tree` structure can be created from
|
||
|
the `shallow_tree` structure by replacing its leaf nodes with deeper
|
||
|
tree structures.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
The following code will raise an exception:
|
||
|
```python
|
||
|
shallow_tree = ["a", "b"]
|
||
|
input_tree = ["c", ["d", "e"], "f"]
|
||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||
|
```
|
||
|
|
||
|
The following code will not raise an exception:
|
||
|
```python
|
||
|
shallow_tree = ["a", "b"]
|
||
|
input_tree = ["c", ["d", "e"]]
|
||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
shallow_tree: an arbitrarily nested structure.
|
||
|
input_tree: an arbitrarily nested structure.
|
||
|
check_types: if `True` (default) the sequence types of `shallow_tree` and
|
||
|
`input_tree` have to be the same.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||
|
`input_tree`. Only raised if `check_types` is `True`.
|
||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||
|
`input_tree`.
|
||
|
"""
|
||
|
if is_sequence(shallow_tree):
|
||
|
if not is_sequence(input_tree):
|
||
|
raise TypeError(
|
||
|
"If shallow structure is a sequence, input must also be a sequence. "
|
||
|
"Input has type: %s." % type(input_tree))
|
||
|
|
||
|
if check_types and not isinstance(input_tree, type(shallow_tree)):
|
||
|
raise TypeError(
|
||
|
"The two structures don't have the same sequence type. Input "
|
||
|
"structure has type %s, while shallow structure has type %s."
|
||
|
% (type(input_tree), type(shallow_tree)))
|
||
|
|
||
|
if len(input_tree) != len(shallow_tree):
|
||
|
raise ValueError(
|
||
|
"The two structures don't have the same sequence length. Input "
|
||
|
"structure has length %s, while shallow structure has length %s."
|
||
|
% (len(input_tree), len(shallow_tree)))
|
||
|
|
||
|
if check_types and isinstance(shallow_tree, dict):
|
||
|
if set(input_tree) != set(shallow_tree):
|
||
|
raise ValueError(
|
||
|
"The two structures don't have the same keys. Input "
|
||
|
"structure has keys %s, while shallow structure has keys %s." %
|
||
|
(list(_six.iterkeys(input_tree)),
|
||
|
list(_six.iterkeys(shallow_tree))))
|
||
|
input_tree = list(sorted(_six.iteritems(input_tree)))
|
||
|
shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
|
||
|
|
||
|
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
|
||
|
assert_shallow_structure(shallow_branch, input_branch,
|
||
|
check_types=check_types)
|
||
|
|
||
|
|
||
|
def flatten_up_to(shallow_tree, input_tree):
|
||
|
"""Flattens `input_tree` up to `shallow_tree`.
|
||
|
|
||
|
Any further depth in structure in `input_tree` is retained as elements in the
|
||
|
partially flatten output.
|
||
|
|
||
|
If `shallow_tree` and `input_tree` are not sequences, this returns a
|
||
|
single-element list: `[input_tree]`.
|
||
|
|
||
|
Use Case:
|
||
|
|
||
|
Sometimes we may wish to partially flatten a nested sequence, retaining some
|
||
|
of the nested structure. We achieve this by specifying a shallow structure,
|
||
|
`shallow_tree`, we wish to flatten up to.
|
||
|
|
||
|
The input, `input_tree`, can be thought of as having the same structure as
|
||
|
`shallow_tree`, but with leaf nodes that are themselves tree structures.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
|
||
|
shallow_tree = [[True, True], [False, True]]
|
||
|
|
||
|
flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
|
||
|
flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
|
||
|
|
||
|
# Output is:
|
||
|
# [[2, 2], [3, 3], [4, 9], [5, 5]]
|
||
|
# [True, True, False, True]
|
||
|
```
|
||
|
|
||
|
```python
|
||
|
input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
|
||
|
shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
|
||
|
|
||
|
input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
|
||
|
input_tree_flattened = flatten(input_tree)
|
||
|
|
||
|
# Output is:
|
||
|
# [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
|
||
|
# ['a', 1, 'b', 2, 'c', 3, 'd', 4]
|
||
|
```
|
||
|
|
||
|
Non-Sequence Edge Cases:
|
||
|
|
||
|
```python
|
||
|
flatten_up_to(0, 0) # Output: [0]
|
||
|
flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
|
||
|
flatten_up_to([0, 1, 2], 0) # Output: TypeError
|
||
|
flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
shallow_tree: a possibly pruned structure of input_tree.
|
||
|
input_tree: an arbitrarily nested structure or a scalar object.
|
||
|
Note, numpy arrays are considered scalars.
|
||
|
|
||
|
Returns:
|
||
|
A Python list, the partially flattened version of `input_tree` according to
|
||
|
the structure of `shallow_tree`.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||
|
`input_tree`.
|
||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||
|
`input_tree`.
|
||
|
"""
|
||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||
|
return list(_yield_flat_up_to(shallow_tree, input_tree))
|
||
|
|
||
|
|
||
|
def map_structure_up_to(shallow_tree, func, *inputs):
|
||
|
"""Applies a function or op to a number of partially flattened inputs.
|
||
|
|
||
|
The `inputs` are flattened up to `shallow_tree` before being mapped.
|
||
|
|
||
|
Use Case:
|
||
|
|
||
|
Sometimes we wish to apply a function to a partially flattened
|
||
|
sequence (for example when the function itself takes sequence inputs). We
|
||
|
achieve this by specifying a shallow structure, `shallow_tree` we wish to
|
||
|
flatten up to.
|
||
|
|
||
|
The `inputs`, can be thought of as having the same structure as
|
||
|
`shallow_tree`, but with leaf nodes that are themselves tree structures.
|
||
|
|
||
|
This function, therefore, will return something with the same base structure
|
||
|
as `shallow_tree`.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||
|
op_tuple = collections.namedtuple("op_tuple", "add, mul")
|
||
|
inp_val = ab_tuple(a=2, b=3)
|
||
|
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
|
||
|
out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
|
||
|
inp_val, inp_ops)
|
||
|
|
||
|
# Output is: ab_tuple(a=6, b=15)
|
||
|
```
|
||
|
|
||
|
```python
|
||
|
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
|
||
|
name_list = ['evens', ['odds', 'primes']]
|
||
|
out = map_structure_up_to(
|
||
|
name_list,
|
||
|
lambda name, sec: "first_{}_{}".format(len(sec), name),
|
||
|
name_list, data_list)
|
||
|
|
||
|
# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
shallow_tree: a shallow tree, common to all the inputs.
|
||
|
func: callable which will be applied to each input individually.
|
||
|
*inputs: arbitrarily nested combination of objects that are compatible with
|
||
|
shallow_tree. The function `func` is applied to corresponding
|
||
|
partially flattened elements of each input, so the function must support
|
||
|
arity of `len(inputs)`.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||
|
`input_tree`.
|
||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||
|
`input_tree`.
|
||
|
|
||
|
Returns:
|
||
|
result of repeatedly applying `func`, with same structure as
|
||
|
`shallow_tree`.
|
||
|
"""
|
||
|
if not inputs:
|
||
|
raise ValueError("Cannot map over no sequences")
|
||
|
for input_tree in inputs:
|
||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||
|
|
||
|
# Flatten each input separately, apply the function to corresponding elements,
|
||
|
# then repack based on the structure of the first input.
|
||
|
all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
|
||
|
for input_tree in inputs]
|
||
|
|
||
|
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
|
||
|
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
|