#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import collections
import itertools
import typing
import apache_beam as beam
from apache_beam import typehints
from apache_beam.internal.util import ArgumentPlaceholder
from apache_beam.transforms.combiners import _CurriedFn
from apache_beam.utils.windowed_value import WindowedValue
[docs]
class LiftedCombinePerKey(beam.PTransform):
"""An implementation of CombinePerKey that does mapper-side pre-combining.
"""
def __init__(self, combine_fn, args, kwargs):
args_to_check = itertools.chain(args, kwargs.values())
if isinstance(combine_fn, _CurriedFn):
args_to_check = itertools.chain(
args_to_check, combine_fn.args, combine_fn.kwargs.values())
if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check):
# This isn't implemented in dataflow either...
raise NotImplementedError('Deferred CombineFn side inputs.')
self._combine_fn = beam.transforms.combiners.curry_combine_fn(
combine_fn, args, kwargs)
[docs]
def expand(self, pcoll):
return (
pcoll
| beam.ParDo(PartialGroupByKeyCombiningValues(self._combine_fn))
| beam.GroupByKey()
| beam.ParDo(FinishCombine(self._combine_fn)))
[docs]
class PartialGroupByKeyCombiningValues(beam.DoFn):
"""Aggregates values into a per-key-window cache.
As bundles are in-memory-sized, we don't bother flushing until the very end.
"""
def __init__(self, combine_fn):
self._combine_fn = combine_fn
[docs]
def setup(self):
self._combine_fn.setup()
[docs]
def start_bundle(self):
self._cache = collections.defaultdict(self._combine_fn.create_accumulator)
[docs]
def process(self, element, window=beam.DoFn.WindowParam):
k, vi = element
self._cache[k, window] = self._combine_fn.add_input(
self._cache[k, window], vi)
[docs]
def finish_bundle(self):
for (k, w), va in self._cache.items():
# We compact the accumulator since a GBK (which necessitates encoding)
# will follow.
yield WindowedValue((k, self._combine_fn.compact(va)), w.end, (w, ))
[docs]
def teardown(self):
self._combine_fn.teardown()
[docs]
def default_type_hints(self):
hints = self._combine_fn.get_type_hints()
K = typehints.TypeVariable('K')
if hints.input_types:
args, kwargs = hints.input_types
args = (typehints.Tuple[K, args[0]], ) + args[1:]
hints = hints.with_input_types(*args, **kwargs)
else:
hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
hints = hints.with_output_types(typehints.Tuple[K, typing.Any])
return hints
[docs]
class FinishCombine(beam.DoFn):
"""Merges partially combined results.
"""
def __init__(self, combine_fn):
self._combine_fn = combine_fn
[docs]
def setup(self):
self._combine_fn.setup()
[docs]
def process(self, element):
k, vs = element
return [(
k,
self._combine_fn.extract_output(
self._combine_fn.merge_accumulators(vs)))]
[docs]
def teardown(self):
self._combine_fn.teardown()
[docs]
def default_type_hints(self):
hints = self._combine_fn.get_type_hints()
K = typehints.TypeVariable('K')
hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
if hints.output_types:
main_output_type = hints.simple_output_type('')
hints = hints.with_output_types(typehints.Tuple[K, main_output_type])
return hints