AI Platform integration patterns

This page describes common patterns in pipelines with Google Cloud AI Platform transforms.

Analysing the structure and meaning of text

This section shows how to use Google Cloud Natural Language API to perform text analysis.

Beam provides a PTransform called AnnotateText. The transform takes a PCollection of type Document. Each Document object contains various information about text. This includes the content, whether it is a plain text or HTML, an optional language hint and other settings. AnnotateText produces response object of type AnnotateTextResponse returned from the API. AnnotateTextResponse is a protobuf message which contains a lot of attributes, some of which are complex structures.

Here is an example of a pipeline that creates in-memory PCollection of strings, changes each string to Document object and invokes Natural Language API. Then, for each response object, a function is called to extract certain results of analysis.

features = nlp.types.AnnotateTextRequest.Features(
    extract_entities=True,
    extract_document_sentiment=True,
    extract_entity_sentiment=True,
    extract_syntax=True,
)

with beam.Pipeline() as pipeline:
  responses = (
      pipeline
      | beam.Create([
          'My experience so far has been fantastic! '
          'I\'d really recommend this product.'
      ])
      | beam.Map(lambda x: nlp.Document(x, type='PLAIN_TEXT'))
      | nlp.AnnotateText(features))

  _ = (
      responses
      | beam.Map(extract_sentiments)
      | 'Parse sentiments to JSON' >> beam.Map(json.dumps)
      | 'Write sentiments' >> beam.io.WriteToText('sentiments.txt'))

  _ = (
      responses
      | beam.Map(extract_entities)
      | 'Parse entities to JSON' >> beam.Map(json.dumps)
      | 'Write entities' >> beam.io.WriteToText('entities.txt'))

  _ = (
      responses
      | beam.Map(analyze_dependency_tree)
      | 'Parse adjacency list to JSON' >> beam.Map(json.dumps)
      | 'Write adjacency list' >> beam.io.WriteToText('adjancency_list.txt'))
AnnotateTextRequest.Features features =
    AnnotateTextRequest.Features.newBuilder()
        .setExtractEntities(true)
        .setExtractDocumentSentiment(true)
        .setExtractEntitySentiment(true)
        .setExtractSyntax(true)
        .build();
AnnotateText annotateText = AnnotateText.newBuilder().setFeatures(features).build();

PCollection<AnnotateTextResponse> responses =
    p.apply(
            Create.of(
                "My experience so far has been fantastic, "
                    + "I\'d really recommend this product."))
        .apply(
            MapElements.into(TypeDescriptor.of(Document.class))
                .via(
                    (SerializableFunction<String, Document>)
                        input ->
                            Document.newBuilder()
                                .setContent(input)
                                .setType(Document.Type.PLAIN_TEXT)
                                .build()))
        .apply(annotateText);

responses
    .apply(MapElements.into(TypeDescriptor.of(TextSentiments.class)).via(extractSentiments))
    .apply(
        MapElements.into(TypeDescriptors.strings())
            .via((SerializableFunction<TextSentiments, String>) TextSentiments::toJson))
    .apply(TextIO.write().to("sentiments.txt"));

responses
    .apply(
        MapElements.into(
                TypeDescriptors.maps(TypeDescriptors.strings(), TypeDescriptors.strings()))
            .via(extractEntities))
    .apply(MapElements.into(TypeDescriptors.strings()).via(mapEntitiesToJson))
    .apply(TextIO.write().to("entities.txt"));

responses
    .apply(
        MapElements.into(
                TypeDescriptors.lists(
                    TypeDescriptors.maps(
                        TypeDescriptors.strings(),
                        TypeDescriptors.lists(TypeDescriptors.strings()))))
            .via(analyzeDependencyTree))
    .apply(MapElements.into(TypeDescriptors.strings()).via(mapDependencyTreesToJson))
    .apply(TextIO.write().to("adjacency_list.txt"));

Extracting sentiments

