apache_beam.ml.transforms.embeddings.huggingface module

class apache_beam.ml.transforms.embeddings.huggingface.SentenceTransformerEmbeddings(model_name: str, columns: list[str], max_seq_length: int | None = None, image_model: bool = False, **kwargs)[source]

Bases: EmbeddingsManager

Embedding config for sentence-transformers. This config can be used with MLTransform to embed text data. Models are loaded using the RunInference PTransform with the help of ModelHandler.

Parameters:
  • model_name – Name of the model to use. The model should be hosted on HuggingFace Hub or compatible with sentence_transformers. For image embedding models, see https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long for a list of available sentence_transformers models.

  • columns – List of columns to be embedded.

  • max_seq_length – Max sequence length to use for the model if applicable.

  • image_model – Whether the model is generating image embeddings.

  • min_batch_size – The minimum batch size to be used for inference.

  • max_batch_size – The maximum batch size to be used for inference.

  • large_model – Whether to share the model across processes.

get_model_handler()[source]
get_ptransform_for_processing(**kwargs) PTransform[source]
class apache_beam.ml.transforms.embeddings.huggingface.InferenceAPIEmbeddings(hf_token: str | None, columns: list[str], model_name: str | None = None, api_url: str | None = None, **kwargs)[source]

Bases: EmbeddingsManager

Feature extraction using HuggingFace’s Inference API. Intended to be used for feature-extraction. For other tasks, please refer to https://huggingface.co/inference-api.

Parameters:
  • hf_token – HuggingFace token.

  • columns – List of columns to be embedded.

  • model_name – Model name used for feature extraction.

  • api_url – API url for feature extraction. If specified, model_name will be ignored. If none, the default url for feature extraction will be used.

get_token()[source]
property api_url
property authorization_token
get_model_handler()[source]
get_ptransform_for_processing(**kwargs) PTransform[source]