Source code for apache_beam.transforms.enrichment_handlers.bigquery

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
from collections.abc import Mapping
from typing import Any
from typing import Optional
from typing import Union

from google.api_core.exceptions import BadRequest
from google.cloud import bigquery

import apache_beam as beam
from apache_beam.pvalue import Row
from apache_beam.transforms.enrichment import EnrichmentSourceHandler

QueryFn = Callable[[beam.Row], str]
ConditionValueFn = Callable[[beam.Row], list[Any]]


def _validate_bigquery_metadata(
    table_name, row_restriction_template, fields, condition_value_fn, query_fn):
  if query_fn:
    if bool(table_name or row_restriction_template or fields or
            condition_value_fn):
      raise ValueError(
          "Please provide either `query_fn` or the parameters `table_name`, "
          "`row_restriction_template`, and `fields/condition_value_fn` "
          "together.")
  else:
    if not (table_name and row_restriction_template):
      raise ValueError(
          "Please provide either `query_fn` or the parameters "
          "`table_name`, `row_restriction_template` together.")
    if ((fields and condition_value_fn) or
        (not fields and not condition_value_fn)):
      raise ValueError(
          "Please provide exactly one of `fields` or "
          "`condition_value_fn`")


[docs] class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], Union[Row, list[Row]]]): """Enrichment handler for Google Cloud BigQuery. Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` transform. To use this handler you need either of the following combinations: * `table_name`, `row_restriction_template`, `fields` * `table_name`, `row_restriction_template`, `condition_value_fn` * `query_fn` By default, the handler pulls all columns from the BigQuery table. To override this, use the `column_name` parameter to specify a list of column names to fetch. This handler pulls data from BigQuery per element by default. To change this behavior, set the `min_batch_size` and `max_batch_size` parameters. These min and max values for batch size are sent to the :class:`apache_beam.transforms.utils.BatchElements` transform. NOTE: Elements cannot be batched when using the `query_fn` parameter. """ def __init__( self, project: str, *, table_name: str = "", row_restriction_template: str = "", fields: Optional[list[str]] = None, column_names: Optional[list[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, min_batch_size: int = 1, max_batch_size: int = 10000, **kwargs, ): """ Example Usage: handler = BigQueryEnrichmentHandler(project=project_name, row_restriction="id='{}'", table_name='project.dataset.table', fields=fields, min_batch_size=2, max_batch_size=100) Args: project: Google Cloud project ID for the BigQuery table. table_name (str): Fully qualified BigQuery table name in the format `project.dataset.table`. row_restriction_template (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data. fields: (Optional[list[str]]) List of field names present in the input `beam.Row`. These are used to construct the WHERE clause (if `condition_value_fn` is not provided). column_names: (Optional[list[str]]) Names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected. condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function that takes a `beam.Row` and returns a list of value to populate in the placeholder `{}` of `WHERE` clause in the query. query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a `beam.Row` and returns a complete BigQuery SQL query string. min_batch_size (int): Minimum number of rows to batch together when querying BigQuery. Defaults to 1 if `query_fn` is not specified. max_batch_size (int): Maximum number of rows to batch together. Defaults to 10,000 if `query_fn` is not specified. **kwargs: Additional keyword arguments to pass to `bigquery.Client`. Note: * `min_batch_size` and `max_batch_size` cannot be defined if the `query_fn` is provided. * Either `fields` or `condition_value_fn` must be provided for query construction if `query_fn` is not provided. * Ensure appropriate permissions are granted for BigQuery access. """ _validate_bigquery_metadata( table_name, row_restriction_template, fields, condition_value_fn, query_fn) self.project = project self.column_names = column_names self.select_fields = ",".join(column_names) if column_names else '*' self.row_restriction_template = row_restriction_template self.table_name = table_name self.fields = fields if fields else [] self.condition_value_fn = condition_value_fn self.query_fn = query_fn self.query_template = ( "SELECT %s FROM %s WHERE %s" % (self.select_fields, self.table_name, self.row_restriction_template)) self.kwargs = kwargs self._batching_kwargs = {} if not query_fn: self._batching_kwargs['min_batch_size'] = min_batch_size self._batching_kwargs['max_batch_size'] = max_batch_size def __enter__(self): self.client = bigquery.Client(project=self.project, **self.kwargs) def _execute_query(self, query: str): try: results = self.client.query(query=query).result() if self._batching_kwargs: return [dict(row.items()) for row in results] else: return [dict(row.items()) for row in results][0] except BadRequest as e: raise BadRequest( f'Could not execute the query: {query}. Please check if ' f'the query is properly formatted and the BigQuery ' f'table exists. {e}') except RuntimeError as e: raise RuntimeError(f"Could not complete the query request: {query}. {e}")
[docs] def create_row_key(self, row: beam.Row): if self.condition_value_fn: return tuple(self.condition_value_fn(row)) if self.fields: row_dict = row._asdict() return (tuple(row_dict[field] for field in self.fields)) raise ValueError("Either fields or condition_value_fn must be specified")
def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): if isinstance(request, list): values = [] responses = [] requests_map: dict[Any, Any] = {} batch_size = len(request) raw_query = self.query_template if batch_size > 1: batched_condition_template = ' or '.join( [fr'({self.row_restriction_template})'] * batch_size) raw_query = self.query_template.replace( self.row_restriction_template, batched_condition_template) for req in request: request_dict = req._asdict() try: current_values = ( self.condition_value_fn(req) if self.condition_value_fn else [request_dict[field] for field in self.fields]) except KeyError as e: raise KeyError( "Make sure the values passed in `fields` are the " "keys in the input `beam.Row`." + str(e)) values.extend(current_values) requests_map[self.create_row_key(req)] = req query = raw_query.format(*values) responses_dict = self._execute_query(query) for response in responses_dict: response_row = beam.Row(**response) response_key = self.create_row_key(response_row) if response_key in requests_map: responses.append((requests_map[response_key], response_row)) return responses else: request_dict = request._asdict() if self.query_fn: # if a query_fn is provided then it return a list of values # that should be populated into the query template string. query = self.query_fn(request) else: values = ( self.condition_value_fn(request) if self.condition_value_fn else list(map(request_dict.get, self.fields))) # construct the query. query = self.query_template.format(*values) response_dict = self._execute_query(query) return request, beam.Row(**response_dict) def __exit__(self, exc_type, exc_val, exc_tb): self.client.close()
[docs] def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): if isinstance(request, list): cache_keys = [] for req in request: req_dict = req._asdict() try: current_values = ( self.condition_value_fn(req) if self.condition_value_fn else [req_dict[field] for field in self.fields]) key = ";".join(["%s"] * len(current_values)) cache_keys.extend([key % tuple(current_values)]) except KeyError as e: raise KeyError( "Make sure the values passed in `fields` are the " "keys in the input `beam.Row`." + str(e)) return cache_keys else: req_dict = request._asdict() try: current_values = ( self.condition_value_fn(request) if self.condition_value_fn else [req_dict[field] for field in self.fields]) key = ";".join(["%s"] * len(current_values)) cache_key = key % tuple(current_values) except KeyError as e: raise KeyError( "Make sure the values passed in `fields` are the " "keys in the input `beam.Row`." + str(e)) return cache_key
[docs] def batch_elements_kwargs(self) -> Mapping[str, Any]: """Returns a kwargs suitable for `beam.BatchElements`.""" return self._batching_kwargs