Source code for kgcnn.ops.axis

[docs]def broadcast_shapes(shape1, shape2): """Broadcast input shapes to a unified shape. Convert to list for mutability. Args: shape1: A tuple or list of integers. shape2: A tuple or list of integers. Returns: output_shape (list of integers or `None`): The broadcasted shape. Example: >>> broadcast_shapes((5, 3), (1, 3)) [5, 3] """ shape1 = list(shape1) shape2 = list(shape2) origin_shape1 = shape1 origin_shape2 = shape2 if len(shape1) > len(shape2): shape2 = [1] * (len(shape1) - len(shape2)) + shape2 if len(shape1) < len(shape2): shape1 = [1] * (len(shape2) - len(shape1)) + shape1 output_shape = list(shape1) for i in range(len(shape1)): if shape1[i] == 1: output_shape[i] = shape2[i] elif shape1[i] is None: output_shape[i] = None if shape2[i] == 1 else shape2[i] else: if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]: output_shape[i] = shape1[i] else: raise ValueError( "Cannot broadcast shape, the failure dim has value " f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. " f"Input shapes are: {origin_shape1} and {origin_shape2}." ) return output_shape
# Found in ops/array_ops and copied here for static reference
[docs]def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"): """Validate an `axis` parameter, and normalize it to be positive. If `ndims` is known (i.e., not `None`), then check that `axis` is in the range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or `axis + ndims` (otherwise). If `ndims` is not known, and `axis` is positive, then return it as-is. If `ndims` is not known, and `axis` is negative, then report an error. Args: axis: An integer constant ndims: An integer constant, or `None` axis_name: The name of `axis` (for error messages). ndims_name: The name of `ndims` (for error messages). Returns: The normalized `axis` value. Raises: ValueError: If `axis` is out-of-bounds, or if `axis` is negative and `ndims is None`. """ if not isinstance(axis, int): raise TypeError("%s must be an int; got %s" % (axis_name, type(axis).__name__)) if ndims is not None: if 0 <= axis < ndims: return axis elif -ndims <= axis < 0: return axis + ndims else: raise ValueError("%s=%s out of bounds: expected %s<=%s<%s" % (axis_name, axis, -ndims, axis_name, ndims)) elif axis < 0: raise ValueError("%s may only be negative if %s is statically known." % (axis_name, ndims_name)) return axis