#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
from apache_beam.io.aws.clients.s3 import messages
from apache_beam.options import pipeline_options
from apache_beam.utils import retry
try:
# pylint: disable=wrong-import-order, wrong-import-position
# pylint: disable=ungrouped-imports
import boto3
except ImportError:
boto3 = None
[docs]
def get_http_error_code(exc):
if hasattr(exc, 'response'):
return exc.response.get('ResponseMetadata', {}).get('HTTPStatusCode')
return None
[docs]
class Client(object):
"""
Wrapper for boto3 library
"""
def __init__(self, options):
assert boto3 is not None, 'Missing boto3 requirement'
if isinstance(options, pipeline_options.PipelineOptions):
s3_options = options.view_as(pipeline_options.S3Options)
access_key_id = s3_options.s3_access_key_id
secret_access_key = s3_options.s3_secret_access_key
session_token = s3_options.s3_session_token
endpoint_url = s3_options.s3_endpoint_url
use_ssl = not s3_options.s3_disable_ssl
region_name = s3_options.s3_region_name
api_version = s3_options.s3_api_version
verify = s3_options.s3_verify
else:
access_key_id = options.get('s3_access_key_id')
secret_access_key = options.get('s3_secret_access_key')
session_token = options.get('s3_session_token')
endpoint_url = options.get('s3_endpoint_url')
use_ssl = not options.get('s3_disable_ssl', False)
region_name = options.get('s3_region_name')
api_version = options.get('s3_api_version')
verify = options.get('s3_verify')
session = boto3.session.Session()
self.client = session.client(
service_name='s3',
region_name=region_name,
api_version=api_version,
use_ssl=use_ssl,
verify=verify,
endpoint_url=endpoint_url,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token)
self._download_request = None
self._download_stream = None
self._download_pos = 0
[docs]
def get_stream(self, request, start):
"""Opens a stream object starting at the given position.
Args:
request: (GetRequest) request
start: (int) start offset
Returns:
(Stream) Boto3 stream object.
"""
if self._download_request and (
start != self._download_pos or
request.bucket != self._download_request.bucket or
request.object != self._download_request.object):
self._download_stream.close()
self._download_stream = None
# noinspection PyProtectedMember
if not self._download_stream or self._download_stream._raw_stream.closed:
try:
self._download_stream = self.client.get_object(
Bucket=request.bucket,
Key=request.object,
Range='bytes={}-'.format(start))['Body']
self._download_request = request
self._download_pos = start
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))
return self._download_stream
[docs]
@retry.with_exponential_backoff()
def get_range(self, request, start, end):
r"""Retrieves an object's contents.
Args:
request: (GetRequest) request
start: (int) start offset
end: (int) end offset (exclusive)
Returns:
(bytes) The response message.
"""
for i in range(2):
try:
stream = self.get_stream(request, start)
data = stream.read(end - start)
self._download_pos += len(data)
return data
except Exception as e:
self._download_stream = None
self._download_request = None
if i == 0:
# Read errors are likely with long-lived connections, retry immediately if a read fails once
continue
if isinstance(e, messages.S3ClientError):
raise e
raise messages.S3ClientError(str(e), get_http_error_code(e))
[docs]
def list(self, request):
r"""Retrieves a list of objects matching the criteria.
Args:
request: (ListRequest) input message
Returns:
(ListResponse) The response message.
"""
kwargs = {'Bucket': request.bucket, 'Prefix': request.prefix}
if request.continuation_token is not None:
kwargs['ContinuationToken'] = request.continuation_token
try:
boto_response = self.client.list_objects_v2(**kwargs)
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))
if boto_response['KeyCount'] == 0:
message = 'Tried to list nonexistent S3 path: s3://%s/%s' % (
request.bucket, request.prefix)
raise messages.S3ClientError(message, 404)
items = [
messages.Item(
etag=content['ETag'],
key=content['Key'],
last_modified=content['LastModified'],
size=content['Size']) for content in boto_response['Contents']
]
try:
next_token = boto_response['NextContinuationToken']
except KeyError:
next_token = None
response = messages.ListResponse(items, next_token)
return response
[docs]
def create_multipart_upload(self, request):
r"""Initates a multipart upload to S3 for a given object
Args:
request: (UploadRequest) input message
Returns:
(UploadResponse) The response message.
"""
try:
boto_response = self.client.create_multipart_upload(
Bucket=request.bucket,
Key=request.object,
ContentType=request.mime_type)
response = messages.UploadResponse(boto_response['UploadId'])
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))
return response
[docs]
def upload_part(self, request):
r"""Uploads part of a file to S3 during a multipart upload
Args:
request: (UploadPartRequest) input message
Returns:
(UploadPartResponse) The response message.
"""
try:
boto_response = self.client.upload_part(
Body=request.bytes,
Bucket=request.bucket,
Key=request.object,
PartNumber=request.part_number,
UploadId=request.upload_id)
response = messages.UploadPartResponse(
boto_response['ETag'], request.part_number)
return response
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))
[docs]
def complete_multipart_upload(self, request):
r"""Completes a multipart upload to S3
Args:
request: (UploadPartRequest) input message
Returns:
(Void) The response message.
"""
parts = {'Parts': request.parts}
try:
self.client.complete_multipart_upload(
Bucket=request.bucket,
Key=request.object,
UploadId=request.upload_id,
MultipartUpload=parts)
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))
[docs]
def delete(self, request):
r"""Deletes given object from bucket
Args:
request: (DeleteRequest) input message
Returns:
(void) Void, otherwise will raise if an error occurs
"""
try:
self.client.delete_object(Bucket=request.bucket, Key=request.object)
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))
[docs]
def delete_batch(self, request):
aws_request = {
'Bucket': request.bucket,
'Delete': {
'Objects': [{
'Key': object
} for object in request.objects]
}
}
try:
aws_response = self.client.delete_objects(**aws_request)
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))
deleted = [obj['Key'] for obj in aws_response.get('Deleted', [])]
failed = [obj['Key'] for obj in aws_response.get('Errors', [])]
errors = [
messages.S3ClientError(obj['Message'], obj['Code'])
for obj in aws_response.get('Errors', [])
]
return messages.DeleteBatchResponse(deleted, failed, errors)
[docs]
def copy(self, request):
try:
copy_src = {'Bucket': request.src_bucket, 'Key': request.src_key}
self.client.copy(copy_src, request.dest_bucket, request.dest_key)
except Exception as e:
raise messages.S3ClientError(str(e), get_http_error_code(e))