#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module to augment interactive flavor into the given pipeline.
For internal use only; no backward-compatibility guarantees.
"""
# pytype: skip-file
import copy
from typing import Dict
from typing import Optional
from typing import Set
import apache_beam as beam
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import background_caching_job
from apache_beam.runners.interactive.caching.cacheable import Cacheable
from apache_beam.runners.interactive.caching.read_cache import ReadCache
from apache_beam.runners.interactive.caching.write_cache import WriteCache
[docs]
class AugmentedPipeline:
"""A pipeline with augmented interactive flavor that caches intermediate
PCollections defined by the user, reads computed PCollections as source and
prunes unnecessary pipeline parts for fast computation.
"""
def __init__(
self,
user_pipeline: beam.Pipeline,
pcolls: Optional[Set[beam.pvalue.PCollection]] = None):
"""
Initializes a pipelilne for augmenting interactive flavor.
Args:
user_pipeline: a beam.Pipeline instance defined by the user.
pcolls: cacheable pcolls to be computed/retrieved. If the set is
empty, all intermediate pcolls assigned to variables are applicable.
"""
assert not pcolls or all(pcoll.pipeline is user_pipeline for pcoll in
pcolls), 'All %s need to belong to %s' % (pcolls, user_pipeline)
self._user_pipeline = user_pipeline
self._pcolls = pcolls
self._cache_manager = ie.current_env().get_cache_manager(
self._user_pipeline, create_if_absent=True)
if background_caching_job.has_source_to_cache(self._user_pipeline):
self._cache_manager = ie.current_env().get_cache_manager(
self._user_pipeline)
_, self._context = self._user_pipeline.to_runner_api(return_context=True)
self._context.component_id_map = copy.copy(
self._user_pipeline.component_id_map)
self._cacheables = self.cacheables()
@property
def augmented_pipeline(self) -> beam_runner_api_pb2.Pipeline:
return self.augment()
# TODO(https://github.com/apache/beam/issues/20526): Support generating a
# background recording job that contains unbound source recording transforms
# only.
@property
def background_recording_pipeline(self) -> beam_runner_api_pb2.Pipeline:
raise NotImplementedError
[docs]
def cacheables(self) -> Dict[beam.pvalue.PCollection, Cacheable]:
"""Finds all the cacheable intermediate PCollections in the pipeline with
their metadata.
"""
c = {}
for watching in ie.current_env().watching():
for key, val in watching:
if (isinstance(val, beam.pvalue.PCollection) and
val.pipeline is self._user_pipeline and
(not self._pcolls or val in self._pcolls)):
c[val] = Cacheable(
var=key,
pcoll=val,
version=str(id(val)),
producer_version=str(id(val.producer)))
return c
[docs]
def augment(self) -> beam_runner_api_pb2.Pipeline:
"""Augments the pipeline with cache. Always calculates a new result.
For a cacheable PCollection, if cache exists, read cache; else, write cache.
"""
pipeline = self._user_pipeline.to_runner_api()
# Find pcolls eligible for reading or writing cache.
readcache_pcolls = set()
for pcoll, cacheable in self._cacheables.items():
key = repr(cacheable.to_key())
if (self._cache_manager.exists('full', key) and
pcoll in ie.current_env().computed_pcollections):
readcache_pcolls.add(pcoll)
writecache_pcolls = set(
self._cacheables.keys()).difference(readcache_pcolls)
# Wire in additional transforms to read cache and write cache.
for readcache_pcoll in readcache_pcolls:
ReadCache(
pipeline,
self._context,
self._cache_manager,
self._cacheables[readcache_pcoll]).read_cache()
for writecache_pcoll in writecache_pcolls:
WriteCache(
pipeline,
self._context,
self._cache_manager,
self._cacheables[writecache_pcoll]).write_cache()
# TODO(https://github.com/apache/beam/issues/20526): Support streaming, add
# pruning logic, and integrate pipeline fragment logic.
return pipeline