#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
from apache_beam.coders import typecoders
from apache_beam.coders.coder_impl import LogicalTypeCoderImpl
from apache_beam.coders.coder_impl import RowCoderImpl
from apache_beam.coders.coders import BigEndianShortCoder
from apache_beam.coders.coders import BooleanCoder
from apache_beam.coders.coders import BytesCoder
from apache_beam.coders.coders import Coder
from apache_beam.coders.coders import DecimalCoder
from apache_beam.coders.coders import FastCoder
from apache_beam.coders.coders import FloatCoder
from apache_beam.coders.coders import IterableCoder
from apache_beam.coders.coders import MapCoder
from apache_beam.coders.coders import NullableCoder
from apache_beam.coders.coders import SinglePrecisionFloatCoder
from apache_beam.coders.coders import StrUtf8Coder
from apache_beam.coders.coders import TimestampCoder
from apache_beam.coders.coders import VarIntCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints import row_type
from apache_beam.typehints.schemas import PYTHON_ANY_URN
from apache_beam.typehints.schemas import LogicalType
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import proto_utils
__all__ = ["RowCoder"]
[docs]
class RowCoder(FastCoder):
""" Coder for `typing.NamedTuple` instances.
Implements the beam:coder:row:v1 standard coder spec.
"""
def __init__(self, schema, force_deterministic=False):
"""Initializes a :class:`RowCoder`.
Args:
schema (apache_beam.portability.api.schema_pb2.Schema): The protobuf
representation of the schema of the data that the RowCoder will be used
to encode/decode.
"""
self.schema = schema
# Eagerly generate type hint to escalate any issues with the Schema proto
self._type_hint = named_tuple_from_schema(self.schema)
# Use non-null coders because null values are represented separately
self.components = [
_nonnull_coder_from_type(field.type) for field in self.schema.fields
]
if force_deterministic:
self.components = [
c.as_deterministic_coder(force_deterministic) for c in self.components
]
self.forced_deterministic = bool(force_deterministic)
def _create_impl(self):
return RowCoderImpl(self.schema, self.components)
[docs]
def is_deterministic(self):
return all(c.is_deterministic() for c in self.components)
[docs]
def as_deterministic_coder(self, step_label, error_message=None):
if self.is_deterministic():
return self
else:
return RowCoder(self.schema, error_message or step_label)
[docs]
def to_type_hint(self):
return self._type_hint
def __hash__(self):
return hash(self.schema.SerializeToString())
def __eq__(self, other):
return (
type(self) == type(other) and self.schema == other.schema and
self.forced_deterministic == other.forced_deterministic)
[docs]
def to_runner_api_parameter(self, unused_context):
return (common_urns.coders.ROW.urn, self.schema, [])
[docs]
@staticmethod
@Coder.register_urn(common_urns.coders.ROW.urn, schema_pb2.Schema)
def from_runner_api_parameter(schema, components, unused_context):
return RowCoder(schema)
[docs]
@classmethod
def from_type_hint(cls, type_hint, registry):
# TODO(https://github.com/apache/beam/issues/21541): Remove once all
# runners are portable.
if isinstance(type_hint, str):
import importlib
main_module = importlib.import_module('__main__')
type_hint = getattr(main_module, type_hint, type_hint)
schema = schema_from_element_type(type_hint)
return cls(schema)
[docs]
@staticmethod
def from_payload(payload: bytes) -> 'RowCoder':
return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema))
def __reduce__(self):
# when pickling, use bytes representation of the schema. schema_pb2.Schema
# objects cannot be pickled.
return (RowCoder.from_payload, (self.schema.SerializeToString(), ))
typecoders.registry.register_coder(row_type.RowTypeConstraint, RowCoder)
typecoders.registry.register_coder(
row_type.GeneratedClassRowTypeConstraint, RowCoder)
def _coder_from_type(field_type):
coder = _nonnull_coder_from_type(field_type)
if field_type.nullable:
return NullableCoder(coder)
else:
return coder
def _nonnull_coder_from_type(field_type):
type_info = field_type.WhichOneof("type_info")
if type_info == "atomic_type":
if field_type.atomic_type in (schema_pb2.INT32, schema_pb2.INT64):
return VarIntCoder()
if field_type.atomic_type == schema_pb2.INT16:
return BigEndianShortCoder()
elif field_type.atomic_type == schema_pb2.FLOAT:
return SinglePrecisionFloatCoder()
elif field_type.atomic_type == schema_pb2.DOUBLE:
return FloatCoder()
elif field_type.atomic_type == schema_pb2.STRING:
return StrUtf8Coder()
elif field_type.atomic_type == schema_pb2.BOOLEAN:
return BooleanCoder()
elif field_type.atomic_type == schema_pb2.BYTES:
return BytesCoder()
elif type_info == "array_type":
return IterableCoder(_coder_from_type(field_type.array_type.element_type))
elif type_info == "map_type":
return MapCoder(
_coder_from_type(field_type.map_type.key_type),
_coder_from_type(field_type.map_type.value_type))
elif type_info == "logical_type":
if field_type.logical_type.urn == PYTHON_ANY_URN:
# Special case for the Any logical type. Just use the default coder for an
# unknown Python object.
return typecoders.registry.get_coder(object)
elif field_type.logical_type.urn == common_urns.millis_instant.urn:
# Special case for millis instant logical type used to handle Java sdk's
# millis Instant. It explicitly uses TimestampCoder which deals with fix
# length 8-bytes big-endian-long instead of VarInt coder.
return TimestampCoder()
elif field_type.logical_type.urn == 'beam:logical_type:decimal:v1':
return DecimalCoder()
logical_type = LogicalType.from_runner_api(field_type.logical_type)
return LogicalTypeCoder(
logical_type, _coder_from_type(field_type.logical_type.representation))
elif type_info == "row_type":
return RowCoder(field_type.row_type.schema)
# The Java SDK supports several more types, but the coders are not yet
# standard, and are not implemented in Python.
raise ValueError(
"Encountered a type that is not currently supported by RowCoder: %s" %
field_type)
class LogicalTypeCoder(FastCoder):
def __init__(self, logical_type, representation_coder):
self.logical_type = logical_type
self.representation_coder = representation_coder
def _create_impl(self):
return LogicalTypeCoderImpl(self.logical_type, self.representation_coder)
def is_deterministic(self):
return self.representation_coder.is_deterministic()
def to_type_hint(self):
return self.logical_type.language_type()