This is a part of response object returned from the API. Sentence-level sentiments can be found in sentences attribute. sentences behaves like a standard Python sequence, therefore all core language features (like iteration or slicing) will work. Overall sentiment can be found in document_sentiment attribute.

sentences {
  text {
    content: "My experience so far has been fantastic!"
  }
  sentiment {
    magnitude: 0.8999999761581421
    score: 0.8999999761581421
  }
}
sentences {
  text {
    content: "I\'d really recommend this product."
    begin_offset: 41
  }
  sentiment {
    magnitude: 0.8999999761581421
    score: 0.8999999761581421
  }
}

...many lines omitted

document_sentiment {
  magnitude: 1.899999976158142
  score: 0.8999999761581421
}

The function for extracting information about sentence-level and document-level sentiments is shown in the next code snippet.

return {
    'sentences': [{
        sentence.text.content: sentence.sentiment.score
    } for sentence in response.sentences],
    'document_sentiment': response.document_sentiment.score,
}
extractSentiments =
(SerializableFunction<AnnotateTextResponse, TextSentiments>)
    annotateTextResponse -> {
      TextSentiments sentiments = new TextSentiments();
      sentiments.setDocumentSentiment(
          annotateTextResponse.getDocumentSentiment().getMagnitude());
      Map<String, Float> sentenceSentimentsMap =
          annotateTextResponse.getSentencesList().stream()
              .collect(
                  Collectors.toMap(
                      (Sentence s) -> s.getText().getContent(),
                      (Sentence s) -> s.getSentiment().getMagnitude()));
      sentiments.setSentenceSentiments(sentenceSentimentsMap);
      return sentiments;
    };

The snippet loops over sentences and, for each sentence, extracts the sentiment score.

The output is:

{"sentences": [{"My experience so far has been fantastic!": 0.8999999761581421}, {"I'd really recommend this product.": 0.8999999761581421}], "document_sentiment": 0.8999999761581421}

Extracting entities

The next function inspects the response for entities and returns the names and the types of those entities.

return [{
    'name': entity.name,
    'type': nlp.enums.Entity.Type(entity.type).name,
} for entity in response.entities]
extractEntities =
(SerializableFunction<AnnotateTextResponse, Map<String, String>>)
    annotateTextResponse ->
        annotateTextResponse.getEntitiesList().stream()
            .collect(
                Collectors.toMap(Entity::getName, (Entity e) -> e.getType().toString()));

Entities can be found in entities attribute. Just like before, entities is a sequence, that’s why list comprehension is a viable choice. The most tricky part is interpreting the types of entities. Natural Language API defines entity types as enum. In a response object, entity types are returned as integers. That’s why a user has to instantiate naturallanguageml.enums.Entity.Type to access a human-readable name.

The output is:

[{"name": "experience", "type": "OTHER"}, {"name": "product", "type": "CONSUMER_GOOD"}]

Accessing sentence dependency tree

The following code loops over the sentences and, for each sentence, builds an adjacency list that represents a dependency tree. For more information on what dependency tree is, see Morphology & Dependency Trees.

from collections import defaultdict
adjacency_lists = []

index = 0
for sentence in response.sentences:
  adjacency_list = defaultdict(list)
  sentence_begin = sentence.text.begin_offset
  sentence_end = sentence_begin + len(sentence.text.content) - 1

  while index < len(response.tokens) and \
      response.tokens[index].text.begin_offset <= sentence_end:
    token = response.tokens[index]
    head_token_index = token.dependency_edge.head_token_index
    head_token_text = response.tokens[head_token_index].text.content
    adjacency_list[head_token_text].append(token.text.content)
    index += 1
  adjacency_lists.append(adjacency_list)
