Source code for apache_beam.runners.interactive.sql.beam_sql_magics

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Module of beam_sql cell magic that executes a Beam SQL.

Only works within an IPython kernel.
"""

import argparse
import importlib
import keyword
import logging
import traceback
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import apache_beam as beam
from apache_beam.pvalue import PValue
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.background_caching_job import has_source_to_cache
from apache_beam.runners.interactive.caching.cacheable import CacheKey
from apache_beam.runners.interactive.caching.reify import reify_to_cache
from apache_beam.runners.interactive.caching.reify import unreify_from_cache
from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll
from apache_beam.runners.interactive.sql.sql_chain import SqlChain
from apache_beam.runners.interactive.sql.sql_chain import SqlNode
from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm
from apache_beam.runners.interactive.sql.utils import find_pcolls
from apache_beam.runners.interactive.sql.utils import pformat_namedtuple
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
from apache_beam.runners.interactive.utils import create_var_in_main
from apache_beam.runners.interactive.utils import obfuscate
from apache_beam.runners.interactive.utils import pcoll_by_name
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.testing import test_stream
from apache_beam.testing.test_stream_service import TestStreamServiceController
from apache_beam.transforms.sql import SqlTransform
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from IPython.core.magic import Magics
from IPython.core.magic import line_cell_magic
from IPython.core.magic import magics_class

_LOGGER = logging.getLogger(__name__)

_EXAMPLE_USAGE = """beam_sql magic to execute Beam SQL in notebooks
---------------------------------------------------------
%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query
---------------------------------------------------------
Or
---------------------------------------------------------
%%%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query-line#1
query-line#2
...
query-line#N
---------------------------------------------------------
"""

_NOT_SUPPORTED_MSG = """The query was valid and successfully applied.
    But beam_sql failed to execute the query: %s

    Runner used by beam_sql was %s.
    Some Beam features might have not been supported by the Python SDK and runner combination.
    Please check the runner output for more details about the failed items.

    In the meantime, you may check:
    https://beam.apache.org/documentation/runners/capability-matrix/
    to choose a runner other than the InteractiveRunner and explicitly apply SqlTransform
    to build Beam pipelines in a non-interactive manner.
