#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
from typing import List
from typing import Optional
import apache_beam as beam
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import
from apache_beam.ml.inference import utils
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
__all__ = ['TensorflowHubTextEmbeddings', 'TensorflowHubImageEmbeddings']
# TODO: https://github.com/apache/beam/issues/30288
# Replace with TFModelHandlerTensor when load_model() supports TFHUB models.
class _TensorflowHubModelHandler(TFModelHandlerTensor):
"""
Note: Intended for internal use only. No backwards compatibility guarantees.
"""
def __init__(self, preprocessing_url: Optional[str], *args, **kwargs):
self.preprocessing_url = preprocessing_url
super().__init__(*args, **kwargs)
def load_model(self):
# unable to load the models with tf.keras.models.load_model so
# using hub.KerasLayer instead
model = hub.KerasLayer(self._model_uri, **self._load_model_args)
return model
def _convert_prediction_result_to_list(
self, predictions: Iterable[PredictionResult]):
result = []
for prediction in predictions:
inference = prediction.inference.numpy().tolist()
result.append(inference)
return result
def run_inference(self, batch, model, inference_args, model_id=None):
if not inference_args:
inference_args = {}
if not self.preprocessing_url:
predictions = default_tensor_inference_fn(
model=model,
batch=batch,
inference_args=inference_args,
model_id=model_id)
return self._convert_prediction_result_to_list(predictions)
vectorized_batch = tf.stack(batch, axis=0)
preprocessor_fn = hub.KerasLayer(self.preprocessing_url)
vectorized_batch = preprocessor_fn(vectorized_batch)
predictions = model(vectorized_batch)
# https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long
# pooled_output -> represents the text as a whole. This is an embeddings
# of the whole text. The shape is [batch_size, embedding_dimension]
# sequence_output -> represents the text as a sequence of tokens. This is
# an embeddings of each token in the text. The shape is
# [batch_size, max_sequence_length, embedding_dimension]
# pooled output is the embeedings as per the documentation. so let's use
# that.
embeddings = predictions['pooled_output']
predictions = utils._convert_to_result(batch, embeddings, model_id)
return self._convert_prediction_result_to_list(predictions)
[docs]
class TensorflowHubTextEmbeddings(EmbeddingsManager):
def __init__(
self,
columns: List[str],
hub_url: str,
preprocessing_url: Optional[str] = None,
**kwargs):
"""
Embedding config for tensorflow hub models. This config can be used with
MLTransform to embed text data. Models are loaded using the RunInference
PTransform with the help of a ModelHandler.
Args:
columns: The columns containing the text to be embedded.
hub_url: The url of the tensorflow hub model.
preprocessing_url: The url of the preprocessing model. This is optional.
If provided, the preprocessing model will be used to preprocess the
text before feeding it to the main model.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns=columns, **kwargs)
self.model_uri = hub_url
self.preprocessing_url = preprocessing_url
[docs]
def get_model_handler(self) -> ModelHandler:
# override the default inference function
return _TensorflowHubModelHandler(
model_uri=self.model_uri,
preprocessing_url=self.preprocessing_url,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model,
)
[docs]
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
"""
Returns a RunInference object that is used to run inference on the text
input using _TextEmbeddingHandler.
"""
return (
RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args,
))
[docs]
class TensorflowHubImageEmbeddings(EmbeddingsManager):
def __init__(self, columns: List[str], hub_url: str, **kwargs):
"""
Embedding config for tensorflow hub models. This config can be used with
MLTransform to embed image data. Models are loaded using the RunInference
PTransform with the help of a ModelHandler.
Args:
columns: The columns containing the images to be embedded.
hub_url: The url of the tensorflow hub model.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns=columns, **kwargs)
self.model_uri = hub_url
[docs]
def get_model_handler(self) -> ModelHandler:
# override the default inference function
return _TensorflowHubModelHandler(
model_uri=self.model_uri,
preprocessing_url=None,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model,
)