#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""iobase.RangeTracker implementations provided with Apache Beam.
"""
# pytype: skip-file
import codecs
import logging
import math
import threading
from typing import Union
from apache_beam.io import iobase
__all__ = [
'OffsetRangeTracker',
'LexicographicKeyRangeTracker',
'OrderedPositionRangeTracker',
'UnsplittableRangeTracker'
]
_LOGGER = logging.getLogger(__name__)
[docs]
class OffsetRangeTracker(iobase.RangeTracker):
"""A 'RangeTracker' for non-negative positions of type 'long'."""
# Offset corresponding to infinity. This can only be used as the upper-bound
# of a range, and indicates reading all of the records until the end without
# specifying exactly what the end is.
# Infinite ranges cannot be split because it is impossible to estimate
# progress within them.
OFFSET_INFINITY = float('inf')
def __init__(self, start, end):
super().__init__()
if start is None:
raise ValueError('Start offset must not be \'None\'')
if end is None:
raise ValueError('End offset must not be \'None\'')
assert isinstance(start, int)
if end != self.OFFSET_INFINITY:
assert isinstance(end, int)
assert start <= end
self._start_offset = start
self._stop_offset = end
self._last_record_start = -1
self._last_attempted_record_start = -1
self._offset_of_last_split_point = -1
self._lock = threading.Lock()
self._split_points_seen = 0
self._split_points_unclaimed_callback = None
[docs]
def start_position(self):
return self._start_offset
[docs]
def stop_position(self):
return self._stop_offset
@property
def last_record_start(self):
return self._last_record_start
@property
def last_attempted_record_start(self):
"""Return current value of last_attempted_record_start.
last_attempted_record_start records a valid position that tried to be
claimed by calling try_claim(). This value is only updated by `try_claim()`
no matter `try_claim()` returns `True` or `False`.
"""
return self._last_attempted_record_start
def _validate_record_start(self, record_start, split_point):
# This function must only be called under the lock self.lock.
if not self._lock.locked():
raise ValueError(
'This function must only be called under the lock self.lock.')
if record_start < self._last_record_start:
raise ValueError(
'Trying to return a record [starting at %d] which is before the '
'last-returned record [starting at %d]' %
(record_start, self._last_record_start))
if (split_point and self._offset_of_last_split_point != -1 and
record_start == self._offset_of_last_split_point):
raise ValueError(
'Record at a split point has same offset as the previous split '
'point: %d' % record_start)
if not split_point and self._last_record_start == -1:
raise ValueError(
'The first record [starting at %d] must be at a split point' %
record_start)
[docs]
def try_claim(self, record_start):
with self._lock:
# Attempted claim should be monotonous.
if record_start <= self._last_attempted_record_start:
raise ValueError(
'Trying to return a record [starting at %d] which is not greater'
'than the last-attempted record [starting at %d]' %
(record_start, self._last_attempted_record_start))
self._validate_record_start(record_start, True)
self._last_attempted_record_start = record_start
if record_start >= self.stop_position():
return False
self._offset_of_last_split_point = record_start
self._last_record_start = record_start
self._split_points_seen += 1
return True
[docs]
def set_current_position(self, record_start):
with self._lock:
self._validate_record_start(record_start, False)
self._last_record_start = record_start
[docs]
def try_split(self, split_offset):
assert isinstance(split_offset, int)
with self._lock:
if self._stop_offset == OffsetRangeTracker.OFFSET_INFINITY:
_LOGGER.debug(
'refusing to split %r at %d: stop position unspecified',
self,
split_offset)
return
if self._last_record_start == -1:
_LOGGER.debug(
'Refusing to split %r at %d: unstarted', self, split_offset)
return
if split_offset <= self._last_record_start:
_LOGGER.debug(
'Refusing to split %r at %d: already past proposed stop offset',
self,
split_offset)
return
if (split_offset < self.start_position() or
split_offset >= self.stop_position()):
_LOGGER.debug(
'Refusing to split %r at %d: proposed split position out of range',
self,
split_offset)
return
_LOGGER.debug('Agreeing to split %r at %d', self, split_offset)
split_fraction = (
float(split_offset - self._start_offset) /
(self._stop_offset - self._start_offset))
self._stop_offset = split_offset
return self._stop_offset, split_fraction
[docs]
def fraction_consumed(self):
with self._lock:
# self.last_record_start may become larger than self.end_offset when
# reading the records since any record that starts before the first 'split
# point' at or after the defined 'stop offset' is considered to be within
# the range of the OffsetRangeTracker. Hence fraction could be > 1.
# self.last_record_start is initialized to -1, hence fraction may be < 0.
# Bounding the to range [0, 1].
return self.position_to_fraction(
self._last_record_start, self.start_position(), self.stop_position())
[docs]
def position_to_fraction(self, pos, start, stop):
fraction = 1.0 * (pos - start) / (stop - start) if start != stop else 0.0
return max(0.0, min(1.0, fraction))
[docs]
def position_at_fraction(self, fraction):
if self.stop_position() == OffsetRangeTracker.OFFSET_INFINITY:
raise Exception(
'get_position_for_fraction_consumed is not applicable for an '
'unbounded range')
return int(
math.ceil(
self.start_position() + fraction *
(self.stop_position() - self.start_position())))
[docs]
def split_points(self):
with self._lock:
split_points_consumed = (
0 if self._split_points_seen == 0 else self._split_points_seen - 1)
split_points_unclaimed = (
self._split_points_unclaimed_callback(self.stop_position())
if self._split_points_unclaimed_callback else
iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
split_points_remaining = (
iobase.RangeTracker.SPLIT_POINTS_UNKNOWN
if split_points_unclaimed == iobase.RangeTracker.SPLIT_POINTS_UNKNOWN
else (split_points_unclaimed + 1))
return (split_points_consumed, split_points_remaining)
[docs]
def set_split_points_unclaimed_callback(self, callback):
self._split_points_unclaimed_callback = callback
[docs]
class OrderedPositionRangeTracker(iobase.RangeTracker):
"""
An abstract base class for range trackers whose positions are comparable.
Subclasses only need to implement the mapping from position ranges
to and from the closed interval [0, 1].
"""
UNSTARTED = object()
def __init__(self, start_position=None, stop_position=None):
self._start_position = start_position
self._stop_position = stop_position
self._lock = threading.Lock()
self._last_claim = self.UNSTARTED
[docs]
def start_position(self):
return self._start_position
[docs]
def stop_position(self):
with self._lock:
return self._stop_position
[docs]
def try_claim(self, position):
with self._lock:
if self._last_claim is not self.UNSTARTED and position < self._last_claim:
raise ValueError(
"Positions must be claimed in order: "
"claim '%s' attempted after claim '%s'" %
(position, self._last_claim))
elif self._start_position is not None and position < self._start_position:
raise ValueError(
"Claim '%s' is before start '%s'" %
(position, self._start_position))
if self._stop_position is None or position < self._stop_position:
self._last_claim = position
return True
else:
return False
[docs]
def position_at_fraction(self, fraction):
return self.fraction_to_position(
fraction, self._start_position, self._stop_position)
[docs]
def try_split(self, position):
with self._lock:
if ((self._stop_position is not None and position >= self._stop_position)
or (self._start_position is not None and
position <= self._start_position)):
_LOGGER.debug(
'Refusing to split %r at %d: proposed split position out of range',
self,
position)
return
if self._last_claim is self.UNSTARTED or self._last_claim < position:
fraction = self.position_to_fraction(
position, start=self._start_position, end=self._stop_position)
self._stop_position = position
return position, fraction
[docs]
def fraction_consumed(self):
if self._last_claim is self.UNSTARTED:
return 0
else:
return self.position_to_fraction(
self._last_claim, self._start_position, self._stop_position)
[docs]
def fraction_to_position(self, fraction, start, end):
"""
Converts a fraction between 0 and 1 to a position between start and end.
"""
raise NotImplementedError
[docs]
def position_to_fraction(self, position, start, end):
"""Returns the fraction of keys in the range [start, end) that
are less than the given key.
"""
raise NotImplementedError
[docs]
class UnsplittableRangeTracker(iobase.RangeTracker):
"""A RangeTracker that always ignores split requests.
This can be used to make a given
:class:`~apache_beam.io.iobase.RangeTracker` object unsplittable by
ignoring all calls to :meth:`.try_split()`. All other calls will be delegated
to the given :class:`~apache_beam.io.iobase.RangeTracker`.
"""
def __init__(self, range_tracker):
"""Initializes UnsplittableRangeTracker.
Args:
range_tracker (~apache_beam.io.iobase.RangeTracker): a
:class:`~apache_beam.io.iobase.RangeTracker` to which all method
calls except calls to :meth:`.try_split()` will be delegated.
"""
assert isinstance(range_tracker, iobase.RangeTracker)
self._range_tracker = range_tracker
[docs]
def start_position(self):
return self._range_tracker.start_position()
[docs]
def stop_position(self):
return self._range_tracker.stop_position()
[docs]
def position_at_fraction(self, fraction):
return self._range_tracker.position_at_fraction(fraction)
[docs]
def try_claim(self, position):
return self._range_tracker.try_claim(position)
[docs]
def try_split(self, position):
return None
[docs]
def set_current_position(self, position):
self._range_tracker.set_current_position(position)
[docs]
def fraction_consumed(self):
return self._range_tracker.fraction_consumed()
[docs]
def split_points(self):
# An unsplittable range only contains a single split point.
return (0, 1)
[docs]
def set_split_points_unclaimed_callback(self, callback):
self._range_tracker.set_split_points_unclaimed_callback(callback)
[docs]
class LexicographicKeyRangeTracker(OrderedPositionRangeTracker):
"""A range tracker that tracks progress through a lexicographically
ordered keyspace of strings.
"""
[docs]
@classmethod
def fraction_to_position(
cls,
fraction: float,
start: Union[bytes, str] = None,
end: Union[bytes, str] = None,
) -> Union[bytes, str]:
"""Linearly interpolates a key that is lexicographically
fraction of the way between start and end.
"""
assert 0 <= fraction <= 1, fraction
if start is None:
start = b''
if fraction == 0:
return start
if fraction == 1:
return end
if not end:
common_prefix_len = len(start) - len(start.lstrip(b'\xFF'))
else:
for ix, (s, e) in enumerate(zip(start, end)):
if s != e:
common_prefix_len = ix
break
else:
common_prefix_len = min(len(start), len(end))
# Convert the relative precision of fraction (~53 bits) to an absolute
# precision needed to represent values between start and end distinctly.
prec = common_prefix_len + int(-math.log(fraction, 256)) + 7
istart = cls._bytestring_to_int(start, prec)
iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
ikey = istart + int((iend - istart) * fraction)
# Could be equal due to rounding.
# Adjust to ensure we never return the actual start and end
# unless fraction is exatly 0 or 1.
if ikey == istart:
ikey += 1
elif ikey == iend:
ikey -= 1
position: bytes = cls._bytestring_from_int(ikey, prec).rstrip(b'\0')
if isinstance(start, bytes):
return position
return position.decode(encoding='unicode_escape', errors='replace')
[docs]
@classmethod
def position_to_fraction(
cls,
key: Union[bytes, str] = None,
start: Union[bytes, str] = None,
end: Union[bytes, str] = None,
) -> float:
"""Returns the fraction of keys in the range [start, end) that
are less than the given key.
"""
if not key:
return 0
if start is None:
start = '' if isinstance(key, str) else b''
prec = len(start) + 7
if key.startswith(start):
# Higher absolute precision needed for very small values of fixed
# relative position.
trailing_symbol = '\0' if isinstance(key, str) else b'\0'
prec = max(
prec, len(key) - len(key[len(start):].strip(trailing_symbol)) + 7)
istart = cls._bytestring_to_int(start, prec)
ikey = cls._bytestring_to_int(key, prec)
iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
return float(ikey - istart) / (iend - istart)
@staticmethod
def _bytestring_to_int(s: Union[bytes, str], prec: int) -> int:
"""Returns int(256**prec * f) where f is the fraction
represented by interpreting '.' + s as a base-256
floating point number.
"""
if not s:
return 0
if isinstance(s, str):
s = s.encode() # str -> bytes
if len(s) < prec:
s += b'\0' * (prec - len(s))
else:
s = s[:prec]
h = codecs.encode(s, encoding='hex')
return int(h, base=16)
@staticmethod
def _bytestring_from_int(i: int, prec: int) -> bytes:
"""Inverse of _bytestring_to_int."""
h = '%x' % i
return codecs.decode('0' * (2 * prec - len(h)) + h, encoding='hex')