apache_beam.ml.inference.base module¶
An extensible run inference transform.
Users of this module can extend the ModelHandler class for any machine learning framework. A ModelHandler implementation is a required parameter of RunInference.
The transform handles standard inference functionality, like metric collection, sharing model between threads, and batching elements.
- class apache_beam.ml.inference.base.PredictionResult(example, inference, model_id=None)[source]¶
Bases:
PredictionResult
A NamedTuple containing both input and output from the inference.
- class apache_beam.ml.inference.base.ModelMetadata(model_id, model_name)[source]¶
Bases:
NamedTuple
Create new instance of ModelMetadata(model_id, model_name)
- class apache_beam.ml.inference.base.RunInferenceDLQ(failed_inferences, failed_preprocessing, failed_postprocessing)[source]¶
Bases:
NamedTuple
Create new instance of RunInferenceDLQ(failed_inferences, failed_preprocessing, failed_postprocessing)
- failed_inferences: PCollection¶
Alias for field number 0
- failed_preprocessing: Sequence[PCollection]¶
Alias for field number 1
- failed_postprocessing: Sequence[PCollection]¶
Alias for field number 2
- class apache_beam.ml.inference.base.KeyModelPathMapping(keys: List[KeyT], update_path: str, model_id: str = '')[source]¶
Bases:
Generic
[KeyT
]Dataclass for mapping 1 or more keys to 1 model path. This is used in conjunction with a KeyedModelHandler with many model handlers to update a set of keys’ model handlers with the new path. Given KeyModelPathMapping(keys: [‘key1’, ‘key2’], update_path: ‘updated/path’, model_id: ‘id1’), all examples with keys key1 or key2 will have their corresponding model handler’s update_model function called with ‘updated/path’ and their metrics will correspond with ‘id1’. For more information see the KeyedModelHandler documentation https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler documentation and the website section on model updates https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh
- class apache_beam.ml.inference.base.ModelHandler[source]¶
Bases:
Generic
[ExampleT
,PredictionT
,ModelT
]Has the ability to load and apply an ML model.
Environment variables are set using a dict named ‘env_vars’ before loading the model. Child classes can accept this dict as a kwarg.
- run_inference(batch: Sequence[ExampleT], model: ModelT, inference_args: Dict[str, Any] | None = None) Iterable[PredictionT] [source]¶
Runs inferences on a batch of examples.
- Parameters:
batch – A sequence of examples or features.
model – The model used to make inferences.
inference_args – Extra arguments for models whose inference call requires extra parameters.
- Returns:
An Iterable of Predictions.
- get_num_bytes(batch: Sequence[ExampleT]) int [source]¶
- Returns:
The number of bytes of data for a batch.
- get_metrics_namespace() str [source]¶
- Returns:
A namespace for metrics collected by the RunInference transform.
- batch_elements_kwargs() Mapping[str, Any] [source]¶
- Returns:
kwargs suitable for beam.BatchElements.
- validate_inference_args(inference_args: Dict[str, Any] | None)[source]¶
Validates inference_args passed in the inference call.
Because most frameworks do not need extra arguments in their predict() call, the default behavior is to error out if inference_args are present.
- update_model_path(model_path: str | None = None)[source]¶
Update the model path produced by side inputs. update_model_path should be used when a ModelHandler represents a single model, not multiple models. This will be true in most cases. For more information see the website section on model updates https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh
- update_model_paths(model: ModelT, model_paths: str | List[KeyModelPathMapping] | None = None)[source]¶
Update the model paths produced by side inputs. update_model_paths should be used when updating multiple models at once (e.g. when using a KeyedModelHandler that holds multiple models). For more information see the KeyedModelHandler documentation https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler documentation and the website section on model updates https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh
- get_preprocess_fns() Iterable[Callable[[Any], Any]] [source]¶
Gets all preprocessing functions to be run before batching/inference. Functions are in order that they should be applied.
- get_postprocess_fns() Iterable[Callable[[Any], Any]] [source]¶
Gets all postprocessing functions to be run after inference. Functions are in order that they should be applied.
- should_skip_batching() bool [source]¶
Whether RunInference’s batching should be skipped. Can be flipped to True by using with_no_batching
- set_environment_vars()[source]¶
Sets environment variables using a dictionary provided via kwargs. Keys are the env variable name, and values are the env variable value. Child ModelHandler classes should set _env_vars via kwargs in __init__, or else call super().__init__().
- with_preprocess_fn(fn: Callable[[PreProcessT], ExampleT]) ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT] [source]¶
Returns a new ModelHandler with a preprocessing function associated with it. The preprocessing function will be run before batching/inference and should map your input PCollection to the base ModelHandler’s input type. If you apply multiple preprocessing functions, they will be run on your original PCollection in order from last applied to first applied.
- with_postprocess_fn(fn: Callable[[PredictionT], PostProcessT]) ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT] [source]¶
Returns a new ModelHandler with a postprocessing function associated with it. The postprocessing function will be run after inference and should map the base ModelHandler’s output type to your desired output type. If you apply multiple postprocessing functions, they will be run on your original inference result in order from first applied to last applied.
- with_no_batching() ModelHandler[Union[ [source]¶
- ModelHandler.with_no_batching ExampleT, Iterable[ExampleT]], PostProcessT, ModelT, PostProcessT]
Returns a new ModelHandler which does not require batching of inputs so that RunInference will skip this step. RunInference will expect the input to be pre-batched and passed in as an Iterable of records. If you skip batching, any preprocessing functions should accept a batch of data, not just a single record.
This option is only recommended if you want to do custom batching yourself. If you just want to pass in records without a batching dimension, it is recommended to (1) add max_batch_size=1 to batch_elements_kwargs and (2) remove the batching dimension as part of your inference call (by calling record=batch[0])
Returns a boolean representing whether or not a model should be shared across multiple processes instead of being loaded per process. This is primary useful for large models that can’t fit multiple copies in memory. Multi-process support may vary by runner, but this will fallback to loading per process as necessary. See https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html
- model_copies() int [source]¶
Returns the maximum number of model copies that should be loaded at one time. This only impacts model handlers that are using share_model_across_processes to share their model across processes instead of being loaded per process.
- class apache_beam.ml.inference.base.KeyModelMapping(keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT])[source]¶
Bases:
Generic
[KeyT
,ExampleT
,PredictionT
,ModelT
]Dataclass for mapping 1 or more keys to 1 model handler. Given KeyModelMapping([‘key1’, ‘key2’], myMh), all examples with keys key1 or key2 will be run against the model defined by the myMh ModelHandler.
- class apache_beam.ml.inference.base.KeyedModelHandler(unkeyed: ModelHandler[ExampleT, PredictionT, ModelT] | List[KeyModelMapping[KeyT, ExampleT, PredictionT, ModelT]], max_models_per_worker_hint: int | None = None)[source]¶
Bases:
Generic
[KeyT
,ExampleT
,PredictionT
,ModelT
],ModelHandler
[Tuple
[KeyT
,ExampleT
],Tuple
[KeyT
,PredictionT
],ModelT
|_ModelManager
]A ModelHandler that takes keyed examples and returns keyed predictions.
For example, if the original model is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible to use the key to associate the outputs with the inputs. KeyedModelHandler is able to accept either a single unkeyed ModelHandler or many different model handlers corresponding to the keys for which that ModelHandler should be used. For example, the following configuration could be used to map keys 1-3 to ModelHandler1 and keys 4-5 to ModelHandler2:
k1 = [‘k1’, ‘k2’, ‘k3’] k2 = [‘k4’, ‘k5’] KeyedModelHandler([KeyModelMapping(k1, mh1), KeyModelMapping(k2, mh2)])
Note that a single copy of each of these models may all be held in memory at the same time; be careful not to load too many large models or your pipeline may cause Out of Memory exceptions.
KeyedModelHandlers support Automatic Model Refresh to update your model to a newer version without stopping your streaming pipeline. For an overview of this feature, see https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh
To use this feature with a KeyedModelHandler that has many models per key, you can pass in a list of KeyModelPathMapping objects to define your new model paths. For example, passing in the side input of
[KeyModelPathMapping(keys=[‘k1’, ‘k2’], update_path=’update/path/1’), KeyModelPathMapping(keys=[‘k3’], update_path=’update/path/2’)]
will update the model corresponding to keys ‘k1’ and ‘k2’ with path ‘update/path/1’ and the model corresponding to ‘k3’ with ‘update/path/2’. In order to do a side input update: (1) all restrictions mentioned in https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh must be met, (2) all update_paths must be non-empty, even if they are not being updated from their original values, and (3) the set of keys originally defined cannot change. This means that if originally you have defined model handlers for ‘key1’, ‘key2’, and ‘key3’, all 3 of those keys must appear in your list of KeyModelPathMappings exactly once. No additional keys can be added.
When using many models defined per key, metrics about inference and model loading will be gathered on an aggregate basis for all keys. These will be reported with no prefix. Metrics will also be gathered on a per key basis. Since some keys can share the same model, only one set of metrics will be reported per key ‘cohort’. These will be reported in the form: <cohort_key>-<metric_name>, where <cohort_key> can be any key selected from the cohort. When model updates occur, the metrics will be reported in the form <cohort_key>-<model id>-<metric_name>.
Loading multiple models at the same time can increase the risk of an out of memory (OOM) exception. To avoid this issue, use the parameter max_models_per_worker_hint to limit the number of models that are loaded at the same time. For more information about memory management, see Use a keyed `ModelHandler <https://beam.apache.org/documentation/ml/about-ml/#use-a-keyed-modelhandler-object>_. # pylint: disable=line-too-long
- Parameters:
unkeyed – Either (a) an implementation of ModelHandler that does not require keys or (b) a list of KeyModelMappings mapping lists of keys to unkeyed ModelHandlers.
max_models_per_worker_hint – A hint to the runner indicating how many models can be held in memory at one time per worker process. For example, if your worker has 8 GB of memory provisioned and your workers take up 1 GB each, you should set this to 7 to allow all models to sit in memory with some buffer. For more information about memory management, see Use a keyed `ModelHandler <https://beam.apache.org/documentation/ml/about-ml/#use-a-keyed-modelhandler-object>_. # pylint: disable=line-too-long
- run_inference(batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT | _ModelManager, inference_args: Dict[str, Any] | None = None) Iterable[Tuple[KeyT, PredictionT]] [source]¶
- update_model_paths(model: ModelT | _ModelManager, model_paths: List[KeyModelPathMapping[KeyT]] | None = None)[source]¶
- class apache_beam.ml.inference.base.MaybeKeyedModelHandler(unkeyed: ModelHandler[ExampleT, PredictionT, ModelT])[source]¶
Bases:
Generic
[KeyT
,ExampleT
,PredictionT
,ModelT
],ModelHandler
[ExampleT
|Tuple
[KeyT
,ExampleT
],PredictionT
|Tuple
[KeyT
,PredictionT
],ModelT
]A ModelHandler that takes examples that might have keys and returns predictions that might have keys.
For example, if the original model is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take either PCollection[E] to a PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], depending on the whether the elements are tuples. This pattern makes it possible to associate the outputs with the inputs based on the key.
Note that you cannot use this ModelHandler if E is a tuple type. In addition, either all examples should be keyed, or none of them.
- Parameters:
unkeyed – An implementation of ModelHandler that does not require keys.
- run_inference(batch: Sequence[ExampleT | Tuple[KeyT, ExampleT]], model: ModelT, inference_args: Dict[str, Any] | None = None) Iterable[PredictionT] | Iterable[Tuple[KeyT, PredictionT]] [source]¶
- class apache_beam.ml.inference.base.RunInference(model_handler: ~apache_beam.ml.inference.base.ModelHandler[~apache_beam.ml.inference.base.ExampleT, ~apache_beam.ml.inference.base.PredictionT, ~typing.Any], clock=<module 'time' (built-in)>, inference_args: ~typing.Dict[str, ~typing.Any] | None = None, metrics_namespace: str | None = None, *, model_metadata_pcoll: ~apache_beam.pvalue.PCollection[~apache_beam.ml.inference.base.ModelMetadata] | None = None, watch_model_pattern: str | None = None, model_identifier: str | None = None, **kwargs)[source]¶
Bases:
PTransform
[PCollection
[ExampleT
|Iterable
[ExampleT
]],PCollection
[PredictionT
]]A transform that takes a PCollection of examples (or features) for use on an ML model. The transform then outputs inferences (or predictions) for those examples in a PCollection of PredictionResults that contains the input examples and the output inferences.
Models for supported frameworks can be loaded using a URI. Supported services can also be used.
This transform attempts to batch examples using the beam.BatchElements transform. Batching can be configured using the ModelHandler.
- Parameters:
model_handler – An implementation of ModelHandler.
clock – A clock implementing time_ns. Used for unit testing.
inference_args – Extra arguments for models whose inference call requires extra parameters.
metrics_namespace – Namespace of the transform to collect metrics.
model_metadata_pcoll – PCollection that emits Singleton ModelMetadata containing model path and model name, that is used as a side input to the _RunInferenceDoFn.
watch_model_pattern – A glob pattern used to watch a directory for automatic model refresh.
model_identifier – A string used to identify the model being loaded. You can set this if you want to reuse the same model across multiple RunInference steps and don’t want to reload it twice. Note that using the same tag for different models will lead to non-deterministic results, so exercise caution when using this parameter. This only impacts models which are already being shared across processes.
- classmethod from_callable(model_handler_provider, **kwargs)[source]¶
Multi-language friendly constructor.
Use this constructor with fully_qualified_named_transform to initialize the RunInference transform from PythonCallableSource provided by foreign SDKs.
- Parameters:
model_handler_provider – A callable object that returns ModelHandler.
kwargs – Keyword arguments for model_handler_provider.
- expand(pcoll: PCollection[ExampleT]) PCollection[PredictionT] [source]¶
- with_exception_handling(*, exc_class=<class 'Exception'>, use_subprocess=False, threshold=1, timeout: int | None = None)[source]¶
Automatically provides a dead letter output for skipping bad records. This can allow a pipeline to continue successfully rather than fail or continuously throw errors on retry when bad elements are encountered.
This returns a tagged output with two PCollections, the first being the results of successfully processing the input PCollection, and the second being the set of bad batches of records (those which threw exceptions during processing) along with information about the errors raised.
For example, one would write:
main, other = RunInference( maybe_error_raising_model_handler ).with_exception_handling()
and main will be a PCollection of PredictionResults and other will contain a RunInferenceDLQ object with PCollections containing failed records for each failed inference, preprocess operation, or postprocess operation. To access each collection of failed records, one would write:
failed_inferences = other.failed_inferences failed_preprocessing = other.failed_preprocessing failed_postprocessing = other.failed_postprocessing
failed_inferences is in the form PCollection[Tuple[failed batch, exception]].
failed_preprocessing is in the form list[PCollection[Tuple[failed record, exception]]]], where each element of the list corresponds to a preprocess function. These PCollections are in the same order that the preprocess functions are applied.
failed_postprocessing is in the form List[PCollection[Tuple[failed record, exception]]]], where each element of the list corresponds to a postprocess function. These PCollections are in the same order that the postprocess functions are applied.
- Parameters:
exc_class – An exception class, or tuple of exception classes, to catch. Optional, defaults to ‘Exception’.
use_subprocess – Whether to execute the DoFn logic in a subprocess. This allows one to recover from errors that can crash the calling process (e.g. from an underlying library causing a segfault), but is slower as elements and results must cross a process boundary. Note that this starts up a long-running process that is used to handle all the elements (until hard failure, which should be rare) rather than a new process per element, so the overhead should be minimal (and can be amortized if there’s any per-process or per-bundle initialization that needs to be done). Optional, defaults to False.
threshold – An upper bound on the ratio of records that can be bad before aborting the entire pipeline. Optional, defaults to 1.0 (meaning up to 100% of records can be bad and the pipeline will still succeed).
timeout – The maximum amount of time in seconds given to load a model, run inference on a batch of elements and perform and pre/postprocessing operations. Since the timeout applies in multiple places, it should be equal to the maximum possible timeout for any of these operations. Note in particular that model load and inference on a single batch count to the same timeout value. When an inference fails, all related resources, including the model, will be deleted and reloaded. As a result, it is recommended to leave significant buffer and set the timeout to at least 2 * (time to load model + time to run inference on a batch of data).