#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module of beam_sql cell magic that executes a Beam SQL.
Only works within an IPython kernel.
"""
import argparse
import importlib
import keyword
import logging
import traceback
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import apache_beam as beam
from apache_beam.pvalue import PValue
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.background_caching_job import has_source_to_cache
from apache_beam.runners.interactive.caching.cacheable import CacheKey
from apache_beam.runners.interactive.caching.reify import reify_to_cache
from apache_beam.runners.interactive.caching.reify import unreify_from_cache
from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll
from apache_beam.runners.interactive.sql.sql_chain import SqlChain
from apache_beam.runners.interactive.sql.sql_chain import SqlNode
from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm
from apache_beam.runners.interactive.sql.utils import find_pcolls
from apache_beam.runners.interactive.sql.utils import pformat_namedtuple
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
from apache_beam.runners.interactive.utils import create_var_in_main
from apache_beam.runners.interactive.utils import obfuscate
from apache_beam.runners.interactive.utils import pcoll_by_name
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.testing import test_stream
from apache_beam.testing.test_stream_service import TestStreamServiceController
from apache_beam.transforms.sql import SqlTransform
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from IPython.core.magic import Magics
from IPython.core.magic import line_cell_magic
from IPython.core.magic import magics_class
_LOGGER = logging.getLogger(__name__)
_EXAMPLE_USAGE = """beam_sql magic to execute Beam SQL in notebooks
---------------------------------------------------------
%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query
---------------------------------------------------------
Or
---------------------------------------------------------
%%%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query-line#1
query-line#2
...
query-line#N
---------------------------------------------------------
"""
_NOT_SUPPORTED_MSG = """The query was valid and successfully applied.
But beam_sql failed to execute the query: %s
Runner used by beam_sql was %s.
Some Beam features might have not been supported by the Python SDK and runner combination.
Please check the runner output for more details about the failed items.
In the meantime, you may check:
https://beam.apache.org/documentation/runners/capability-matrix/
to choose a runner other than the InteractiveRunner and explicitly apply SqlTransform
to build Beam pipelines in a non-interactive manner.
"""
_SUPPORTED_RUNNERS = ['DirectRunner', 'DataflowRunner']
[docs]
class BeamSqlParser:
"""A parser to parse beam_sql inputs."""
def __init__(self):
self._parser = argparse.ArgumentParser(usage=_EXAMPLE_USAGE)
self._parser.add_argument(
'-o',
'--output-name',
dest='output_name',
help=(
'The output variable name of the magic, usually a PCollection. '
'Auto-generated if omitted.'))
self._parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Display more details about the magic execution.')
self._parser.add_argument(
'-r',
'--runner',
dest='runner',
help=(
'The runner to run the query. Supported runners are %s. If not '
'provided, DirectRunner is used and results can be inspected '
'locally.' % _SUPPORTED_RUNNERS))
self._parser.add_argument(
'query',
type=str,
nargs='*',
help=(
'The Beam SQL query to execute. '
'Syntax: https://beam.apache.org/documentation/dsls/sql/calcite/'
'query-syntax/. '
'Please make sure that there is no conflict between your variable '
'names and the SQL keywords, such as "SELECT", "FROM", "WHERE" and '
'etc.'))
[docs]
def parse(self, args: List[str]) -> Optional[argparse.Namespace]:
"""Parses a list of string inputs.
The parsed namespace contains these attributes:
output_name: Optional[str], the output variable name.
verbose: bool, whether to display more details of the magic execution.
query: Optional[List[str]], the beam SQL query to execute.
Returns:
The parsed args or None if fail to parse.
"""
try:
return self._parser.parse_args(args)
except KeyboardInterrupt:
raise
except: # pylint: disable=bare-except
# -h or --help results in SystemExit 0. Do not raise.
return None
[docs]
def print_help(self) -> None:
self._parser.print_help()
[docs]
def on_error(error_msg, *args):
"""Logs the error and the usage example."""
_LOGGER.error(error_msg, *args)
BeamSqlParser().print_help()
[docs]
@magics_class
class BeamSqlMagics(Magics):
def __init__(self, shell):
super().__init__(shell)
# Eagerly initializes the environment.
_ = ie.current_env()
self._parser = BeamSqlParser()
[docs]
@line_cell_magic
def beam_sql(self, line: str, cell: Optional[str] = None) -> Optional[PValue]:
"""The beam_sql line/cell magic that executes a Beam SQL.
Args:
line: the string on the same line after the beam_sql magic.
cell: everything else in the same notebook cell as a string. If None,
beam_sql is used as line magic. Otherwise, cell magic.
Returns None if running into an error or waiting for user input (running on
a selected runner remotely), otherwise a PValue as if a SqlTransform is
applied.
"""
input_str = line
if cell:
input_str += ' ' + cell
parsed = self._parser.parse(input_str.strip().split())
if not parsed:
# Failed to parse inputs, let the parser handle the exit.
return
output_name = parsed.output_name
verbose = parsed.verbose
query = parsed.query
runner = parsed.runner
if output_name and not output_name.isidentifier() or keyword.iskeyword(
output_name):
on_error(
'The output_name "%s" is not a valid identifier. Please supply a '
'valid identifier that is not a Python keyword.',
line)
return
if not query:
on_error('Please supply the SQL query to be executed.')
return
if runner and runner not in _SUPPORTED_RUNNERS:
on_error(
'Runner "%s" is not supported. Supported runners are %s.',
runner,
_SUPPORTED_RUNNERS)
return
query = ' '.join(query)
found = find_pcolls(query, pcoll_by_name(), verbose=verbose)
schemas = set()
main_session = importlib.import_module('__main__')
for _, pcoll in found.items():
if not match_is_named_tuple(pcoll.element_type):
on_error(
'PCollection %s of type %s is not a NamedTuple. See '
'https://beam.apache.org/documentation/programming-guide/#schemas '
'for more details.',
pcoll,
pcoll.element_type)
return
register_coder_for_schema(pcoll.element_type, verbose=verbose)
# Only care about schemas defined by the user in the main module.
if hasattr(main_session, pcoll.element_type.__name__):
schemas.add(pcoll.element_type)
if runner in ('DirectRunner', None):
collect_data_for_local_run(query, found)
output_name, output, chain = apply_sql(query, output_name, found)
chain.current.schemas = schemas
cache_output(output_name, output)
return output
output_name, current_node, chain = apply_sql(
query, output_name, found, False)
current_node.schemas = schemas
# TODO(BEAM-10708): Move the options setup and result handling to a
# separate module when more runners are supported.
if runner == 'DataflowRunner':
_ = chain.to_pipeline()
_ = DataflowOptionsForm(
output_name, pcoll_by_name()[output_name],
verbose).display_for_input()
return None
else:
raise ValueError('Unsupported runner %s.', runner)
[docs]
@progress_indicated
def collect_data_for_local_run(query: str, found: Dict[str, beam.PCollection]):
from apache_beam.runners.interactive import interactive_beam as ib
for name, pcoll in found.items():
try:
_ = ib.collect(pcoll)
except (KeyboardInterrupt, SystemExit):
raise
except: # pylint: disable=bare-except
_LOGGER.error(
'Cannot collect data for PCollection %s. Please make sure the '
'PCollections queried in the sql "%s" are all from a single '
'pipeline using an InteractiveRunner. Make sure there is no '
'ambiguity, for example, same named PCollections from multiple '
'pipelines or notebook re-executions.',
name,
query)
raise
[docs]
@progress_indicated
def apply_sql(
query: str,
output_name: Optional[str],
found: Dict[str, beam.PCollection],
run: bool = True) -> Tuple[str, Union[PValue, SqlNode], SqlChain]:
"""Applies a SqlTransform with the given sql and queried PCollections.
Args:
query: The SQL query executed in the magic.
output_name: (optional) The output variable name in __main__ module.
found: The PCollections with variable names found to be used in the query.
run: Whether to prepare the SQL pipeline for a local run or not.
Returns:
A tuple of values. First str value is the output variable name in
__main__ module, auto-generated if not provided. Second value: if run,
it's a PValue; otherwise, a SqlNode tracks the SQL without applying it or
executing it. Third value: SqlChain is a chain of SqlNodes that have been
applied.
"""
output_name = _generate_output_name(output_name, query, found)
query, sql_source, chain = _build_query_components(
query, found, output_name, run)
if run:
try:
output = sql_source | SqlTransform(query)
# Declare a variable with the output_name and output value in the
# __main__ module so that the user can use the output smoothly.
output_name, output = create_var_in_main(output_name, output)
_LOGGER.info(
"The output PCollection variable is %s with element_type %s",
output_name,
pformat_namedtuple(output.element_type))
return output_name, output, chain
except (KeyboardInterrupt, SystemExit):
raise
except: # pylint: disable=bare-except
on_error('Error when applying the Beam SQL: %s', traceback.format_exc())
raise
else:
return output_name, chain.current, chain
[docs]
def pcolls_from_streaming_cache(
user_pipeline: beam.Pipeline,
query_pipeline: beam.Pipeline,
name_to_pcoll: Dict[str, beam.PCollection]) -> Dict[str, beam.PCollection]:
"""Reads PCollection cache through the TestStream.
Args:
user_pipeline: The beam.Pipeline object defined by the user in the
notebook.
query_pipeline: The beam.Pipeline object built by the magic to execute the
SQL query.
name_to_pcoll: PCollections with variable names used in the SQL query.
Returns:
A Dict[str, beam.PCollection], where each PCollection is tagged with
their PCollection variable names, read from the cache.
When the user_pipeline has unbounded sources, we force all cache reads to go
through the TestStream even if they are bounded sources.
"""
def exception_handler(e):
_LOGGER.error(str(e))
return True
cache_manager = ie.current_env().get_cache_manager(
user_pipeline, create_if_absent=True)
test_stream_service = ie.current_env().get_test_stream_service_controller(
user_pipeline)
if not test_stream_service:
test_stream_service = TestStreamServiceController(
cache_manager, exception_handler=exception_handler)
test_stream_service.start()
ie.current_env().set_test_stream_service_controller(
user_pipeline, test_stream_service)
tag_to_name = {}
for name, pcoll in name_to_pcoll.items():
key = CacheKey.from_pcoll(name, pcoll).to_str()
tag_to_name[key] = name
output_pcolls = query_pipeline | test_stream.TestStream(
output_tags=set(tag_to_name.keys()),
coder=cache_manager._default_pcoder,
endpoint=test_stream_service.endpoint)
sql_source = {}
for tag, output in output_pcolls.items():
name = tag_to_name[tag]
# Must mark the element_type to avoid introducing pickled Python coder
# to the Java expansion service.
output.element_type = name_to_pcoll[name].element_type
sql_source[name] = output
return sql_source
def _generate_output_name(
output_name: Optional[str], query: str,
found: Dict[str, beam.PCollection]) -> str:
"""Generates a unique output name if None is provided.
Otherwise, returns the given output name directly.
The generated output name is sql_output_{uuid} where uuid is an obfuscated
value from the query and PCollections found to be used in the query.
"""
if not output_name:
execution_id = obfuscate(query, found)[:12]
output_name = 'sql_output_' + execution_id
return output_name
def _build_query_components(
query: str,
found: Dict[str, beam.PCollection],
output_name: str,
run: bool = True
) -> Tuple[str,
Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline],
SqlChain]:
"""Builds necessary components needed to apply the SqlTransform.
Args:
query: The SQL query to be executed by the magic.
found: The PCollections with variable names found to be used by the query.
output_name: The output variable name in __main__ module.
run: Whether to prepare components for a local run or not.
Returns:
The processed query to be executed by the magic; a source to apply the
SqlTransform to: a dictionary of tagged PCollections, or a single
PCollection, or the pipeline to execute the query; the chain of applied
beam_sql magics this one belongs to.
"""
if found:
user_pipeline = ie.current_env().user_pipeline(
next(iter(found.values())).pipeline)
sql_pipeline = beam.Pipeline(options=user_pipeline._options)
ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
sql_source = {}
if run:
if has_source_to_cache(user_pipeline):
sql_source = pcolls_from_streaming_cache(
user_pipeline, sql_pipeline, found)
else:
cache_manager = ie.current_env().get_cache_manager(
user_pipeline, create_if_absent=True)
for pcoll_name, pcoll in found.items():
cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
sql_source[pcoll_name] = unreify_from_cache(
pipeline=sql_pipeline,
cache_key=cache_key,
cache_manager=cache_manager,
element_type=pcoll.element_type)
else:
sql_source = found
if len(sql_source) == 1:
query = replace_single_pcoll_token(query, next(iter(sql_source.keys())))
sql_source = next(iter(sql_source.values()))
node = SqlNode(
output_name=output_name, source=set(found.keys()), query=query)
chain = ie.current_env().get_sql_chain(
user_pipeline, set_user_pipeline=True).append(node)
else: # does not query any existing PCollection
sql_source = beam.Pipeline()
ie.current_env().add_user_pipeline(sql_source)
# The node should be the root node of the chain created below.
node = SqlNode(output_name=output_name, source=sql_source, query=query)
chain = ie.current_env().get_sql_chain(sql_source).append(node)
return query, sql_source, chain
[docs]
@progress_indicated
def cache_output(output_name: str, output: PValue) -> None:
user_pipeline = ie.current_env().user_pipeline(output.pipeline)
if user_pipeline:
cache_manager = ie.current_env().get_cache_manager(
user_pipeline, create_if_absent=True)
else:
_LOGGER.warning(
'Something is wrong with %s. Cannot introspect its data.', output)
return
key = CacheKey.from_pcoll(output_name, output).to_str()
_ = reify_to_cache(pcoll=output, cache_key=key, cache_manager=cache_manager)
try:
output.pipeline.run().wait_until_finish()
except (KeyboardInterrupt, SystemExit):
raise
except: # pylint: disable=bare-except
_LOGGER.warning(
_NOT_SUPPORTED_MSG, traceback.format_exc(), output.pipeline.runner)
return
ie.current_env().mark_pcollection_computed([output])
visualize_computed_pcoll(
output_name, output, max_n=float('inf'), max_duration_secs=float('inf'))
[docs]
def load_ipython_extension(ipython):
"""Marks this module as an IPython extension.
To load this magic in an IPython environment, execute:
%load_ext apache_beam.runners.interactive.sql.beam_sql_magics.
"""
ipython.register_magics(BeamSqlMagics)