analyzeDependencyTree =
    (SerializableFunction<AnnotateTextResponse, List<Map<String, List<String>>>>)
        response -> {
          List<Map<String, List<String>>> adjacencyLists = new ArrayList<>();
          int index = 0;
          for (Sentence s : response.getSentencesList()) {
            Map<String, List<String>> adjacencyMap = new HashMap<>();
            int sentenceBegin = s.getText().getBeginOffset();
            int sentenceEnd = sentenceBegin + s.getText().getContent().length() - 1;
            while (index < response.getTokensCount()
                && response.getTokens(index).getText().getBeginOffset() <= sentenceEnd) {
              Token token = response.getTokensList().get(index);
              int headTokenIndex = token.getDependencyEdge().getHeadTokenIndex();
              String headTokenContent =
                  response.getTokens(headTokenIndex).getText().getContent();
              List<String> adjacencyList =
                  adjacencyMap.getOrDefault(headTokenContent, new ArrayList<>());
              adjacencyList.add(token.getText().getContent());
              adjacencyMap.put(headTokenContent, adjacencyList);
              index++;
            }
            adjacencyLists.add(adjacencyMap);
          }
          return adjacencyLists;
        };

The output is below. For better readability, indexes are replaced by text which they refer to:

[
  {
    "experience": [
      "My"
    ],
    "been": [
      "experience",
      "far",
      "has",
      "been",
      "fantastic",
      "!"
    ],
    "far": [
      "so"
    ]
  },
  {
    "recommend": [
      "I",
      "'d",
      "really",
      "recommend",
      "product",
      "."
    ],
    "product": [
      "this"
    ]
  }
]

Getting predictions

This section shows how to use Google Cloud AI Platform Prediction to make predictions about new data from a cloud-hosted machine learning model.

tfx_bsl is a library with a Beam PTransform called RunInference. RunInference is able to perform an inference that can use an external service endpoint for receiving data. When using a service endpoint, the transform takes a PCollection of type tf.train.Example and, for every batch of elements, sends a request to AI Platform Prediction. The size of a batch is automatically computed. For more details on how Beam finds the best batch size, refer to a docstring for BatchElements. Currently, the transform does not support using tf.train.SequenceExample as input, but the work is in progress.

The transform produces a PCollection of type PredictionLog, which contains predictions.

Before getting started, deploy a TensorFlow model to AI Platform Prediction. The cloud service manages the infrastructure needed to handle prediction requests in both efficient and scalable way. Do note that only TensorFlow models are supported by the transform. For more information, see Exporting a SavedModel for prediction.

Once a machine learning model is deployed, prepare a list of instances to get predictions for. To send binary data, make sure that the name of an input ends in _bytes. This will base64-encode data before sending a request.

Example

Here is an example of a pipeline that reads input instances from the file, converts JSON objects to tf.train.Example objects and sends data to AI Platform Prediction. The content of a file can look like this:

{"input": "the quick brown"}
{"input": "la bruja le"}

The example creates tf.train.BytesList instances, thus it expects byte-like strings as input. However, other data types, like tf.train.FloatList and tf.train.Int64List, are also supported by the transform.

Here is the code:

import json

import apache_beam as beam

import tensorflow as tf
from tfx_bsl.beam.run_inference import RunInference
from tfx_bsl.proto import model_spec_pb2

def convert_json_to_tf_example(json_obj):
  samples = json.loads(json_obj)
  for name, text in samples.items():
      value = tf.train.Feature(bytes_list=tf.train.BytesList(
        value=[text.encode('utf-8')]))
      feature = {name: value}
      return tf.train.Example(features=tf.train.Features(feature=feature))

with beam.Pipeline() as p:
     _ = (p
         | beam.io.ReadFromText('gs://my-bucket/samples.json')
         | beam.Map(convert_json_to_tf_example)
         | RunInference(
             model_spec_pb2.InferenceEndpoint(
                 model_endpoint_spec=model_spec_pb2.AIPlatformPredictionModelSpec(
                     project_id='my-project-id',
                     model_name='my-model-name',
                     version_name='my-model-version'))))
// Getting predictions is not yet available for Java. [https://github.com/apache/beam/issues/20001]