"""

_SUPPORTED_RUNNERS = ['DirectRunner', 'DataflowRunner']


[docs] class BeamSqlParser: """A parser to parse beam_sql inputs.""" def __init__(self): self._parser = argparse.ArgumentParser(usage=_EXAMPLE_USAGE) self._parser.add_argument( '-o', '--output-name', dest='output_name', help=( 'The output variable name of the magic, usually a PCollection. ' 'Auto-generated if omitted.')) self._parser.add_argument( '-v', '--verbose', action='store_true', help='Display more details about the magic execution.') self._parser.add_argument( '-r', '--runner', dest='runner', help=( 'The runner to run the query. Supported runners are %s. If not ' 'provided, DirectRunner is used and results can be inspected ' 'locally.' % _SUPPORTED_RUNNERS)) self._parser.add_argument( 'query', type=str, nargs='*', help=( 'The Beam SQL query to execute. ' 'Syntax: https://beam.apache.org/documentation/dsls/sql/calcite/' 'query-syntax/. ' 'Please make sure that there is no conflict between your variable ' 'names and the SQL keywords, such as "SELECT", "FROM", "WHERE" and ' 'etc.'))
[docs] def parse(self, args: List[str]) -> Optional[argparse.Namespace]: """Parses a list of string inputs. The parsed namespace contains these attributes: output_name: Optional[str], the output variable name. verbose: bool, whether to display more details of the magic execution. query: Optional[List[str]], the beam SQL query to execute. Returns: The parsed args or None if fail to parse. """ try: return self._parser.parse_args(args) except KeyboardInterrupt: raise except: # pylint: disable=bare-except # -h or --help results in SystemExit 0. Do not raise. return None
[docs] def print_help(self) -> None: self._parser.print_help()
[docs] def on_error(error_msg, *args): """Logs the error and the usage example.""" _LOGGER.error(error_msg, *args) BeamSqlParser().print_help()
[docs] @magics_class class BeamSqlMagics(Magics): def __init__(self, shell): super().__init__(shell) # Eagerly initializes the environment. _ = ie.current_env() self._parser = BeamSqlParser()
[docs] @line_cell_magic def beam_sql(self, line: str, cell: Optional[str] = None) -> Optional[PValue]: """The beam_sql line/cell magic that executes a Beam SQL. Args: line: the string on the same line after the beam_sql magic. cell: everything else in the same notebook cell as a string. If None, beam_sql is used as line magic. Otherwise, cell magic. Returns None if running into an error or waiting for user input (running on a selected runner remotely), otherwise a PValue as if a SqlTransform is applied. """ input_str = line if cell: input_str += ' ' + cell parsed = self._parser.parse(input_str.strip().split()) if not parsed: # Failed to parse inputs, let the parser handle the exit. return output_name = parsed.output_name verbose = parsed.verbose query = parsed.query runner = parsed.runner if output_name and not output_name.isidentifier() or keyword.iskeyword( output_name): on_error( 'The output_name "%s" is not a valid identifier. Please supply a ' 'valid identifier that is not a Python keyword.', line) return if not query: on_error('Please supply the SQL query to be executed.') return if runner and runner not in _SUPPORTED_RUNNERS: on_error( 'Runner "%s" is not supported. Supported runners are %s.', runner, _SUPPORTED_RUNNERS) return query = ' '.join(query) found = find_pcolls(query, pcoll_by_name(), verbose=verbose) schemas = set() main_session = importlib.import_module('__main__') for _, pcoll in found.items(): if not match_is_named_tuple(pcoll.element_type): on_error( 'PCollection %s of type %s is not a NamedTuple. See ' 'https://beam.apache.org/documentation/programming-guide/#schemas ' 'for more details.', pcoll, pcoll.element_type) return register_coder_for_schema(pcoll.element_type, verbose=verbose) # Only care about schemas defined by the user in the main module. if hasattr(main_session, pcoll.element_type.__name__): schemas.add(pcoll.element_type) if runner in ('DirectRunner', None): collect_data_for_local_run(query, found) output_name, output, chain = apply_sql(query, output_name, found) chain.current.schemas = schemas cache_output(output_name, output) return output output_name, current_node, chain = apply_sql( query, output_name, found, False) current_node.schemas = schemas # TODO(BEAM-10708): Move the options setup and result handling to a # separate module when more runners are supported. if runner == 'DataflowRunner': _ = chain.to_pipeline() _ = DataflowOptionsForm( output_name, pcoll_by_name()[output_name], verbose).display_for_input() return None else: raise ValueError('Unsupported runner %s.', runner)
[docs] @progress_indicated def collect_data_for_local_run(query: str, found: Dict[str, beam.PCollection]): from apache_beam.runners.interactive import interactive_beam as ib for name, pcoll in found.items(): try: _ = ib.collect(pcoll) except (KeyboardInterrupt, SystemExit): raise except: # pylint: disable=bare-except _LOGGER.error( 'Cannot collect data for PCollection %s. Please make sure the ' 'PCollections queried in the sql "%s" are all from a single ' 'pipeline using an InteractiveRunner. Make sure there is no ' 'ambiguity, for example, same named PCollections from multiple ' 'pipelines or notebook re-executions.', name, query) raise
[docs] @progress_indicated def apply_sql( query: str, output_name: Optional[str], found: Dict[str, beam.PCollection], run: bool = True) -> Tuple[str, Union[PValue, SqlNode], SqlChain]: """Applies a SqlTransform with the given sql and queried PCollections. Args: query: The SQL query executed in the magic. output_name: (optional) The output variable name in __main__ module. found: The PCollections with variable names found to be used in the query. run: Whether to prepare the SQL pipeline for a local run or not. Returns: A tuple of values. First str value is the output variable name in __main__ module, auto-generated if not provided. Second value: if run, it's a PValue; otherwise, a SqlNode tracks the SQL without applying it or executing it. Third value: SqlChain is a chain of SqlNodes that have been applied. """ output_name = _generate_output_name(output_name, query, found) query, sql_source, chain = _build_query_components( query, found, output_name, run) if run: try: output = sql_source | SqlTransform(query) # Declare a variable with the output_name and output value in the # __main__ module so that the user can use the output smoothly. output_name, output = create_var_in_main(output_name, output) _LOGGER.info( "The output PCollection variable is %s with element_type %s", output_name, pformat_namedtuple(output.element_type)) return output_name, output, chain except (KeyboardInterrupt, SystemExit): raise except: # pylint: disable=bare-except on_error('Error when applying the Beam SQL: %s', traceback.format_exc()) raise else: return output_name, chain.current, chain
[docs] def pcolls_from_streaming_cache( user_pipeline: beam.Pipeline, query_pipeline: beam.Pipeline, name_to_pcoll: Dict[str, beam.PCollection]) -> Dict[str, beam.PCollection]: """Reads PCollection cache through the TestStream. Args: user_pipeline: The beam.Pipeline object defined by the user in the notebook. query_pipeline: The beam.Pipeline object built by the magic to execute the SQL query. name_to_pcoll: PCollections with variable names used in the SQL query. Returns: A Dict[str, beam.PCollection], where each PCollection is tagged with their PCollection variable names, read from the cache. When the user_pipeline has unbounded sources, we force all cache reads to go through the TestStream even if they are bounded sources. """ def exception_handler(e): _LOGGER.error(str(e)) return True cache_manager = ie.current_env().get_cache_manager( user_pipeline, create_if_absent=True) test_stream_service = ie.current_env().get_test_stream_service_controller( user_pipeline) if not test_stream_service: test_stream_service = TestStreamServiceController( cache_manager, exception_handler=exception_handler) test_stream_service.start() ie.current_env().set_test_stream_service_controller( user_pipeline, test_stream_service) tag_to_name = {} for name, pcoll in name_to_pcoll.items(): key = CacheKey.from_pcoll(name, pcoll).to_str() tag_to_name[key] = name output_pcolls = query_pipeline | test_stream.TestStream( output_tags=set(tag_to_name.keys()), coder=cache_manager._default_pcoder, endpoint=test_stream_service.endpoint) sql_source = {} for tag, output in output_pcolls.items(): name = tag_to_name[tag] # Must mark the element_type to avoid introducing pickled Python coder # to the Java expansion service. output.element_type = name_to_pcoll[name].element_type sql_source[name] = output return sql_source
def _generate_output_name( output_name: Optional[str], query: str, found: Dict[str, beam.PCollection]) -> str: """Generates a unique output name if None is provided. Otherwise, returns the given output name directly. The generated output name is sql_output_{uuid} where uuid is an obfuscated value from the query and PCollections found to be used in the query. """ if not output_name: execution_id = obfuscate(query, found)[:12] output_name = 'sql_output_' + execution_id return output_name def _build_query_components( query: str, found: Dict[str, beam.PCollection], output_name: str, run: bool = True ) -> Tuple[str, Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline], SqlChain]: """Builds necessary components needed to apply the SqlTransform. Args: query: The SQL query to be executed by the magic. found: The PCollections with variable names found to be used by the query. output_name: The output variable name in __main__ module. run: Whether to prepare components for a local run or not. Returns: The processed query to be executed by the magic; a source to apply the SqlTransform to: a dictionary of tagged PCollections, or a single PCollection, or the pipeline to execute the query; the chain of applied beam_sql magics this one belongs to. """ if found: user_pipeline = ie.current_env().user_pipeline( next(iter(found.values())).pipeline) sql_pipeline = beam.Pipeline(options=user_pipeline._options) ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline) sql_source = {} if run: if has_source_to_cache(user_pipeline): sql_source = pcolls_from_streaming_cache( user_pipeline, sql_pipeline, found) else: cache_manager = ie.current_env().get_cache_manager( user_pipeline, create_if_absent=True) for pcoll_name, pcoll in found.items(): cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str() sql_source[pcoll_name] = unreify_from_cache( pipeline=sql_pipeline, cache_key=cache_key, cache_manager=cache_manager, element_type=pcoll.element_type) else: sql_source = found if len(sql_source) == 1: query = replace_single_pcoll_token(query, next(iter(sql_source.keys()))) sql_source = next(iter(sql_source.values())) node = SqlNode( output_name=output_name, source=set(found.keys()), query=query) chain = ie.current_env().get_sql_chain( user_pipeline, set_user_pipeline=True).append(node) else: # does not query any existing PCollection sql_source = beam.Pipeline() ie.current_env().add_user_pipeline(sql_source) # The node should be the root node of the chain created below. node = SqlNode(output_name=output_name, source=sql_source, query=query) chain = ie.current_env().get_sql_chain(sql_source).append(node) return query, sql_source, chain
[docs] @progress_indicated def cache_output(output_name: str, output: PValue) -> None: user_pipeline = ie.current_env().user_pipeline(output.pipeline) if user_pipeline: cache_manager = ie.current_env().get_cache_manager( user_pipeline, create_if_absent=True) else: _LOGGER.warning( 'Something is wrong with %s. Cannot introspect its data.', output) return key = CacheKey.from_pcoll(output_name, output).to_str() _ = reify_to_cache(pcoll=output, cache_key=key, cache_manager=cache_manager) try: output.pipeline.run().wait_until_finish() except (KeyboardInterrupt, SystemExit): raise except: # pylint: disable=bare-except _LOGGER.warning( _NOT_SUPPORTED_MSG, traceback.format_exc(), output.pipeline.runner) return ie.current_env().mark_pcollection_computed([output]) visualize_computed_pcoll( output_name, output, max_n=float('inf'), max_duration_secs=float('inf'))
[docs] def load_ipython_extension(ipython): """Marks this module as an IPython extension. To load this magic in an IPython environment, execute: %load_ext apache_beam.runners.interactive.sql.beam_sql_magics. """ ipython.register_magics(BeamSqlMagics)