#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Helper to render pipeline graph in IPython when running interactively.
This module is experimental. No backwards-compatibility guarantees.
"""
# pytype: skip-file
import re
from apache_beam.runners.interactive.display import pipeline_graph
[docs]
def nice_str(o):
s = repr(o)
s = s.replace('"', "'")
s = s.replace('\\', '|')
s = re.sub(r'[^\x20-\x7F]', ' ', s)
assert '"' not in s
if len(s) > 35:
s = s[:35] + '...'
return s
[docs]
class InteractivePipelineGraph(pipeline_graph.PipelineGraph):
"""Creates the DOT representation of an interactive pipeline. Thread-safe."""
def __init__(
self,
pipeline,
required_transforms=None,
referenced_pcollections=None,
cached_pcollections=None):
"""Constructor of PipelineGraph.
Args:
pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
required_transforms: (list/set of str) ID of top level PTransforms that
lead to visible results.
referenced_pcollections: (list/set of str) ID of PCollections that are
referenced by top level PTransforms executed (i.e.
required_transforms)
cached_pcollections: (set of str) a set of PCollection IDs of those whose
cached results are used in the execution.
"""
self._required_transforms = required_transforms or set()
self._referenced_pcollections = referenced_pcollections or set()
self._cached_pcollections = cached_pcollections or set()
super().__init__(
pipeline=pipeline,
default_vertex_attrs={
'color': 'gray', 'fontcolor': 'gray'
},
default_edge_attrs={'color': 'gray'})
transform_updates, pcollection_updates = self._generate_graph_update_dicts()
self._update_graph(transform_updates, pcollection_updates)
[docs]
def update_pcollection_stats(self, pcollection_stats):
"""Updates PCollection stats.
Args:
pcollection_stats: (dict of dict) maps PCollection IDs to informations. In
particular, we only care about the field 'sample' which should be a
the PCollection result in as a list.
"""
edge_dict = {}
for pcoll_id, stats in pcollection_stats.items():
attrs = {}
pcoll_list = stats['sample']
if pcoll_list:
attrs['label'] = format_sample(pcoll_list, 1)
attrs['labeltooltip'] = format_sample(pcoll_list, 10)
else:
attrs['label'] = '?'
edge_dict[pcoll_id] = attrs
self._update_graph(edge_dict=edge_dict)
def _generate_graph_update_dicts(self):
"""Generate updates specific to interactive pipeline.
Returns:
vertex_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes
edge_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes
"""
transform_dict = {} # maps PTransform IDs to properties
pcoll_dict = {} # maps PCollection IDs to properties
for transform_id, transform_proto in self._top_level_transforms():
transform_dict[transform_proto.unique_name] = {
'required': transform_id in self._required_transforms
}
for pcoll_id in transform_proto.outputs.values():
pcoll_dict[pcoll_id] = {
'cached': pcoll_id in self._cached_pcollections,
'referenced': pcoll_id in self._referenced_pcollections
}
def vertex_properties_to_attributes(vertex):
"""Converts PCollection properties to DOT vertex attributes."""
attrs = {}
if 'leaf' in vertex:
attrs['style'] = 'invis'
elif vertex.get('required'):
attrs['color'] = 'blue'
attrs['fontcolor'] = 'blue'
else:
attrs['color'] = 'grey'
return attrs
def edge_properties_to_attributes(edge):
"""Converts PTransform properties to DOT edge attributes."""
attrs = {}
if edge.get('cached'):
attrs['color'] = 'red'
elif edge.get('referenced'):
attrs['color'] = 'black'
else:
attrs['color'] = 'grey'
return attrs
vertex_dict = {} # maps vertex names to attributes
edge_dict = {} # maps edge names to attributes
for transform_name, transform_properties in transform_dict.items():
vertex_dict[transform_name] = vertex_properties_to_attributes(
transform_properties)
for pcoll_id, pcoll_properties in pcoll_dict.items():
edge_dict[pcoll_id] = edge_properties_to_attributes(pcoll_properties)
return vertex_dict, edge_dict