#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A ValueProvider abstracts the notion of fetching a value that may or
may not be currently available.
This can be used to parameterize transforms that only read values in at
runtime, for example.
"""
# pytype: skip-file
from functools import wraps
from typing import Set
from apache_beam import error
__all__ = [
'ValueProvider',
'StaticValueProvider',
'RuntimeValueProvider',
'NestedValueProvider',
'check_accessible',
]
[docs]
class ValueProvider(object):
"""Base class that all other ValueProviders must implement.
"""
[docs]
def is_accessible(self):
"""Whether the contents of this ValueProvider is available to routines
that run at graph construction time.
"""
raise NotImplementedError(
'ValueProvider.is_accessible implemented in derived classes')
[docs]
def get(self):
"""Return the value wrapped by this ValueProvider.
"""
raise NotImplementedError(
'ValueProvider.get implemented in derived classes')
[docs]
class StaticValueProvider(ValueProvider):
"""StaticValueProvider is an implementation of ValueProvider that allows
for a static value to be provided.
"""
def __init__(self, value_type, value):
"""
Args:
value_type: Type of the static value
value: Static value
"""
self.value_type = value_type
self.value = value_type(value)
[docs]
def is_accessible(self):
return True
[docs]
def get(self):
return self.value
def __str__(self):
return str(self.value)
def __eq__(self, other):
if self.value == other:
return True
if isinstance(other, StaticValueProvider):
if (self.value_type == other.value_type and self.value == other.value):
return True
return False
def __hash__(self):
return hash((type(self), self.value_type, self.value))
[docs]
class RuntimeValueProvider(ValueProvider):
"""RuntimeValueProvider is an implementation of ValueProvider that
allows for a value to be provided at execution time rather than
at graph construction time.
"""
runtime_options = None
experiments: Set[str] = set()
def __init__(self, option_name, value_type, default_value):
self.option_name = option_name
self.default_value = default_value
self.value_type = value_type
[docs]
def is_accessible(self):
return RuntimeValueProvider.runtime_options is not None
[docs]
@classmethod
def get_value(cls, option_name, value_type, default_value):
if not RuntimeValueProvider.runtime_options:
return default_value
candidate = RuntimeValueProvider.runtime_options.get(option_name)
if candidate:
return value_type(candidate)
else:
return default_value
[docs]
def get(self):
if RuntimeValueProvider.runtime_options is None:
raise error.RuntimeValueProviderError(
'%s.get() not called from a runtime context' % self)
return RuntimeValueProvider.get_value(
self.option_name, self.value_type, self.default_value)
[docs]
@classmethod
def set_runtime_options(cls, pipeline_options):
RuntimeValueProvider.runtime_options = pipeline_options
RuntimeValueProvider.experiments = RuntimeValueProvider.get_value(
'experiments', set, set())
def __str__(self):
return '%s(option: %s, type: %s, default_value: %s)' % (
self.__class__.__name__,
self.option_name,
self.value_type.__name__,
repr(self.default_value))
[docs]
class NestedValueProvider(ValueProvider):
"""NestedValueProvider is an implementation of ValueProvider that allows
for wrapping another ValueProvider object.
"""
def __init__(self, value, translator):
"""Creates a NestedValueProvider that wraps the provided ValueProvider.
Args:
value: ValueProvider object to wrap
translator: function that is applied to the ValueProvider
Raises:
``RuntimeValueProviderError``: if any of the provided objects are not
accessible.
"""
self.value = value
self.translator = translator
[docs]
def is_accessible(self):
return self.value.is_accessible()
[docs]
def get(self):
try:
return self.cached_value
except AttributeError:
self.cached_value = self.translator(self.value.get())
return self.cached_value
def __str__(self):
return "%s(value: %s, translator: %s)" % (
self.__class__.__name__,
self.value,
self.translator.__name__,
)
[docs]
def check_accessible(value_provider_list):
"""A decorator that checks accessibility of a list of ValueProvider objects.
Args:
value_provider_list: list of ValueProvider objects
Raises:
``RuntimeValueProviderError``: if any of the provided objects are not
accessible.
"""
assert isinstance(value_provider_list, list)
def _check_accessible(fnc):
@wraps(fnc)
def _f(self, *args, **kwargs):
for obj in [getattr(self, vp) for vp in value_provider_list]:
if not obj.is_accessible():
raise error.RuntimeValueProviderError('%s not accessible' % obj)
return fnc(self, *args, **kwargs)
return _f
return _check_accessible