#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import itertools
import sys
import apache_beam as beam
from apache_beam import coders
from apache_beam.portability.api import beam_interactive_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive.cache_manager import CacheManager
from apache_beam.utils.timestamp import Duration
from apache_beam.utils.timestamp import Timestamp
[docs]
class InMemoryCache(CacheManager):
"""A cache that stores all PCollections in an in-memory map.
This is only used for checking the pipeline shape. This can't be used for
running the pipeline isn't shared between the SDK and the Runner.
"""
def __init__(self):
self._cached = {}
self._pcoders = {}
[docs]
def exists(self, *labels):
return self._key(*labels) in self._cached
def _latest_version(self, *labels):
return True
[docs]
def read(self, *labels, **args):
if not self.exists(*labels):
return itertools.chain([]), -1
return itertools.chain(self._cached[self._key(*labels)]), None
[docs]
def write(self, value, *labels):
if not self.exists(*labels):
self._cached[self._key(*labels)] = []
self._cached[self._key(*labels)] += value
[docs]
def save_pcoder(self, pcoder, *labels):
self._pcoders[self._key(*labels)] = pcoder
[docs]
def load_pcoder(self, *labels):
return self._pcoders[self._key(*labels)]
[docs]
def cleanup(self):
self._cached = collections.defaultdict(list)
self._pcoders = {}
[docs]
def clear(self, *label):
# Noop because in-memory.
pass
[docs]
def source(self, *labels):
vals = self._cached[self._key(*labels)]
return beam.Create(vals)
[docs]
def sink(self, labels, is_capture=False):
return beam.Map(lambda _: _)
[docs]
def size(self, *labels):
if self.exists(*labels):
return sys.getsizeof(self._cached[self._key(*labels)])
return 0
def _key(self, *labels):
return '/'.join([l for l in labels])
[docs]
class NoopSink(beam.PTransform):
[docs]
def expand(self, pcoll):
return pcoll | beam.Map(lambda x: x)
[docs]
class FileRecordsBuilder(object):
def __init__(self, tag=None):
self._header = beam_interactive_api_pb2.TestStreamFileHeader(tag=tag)
self._records = []
self._coder = coders.FastPrimitivesCoder()
[docs]
def add_element(self, element, event_time_secs):
element_payload = beam_runner_api_pb2.TestStreamPayload.TimestampedElement(
encoded_element=self._coder.encode(element),
timestamp=Timestamp.of(event_time_secs).micros)
record = beam_interactive_api_pb2.TestStreamFileRecord(
recorded_event=beam_runner_api_pb2.TestStreamPayload.Event(
element_event=beam_runner_api_pb2.TestStreamPayload.Event.
AddElements(elements=[element_payload])))
self._records.append(record)
return self
[docs]
def advance_watermark(self, watermark_secs):
record = beam_interactive_api_pb2.TestStreamFileRecord(
recorded_event=beam_runner_api_pb2.TestStreamPayload.Event(
watermark_event=beam_runner_api_pb2.TestStreamPayload.
Event.AdvanceWatermark(
new_watermark=Timestamp.of(watermark_secs).micros)))
self._records.append(record)
return self
[docs]
def advance_processing_time(self, delta_secs):
record = beam_interactive_api_pb2.TestStreamFileRecord(
recorded_event=beam_runner_api_pb2.TestStreamPayload.Event(
processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event.
AdvanceProcessingTime(
advance_duration=Duration.of(delta_secs).micros)))
self._records.append(record)
return self
[docs]
def build(self):
return [self._header] + self._records