Mise à jour de Monitor.py et autres scripts

This commit is contained in:
Debian
2025-07-23 10:46:27 +02:00
parent 7081418ce0
commit 7de3e0fb50
8604 changed files with 2789953 additions and 295 deletions

View File

@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from libcpp.memory cimport shared_ptr
from pyarrow.includes.libarrow cimport (CArray, CBuffer, CDataType,
CField, CRecordBatch, CSchema,
CTable, CTensor, CSparseCOOTensor,
CSparseCSRMatrix, CSparseCSCMatrix,
CSparseCSFTensor)
cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py":
cdef int import_pyarrow() except -1
cdef object wrap_buffer(const shared_ptr[CBuffer]& buffer)
cdef object wrap_data_type(const shared_ptr[CDataType]& type)
cdef object wrap_field(const shared_ptr[CField]& field)
cdef object wrap_schema(const shared_ptr[CSchema]& schema)
cdef object wrap_array(const shared_ptr[CArray]& sp_array)
cdef object wrap_tensor(const shared_ptr[CTensor]& sp_tensor)
cdef object wrap_sparse_tensor_coo(
const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor)
cdef object wrap_sparse_tensor_csr(
const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor)
cdef object wrap_sparse_tensor_csc(
const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor)
cdef object wrap_sparse_tensor_csf(
const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor)
cdef object wrap_table(const shared_ptr[CTable]& ctable)
cdef object wrap_batch(const shared_ptr[CRecordBatch]& cbatch)

View File

@@ -0,0 +1,439 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# flake8: noqa
"""
PyArrow is the python implementation of Apache Arrow.
Apache Arrow is a cross-language development platform for in-memory data.
It specifies a standardized language-independent columnar memory format for
flat and hierarchical data, organized for efficient analytic operations on
modern hardware. It also provides computational libraries and zero-copy
streaming messaging and interprocess communication.
For more information see the official page at https://arrow.apache.org
"""
import gc as _gc
import importlib as _importlib
import os as _os
import platform as _platform
import sys as _sys
import warnings as _warnings
try:
from ._generated_version import version as __version__
except ImportError:
# Package is not installed, parse git tag at runtime
try:
import setuptools_scm
# Code duplicated from setup.py to avoid a dependency on each other
def parse_git(root, **kwargs):
"""
Parse function for setuptools_scm that ignores tags for non-C++
subprojects, e.g. apache-arrow-js-XXX tags.
"""
from setuptools_scm.git import parse
kwargs['describe_command'] = \
"git describe --dirty --tags --long --match 'apache-arrow-[0-9]*.*'"
return parse(root, **kwargs)
__version__ = setuptools_scm.get_version('../',
parse=parse_git)
except ImportError:
__version__ = None
import pyarrow.lib as _lib
from pyarrow.lib import (BuildInfo, RuntimeInfo, set_timezone_db_path,
MonthDayNano, VersionInfo, cpp_build_info,
cpp_version, cpp_version_info, runtime_info,
cpu_count, set_cpu_count, enable_signal_handlers,
io_thread_count, set_io_thread_count)
def show_versions():
"""
Print various version information, to help with error reporting.
"""
def print_entry(label, value):
print(f"{label: <26}: {value: <8}")
print("pyarrow version info\n--------------------")
print_entry("Package kind", cpp_build_info.package_kind
if len(cpp_build_info.package_kind) > 0
else "not indicated")
print_entry("Arrow C++ library version", cpp_build_info.version)
print_entry("Arrow C++ compiler",
f"{cpp_build_info.compiler_id} {cpp_build_info.compiler_version}")
print_entry("Arrow C++ compiler flags", cpp_build_info.compiler_flags)
print_entry("Arrow C++ git revision", cpp_build_info.git_id)
print_entry("Arrow C++ git description", cpp_build_info.git_description)
print_entry("Arrow C++ build type", cpp_build_info.build_type)
def _module_is_available(module):
try:
_importlib.import_module(f'pyarrow.{module}')
except ImportError:
return False
else:
return True
def _filesystem_is_available(fs):
try:
import pyarrow.fs
except ImportError:
return False
try:
getattr(pyarrow.fs, fs)
except (ImportError, AttributeError):
return False
else:
return True
def show_info():
"""
Print detailed version and platform information, for error reporting
"""
show_versions()
def print_entry(label, value):
print(f" {label: <20}: {value: <8}")
print("\nPlatform:")
print_entry("OS / Arch", f"{_platform.system()} {_platform.machine()}")
print_entry("SIMD Level", runtime_info().simd_level)
print_entry("Detected SIMD Level", runtime_info().detected_simd_level)
pool = default_memory_pool()
print("\nMemory:")
print_entry("Default backend", pool.backend_name)
print_entry("Bytes allocated", f"{pool.bytes_allocated()} bytes")
print_entry("Max memory", f"{pool.max_memory()} bytes")
print_entry("Supported Backends", ', '.join(supported_memory_backends()))
print("\nOptional modules:")
modules = ["csv", "cuda", "dataset", "feather", "flight", "fs", "gandiva", "json",
"orc", "parquet"]
for module in modules:
status = "Enabled" if _module_is_available(module) else "-"
print(f" {module: <20}: {status: <8}")
print("\nFilesystems:")
filesystems = ["AzureFileSystem", "GcsFileSystem",
"HadoopFileSystem", "S3FileSystem"]
for fs in filesystems:
status = "Enabled" if _filesystem_is_available(fs) else "-"
print(f" {fs: <20}: {status: <8}")
print("\nCompression Codecs:")
codecs = ["brotli", "bz2", "gzip", "lz4_frame", "lz4", "snappy", "zstd"]
for codec in codecs:
status = "Enabled" if Codec.is_available(codec) else "-"
print(f" {codec: <20}: {status: <8}")
from pyarrow.lib import (null, bool_,
int8, int16, int32, int64,
uint8, uint16, uint32, uint64,
time32, time64, timestamp, date32, date64, duration,
month_day_nano_interval,
float16, float32, float64,
binary, string, utf8, binary_view, string_view,
large_binary, large_string, large_utf8,
decimal32, decimal64, decimal128, decimal256,
list_, large_list, list_view, large_list_view,
map_, struct,
union, sparse_union, dense_union,
dictionary,
run_end_encoded,
bool8, fixed_shape_tensor, json_, opaque, uuid,
field,
type_for_alias,
DataType, DictionaryType, StructType,
ListType, LargeListType, FixedSizeListType,
ListViewType, LargeListViewType,
MapType, UnionType, SparseUnionType, DenseUnionType,
TimestampType, Time32Type, Time64Type, DurationType,
FixedSizeBinaryType,
Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type,
BaseExtensionType, ExtensionType,
RunEndEncodedType, Bool8Type, FixedShapeTensorType,
JsonType, OpaqueType, UuidType,
UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
KeyValueMetadata,
Field,
Schema,
schema,
unify_schemas,
Array, Tensor,
array, chunked_array, record_batch, nulls, repeat,
SparseCOOTensor, SparseCSRMatrix, SparseCSCMatrix,
SparseCSFTensor,
infer_type, from_numpy_dtype,
arange,
NullArray,
NumericArray, IntegerArray, FloatingPointArray,
BooleanArray,
Int8Array, UInt8Array,
Int16Array, UInt16Array,
Int32Array, UInt32Array,
Int64Array, UInt64Array,
HalfFloatArray, FloatArray, DoubleArray,
ListArray, LargeListArray, FixedSizeListArray,
ListViewArray, LargeListViewArray,
MapArray, UnionArray,
BinaryArray, StringArray,
LargeBinaryArray, LargeStringArray,
BinaryViewArray, StringViewArray,
FixedSizeBinaryArray,
DictionaryArray,
Date32Array, Date64Array, TimestampArray,
Time32Array, Time64Array, DurationArray,
MonthDayNanoIntervalArray,
Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
StructArray, ExtensionArray,
RunEndEncodedArray, Bool8Array, FixedShapeTensorArray,
JsonArray, OpaqueArray, UuidArray,
scalar, NA, _NULL as NULL, Scalar,
NullScalar, BooleanScalar,
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar,
HalfFloatScalar, FloatScalar, DoubleScalar,
Decimal32Scalar, Decimal64Scalar, Decimal128Scalar, Decimal256Scalar,
ListScalar, LargeListScalar, FixedSizeListScalar,
ListViewScalar, LargeListViewScalar,
Date32Scalar, Date64Scalar,
Time32Scalar, Time64Scalar,
TimestampScalar, DurationScalar,
MonthDayNanoIntervalScalar,
BinaryScalar, LargeBinaryScalar, BinaryViewScalar,
StringScalar, LargeStringScalar, StringViewScalar,
FixedSizeBinaryScalar, DictionaryScalar,
MapScalar, StructScalar, UnionScalar,
RunEndEncodedScalar, Bool8Scalar, ExtensionScalar,
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar)
# Buffers, allocation
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,
default_cpu_memory_manager)
from pyarrow.lib import (Buffer, ResizableBuffer, foreign_buffer, py_buffer,
Codec, compress, decompress, allocate_buffer)
from pyarrow.lib import (MemoryPool, LoggingMemoryPool, ProxyMemoryPool,
total_allocated_bytes, set_memory_pool,
default_memory_pool, system_memory_pool,
jemalloc_memory_pool, mimalloc_memory_pool,
logging_memory_pool, proxy_memory_pool,
log_memory_allocations, jemalloc_set_decay_ms,
supported_memory_backends)
# I/O
from pyarrow.lib import (NativeFile, PythonFile,
BufferedInputStream, BufferedOutputStream, CacheOptions,
CompressedInputStream, CompressedOutputStream,
TransformInputStream, transcoding_input_stream,
FixedSizeBufferWriter,
BufferReader, BufferOutputStream,
OSFile, MemoryMappedFile, memory_map,
create_memory_map, MockOutputStream,
input_stream, output_stream,
have_libhdfs)
from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table,
concat_arrays, concat_tables, TableGroupBy,
RecordBatchReader, concat_batches)
# Exceptions
from pyarrow.lib import (ArrowCancelled,
ArrowCapacityError,
ArrowException,
ArrowKeyError,
ArrowIndexError,
ArrowInvalid,
ArrowIOError,
ArrowMemoryError,
ArrowNotImplementedError,
ArrowTypeError,
ArrowSerializationError)
from pyarrow.ipc import serialize_pandas, deserialize_pandas
import pyarrow.ipc as ipc
import pyarrow.types as types
# ----------------------------------------------------------------------
# Deprecations
from pyarrow.util import _deprecate_api, _deprecate_class
# TODO: Deprecate these somehow in the pyarrow namespace
from pyarrow.ipc import (Message, MessageReader, MetadataVersion,
RecordBatchFileReader, RecordBatchFileWriter,
RecordBatchStreamReader, RecordBatchStreamWriter)
# ----------------------------------------------------------------------
# Returning absolute path to the pyarrow include directory (if bundled, e.g. in
# wheels)
def get_include():
"""
Return absolute path to directory containing Arrow C++ include
headers. Similar to numpy.get_include
"""
return _os.path.join(_os.path.dirname(__file__), 'include')
def _get_pkg_config_executable():
return _os.environ.get('PKG_CONFIG', 'pkg-config')
def _has_pkg_config(pkgname):
import subprocess
try:
return subprocess.call([_get_pkg_config_executable(),
'--exists', pkgname]) == 0
except FileNotFoundError:
return False
def _read_pkg_config_variable(pkgname, cli_args):
import subprocess
cmd = [_get_pkg_config_executable(), pkgname] + cli_args
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
out, err = proc.communicate()
if proc.returncode != 0:
raise RuntimeError("pkg-config failed: " + err.decode('utf8'))
return out.rstrip().decode('utf8')
def get_libraries():
"""
Return list of library names to include in the `libraries` argument for C
or Cython extensions using pyarrow
"""
return ['arrow_python', 'arrow']
def create_library_symlinks():
"""
With Linux and macOS wheels, the bundled shared libraries have an embedded
ABI version like libarrow.so.17 or libarrow.17.dylib and so linking to them
with -larrow won't work unless we create symlinks at locations like
site-packages/pyarrow/libarrow.so. This unfortunate workaround addresses
prior problems we had with shipping two copies of the shared libraries to
permit third party projects like turbodbc to build their C++ extensions
against the pyarrow wheels.
This function must only be invoked once and only when the shared libraries
are bundled with the Python package, which should only apply to wheel-based
installs. It requires write access to the site-packages/pyarrow directory
and so depending on your system may need to be run with root.
"""
import glob
if _sys.platform == 'win32':
return
package_cwd = _os.path.dirname(__file__)
if _sys.platform == 'linux':
bundled_libs = glob.glob(_os.path.join(package_cwd, '*.so.*'))
def get_symlink_path(hard_path):
return hard_path.rsplit('.', 1)[0]
else:
bundled_libs = glob.glob(_os.path.join(package_cwd, '*.*.dylib'))
def get_symlink_path(hard_path):
return '.'.join((hard_path.rsplit('.', 2)[0], 'dylib'))
for lib_hard_path in bundled_libs:
symlink_path = get_symlink_path(lib_hard_path)
if _os.path.exists(symlink_path):
continue
try:
_os.symlink(lib_hard_path, symlink_path)
except PermissionError:
print("Tried creating symlink {}. If you need to link to "
"bundled shared libraries, run "
"pyarrow.create_library_symlinks() as root")
def get_library_dirs():
"""
Return lists of directories likely to contain Arrow C++ libraries for
linking C or Cython extensions using pyarrow
"""
package_cwd = _os.path.dirname(__file__)
library_dirs = [package_cwd]
def append_library_dir(library_dir):
if library_dir not in library_dirs:
library_dirs.append(library_dir)
# Search library paths via pkg-config. This is necessary if the user
# installed libarrow and the other shared libraries manually and they
# are not shipped inside the pyarrow package (see also ARROW-2976).
pkg_config_executable = _os.environ.get('PKG_CONFIG') or 'pkg-config'
for pkgname in ["arrow", "arrow_python"]:
if _has_pkg_config(pkgname):
library_dir = _read_pkg_config_variable(pkgname,
["--libs-only-L"])
# pkg-config output could be empty if Arrow is installed
# as a system package.
if library_dir:
if not library_dir.startswith("-L"):
raise ValueError(
"pkg-config --libs-only-L returned unexpected "
f"value {library_dir!r}")
append_library_dir(library_dir[2:])
if _sys.platform == 'win32':
# TODO(wesm): Is this necessary, or does setuptools within a conda
# installation add Library\lib to the linker path for MSVC?
python_base_install = _os.path.dirname(_sys.executable)
library_dir = _os.path.join(python_base_install, 'Library', 'lib')
if _os.path.exists(_os.path.join(library_dir, 'arrow.lib')):
append_library_dir(library_dir)
# GH-45530: Add pyarrow.libs dir containing delvewheel-mangled
# msvcp140.dll
pyarrow_libs_dir = _os.path.abspath(
_os.path.join(_os.path.dirname(__file__), _os.pardir, "pyarrow.libs")
)
if _os.path.exists(pyarrow_libs_dir):
append_library_dir(pyarrow_libs_dir)
# ARROW-4074: Allow for ARROW_HOME to be set to some other directory
if _os.environ.get('ARROW_HOME'):
append_library_dir(_os.path.join(_os.environ['ARROW_HOME'], 'lib'))
else:
# Python wheels bundle the Arrow libraries in the pyarrow directory.
append_library_dir(_os.path.dirname(_os.path.abspath(__file__)))
return library_dirs

View File

@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.lib cimport *
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_acero cimport *
cdef class ExecNodeOptions(_Weakrefable):
cdef:
shared_ptr[CExecNodeOptions] wrapped
cdef void init(self, const shared_ptr[CExecNodeOptions]& sp)
cdef inline shared_ptr[CExecNodeOptions] unwrap(self) nogil
cdef class Declaration(_Weakrefable):
cdef:
CDeclaration decl
cdef void init(self, const CDeclaration& c_decl)
@staticmethod
cdef wrap(const CDeclaration& c_decl)
cdef inline CDeclaration unwrap(self) nogil

View File

@@ -0,0 +1,609 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ---------------------------------------------------------------------
# Low-level Acero bindings
# cython: profile=False
# distutils: language = c++
# cython: language_level = 3
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_acero cimport *
from pyarrow.lib cimport (Table, pyarrow_unwrap_table, pyarrow_wrap_table,
RecordBatchReader)
from pyarrow.lib import frombytes, tobytes
from pyarrow._compute cimport (
Expression, FunctionOptions, _ensure_field_ref, _true,
unwrap_null_placement, unwrap_sort_keys
)
cdef class ExecNodeOptions(_Weakrefable):
"""
Base class for the node options.
Use one of the subclasses to construct an options object.
"""
__slots__ = () # avoid mistakingly creating attributes
cdef void init(self, const shared_ptr[CExecNodeOptions]& sp):
self.wrapped = sp
cdef inline shared_ptr[CExecNodeOptions] unwrap(self) nogil:
return self.wrapped
cdef class _TableSourceNodeOptions(ExecNodeOptions):
def _set_options(self, Table table):
cdef:
shared_ptr[CTable] c_table
c_table = pyarrow_unwrap_table(table)
self.wrapped.reset(
new CTableSourceNodeOptions(c_table)
)
class TableSourceNodeOptions(_TableSourceNodeOptions):
"""
A Source node which accepts a table.
This is the option class for the "table_source" node factory.
Parameters
----------
table : pyarrow.Table
The table which acts as the data source.
"""
def __init__(self, Table table):
self._set_options(table)
cdef class _FilterNodeOptions(ExecNodeOptions):
def _set_options(self, Expression filter_expression not None):
self.wrapped.reset(
new CFilterNodeOptions(<CExpression>filter_expression.unwrap())
)
class FilterNodeOptions(_FilterNodeOptions):
"""
Make a node which excludes some rows from batches passed through it.
This is the option class for the "filter" node factory.
The "filter" operation provides an option to define data filtering
criteria. It selects rows where the given expression evaluates to true.
Filters can be written using pyarrow.compute.Expression, and the
expression must have a return type of boolean.
Parameters
----------
filter_expression : pyarrow.compute.Expression
"""
def __init__(self, Expression filter_expression):
self._set_options(filter_expression)
cdef class _ProjectNodeOptions(ExecNodeOptions):
def _set_options(self, expressions, names=None):
cdef:
Expression expr
vector[CExpression] c_expressions
vector[c_string] c_names
for expr in expressions:
c_expressions.push_back(expr.unwrap())
if names is not None:
if len(names) != len(expressions):
raise ValueError(
"The number of names should be equal to the number of expressions"
)
for name in names:
c_names.push_back(<c_string>tobytes(name))
self.wrapped.reset(
new CProjectNodeOptions(c_expressions, c_names)
)
else:
self.wrapped.reset(
new CProjectNodeOptions(c_expressions)
)
class ProjectNodeOptions(_ProjectNodeOptions):
"""
Make a node which executes expressions on input batches,
producing batches of the same length with new columns.
This is the option class for the "project" node factory.
The "project" operation rearranges, deletes, transforms, and
creates columns. Each output column is computed by evaluating
an expression against the source record batch. These must be
scalar expressions (expressions consisting of scalar literals,
field references and scalar functions, i.e. elementwise functions
that return one value for each input row independent of the value
of all other rows).
Parameters
----------
expressions : list of pyarrow.compute.Expression
List of expressions to evaluate against the source batch. This must
be scalar expressions.
names : list of str, optional
List of names for each of the output columns (same length as
`expressions`). If `names` is not provided, the string
representations of exprs will be used.
"""
def __init__(self, expressions, names=None):
self._set_options(expressions, names)
cdef class _AggregateNodeOptions(ExecNodeOptions):
def _set_options(self, aggregates, keys=None):
cdef:
CAggregate c_aggr
vector[CAggregate] c_aggregations
vector[CFieldRef] c_keys
for arg_names, func_name, opts, name in aggregates:
c_aggr.function = tobytes(func_name)
if opts is not None:
c_aggr.options = (<FunctionOptions?>opts).wrapped
else:
c_aggr.options = <shared_ptr[CFunctionOptions]>nullptr
if not isinstance(arg_names, (list, tuple)):
arg_names = [arg_names]
for arg in arg_names:
c_aggr.target.push_back(_ensure_field_ref(arg))
c_aggr.name = tobytes(name)
c_aggregations.push_back(move(c_aggr))
if keys is None:
keys = []
for name in keys:
c_keys.push_back(_ensure_field_ref(name))
self.wrapped.reset(
new CAggregateNodeOptions(c_aggregations, c_keys)
)
class AggregateNodeOptions(_AggregateNodeOptions):
"""
Make a node which aggregates input batches, optionally grouped by keys.
This is the option class for the "aggregate" node factory.
Acero supports two types of aggregates: "scalar" aggregates,
and "hash" aggregates. Scalar aggregates reduce an array or scalar
input to a single scalar output (e.g. computing the mean of a column).
Hash aggregates act like GROUP BY in SQL and first partition data
based on one or more key columns, then reduce the data in each partition.
The aggregate node supports both types of computation, and can compute
any number of aggregations at once.
Parameters
----------
aggregates : list of tuples
Aggregations which will be applied to the targeted fields.
Specified as a list of tuples, where each tuple is one aggregation
specification and consists of: aggregation target column(s) followed
by function name, aggregation function options object and the
output field name.
The target column(s) specification can be a single field reference,
an empty list or a list of fields unary, nullary and n-ary aggregation
functions respectively. Each field reference can be a string
column name or expression.
keys : list of field references, optional
Keys by which aggregations will be grouped. Each key can reference
a field using a string name or expression.
"""
def __init__(self, aggregates, keys=None):
self._set_options(aggregates, keys)
cdef class _OrderByNodeOptions(ExecNodeOptions):
def _set_options(self, sort_keys, null_placement):
self.wrapped.reset(
new COrderByNodeOptions(
COrdering(unwrap_sort_keys(sort_keys, allow_str=False),
unwrap_null_placement(null_placement))
)
)
class OrderByNodeOptions(_OrderByNodeOptions):
"""
Make a node which applies a new ordering to the data.
Currently this node works by accumulating all data, sorting, and then
emitting the new data with an updated batch index.
Larger-than-memory sort is not currently supported.
This is the option class for the "order_by" node factory.
Parameters
----------
sort_keys : sequence of (name, order) tuples
Names of field/column keys to sort the input on,
along with the order each field/column is sorted in.
Accepted values for `order` are "ascending", "descending".
Each field reference can be a string column name or expression.
null_placement : str, default "at_end"
Where nulls in input should be sorted, only applying to
columns/fields mentioned in `sort_keys`.
Accepted values are "at_start", "at_end".
"""
def __init__(self, sort_keys=(), *, null_placement="at_end"):
self._set_options(sort_keys, null_placement)
cdef class _HashJoinNodeOptions(ExecNodeOptions):
def _set_options(
self, join_type, left_keys, right_keys, left_output=None, right_output=None,
output_suffix_for_left="", output_suffix_for_right="", Expression filter_expression=None,
):
cdef:
CJoinType c_join_type
vector[CFieldRef] c_left_keys
vector[CFieldRef] c_right_keys
vector[CFieldRef] c_left_output
vector[CFieldRef] c_right_output
CExpression c_filter_expression
# join type
if join_type == "left semi":
c_join_type = CJoinType_LEFT_SEMI
elif join_type == "right semi":
c_join_type = CJoinType_RIGHT_SEMI
elif join_type == "left anti":
c_join_type = CJoinType_LEFT_ANTI
elif join_type == "right anti":
c_join_type = CJoinType_RIGHT_ANTI
elif join_type == "inner":
c_join_type = CJoinType_INNER
elif join_type == "left outer":
c_join_type = CJoinType_LEFT_OUTER
elif join_type == "right outer":
c_join_type = CJoinType_RIGHT_OUTER
elif join_type == "full outer":
c_join_type = CJoinType_FULL_OUTER
else:
raise ValueError("Unsupported join type")
# left/right keys
if not isinstance(left_keys, (list, tuple)):
left_keys = [left_keys]
for key in left_keys:
c_left_keys.push_back(_ensure_field_ref(key))
if not isinstance(right_keys, (list, tuple)):
right_keys = [right_keys]
for key in right_keys:
c_right_keys.push_back(_ensure_field_ref(key))
if filter_expression is None:
c_filter_expression = _true
else:
c_filter_expression = filter_expression.unwrap()
# left/right output fields
if left_output is not None and right_output is not None:
for colname in left_output:
c_left_output.push_back(_ensure_field_ref(colname))
for colname in right_output:
c_right_output.push_back(_ensure_field_ref(colname))
self.wrapped.reset(
new CHashJoinNodeOptions(
c_join_type, c_left_keys, c_right_keys,
c_left_output, c_right_output,
c_filter_expression,
<c_string>tobytes(output_suffix_for_left),
<c_string>tobytes(output_suffix_for_right)
)
)
else:
self.wrapped.reset(
new CHashJoinNodeOptions(
c_join_type, c_left_keys, c_right_keys,
c_filter_expression,
<c_string>tobytes(output_suffix_for_left),
<c_string>tobytes(output_suffix_for_right)
)
)
class HashJoinNodeOptions(_HashJoinNodeOptions):
"""
Make a node which implements join operation using hash join strategy.
This is the option class for the "hashjoin" node factory.
Parameters
----------
join_type : str
Type of join. One of "left semi", "right semi", "left anti",
"right anti", "inner", "left outer", "right outer", "full outer".
left_keys : str, Expression or list
Key fields from left input. Each key can be a string column name
or a field expression, or a list of such field references.
right_keys : str, Expression or list
Key fields from right input. See `left_keys` for details.
left_output : list, optional
List of output fields passed from left input. If left and right
output fields are not specified, all valid fields from both left and
right input will be output. Each field can be a string column name
or a field expression.
right_output : list, optional
List of output fields passed from right input. If left and right
output fields are not specified, all valid fields from both left and
right input will be output. Each field can be a string column name
or a field expression.
output_suffix_for_left : str
Suffix added to names of output fields coming from left input
(used to distinguish, if necessary, between fields of the same
name in left and right input and can be left empty if there are
no name collisions).
output_suffix_for_right : str
Suffix added to names of output fields coming from right input,
see `output_suffix_for_left` for details.
filter_expression : pyarrow.compute.Expression
Residual filter which is applied to matching row.
"""
def __init__(
self, join_type, left_keys, right_keys, left_output=None, right_output=None,
output_suffix_for_left="", output_suffix_for_right="", filter_expression=None,
):
self._set_options(
join_type, left_keys, right_keys, left_output, right_output,
output_suffix_for_left, output_suffix_for_right, filter_expression
)
cdef class _AsofJoinNodeOptions(ExecNodeOptions):
def _set_options(self, left_on, left_by, right_on, right_by, tolerance):
cdef:
vector[CFieldRef] c_left_by
vector[CFieldRef] c_right_by
CAsofJoinKeys c_left_keys
CAsofJoinKeys c_right_keys
vector[CAsofJoinKeys] c_input_keys
# Prepare left AsofJoinNodeOption::Keys
if not isinstance(left_by, (list, tuple)):
left_by = [left_by]
for key in left_by:
c_left_by.push_back(_ensure_field_ref(key))
c_left_keys.on_key = _ensure_field_ref(left_on)
c_left_keys.by_key = c_left_by
c_input_keys.push_back(c_left_keys)
# Prepare right AsofJoinNodeOption::Keys
if not isinstance(right_by, (list, tuple)):
right_by = [right_by]
for key in right_by:
c_right_by.push_back(_ensure_field_ref(key))
c_right_keys.on_key = _ensure_field_ref(right_on)
c_right_keys.by_key = c_right_by
c_input_keys.push_back(c_right_keys)
self.wrapped.reset(
new CAsofJoinNodeOptions(
c_input_keys,
tolerance,
)
)
class AsofJoinNodeOptions(_AsofJoinNodeOptions):
"""
Make a node which implements 'as of join' operation.
This is the option class for the "asofjoin" node factory.
Parameters
----------
left_on : str, Expression
The left key on which the join operation should be performed.
Can be a string column name or a field expression.
An inexact match is used on the "on" key, i.e. a row is considered a
match if and only if left_on - tolerance <= right_on <= left_on.
The input dataset must be sorted by the "on" key. Must be a single
field of a common type.
Currently, the "on" key must be an integer, date, or timestamp type.
left_by: str, Expression or list
The left keys on which the join operation should be performed.
Exact equality is used for each field of the "by" keys.
Each key can be a string column name or a field expression,
or a list of such field references.
right_on : str, Expression
The right key on which the join operation should be performed.
See `left_on` for details.
right_by: str, Expression or list
The right keys on which the join operation should be performed.
See `left_by` for details.
tolerance : int
The tolerance to use for the asof join. The tolerance is interpreted in
the same units as the "on" key.
"""
def __init__(self, left_on, left_by, right_on, right_by, tolerance):
self._set_options(left_on, left_by, right_on, right_by, tolerance)
cdef class Declaration(_Weakrefable):
"""
Helper class for declaring the nodes of an ExecPlan.
A Declaration represents an unconstructed ExecNode, and potentially
more since its inputs may also be Declarations or when constructed
with ``from_sequence``.
The possible ExecNodes to use are registered with a name,
the "factory name", and need to be specified using this name, together
with its corresponding ExecNodeOptions subclass.
Parameters
----------
factory_name : str
The ExecNode factory name, such as "table_source", "filter",
"project" etc. See the ExecNodeOptions subclasses for the exact
factory names to use.
options : ExecNodeOptions
Corresponding ExecNodeOptions subclass (matching the factory name).
inputs : list of Declaration, optional
Input nodes for this declaration. Optional if the node is a source
node, or when the declaration gets combined later with
``from_sequence``.
Returns
-------
Declaration
"""
cdef void init(self, const CDeclaration& c_decl):
self.decl = c_decl
@staticmethod
cdef wrap(const CDeclaration& c_decl):
cdef Declaration self = Declaration.__new__(Declaration)
self.init(c_decl)
return self
cdef inline CDeclaration unwrap(self) nogil:
return self.decl
def __init__(self, factory_name, ExecNodeOptions options, inputs=None):
cdef:
c_string c_factory_name
CDeclaration c_decl
vector[CDeclaration.Input] c_inputs
c_factory_name = tobytes(factory_name)
if inputs is not None:
for ipt in inputs:
c_inputs.push_back(
CDeclaration.Input((<Declaration>ipt).unwrap())
)
c_decl = CDeclaration(c_factory_name, c_inputs, options.unwrap())
self.init(c_decl)
@staticmethod
def from_sequence(decls):
"""
Convenience factory for the common case of a simple sequence of nodes.
Each of the declarations will be appended to the inputs of the
subsequent declaration, and the final modified declaration will
be returned.
Parameters
----------
decls : list of Declaration
Returns
-------
Declaration
"""
cdef:
vector[CDeclaration] c_decls
CDeclaration c_decl
for decl in decls:
c_decls.push_back((<Declaration> decl).unwrap())
c_decl = CDeclaration.Sequence(c_decls)
return Declaration.wrap(c_decl)
def __str__(self):
return frombytes(GetResultValue(DeclarationToString(self.decl)))
def __repr__(self):
return f"<pyarrow.acero.Declaration>\n{self}"
def to_table(self, bint use_threads=True):
"""
Run the declaration and collect the results into a table.
This method will implicitly add a sink node to the declaration
to collect results into a table. It will then create an ExecPlan
from the declaration, start the exec plan, block until the plan
has finished, and return the created table.
Parameters
----------
use_threads : bool, default True
If set to False, then all CPU work will be done on the calling
thread. I/O tasks will still happen on the I/O executor
and may be multi-threaded (but should not use significant CPU
resources).
Returns
-------
pyarrow.Table
"""
cdef:
shared_ptr[CTable] c_table
with nogil:
c_table = GetResultValue(DeclarationToTable(self.unwrap(), use_threads))
return pyarrow_wrap_table(c_table)
def to_reader(self, bint use_threads=True):
"""Run the declaration and return results as a RecordBatchReader.
For details about the parameters, see `to_table`.
Returns
-------
pyarrow.RecordBatchReader
"""
cdef:
RecordBatchReader reader
reader = RecordBatchReader.__new__(RecordBatchReader)
reader.reader.reset(
GetResultValue(DeclarationToReader(self.unwrap(), use_threads)).release()
)
return reader

View File

@@ -0,0 +1,188 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.lib import frombytes, tobytes
from pyarrow.includes.libarrow_fs cimport *
from pyarrow._fs cimport FileSystem
cdef class AzureFileSystem(FileSystem):
"""
Azure Blob Storage backed FileSystem implementation
This implementation supports flat namespace and hierarchical namespace (HNS) a.k.a.
Data Lake Gen2 storage accounts. HNS will be automatically detected and HNS specific
features will be used when they provide a performance advantage. Azurite emulator is
also supported. Note: `/` is the only supported delimiter.
The storage account is considered the root of the filesystem. When enabled, containers
will be created or deleted during relevant directory operations. Obviously, this also
requires authentication with the additional permissions.
By default `DefaultAzureCredential <https://github.com/Azure/azure-sdk-for-cpp/blob/main/sdk/identity/azure-identity/README.md#defaultazurecredential>`__
is used for authentication. This means it will try several types of authentication
and go with the first one that works. If any authentication parameters are provided when
initialising the FileSystem, they will be used instead of the default credential.
Parameters
----------
account_name : str
Azure Blob Storage account name. This is the globally unique identifier for the
storage account.
account_key : str, default None
Account key of the storage account. If sas_token and account_key are None the
default credential will be used. The parameters account_key and sas_token are
mutually exclusive.
blob_storage_authority : str, default None
hostname[:port] of the Blob Service. Defaults to `.blob.core.windows.net`. Useful
for connecting to a local emulator, like Azurite.
blob_storage_scheme : str, default None
Either `http` or `https`. Defaults to `https`. Useful for connecting to a local
emulator, like Azurite.
client_id : str, default None
The client ID (Application ID) for Azure Active Directory authentication.
Its interpretation depends on the credential type being used:
- For `ClientSecretCredential`: It is the Application (client) ID of your
registered Azure AD application (Service Principal). It must be provided
together with `tenant_id` and `client_secret` to use ClientSecretCredential.
- For `ManagedIdentityCredential`: It is the client ID of a specific
user-assigned managed identity. This is only necessary if you are using a
user-assigned managed identity and need to explicitly specify which one
(e.g., if the resource has multiple user-assigned identities). For
system-assigned managed identities, this parameter is typically not required.
client_secret : str, default None
Client secret for Azure Active Directory authentication. Must be provided together
with `tenant_id` and `client_id` to use ClientSecretCredential.
dfs_storage_authority : str, default None
hostname[:port] of the Data Lake Gen 2 Service. Defaults to
`.dfs.core.windows.net`. Useful for connecting to a local emulator, like Azurite.
dfs_storage_scheme : str, default None
Either `http` or `https`. Defaults to `https`. Useful for connecting to a local
emulator, like Azurite.
sas_token : str, default None
SAS token for the storage account, used as an alternative to account_key. If sas_token
and account_key are None the default credential will be used. The parameters
account_key and sas_token are mutually exclusive.
tenant_id : str, default None
Tenant ID for Azure Active Directory authentication. Must be provided together with
`client_id` and `client_secret` to use ClientSecretCredential.
Examples
--------
>>> from pyarrow import fs
>>> azure_fs = fs.AzureFileSystem(account_name='myaccount')
>>> azurite_fs = fs.AzureFileSystem(
... account_name='devstoreaccount1',
... account_key='Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==',
... blob_storage_authority='127.0.0.1:10000',
... dfs_storage_authority='127.0.0.1:10000',
... blob_storage_scheme='http',
... dfs_storage_scheme='http',
... )
For usage of the methods see examples for :func:`~pyarrow.fs.LocalFileSystem`.
"""
cdef:
CAzureFileSystem* azurefs
c_string account_key
c_string sas_token
c_string tenant_id
c_string client_id
c_string client_secret
def __init__(self, account_name, *, account_key=None, blob_storage_authority=None,
blob_storage_scheme=None, client_id=None, client_secret=None,
dfs_storage_authority=None, dfs_storage_scheme=None,
sas_token=None, tenant_id=None):
cdef:
CAzureOptions options
shared_ptr[CAzureFileSystem] wrapped
options.account_name = tobytes(account_name)
if blob_storage_authority:
options.blob_storage_authority = tobytes(blob_storage_authority)
if dfs_storage_authority:
options.dfs_storage_authority = tobytes(dfs_storage_authority)
if blob_storage_scheme:
options.blob_storage_scheme = tobytes(blob_storage_scheme)
if dfs_storage_scheme:
options.dfs_storage_scheme = tobytes(dfs_storage_scheme)
if account_key and sas_token:
raise ValueError("Cannot specify both account_key and sas_token.")
if (tenant_id or client_id or client_secret):
if not client_id:
raise ValueError("client_id must be specified")
if not tenant_id and not client_secret:
options.ConfigureManagedIdentityCredential(tobytes(client_id))
self.client_id = tobytes(client_id)
elif tenant_id and client_secret:
options.ConfigureClientSecretCredential(
tobytes(tenant_id), tobytes(client_id), tobytes(client_secret)
)
self.tenant_id = tobytes(tenant_id)
self.client_id = tobytes(client_id)
self.client_secret = tobytes(client_secret)
else:
raise ValueError(
"Invalid Azure credential configuration: "
"For ManagedIdentityCredential, provide only client_id. "
"For ClientSecretCredential, provide tenant_id, client_id, and client_secret."
)
elif account_key:
options.ConfigureAccountKeyCredential(tobytes(account_key))
self.account_key = tobytes(account_key)
elif sas_token:
options.ConfigureSASCredential(tobytes(sas_token))
self.sas_token = tobytes(sas_token)
else:
options.ConfigureDefaultCredential()
with nogil:
wrapped = GetResultValue(CAzureFileSystem.Make(options))
self.init(<shared_ptr[CFileSystem]> wrapped)
cdef init(self, const shared_ptr[CFileSystem]& wrapped):
FileSystem.init(self, wrapped)
self.azurefs = <CAzureFileSystem*> wrapped.get()
@staticmethod
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return AzureFileSystem(**kwargs)
def __reduce__(self):
cdef CAzureOptions opts = self.azurefs.options()
return (
AzureFileSystem._reconstruct, (dict(
account_name=frombytes(opts.account_name),
account_key=frombytes(self.account_key),
blob_storage_authority=frombytes(opts.blob_storage_authority),
blob_storage_scheme=frombytes(opts.blob_storage_scheme),
client_id=frombytes(self.client_id),
client_secret=frombytes(self.client_secret),
dfs_storage_authority=frombytes(opts.dfs_storage_authority),
dfs_storage_scheme=frombytes(opts.dfs_storage_scheme),
sas_token=frombytes(self.sas_token),
tenant_id=frombytes(self.tenant_id)
),))

View File

@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.lib cimport *
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
cdef class UdfContext(_Weakrefable):
cdef:
CUdfContext c_context
cdef void init(self, const CUdfContext& c_context)
cdef class FunctionOptions(_Weakrefable):
cdef:
shared_ptr[CFunctionOptions] wrapped
cdef const CFunctionOptions* get_options(self) except NULL
cdef void init(self, const shared_ptr[CFunctionOptions]& sp)
cdef inline shared_ptr[CFunctionOptions] unwrap(self)
cdef class _SortOptions(FunctionOptions):
pass
cdef CExpression _bind(Expression filter, Schema schema) except *
cdef class Expression(_Weakrefable):
cdef:
CExpression expr
cdef void init(self, const CExpression& sp)
@staticmethod
cdef wrap(const CExpression& sp)
cdef inline CExpression unwrap(self)
@staticmethod
cdef Expression _expr_or_scalar(object expr)
cdef CExpression _true
cdef CFieldRef _ensure_field_ref(value) except *
cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=*) except *
cdef CSortOrder unwrap_sort_order(order) except *
cdef CNullPlacement unwrap_null_placement(null_placement) except *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Custom documentation additions for compute functions.
"""
function_doc_additions = {}
function_doc_additions["filter"] = """
Examples
--------
>>> import pyarrow as pa
>>> arr = pa.array(["a", "b", "c", None, "e"])
>>> mask = pa.array([True, False, None, False, True])
>>> arr.filter(mask)
<pyarrow.lib.StringArray object at ...>
[
"a",
"e"
]
>>> arr.filter(mask, null_selection_behavior='emit_null')
<pyarrow.lib.StringArray object at ...>
[
"a",
null,
"e"
]
"""
function_doc_additions["mode"] = """
Examples
--------
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> arr = pa.array([1, 1, 2, 2, 3, 2, 2, 2])
>>> modes = pc.mode(arr, 2)
>>> modes[0]
<pyarrow.StructScalar: [('mode', 2), ('count', 5)]>
>>> modes[1]
<pyarrow.StructScalar: [('mode', 1), ('count', 2)]>
"""

View File

@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport _Weakrefable
cdef class ConvertOptions(_Weakrefable):
cdef:
unique_ptr[CCSVConvertOptions] options
@staticmethod
cdef ConvertOptions wrap(CCSVConvertOptions options)
cdef class ParseOptions(_Weakrefable):
cdef:
unique_ptr[CCSVParseOptions] options
object _invalid_row_handler
@staticmethod
cdef ParseOptions wrap(CCSVParseOptions options)
cdef class ReadOptions(_Weakrefable):
cdef:
unique_ptr[CCSVReadOptions] options
public object encoding
@staticmethod
cdef ReadOptions wrap(CCSVReadOptions options)
cdef class WriteOptions(_Weakrefable):
cdef:
unique_ptr[CCSVWriteOptions] options
@staticmethod
cdef WriteOptions wrap(CCSVWriteOptions options)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.lib cimport *
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_cuda cimport *
cdef class Context(_Weakrefable):
cdef:
shared_ptr[CCudaContext] context
int device_number
cdef void init(self, const shared_ptr[CCudaContext]& ctx)
cdef class IpcMemHandle(_Weakrefable):
cdef:
shared_ptr[CCudaIpcMemHandle] handle
cdef void init(self, shared_ptr[CCudaIpcMemHandle]& h)
cdef class CudaBuffer(Buffer):
cdef:
shared_ptr[CCudaBuffer] cuda_buffer
object base
cdef void init_cuda(self,
const shared_ptr[CCudaBuffer]& buffer,
object base)
cdef class HostBuffer(Buffer):
cdef:
shared_ptr[CCudaHostBuffer] host_buffer
cdef void init_host(self, const shared_ptr[CCudaHostBuffer]& buffer)
cdef class BufferReader(NativeFile):
cdef:
CCudaBufferReader* reader
CudaBuffer buffer
cdef class BufferWriter(NativeFile):
cdef:
CCudaBufferWriter* writer
CudaBuffer buffer

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,183 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
"""Dataset is currently unstable. APIs subject to change without notice."""
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow_dataset cimport *
from pyarrow.lib cimport *
from pyarrow._fs cimport FileSystem, FileInfo
cdef CFileSource _make_file_source(object file, FileSystem filesystem=*, object file_size=*)
cdef class DatasetFactory(_Weakrefable):
cdef:
SharedPtrNoGIL[CDatasetFactory] wrapped
CDatasetFactory* factory
cdef init(self, const shared_ptr[CDatasetFactory]& sp)
@staticmethod
cdef wrap(const shared_ptr[CDatasetFactory]& sp)
cdef inline shared_ptr[CDatasetFactory] unwrap(self) nogil
cdef class Dataset(_Weakrefable):
cdef:
SharedPtrNoGIL[CDataset] wrapped
CDataset* dataset
public dict _scan_options
cdef void init(self, const shared_ptr[CDataset]& sp)
@staticmethod
cdef wrap(const shared_ptr[CDataset]& sp)
cdef shared_ptr[CDataset] unwrap(self) nogil
cdef class Scanner(_Weakrefable):
cdef:
SharedPtrNoGIL[CScanner] wrapped
CScanner* scanner
cdef void init(self, const shared_ptr[CScanner]& sp)
@staticmethod
cdef wrap(const shared_ptr[CScanner]& sp)
cdef shared_ptr[CScanner] unwrap(self)
@staticmethod
cdef shared_ptr[CScanOptions] _make_scan_options(Dataset dataset, dict py_scanoptions) except *
cdef class FragmentScanOptions(_Weakrefable):
cdef:
shared_ptr[CFragmentScanOptions] wrapped
cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp)
@staticmethod
cdef wrap(const shared_ptr[CFragmentScanOptions]& sp)
cdef class FileFormat(_Weakrefable):
cdef:
shared_ptr[CFileFormat] wrapped
CFileFormat* format
cdef void init(self, const shared_ptr[CFileFormat]& sp)
@staticmethod
cdef wrap(const shared_ptr[CFileFormat]& sp)
cdef inline shared_ptr[CFileFormat] unwrap(self)
cdef _set_default_fragment_scan_options(self, FragmentScanOptions options)
# Return a WrittenFile after a file was written.
# May be overridden by subclasses, e.g. to add metadata.
cdef WrittenFile _finish_write(self, path, base_dir,
CFileWriter* file_writer)
cdef class FileWriteOptions(_Weakrefable):
cdef:
shared_ptr[CFileWriteOptions] wrapped
CFileWriteOptions* c_options
cdef void init(self, const shared_ptr[CFileWriteOptions]& sp)
@staticmethod
cdef wrap(const shared_ptr[CFileWriteOptions]& sp)
cdef inline shared_ptr[CFileWriteOptions] unwrap(self)
cdef class Fragment(_Weakrefable):
cdef:
SharedPtrNoGIL[CFragment] wrapped
CFragment* fragment
cdef void init(self, const shared_ptr[CFragment]& sp)
@staticmethod
cdef wrap(const shared_ptr[CFragment]& sp)
cdef inline shared_ptr[CFragment] unwrap(self)
cdef class FileFragment(Fragment):
cdef:
CFileFragment* file_fragment
cdef void init(self, const shared_ptr[CFragment]& sp)
cdef class Partitioning(_Weakrefable):
cdef:
shared_ptr[CPartitioning] wrapped
CPartitioning* partitioning
cdef init(self, const shared_ptr[CPartitioning]& sp)
@staticmethod
cdef wrap(const shared_ptr[CPartitioning]& sp)
cdef inline shared_ptr[CPartitioning] unwrap(self)
cdef class PartitioningFactory(_Weakrefable):
cdef:
shared_ptr[CPartitioningFactory] wrapped
CPartitioningFactory* factory
object constructor
object options
cdef init(self, const shared_ptr[CPartitioningFactory]& sp)
@staticmethod
cdef wrap(const shared_ptr[CPartitioningFactory]& sp,
object constructor, object options)
cdef inline shared_ptr[CPartitioningFactory] unwrap(self)
cdef class WrittenFile(_Weakrefable):
# The full path to the created file
cdef public str path
# Optional Parquet metadata
# This metadata will have the file path attribute set to the path of
# the written file.
cdef public object metadata
# The size of the file in bytes
cdef public int64_t size

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
"""Dataset support for ORC file format."""
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_dataset cimport *
from pyarrow._dataset cimport FileFormat
cdef class OrcFileFormat(FileFormat):
def __init__(self):
self.init(shared_ptr[CFileFormat](new COrcFileFormat()))
def equals(self, OrcFileFormat other):
"""
Parameters
----------
other : pyarrow.dataset.OrcFileFormat
Returns
-------
True
"""
return True
@property
def default_extname(self):
return "orc"
def __reduce__(self):
return OrcFileFormat, tuple()

View File

@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
"""Dataset support for Parquet file format."""
from pyarrow.includes.libarrow_dataset cimport *
from pyarrow.includes.libarrow_dataset_parquet cimport *
from pyarrow._dataset cimport FragmentScanOptions, FileWriteOptions
cdef class ParquetFragmentScanOptions(FragmentScanOptions):
cdef:
CParquetFragmentScanOptions* parquet_options
object _parquet_decryption_config
object _decryption_properties
cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp)
cdef CReaderProperties* reader_properties(self)
cdef ArrowReaderProperties* arrow_reader_properties(self)
cdef class ParquetFileWriteOptions(FileWriteOptions):
cdef:
CParquetFileWriteOptions* parquet_options
object _properties

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
"""Dataset support for Parquet encryption."""
from pyarrow.includes.libarrow_dataset_parquet cimport *
from pyarrow._parquet_encryption cimport *
from pyarrow._dataset_parquet cimport ParquetFragmentScanOptions, ParquetFileWriteOptions
cdef class ParquetEncryptionConfig(_Weakrefable):
"""
Core configuration class encapsulating parameters for high-level encryption
within the Parquet framework.
The ParquetEncryptionConfig class serves as a bridge for passing encryption-related
parameters to the appropriate components within the Parquet library. It maintains references
to objects that define the encryption strategy, Key Management Service (KMS) configuration,
and specific encryption configurations for Parquet data.
Parameters
----------
crypto_factory : pyarrow.parquet.encryption.CryptoFactory
Shared pointer to a `CryptoFactory` object. The `CryptoFactory` is responsible for
creating cryptographic components, such as encryptors and decryptors.
kms_connection_config : pyarrow.parquet.encryption.KmsConnectionConfig
Shared pointer to a `KmsConnectionConfig` object. This object holds the configuration
parameters necessary for connecting to a Key Management Service (KMS).
encryption_config : pyarrow.parquet.encryption.EncryptionConfiguration
Shared pointer to an `EncryptionConfiguration` object. This object defines specific
encryption settings for Parquet data, including the keys assigned to different columns.
Raises
------
ValueError
Raised if `encryption_config` is None.
"""
cdef:
shared_ptr[CParquetEncryptionConfig] c_config
# Avoid mistakenly creating attributes
__slots__ = ()
def __cinit__(self, CryptoFactory crypto_factory, KmsConnectionConfig kms_connection_config,
EncryptionConfiguration encryption_config):
cdef shared_ptr[CEncryptionConfiguration] c_encryption_config
if crypto_factory is None:
raise ValueError("crypto_factory cannot be None")
if kms_connection_config is None:
raise ValueError("kms_connection_config cannot be None")
if encryption_config is None:
raise ValueError("encryption_config cannot be None")
self.c_config.reset(new CParquetEncryptionConfig())
c_encryption_config = pyarrow_unwrap_encryptionconfig(
encryption_config)
self.c_config.get().crypto_factory = pyarrow_unwrap_cryptofactory(crypto_factory)
self.c_config.get().kms_connection_config = pyarrow_unwrap_kmsconnectionconfig(
kms_connection_config)
self.c_config.get().encryption_config = c_encryption_config
@staticmethod
cdef wrap(shared_ptr[CParquetEncryptionConfig] c_config):
cdef ParquetEncryptionConfig python_config = ParquetEncryptionConfig.__new__(ParquetEncryptionConfig)
python_config.c_config = c_config
return python_config
cdef shared_ptr[CParquetEncryptionConfig] unwrap(self):
return self.c_config
cdef class ParquetDecryptionConfig(_Weakrefable):
"""
Core configuration class encapsulating parameters for high-level decryption
within the Parquet framework.
ParquetDecryptionConfig is designed to pass decryption-related parameters to
the appropriate decryption components within the Parquet library. It holds references to
objects that define the decryption strategy, Key Management Service (KMS) configuration,
and specific decryption configurations for reading encrypted Parquet data.
Parameters
----------
crypto_factory : pyarrow.parquet.encryption.CryptoFactory
Shared pointer to a `CryptoFactory` object, pivotal in creating cryptographic
components for the decryption process.
kms_connection_config : pyarrow.parquet.encryption.KmsConnectionConfig
Shared pointer to a `KmsConnectionConfig` object, containing parameters necessary
for connecting to a Key Management Service (KMS) during decryption.
decryption_config : pyarrow.parquet.encryption.DecryptionConfiguration
Shared pointer to a `DecryptionConfiguration` object, specifying decryption settings
for reading encrypted Parquet data.
Raises
------
ValueError
Raised if `decryption_config` is None.
"""
cdef:
shared_ptr[CParquetDecryptionConfig] c_config
# Avoid mistakingly creating attributes
__slots__ = ()
def __cinit__(self, CryptoFactory crypto_factory, KmsConnectionConfig kms_connection_config,
DecryptionConfiguration decryption_config):
cdef shared_ptr[CDecryptionConfiguration] c_decryption_config
if decryption_config is None:
raise ValueError(
"decryption_config cannot be None")
self.c_config.reset(new CParquetDecryptionConfig())
c_decryption_config = pyarrow_unwrap_decryptionconfig(
decryption_config)
self.c_config.get().crypto_factory = pyarrow_unwrap_cryptofactory(crypto_factory)
self.c_config.get().kms_connection_config = pyarrow_unwrap_kmsconnectionconfig(
kms_connection_config)
self.c_config.get().decryption_config = c_decryption_config
@staticmethod
cdef wrap(shared_ptr[CParquetDecryptionConfig] c_config):
cdef ParquetDecryptionConfig python_config = ParquetDecryptionConfig.__new__(ParquetDecryptionConfig)
python_config.c_config = c_config
return python_config
cdef shared_ptr[CParquetDecryptionConfig] unwrap(self):
return self.c_config
def set_encryption_config(
ParquetFileWriteOptions opts not None,
ParquetEncryptionConfig config not None
):
cdef shared_ptr[CParquetEncryptionConfig] c_config = config.unwrap()
opts.parquet_options.parquet_encryption_config = c_config
def set_decryption_properties(
ParquetFragmentScanOptions opts not None,
FileDecryptionProperties config not None
):
cdef CReaderProperties* reader_props = opts.reader_properties()
reader_props.file_decryption_properties(config.unwrap())
def set_decryption_config(
ParquetFragmentScanOptions opts not None,
ParquetDecryptionConfig config not None
):
cdef shared_ptr[CParquetDecryptionConfig] c_config = config.unwrap()
opts.parquet_options.parquet_decryption_config = c_config

View File

@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cimport cpython
from cpython.pycapsule cimport PyCapsule_New
cdef void dlpack_pycapsule_deleter(object dltensor) noexcept:
cdef DLManagedTensor* dlm_tensor
cdef PyObject* err_type
cdef PyObject* err_value
cdef PyObject* err_traceback
# Do nothing if the capsule has been consumed
if cpython.PyCapsule_IsValid(dltensor, "used_dltensor"):
return
# An exception may be in-flight, we must save it in case
# we create another one
cpython.PyErr_Fetch(&err_type, &err_value, &err_traceback)
dlm_tensor = <DLManagedTensor*>cpython.PyCapsule_GetPointer(dltensor, 'dltensor')
if dlm_tensor == NULL:
cpython.PyErr_WriteUnraisable(dltensor)
# The deleter can be NULL if there is no way for the caller
# to provide a reasonable destructor
elif dlm_tensor.deleter:
dlm_tensor.deleter(dlm_tensor)
assert (not cpython.PyErr_Occurred())
# Set the error indicator from err_type, err_value, err_traceback
cpython.PyErr_Restore(err_type, err_value, err_traceback)

View File

@@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ---------------------------------------------------------------------
# Implement Feather file format
# cython: profile=False
# distutils: language = c++
# cython: language_level=3
from cython.operator cimport dereference as deref
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_feather cimport *
from pyarrow.lib cimport (check_status, Table, _Weakrefable,
get_writer, get_reader, pyarrow_wrap_table)
from pyarrow.lib import tobytes
class FeatherError(Exception):
pass
def write_feather(Table table, object dest, compression=None,
compression_level=None, chunksize=None, version=2):
cdef shared_ptr[COutputStream] sink
get_writer(dest, &sink)
cdef CFeatherProperties properties
if version == 2:
properties.version = kFeatherV2Version
else:
properties.version = kFeatherV1Version
if compression == 'zstd':
properties.compression = CCompressionType_ZSTD
elif compression == 'lz4':
properties.compression = CCompressionType_LZ4_FRAME
else:
properties.compression = CCompressionType_UNCOMPRESSED
if chunksize is not None:
properties.chunksize = chunksize
if compression_level is not None:
properties.compression_level = compression_level
with nogil:
check_status(WriteFeather(deref(table.table), sink.get(),
properties))
cdef class FeatherReader(_Weakrefable):
cdef:
shared_ptr[CFeatherReader] reader
def __cinit__(self, source, c_bool use_memory_map, c_bool use_threads):
cdef:
shared_ptr[CRandomAccessFile] reader
CIpcReadOptions options = CIpcReadOptions.Defaults()
options.use_threads = use_threads
get_reader(source, use_memory_map, &reader)
with nogil:
self.reader = GetResultValue(CFeatherReader.Open(reader, options))
@property
def version(self):
return self.reader.get().version()
def read(self):
cdef shared_ptr[CTable] sp_table
with nogil:
check_status(self.reader.get()
.Read(&sp_table))
return pyarrow_wrap_table(sp_table)
def read_indices(self, indices):
cdef:
shared_ptr[CTable] sp_table
vector[int] c_indices
for index in indices:
c_indices.push_back(index)
with nogil:
check_status(self.reader.get()
.Read(c_indices, &sp_table))
return pyarrow_wrap_table(sp_table)
def read_names(self, names):
cdef:
shared_ptr[CTable] sp_table
vector[c_string] c_names
for name in names:
c_names.push_back(tobytes(name))
with nogil:
check_status(self.reader.get()
.Read(c_names, &sp_table))
return pyarrow_wrap_table(sp_table)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow_fs cimport *
from pyarrow.lib import _detect_compression, frombytes, tobytes
from pyarrow.lib cimport *
cpdef enum FileType:
NotFound = <int8_t> CFileType_NotFound
Unknown = <int8_t> CFileType_Unknown
File = <int8_t> CFileType_File
Directory = <int8_t> CFileType_Directory
cdef class FileInfo(_Weakrefable):
cdef:
CFileInfo info
@staticmethod
cdef wrap(CFileInfo info)
cdef inline CFileInfo unwrap(self) nogil
@staticmethod
cdef CFileInfo unwrap_safe(obj)
cdef class FileSelector(_Weakrefable):
cdef:
CFileSelector selector
@staticmethod
cdef FileSelector wrap(CFileSelector selector)
cdef inline CFileSelector unwrap(self) nogil
cdef class FileSystem(_Weakrefable):
cdef:
shared_ptr[CFileSystem] wrapped
CFileSystem* fs
cdef init(self, const shared_ptr[CFileSystem]& wrapped)
@staticmethod
cdef wrap(const shared_ptr[CFileSystem]& sp)
cdef inline shared_ptr[CFileSystem] unwrap(self) nogil
cdef class LocalFileSystem(FileSystem):
cdef init(self, const shared_ptr[CFileSystem]& wrapped)
cdef class SubTreeFileSystem(FileSystem):
cdef:
CSubTreeFileSystem* subtreefs
cdef init(self, const shared_ptr[CFileSystem]& wrapped)
cdef class _MockFileSystem(FileSystem):
cdef:
CMockFileSystem* mockfs
cdef init(self, const shared_ptr[CFileSystem]& wrapped)
cdef class PyFileSystem(FileSystem):
cdef:
CPyFileSystem* pyfs
cdef init(self, const shared_ptr[CFileSystem]& wrapped)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,209 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.lib cimport (pyarrow_wrap_metadata,
pyarrow_unwrap_metadata)
from pyarrow.lib import frombytes, tobytes, ensure_metadata
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_fs cimport *
from pyarrow._fs cimport FileSystem, TimePoint_to_ns, PyDateTime_to_TimePoint
from datetime import datetime, timedelta, timezone
cdef class GcsFileSystem(FileSystem):
"""
Google Cloud Storage (GCS) backed FileSystem implementation
By default uses the process described in https://google.aip.dev/auth/4110
to resolve credentials. If not running on Google Cloud Platform (GCP),
this generally requires the environment variable
GOOGLE_APPLICATION_CREDENTIALS to point to a JSON file
containing credentials.
Note: GCS buckets are special and the operations available on them may be
limited or more expensive than expected compared to local file systems.
Note: When pickling a GcsFileSystem that uses default credentials, resolution
credentials are not stored in the serialized data. Therefore, when unpickling
it is assumed that the necessary credentials are in place for the target
process.
Parameters
----------
anonymous : boolean, default False
Whether to connect anonymously.
If true, will not attempt to look up credentials using standard GCP
configuration methods.
access_token : str, default None
GCP access token. If provided, temporary credentials will be fetched by
assuming this role; also, a `credential_token_expiration` must be
specified as well.
target_service_account : str, default None
An optional service account to try to impersonate when accessing GCS. This
requires the specified credential user or service account to have the necessary
permissions.
credential_token_expiration : datetime, default None
Expiration for credential generated with an access token. Must be specified
if `access_token` is specified.
default_bucket_location : str, default 'US'
GCP region to create buckets in.
scheme : str, default 'https'
GCS connection transport scheme.
endpoint_override : str, default None
Override endpoint with a connect string such as "localhost:9000"
default_metadata : mapping or pyarrow.KeyValueMetadata, default None
Default metadata for `open_output_stream`. This will be ignored if
non-empty metadata is passed to `open_output_stream`.
retry_time_limit : timedelta, default None
Set the maximum amount of time the GCS client will attempt to retry
transient errors. Subsecond granularity is ignored.
project_id : str, default None
The GCP project identifier to use for creating buckets.
If not set, the library uses the GOOGLE_CLOUD_PROJECT environment
variable. Most I/O operations do not need a project id, only applications
that create new buckets need a project id.
"""
cdef:
CGcsFileSystem* gcsfs
def __init__(self, *, bint anonymous=False, access_token=None,
target_service_account=None, credential_token_expiration=None,
default_bucket_location='US',
scheme=None,
endpoint_override=None,
default_metadata=None,
retry_time_limit=None,
project_id=None):
cdef:
CGcsOptions options
shared_ptr[CGcsFileSystem] wrapped
double time_limit_seconds
# Intentional use of truthiness because empty strings aren't valid and
# for reconstruction from pickling will give empty strings.
if anonymous and (target_service_account or access_token):
raise ValueError(
'anonymous option is not compatible with target_service_account and '
'access_token'
)
elif bool(access_token) != bool(credential_token_expiration):
raise ValueError(
'access_token and credential_token_expiration must be '
'specified together'
)
elif anonymous:
options = CGcsOptions.Anonymous()
elif access_token:
if not isinstance(credential_token_expiration, datetime):
raise ValueError(
"credential_token_expiration must be a datetime")
options = CGcsOptions.FromAccessToken(
tobytes(access_token),
PyDateTime_to_TimePoint(<PyDateTime_DateTime*>credential_token_expiration))
else:
options = CGcsOptions.Defaults()
# Target service account requires base credentials so
# it is not part of the if/else chain above which only
# handles base credentials.
if target_service_account:
options = CGcsOptions.FromImpersonatedServiceAccount(
options.credentials, tobytes(target_service_account))
options.default_bucket_location = tobytes(default_bucket_location)
if scheme is not None:
options.scheme = tobytes(scheme)
if endpoint_override is not None:
options.endpoint_override = tobytes(endpoint_override)
if default_metadata is not None:
options.default_metadata = pyarrow_unwrap_metadata(
ensure_metadata(default_metadata))
if retry_time_limit is not None:
time_limit_seconds = retry_time_limit.total_seconds()
options.retry_limit_seconds = time_limit_seconds
if project_id is not None:
options.project_id = <c_string>tobytes(project_id)
with nogil:
wrapped = GetResultValue(CGcsFileSystem.Make(options))
self.init(<shared_ptr[CFileSystem]> wrapped)
cdef init(self, const shared_ptr[CFileSystem]& wrapped):
FileSystem.init(self, wrapped)
self.gcsfs = <CGcsFileSystem*> wrapped.get()
def _expiration_datetime_from_options(self):
expiration_ns = TimePoint_to_ns(
self.gcsfs.options().credentials.expiration())
if expiration_ns == 0:
return None
return datetime.fromtimestamp(expiration_ns / 1.0e9, timezone.utc)
@staticmethod
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return GcsFileSystem(**kwargs)
def __reduce__(self):
cdef CGcsOptions opts = self.gcsfs.options()
service_account = frombytes(opts.credentials.target_service_account())
expiration_dt = self._expiration_datetime_from_options()
retry_time_limit = None
if opts.retry_limit_seconds.has_value():
retry_time_limit = timedelta(
seconds=opts.retry_limit_seconds.value())
project_id = None
if opts.project_id.has_value():
project_id = frombytes(opts.project_id.value())
return (
GcsFileSystem._reconstruct, (dict(
access_token=frombytes(opts.credentials.access_token()),
anonymous=opts.credentials.anonymous(),
credential_token_expiration=expiration_dt,
target_service_account=service_account,
scheme=frombytes(opts.scheme),
endpoint_override=frombytes(opts.endpoint_override),
default_bucket_location=frombytes(
opts.default_bucket_location),
default_metadata=pyarrow_wrap_metadata(opts.default_metadata),
retry_time_limit=retry_time_limit,
project_id=project_id
),))
@property
def default_bucket_location(self):
"""
The GCP location this filesystem will write to.
"""
return frombytes(self.gcsfs.options().default_bucket_location)
@property
def project_id(self):
"""
The GCP project id this filesystem will use.
"""
if self.gcsfs.options().project_id.has_value():
return frombytes(self.gcsfs.options().project_id.value())

View File

@@ -0,0 +1,21 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
__version__ = version = '21.0.0'
__version_tuple__ = version_tuple = (21, 0, 0)

View File

@@ -0,0 +1,157 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_fs cimport *
from pyarrow._fs cimport FileSystem
from pyarrow.lib import frombytes, tobytes
from pyarrow.util import _stringify_path
cdef class HadoopFileSystem(FileSystem):
"""
HDFS backed FileSystem implementation
Parameters
----------
host : str
HDFS host to connect to. Set to "default" for fs.defaultFS from
core-site.xml.
port : int, default 8020
HDFS port to connect to. Set to 0 for default or logical (HA) nodes.
user : str, default None
Username when connecting to HDFS; None implies login user.
replication : int, default 3
Number of copies each block will have.
buffer_size : int, default 0
If 0, no buffering will happen otherwise the size of the temporary read
and write buffer.
default_block_size : int, default None
None means the default configuration for HDFS, a typical block size is
128 MB.
kerb_ticket : string or path, default None
If not None, the path to the Kerberos ticket cache.
extra_conf : dict, default None
Extra key/value pairs for configuration; will override any
hdfs-site.xml properties.
Examples
--------
>>> from pyarrow import fs
>>> hdfs = fs.HadoopFileSystem(host, port, user=user, kerb_ticket=ticket_cache_path) # doctest: +SKIP
For usage of the methods see examples for :func:`~pyarrow.fs.LocalFileSystem`.
"""
cdef:
CHadoopFileSystem* hdfs
def __init__(self, str host, int port=8020, *, str user=None,
int replication=3, int buffer_size=0,
default_block_size=None, kerb_ticket=None,
extra_conf=None):
cdef:
CHdfsOptions options
shared_ptr[CHadoopFileSystem] wrapped
if not host.startswith(('hdfs://', 'viewfs://')) and host != "default":
# TODO(kszucs): do more sanitization
host = f'hdfs://{host}'
options.ConfigureEndPoint(tobytes(host), int(port))
options.ConfigureReplication(replication)
options.ConfigureBufferSize(buffer_size)
if user is not None:
options.ConfigureUser(tobytes(user))
if default_block_size is not None:
options.ConfigureBlockSize(default_block_size)
if kerb_ticket is not None:
options.ConfigureKerberosTicketCachePath(
tobytes(_stringify_path(kerb_ticket)))
if extra_conf is not None:
for k, v in extra_conf.items():
options.ConfigureExtraConf(tobytes(k), tobytes(v))
with nogil:
wrapped = GetResultValue(CHadoopFileSystem.Make(options))
self.init(<shared_ptr[CFileSystem]> wrapped)
cdef init(self, const shared_ptr[CFileSystem]& wrapped):
FileSystem.init(self, wrapped)
self.hdfs = <CHadoopFileSystem*> wrapped.get()
@staticmethod
def from_uri(uri):
"""
Instantiate HadoopFileSystem object from an URI string.
The following two calls are equivalent
* ``HadoopFileSystem.from_uri('hdfs://localhost:8020/?user=test\
&replication=1')``
* ``HadoopFileSystem('localhost', port=8020, user='test', \
replication=1)``
Parameters
----------
uri : str
A string URI describing the connection to HDFS.
In order to change the user, replication, buffer_size or
default_block_size pass the values as query parts.
Returns
-------
HadoopFileSystem
"""
cdef:
HadoopFileSystem self = HadoopFileSystem.__new__(HadoopFileSystem)
shared_ptr[CHadoopFileSystem] wrapped
CHdfsOptions options
options = GetResultValue(CHdfsOptions.FromUriString(tobytes(uri)))
with nogil:
wrapped = GetResultValue(CHadoopFileSystem.Make(options))
self.init(<shared_ptr[CFileSystem]> wrapped)
return self
@staticmethod
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return HadoopFileSystem(**kwargs)
def __reduce__(self):
cdef CHdfsOptions opts = self.hdfs.options()
return (
HadoopFileSystem._reconstruct, (dict(
host=frombytes(opts.connection_config.host),
port=opts.connection_config.port,
user=frombytes(opts.connection_config.user),
replication=opts.replication,
buffer_size=opts.buffer_size,
default_block_size=opts.default_block_size,
kerb_ticket=frombytes(opts.connection_config.kerb_ticket),
extra_conf={frombytes(k): frombytes(v)
for k, v in opts.connection_config.extra_conf},
),)
)

View File

@@ -0,0 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport _Weakrefable
cdef class ParseOptions(_Weakrefable):
cdef:
CJSONParseOptions options
@staticmethod
cdef ParseOptions wrap(CJSONParseOptions options)
cdef class ReadOptions(_Weakrefable):
cdef:
CJSONReadOptions options
@staticmethod
cdef ReadOptions wrap(CJSONReadOptions options)

View File

@@ -0,0 +1,386 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: profile=False
# distutils: language = c++
# cython: language_level = 3
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport (_Weakrefable, Schema,
RecordBatchReader, MemoryPool,
maybe_unbox_memory_pool,
get_input_stream, pyarrow_wrap_table,
pyarrow_wrap_schema, pyarrow_unwrap_schema)
cdef class ReadOptions(_Weakrefable):
"""
Options for reading JSON files.
Parameters
----------
use_threads : bool, optional (default True)
Whether to use multiple threads to accelerate reading
block_size : int, optional
How much bytes to process at a time from the input stream.
This will determine multi-threading granularity as well as
the size of individual chunks in the Table.
"""
# Avoid mistakingly creating attributes
__slots__ = ()
def __init__(self, use_threads=None, block_size=None):
self.options = CJSONReadOptions.Defaults()
if use_threads is not None:
self.use_threads = use_threads
if block_size is not None:
self.block_size = block_size
@property
def use_threads(self):
"""
Whether to use multiple threads to accelerate reading.
"""
return self.options.use_threads
@use_threads.setter
def use_threads(self, value):
self.options.use_threads = value
@property
def block_size(self):
"""
How much bytes to process at a time from the input stream.
This will determine multi-threading granularity as well as the size of
individual chunks in the Table.
"""
return self.options.block_size
@block_size.setter
def block_size(self, value):
self.options.block_size = value
def __reduce__(self):
return ReadOptions, (
self.use_threads,
self.block_size
)
def equals(self, ReadOptions other):
"""
Parameters
----------
other : pyarrow.json.ReadOptions
Returns
-------
bool
"""
return (
self.use_threads == other.use_threads and
self.block_size == other.block_size
)
def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return False
@staticmethod
cdef ReadOptions wrap(CJSONReadOptions options):
out = ReadOptions()
out.options = options # shallow copy
return out
cdef class ParseOptions(_Weakrefable):
"""
Options for parsing JSON files.
Parameters
----------
explicit_schema : Schema, optional (default None)
Optional explicit schema (no type inference, ignores other fields).
newlines_in_values : bool, optional (default False)
Whether objects may be printed across multiple lines (for example
pretty printed). If false, input must end with an empty line.
unexpected_field_behavior : str, default "infer"
How JSON fields outside of explicit_schema (if given) are treated.
Possible behaviors:
- "ignore": unexpected JSON fields are ignored
- "error": error out on unexpected JSON fields
- "infer": unexpected JSON fields are type-inferred and included in
the output
"""
__slots__ = ()
def __init__(self, explicit_schema=None, newlines_in_values=None,
unexpected_field_behavior=None):
self.options = CJSONParseOptions.Defaults()
if explicit_schema is not None:
self.explicit_schema = explicit_schema
if newlines_in_values is not None:
self.newlines_in_values = newlines_in_values
if unexpected_field_behavior is not None:
self.unexpected_field_behavior = unexpected_field_behavior
def __reduce__(self):
return ParseOptions, (
self.explicit_schema,
self.newlines_in_values,
self.unexpected_field_behavior
)
@property
def explicit_schema(self):
"""
Optional explicit schema (no type inference, ignores other fields)
"""
if self.options.explicit_schema.get() == NULL:
return None
else:
return pyarrow_wrap_schema(self.options.explicit_schema)
@explicit_schema.setter
def explicit_schema(self, value):
self.options.explicit_schema = pyarrow_unwrap_schema(value)
@property
def newlines_in_values(self):
"""
Whether newline characters are allowed in JSON values.
Setting this to True reduces the performance of multi-threaded
JSON reading.
"""
return self.options.newlines_in_values
@newlines_in_values.setter
def newlines_in_values(self, value):
self.options.newlines_in_values = value
@property
def unexpected_field_behavior(self):
"""
How JSON fields outside of explicit_schema (if given) are treated.
Possible behaviors:
- "ignore": unexpected JSON fields are ignored
- "error": error out on unexpected JSON fields
- "infer": unexpected JSON fields are type-inferred and included in
the output
Set to "infer" by default.
"""
v = self.options.unexpected_field_behavior
if v == CUnexpectedFieldBehavior_Ignore:
return "ignore"
elif v == CUnexpectedFieldBehavior_Error:
return "error"
elif v == CUnexpectedFieldBehavior_InferType:
return "infer"
else:
raise ValueError('Unexpected value for unexpected_field_behavior')
@unexpected_field_behavior.setter
def unexpected_field_behavior(self, value):
cdef CUnexpectedFieldBehavior v
if value == "ignore":
v = CUnexpectedFieldBehavior_Ignore
elif value == "error":
v = CUnexpectedFieldBehavior_Error
elif value == "infer":
v = CUnexpectedFieldBehavior_InferType
else:
raise ValueError(
f"Unexpected value `{value}` for `unexpected_field_behavior`, pass "
f"either `ignore`, `error` or `infer`."
)
self.options.unexpected_field_behavior = v
def equals(self, ParseOptions other):
"""
Parameters
----------
other : pyarrow.json.ParseOptions
Returns
-------
bool
"""
return (
self.explicit_schema == other.explicit_schema and
self.newlines_in_values == other.newlines_in_values and
self.unexpected_field_behavior == other.unexpected_field_behavior
)
def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return False
@staticmethod
cdef ParseOptions wrap(CJSONParseOptions options):
out = ParseOptions()
out.options = options # shallow copy
return out
cdef _get_reader(input_file, shared_ptr[CInputStream]* out):
use_memory_map = False
get_input_stream(input_file, use_memory_map, out)
cdef _get_read_options(ReadOptions read_options, CJSONReadOptions* out):
if read_options is None:
out[0] = CJSONReadOptions.Defaults()
else:
out[0] = read_options.options
cdef _get_parse_options(ParseOptions parse_options, CJSONParseOptions* out):
if parse_options is None:
out[0] = CJSONParseOptions.Defaults()
else:
out[0] = parse_options.options
cdef class JSONStreamingReader(RecordBatchReader):
"""An object that reads record batches incrementally from a JSON file.
Should not be instantiated directly by user code.
"""
cdef readonly:
Schema schema
def __init__(self):
raise TypeError(f"Do not call {self.__class__.__name__}'s "
"constructor directly, "
"use pyarrow.json.open_json() instead.")
cdef _open(self, shared_ptr[CInputStream] stream,
CJSONReadOptions c_read_options,
CJSONParseOptions c_parse_options,
MemoryPool memory_pool):
cdef:
shared_ptr[CSchema] c_schema
CIOContext io_context
io_context = CIOContext(maybe_unbox_memory_pool(memory_pool))
with nogil:
self.reader = <shared_ptr[CRecordBatchReader]> GetResultValue(
CJSONStreamingReader.Make(stream, move(c_read_options),
move(c_parse_options), io_context))
c_schema = self.reader.get().schema()
self.schema = pyarrow_wrap_schema(c_schema)
def read_json(input_file, read_options=None, parse_options=None,
MemoryPool memory_pool=None):
"""
Read a Table from a stream of JSON data.
Parameters
----------
input_file : str, path or file-like object
The location of JSON data. Currently only the line-delimited JSON
format is supported.
read_options : pyarrow.json.ReadOptions, optional
Options for the JSON reader (see ReadOptions constructor for defaults).
parse_options : pyarrow.json.ParseOptions, optional
Options for the JSON parser
(see ParseOptions constructor for defaults).
memory_pool : MemoryPool, optional
Pool to allocate Table memory from.
Returns
-------
:class:`pyarrow.Table`
Contents of the JSON file as a in-memory table.
"""
cdef:
shared_ptr[CInputStream] stream
CJSONReadOptions c_read_options
CJSONParseOptions c_parse_options
shared_ptr[CJSONReader] reader
shared_ptr[CTable] table
_get_reader(input_file, &stream)
_get_read_options(read_options, &c_read_options)
_get_parse_options(parse_options, &c_parse_options)
reader = GetResultValue(
CJSONReader.Make(maybe_unbox_memory_pool(memory_pool),
stream, c_read_options, c_parse_options))
with nogil:
table = GetResultValue(reader.get().Read())
return pyarrow_wrap_table(table)
def open_json(input_file, read_options=None, parse_options=None,
MemoryPool memory_pool=None):
"""
Open a streaming reader of JSON data.
Reading using this function is always single-threaded.
Parameters
----------
input_file : string, path or file-like object
The location of JSON data. If a string or path, and if it ends
with a recognized compressed file extension (e.g. ".gz" or ".bz2"),
the data is automatically decompressed when reading.
read_options : pyarrow.json.ReadOptions, optional
Options for the JSON reader (see pyarrow.json.ReadOptions constructor
for defaults)
parse_options : pyarrow.json.ParseOptions, optional
Options for the JSON parser
(see pyarrow.json.ParseOptions constructor for defaults)
memory_pool : MemoryPool, optional
Pool to allocate RecordBatch memory from
Returns
-------
:class:`pyarrow.json.JSONStreamingReader`
"""
cdef:
shared_ptr[CInputStream] stream
CJSONReadOptions c_read_options
CJSONParseOptions c_parse_options
JSONStreamingReader reader
_get_reader(input_file, &stream)
_get_read_options(read_options, &c_read_options)
_get_parse_options(parse_options, &c_parse_options)
reader = JSONStreamingReader.__new__(JSONStreamingReader)
reader._open(stream, move(c_read_options), move(c_parse_options),
memory_pool)
return reader

View File

@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language = c++
# cython: language_level = 3
from libcpp cimport bool as c_bool
from libc.string cimport const_char
from libcpp.vector cimport vector as std_vector
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport (CArray, CSchema, CStatus,
CResult, CTable, CMemoryPool,
CKeyValueMetadata,
CRecordBatch,
CTable, CCompressionType,
CRandomAccessFile, COutputStream,
TimeUnit)
cdef extern from "arrow/adapters/orc/options.h" \
namespace "arrow::adapters::orc" nogil:
cdef enum CompressionStrategy \
" arrow::adapters::orc::CompressionStrategy":
_CompressionStrategy_SPEED \
" arrow::adapters::orc::CompressionStrategy::kSpeed"
_CompressionStrategy_COMPRESSION \
" arrow::adapters::orc::CompressionStrategy::kCompression"
cdef enum WriterId" arrow::adapters::orc::WriterId":
_WriterId_ORC_JAVA_WRITER" arrow::adapters::orc::WriterId::kOrcJava"
_WriterId_ORC_CPP_WRITER" arrow::adapters::orc::WriterId::kOrcCpp"
_WriterId_PRESTO_WRITER" arrow::adapters::orc::WriterId::kPresto"
_WriterId_SCRITCHLEY_GO \
" arrow::adapters::orc::WriterId::kScritchleyGo"
_WriterId_TRINO_WRITER" arrow::adapters::orc::WriterId::kTrino"
_WriterId_UNKNOWN_WRITER" arrow::adapters::orc::WriterId::kUnknown"
cdef enum WriterVersion" arrow::adapters::orc::WriterVersion":
_WriterVersion_ORIGINAL \
" arrow::adapters::orc::WriterVersion::kOriginal"
_WriterVersion_HIVE_8732 \
" arrow::adapters::orc::WriterVersion::kHive8732"
_WriterVersion_HIVE_4243 \
" arrow::adapters::orc::WriterVersion::kHive4243"
_WriterVersion_HIVE_12055 \
" arrow::adapters::orc::WriterVersion::kHive12055"
_WriterVersion_HIVE_13083 \
" arrow::adapters::orc::WriterVersion::kHive13083"
_WriterVersion_ORC_101" arrow::adapters::orc::WriterVersion::kOrc101"
_WriterVersion_ORC_135" arrow::adapters::orc::WriterVersion::kOrc135"
_WriterVersion_ORC_517" arrow::adapters::orc::WriterVersion::kOrc517"
_WriterVersion_ORC_203" arrow::adapters::orc::WriterVersion::kOrc203"
_WriterVersion_ORC_14" arrow::adapters::orc::WriterVersion::kOrc14"
_WriterVersion_MAX" arrow::adapters::orc::WriterVersion::kMax"
cdef cppclass FileVersion" arrow::adapters::orc::FileVersion":
FileVersion(uint32_t major_version, uint32_t minor_version)
uint32_t major_version()
uint32_t minor_version()
c_string ToString()
cdef struct WriteOptions" arrow::adapters::orc::WriteOptions":
int64_t batch_size
FileVersion file_version
int64_t stripe_size
CCompressionType compression
int64_t compression_block_size
CompressionStrategy compression_strategy
int64_t row_index_stride
double padding_tolerance
double dictionary_key_size_threshold
std_vector[int64_t] bloom_filter_columns
double bloom_filter_fpp
cdef extern from "arrow/adapters/orc/adapter.h" \
namespace "arrow::adapters::orc" nogil:
cdef cppclass ORCFileReader:
@staticmethod
CResult[unique_ptr[ORCFileReader]] Open(
const shared_ptr[CRandomAccessFile]& file,
CMemoryPool* pool)
CResult[shared_ptr[const CKeyValueMetadata]] ReadMetadata()
CResult[shared_ptr[CSchema]] ReadSchema()
CResult[shared_ptr[CRecordBatch]] ReadStripe(int64_t stripe)
CResult[shared_ptr[CRecordBatch]] ReadStripe(
int64_t stripe, std_vector[c_string])
CResult[shared_ptr[CTable]] Read()
CResult[shared_ptr[CTable]] Read(std_vector[c_string])
int64_t NumberOfStripes()
int64_t NumberOfRows()
FileVersion GetFileVersion()
c_string GetSoftwareVersion()
CResult[CCompressionType] GetCompression()
int64_t GetCompressionSize()
int64_t GetRowIndexStride()
WriterId GetWriterId()
int32_t GetWriterIdValue()
WriterVersion GetWriterVersion()
int64_t GetNumberOfStripeStatistics()
int64_t GetContentLength()
int64_t GetStripeStatisticsLength()
int64_t GetFileFooterLength()
int64_t GetFilePostscriptLength()
int64_t GetFileLength()
c_string GetSerializedFileTail()
cdef cppclass ORCFileWriter:
@staticmethod
CResult[unique_ptr[ORCFileWriter]] Open(
COutputStream* output_stream, const WriteOptions& writer_options)
CStatus Write(const CTable& table)
CStatus Close()

View File

@@ -0,0 +1,445 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: profile=False
# distutils: language = c++
from cython.operator cimport dereference as deref
from libcpp.vector cimport vector as std_vector
from libcpp.utility cimport move
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport (check_status, _Weakrefable,
MemoryPool, maybe_unbox_memory_pool,
pyarrow_wrap_schema,
pyarrow_wrap_batch,
Table,
pyarrow_wrap_table,
pyarrow_wrap_metadata,
pyarrow_unwrap_table,
get_reader,
get_writer)
from pyarrow.lib import frombytes, tobytes
from pyarrow.util import _stringify_path
cdef compression_type_from_enum(CCompressionType compression_type):
compression_map = {
CCompressionType_UNCOMPRESSED: 'UNCOMPRESSED',
CCompressionType_GZIP: 'ZLIB',
CCompressionType_SNAPPY: 'SNAPPY',
CCompressionType_LZ4: 'LZ4',
CCompressionType_ZSTD: 'ZSTD',
}
if compression_type in compression_map:
return compression_map[compression_type]
raise ValueError('Unsupported compression')
cdef CCompressionType compression_type_from_name(name) except *:
if not isinstance(name, str):
raise TypeError('compression must be a string')
name = name.upper()
if name == 'ZLIB':
return CCompressionType_GZIP
elif name == 'SNAPPY':
return CCompressionType_SNAPPY
elif name == 'LZ4':
return CCompressionType_LZ4
elif name == 'ZSTD':
return CCompressionType_ZSTD
elif name == 'UNCOMPRESSED':
return CCompressionType_UNCOMPRESSED
raise ValueError(f'Unknown CompressionKind: {name}')
cdef compression_strategy_from_enum(
CompressionStrategy compression_strategy
):
compression_strategy_map = {
_CompressionStrategy_SPEED: 'SPEED',
_CompressionStrategy_COMPRESSION: 'COMPRESSION',
}
if compression_strategy in compression_strategy_map:
return compression_strategy_map[compression_strategy]
raise ValueError('Unsupported compression strategy')
cdef CompressionStrategy compression_strategy_from_name(name) except *:
if not isinstance(name, str):
raise TypeError('compression strategy must be a string')
name = name.upper()
if name == 'COMPRESSION':
return _CompressionStrategy_COMPRESSION
elif name == 'SPEED':
return _CompressionStrategy_SPEED
raise ValueError(f'Unknown CompressionStrategy: {name}')
cdef file_version_from_class(FileVersion file_version):
return frombytes(file_version.ToString())
cdef writer_id_from_enum(WriterId writer_id):
writer_id_map = {
_WriterId_ORC_JAVA_WRITER: 'ORC_JAVA',
_WriterId_ORC_CPP_WRITER: 'ORC_CPP',
_WriterId_PRESTO_WRITER: 'PRESTO',
_WriterId_SCRITCHLEY_GO: 'SCRITCHLEY_GO',
_WriterId_TRINO_WRITER: 'TRINO',
}
if writer_id in writer_id_map:
return writer_id_map[writer_id]
raise ValueError('Unsupported writer ID')
cdef writer_version_from_enum(WriterVersion writer_version):
writer_version_map = {
_WriterVersion_ORIGINAL: 'ORIGINAL',
_WriterVersion_HIVE_8732: 'HIVE_8732',
_WriterVersion_HIVE_4243: 'HIVE_4243',
_WriterVersion_HIVE_12055: 'HIVE_12055',
_WriterVersion_HIVE_13083: 'HIVE_13083',
_WriterVersion_ORC_101: 'ORC_101',
_WriterVersion_ORC_135: 'ORC_135',
_WriterVersion_ORC_517: 'ORC_517',
_WriterVersion_ORC_203: 'ORC_203',
_WriterVersion_ORC_14: 'ORC_14',
}
if writer_version in writer_version_map:
return writer_version_map[writer_version]
raise ValueError('Unsupported writer version')
cdef shared_ptr[WriteOptions] _create_write_options(
file_version=None,
batch_size=None,
stripe_size=None,
compression=None,
compression_block_size=None,
compression_strategy=None,
row_index_stride=None,
padding_tolerance=None,
dictionary_key_size_threshold=None,
bloom_filter_columns=None,
bloom_filter_fpp=None
) except *:
"""General writer options"""
cdef:
shared_ptr[WriteOptions] options
options = make_shared[WriteOptions]()
# batch_size
if batch_size is not None:
if isinstance(batch_size, int) and batch_size > 0:
deref(options).batch_size = batch_size
else:
raise ValueError(f"Invalid ORC writer batch size: {batch_size}")
# file_version
if file_version is not None:
if file_version == "0.12":
deref(options).file_version = FileVersion(0, 12)
elif file_version == "0.11":
deref(options).file_version = FileVersion(0, 11)
else:
raise ValueError(f"Unsupported ORC file version: {file_version}")
# stripe_size
if stripe_size is not None:
if isinstance(stripe_size, int) and stripe_size > 0:
deref(options).stripe_size = stripe_size
else:
raise ValueError(f"Invalid ORC stripe size: {stripe_size}")
# compression
if compression is not None:
if isinstance(compression, str):
deref(options).compression = compression_type_from_name(
compression)
else:
raise TypeError("Unsupported ORC compression type: "
f"{compression}")
# compression_block_size
if compression_block_size is not None:
if (isinstance(compression_block_size, int) and
compression_block_size > 0):
deref(options).compression_block_size = compression_block_size
else:
raise ValueError("Invalid ORC compression block size: "
f"{compression_block_size}")
# compression_strategy
if compression_strategy is not None:
if isinstance(compression, str):
deref(options).compression_strategy = \
compression_strategy_from_name(compression_strategy)
else:
raise TypeError("Unsupported ORC compression strategy: "
f"{compression_strategy}")
# row_index_stride
if row_index_stride is not None:
if isinstance(row_index_stride, int) and row_index_stride > 0:
deref(options).row_index_stride = row_index_stride
else:
raise ValueError("Invalid ORC row index stride: "
f"{row_index_stride}")
# padding_tolerance
if padding_tolerance is not None:
try:
padding_tolerance = float(padding_tolerance)
deref(options).padding_tolerance = padding_tolerance
except Exception:
raise ValueError("Invalid ORC padding tolerance: "
f"{padding_tolerance}")
# dictionary_key_size_threshold
if dictionary_key_size_threshold is not None:
try:
dictionary_key_size_threshold = float(
dictionary_key_size_threshold)
assert 0 <= dictionary_key_size_threshold <= 1
deref(options).dictionary_key_size_threshold = \
dictionary_key_size_threshold
except Exception:
raise ValueError("Invalid ORC dictionary key size threshold: "
f"{dictionary_key_size_threshold}")
# bloom_filter_columns
if bloom_filter_columns is not None:
try:
bloom_filter_columns = list(bloom_filter_columns)
for col in bloom_filter_columns:
assert isinstance(col, int) and col >= 0
deref(options).bloom_filter_columns = bloom_filter_columns
except Exception:
raise ValueError("Invalid ORC BloomFilter columns: "
f"{bloom_filter_columns}")
# Max false positive rate of the Bloom Filter
if bloom_filter_fpp is not None:
try:
bloom_filter_fpp = float(bloom_filter_fpp)
assert 0 <= bloom_filter_fpp <= 1
deref(options).bloom_filter_fpp = bloom_filter_fpp
except Exception:
raise ValueError("Invalid ORC BloomFilter false positive rate: "
f"{bloom_filter_fpp}")
return options
cdef class ORCReader(_Weakrefable):
cdef:
object source
CMemoryPool* allocator
unique_ptr[ORCFileReader] reader
def __cinit__(self, MemoryPool memory_pool=None):
self.allocator = maybe_unbox_memory_pool(memory_pool)
def open(self, object source, c_bool use_memory_map=True):
cdef:
shared_ptr[CRandomAccessFile] rd_handle
self.source = source
get_reader(source, use_memory_map, &rd_handle)
with nogil:
self.reader = move(GetResultValue(
ORCFileReader.Open(rd_handle, self.allocator)
))
def metadata(self):
"""
The arrow metadata for this file.
Returns
-------
metadata : pyarrow.KeyValueMetadata
"""
cdef:
shared_ptr[const CKeyValueMetadata] sp_arrow_metadata
with nogil:
sp_arrow_metadata = GetResultValue(
deref(self.reader).ReadMetadata()
)
return pyarrow_wrap_metadata(sp_arrow_metadata)
def schema(self):
"""
The arrow schema for this file.
Returns
-------
schema : pyarrow.Schema
"""
cdef:
shared_ptr[CSchema] sp_arrow_schema
with nogil:
sp_arrow_schema = GetResultValue(deref(self.reader).ReadSchema())
return pyarrow_wrap_schema(sp_arrow_schema)
def nrows(self):
return deref(self.reader).NumberOfRows()
def nstripes(self):
return deref(self.reader).NumberOfStripes()
def file_version(self):
return file_version_from_class(deref(self.reader).GetFileVersion())
def software_version(self):
return frombytes(deref(self.reader).GetSoftwareVersion())
def compression(self):
return compression_type_from_enum(
GetResultValue(deref(self.reader).GetCompression()))
def compression_size(self):
return deref(self.reader).GetCompressionSize()
def row_index_stride(self):
return deref(self.reader).GetRowIndexStride()
def writer(self):
writer_name = writer_id_from_enum(deref(self.reader).GetWriterId())
if writer_name == 'UNKNOWN':
return deref(self.reader).GetWriterIdValue()
else:
return writer_name
def writer_version(self):
return writer_version_from_enum(deref(self.reader).GetWriterVersion())
def nstripe_statistics(self):
return deref(self.reader).GetNumberOfStripeStatistics()
def content_length(self):
return deref(self.reader).GetContentLength()
def stripe_statistics_length(self):
return deref(self.reader).GetStripeStatisticsLength()
def file_footer_length(self):
return deref(self.reader).GetFileFooterLength()
def file_postscript_length(self):
return deref(self.reader).GetFilePostscriptLength()
def file_length(self):
return deref(self.reader).GetFileLength()
def serialized_file_tail(self):
return deref(self.reader).GetSerializedFileTail()
def read_stripe(self, n, columns=None):
cdef:
shared_ptr[CRecordBatch] sp_record_batch
int64_t stripe
std_vector[c_string] c_names
stripe = n
if columns is None:
with nogil:
sp_record_batch = GetResultValue(
deref(self.reader).ReadStripe(stripe)
)
else:
c_names = [tobytes(name) for name in columns]
with nogil:
sp_record_batch = GetResultValue(
deref(self.reader).ReadStripe(stripe, c_names)
)
return pyarrow_wrap_batch(sp_record_batch)
def read(self, columns=None):
cdef:
shared_ptr[CTable] sp_table
std_vector[c_string] c_names
if columns is None:
with nogil:
sp_table = GetResultValue(deref(self.reader).Read())
else:
c_names = [tobytes(name) for name in columns]
with nogil:
sp_table = GetResultValue(deref(self.reader).Read(c_names))
return pyarrow_wrap_table(sp_table)
cdef class ORCWriter(_Weakrefable):
cdef:
unique_ptr[ORCFileWriter] writer
shared_ptr[COutputStream] sink
c_bool own_sink
def open(self, object where, *,
file_version=None,
batch_size=None,
stripe_size=None,
compression=None,
compression_block_size=None,
compression_strategy=None,
row_index_stride=None,
padding_tolerance=None,
dictionary_key_size_threshold=None,
bloom_filter_columns=None,
bloom_filter_fpp=None):
cdef:
shared_ptr[WriteOptions] write_options
c_string c_where
try:
where = _stringify_path(where)
except TypeError:
get_writer(where, &self.sink)
self.own_sink = False
else:
c_where = tobytes(where)
with nogil:
self.sink = GetResultValue(FileOutputStream.Open(c_where))
self.own_sink = True
write_options = _create_write_options(
file_version=file_version,
batch_size=batch_size,
stripe_size=stripe_size,
compression=compression,
compression_block_size=compression_block_size,
compression_strategy=compression_strategy,
row_index_stride=row_index_stride,
padding_tolerance=padding_tolerance,
dictionary_key_size_threshold=dictionary_key_size_threshold,
bloom_filter_columns=bloom_filter_columns,
bloom_filter_fpp=bloom_filter_fpp
)
with nogil:
self.writer = move(GetResultValue(
ORCFileWriter.Open(self.sink.get(),
deref(write_options))))
def write(self, Table table):
cdef:
shared_ptr[CTable] sp_table
sp_table = pyarrow_unwrap_table(table)
with nogil:
check_status(deref(self.writer).Write(deref(sp_table)))
def close(self):
with nogil:
check_status(deref(self.writer).Close())
if self.own_sink:
check_status(deref(self.sink).Close())

View File

@@ -0,0 +1,152 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language = c++
# cython: language_level = 3
from pyarrow.includes.libparquet cimport *
from pyarrow.lib cimport _Weakrefable
cdef class FileEncryptionProperties:
"""File-level encryption properties for the low-level API"""
cdef:
shared_ptr[CFileEncryptionProperties] properties
@staticmethod
cdef inline FileEncryptionProperties wrap(
shared_ptr[CFileEncryptionProperties] properties):
result = FileEncryptionProperties()
result.properties = properties
return result
cdef inline shared_ptr[CFileEncryptionProperties] unwrap(self):
return self.properties
cdef shared_ptr[WriterProperties] _create_writer_properties(
use_dictionary=*,
compression=*,
version=*,
write_statistics=*,
data_page_size=*,
compression_level=*,
use_byte_stream_split=*,
column_encoding=*,
data_page_version=*,
FileEncryptionProperties encryption_properties=*,
write_batch_size=*,
dictionary_pagesize_limit=*,
write_page_index=*,
write_page_checksum=*,
sorting_columns=*,
store_decimal_as_integer=*,
use_content_defined_chunking=*
) except *
cdef shared_ptr[ArrowWriterProperties] _create_arrow_writer_properties(
use_deprecated_int96_timestamps=*,
coerce_timestamps=*,
allow_truncated_timestamps=*,
writer_engine_version=*,
use_compliant_nested_type=*,
store_schema=*,
) except *
# Unwrap the "list_type" argument for ArrowReaderProperties
cdef Type _unwrap_list_type(obj) except *
cdef class ParquetSchema(_Weakrefable):
cdef:
FileMetaData parent # the FileMetaData owning the SchemaDescriptor
const SchemaDescriptor* schema
cdef class FileMetaData(_Weakrefable):
cdef:
shared_ptr[CFileMetaData] sp_metadata
CFileMetaData* _metadata
ParquetSchema _schema
cdef inline init(self, const shared_ptr[CFileMetaData]& metadata):
self.sp_metadata = metadata
self._metadata = metadata.get()
cdef class RowGroupMetaData(_Weakrefable):
cdef:
int index # for pickling support
unique_ptr[CRowGroupMetaData] up_metadata
CRowGroupMetaData* metadata
FileMetaData parent
cdef inline init(self, FileMetaData parent, int index):
if index < 0 or index >= parent.num_row_groups:
raise IndexError('{0} out of bounds'.format(index))
self.up_metadata = parent._metadata.RowGroup(index)
self.metadata = self.up_metadata.get()
self.parent = parent
self.index = index
cdef class ColumnChunkMetaData(_Weakrefable):
cdef:
unique_ptr[CColumnChunkMetaData] up_metadata
CColumnChunkMetaData* metadata
RowGroupMetaData parent
cdef inline init(self, RowGroupMetaData parent, int i):
self.up_metadata = parent.metadata.ColumnChunk(i)
self.metadata = self.up_metadata.get()
self.parent = parent
cdef class Statistics(_Weakrefable):
cdef:
shared_ptr[CStatistics] statistics
ColumnChunkMetaData parent
cdef inline init(self, const shared_ptr[CStatistics]& statistics,
ColumnChunkMetaData parent):
self.statistics = statistics
self.parent = parent
cdef class GeoStatistics(_Weakrefable):
cdef:
shared_ptr[CParquetGeoStatistics] statistics
ColumnChunkMetaData parent
cdef inline init(self, const shared_ptr[CParquetGeoStatistics]& statistics,
ColumnChunkMetaData parent):
self.statistics = statistics
self.parent = parent
cdef class FileDecryptionProperties:
"""File-level decryption properties for the low-level API"""
cdef:
shared_ptr[CFileDecryptionProperties] properties
@staticmethod
cdef inline FileDecryptionProperties wrap(
shared_ptr[CFileDecryptionProperties] properties):
result = FileDecryptionProperties()
result.properties = properties
return result
cdef inline shared_ptr[CFileDecryptionProperties] unwrap(self):
return self.properties

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language = c++
# cython: language_level = 3
from pyarrow.includes.common cimport *
from pyarrow.includes.libparquet_encryption cimport *
from pyarrow._parquet cimport (ParquetCipher,
CFileEncryptionProperties,
CFileDecryptionProperties,
FileEncryptionProperties,
FileDecryptionProperties,
ParquetCipher_AES_GCM_V1,
ParquetCipher_AES_GCM_CTR_V1)
from pyarrow.lib cimport _Weakrefable
cdef class CryptoFactory(_Weakrefable):
cdef shared_ptr[CPyCryptoFactory] factory
cdef init(self, callable_client_factory)
cdef inline shared_ptr[CPyCryptoFactory] unwrap(self)
cdef class EncryptionConfiguration(_Weakrefable):
cdef shared_ptr[CEncryptionConfiguration] configuration
cdef inline shared_ptr[CEncryptionConfiguration] unwrap(self) nogil
cdef class DecryptionConfiguration(_Weakrefable):
cdef shared_ptr[CDecryptionConfiguration] configuration
cdef inline shared_ptr[CDecryptionConfiguration] unwrap(self) nogil
cdef class KmsConnectionConfig(_Weakrefable):
cdef shared_ptr[CKmsConnectionConfig] configuration
cdef inline shared_ptr[CKmsConnectionConfig] unwrap(self) nogil
@staticmethod
cdef wrap(const CKmsConnectionConfig& config)
cdef shared_ptr[CCryptoFactory] pyarrow_unwrap_cryptofactory(object crypto_factory) except *
cdef shared_ptr[CKmsConnectionConfig] pyarrow_unwrap_kmsconnectionconfig(object kmsconnectionconfig) except *
cdef shared_ptr[CEncryptionConfiguration] pyarrow_unwrap_encryptionconfig(object encryptionconfig) except *
cdef shared_ptr[CDecryptionConfiguration] pyarrow_unwrap_decryptionconfig(object decryptionconfig) except *

View File

@@ -0,0 +1,498 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: profile=False
# distutils: language = c++
from datetime import timedelta
from cython.operator cimport dereference as deref
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport _Weakrefable
from pyarrow.lib import tobytes, frombytes
cdef ParquetCipher cipher_from_name(name):
name = name.upper()
if name == 'AES_GCM_V1':
return ParquetCipher_AES_GCM_V1
elif name == 'AES_GCM_CTR_V1':
return ParquetCipher_AES_GCM_CTR_V1
else:
raise ValueError(f'Invalid cipher name: {name!r}')
cdef cipher_to_name(ParquetCipher cipher):
if ParquetCipher_AES_GCM_V1 == cipher:
return 'AES_GCM_V1'
elif ParquetCipher_AES_GCM_CTR_V1 == cipher:
return 'AES_GCM_CTR_V1'
else:
raise ValueError(f'Invalid cipher value: {cipher}')
cdef class EncryptionConfiguration(_Weakrefable):
"""Configuration of the encryption, such as which columns to encrypt"""
# Avoid mistakingly creating attributes
__slots__ = ()
def __init__(self, footer_key, *, column_keys=None,
uniform_encryption=None,
encryption_algorithm=None,
plaintext_footer=None, double_wrapping=None,
cache_lifetime=None, internal_key_material=None,
data_key_length_bits=None):
self.configuration.reset(
new CEncryptionConfiguration(tobytes(footer_key)))
if column_keys is not None:
self.column_keys = column_keys
if uniform_encryption is not None:
self.uniform_encryption = uniform_encryption
if encryption_algorithm is not None:
self.encryption_algorithm = encryption_algorithm
if plaintext_footer is not None:
self.plaintext_footer = plaintext_footer
if double_wrapping is not None:
self.double_wrapping = double_wrapping
if cache_lifetime is not None:
self.cache_lifetime = cache_lifetime
if internal_key_material is not None:
self.internal_key_material = internal_key_material
if data_key_length_bits is not None:
self.data_key_length_bits = data_key_length_bits
@property
def footer_key(self):
"""ID of the master key for footer encryption/signing"""
return frombytes(self.configuration.get().footer_key)
@property
def column_keys(self):
"""
List of columns to encrypt, with master key IDs.
"""
column_keys_str = frombytes(self.configuration.get().column_keys)
# Convert from "masterKeyID:colName,colName;masterKeyID:colName..."
# (see HIVE-21848) to dictionary of master key ID to column name lists
column_keys_to_key_list_str = dict(subString.replace(" ", "").split(
":") for subString in column_keys_str.split(";"))
column_keys_dict = {k: v.split(
",") for k, v in column_keys_to_key_list_str.items()}
return column_keys_dict
@column_keys.setter
def column_keys(self, dict value):
if value is not None:
# convert a dictionary such as
# '{"key1": ["col1 ", "col2"], "key2": ["col3 ", "col4"]}''
# to the string defined by the spec
# 'key1: col1 , col2; key2: col3 , col4'
column_keys = "; ".join(
[f"{k}: {', '.join(v)}" for k, v in value.items()])
self.configuration.get().column_keys = tobytes(column_keys)
@property
def uniform_encryption(self):
"""Whether to encrypt footer and all columns with the same encryption key.
This cannot be used together with column_keys.
"""
return self.configuration.get().uniform_encryption
@uniform_encryption.setter
def uniform_encryption(self, value):
self.configuration.get().uniform_encryption = value
@property
def encryption_algorithm(self):
"""Parquet encryption algorithm.
Can be "AES_GCM_V1" (default), or "AES_GCM_CTR_V1"."""
return cipher_to_name(self.configuration.get().encryption_algorithm)
@encryption_algorithm.setter
def encryption_algorithm(self, value):
cipher = cipher_from_name(value)
self.configuration.get().encryption_algorithm = cipher
@property
def plaintext_footer(self):
"""Write files with plaintext footer."""
return self.configuration.get().plaintext_footer
@plaintext_footer.setter
def plaintext_footer(self, value):
self.configuration.get().plaintext_footer = value
@property
def double_wrapping(self):
"""Use double wrapping - where data encryption keys (DEKs) are
encrypted with key encryption keys (KEKs), which in turn are
encrypted with master keys.
If set to false, use single wrapping - where DEKs are
encrypted directly with master keys."""
return self.configuration.get().double_wrapping
@double_wrapping.setter
def double_wrapping(self, value):
self.configuration.get().double_wrapping = value
@property
def cache_lifetime(self):
"""Lifetime of cached entities (key encryption keys,
local wrapping keys, KMS client objects)."""
return timedelta(
seconds=self.configuration.get().cache_lifetime_seconds)
@cache_lifetime.setter
def cache_lifetime(self, value):
if not isinstance(value, timedelta):
raise TypeError("cache_lifetime should be a timedelta")
self.configuration.get().cache_lifetime_seconds = value.total_seconds()
@property
def internal_key_material(self):
"""Store key material inside Parquet file footers; this mode doesnt
produce additional files. If set to false, key material is stored in
separate files in the same folder, which enables key rotation for
immutable Parquet files."""
return self.configuration.get().internal_key_material
@internal_key_material.setter
def internal_key_material(self, value):
self.configuration.get().internal_key_material = value
@property
def data_key_length_bits(self):
"""Length of data encryption keys (DEKs), randomly generated by parquet key
management tools. Can be 128, 192 or 256 bits."""
return self.configuration.get().data_key_length_bits
@data_key_length_bits.setter
def data_key_length_bits(self, value):
self.configuration.get().data_key_length_bits = value
cdef inline shared_ptr[CEncryptionConfiguration] unwrap(self) nogil:
return self.configuration
cdef class DecryptionConfiguration(_Weakrefable):
"""Configuration of the decryption, such as cache timeout."""
# Avoid mistakingly creating attributes
__slots__ = ()
def __init__(self, *, cache_lifetime=None):
self.configuration.reset(new CDecryptionConfiguration())
@property
def cache_lifetime(self):
"""Lifetime of cached entities (key encryption keys,
local wrapping keys, KMS client objects)."""
return timedelta(
seconds=self.configuration.get().cache_lifetime_seconds)
@cache_lifetime.setter
def cache_lifetime(self, value):
self.configuration.get().cache_lifetime_seconds = value.total_seconds()
cdef inline shared_ptr[CDecryptionConfiguration] unwrap(self) nogil:
return self.configuration
cdef class KmsConnectionConfig(_Weakrefable):
"""Configuration of the connection to the Key Management Service (KMS)"""
# Avoid mistakingly creating attributes
__slots__ = ()
def __init__(self, *, kms_instance_id=None, kms_instance_url=None,
key_access_token=None, custom_kms_conf=None):
self.configuration.reset(new CKmsConnectionConfig())
if kms_instance_id is not None:
self.kms_instance_id = kms_instance_id
if kms_instance_url is not None:
self.kms_instance_url = kms_instance_url
if key_access_token is None:
self.key_access_token = b'DEFAULT'
else:
self.key_access_token = key_access_token
if custom_kms_conf is not None:
self.custom_kms_conf = custom_kms_conf
@property
def kms_instance_id(self):
"""ID of the KMS instance that will be used for encryption
(if multiple KMS instances are available)."""
return frombytes(self.configuration.get().kms_instance_id)
@kms_instance_id.setter
def kms_instance_id(self, value):
self.configuration.get().kms_instance_id = tobytes(value)
@property
def kms_instance_url(self):
"""URL of the KMS instance."""
return frombytes(self.configuration.get().kms_instance_url)
@kms_instance_url.setter
def kms_instance_url(self, value):
self.configuration.get().kms_instance_url = tobytes(value)
@property
def key_access_token(self):
"""Authorization token that will be passed to KMS."""
return frombytes(self.configuration.get()
.refreshable_key_access_token.get().value())
@key_access_token.setter
def key_access_token(self, value):
self.refresh_key_access_token(value)
@property
def custom_kms_conf(self):
"""A dictionary with KMS-type-specific configuration"""
custom_kms_conf = {
frombytes(k): frombytes(v)
for k, v in self.configuration.get().custom_kms_conf
}
return custom_kms_conf
@custom_kms_conf.setter
def custom_kms_conf(self, dict value):
if value is not None:
for k, v in value.items():
if isinstance(k, str) and isinstance(v, str):
self.configuration.get().custom_kms_conf[tobytes(k)] = \
tobytes(v)
else:
raise TypeError("Expected custom_kms_conf to be " +
"a dictionary of strings")
def refresh_key_access_token(self, value):
cdef:
shared_ptr[CKeyAccessToken] c_key_access_token = \
self.configuration.get().refreshable_key_access_token
c_key_access_token.get().Refresh(tobytes(value))
cdef inline shared_ptr[CKmsConnectionConfig] unwrap(self) nogil:
return self.configuration
@staticmethod
cdef wrap(const CKmsConnectionConfig& config):
result = KmsConnectionConfig()
result.configuration = make_shared[CKmsConnectionConfig](move(config))
return result
# Callback definitions for CPyKmsClientVtable
cdef void _cb_wrap_key(
handler, const c_string& key_bytes,
const c_string& master_key_identifier, c_string* out) except *:
mkid_str = frombytes(master_key_identifier)
wrapped_key = handler.wrap_key(key_bytes, mkid_str)
out[0] = tobytes(wrapped_key)
cdef void _cb_unwrap_key(
handler, const c_string& wrapped_key,
const c_string& master_key_identifier, c_string* out) except *:
mkid_str = frombytes(master_key_identifier)
wk_str = frombytes(wrapped_key)
key = handler.unwrap_key(wk_str, mkid_str)
out[0] = tobytes(key)
cdef class KmsClient(_Weakrefable):
"""The abstract base class for KmsClient implementations."""
cdef:
shared_ptr[CKmsClient] client
def __init__(self):
self.init()
cdef init(self):
cdef:
CPyKmsClientVtable vtable = CPyKmsClientVtable()
vtable.wrap_key = _cb_wrap_key
vtable.unwrap_key = _cb_unwrap_key
self.client.reset(new CPyKmsClient(self, vtable))
def wrap_key(self, key_bytes, master_key_identifier):
"""Wrap a key - encrypt it with the master key."""
raise NotImplementedError()
def unwrap_key(self, wrapped_key, master_key_identifier):
"""Unwrap a key - decrypt it with the master key."""
raise NotImplementedError()
cdef inline shared_ptr[CKmsClient] unwrap(self) nogil:
return self.client
# Callback definition for CPyKmsClientFactoryVtable
cdef void _cb_create_kms_client(
handler,
const CKmsConnectionConfig& kms_connection_config,
shared_ptr[CKmsClient]* out) except *:
connection_config = KmsConnectionConfig.wrap(kms_connection_config)
result = handler(connection_config)
if not isinstance(result, KmsClient):
raise TypeError(
f"callable must return KmsClient instances, but got {type(result)}")
out[0] = (<KmsClient> result).unwrap()
cdef class CryptoFactory(_Weakrefable):
""" A factory that produces the low-level FileEncryptionProperties and
FileDecryptionProperties objects, from the high-level parameters."""
# Avoid mistakingly creating attributes
__slots__ = ()
def __init__(self, kms_client_factory):
"""Create CryptoFactory.
Parameters
----------
kms_client_factory : a callable that accepts KmsConnectionConfig
and returns a KmsClient
"""
self.factory.reset(new CPyCryptoFactory())
if callable(kms_client_factory):
self.init(kms_client_factory)
else:
raise TypeError("Parameter kms_client_factory must be a callable")
cdef init(self, callable_client_factory):
cdef:
CPyKmsClientFactoryVtable vtable
shared_ptr[CPyKmsClientFactory] kms_client_factory
vtable.create_kms_client = _cb_create_kms_client
kms_client_factory.reset(
new CPyKmsClientFactory(callable_client_factory, vtable))
# A KmsClientFactory object must be registered
# via this method before calling any of
# file_encryption_properties()/file_decryption_properties() methods.
self.factory.get().RegisterKmsClientFactory(
static_pointer_cast[CKmsClientFactory, CPyKmsClientFactory](
kms_client_factory))
def file_encryption_properties(self,
KmsConnectionConfig kms_connection_config,
EncryptionConfiguration encryption_config):
"""Create file encryption properties.
Parameters
----------
kms_connection_config : KmsConnectionConfig
Configuration of connection to KMS
encryption_config : EncryptionConfiguration
Configuration of the encryption, such as which columns to encrypt
Returns
-------
file_encryption_properties : FileEncryptionProperties
File encryption properties.
"""
cdef:
CResult[shared_ptr[CFileEncryptionProperties]] \
file_encryption_properties_result
with nogil:
file_encryption_properties_result = \
self.factory.get().SafeGetFileEncryptionProperties(
deref(kms_connection_config.unwrap().get()),
deref(encryption_config.unwrap().get()))
file_encryption_properties = GetResultValue(
file_encryption_properties_result)
return FileEncryptionProperties.wrap(file_encryption_properties)
def file_decryption_properties(
self,
KmsConnectionConfig kms_connection_config,
DecryptionConfiguration decryption_config=None):
"""Create file decryption properties.
Parameters
----------
kms_connection_config : KmsConnectionConfig
Configuration of connection to KMS
decryption_config : DecryptionConfiguration, default None
Configuration of the decryption, such as cache timeout.
Can be None.
Returns
-------
file_decryption_properties : FileDecryptionProperties
File decryption properties.
"""
cdef:
CDecryptionConfiguration c_decryption_config
CResult[shared_ptr[CFileDecryptionProperties]] \
c_file_decryption_properties
if decryption_config is None:
c_decryption_config = CDecryptionConfiguration()
else:
c_decryption_config = deref(decryption_config.unwrap().get())
with nogil:
c_file_decryption_properties = \
self.factory.get().SafeGetFileDecryptionProperties(
deref(kms_connection_config.unwrap().get()),
c_decryption_config)
file_decryption_properties = GetResultValue(
c_file_decryption_properties)
return FileDecryptionProperties.wrap(file_decryption_properties)
def remove_cache_entries_for_token(self, access_token):
self.factory.get().RemoveCacheEntriesForToken(tobytes(access_token))
def remove_cache_entries_for_all_tokens(self):
self.factory.get().RemoveCacheEntriesForAllTokens()
cdef inline shared_ptr[CPyCryptoFactory] unwrap(self):
return self.factory
cdef shared_ptr[CCryptoFactory] pyarrow_unwrap_cryptofactory(object crypto_factory) except *:
if isinstance(crypto_factory, CryptoFactory):
pycf = (<CryptoFactory> crypto_factory).unwrap()
return static_pointer_cast[CCryptoFactory, CPyCryptoFactory](pycf)
raise TypeError("Expected CryptoFactory, got %s" % type(crypto_factory))
cdef shared_ptr[CKmsConnectionConfig] pyarrow_unwrap_kmsconnectionconfig(object kmsconnectionconfig) except *:
if isinstance(kmsconnectionconfig, KmsConnectionConfig):
return (<KmsConnectionConfig> kmsconnectionconfig).unwrap()
raise TypeError("Expected KmsConnectionConfig, got %s" % type(kmsconnectionconfig))
cdef shared_ptr[CEncryptionConfiguration] pyarrow_unwrap_encryptionconfig(object encryptionconfig) except *:
if isinstance(encryptionconfig, EncryptionConfiguration):
return (<EncryptionConfiguration> encryptionconfig).unwrap()
raise TypeError("Expected EncryptionConfiguration, got %s" % type(encryptionconfig))
cdef shared_ptr[CDecryptionConfiguration] pyarrow_unwrap_decryptionconfig(object decryptionconfig) except *:
if isinstance(decryptionconfig, DecryptionConfiguration):
return (<DecryptionConfiguration> decryptionconfig).unwrap()
raise TypeError("Expected DecryptionConfiguration, got %s" % type(decryptionconfig))

View File

@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language = c++
# cython: language_level = 3
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport CStatus
ctypedef CStatus cb_test_func()
cdef extern from "arrow/python/python_test.h" namespace "arrow::py::testing" nogil:
cdef cppclass CTestCase "arrow::py::testing::TestCase":
c_string name
cb_test_func func
vector[CTestCase] GetCppTestCases()

View File

@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: profile=False, binding=True
# distutils: language = c++
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport check_status
from pyarrow.lib import frombytes
cdef class CppTestCase:
"""
A simple wrapper for a C++ test case.
"""
cdef:
CTestCase c_case
@staticmethod
cdef wrap(CTestCase c_case):
cdef:
CppTestCase obj
obj = CppTestCase.__new__(CppTestCase)
obj.c_case = c_case
return obj
@property
def name(self):
return frombytes(self.c_case.name)
def __repr__(self):
return f"<{self.__class__.__name__} {self.name!r}>"
def __call__(self):
check_status(self.c_case.func())
def get_cpp_tests():
"""
Get a list of C++ test cases.
"""
cases = []
c_cases = GetCppTestCases()
for c_case in c_cases:
cases.append(CppTestCase.wrap(c_case))
return cases

View File

@@ -0,0 +1,491 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from pyarrow.lib cimport (check_status, pyarrow_wrap_metadata,
pyarrow_unwrap_metadata)
from pyarrow.lib import frombytes, tobytes, KeyValueMetadata
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_fs cimport *
from pyarrow._fs cimport FileSystem
cpdef enum S3LogLevel:
Off = <int8_t> CS3LogLevel_Off
Fatal = <int8_t> CS3LogLevel_Fatal
Error = <int8_t> CS3LogLevel_Error
Warn = <int8_t> CS3LogLevel_Warn
Info = <int8_t> CS3LogLevel_Info
Debug = <int8_t> CS3LogLevel_Debug
Trace = <int8_t> CS3LogLevel_Trace
def initialize_s3(S3LogLevel log_level=S3LogLevel.Fatal, int num_event_loop_threads=1):
"""
Initialize S3 support
Parameters
----------
log_level : S3LogLevel
level of logging
num_event_loop_threads : int, default 1
how many threads to use for the AWS SDK's I/O event loop
Examples
--------
>>> fs.initialize_s3(fs.S3LogLevel.Error) # doctest: +SKIP
"""
cdef CS3GlobalOptions options
options.log_level = <CS3LogLevel> log_level
options.num_event_loop_threads = num_event_loop_threads
check_status(CInitializeS3(options))
def ensure_s3_initialized():
"""
Initialize S3 (with default options) if not already initialized
"""
check_status(CEnsureS3Initialized())
def finalize_s3():
check_status(CFinalizeS3())
def ensure_s3_finalized():
"""
Finalize S3 if already initialized
"""
check_status(CEnsureS3Finalized())
def resolve_s3_region(bucket):
"""
Resolve the S3 region of a bucket.
Parameters
----------
bucket : str
A S3 bucket name
Returns
-------
region : str
A S3 region name
Examples
--------
>>> fs.resolve_s3_region('voltrondata-labs-datasets')
'us-east-2'
"""
cdef:
c_string c_bucket
c_string c_region
ensure_s3_initialized()
c_bucket = tobytes(bucket)
with nogil:
c_region = GetResultValue(ResolveS3BucketRegion(c_bucket))
return frombytes(c_region)
class S3RetryStrategy:
"""
Base class for AWS retry strategies for use with S3.
Parameters
----------
max_attempts : int, default 3
The maximum number of retry attempts to attempt before failing.
"""
def __init__(self, max_attempts=3):
self.max_attempts = max_attempts
class AwsStandardS3RetryStrategy(S3RetryStrategy):
"""
Represents an AWS Standard retry strategy for use with S3.
Parameters
----------
max_attempts : int, default 3
The maximum number of retry attempts to attempt before failing.
"""
pass
class AwsDefaultS3RetryStrategy(S3RetryStrategy):
"""
Represents an AWS Default retry strategy for use with S3.
Parameters
----------
max_attempts : int, default 3
The maximum number of retry attempts to attempt before failing.
"""
pass
cdef class S3FileSystem(FileSystem):
"""
S3-backed FileSystem implementation
AWS access_key and secret_key can be provided explicitly.
If role_arn is provided instead of access_key and secret_key, temporary
credentials will be fetched by issuing a request to STS to assume the
specified role.
If neither access_key nor secret_key are provided, and role_arn is also not
provided, then attempts to establish the credentials automatically.
S3FileSystem will try the following methods, in order:
* ``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``, and ``AWS_SESSION_TOKEN`` environment variables
* configuration files such as ``~/.aws/credentials`` and ``~/.aws/config``
* for nodes on Amazon EC2, the EC2 Instance Metadata Service
Note: S3 buckets are special and the operations available on them may be
limited or more expensive than desired.
When S3FileSystem creates new buckets (assuming allow_bucket_creation is
True), it does not pass any non-default settings. In AWS S3, the bucket and
all objects will be not publicly visible, and will have no bucket policies
and no resource tags. To have more control over how buckets are created,
use a different API to create them.
Parameters
----------
access_key : str, default None
AWS Access Key ID. Pass None to use the standard AWS environment
variables and/or configuration file.
secret_key : str, default None
AWS Secret Access key. Pass None to use the standard AWS environment
variables and/or configuration file.
session_token : str, default None
AWS Session Token. An optional session token, required if access_key
and secret_key are temporary credentials from STS.
anonymous : bool, default False
Whether to connect anonymously if access_key and secret_key are None.
If true, will not attempt to look up credentials using standard AWS
configuration methods.
role_arn : str, default None
AWS Role ARN. If provided instead of access_key and secret_key,
temporary credentials will be fetched by assuming this role.
session_name : str, default None
An optional identifier for the assumed role session.
external_id : str, default None
An optional unique identifier that might be required when you assume
a role in another account.
load_frequency : int, default 900
The frequency (in seconds) with which temporary credentials from an
assumed role session will be refreshed.
region : str, default None
AWS region to connect to. If not set, the AWS SDK will attempt to
determine the region using heuristics such as environment variables,
configuration profile, EC2 metadata, or default to 'us-east-1' when SDK
version <1.8. One can also use :func:`pyarrow.fs.resolve_s3_region` to
automatically resolve the region from a bucket name.
request_timeout : double, default None
Socket read timeouts on Windows and macOS, in seconds.
If omitted, the AWS SDK default value is used (typically 3 seconds).
This option is ignored on non-Windows, non-macOS systems.
connect_timeout : double, default None
Socket connection timeout, in seconds.
If omitted, the AWS SDK default value is used (typically 1 second).
scheme : str, default 'https'
S3 connection transport scheme.
endpoint_override : str, default None
Override region with a connect string such as "localhost:9000"
background_writes : bool, default True
Whether file writes will be issued in the background, without
blocking.
default_metadata : mapping or pyarrow.KeyValueMetadata, default None
Default metadata for open_output_stream. This will be ignored if
non-empty metadata is passed to open_output_stream.
proxy_options : dict or str, default None
If a proxy is used, provide the options here. Supported options are:
'scheme' (str: 'http' or 'https'; required), 'host' (str; required),
'port' (int; required), 'username' (str; optional),
'password' (str; optional).
A proxy URI (str) can also be provided, in which case these options
will be derived from the provided URI.
The following are equivalent::
S3FileSystem(proxy_options='http://username:password@localhost:8020')
S3FileSystem(proxy_options={'scheme': 'http', 'host': 'localhost',
'port': 8020, 'username': 'username',
'password': 'password'})
allow_delayed_open : bool, default False
Whether to allow file-open methods to return before the actual open. This option
may reduce latency as it decreases the number of round trips.
The downside is failures such as opening a file in a non-existing bucket will
only be reported when actual I/O is done (at worst, when attempting to close the
file).
allow_bucket_creation : bool, default False
Whether to allow directory creation at the bucket-level. This option may also be
passed in a URI query parameter.
allow_bucket_deletion : bool, default False
Whether to allow directory deletion at the bucket-level. This option may also be
passed in a URI query parameter.
check_directory_existence_before_creation : bool, default false
Whether to check the directory existence before creating it.
If false, when creating a directory the code will not check if it already
exists or not. It's an optimization to try directory creation and catch the error,
rather than issue two dependent I/O calls.
If true, when creating a directory the code will only create the directory when necessary
at the cost of extra I/O calls. This can be used for key/value cloud storage which has
a hard rate limit to number of object mutation operations or scenarios such as
the directories already exist and you do not have creation access.
retry_strategy : S3RetryStrategy, default AwsStandardS3RetryStrategy(max_attempts=3)
The retry strategy to use with S3; fail after max_attempts. Available
strategies are AwsStandardS3RetryStrategy, AwsDefaultS3RetryStrategy.
force_virtual_addressing : bool, default False
Whether to use virtual addressing of buckets.
If true, then virtual addressing is always enabled.
If false, then virtual addressing is only enabled if `endpoint_override` is empty.
This can be used for non-AWS backends that only support virtual hosted-style access.
tls_ca_file_path : str, default None
If set, this should be the path of a file containing TLS certificates
in PEM format which will be used for TLS verification.
Examples
--------
>>> from pyarrow import fs
>>> s3 = fs.S3FileSystem(region='us-west-2')
>>> s3.get_file_info(fs.FileSelector(
... 'power-analysis-ready-datastore/power_901_constants.zarr/FROCEAN', recursive=True
... )) # doctest: +SKIP
[<FileInfo for 'power-analysis-ready-datastore/power_901_constants.zarr/FROCEAN/.zarray...
For usage of the methods see examples for :func:`~pyarrow.fs.LocalFileSystem`.
"""
cdef:
CS3FileSystem* s3fs
def __init__(self, *, access_key=None, secret_key=None, session_token=None,
bint anonymous=False, region=None, request_timeout=None,
connect_timeout=None, scheme=None, endpoint_override=None,
bint background_writes=True, default_metadata=None,
role_arn=None, session_name=None, external_id=None,
load_frequency=900, proxy_options=None,
allow_delayed_open=False,
allow_bucket_creation=False, allow_bucket_deletion=False,
check_directory_existence_before_creation=False,
retry_strategy: S3RetryStrategy = AwsStandardS3RetryStrategy(
max_attempts=3),
force_virtual_addressing=False, tls_ca_file_path=None):
cdef:
optional[CS3Options] options
shared_ptr[CS3FileSystem] wrapped
# Need to do this before initializing `options` as the S3Options
# constructor has a debug check against use after S3 finalization.
ensure_s3_initialized()
if access_key is not None and secret_key is None:
raise ValueError(
'In order to initialize with explicit credentials both '
'access_key and secret_key must be provided, '
'`secret_key` is not set.'
)
elif access_key is None and secret_key is not None:
raise ValueError(
'In order to initialize with explicit credentials both '
'access_key and secret_key must be provided, '
'`access_key` is not set.'
)
elif session_token is not None and (access_key is None or
secret_key is None):
raise ValueError(
'In order to initialize a session with temporary credentials, '
'both secret_key and access_key must be provided in addition '
'to session_token.'
)
elif (access_key is not None or secret_key is not None):
if anonymous:
raise ValueError(
'Cannot pass anonymous=True together with access_key '
'and secret_key.')
if role_arn:
raise ValueError(
'Cannot provide role_arn with access_key and secret_key')
if session_token is None:
session_token = ""
options = CS3Options.FromAccessKey(
tobytes(access_key),
tobytes(secret_key),
tobytes(session_token)
)
elif anonymous:
if role_arn:
raise ValueError(
'Cannot provide role_arn with anonymous=True')
options = CS3Options.Anonymous()
elif role_arn:
if session_name is None:
session_name = ''
if external_id is None:
external_id = ''
options = CS3Options.FromAssumeRole(
tobytes(role_arn),
tobytes(session_name),
tobytes(external_id),
load_frequency
)
else:
options = CS3Options.Defaults()
if region is not None:
options.value().region = tobytes(region)
if request_timeout is not None:
options.value().request_timeout = request_timeout
if connect_timeout is not None:
options.value().connect_timeout = connect_timeout
if scheme is not None:
options.value().scheme = tobytes(scheme)
if endpoint_override is not None:
options.value().endpoint_override = tobytes(endpoint_override)
if background_writes is not None:
options.value().background_writes = background_writes
if default_metadata is not None:
if not isinstance(default_metadata, KeyValueMetadata):
default_metadata = KeyValueMetadata(default_metadata)
options.value().default_metadata = pyarrow_unwrap_metadata(
default_metadata)
if proxy_options is not None:
if isinstance(proxy_options, dict):
options.value().proxy_options.scheme = tobytes(
proxy_options["scheme"])
options.value().proxy_options.host = tobytes(
proxy_options["host"])
options.value().proxy_options.port = proxy_options["port"]
proxy_username = proxy_options.get("username", None)
if proxy_username:
options.value().proxy_options.username = tobytes(
proxy_username)
proxy_password = proxy_options.get("password", None)
if proxy_password:
options.value().proxy_options.password = tobytes(
proxy_password)
elif isinstance(proxy_options, str):
options.value().proxy_options = GetResultValue(
CS3ProxyOptions.FromUriString(tobytes(proxy_options)))
else:
raise TypeError(
"'proxy_options': expected 'dict' or 'str', "
f"got {type(proxy_options)} instead.")
options.value().allow_delayed_open = allow_delayed_open
options.value().allow_bucket_creation = allow_bucket_creation
options.value().allow_bucket_deletion = allow_bucket_deletion
options.value().check_directory_existence_before_creation = check_directory_existence_before_creation
options.value().force_virtual_addressing = force_virtual_addressing
if isinstance(retry_strategy, AwsStandardS3RetryStrategy):
options.value().retry_strategy = CS3RetryStrategy.GetAwsStandardRetryStrategy(
retry_strategy.max_attempts)
elif isinstance(retry_strategy, AwsDefaultS3RetryStrategy):
options.value().retry_strategy = CS3RetryStrategy.GetAwsDefaultRetryStrategy(
retry_strategy.max_attempts)
else:
raise ValueError(f'Invalid retry_strategy {retry_strategy!r}')
if tls_ca_file_path is not None:
options.value().tls_ca_file_path = tobytes(tls_ca_file_path)
with nogil:
wrapped = GetResultValue(CS3FileSystem.Make(options.value()))
self.init(<shared_ptr[CFileSystem]> wrapped)
cdef init(self, const shared_ptr[CFileSystem]& wrapped):
FileSystem.init(self, wrapped)
self.s3fs = <CS3FileSystem*> wrapped.get()
@staticmethod
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return S3FileSystem(**kwargs)
def __reduce__(self):
cdef CS3Options opts = self.s3fs.options()
# if creds were explicitly provided, then use them
# else obtain them as they were last time.
if opts.credentials_kind == CS3CredentialsKind_Explicit:
access_key = frombytes(opts.GetAccessKey())
secret_key = frombytes(opts.GetSecretKey())
session_token = frombytes(opts.GetSessionToken())
else:
access_key = None
secret_key = None
session_token = None
return (
S3FileSystem._reconstruct, (dict(
access_key=access_key,
secret_key=secret_key,
session_token=session_token,
anonymous=(opts.credentials_kind ==
CS3CredentialsKind_Anonymous),
region=frombytes(opts.region),
scheme=frombytes(opts.scheme),
connect_timeout=opts.connect_timeout,
request_timeout=opts.request_timeout,
endpoint_override=frombytes(opts.endpoint_override),
role_arn=frombytes(opts.role_arn),
session_name=frombytes(opts.session_name),
external_id=frombytes(opts.external_id),
load_frequency=opts.load_frequency,
background_writes=opts.background_writes,
allow_delayed_open=opts.allow_delayed_open,
allow_bucket_creation=opts.allow_bucket_creation,
allow_bucket_deletion=opts.allow_bucket_deletion,
check_directory_existence_before_creation=opts.check_directory_existence_before_creation,
default_metadata=pyarrow_wrap_metadata(opts.default_metadata),
proxy_options={'scheme': frombytes(opts.proxy_options.scheme),
'host': frombytes(opts.proxy_options.host),
'port': opts.proxy_options.port,
'username': frombytes(
opts.proxy_options.username),
'password': frombytes(
opts.proxy_options.password)},
force_virtual_addressing=opts.force_virtual_addressing,
tls_ca_file_path=frombytes(opts.tls_ca_file_path),
),)
)
@property
def region(self):
"""
The AWS region this filesystem connects to.
"""
return frombytes(self.s3fs.region())

View File

@@ -0,0 +1,481 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from cython.operator cimport dereference as deref
from libcpp.vector cimport vector as std_vector
from pyarrow import Buffer, py_buffer
from pyarrow._compute cimport Expression
from pyarrow.lib import frombytes, tobytes
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_substrait cimport *
try:
import substrait as py_substrait
except ImportError:
py_substrait = None
else:
import substrait.proto # no-cython-lint
# TODO GH-37235: Fix exception handling
cdef CDeclaration _create_named_table_provider(
dict named_args, const std_vector[c_string]& names, const CSchema& schema
) noexcept:
cdef:
c_string c_name
shared_ptr[CTable] c_in_table
shared_ptr[CTableSourceNodeOptions] c_tablesourceopts
shared_ptr[CExecNodeOptions] c_input_node_opts
vector[CDeclaration.Input] no_c_inputs
py_names = []
for i in range(names.size()):
c_name = names[i]
py_names.append(frombytes(c_name))
py_schema = pyarrow_wrap_schema(make_shared[CSchema](schema))
py_table = named_args["provider"](py_names, py_schema)
c_in_table = pyarrow_unwrap_table(py_table)
c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table)
c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions](
c_tablesourceopts)
return CDeclaration(tobytes("table_source"),
no_c_inputs, c_input_node_opts)
def run_query(plan, *, table_provider=None, use_threads=True):
"""
Execute a Substrait plan and read the results as a RecordBatchReader.
Parameters
----------
plan : Union[Buffer, bytes]
The serialized Substrait plan to execute.
table_provider : object (optional)
A function to resolve any NamedTable relation to a table.
The function will receive two arguments which will be a list
of strings representing the table name and a pyarrow.Schema representing
the expected schema and should return a pyarrow.Table.
use_threads : bool, default True
If True then multiple threads will be used to run the query. If False then
all CPU intensive work will be done on the calling thread.
Returns
-------
RecordBatchReader
A reader containing the result of the executed query
Examples
--------
>>> import pyarrow as pa
>>> from pyarrow.lib import tobytes
>>> import pyarrow.substrait as substrait
>>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
>>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
>>> def table_provider(names, schema):
... if not names:
... raise Exception("No names provided")
... elif names[0] == "t1":
... return test_table_1
... elif names[1] == "t2":
... return test_table_2
... else:
... raise Exception("Unrecognized table name")
...
>>> substrait_query = '''
... {
... "relations": [
... {"rel": {
... "read": {
... "base_schema": {
... "struct": {
... "types": [
... {"i64": {}}
... ]
... },
... "names": [
... "x"
... ]
... },
... "namedTable": {
... "names": ["t1"]
... }
... }
... }}
... ]
... }
... '''
>>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
>>> reader = pa.substrait.run_query(buf, table_provider=table_provider)
>>> reader.read_all()
pyarrow.Table
x: int64
----
x: [[1,2,3]]
"""
cdef:
CResult[shared_ptr[CRecordBatchReader]] c_res_reader
shared_ptr[CRecordBatchReader] c_reader
RecordBatchReader reader
shared_ptr[CBuffer] c_buf_plan
CConversionOptions c_conversion_options
c_bool c_use_threads
c_use_threads = use_threads
if isinstance(plan, (bytes, memoryview)):
c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan))
elif isinstance(plan, Buffer):
c_buf_plan = pyarrow_unwrap_buffer(plan)
else:
raise TypeError(
f"Expected 'pyarrow.Buffer' or bytes, got '{type(plan)}'")
if table_provider is not None:
named_table_args = {
"provider": table_provider
}
c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider](
&_create_named_table_provider, named_table_args)
with nogil:
c_res_reader = ExecuteSerializedPlan(
deref(c_buf_plan), default_extension_id_registry(),
GetFunctionRegistry(), c_conversion_options, c_use_threads)
c_reader = GetResultValue(c_res_reader)
reader = RecordBatchReader.__new__(RecordBatchReader)
reader.reader = c_reader
return reader
def _parse_json_plan(plan):
"""
Parse a JSON plan into equivalent serialized Protobuf.
Parameters
----------
plan : bytes
Substrait plan in JSON.
Returns
-------
Buffer
A buffer containing the serialized Protobuf plan.
"""
cdef:
CResult[shared_ptr[CBuffer]] c_res_buffer
c_string c_str_plan
shared_ptr[CBuffer] c_buf_plan
c_str_plan = plan
c_res_buffer = SerializeJsonPlan(c_str_plan)
with nogil:
c_buf_plan = GetResultValue(c_res_buffer)
return pyarrow_wrap_buffer(c_buf_plan)
class SubstraitSchema:
"""A Schema encoded for Substrait usage.
The SubstraitSchema contains a schema represented
both as a substrait ``NamedStruct`` and as an
``ExtendedExpression``.
The ``ExtendedExpression`` is available for cases where types
used by the schema require extensions to decode them.
In such case the schema will be the ``base_schema`` of the
``ExtendedExpression`` and all extensions will be provided.
"""
def __init__(self, schema, expression):
self.schema = schema
self.expression = expression
def to_pysubstrait(self):
"""Convert the schema to a substrait-python ExtendedExpression object."""
if py_substrait is None:
raise ImportError("The 'substrait' package is required.")
return py_substrait.proto.ExtendedExpression.FromString(self.expression)
def serialize_schema(schema):
"""
Serialize a schema into a SubstraitSchema object.
Parameters
----------
schema : Schema
The schema to serialize
Returns
-------
SubstraitSchema
The schema stored in a SubstraitSchema object.
"""
return SubstraitSchema(
schema=_serialize_namedstruct_schema(schema),
expression=serialize_expressions([], [], schema, allow_arrow_extensions=True)
)
def _serialize_namedstruct_schema(schema):
cdef:
CResult[shared_ptr[CBuffer]] c_res_buffer
shared_ptr[CBuffer] c_buffer
CConversionOptions c_conversion_options
CExtensionSet c_extensions
with nogil:
c_res_buffer = SerializeSchema(deref((<Schema> schema).sp_schema), &c_extensions, c_conversion_options)
c_buffer = GetResultValue(c_res_buffer)
return memoryview(pyarrow_wrap_buffer(c_buffer))
def deserialize_schema(buf):
"""
Deserialize a ``NamedStruct`` Substrait message
or a SubstraitSchema object into an Arrow Schema object
Parameters
----------
buf : Buffer or bytes or SubstraitSchema
The message to deserialize
Returns
-------
Schema
The deserialized schema
"""
cdef:
shared_ptr[CBuffer] c_buffer
CResult[shared_ptr[CSchema]] c_res_schema
shared_ptr[CSchema] c_schema
CConversionOptions c_conversion_options
CExtensionSet c_extensions
if isinstance(buf, SubstraitSchema):
return deserialize_expressions(buf.expression).schema
if isinstance(buf, (bytes, memoryview)):
c_buffer = pyarrow_unwrap_buffer(py_buffer(buf))
elif isinstance(buf, Buffer):
c_buffer = pyarrow_unwrap_buffer(buf)
else:
raise TypeError(
f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'")
with nogil:
c_res_schema = DeserializeSchema(
deref(c_buffer), c_extensions, c_conversion_options)
c_schema = GetResultValue(c_res_schema)
return pyarrow_wrap_schema(c_schema)
def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False):
"""
Serialize a collection of expressions into Substrait
Substrait expressions must be bound to a schema. For example,
the Substrait expression ``a:i32 + b:i32`` is different from the
Substrait expression ``a:i64 + b:i64``. Pyarrow expressions are
typically unbound. For example, both of the above expressions
would be represented as ``a + b`` in pyarrow.
This means a schema must be provided when serializing an expression.
It also means that the serialization may fail if a matching function
call cannot be found for the expression.
Parameters
----------
exprs : list of Expression
The expressions to serialize
names : list of str
Names for the expressions
schema : Schema
The schema the expressions will be bound to
allow_arrow_extensions : bool, default False
If False then only functions that are part of the core Substrait function
definitions will be allowed. Set this to True to allow pyarrow-specific functions
and user defined functions but the result may not be accepted by other
compute libraries.
Returns
-------
Buffer
An ExtendedExpression message containing the serialized expressions
"""
cdef:
CResult[shared_ptr[CBuffer]] c_res_buffer
shared_ptr[CBuffer] c_buffer
CNamedExpression c_named_expr
CBoundExpressions c_bound_exprs
CConversionOptions c_conversion_options
if len(exprs) != len(names):
raise ValueError("exprs and names need to have the same length")
for expr, name in zip(exprs, names):
if not isinstance(expr, Expression):
raise TypeError(f"Expected Expression, got '{type(expr)}' in exprs")
if not isinstance(name, str):
raise TypeError(f"Expected str, got '{type(name)}' in names")
c_named_expr.expression = (<Expression> expr).unwrap()
c_named_expr.name = tobytes(<str> name)
c_bound_exprs.named_expressions.push_back(c_named_expr)
c_bound_exprs.schema = (<Schema> schema).sp_schema
c_conversion_options.allow_arrow_extensions = allow_arrow_extensions
with nogil:
c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options)
c_buffer = GetResultValue(c_res_buffer)
return memoryview(pyarrow_wrap_buffer(c_buffer))
cdef class BoundExpressions(_Weakrefable):
"""
A collection of named expressions and the schema they are bound to
This is equivalent to the Substrait ExtendedExpression message
"""
cdef:
CBoundExpressions c_bound_exprs
def __init__(self):
msg = 'BoundExpressions is an abstract class thus cannot be initialized.'
raise TypeError(msg)
cdef void init(self, CBoundExpressions bound_expressions):
self.c_bound_exprs = bound_expressions
@property
def schema(self):
"""
The common schema that all expressions are bound to
"""
return pyarrow_wrap_schema(self.c_bound_exprs.schema)
@property
def expressions(self):
"""
A dict from expression name to expression
"""
expr_dict = {}
for named_expr in self.c_bound_exprs.named_expressions:
name = frombytes(named_expr.name)
expr = Expression.wrap(named_expr.expression)
expr_dict[name] = expr
return expr_dict
@staticmethod
cdef wrap(const CBoundExpressions& bound_expressions):
cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions)
self.init(bound_expressions)
return self
@classmethod
def from_substrait(cls, message):
"""
Convert a Substrait message into a BoundExpressions object
Parameters
----------
message : Buffer or bytes or protobuf Message
The message to convert to a BoundExpressions object
Returns
-------
BoundExpressions
The converted expressions, their names, and the bound schema
"""
if isinstance(message, (bytes, memoryview)):
return deserialize_expressions(message)
elif isinstance(message, Buffer):
return deserialize_expressions(message)
else:
try:
return deserialize_expressions(message.SerializeToString())
except AttributeError:
raise TypeError(
f"Expected 'pyarrow.Buffer' or bytes or protobuf Message, got '{type(message)}'")
def deserialize_expressions(buf):
"""
Deserialize an ExtendedExpression Substrait message into a BoundExpressions object
Parameters
----------
buf : Buffer or bytes
The message to deserialize
Returns
-------
BoundExpressions
The deserialized expressions, their names, and the bound schema
"""
cdef:
shared_ptr[CBuffer] c_buffer
CResult[CBoundExpressions] c_res_bound_exprs
CBoundExpressions c_bound_exprs
if isinstance(buf, (bytes, memoryview)):
c_buffer = pyarrow_unwrap_buffer(py_buffer(buf))
elif isinstance(buf, Buffer):
c_buffer = pyarrow_unwrap_buffer(buf)
else:
raise TypeError(
f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'")
with nogil:
c_res_bound_exprs = DeserializeExpressions(deref(c_buffer))
c_bound_exprs = GetResultValue(c_res_bound_exprs)
return BoundExpressions.wrap(c_bound_exprs)
def get_supported_functions():
"""
Get a list of Substrait functions that the underlying
engine currently supports.
Returns
-------
list[str]
A list of function ids encoded as '{uri}#{name}'
"""
cdef:
ExtensionIdRegistry* c_id_registry
std_vector[c_string] c_ids
c_id_registry = default_extension_id_registry()
c_ids = c_id_registry.GetSupportedSubstraitFunctions()
functions_list = []
for c_id in c_ids:
functions_list.append(frombytes(c_id))
return functions_list

View File

@@ -0,0 +1,418 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ---------------------------------------------------------------------
# Implement Internal ExecPlan bindings
# cython: profile=False
# distutils: language = c++
# cython: language_level = 3
from pyarrow.lib import Table, RecordBatch, array
from pyarrow.compute import Expression, field
try:
from pyarrow._acero import ( # noqa
Declaration,
ExecNodeOptions,
TableSourceNodeOptions,
FilterNodeOptions,
ProjectNodeOptions,
AggregateNodeOptions,
OrderByNodeOptions,
HashJoinNodeOptions,
AsofJoinNodeOptions,
)
except ImportError as exc:
raise ImportError(
f"The pyarrow installation is not built with support for 'acero' ({str(exc)})"
) from None
try:
import pyarrow.dataset as ds
from pyarrow._dataset import ScanNodeOptions
except ImportError:
class DatasetModuleStub:
class Dataset:
pass
class InMemoryDataset:
pass
ds = DatasetModuleStub
def _dataset_to_decl(dataset, use_threads=True, implicit_ordering=False):
decl = Declaration("scan", ScanNodeOptions(
dataset, use_threads=use_threads,
implicit_ordering=implicit_ordering))
# Get rid of special dataset columns
# "__fragment_index", "__batch_index", "__last_in_fragment", "__filename"
projections = [field(f) for f in dataset.schema.names]
decl = Declaration.from_sequence(
[decl, Declaration("project", ProjectNodeOptions(projections))]
)
filter_expr = dataset._scan_options.get("filter")
if filter_expr is not None:
# Filters applied in CScanNodeOptions are "best effort" for the scan node itself
# so we always need to inject an additional Filter node to apply them for real.
decl = Declaration.from_sequence(
[decl, Declaration("filter", FilterNodeOptions(filter_expr))]
)
return decl
def _perform_join(join_type, left_operand, left_keys,
right_operand, right_keys,
left_suffix=None, right_suffix=None,
use_threads=True, coalesce_keys=False,
output_type=Table, filter_expression=None):
"""
Perform join of two tables or datasets.
The result will be an output table with the result of the join operation
Parameters
----------
join_type : str
One of supported join types.
left_operand : Table or Dataset
The left operand for the join operation.
left_keys : str or list[str]
The left key (or keys) on which the join operation should be performed.
right_operand : Table or Dataset
The right operand for the join operation.
right_keys : str or list[str]
The right key (or keys) on which the join operation should be performed.
left_suffix : str, default None
Which suffix to add to left column names. This prevents confusion
when the columns in left and right operands have colliding names.
right_suffix : str, default None
Which suffix to add to the right column names. This prevents confusion
when the columns in left and right operands have colliding names.
use_threads : bool, default True
Whether to use multithreading or not.
coalesce_keys : bool, default False
If the duplicated keys should be omitted from one of the sides
in the join result.
output_type: Table or InMemoryDataset
The output type for the exec plan result.
filter_expression : pyarrow.compute.Expression
Residual filter which is applied to matching row.
Returns
-------
result_table : Table or InMemoryDataset
"""
if not isinstance(left_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(left_operand)}")
if not isinstance(right_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(right_operand)}")
# Prepare left and right tables Keys to send them to the C++ function
left_keys_order = {}
if not isinstance(left_keys, (tuple, list)):
left_keys = [left_keys]
for idx, key in enumerate(left_keys):
left_keys_order[key] = idx
right_keys_order = {}
if not isinstance(right_keys, (list, tuple)):
right_keys = [right_keys]
for idx, key in enumerate(right_keys):
right_keys_order[key] = idx
# By default expose all columns on both left and right table
left_columns = left_operand.schema.names
right_columns = right_operand.schema.names
# Pick the join type
if join_type == "left semi" or join_type == "left anti":
right_columns = []
elif join_type == "right semi" or join_type == "right anti":
left_columns = []
elif join_type == "inner" or join_type == "left outer":
right_columns = [
col for col in right_columns if col not in right_keys_order
]
elif join_type == "right outer":
left_columns = [
col for col in left_columns if col not in left_keys_order
]
# Turn the columns to vectors of FieldRefs
# and set aside indices of keys.
left_column_keys_indices = {}
for idx, colname in enumerate(left_columns):
if colname in left_keys:
left_column_keys_indices[colname] = idx
right_column_keys_indices = {}
for idx, colname in enumerate(right_columns):
if colname in right_keys:
right_column_keys_indices[colname] = idx
# Add the join node to the execplan
if isinstance(left_operand, ds.Dataset):
left_source = _dataset_to_decl(left_operand, use_threads=use_threads)
else:
left_source = Declaration("table_source", TableSourceNodeOptions(left_operand))
if isinstance(right_operand, ds.Dataset):
right_source = _dataset_to_decl(right_operand, use_threads=use_threads)
else:
right_source = Declaration(
"table_source", TableSourceNodeOptions(right_operand)
)
if coalesce_keys:
join_opts = HashJoinNodeOptions(
join_type, left_keys, right_keys, left_columns, right_columns,
output_suffix_for_left=left_suffix or "",
output_suffix_for_right=right_suffix or "",
filter_expression=filter_expression,
)
else:
join_opts = HashJoinNodeOptions(
join_type, left_keys, right_keys,
output_suffix_for_left=left_suffix or "",
output_suffix_for_right=right_suffix or "",
filter_expression=filter_expression,
)
decl = Declaration(
"hashjoin", options=join_opts, inputs=[left_source, right_source]
)
if coalesce_keys and join_type == "full outer":
# In case of full outer joins, the join operation will output all columns
# so that we can coalesce the keys and exclude duplicates in a subsequent
# projection.
left_columns_set = set(left_columns)
right_columns_set = set(right_columns)
# Where the right table columns start.
right_operand_index = len(left_columns)
projected_col_names = []
projections = []
for idx, col in enumerate(left_columns + right_columns):
if idx < len(left_columns) and col in left_column_keys_indices:
# Include keys only once and coalesce left+right table keys.
projected_col_names.append(col)
# Get the index of the right key that is being paired
# with this left key. We do so by retrieving the name
# of the right key that is in the same position in the provided keys
# and then looking up the index for that name in the right table.
right_key_index = right_column_keys_indices[
right_keys[left_keys_order[col]]]
projections.append(
Expression._call("coalesce", [
Expression._field(idx), Expression._field(
right_operand_index+right_key_index)
])
)
elif idx >= right_operand_index and col in right_column_keys_indices:
# Do not include right table keys. As they would lead to duplicated keys
continue
else:
# For all the other columns include them as they are.
# Just recompute the suffixes that the join produced as the projection
# would lose them otherwise.
if (
left_suffix and idx < right_operand_index
and col in right_columns_set
):
col += left_suffix
if (
right_suffix and idx >= right_operand_index
and col in left_columns_set
):
col += right_suffix
projected_col_names.append(col)
projections.append(
Expression._field(idx)
)
projection = Declaration(
"project", ProjectNodeOptions(projections, projected_col_names)
)
decl = Declaration.from_sequence([decl, projection])
result_table = decl.to_table(use_threads=use_threads)
if output_type == Table:
return result_table
elif output_type == ds.InMemoryDataset:
return ds.InMemoryDataset(result_table)
else:
raise TypeError("Unsupported output type")
def _perform_join_asof(left_operand, left_on, left_by,
right_operand, right_on, right_by,
tolerance, use_threads=True,
output_type=Table):
"""
Perform asof join of two tables or datasets.
The result will be an output table with the result of the join operation
Parameters
----------
left_operand : Table or Dataset
The left operand for the join operation.
left_on : str
The left key (or keys) on which the join operation should be performed.
left_by: str or list[str]
The left key (or keys) on which the join operation should be performed.
right_operand : Table or Dataset
The right operand for the join operation.
right_on : str or list[str]
The right key (or keys) on which the join operation should be performed.
right_by: str or list[str]
The right key (or keys) on which the join operation should be performed.
tolerance : int
The tolerance to use for the asof join. The tolerance is interpreted in
the same units as the "on" key.
output_type: Table or InMemoryDataset
The output type for the exec plan result.
Returns
-------
result_table : Table or InMemoryDataset
"""
if not isinstance(left_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(left_operand)}")
if not isinstance(right_operand, (Table, ds.Dataset)):
raise TypeError(f"Expected Table or Dataset, got {type(right_operand)}")
if not isinstance(left_by, (tuple, list)):
left_by = [left_by]
if not isinstance(right_by, (tuple, list)):
right_by = [right_by]
# AsofJoin does not return on or by columns for right_operand.
right_columns = [
col for col in right_operand.schema.names
if col not in [right_on] + right_by
]
columns_collisions = set(left_operand.schema.names) & set(right_columns)
if columns_collisions:
raise ValueError(
f"Columns {columns_collisions} present in both tables. "
"AsofJoin does not support column collisions."
)
# Add the join node to the execplan
if isinstance(left_operand, ds.Dataset):
left_source = _dataset_to_decl(
left_operand,
use_threads=use_threads,
implicit_ordering=True)
else:
left_source = Declaration(
"table_source", TableSourceNodeOptions(left_operand),
)
if isinstance(right_operand, ds.Dataset):
right_source = _dataset_to_decl(
right_operand, use_threads=use_threads,
implicit_ordering=True)
else:
right_source = Declaration(
"table_source", TableSourceNodeOptions(right_operand)
)
join_opts = AsofJoinNodeOptions(
left_on, left_by, right_on, right_by, tolerance
)
decl = Declaration(
"asofjoin", options=join_opts, inputs=[left_source, right_source]
)
result_table = decl.to_table(use_threads=use_threads)
if output_type == Table:
return result_table
elif output_type == ds.InMemoryDataset:
return ds.InMemoryDataset(result_table)
else:
raise TypeError("Unsupported output type")
def _filter_table(table, expression):
"""Filter rows of a table based on the provided expression.
The result will be an output table with only the rows matching
the provided expression.
Parameters
----------
table : Table or RecordBatch
Table that should be filtered.
expression : Expression
The expression on which rows should be filtered.
Returns
-------
Table or RecordBatch
"""
is_batch = False
if isinstance(table, RecordBatch):
table = Table.from_batches([table])
is_batch = True
decl = Declaration.from_sequence([
Declaration("table_source", options=TableSourceNodeOptions(table)),
Declaration("filter", options=FilterNodeOptions(expression))
])
result = decl.to_table(use_threads=True)
if is_batch:
if result.num_rows > 0:
result = result.combine_chunks().to_batches()[0]
else:
arrays = [array([], type=field.type) for field in result.schema]
result = RecordBatch.from_arrays(arrays, schema=result.schema)
return result
def _sort_source(table_or_dataset, sort_keys, output_type=Table, **kwargs):
if isinstance(table_or_dataset, ds.Dataset):
data_source = _dataset_to_decl(table_or_dataset, use_threads=True)
else:
data_source = Declaration(
"table_source", TableSourceNodeOptions(table_or_dataset)
)
order_by = Declaration("order_by", OrderByNodeOptions(sort_keys, **kwargs))
decl = Declaration.from_sequence([data_source, order_by])
result_table = decl.to_table(use_threads=True)
if output_type == Table:
return result_table
elif output_type == ds.InMemoryDataset:
return ds.InMemoryDataset(result_table)
else:
raise TypeError("Unsupported output type")
def _group_by(table, aggregates, keys, use_threads=True):
decl = Declaration.from_sequence([
Declaration("table_source", TableSourceNodeOptions(table)),
Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys))
])
return decl.to_table(use_threads=use_threads)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
def benchmark_PandasObjectIsNull(list obj):
Benchmark_PandasObjectIsNull(obj)

View File

@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# flake8: noqa
from pyarrow.lib import benchmark_PandasObjectIsNull

View File

@@ -0,0 +1,150 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
cdef class StringBuilder(_Weakrefable):
"""
Builder class for UTF8 strings.
This class exposes facilities for incrementally adding string values and
building the null bitmap for a pyarrow.Array (type='string').
"""
cdef:
unique_ptr[CStringBuilder] builder
def __cinit__(self, MemoryPool memory_pool=None):
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
self.builder.reset(new CStringBuilder(pool))
def append(self, value):
"""
Append a single value to the builder.
The value can either be a string/bytes object or a null value
(np.nan or None).
Parameters
----------
value : string/bytes or np.nan/None
The value to append to the string array builder.
"""
if isinstance(value, (bytes, str)):
self.builder.get().Append(tobytes(value))
elif value is None or math.isnan(value):
self.builder.get().AppendNull()
else:
raise TypeError('StringBuilder only accepts string objects')
def append_values(self, values):
"""
Append all the values from an iterable.
Parameters
----------
values : iterable of string/bytes or np.nan/None values
The values to append to the string array builder.
"""
for value in values:
self.append(value)
def finish(self):
"""
Return result of builder as an Array object; also resets the builder.
Returns
-------
array : pyarrow.Array
"""
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
return pyarrow_wrap_array(out)
@property
def null_count(self):
return self.builder.get().null_count()
def __len__(self):
return self.builder.get().length()
cdef class StringViewBuilder(_Weakrefable):
"""
Builder class for UTF8 string views.
This class exposes facilities for incrementally adding string values and
building the null bitmap for a pyarrow.Array (type='string_view').
"""
cdef:
unique_ptr[CStringViewBuilder] builder
def __cinit__(self, MemoryPool memory_pool=None):
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
self.builder.reset(new CStringViewBuilder(pool))
def append(self, value):
"""
Append a single value to the builder.
The value can either be a string/bytes object or a null value
(np.nan or None).
Parameters
----------
value : string/bytes or np.nan/None
The value to append to the string array builder.
"""
if isinstance(value, (bytes, str)):
self.builder.get().Append(tobytes(value))
elif value is None or math.isnan(value):
self.builder.get().AppendNull()
else:
raise TypeError('StringViewBuilder only accepts string objects')
def append_values(self, values):
"""
Append all the values from an iterable.
Parameters
----------
values : iterable of string/bytes or np.nan/None values
The values to append to the string array builder.
"""
for value in values:
self.append(value)
def finish(self):
"""
Return result of builder as an Array object; also resets the builder.
Returns
-------
array : pyarrow.Array
"""
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
return pyarrow_wrap_array(out)
@property
def null_count(self):
return self.builder.get().null_count()
def __len__(self):
return self.builder.get().length()

View File

@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import cffi
c_source = """
struct ArrowSchema {
// Array type description
const char* format;
const char* name;
const char* metadata;
int64_t flags;
int64_t n_children;
struct ArrowSchema** children;
struct ArrowSchema* dictionary;
// Release callback
void (*release)(struct ArrowSchema*);
// Opaque producer-specific data
void* private_data;
};
struct ArrowArray {
// Array data description
int64_t length;
int64_t null_count;
int64_t offset;
int64_t n_buffers;
int64_t n_children;
const void** buffers;
struct ArrowArray** children;
struct ArrowArray* dictionary;
// Release callback
void (*release)(struct ArrowArray*);
// Opaque producer-specific data
void* private_data;
};
struct ArrowArrayStream {
int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out);
int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out);
const char* (*get_last_error)(struct ArrowArrayStream*);
// Release callback
void (*release)(struct ArrowArrayStream*);
// Opaque producer-specific data
void* private_data;
};
typedef int32_t ArrowDeviceType;
struct ArrowDeviceArray {
struct ArrowArray array;
int64_t device_id;
ArrowDeviceType device_type;
void* sync_event;
int64_t reserved[3];
};
"""
# TODO use out-of-line mode for faster import and avoid C parsing
ffi = cffi.FFI()
ffi.cdef(c_source)

View File

@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
def encode_file_path(path):
if isinstance(path, str):
# POSIX systems can handle utf-8. UTF8 is converted to utf16-le in
# libarrow
encoded_path = path.encode('utf-8')
else:
encoded_path = path
# Windows file system requires utf-16le for file names; Arrow C++ libraries
# will convert utf8 to utf16
return encoded_path
# Starting with Python 3.7, dicts are guaranteed to be insertion-ordered.
ordered_dict = dict
try:
import cloudpickle as pickle
except ImportError:
import pickle
def tobytes(o):
"""
Encode a unicode or bytes string to bytes.
Parameters
----------
o : str or bytes
Input string.
"""
if isinstance(o, str):
return o.encode('utf8')
else:
return o
def frombytes(o, *, safe=False):
"""
Decode the given bytestring to unicode.
Parameters
----------
o : bytes-like
Input object.
safe : bool, default False
If true, raise on encoding errors.
"""
if safe:
return o.decode('utf8', errors='replace')
else:
return o.decode('utf8')

View File

@@ -0,0 +1,764 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pyarrow._compute import ( # noqa
Function,
FunctionOptions,
FunctionRegistry,
HashAggregateFunction,
HashAggregateKernel,
Kernel,
ScalarAggregateFunction,
ScalarAggregateKernel,
ScalarFunction,
ScalarKernel,
VectorFunction,
VectorKernel,
# Option classes
ArraySortOptions,
AssumeTimezoneOptions,
CastOptions,
CountOptions,
CumulativeOptions,
CumulativeSumOptions,
DayOfWeekOptions,
DictionaryEncodeOptions,
RunEndEncodeOptions,
ElementWiseAggregateOptions,
ExtractRegexOptions,
ExtractRegexSpanOptions,
FilterOptions,
IndexOptions,
JoinOptions,
ListSliceOptions,
ListFlattenOptions,
MakeStructOptions,
MapLookupOptions,
MatchSubstringOptions,
ModeOptions,
NullOptions,
PadOptions,
PairwiseOptions,
PartitionNthOptions,
PivotWiderOptions,
QuantileOptions,
RandomOptions,
RankOptions,
RankQuantileOptions,
ReplaceSliceOptions,
ReplaceSubstringOptions,
RoundBinaryOptions,
RoundOptions,
RoundTemporalOptions,
RoundToMultipleOptions,
ScalarAggregateOptions,
SelectKOptions,
SetLookupOptions,
SkewOptions,
SliceOptions,
SortOptions,
SplitOptions,
SplitPatternOptions,
StrftimeOptions,
StrptimeOptions,
StructFieldOptions,
TakeOptions,
TDigestOptions,
TrimOptions,
Utf8NormalizeOptions,
VarianceOptions,
WeekOptions,
WinsorizeOptions,
ZeroFillOptions,
# Functions
call_function,
function_registry,
get_function,
list_functions,
# Udf
call_tabular_function,
register_scalar_function,
register_tabular_function,
register_aggregate_function,
register_vector_function,
UdfContext,
# Expressions
Expression,
)
from collections import namedtuple
import inspect
from textwrap import dedent
import warnings
import pyarrow as pa
from pyarrow import _compute_docstrings
from pyarrow.vendored import docscrape
def _get_arg_names(func):
return func._doc.arg_names
_OptionsClassDoc = namedtuple('_OptionsClassDoc', ('params',))
def _scrape_options_class_doc(options_class):
if not options_class.__doc__:
return None
doc = docscrape.NumpyDocString(options_class.__doc__)
return _OptionsClassDoc(doc['Parameters'])
def _decorate_compute_function(wrapper, exposed_name, func, options_class):
# Decorate the given compute function wrapper with useful metadata
# and documentation.
cpp_doc = func._doc
wrapper.__arrow_compute_function__ = dict(
name=func.name,
arity=func.arity,
options_class=cpp_doc.options_class,
options_required=cpp_doc.options_required)
wrapper.__name__ = exposed_name
wrapper.__qualname__ = exposed_name
doc_pieces = []
# 1. One-line summary
summary = cpp_doc.summary
if not summary:
arg_str = "arguments" if func.arity > 1 else "argument"
summary = f"Call compute function {func.name!r} with the given {arg_str}"
doc_pieces.append(f"{summary}.\n\n")
# 2. Multi-line description
description = cpp_doc.description
if description:
doc_pieces.append(f"{description}\n\n")
doc_addition = _compute_docstrings.function_doc_additions.get(func.name)
# 3. Parameter description
doc_pieces.append(dedent("""\
Parameters
----------
"""))
# 3a. Compute function parameters
arg_names = _get_arg_names(func)
for arg_name in arg_names:
if func.kind in ('vector', 'scalar_aggregate'):
arg_type = 'Array-like'
else:
arg_type = 'Array-like or scalar-like'
doc_pieces.append(f"{arg_name} : {arg_type}\n")
doc_pieces.append(" Argument to compute function.\n")
# 3b. Compute function option values
if options_class is not None:
options_class_doc = _scrape_options_class_doc(options_class)
if options_class_doc:
for p in options_class_doc.params:
doc_pieces.append(f"{p.name} : {p.type}\n")
for s in p.desc:
doc_pieces.append(f" {s}\n")
else:
warnings.warn(f"Options class {options_class.__name__} "
f"does not have a docstring", RuntimeWarning)
options_sig = inspect.signature(options_class)
for p in options_sig.parameters.values():
doc_pieces.append(dedent(f"""\
{p.name} : optional
Parameter for {options_class.__name__} constructor. Either `options`
or `{p.name}` can be passed, but not both at the same time.
"""))
doc_pieces.append(dedent(f"""\
options : pyarrow.compute.{options_class.__name__}, optional
Alternative way of passing options.
"""))
doc_pieces.append(dedent("""\
memory_pool : pyarrow.MemoryPool, optional
If not passed, will allocate memory from the default memory pool.
"""))
# 4. Custom addition (e.g. examples)
if doc_addition is not None:
stripped = dedent(doc_addition).strip('\n')
doc_pieces.append(f"\n{stripped}\n")
wrapper.__doc__ = "".join(doc_pieces)
return wrapper
def _get_options_class(func):
class_name = func._doc.options_class
if not class_name:
return None
try:
return globals()[class_name]
except KeyError:
warnings.warn(f"Python binding for {class_name} not exposed",
RuntimeWarning)
return None
def _handle_options(name, options_class, options, args, kwargs):
if args or kwargs:
if options is not None:
raise TypeError(
f"Function {name!r} called with both an 'options' argument "
f"and additional arguments")
return options_class(*args, **kwargs)
if options is not None:
if isinstance(options, dict):
return options_class(**options)
elif isinstance(options, options_class):
return options
raise TypeError(
f"Function {name!r} expected a {options_class} parameter, "
f"got {type(options)}")
return None
def _make_generic_wrapper(func_name, func, options_class, arity):
if options_class is None:
def wrapper(*args, memory_pool=None):
if arity is not Ellipsis and len(args) != arity:
raise TypeError(
f"{func_name} takes {arity} positional argument(s), "
f"but {len(args)} were given"
)
if args and isinstance(args[0], Expression):
return Expression._call(func_name, list(args))
return func.call(args, None, memory_pool)
else:
def wrapper(*args, memory_pool=None, options=None, **kwargs):
if arity is not Ellipsis:
if len(args) < arity:
raise TypeError(
f"{func_name} takes {arity} positional argument(s), "
f"but {len(args)} were given"
)
option_args = args[arity:]
args = args[:arity]
else:
option_args = ()
options = _handle_options(func_name, options_class, options,
option_args, kwargs)
if args and isinstance(args[0], Expression):
return Expression._call(func_name, list(args), options)
return func.call(args, options, memory_pool)
return wrapper
def _make_signature(arg_names, var_arg_names, options_class):
from inspect import Parameter
params = []
for name in arg_names:
params.append(Parameter(name, Parameter.POSITIONAL_ONLY))
for name in var_arg_names:
params.append(Parameter(name, Parameter.VAR_POSITIONAL))
if options_class is not None:
options_sig = inspect.signature(options_class)
for p in options_sig.parameters.values():
assert p.kind in (Parameter.POSITIONAL_OR_KEYWORD,
Parameter.KEYWORD_ONLY)
if var_arg_names:
# Cannot have a positional argument after a *args
p = p.replace(kind=Parameter.KEYWORD_ONLY)
params.append(p)
params.append(Parameter("options", Parameter.KEYWORD_ONLY,
default=None))
params.append(Parameter("memory_pool", Parameter.KEYWORD_ONLY,
default=None))
return inspect.Signature(params)
def _wrap_function(name, func):
options_class = _get_options_class(func)
arg_names = _get_arg_names(func)
has_vararg = arg_names and arg_names[-1].startswith('*')
if has_vararg:
var_arg_names = [arg_names.pop().lstrip('*')]
else:
var_arg_names = []
wrapper = _make_generic_wrapper(
name, func, options_class, arity=func.arity)
wrapper.__signature__ = _make_signature(arg_names, var_arg_names,
options_class)
return _decorate_compute_function(wrapper, name, func, options_class)
def _make_global_functions():
"""
Make global functions wrapping each compute function.
Note that some of the automatically-generated wrappers may be overridden
by custom versions below.
"""
g = globals()
reg = function_registry()
# Avoid clashes with Python keywords
rewrites = {'and': 'and_',
'or': 'or_'}
for cpp_name in reg.list_functions():
name = rewrites.get(cpp_name, cpp_name)
func = reg.get_function(cpp_name)
if func.kind == "hash_aggregate":
# Hash aggregate functions are not callable,
# so let's not expose them at module level.
continue
if func.kind == "scalar_aggregate" and func.arity == 0:
# Nullary scalar aggregate functions are not callable
# directly so let's not expose them at module level.
continue
assert name not in g, name
g[cpp_name] = g[name] = _wrap_function(name, func)
_make_global_functions()
# Alias for consistency; globals() is needed to avoid Python lint errors
utf8_zfill = utf8_zero_fill = globals()["utf8_zero_fill"]
def cast(arr, target_type=None, safe=None, options=None, memory_pool=None):
"""
Cast array values to another data type. Can also be invoked as an array
instance method.
Parameters
----------
arr : Array-like
target_type : DataType or str
Type to cast to
safe : bool, default True
Check for overflows or other unsafe conversions
options : CastOptions, default None
Additional checks pass by CastOptions
memory_pool : MemoryPool, optional
memory pool to use for allocations during function execution.
Examples
--------
>>> from datetime import datetime
>>> import pyarrow as pa
>>> arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)])
>>> arr.type
TimestampType(timestamp[us])
You can use ``pyarrow.DataType`` objects to specify the target type:
>>> cast(arr, pa.timestamp('ms'))
<pyarrow.lib.TimestampArray object at ...>
[
2010-01-01 00:00:00.000,
2015-01-01 00:00:00.000
]
>>> cast(arr, pa.timestamp('ms')).type
TimestampType(timestamp[ms])
Alternatively, it is also supported to use the string aliases for these
types:
>>> arr.cast('timestamp[ms]')
<pyarrow.lib.TimestampArray object at ...>
[
2010-01-01 00:00:00.000,
2015-01-01 00:00:00.000
]
>>> arr.cast('timestamp[ms]').type
TimestampType(timestamp[ms])
Returns
-------
casted : Array
The cast result as a new Array
"""
safe_vars_passed = (safe is not None) or (target_type is not None)
if safe_vars_passed and (options is not None):
raise ValueError("Must either pass values for 'target_type' and 'safe'"
" or pass a value for 'options'")
if options is None:
target_type = pa.types.lib.ensure_type(target_type)
if safe is False:
options = CastOptions.unsafe(target_type)
else:
options = CastOptions.safe(target_type)
return call_function("cast", [arr], options, memory_pool)
def index(data, value, start=None, end=None, *, memory_pool=None):
"""
Find the index of the first occurrence of a given value.
Parameters
----------
data : Array-like
value : Scalar-like object
The value to search for.
start : int, optional
end : int, optional
memory_pool : MemoryPool, optional
If not passed, will allocate memory from the default memory pool.
Returns
-------
index : int
the index, or -1 if not found
Examples
--------
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> arr = pa.array(["Lorem", "ipsum", "dolor", "sit", "Lorem", "ipsum"])
>>> pc.index(arr, "ipsum")
<pyarrow.Int64Scalar: 1>
>>> pc.index(arr, "ipsum", start=2)
<pyarrow.Int64Scalar: 5>
>>> pc.index(arr, "amet")
<pyarrow.Int64Scalar: -1>
"""
if start is not None:
if end is not None:
data = data.slice(start, end - start)
else:
data = data.slice(start)
elif end is not None:
data = data.slice(0, end)
if not isinstance(value, pa.Scalar):
value = pa.scalar(value, type=data.type)
elif data.type != value.type:
value = pa.scalar(value.as_py(), type=data.type)
options = IndexOptions(value=value)
result = call_function('index', [data], options, memory_pool)
if start is not None and result.as_py() >= 0:
result = pa.scalar(result.as_py() + start, type=pa.int64())
return result
def take(data, indices, *, boundscheck=True, memory_pool=None):
"""
Select values (or records) from array- or table-like data given integer
selection indices.
The result will be of the same type(s) as the input, with elements taken
from the input array (or record batch / table fields) at the given
indices. If an index is null then the corresponding value in the output
will be null.
Parameters
----------
data : Array, ChunkedArray, RecordBatch, or Table
indices : Array, ChunkedArray
Must be of integer type
boundscheck : boolean, default True
Whether to boundscheck the indices. If False and there is an out of
bounds index, will likely cause the process to crash.
memory_pool : MemoryPool, optional
If not passed, will allocate memory from the default memory pool.
Returns
-------
result : depends on inputs
Selected values for the given indices
Examples
--------
>>> import pyarrow as pa
>>> arr = pa.array(["a", "b", "c", None, "e", "f"])
>>> indices = pa.array([0, None, 4, 3])
>>> arr.take(indices)
<pyarrow.lib.StringArray object at ...>
[
"a",
null,
"e",
null
]
"""
options = TakeOptions(boundscheck=boundscheck)
return call_function('take', [data, indices], options, memory_pool)
def fill_null(values, fill_value):
"""Replace each null element in values with a corresponding
element from fill_value.
If fill_value is scalar-like, then every null element in values
will be replaced with fill_value. If fill_value is array-like,
then the i-th element in values will be replaced with the i-th
element in fill_value.
The fill_value's type must be the same as that of values, or it
must be able to be implicitly casted to the array's type.
This is an alias for :func:`coalesce`.
Parameters
----------
values : Array, ChunkedArray, or Scalar-like object
Each null element is replaced with the corresponding value
from fill_value.
fill_value : Array, ChunkedArray, or Scalar-like object
If not same type as values, will attempt to cast.
Returns
-------
result : depends on inputs
Values with all null elements replaced
Examples
--------
>>> import pyarrow as pa
>>> arr = pa.array([1, 2, None, 3], type=pa.int8())
>>> fill_value = pa.scalar(5, type=pa.int8())
>>> arr.fill_null(fill_value)
<pyarrow.lib.Int8Array object at ...>
[
1,
2,
5,
3
]
>>> arr = pa.array([1, 2, None, 4, None])
>>> arr.fill_null(pa.array([10, 20, 30, 40, 50]))
<pyarrow.lib.Int64Array object at ...>
[
1,
2,
30,
4,
50
]
"""
if not isinstance(fill_value, (pa.Array, pa.ChunkedArray, pa.Scalar)):
fill_value = pa.scalar(fill_value, type=values.type)
elif values.type != fill_value.type:
fill_value = pa.scalar(fill_value.as_py(), type=values.type)
return call_function("coalesce", [values, fill_value])
def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None):
"""
Select the indices of the top-k ordered elements from array- or table-like
data.
This is a specialization for :func:`select_k_unstable`. Output is not
guaranteed to be stable.
Parameters
----------
values : Array, ChunkedArray, RecordBatch, or Table
Data to sort and get top indices from.
k : int
The number of `k` elements to keep.
sort_keys : List-like
Column key names to order by when input is table-like data.
memory_pool : MemoryPool, optional
If not passed, will allocate memory from the default memory pool.
Returns
-------
result : Array
Indices of the top-k ordered elements
Examples
--------
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> arr = pa.array(["a", "b", "c", None, "e", "f"])
>>> pc.top_k_unstable(arr, k=3)
<pyarrow.lib.UInt64Array object at ...>
[
5,
4,
2
]
"""
if sort_keys is None:
sort_keys = []
if isinstance(values, (pa.Array, pa.ChunkedArray)):
sort_keys.append(("dummy", "descending"))
else:
sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys)
options = SelectKOptions(k, sort_keys)
return call_function("select_k_unstable", [values], options, memory_pool)
def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None):
"""
Select the indices of the bottom-k ordered elements from
array- or table-like data.
This is a specialization for :func:`select_k_unstable`. Output is not
guaranteed to be stable.
Parameters
----------
values : Array, ChunkedArray, RecordBatch, or Table
Data to sort and get bottom indices from.
k : int
The number of `k` elements to keep.
sort_keys : List-like
Column key names to order by when input is table-like data.
memory_pool : MemoryPool, optional
If not passed, will allocate memory from the default memory pool.
Returns
-------
result : Array of indices
Indices of the bottom-k ordered elements
Examples
--------
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> arr = pa.array(["a", "b", "c", None, "e", "f"])
>>> pc.bottom_k_unstable(arr, k=3)
<pyarrow.lib.UInt64Array object at ...>
[
0,
1,
2
]
"""
if sort_keys is None:
sort_keys = []
if isinstance(values, (pa.Array, pa.ChunkedArray)):
sort_keys.append(("dummy", "ascending"))
else:
sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys)
options = SelectKOptions(k, sort_keys)
return call_function("select_k_unstable", [values], options, memory_pool)
def random(n, *, initializer='system', options=None, memory_pool=None):
"""
Generate numbers in the range [0, 1).
Generated values are uniformly-distributed, double-precision
in range [0, 1). Algorithm and seed can be changed via RandomOptions.
Parameters
----------
n : int
Number of values to generate, must be greater than or equal to 0
initializer : int or str
How to initialize the underlying random generator.
If an integer is given, it is used as a seed.
If "system" is given, the random generator is initialized with
a system-specific source of (hopefully true) randomness.
Other values are invalid.
options : pyarrow.compute.RandomOptions, optional
Alternative way of passing options.
memory_pool : pyarrow.MemoryPool, optional
If not passed, will allocate memory from the default memory pool.
"""
options = RandomOptions(initializer=initializer)
return call_function("random", [], options, memory_pool, length=n)
def field(*name_or_index):
"""Reference a column of the dataset.
Stores only the field's name. Type and other information is known only when
the expression is bound to a dataset having an explicit scheme.
Nested references are allowed by passing multiple names or a tuple of
names. For example ``('foo', 'bar')`` references the field named "bar"
inside the field named "foo".
Parameters
----------
*name_or_index : string, multiple strings, tuple or int
The name or index of the (possibly nested) field the expression
references to.
Returns
-------
field_expr : Expression
Reference to the given field
Examples
--------
>>> import pyarrow.compute as pc
>>> pc.field("a")
<pyarrow.compute.Expression a>
>>> pc.field(1)
<pyarrow.compute.Expression FieldPath(1)>
>>> pc.field(("a", "b"))
<pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
>>> pc.field("a", "b")
<pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
"""
n = len(name_or_index)
if n == 1:
if isinstance(name_or_index[0], (str, int)):
return Expression._field(name_or_index[0])
elif isinstance(name_or_index[0], tuple):
return Expression._nested_field(name_or_index[0])
else:
raise TypeError(
"field reference should be str, multiple str, tuple or "
f"integer, got {type(name_or_index[0])}"
)
# In case of multiple strings not supplied in a tuple
else:
return Expression._nested_field(name_or_index)
def scalar(value):
"""Expression representing a scalar value.
Creates an Expression object representing a scalar value that can be used
in compute expressions and predicates.
Parameters
----------
value : bool, int, float or string
Python value of the scalar. This function accepts any value that can be
converted to a ``pyarrow.Scalar`` using ``pa.scalar()``.
Notes
-----
This function differs from ``pyarrow.scalar()`` in the following way:
* ``pyarrow.scalar()`` creates a ``pyarrow.Scalar`` object that represents
a single value in Arrow's memory model.
* ``pyarrow.compute.scalar()`` creates an ``Expression`` object representing
a scalar value that can be used in compute expressions, predicates, and
dataset filtering operations.
Returns
-------
scalar_expr : Expression
An Expression representing the scalar value
"""
return Expression._scalar(value)

View File

@@ -0,0 +1,95 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pyarrow.includes.libarrow cimport GetBuildInfo
from collections import namedtuple
import os
VersionInfo = namedtuple('VersionInfo', ('major', 'minor', 'patch'))
BuildInfo = namedtuple(
'BuildInfo',
('version', 'version_info', 'so_version', 'full_so_version',
'compiler_id', 'compiler_version', 'compiler_flags',
'git_id', 'git_description', 'package_kind', 'build_type'))
RuntimeInfo = namedtuple('RuntimeInfo',
('simd_level', 'detected_simd_level'))
cdef _build_info():
cdef:
const CBuildInfo* c_info
c_info = &GetBuildInfo()
return BuildInfo(version=frombytes(c_info.version_string),
version_info=VersionInfo(c_info.version_major,
c_info.version_minor,
c_info.version_patch),
so_version=frombytes(c_info.so_version),
full_so_version=frombytes(c_info.full_so_version),
compiler_id=frombytes(c_info.compiler_id),
compiler_version=frombytes(c_info.compiler_version),
compiler_flags=frombytes(c_info.compiler_flags),
git_id=frombytes(c_info.git_id),
git_description=frombytes(c_info.git_description),
package_kind=frombytes(c_info.package_kind),
build_type=frombytes(c_info.build_type).lower(),
)
cpp_build_info = _build_info()
cpp_version = cpp_build_info.version
cpp_version_info = cpp_build_info.version_info
def runtime_info():
"""
Get runtime information.
Returns
-------
info : pyarrow.RuntimeInfo
"""
cdef:
CRuntimeInfo c_info
c_info = GetRuntimeInfo()
return RuntimeInfo(
simd_level=frombytes(c_info.simd_level),
detected_simd_level=frombytes(c_info.detected_simd_level))
def set_timezone_db_path(path):
"""
Configure the path to text timezone database on Windows.
Parameters
----------
path : str
Path to text timezone database.
"""
cdef:
CGlobalOptions options
if path is not None:
options.timezone_db_path = <c_string>tobytes(path)
check_status(Initialize(options))

View File

@@ -0,0 +1,386 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import os
import pyarrow as pa
from pyarrow import Codec
from pyarrow import fs
from pyarrow.lib import is_threading_enabled
from pyarrow.tests.util import windows_has_tzdata
import sys
groups = [
'acero',
'azure',
'brotli',
'bz2',
'cython',
'dataset',
'hypothesis',
'fastparquet',
'flight',
'gandiva',
'gcs',
'gdb',
'gzip',
'hdfs',
'large_memory',
'lz4',
'memory_leak',
'nopandas',
'nonumpy',
'numpy',
'orc',
'pandas',
'parquet',
'parquet_encryption',
'processes',
'requires_testing_data',
's3',
'slow',
'snappy',
'sockets',
'substrait',
'threading',
'timezone_data',
'zstd',
]
defaults = {
'acero': False,
'azure': False,
'brotli': Codec.is_available('brotli'),
'bz2': Codec.is_available('bz2'),
'cython': False,
'dataset': False,
'fastparquet': False,
'flight': False,
'gandiva': False,
'gcs': False,
'gdb': True,
'gzip': Codec.is_available('gzip'),
'hdfs': False,
'hypothesis': False,
'large_memory': False,
'lz4': Codec.is_available('lz4'),
'memory_leak': False,
'nopandas': False,
'nonumpy': False,
'numpy': False,
'orc': False,
'pandas': False,
'parquet': False,
'parquet_encryption': False,
'processes': True,
'requires_testing_data': True,
's3': False,
'slow': False,
'snappy': Codec.is_available('snappy'),
'sockets': True,
'substrait': False,
'threading': is_threading_enabled(),
'timezone_data': True,
'zstd': Codec.is_available('zstd'),
}
if sys.platform == "emscripten":
# Emscripten doesn't support subprocess,
# multiprocessing, gdb or socket based
# networking
defaults['gdb'] = False
defaults['processes'] = False
defaults['sockets'] = False
if sys.platform == "win32":
defaults['timezone_data'] = windows_has_tzdata()
elif sys.platform == "emscripten":
defaults['timezone_data'] = os.path.exists("/usr/share/zoneinfo")
try:
import cython # noqa
defaults['cython'] = True
except ImportError:
pass
try:
import fastparquet # noqa
defaults['fastparquet'] = True
except ImportError:
pass
try:
import pyarrow.gandiva # noqa
defaults['gandiva'] = True
except ImportError:
pass
try:
import pyarrow.acero # noqa
defaults['acero'] = True
except ImportError:
pass
try:
import pyarrow.dataset # noqa
defaults['dataset'] = True
except ImportError:
pass
try:
import pyarrow.orc # noqa
if sys.platform == "win32":
defaults['orc'] = True
else:
# orc tests on non-Windows platforms only work
# if timezone data exists, so skip them if
# not.
defaults['orc'] = defaults['timezone_data']
except ImportError:
pass
try:
import pandas # noqa
defaults['pandas'] = True
except ImportError:
defaults['nopandas'] = True
try:
import numpy # noqa
defaults['numpy'] = True
except ImportError:
defaults['nonumpy'] = True
try:
import pyarrow.parquet # noqa
defaults['parquet'] = True
except ImportError:
pass
try:
import pyarrow.parquet.encryption # noqa
defaults['parquet_encryption'] = True
except ImportError:
pass
try:
import pyarrow.flight # noqa
defaults['flight'] = True
except ImportError:
pass
try:
from pyarrow.fs import AzureFileSystem # noqa
defaults['azure'] = True
except ImportError:
pass
try:
from pyarrow.fs import GcsFileSystem # noqa
defaults['gcs'] = True
except ImportError:
pass
try:
from pyarrow.fs import S3FileSystem # noqa
defaults['s3'] = True
except ImportError:
pass
try:
from pyarrow.fs import HadoopFileSystem # noqa
defaults['hdfs'] = True
except ImportError:
pass
try:
import pyarrow.substrait # noqa
defaults['substrait'] = True
except ImportError:
pass
# Doctest should ignore files for the modules that are not built
def pytest_ignore_collect(collection_path, config):
if config.option.doctestmodules:
# don't try to run doctests on the /tests directory
if "/pyarrow/tests/" in str(collection_path):
return True
doctest_groups = [
'dataset',
'orc',
'parquet',
'flight',
'substrait',
]
# handle cuda, flight, etc
for group in doctest_groups:
if f'pyarrow/{group}' in str(collection_path):
if not defaults[group]:
return True
if 'pyarrow/parquet/encryption' in str(collection_path):
if not defaults['parquet_encryption']:
return True
if 'pyarrow/cuda' in str(collection_path):
try:
import pyarrow.cuda # noqa
return False
except ImportError:
return True
if 'pyarrow/fs' in str(collection_path):
try:
from pyarrow.fs import S3FileSystem # noqa
return False
except ImportError:
return True
if getattr(config.option, "doctest_cython", False):
if "/pyarrow/tests/" in str(collection_path):
return True
if "/pyarrow/_parquet_encryption" in str(collection_path):
return True
return False
# Save output files from doctest examples into temp dir
@pytest.fixture(autouse=True)
def _docdir(request):
# Trigger ONLY for the doctests
doctest_m = request.config.option.doctestmodules
doctest_c = getattr(request.config.option, "doctest_cython", False)
if doctest_m or doctest_c:
# Get the fixture dynamically by its name.
tmpdir = request.getfixturevalue('tmpdir')
# Chdir only for the duration of the test.
with tmpdir.as_cwd():
yield
else:
yield
# Define doctest_namespace for fs module docstring import
@pytest.fixture(autouse=True)
def add_fs(doctest_namespace, request, tmp_path):
# Trigger ONLY for the doctests
doctest_m = request.config.option.doctestmodules
doctest_c = getattr(request.config.option, "doctest_cython", False)
if doctest_m or doctest_c:
# fs import
doctest_namespace["fs"] = fs
# Creation of an object and file with data
local = fs.LocalFileSystem()
path = tmp_path / 'pyarrow-fs-example.dat'
with local.open_output_stream(str(path)) as stream:
stream.write(b'data')
doctest_namespace["local"] = local
doctest_namespace["local_path"] = str(tmp_path)
doctest_namespace["path"] = str(path)
yield
# Define udf fixture for test_udf.py and test_substrait.py
@pytest.fixture(scope="session")
def unary_func_fixture():
"""
Register a unary scalar function.
"""
from pyarrow import compute as pc
def unary_function(ctx, x):
return pc.call_function("add", [x, 1],
memory_pool=ctx.memory_pool)
func_name = "y=x+1"
unary_doc = {"summary": "add function",
"description": "test add function"}
pc.register_scalar_function(unary_function,
func_name,
unary_doc,
{"array": pa.int64()},
pa.int64())
return unary_function, func_name
@pytest.fixture(scope="session")
def unary_agg_func_fixture():
"""
Register a unary aggregate function (mean)
"""
from pyarrow import compute as pc
import numpy as np
def func(ctx, x):
return pa.scalar(np.nanmean(x))
func_name = "mean_udf"
func_doc = {"summary": "y=avg(x)",
"description": "find mean of x"}
pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.float64(),
},
pa.float64()
)
return func, func_name
@pytest.fixture(scope="session")
def varargs_agg_func_fixture():
"""
Register a unary aggregate function
"""
from pyarrow import compute as pc
import numpy as np
def func(ctx, *args):
sum = 0.0
for arg in args:
sum += np.nanmean(arg)
return pa.scalar(sum)
func_name = "sum_mean"
func_doc = {"summary": "Varargs aggregate",
"description": "Varargs aggregate"}
pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.int64(),
"y": pa.float64()
},
pa.float64()
)
return func, func_name

View File

@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pyarrow._csv import ( # noqa
ReadOptions, ParseOptions, ConvertOptions, ISO8601,
open_csv, read_csv, CSVStreamingReader, write_csv,
WriteOptions, CSVWriter, InvalidRow)

View File

@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# flake8: noqa
from pyarrow._cuda import (Context, IpcMemHandle, CudaBuffer,
HostBuffer, BufferReader, BufferWriter,
new_host_buffer,
serialize_record_batch, read_message,
read_record_batch)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
cpdef enum DeviceAllocationType:
CPU = <char> CDeviceAllocationType_kCPU
CUDA = <char> CDeviceAllocationType_kCUDA
CUDA_HOST = <char> CDeviceAllocationType_kCUDA_HOST
OPENCL = <char> CDeviceAllocationType_kOPENCL
VULKAN = <char> CDeviceAllocationType_kVULKAN
METAL = <char> CDeviceAllocationType_kMETAL
VPI = <char> CDeviceAllocationType_kVPI
ROCM = <char> CDeviceAllocationType_kROCM
ROCM_HOST = <char> CDeviceAllocationType_kROCM_HOST
EXT_DEV = <char> CDeviceAllocationType_kEXT_DEV
CUDA_MANAGED = <char> CDeviceAllocationType_kCUDA_MANAGED
ONEAPI = <char> CDeviceAllocationType_kONEAPI
WEBGPU = <char> CDeviceAllocationType_kWEBGPU
HEXAGON = <char> CDeviceAllocationType_kHEXAGON
cdef object _wrap_device_allocation_type(CDeviceAllocationType device_type):
return DeviceAllocationType(<char> device_type)
cdef class Device(_Weakrefable):
"""
Abstract interface for hardware devices
This object represents a device with access to some memory spaces.
When handling a Buffer or raw memory address, it allows deciding in which
context the raw memory address should be interpreted
(e.g. CPU-accessible memory, or embedded memory on some particular GPU).
"""
def __init__(self):
raise TypeError("Do not call Device's constructor directly, "
"use the device attribute of the MemoryManager instead.")
cdef void init(self, const shared_ptr[CDevice]& device):
self.device = device
@staticmethod
cdef wrap(const shared_ptr[CDevice]& device):
cdef Device self = Device.__new__(Device)
self.init(device)
return self
cdef inline shared_ptr[CDevice] unwrap(self) nogil:
return self.device
def __eq__(self, other):
if not isinstance(other, Device):
return False
return self.device.get().Equals(deref((<Device>other).device.get()))
def __repr__(self):
return f"<pyarrow.Device: {frombytes(self.device.get().ToString())}>"
@property
def type_name(self):
"""
A shorthand for this device's type.
"""
return frombytes(self.device.get().type_name())
@property
def device_id(self):
"""
A device ID to identify this device if there are multiple of this type.
If there is no "device_id" equivalent (such as for the main CPU device on
non-numa systems) returns -1.
"""
return self.device.get().device_id()
@property
def is_cpu(self):
"""
Whether this device is the main CPU device.
This shorthand method is very useful when deciding whether a memory address
is CPU-accessible.
"""
return self.device.get().is_cpu()
@property
def device_type(self):
"""
Return the DeviceAllocationType of this device.
"""
return _wrap_device_allocation_type(self.device.get().device_type())
cdef class MemoryManager(_Weakrefable):
"""
An object that provides memory management primitives.
A MemoryManager is always tied to a particular Device instance.
It can also have additional parameters (such as a MemoryPool to
allocate CPU memory).
"""
def __init__(self):
raise TypeError("Do not call MemoryManager's constructor directly, "
"use pyarrow.default_cpu_memory_manager() instead.")
cdef void init(self, const shared_ptr[CMemoryManager]& mm):
self.memory_manager = mm
@staticmethod
cdef wrap(const shared_ptr[CMemoryManager]& mm):
cdef MemoryManager self = MemoryManager.__new__(MemoryManager)
self.init(mm)
return self
cdef inline shared_ptr[CMemoryManager] unwrap(self) nogil:
return self.memory_manager
def __repr__(self):
device_str = frombytes(self.memory_manager.get().device().get().ToString())
return f"<pyarrow.MemoryManager device: {device_str}>"
@property
def device(self):
"""
The device this MemoryManager is tied to.
"""
return Device.wrap(self.memory_manager.get().device())
@property
def is_cpu(self):
"""
Whether this MemoryManager is tied to the main CPU device.
This shorthand method is very useful when deciding whether a memory
address is CPU-accessible.
"""
return self.memory_manager.get().is_cpu()
def default_cpu_memory_manager():
"""
Return the default CPU MemoryManager instance.
The returned singleton instance uses the default MemoryPool.
"""
return MemoryManager.wrap(c_default_cpu_memory_manager())

View File

@@ -0,0 +1,274 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetInterrupt
from pyarrow.includes.libarrow cimport CStatus
from pyarrow.includes.libarrow_python cimport IsPyError, RestorePyError
from pyarrow.includes.common cimport c_string
from contextlib import contextmanager
import os
import signal
import threading
from pyarrow.lib import is_threading_enabled
from pyarrow.util import _break_traceback_cycle_from_frame
class ArrowException(Exception):
pass
class ArrowInvalid(ValueError, ArrowException):
pass
class ArrowMemoryError(MemoryError, ArrowException):
pass
class ArrowKeyError(KeyError, ArrowException):
def __str__(self):
# Override KeyError.__str__, as it uses the repr() of the key
return ArrowException.__str__(self)
class ArrowTypeError(TypeError, ArrowException):
pass
class ArrowNotImplementedError(NotImplementedError, ArrowException):
pass
class ArrowCapacityError(ArrowException):
pass
class ArrowIndexError(IndexError, ArrowException):
pass
class ArrowSerializationError(ArrowException):
pass
class ArrowCancelled(ArrowException):
def __init__(self, message, signum=None):
super().__init__(message)
self.signum = signum
# Compatibility alias
ArrowIOError = IOError
# check_status() and convert_status() could be written directly in C++
# if we didn't define Arrow-specific subclasses (ArrowInvalid etc.)
cdef int check_status(const CStatus& status) except -1 nogil:
if status.ok():
return 0
with gil:
if IsPyError(status):
RestorePyError(status)
return -1
raise convert_status(status)
cdef object convert_status(const CStatus& status):
if IsPyError(status):
try:
RestorePyError(status)
except BaseException as e:
return e
# We don't use Status::ToString() as it would redundantly include
# the C++ class name.
message = frombytes(status.message(), safe=True)
detail = status.detail()
if detail != nullptr:
message += ". Detail: " + frombytes(detail.get().ToString(),
safe=True)
if status.IsInvalid():
return ArrowInvalid(message)
elif status.IsIOError():
# Note: OSError constructor is
# OSError(message)
# or
# OSError(errno, message, filename=None)
# or (on Windows)
# OSError(errno, message, filename, winerror)
errno = ErrnoFromStatus(status)
winerror = WinErrorFromStatus(status)
if winerror != 0:
return IOError(errno, message, None, winerror)
elif errno != 0:
return IOError(errno, message)
else:
return IOError(message)
elif status.IsOutOfMemory():
return ArrowMemoryError(message)
elif status.IsKeyError():
return ArrowKeyError(message)
elif status.IsNotImplemented():
return ArrowNotImplementedError(message)
elif status.IsTypeError():
return ArrowTypeError(message)
elif status.IsCapacityError():
return ArrowCapacityError(message)
elif status.IsIndexError():
return ArrowIndexError(message)
elif status.IsSerializationError():
return ArrowSerializationError(message)
elif status.IsCancelled():
signum = SignalFromStatus(status)
if signum > 0:
return ArrowCancelled(message, signum)
else:
return ArrowCancelled(message)
else:
message = frombytes(status.ToString(), safe=True)
return ArrowException(message)
# These are API functions for C++ PyArrow
cdef api int pyarrow_internal_check_status(const CStatus& status) \
except -1 nogil:
return check_status(status)
cdef api object pyarrow_internal_convert_status(const CStatus& status):
return convert_status(status)
cdef class StopToken:
cdef void init(self, CStopToken stop_token):
self.stop_token = move(stop_token)
cdef c_bool signal_handlers_enabled = True
def enable_signal_handlers(c_bool enable):
"""
Enable or disable interruption of long-running operations.
By default, certain long running operations will detect user
interruptions, such as by pressing Ctrl-C. This detection relies
on setting a signal handler for the duration of the long-running
operation, and may therefore interfere with other frameworks or
libraries (such as an event loop).
Parameters
----------
enable : bool
Whether to enable user interruption by setting a temporary
signal handler.
"""
global signal_handlers_enabled
signal_handlers_enabled = enable
# For internal use
# Whether we need a workaround for https://bugs.python.org/issue42248
have_signal_refcycle = (sys.version_info < (3, 8, 10) or
(3, 9) <= sys.version_info < (3, 9, 5) or
sys.version_info[:2] == (3, 10))
cdef class SignalStopHandler:
cdef:
StopToken _stop_token
vector[int] _signals
c_bool _enabled
def __cinit__(self):
self._enabled = False
self._init_signals()
if have_signal_refcycle:
_break_traceback_cycle_from_frame(sys._getframe(0))
self._stop_token = StopToken()
if not self._signals.empty():
maybe_source = SetSignalStopSource()
if not maybe_source.ok():
# See ARROW-11841 / ARROW-17173: in complex interaction
# scenarios (such as R calling into Python), SetSignalStopSource()
# may have already activated a signal-receiving StopSource.
# Just warn instead of erroring out.
maybe_source.status().Warn()
else:
self._stop_token.init(deref(maybe_source).token())
# signals don't work on Emscripten without threads.
# and possibly other single-thread environments.
self._enabled = is_threading_enabled()
def _init_signals(self):
if (signal_handlers_enabled and
threading.current_thread() is threading.main_thread()):
self._signals = [
sig for sig in (signal.SIGINT, signal.SIGTERM)
if signal.getsignal(sig) not in (signal.SIG_DFL,
signal.SIG_IGN, None)]
def __enter__(self):
if self._enabled:
check_status(RegisterCancellingSignalHandler(self._signals))
return self
def __exit__(self, exc_type, exc_value, exc_tb):
if self._enabled:
UnregisterCancellingSignalHandler()
if exc_value is None:
# Make sure we didn't lose a signal
try:
check_status(self._stop_token.stop_token.Poll())
except ArrowCancelled as e:
exc_value = e
if isinstance(exc_value, ArrowCancelled):
if exc_value.signum:
# Re-emit the exact same signal. We restored the Python signal
# handler above, so it should receive it.
if os.name == 'nt':
SendSignal(exc_value.signum)
else:
SendSignalToThread(exc_value.signum,
threading.main_thread().ident)
else:
# Simulate Python receiving a SIGINT
# (see https://bugs.python.org/issue43356 for why we can't
# simulate the exact signal number)
PyErr_SetInterrupt()
# Maximize chances of the Python signal handler being executed now.
# Otherwise a potential KeyboardInterrupt might be missed by an
# immediately enclosing try/except block.
PyErr_CheckSignals()
# ArrowCancelled will be re-raised if PyErr_CheckSignals()
# returned successfully.
def __dealloc__(self):
if self._enabled:
ResetSignalStopSource()
@property
def stop_token(self):
return self._stop_token

View File

@@ -0,0 +1,279 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Sequence
import os
from pyarrow.pandas_compat import _pandas_api # noqa
from pyarrow.lib import (Codec, Table, # noqa
concat_tables, schema)
import pyarrow.lib as ext
from pyarrow import _feather
from pyarrow._feather import FeatherError # noqa: F401
class FeatherDataset:
"""
Encapsulates details of reading a list of Feather files.
Parameters
----------
path_or_paths : List[str]
A list of file names
validate_schema : bool, default True
Check that individual file schemas are all the same / compatible
"""
def __init__(self, path_or_paths, validate_schema=True):
self.paths = path_or_paths
self.validate_schema = validate_schema
def read_table(self, columns=None):
"""
Read multiple feather files as a single pyarrow.Table
Parameters
----------
columns : List[str]
Names of columns to read from the file
Returns
-------
pyarrow.Table
Content of the file as a table (of columns)
"""
_fil = read_table(self.paths[0], columns=columns)
self._tables = [_fil]
self.schema = _fil.schema
for path in self.paths[1:]:
table = read_table(path, columns=columns)
if self.validate_schema:
self.validate_schemas(path, table)
self._tables.append(table)
return concat_tables(self._tables)
def validate_schemas(self, piece, table):
if not self.schema.equals(table.schema):
raise ValueError(f'Schema in {piece} was different. \n'
f'{self.schema}\n\nvs\n\n{table.schema}')
def read_pandas(self, columns=None, use_threads=True):
"""
Read multiple Parquet files as a single pandas DataFrame
Parameters
----------
columns : List[str]
Names of columns to read from the file
use_threads : bool, default True
Use multiple threads when converting to pandas
Returns
-------
pandas.DataFrame
Content of the file as a pandas DataFrame (of columns)
"""
return self.read_table(columns=columns).to_pandas(
use_threads=use_threads)
def check_chunked_overflow(name, col):
if col.num_chunks == 1:
return
if col.type in (ext.binary(), ext.string()):
raise ValueError(f"Column '{name}' exceeds 2GB maximum capacity of "
"a Feather binary column. This restriction may be "
"lifted in the future")
else:
# TODO(wesm): Not sure when else this might be reached
raise ValueError(
f"Column '{name}' of type {col.type} was chunked on conversion to Arrow "
"and cannot be currently written to Feather format"
)
_FEATHER_SUPPORTED_CODECS = {'lz4', 'zstd', 'uncompressed'}
def write_feather(df, dest, compression=None, compression_level=None,
chunksize=None, version=2):
"""
Write a pandas.DataFrame to Feather format.
Parameters
----------
df : pandas.DataFrame or pyarrow.Table
Data to write out as Feather format.
dest : str
Local destination path.
compression : string, default None
Can be one of {"zstd", "lz4", "uncompressed"}. The default of None uses
LZ4 for V2 files if it is available, otherwise uncompressed.
compression_level : int, default None
Use a compression level particular to the chosen compressor. If None
use the default compression level
chunksize : int, default None
For V2 files, the internal maximum size of Arrow RecordBatch chunks
when writing the Arrow IPC file format. None means use the default,
which is currently 64K
version : int, default 2
Feather file version. Version 2 is the current. Version 1 is the more
limited legacy format
"""
if _pandas_api.have_pandas:
if (_pandas_api.has_sparse and
isinstance(df, _pandas_api.pd.SparseDataFrame)):
df = df.to_dense()
if _pandas_api.is_data_frame(df):
# Feather v1 creates a new column in the resultant Table to
# store index information if index type is not RangeIndex
if version == 1:
preserve_index = False
elif version == 2:
preserve_index = None
else:
raise ValueError("Version value should either be 1 or 2")
table = Table.from_pandas(df, preserve_index=preserve_index)
if version == 1:
# Version 1 does not chunking
for i, name in enumerate(table.schema.names):
col = table[i]
check_chunked_overflow(name, col)
else:
table = df
if version == 1:
if len(table.column_names) > len(set(table.column_names)):
raise ValueError("cannot serialize duplicate column names")
if compression is not None:
raise ValueError("Feather V1 files do not support compression "
"option")
if chunksize is not None:
raise ValueError("Feather V1 files do not support chunksize "
"option")
else:
if compression is None and Codec.is_available('lz4_frame'):
compression = 'lz4'
elif (compression is not None and
compression not in _FEATHER_SUPPORTED_CODECS):
raise ValueError(f'compression="{compression}" not supported, must be '
f'one of {_FEATHER_SUPPORTED_CODECS}')
try:
_feather.write_feather(table, dest, compression=compression,
compression_level=compression_level,
chunksize=chunksize, version=version)
except Exception:
if isinstance(dest, str):
try:
os.remove(dest)
except os.error:
pass
raise
def read_feather(source, columns=None, use_threads=True,
memory_map=False, **kwargs):
"""
Read a pandas.DataFrame from Feather format. To read as pyarrow.Table use
feather.read_table.
Parameters
----------
source : str file path, or file-like object
You can use MemoryMappedFile as source, for explicitly use memory map.
columns : sequence, optional
Only read a specific set of columns. If not provided, all columns are
read.
use_threads : bool, default True
Whether to parallelize reading using multiple threads. If false the
restriction is used in the conversion to Pandas as well as in the
reading from Feather format.
memory_map : boolean, default False
Use memory mapping when opening file on disk, when source is a str.
**kwargs
Additional keyword arguments passed on to `pyarrow.Table.to_pandas`.
Returns
-------
df : pandas.DataFrame
The contents of the Feather file as a pandas.DataFrame
"""
return (read_table(
source, columns=columns, memory_map=memory_map,
use_threads=use_threads).to_pandas(use_threads=use_threads, **kwargs))
def read_table(source, columns=None, memory_map=False, use_threads=True):
"""
Read a pyarrow.Table from Feather format
Parameters
----------
source : str file path, or file-like object
You can use MemoryMappedFile as source, for explicitly use memory map.
columns : sequence, optional
Only read a specific set of columns. If not provided, all columns are
read.
memory_map : boolean, default False
Use memory mapping when opening file on disk, when source is a str
use_threads : bool, default True
Whether to parallelize reading using multiple threads.
Returns
-------
table : pyarrow.Table
The contents of the Feather file as a pyarrow.Table
"""
reader = _feather.FeatherReader(
source, use_memory_map=memory_map, use_threads=use_threads)
if columns is None:
return reader.read()
if not isinstance(columns, Sequence):
raise TypeError("Columns must be a sequence but, got {}"
.format(type(columns).__name__))
column_types = [type(column) for column in columns]
if all(map(lambda t: t == int, column_types)):
table = reader.read_indices(columns)
elif all(map(lambda t: t == str, column_types)):
table = reader.read_names(columns)
else:
column_type_names = [t.__name__ for t in column_types]
raise TypeError("Columns must be indices or names. "
f"Got columns {columns} of types {column_type_names}")
# Feather v1 already respects the column selection
if reader.version < 3:
return table
# Feather v2 reads with sorted / deduplicated selection
elif sorted(set(columns)) == columns:
return table
else:
# follow exact order / selection of names
return table.select(columns)

View File

@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
try:
from pyarrow._flight import ( # noqa:F401
connect,
Action,
ActionType,
BasicAuth,
CallInfo,
CertKeyPair,
ClientAuthHandler,
ClientMiddleware,
ClientMiddlewareFactory,
DescriptorType,
FlightCallOptions,
FlightCancelledError,
FlightClient,
FlightDataStream,
FlightDescriptor,
FlightEndpoint,
FlightError,
FlightInfo,
FlightInternalError,
FlightMetadataReader,
FlightMetadataWriter,
FlightMethod,
FlightServerBase,
FlightServerError,
FlightStreamChunk,
FlightStreamReader,
FlightStreamWriter,
FlightTimedOutError,
FlightUnauthenticatedError,
FlightUnauthorizedError,
FlightUnavailableError,
FlightWriteSizeExceededError,
GeneratorStream,
Location,
MetadataRecordBatchReader,
MetadataRecordBatchWriter,
RecordBatchStream,
Result,
SchemaResult,
ServerAuthHandler,
ServerCallContext,
ServerMiddleware,
ServerMiddlewareFactory,
Ticket,
TracingServerMiddlewareFactory,
)
except ImportError as exc:
raise ImportError(
f"The pyarrow installation is not built with support for 'flight' ({str(exc)})"
) from None

View File

@@ -0,0 +1,428 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
FileSystem abstraction to interact with various local and remote filesystems.
"""
from pyarrow.util import _is_path_like, _stringify_path
from pyarrow._fs import ( # noqa
FileSelector,
FileType,
FileInfo,
FileSystem,
LocalFileSystem,
SubTreeFileSystem,
_MockFileSystem,
FileSystemHandler,
PyFileSystem,
_copy_files,
_copy_files_selector,
)
# For backward compatibility.
FileStats = FileInfo
_not_imported = []
try:
from pyarrow._azurefs import AzureFileSystem # noqa
except ImportError:
_not_imported.append("AzureFileSystem")
try:
from pyarrow._hdfs import HadoopFileSystem # noqa
except ImportError:
_not_imported.append("HadoopFileSystem")
try:
from pyarrow._gcsfs import GcsFileSystem # noqa
except ImportError:
_not_imported.append("GcsFileSystem")
try:
from pyarrow._s3fs import ( # noqa
AwsDefaultS3RetryStrategy, AwsStandardS3RetryStrategy,
S3FileSystem, S3LogLevel, S3RetryStrategy, ensure_s3_initialized,
finalize_s3, ensure_s3_finalized, initialize_s3, resolve_s3_region)
except ImportError:
_not_imported.append("S3FileSystem")
else:
# GH-38364: we don't initialize S3 eagerly as that could lead
# to crashes at shutdown even when S3 isn't used.
# Instead, S3 is initialized lazily using `ensure_s3_initialized`
# in assorted places.
import atexit
atexit.register(ensure_s3_finalized)
def __getattr__(name):
if name in _not_imported:
raise ImportError(
"The pyarrow installation is not built with support for "
f"'{name}'"
)
raise AttributeError(
f"module 'pyarrow.fs' has no attribute '{name}'"
)
def _ensure_filesystem(filesystem, *, use_mmap=False):
if isinstance(filesystem, FileSystem):
return filesystem
elif isinstance(filesystem, str):
# create a filesystem from a URI string, note that the `path` part of the URI
# is treated as a prefix if specified, so the filesystem is wrapped in a
# SubTreeFileSystem
if use_mmap:
raise ValueError(
"Specifying to use memory mapping not supported for "
"filesystem specified as an URI string"
)
fs, path = FileSystem.from_uri(filesystem)
prefix = fs.normalize_path(path)
if prefix:
# validate that the prefix is pointing to a directory
prefix_info = fs.get_file_info([prefix])[0]
if prefix_info.type != FileType.Directory:
raise ValueError(
"The path component of the filesystem URI must point to a "
f"directory but it has a type: `{prefix_info.type.name}`. The path "
f"component is `{prefix_info.path}` and the given filesystem URI "
f"is `{filesystem}`"
)
fs = SubTreeFileSystem(prefix, fs)
return fs
else:
# handle fsspec-compatible filesystems
try:
import fsspec
except ImportError:
pass
else:
if isinstance(filesystem, fsspec.AbstractFileSystem):
if type(filesystem).__name__ == 'LocalFileSystem':
# In case its a simple LocalFileSystem, use native arrow one
return LocalFileSystem(use_mmap=use_mmap)
return PyFileSystem(FSSpecHandler(filesystem))
raise TypeError(
f"Unrecognized filesystem: {type(filesystem)}. `filesystem` argument must "
"be a FileSystem instance or a valid file system URI"
)
def _resolve_filesystem_and_path(path, filesystem=None, *, memory_map=False):
"""
Return filesystem/path from path which could be an URI or a plain
filesystem path or a combination of fsspec protocol and URI.
"""
if not _is_path_like(path):
if filesystem is not None:
raise ValueError(
"'filesystem' passed but the specified path is file-like, so"
" there is nothing to open with 'filesystem'."
)
return filesystem, path
if filesystem is not None:
filesystem = _ensure_filesystem(filesystem, use_mmap=memory_map)
if isinstance(filesystem, LocalFileSystem):
path = _stringify_path(path)
elif not isinstance(path, str):
raise TypeError(
"Expected string path; path-like objects are only allowed "
"with a local filesystem"
)
path = filesystem.normalize_path(path)
return filesystem, path
path = _stringify_path(path)
# if filesystem is not given, try to automatically determine one
# first check if the file exists as a local (relative) file path
# if not then try to parse the path as an URI
filesystem = LocalFileSystem(use_mmap=memory_map)
try:
file_info = filesystem.get_file_info(path)
except ValueError: # ValueError means path is likely an URI
file_info = None
exists_locally = False
else:
exists_locally = (file_info.type != FileType.NotFound)
# if the file or directory doesn't exists locally, then assume that
# the path is an URI describing the file system as well
if not exists_locally:
try:
filesystem, path = FileSystem.from_uri(path)
except ValueError as e:
msg = str(e)
if "empty scheme" in msg or "Cannot parse URI" in msg:
# neither an URI nor a locally existing path, so assume that
# local path was given and propagate a nicer file not found
# error instead of a more confusing scheme parsing error
pass
else:
raise e
else:
path = filesystem.normalize_path(path)
return filesystem, path
def copy_files(source, destination,
source_filesystem=None, destination_filesystem=None,
*, chunk_size=1024*1024, use_threads=True):
"""
Copy files between FileSystems.
This functions allows you to recursively copy directories of files from
one file system to another, such as from S3 to your local machine.
Parameters
----------
source : string
Source file path or URI to a single file or directory.
If a directory, files will be copied recursively from this path.
destination : string
Destination file path or URI. If `source` is a file, `destination`
is also interpreted as the destination file (not directory).
Directories will be created as necessary.
source_filesystem : FileSystem, optional
Source filesystem, needs to be specified if `source` is not a URI,
otherwise inferred.
destination_filesystem : FileSystem, optional
Destination filesystem, needs to be specified if `destination` is not
a URI, otherwise inferred.
chunk_size : int, default 1MB
The maximum size of block to read before flushing to the
destination file. A larger chunk_size will use more memory while
copying but may help accommodate high latency FileSystems.
use_threads : bool, default True
Whether to use multiple threads to accelerate copying.
Examples
--------
Inspect an S3 bucket's files:
>>> s3, path = fs.FileSystem.from_uri(
... "s3://registry.opendata.aws/roda/ndjson/")
>>> selector = fs.FileSelector(path)
>>> s3.get_file_info(selector)
[<FileInfo for 'registry.opendata.aws/roda/ndjson/index.ndjson':...]
Copy one file from S3 bucket to a local directory:
>>> fs.copy_files("s3://registry.opendata.aws/roda/ndjson/index.ndjson",
... f"file:///{local_path}/index_copy.ndjson")
>>> fs.LocalFileSystem().get_file_info(str(local_path)+
... '/index_copy.ndjson')
<FileInfo for '.../index_copy.ndjson': type=FileType.File, size=...>
Copy file using a FileSystem object:
>>> fs.copy_files("registry.opendata.aws/roda/ndjson/index.ndjson",
... f"file:///{local_path}/index_copy.ndjson",
... source_filesystem=fs.S3FileSystem())
"""
source_fs, source_path = _resolve_filesystem_and_path(
source, source_filesystem
)
destination_fs, destination_path = _resolve_filesystem_and_path(
destination, destination_filesystem
)
file_info = source_fs.get_file_info(source_path)
if file_info.type == FileType.Directory:
source_sel = FileSelector(source_path, recursive=True)
_copy_files_selector(source_fs, source_sel,
destination_fs, destination_path,
chunk_size, use_threads)
else:
_copy_files(source_fs, source_path,
destination_fs, destination_path,
chunk_size, use_threads)
class FSSpecHandler(FileSystemHandler):
"""
Handler for fsspec-based Python filesystems.
https://filesystem-spec.readthedocs.io/en/latest/index.html
Parameters
----------
fs : FSSpec-compliant filesystem instance
Examples
--------
>>> PyFileSystem(FSSpecHandler(fsspec_fs)) # doctest: +SKIP
"""
def __init__(self, fs):
self.fs = fs
def __eq__(self, other):
if isinstance(other, FSSpecHandler):
return self.fs == other.fs
return NotImplemented
def __ne__(self, other):
if isinstance(other, FSSpecHandler):
return self.fs != other.fs
return NotImplemented
def get_type_name(self):
protocol = self.fs.protocol
if isinstance(protocol, list):
protocol = protocol[0]
return f"fsspec+{protocol}"
def normalize_path(self, path):
return path
@staticmethod
def _create_file_info(path, info):
size = info["size"]
if info["type"] == "file":
ftype = FileType.File
elif info["type"] == "directory":
ftype = FileType.Directory
# some fsspec filesystems include a file size for directories
size = None
else:
ftype = FileType.Unknown
return FileInfo(path, ftype, size=size, mtime=info.get("mtime", None))
def get_file_info(self, paths):
infos = []
for path in paths:
try:
info = self.fs.info(path)
except FileNotFoundError:
infos.append(FileInfo(path, FileType.NotFound))
else:
infos.append(self._create_file_info(path, info))
return infos
def get_file_info_selector(self, selector):
if not self.fs.isdir(selector.base_dir):
if self.fs.exists(selector.base_dir):
raise NotADirectoryError(selector.base_dir)
else:
if selector.allow_not_found:
return []
else:
raise FileNotFoundError(selector.base_dir)
if selector.recursive:
maxdepth = None
else:
maxdepth = 1
infos = []
selected_files = self.fs.find(
selector.base_dir, maxdepth=maxdepth, withdirs=True, detail=True
)
for path, info in selected_files.items():
_path = path.strip("/")
base_dir = selector.base_dir.strip("/")
# Need to exclude base directory from selected files if present
# (fsspec filesystems, see GH-37555)
if _path != base_dir:
infos.append(self._create_file_info(path, info))
return infos
def create_dir(self, path, recursive):
# mkdir also raises FileNotFoundError when base directory is not found
try:
self.fs.mkdir(path, create_parents=recursive)
except FileExistsError:
pass
def delete_dir(self, path):
self.fs.rm(path, recursive=True)
def _delete_dir_contents(self, path, missing_dir_ok):
try:
subpaths = self.fs.listdir(path, detail=False)
except FileNotFoundError:
if missing_dir_ok:
return
raise
for subpath in subpaths:
if self.fs.isdir(subpath):
self.fs.rm(subpath, recursive=True)
elif self.fs.isfile(subpath):
self.fs.rm(subpath)
def delete_dir_contents(self, path, missing_dir_ok):
if path.strip("/") == "":
raise ValueError(
"delete_dir_contents called on path '", path, "'")
self._delete_dir_contents(path, missing_dir_ok)
def delete_root_dir_contents(self):
self._delete_dir_contents("/")
def delete_file(self, path):
# fs.rm correctly raises IsADirectoryError when `path` is a directory
# instead of a file and `recursive` is not set to True
if not self.fs.exists(path):
raise FileNotFoundError(path)
self.fs.rm(path)
def move(self, src, dest):
self.fs.mv(src, dest, recursive=True)
def copy_file(self, src, dest):
# fs.copy correctly raises IsADirectoryError when `src` is a directory
# instead of a file
self.fs.copy(src, dest)
# TODO can we read/pass metadata (e.g. Content-Type) in the methods below?
def open_input_stream(self, path):
from pyarrow import PythonFile
if not self.fs.isfile(path):
raise FileNotFoundError(path)
return PythonFile(self.fs.open(path, mode="rb"), mode="r")
def open_input_file(self, path):
from pyarrow import PythonFile
if not self.fs.isfile(path):
raise FileNotFoundError(path)
return PythonFile(self.fs.open(path, mode="rb"), mode="r")
def open_output_stream(self, path, metadata):
from pyarrow import PythonFile
return PythonFile(self.fs.open(path, mode="wb"), mode="w")
def open_append_stream(self, path, metadata):
from pyarrow import PythonFile
return PythonFile(self.fs.open(path, mode="ab"), mode="w")

View File

@@ -0,0 +1,756 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: profile=False
# distutils: language = c++
# cython: language_level = 3
from libcpp.memory cimport shared_ptr
from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector
from libcpp.unordered_set cimport unordered_set as c_unordered_set
from libc.stdint cimport int64_t, int32_t
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport (DataType, Field, MemoryPool, RecordBatch,
Schema, check_status, pyarrow_wrap_array,
pyarrow_wrap_data_type, ensure_type, _Weakrefable,
pyarrow_wrap_field)
from pyarrow.includes.libgandiva cimport (
CCondition, CGandivaExpression,
CNode, CProjector, CFilter,
CSelectionVector,
_ensure_selection_mode,
CConfiguration,
CConfigurationBuilder,
TreeExprBuilder_MakeExpression,
TreeExprBuilder_MakeFunction,
TreeExprBuilder_MakeBoolLiteral,
TreeExprBuilder_MakeUInt8Literal,
TreeExprBuilder_MakeUInt16Literal,
TreeExprBuilder_MakeUInt32Literal,
TreeExprBuilder_MakeUInt64Literal,
TreeExprBuilder_MakeInt8Literal,
TreeExprBuilder_MakeInt16Literal,
TreeExprBuilder_MakeInt32Literal,
TreeExprBuilder_MakeInt64Literal,
TreeExprBuilder_MakeFloatLiteral,
TreeExprBuilder_MakeDoubleLiteral,
TreeExprBuilder_MakeStringLiteral,
TreeExprBuilder_MakeBinaryLiteral,
TreeExprBuilder_MakeField,
TreeExprBuilder_MakeIf,
TreeExprBuilder_MakeAnd,
TreeExprBuilder_MakeOr,
TreeExprBuilder_MakeCondition,
TreeExprBuilder_MakeInExpressionInt32,
TreeExprBuilder_MakeInExpressionInt64,
TreeExprBuilder_MakeInExpressionTime32,
TreeExprBuilder_MakeInExpressionTime64,
TreeExprBuilder_MakeInExpressionDate32,
TreeExprBuilder_MakeInExpressionDate64,
TreeExprBuilder_MakeInExpressionTimeStamp,
TreeExprBuilder_MakeInExpressionString,
SelectionVector_MakeInt16,
SelectionVector_MakeInt32,
SelectionVector_MakeInt64,
Projector_Make,
Filter_Make,
CFunctionSignature,
GetRegisteredFunctionSignatures)
cdef class Node(_Weakrefable):
cdef:
shared_ptr[CNode] node
def __init__(self):
raise TypeError(f"Do not call {self.__class__.__name__}'s constructor directly, use the "
"TreeExprBuilder API directly")
@staticmethod
cdef create(shared_ptr[CNode] node):
cdef Node self = Node.__new__(Node)
self.node = node
return self
def __str__(self):
return self.node.get().ToString().decode()
def __repr__(self):
type_format = object.__repr__(self)
return f"{type_format}\n{self}"
def return_type(self):
return pyarrow_wrap_data_type(self.node.get().return_type())
cdef class Expression(_Weakrefable):
cdef:
shared_ptr[CGandivaExpression] expression
cdef void init(self, shared_ptr[CGandivaExpression] expression):
self.expression = expression
def __str__(self):
return self.expression.get().ToString().decode()
def __repr__(self):
type_format = object.__repr__(self)
return f"{type_format}\n{self}"
def root(self):
return Node.create(self.expression.get().root())
def result(self):
return pyarrow_wrap_field(self.expression.get().result())
cdef class Condition(_Weakrefable):
cdef:
shared_ptr[CCondition] condition
def __init__(self):
raise TypeError(f"Do not call {self.__class__.__name__}'s constructor directly, use the "
"TreeExprBuilder API instead")
@staticmethod
cdef create(shared_ptr[CCondition] condition):
cdef Condition self = Condition.__new__(Condition)
self.condition = condition
return self
def __str__(self):
return self.condition.get().ToString().decode()
def __repr__(self):
type_format = object.__repr__(self)
return f"{type_format}\n{self}"
def root(self):
return Node.create(self.condition.get().root())
def result(self):
return pyarrow_wrap_field(self.condition.get().result())
cdef class SelectionVector(_Weakrefable):
cdef:
shared_ptr[CSelectionVector] selection_vector
def __init__(self):
raise TypeError(
f"Do not call {self.__class__.__name__}'s constructor directly.")
@staticmethod
cdef create(shared_ptr[CSelectionVector] selection_vector):
cdef SelectionVector self = SelectionVector.__new__(SelectionVector)
self.selection_vector = selection_vector
return self
def to_array(self):
cdef shared_ptr[CArray] result = self.selection_vector.get().ToArray()
return pyarrow_wrap_array(result)
cdef class Projector(_Weakrefable):
cdef:
shared_ptr[CProjector] projector
MemoryPool pool
def __init__(self):
raise TypeError(f"Do not call {self.__class__.__name__}'s constructor directly, use "
"make_projector instead")
@staticmethod
cdef create(shared_ptr[CProjector] projector, MemoryPool pool):
cdef Projector self = Projector.__new__(Projector)
self.projector = projector
self.pool = pool
return self
@property
def llvm_ir(self):
return self.projector.get().DumpIR().decode()
def evaluate(self, RecordBatch batch, SelectionVector selection=None):
"""
Evaluate the specified record batch and return the arrays at the
filtered positions.
Parameters
----------
batch : pyarrow.RecordBatch
selection : pyarrow.gandiva.SelectionVector
Returns
-------
list[pyarrow.Array]
"""
cdef vector[shared_ptr[CArray]] results
if selection is None:
check_status(self.projector.get().Evaluate(
batch.sp_batch.get()[0], self.pool.pool, &results))
else:
check_status(
self.projector.get().Evaluate(
batch.sp_batch.get()[0], selection.selection_vector.get(),
self.pool.pool, &results))
cdef shared_ptr[CArray] result
arrays = []
for result in results:
arrays.append(pyarrow_wrap_array(result))
return arrays
cdef class Filter(_Weakrefable):
cdef:
shared_ptr[CFilter] filter
def __init__(self):
raise TypeError(f"Do not call {self.__class__.__name__}'s constructor directly, use "
"make_filter instead")
@staticmethod
cdef create(shared_ptr[CFilter] filter):
cdef Filter self = Filter.__new__(Filter)
self.filter = filter
return self
@property
def llvm_ir(self):
return self.filter.get().DumpIR().decode()
def evaluate(self, RecordBatch batch, MemoryPool pool, dtype='int32'):
"""
Evaluate the specified record batch and return a selection vector.
Parameters
----------
batch : pyarrow.RecordBatch
pool : MemoryPool
dtype : DataType or str, default int32
Returns
-------
pyarrow.gandiva.SelectionVector
"""
cdef:
DataType type = ensure_type(dtype)
shared_ptr[CSelectionVector] selection
if type.id == _Type_INT16:
check_status(SelectionVector_MakeInt16(
batch.num_rows, pool.pool, &selection))
elif type.id == _Type_INT32:
check_status(SelectionVector_MakeInt32(
batch.num_rows, pool.pool, &selection))
elif type.id == _Type_INT64:
check_status(SelectionVector_MakeInt64(
batch.num_rows, pool.pool, &selection))
else:
raise ValueError("'dtype' of the selection vector should be "
"one of 'int16', 'int32' and 'int64'.")
check_status(self.filter.get().Evaluate(
batch.sp_batch.get()[0], selection))
return SelectionVector.create(selection)
cdef class TreeExprBuilder(_Weakrefable):
def make_literal(self, value, dtype):
"""
Create a node on a literal.
Parameters
----------
value : a literal value
dtype : DataType
Returns
-------
pyarrow.gandiva.Node
"""
cdef:
DataType type = ensure_type(dtype)
shared_ptr[CNode] r
if type.id == _Type_BOOL:
r = TreeExprBuilder_MakeBoolLiteral(value)
elif type.id == _Type_UINT8:
r = TreeExprBuilder_MakeUInt8Literal(value)
elif type.id == _Type_UINT16:
r = TreeExprBuilder_MakeUInt16Literal(value)
elif type.id == _Type_UINT32:
r = TreeExprBuilder_MakeUInt32Literal(value)
elif type.id == _Type_UINT64:
r = TreeExprBuilder_MakeUInt64Literal(value)
elif type.id == _Type_INT8:
r = TreeExprBuilder_MakeInt8Literal(value)
elif type.id == _Type_INT16:
r = TreeExprBuilder_MakeInt16Literal(value)
elif type.id == _Type_INT32:
r = TreeExprBuilder_MakeInt32Literal(value)
elif type.id == _Type_INT64:
r = TreeExprBuilder_MakeInt64Literal(value)
elif type.id == _Type_FLOAT:
r = TreeExprBuilder_MakeFloatLiteral(value)
elif type.id == _Type_DOUBLE:
r = TreeExprBuilder_MakeDoubleLiteral(value)
elif type.id == _Type_STRING:
r = TreeExprBuilder_MakeStringLiteral(value.encode('UTF-8'))
elif type.id == _Type_BINARY:
r = TreeExprBuilder_MakeBinaryLiteral(value)
else:
raise TypeError("Didn't recognize dtype " + str(dtype))
return Node.create(r)
def make_expression(self, Node root_node not None,
Field return_field not None):
"""
Create an expression with the specified root_node,
and the result written to result_field.
Parameters
----------
root_node : pyarrow.gandiva.Node
return_field : pyarrow.Field
Returns
-------
pyarrow.gandiva.Expression
"""
cdef shared_ptr[CGandivaExpression] r = TreeExprBuilder_MakeExpression(
root_node.node, return_field.sp_field)
cdef Expression expression = Expression()
expression.init(r)
return expression
def make_function(self, name, children, DataType return_type):
"""
Create a node with a function.
Parameters
----------
name : str
children : pyarrow.gandiva.NodeVector
return_type : DataType
Returns
-------
pyarrow.gandiva.Node
"""
cdef c_vector[shared_ptr[CNode]] c_children
cdef Node child
for child in children:
if child is None:
raise TypeError("Child nodes must not be None")
c_children.push_back(child.node)
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeFunction(
name.encode(), c_children, return_type.sp_type)
return Node.create(r)
def make_field(self, Field field not None):
"""
Create a node with an Arrow field.
Parameters
----------
field : pyarrow.Field
Returns
-------
pyarrow.gandiva.Node
"""
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeField(field.sp_field)
return Node.create(r)
def make_if(self, Node condition not None, Node this_node not None,
Node else_node not None, DataType return_type not None):
"""
Create a node with an if-else expression.
Parameters
----------
condition : pyarrow.gandiva.Node
this_node : pyarrow.gandiva.Node
else_node : pyarrow.gandiva.Node
return_type : DataType
Returns
-------
pyarrow.gandiva.Node
"""
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeIf(
condition.node, this_node.node, else_node.node,
return_type.sp_type)
return Node.create(r)
def make_and(self, children):
"""
Create a Node with a boolean AND expression.
Parameters
----------
children : list[pyarrow.gandiva.Node]
Returns
-------
pyarrow.gandiva.Node
"""
cdef c_vector[shared_ptr[CNode]] c_children
cdef Node child
for child in children:
if child is None:
raise TypeError("Child nodes must not be None")
c_children.push_back(child.node)
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeAnd(c_children)
return Node.create(r)
def make_or(self, children):
"""
Create a Node with a boolean OR expression.
Parameters
----------
children : list[pyarrow.gandiva.Node]
Returns
-------
pyarrow.gandiva.Node
"""
cdef c_vector[shared_ptr[CNode]] c_children
cdef Node child
for child in children:
if child is None:
raise TypeError("Child nodes must not be None")
c_children.push_back(child.node)
cdef shared_ptr[CNode] r = TreeExprBuilder_MakeOr(c_children)
return Node.create(r)
def _make_in_expression_int32(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int32_t] c_values
cdef int32_t v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionInt32(node.node, c_values)
return Node.create(r)
def _make_in_expression_int64(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionInt64(node.node, c_values)
return Node.create(r)
def _make_in_expression_time32(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int32_t] c_values
cdef int32_t v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionTime32(node.node, c_values)
return Node.create(r)
def _make_in_expression_time64(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionTime64(node.node, c_values)
return Node.create(r)
def _make_in_expression_date32(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int32_t] c_values
cdef int32_t v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionDate32(node.node, c_values)
return Node.create(r)
def _make_in_expression_date64(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionDate64(node.node, c_values)
return Node.create(r)
def _make_in_expression_timestamp(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[int64_t] c_values
cdef int64_t v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionTimeStamp(node.node, c_values)
return Node.create(r)
def _make_in_expression_binary(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[c_string] c_values
cdef c_string v
for v in values:
c_values.insert(v)
r = TreeExprBuilder_MakeInExpressionString(node.node, c_values)
return Node.create(r)
def _make_in_expression_string(self, Node node not None, values):
cdef shared_ptr[CNode] r
cdef c_unordered_set[c_string] c_values
cdef c_string _v
for v in values:
_v = v.encode('UTF-8')
c_values.insert(_v)
r = TreeExprBuilder_MakeInExpressionString(node.node, c_values)
return Node.create(r)
def make_in_expression(self, Node node not None, values, dtype):
"""
Create a Node with an IN expression.
Parameters
----------
node : pyarrow.gandiva.Node
values : iterable
dtype : DataType
Returns
-------
pyarrow.gandiva.Node
"""
cdef DataType type = ensure_type(dtype)
if type.id == _Type_INT32:
return self._make_in_expression_int32(node, values)
elif type.id == _Type_INT64:
return self._make_in_expression_int64(node, values)
elif type.id == _Type_TIME32:
return self._make_in_expression_time32(node, values)
elif type.id == _Type_TIME64:
return self._make_in_expression_time64(node, values)
elif type.id == _Type_TIMESTAMP:
return self._make_in_expression_timestamp(node, values)
elif type.id == _Type_DATE32:
return self._make_in_expression_date32(node, values)
elif type.id == _Type_DATE64:
return self._make_in_expression_date64(node, values)
elif type.id == _Type_BINARY:
return self._make_in_expression_binary(node, values)
elif type.id == _Type_STRING:
return self._make_in_expression_string(node, values)
else:
raise TypeError("Data type " + str(dtype) + " not supported.")
def make_condition(self, Node condition not None):
"""
Create a condition with the specified node.
Parameters
----------
condition : pyarrow.gandiva.Node
Returns
-------
pyarrow.gandiva.Condition
"""
cdef shared_ptr[CCondition] r = TreeExprBuilder_MakeCondition(
condition.node)
return Condition.create(r)
cdef class Configuration(_Weakrefable):
cdef:
shared_ptr[CConfiguration] configuration
def __cinit__(self, bint optimize=True, bint dump_ir=False):
"""
Initialize the configuration with specified options.
Parameters
----------
optimize : bool, default True
Whether to enable optimizations.
dump_ir : bool, default False
Whether to dump LLVM IR.
"""
self.configuration = CConfigurationBuilder().build()
self.configuration.get().set_optimize(optimize)
self.configuration.get().set_dump_ir(dump_ir)
@staticmethod
cdef create(shared_ptr[CConfiguration] configuration):
"""
Create a Configuration instance from an existing CConfiguration pointer.
Parameters
----------
configuration : shared_ptr[CConfiguration]
Existing CConfiguration pointer.
Returns
-------
Configuration instance
"""
cdef Configuration self = Configuration.__new__(Configuration)
self.configuration = configuration
return self
cpdef make_projector(Schema schema, children, MemoryPool pool,
str selection_mode="NONE",
Configuration configuration=None):
"""
Construct a projection using expressions.
A projector is built for a specific schema and vector of expressions.
Once the projector is built, it can be used to evaluate many row batches.
Parameters
----------
schema : pyarrow.Schema
Schema for the record batches, and the expressions.
children : list[pyarrow.gandiva.Expression]
List of projectable expression objects.
pool : pyarrow.MemoryPool
Memory pool used to allocate output arrays.
selection_mode : str, default "NONE"
Possible values are NONE, UINT16, UINT32, UINT64.
configuration : pyarrow.gandiva.Configuration, default None
Configuration for the projector.
Returns
-------
Projector instance
"""
cdef:
Expression child
c_vector[shared_ptr[CGandivaExpression]] c_children
shared_ptr[CProjector] result
if configuration is None:
configuration = Configuration()
for child in children:
if child is None:
raise TypeError("Expressions must not be None")
c_children.push_back(child.expression)
check_status(
Projector_Make(schema.sp_schema, c_children,
_ensure_selection_mode(selection_mode),
configuration.configuration,
&result))
return Projector.create(result, pool)
cpdef make_filter(Schema schema, Condition condition,
Configuration configuration=None):
"""
Construct a filter based on a condition.
A filter is built for a specific schema and condition. Once the filter is
built, it can be used to evaluate many row batches.
Parameters
----------
schema : pyarrow.Schema
Schema for the record batches, and the condition.
condition : pyarrow.gandiva.Condition
Filter condition.
configuration : pyarrow.gandiva.Configuration, default None
Configuration for the filter.
Returns
-------
Filter instance
"""
cdef shared_ptr[CFilter] result
if condition is None:
raise TypeError("Condition must not be None")
if configuration is None:
configuration = Configuration()
check_status(
Filter_Make(schema.sp_schema, condition.condition, configuration.configuration, &result))
return Filter.create(result)
cdef class FunctionSignature(_Weakrefable):
"""
Signature of a Gandiva function including name, parameter types
and return type.
"""
cdef:
shared_ptr[CFunctionSignature] signature
def __init__(self):
raise TypeError(
f"Do not call {self.__class__.__name__}'s constructor directly.")
@staticmethod
cdef create(shared_ptr[CFunctionSignature] signature):
cdef FunctionSignature self = FunctionSignature.__new__(
FunctionSignature)
self.signature = signature
return self
def return_type(self):
return pyarrow_wrap_data_type(self.signature.get().ret_type())
def param_types(self):
result = []
cdef vector[shared_ptr[CDataType]] types = \
self.signature.get().param_types()
for t in types:
result.append(pyarrow_wrap_data_type(t))
return result
def name(self):
return self.signature.get().base_name().decode()
def __repr__(self):
signature = self.signature.get().ToString().decode()
return "FunctionSignature(" + signature + ")"
def get_registered_function_signatures():
"""
Return the function in Gandiva's ExpressionRegistry.
Returns
-------
registry: a list of registered function signatures
"""
results = []
cdef vector[shared_ptr[CFunctionSignature]] signatures = \
GetRegisteredFunctionSignatures()
for signature in signatures:
results.append(FunctionSignature.create(signature))
return results

View File

@@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cstdint>
#include <functional>
#include <optional>
#include <vector>
#include "arrow/acero/visibility.h"
#include "arrow/compute/exec.h"
#include "arrow/result.h"
namespace arrow {
namespace acero {
namespace util {
using arrow::compute::ExecBatch;
/// \brief A container that accumulates batches until they are ready to
/// be processed.
class ARROW_ACERO_EXPORT AccumulationQueue {
public:
AccumulationQueue() : row_count_(0) {}
~AccumulationQueue() = default;
// We should never be copying ExecBatch around
AccumulationQueue(const AccumulationQueue&) = delete;
AccumulationQueue& operator=(const AccumulationQueue&) = delete;
AccumulationQueue(AccumulationQueue&& that);
AccumulationQueue& operator=(AccumulationQueue&& that);
void Concatenate(AccumulationQueue&& that);
void InsertBatch(ExecBatch batch);
int64_t row_count() { return row_count_; }
size_t batch_count() { return batches_.size(); }
bool empty() const { return batches_.empty(); }
void Clear();
ExecBatch& operator[](size_t i);
private:
int64_t row_count_;
std::vector<ExecBatch> batches_;
};
/// A queue that sequences incoming batches
///
/// This can be used when a node needs to do some kind of ordered processing on
/// the stream.
///
/// Batches can be inserted in any order. The process_callback will be called on
/// the batches, in order, without reentrant calls. For this reason the callback
/// should be quick.
///
/// For example, in a top-n node, the process callback should determine how many
/// rows need to be delivered for the given batch, and then return a task to actually
/// deliver those rows.
class ARROW_ACERO_EXPORT SequencingQueue {
public:
using Task = std::function<Status()>;
/// Strategy that describes how to handle items
class Processor {
public:
/// Process the batch, potentially generating a task
///
/// This method will be called on each batch in order. Calls to this method
/// will be serialized and it will not be called reentrantly. This makes it
/// safe to do things that rely on order but minimal time should be spent here
/// to avoid becoming a bottleneck.
///
/// \return a follow-up task that will be scheduled. The follow-up task(s) are
/// is not guaranteed to run in any particular order. If nullopt is
/// returned then nothing will be scheduled.
virtual Result<std::optional<Task>> Process(ExecBatch batch) = 0;
/// Schedule a task
virtual void Schedule(Task task) = 0;
};
virtual ~SequencingQueue() = default;
/// Insert a batch into the queue
///
/// This will insert the batch into the queue. If this batch was the next batch
/// to deliver then this will trigger 1+ calls to the process callback to generate
/// 1+ tasks.
///
/// The task generated by this call will be executed immediately. The remaining
/// tasks will be scheduled using the schedule callback.
///
/// From a data pipeline perspective the sequencing queue is a "sometimes" breaker. If
/// a task arrives in order then this call will usually execute the downstream pipeline.
/// If this task arrives early then this call will only queue the data.
virtual Status InsertBatch(ExecBatch batch) = 0;
/// Create a queue
/// \param processor describes how to process the batches, must outlive the queue
static std::unique_ptr<SequencingQueue> Make(Processor* processor);
};
/// A queue that sequences incoming batches
///
/// Unlike SequencingQueue the Process method is not expected to schedule new tasks.
///
/// If a batch arrives and another thread is currently processing then the batch
/// will be queued and control will return. In other words, delivery of batches will
/// not block on the Process method.
///
/// It can be helpful to think of this as if a dedicated thread is running Process as
/// batches arrive
class ARROW_ACERO_EXPORT SerialSequencingQueue {
public:
/// Strategy that describes how to handle items
class Processor {
public:
virtual ~Processor() = default;
/// Process the batch
///
/// This method will be called on each batch in order. Calls to this method
/// will be serialized and it will not be called reentrantly. This makes it
/// safe to do things that rely on order.
///
/// If this falls behind then data may accumulate
///
/// TODO: Could add backpressure if needed but right now all uses of this should
/// be pretty fast and so are unlikely to block.
virtual Status Process(ExecBatch batch) = 0;
};
virtual ~SerialSequencingQueue() = default;
/// Insert a batch into the queue
///
/// This will insert the batch into the queue. If this batch was the next batch
/// to deliver then this may trigger calls to the processor which will be run
/// as part of this call.
virtual Status InsertBatch(ExecBatch batch) = 0;
/// Create a queue
/// \param processor describes how to process the batches, must outlive the queue
static std::unique_ptr<SerialSequencingQueue> Make(Processor* processor);
};
} // namespace util
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This API is EXPERIMENTAL.
#pragma once
#include <memory>
#include <vector>
#include "arrow/acero/visibility.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/test_util_internal.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/result.h"
#include "arrow/type_fwd.h"
namespace arrow {
namespace acero {
namespace aggregate {
using compute::Aggregate;
using compute::default_exec_context;
using compute::ExecContext;
/// \brief Make the output schema of an aggregate node
///
/// The output schema is determined by the aggregation kernels, which may depend on the
/// ExecContext argument. To guarantee correct results, the same ExecContext argument
/// should be used in execution.
///
/// \param[in] input_schema the schema of the input to the node
/// \param[in] keys the grouping keys for the aggregation
/// \param[in] segment_keys the segmenting keys for the aggregation
/// \param[in] aggregates the aggregates for the aggregation
/// \param[in] exec_ctx the execution context for the aggregation
ARROW_ACERO_EXPORT Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& keys,
const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& aggregates,
ExecContext* exec_ctx = default_exec_context());
} // namespace aggregate
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,32 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// NOTE: API is EXPERIMENTAL and will change without going through a
// deprecation cycle
#pragma once
/// \defgroup acero-api Utilities for creating and executing execution plans
/// @{
/// @}
/// \defgroup acero-nodes Options classes for the various exec nodes
/// @{
/// @}
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/options.h"

View File

@@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <vector>
#include "arrow/acero/options.h"
#include "arrow/acero/visibility.h"
#include "arrow/compute/exec.h"
#include "arrow/type.h"
namespace arrow {
namespace acero {
namespace asofjoin {
using AsofJoinKeys = AsofJoinNodeOptions::Keys;
/// \brief Make the output schema of an as-of-join node
///
/// \param[in] input_schema the schema of each input to the node
/// \param[in] input_keys the key of each input to the node
ARROW_ACERO_EXPORT Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::vector<std::shared_ptr<Schema>>& input_schema,
const std::vector<AsofJoinKeys>& input_keys);
} // namespace asofjoin
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,74 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/options.h"
#include <memory>
namespace arrow::acero {
class BackpressureHandler {
private:
BackpressureHandler(ExecNode* input, size_t low_threshold, size_t high_threshold,
std::unique_ptr<BackpressureControl> backpressure_control)
: input_(input),
low_threshold_(low_threshold),
high_threshold_(high_threshold),
backpressure_control_(std::move(backpressure_control)) {}
public:
static Result<BackpressureHandler> Make(
ExecNode* input, size_t low_threshold, size_t high_threshold,
std::unique_ptr<BackpressureControl> backpressure_control) {
if (low_threshold >= high_threshold) {
return Status::Invalid("low threshold (", low_threshold,
") must be less than high threshold (", high_threshold, ")");
}
if (backpressure_control == NULLPTR) {
return Status::Invalid("null backpressure control parameter");
}
BackpressureHandler backpressure_handler(input, low_threshold, high_threshold,
std::move(backpressure_control));
return backpressure_handler;
}
void Handle(size_t start_level, size_t end_level) {
if (start_level < high_threshold_ && end_level >= high_threshold_) {
backpressure_control_->Pause();
} else if (start_level > low_threshold_ && end_level <= low_threshold_) {
backpressure_control_->Resume();
}
}
Status ForceShutdown() {
// It may be unintuitive to call Resume() here, but this is to avoid a deadlock.
// Since acero's executor won't terminate if any one node is paused, we need to
// force resume the node before stopping production.
backpressure_control_->Resume();
return input_->StopProducing();
}
private:
ExecNode* input_;
size_t low_threshold_;
size_t high_threshold_;
std::unique_ptr<BackpressureControl> backpressure_control_;
};
} // namespace arrow::acero

View File

@@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cstdint>
#include <string>
#include <vector>
#include "benchmark/benchmark.h"
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/test_util_internal.h"
#include "arrow/compute/exec.h"
namespace arrow {
namespace acero {
Status BenchmarkNodeOverhead(benchmark::State& state, int32_t num_batches,
int32_t batch_size, arrow::acero::BatchesWithSchema data,
std::vector<arrow::acero::Declaration>& node_declarations,
arrow::MemoryPool* pool = default_memory_pool());
Status BenchmarkIsolatedNodeOverhead(benchmark::State& state,
arrow::compute::Expression expr, int32_t num_batches,
int32_t batch_size,
arrow::acero::BatchesWithSchema data,
std::string factory_name,
arrow::acero::ExecNodeOptions& options,
arrow::MemoryPool* pool = default_memory_pool());
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,323 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <atomic>
#include <cstdint>
#include <memory>
#include "arrow/acero/partition_util.h"
#include "arrow/acero/util.h"
#include "arrow/memory_pool.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/util/simd.h"
namespace arrow {
namespace acero {
// A set of pre-generated bit masks from a 64-bit word.
//
// It is used to map selected bits of hash to a bit mask that will be used in
// a Bloom filter.
//
// These bit masks need to look random and need to have a similar fractions of
// bits set in order for a Bloom filter to have a low false positives rate.
//
struct ARROW_ACERO_EXPORT BloomFilterMasks {
// Generate all masks as a single bit vector. Each bit offset in this bit
// vector corresponds to a single mask.
// In each consecutive kBitsPerMask bits, there must be between
// kMinBitsSet and kMaxBitsSet bits set.
//
BloomFilterMasks();
inline uint64_t mask(int bit_offset) {
#if ARROW_LITTLE_ENDIAN
return (arrow::util::SafeLoadAs<uint64_t>(masks_ + bit_offset / 8) >>
(bit_offset % 8)) &
kFullMask;
#else
return (BYTESWAP(arrow::util::SafeLoadAs<uint64_t>(masks_ + bit_offset / 8)) >>
(bit_offset % 8)) &
kFullMask;
#endif
}
// Masks are 57 bits long because then they can be accessed at an
// arbitrary bit offset using a single unaligned 64-bit load instruction.
//
static constexpr int kBitsPerMask = 57;
static constexpr uint64_t kFullMask = (1ULL << kBitsPerMask) - 1;
// Minimum and maximum number of bits set in each mask.
// This constraint is enforced when generating the bit masks.
// Values should be close to each other and chosen as to minimize a Bloom
// filter false positives rate.
//
static constexpr int kMinBitsSet = 4;
static constexpr int kMaxBitsSet = 5;
// Number of generated masks.
// Having more masks to choose will improve false positives rate of Bloom
// filter but will also use more memory, which may lead to more CPU cache
// misses.
// The chosen value results in using only a few cache-lines for mask lookups,
// while providing a good variety of available bit masks.
//
static constexpr int kLogNumMasks = 10;
static constexpr int kNumMasks = 1 << kLogNumMasks;
// Data of masks. Masks are stored in a single bit vector. Nth mask is
// kBitsPerMask bits starting at bit offset N.
//
static constexpr int kTotalBytes = (kNumMasks + 64) / 8;
uint8_t masks_[kTotalBytes];
};
// A variant of a blocked Bloom filter implementation.
// A Bloom filter is a data structure that provides approximate membership test
// functionality based only on the hash of the key. Membership test may return
// false positives but not false negatives. Approximation of the result allows
// in general case (for arbitrary data types of keys) to save on both memory and
// lookup cost compared to the accurate membership test.
// The accurate test may sometimes still be cheaper for a specific data types
// and inputs, e.g. integers from a small range.
//
// This blocked Bloom filter is optimized for use in hash joins, to achieve a
// good balance between the size of the filter, the cost of its building and
// querying and the rate of false positives.
//
class ARROW_ACERO_EXPORT BlockedBloomFilter {
friend class BloomFilterBuilder_SingleThreaded;
friend class BloomFilterBuilder_Parallel;
public:
BlockedBloomFilter() : log_num_blocks_(0), num_blocks_(0), blocks_(NULLPTR) {}
inline bool Find(uint64_t hash) const {
uint64_t m = mask(hash);
uint64_t b = blocks_[block_id(hash)];
return (b & m) == m;
}
// Uses SIMD if available for smaller Bloom filters.
// Uses memory prefetching for larger Bloom filters.
//
void Find(int64_t hardware_flags, int64_t num_rows, const uint32_t* hashes,
uint8_t* result_bit_vector, bool enable_prefetch = true) const;
void Find(int64_t hardware_flags, int64_t num_rows, const uint64_t* hashes,
uint8_t* result_bit_vector, bool enable_prefetch = true) const;
int log_num_blocks() const { return log_num_blocks_; }
int NumHashBitsUsed() const;
bool IsSameAs(const BlockedBloomFilter* other) const;
int64_t NumBitsSet() const;
// Folding of a block Bloom filter after the initial version
// has been built.
//
// One of the parameters for creation of Bloom filter is the number
// of bits allocated for it. The more bits allocated, the lower the
// probability of false positives. A good heuristic is to aim for
// half of the bits set in the constructed Bloom filter. This should
// result in a good trade off between size (and following cost of
// memory accesses) and false positives rate.
//
// There might have been many duplicate keys in the input provided
// to Bloom filter builder. In that case the resulting bit vector
// would be more sparse then originally intended. It is possible to
// easily correct that and cut in half the size of Bloom filter
// after it has already been constructed. The process to do that is
// approximately equal to OR-ing bits from upper and lower half (the
// way we address these bits when inserting or querying a hash makes
// such folding in half possible).
//
// We will keep folding as long as the fraction of bits set is less
// than 1/4. The resulting bit vector density should be in the [1/4,
// 1/2) range.
//
void Fold();
private:
Status CreateEmpty(int64_t num_rows_to_insert, MemoryPool* pool);
inline void Insert(uint64_t hash) {
uint64_t m = mask(hash);
uint64_t& b = blocks_[block_id(hash)];
b |= m;
}
void Insert(int64_t hardware_flags, int64_t num_rows, const uint32_t* hashes);
void Insert(int64_t hardware_flags, int64_t num_rows, const uint64_t* hashes);
inline uint64_t mask(uint64_t hash) const {
// The lowest bits of hash are used to pick mask index.
//
int mask_id = static_cast<int>(hash & (BloomFilterMasks::kNumMasks - 1));
uint64_t result = masks_.mask(mask_id);
// The next set of hash bits is used to pick the amount of bit
// rotation of the mask.
//
int rotation = (hash >> BloomFilterMasks::kLogNumMasks) & 63;
result = ROTL64(result, rotation);
return result;
}
inline int64_t block_id(uint64_t hash) const {
// The next set of hash bits following the bits used to select a
// mask is used to pick block id (index of 64-bit word in a bit
// vector).
//
return (hash >> (BloomFilterMasks::kLogNumMasks + 6)) & (num_blocks_ - 1);
}
template <typename T>
inline void InsertImp(int64_t num_rows, const T* hashes);
template <typename T>
inline void FindImp(int64_t num_rows, const T* hashes, uint8_t* result_bit_vector,
bool enable_prefetch) const;
void SingleFold(int num_folds);
#if defined(ARROW_HAVE_RUNTIME_AVX2)
inline __m256i mask_avx2(__m256i hash) const;
inline __m256i block_id_avx2(__m256i hash) const;
int64_t Insert_avx2(int64_t num_rows, const uint32_t* hashes);
int64_t Insert_avx2(int64_t num_rows, const uint64_t* hashes);
template <typename T>
int64_t InsertImp_avx2(int64_t num_rows, const T* hashes);
int64_t Find_avx2(int64_t num_rows, const uint32_t* hashes,
uint8_t* result_bit_vector) const;
int64_t Find_avx2(int64_t num_rows, const uint64_t* hashes,
uint8_t* result_bit_vector) const;
template <typename T>
int64_t FindImp_avx2(int64_t num_rows, const T* hashes,
uint8_t* result_bit_vector) const;
#endif
bool UsePrefetch() const {
return num_blocks_ * sizeof(uint64_t) > kPrefetchLimitBytes;
}
static constexpr int64_t kPrefetchLimitBytes = 256 * 1024;
static BloomFilterMasks masks_;
// Total number of bits used by block Bloom filter must be a power
// of 2.
//
int log_num_blocks_;
int64_t num_blocks_;
// Buffer allocated to store an array of power of 2 64-bit blocks.
//
std::shared_ptr<Buffer> buf_;
// Pointer to mutable data owned by Buffer
//
uint64_t* blocks_;
};
// We have two separate implementations of building a Bloom filter, multi-threaded and
// single-threaded.
//
// Single threaded version is useful in two ways:
// a) It allows to verify parallel implementation in tests (the single threaded one is
// simpler and can be used as the source of truth).
// b) It is preferred for small and medium size Bloom filters, because it skips extra
// synchronization related steps from parallel variant (partitioning and taking locks).
//
enum class BloomFilterBuildStrategy {
SINGLE_THREADED = 0,
PARALLEL = 1,
};
class ARROW_ACERO_EXPORT BloomFilterBuilder {
public:
virtual ~BloomFilterBuilder() = default;
virtual Status Begin(size_t num_threads, int64_t hardware_flags, MemoryPool* pool,
int64_t num_rows, int64_t num_batches,
BlockedBloomFilter* build_target) = 0;
virtual int64_t num_tasks() const { return 0; }
virtual Status PushNextBatch(size_t thread_index, int64_t num_rows,
const uint32_t* hashes) = 0;
virtual Status PushNextBatch(size_t thread_index, int64_t num_rows,
const uint64_t* hashes) = 0;
virtual void CleanUp() {}
static std::unique_ptr<BloomFilterBuilder> Make(BloomFilterBuildStrategy strategy);
};
class ARROW_ACERO_EXPORT BloomFilterBuilder_SingleThreaded : public BloomFilterBuilder {
public:
Status Begin(size_t num_threads, int64_t hardware_flags, MemoryPool* pool,
int64_t num_rows, int64_t num_batches,
BlockedBloomFilter* build_target) override;
Status PushNextBatch(size_t /*thread_index*/, int64_t num_rows,
const uint32_t* hashes) override;
Status PushNextBatch(size_t /*thread_index*/, int64_t num_rows,
const uint64_t* hashes) override;
private:
template <typename T>
void PushNextBatchImp(int64_t num_rows, const T* hashes);
int64_t hardware_flags_;
BlockedBloomFilter* build_target_;
};
class ARROW_ACERO_EXPORT BloomFilterBuilder_Parallel : public BloomFilterBuilder {
public:
Status Begin(size_t num_threads, int64_t hardware_flags, MemoryPool* pool,
int64_t num_rows, int64_t num_batches,
BlockedBloomFilter* build_target) override;
Status PushNextBatch(size_t thread_id, int64_t num_rows,
const uint32_t* hashes) override;
Status PushNextBatch(size_t thread_id, int64_t num_rows,
const uint64_t* hashes) override;
void CleanUp() override;
private:
template <typename T>
void PushNextBatchImp(size_t thread_id, int64_t num_rows, const T* hashes);
int64_t hardware_flags_;
BlockedBloomFilter* build_target_;
int log_num_prtns_;
struct ThreadLocalState {
std::vector<uint32_t> partitioned_hashes_32;
std::vector<uint64_t> partitioned_hashes_64;
std::vector<uint16_t> partition_ranges;
std::vector<int> unprocessed_partition_ids;
};
std::vector<ThreadLocalState> thread_local_states_;
PartitionLocks prtn_locks_;
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,819 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "arrow/acero/type_fwd.h"
#include "arrow/acero/visibility.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/ordering.h"
#include "arrow/type_fwd.h"
#include "arrow/util/future.h"
#include "arrow/util/macros.h"
#include "arrow/util/tracing.h"
#include "arrow/util/type_fwd.h"
namespace arrow {
using compute::ExecBatch;
using compute::ExecContext;
using compute::FunctionRegistry;
using compute::GetFunctionRegistry;
using compute::Ordering;
using compute::threaded_exec_context;
namespace acero {
/// \addtogroup acero-internals
/// @{
class ARROW_ACERO_EXPORT ExecPlan : public std::enable_shared_from_this<ExecPlan> {
public:
// This allows operators to rely on signed 16-bit indices
static const uint32_t kMaxBatchSize = 1 << 15;
using NodeVector = std::vector<ExecNode*>;
virtual ~ExecPlan() = default;
QueryContext* query_context();
/// \brief retrieve the nodes in the plan
const NodeVector& nodes() const;
/// Make an empty exec plan
static Result<std::shared_ptr<ExecPlan>> Make(
QueryOptions options, ExecContext exec_context = *threaded_exec_context(),
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
static Result<std::shared_ptr<ExecPlan>> Make(
ExecContext exec_context = *threaded_exec_context(),
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
static Result<std::shared_ptr<ExecPlan>> Make(
QueryOptions options, ExecContext* exec_context,
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
static Result<std::shared_ptr<ExecPlan>> Make(
ExecContext* exec_context,
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
ExecNode* AddNode(std::unique_ptr<ExecNode> node);
template <typename Node, typename... Args>
Node* EmplaceNode(Args&&... args) {
std::unique_ptr<Node> node{new Node{std::forward<Args>(args)...}};
auto out = node.get();
AddNode(std::move(node));
return out;
}
Status Validate();
/// \brief Start producing on all nodes
///
/// Nodes are started in reverse topological order, such that any node
/// is started before all of its inputs.
void StartProducing();
/// \brief Stop producing on all nodes
///
/// Triggers all sources to stop producing new data. In order to cleanly stop the plan
/// will continue to run any tasks that are already in progress. The caller should
/// still wait for `finished` to complete before destroying the plan.
void StopProducing();
/// \brief A future which will be marked finished when all tasks have finished.
Future<> finished();
/// \brief Return whether the plan has non-empty metadata
bool HasMetadata() const;
/// \brief Return the plan's attached metadata
std::shared_ptr<const KeyValueMetadata> metadata() const;
std::string ToString() const;
};
// Acero can be extended by providing custom implementations of ExecNode. The methods
// below are documented in detail and provide careful instruction on how to fulfill the
// ExecNode contract. It's suggested you familiarize yourself with the Acero
// documentation in the C++ user guide.
class ARROW_ACERO_EXPORT ExecNode {
public:
using NodeVector = std::vector<ExecNode*>;
virtual ~ExecNode() = default;
virtual const char* kind_name() const = 0;
// The number of inputs expected by this node
int num_inputs() const { return static_cast<int>(inputs_.size()); }
/// This node's predecessors in the exec plan
const NodeVector& inputs() const { return inputs_; }
/// True if the plan has no output schema (is a sink)
bool is_sink() const { return !output_schema_; }
/// \brief Labels identifying the function of each input.
const std::vector<std::string>& input_labels() const { return input_labels_; }
/// This node's successor in the exec plan
const ExecNode* output() const { return output_; }
/// The datatypes for batches produced by this node
const std::shared_ptr<Schema>& output_schema() const { return output_schema_; }
/// This node's exec plan
ExecPlan* plan() { return plan_; }
/// \brief An optional label, for display and debugging
///
/// There is no guarantee that this value is non-empty or unique.
const std::string& label() const { return label_; }
void SetLabel(std::string label) { label_ = std::move(label); }
virtual Status Validate() const;
/// \brief the ordering of the output batches
///
/// This does not guarantee the batches will be emitted by this node
/// in order. Instead it guarantees that the batches will have their
/// ExecBatch::index property set in a way that respects this ordering.
///
/// In other words, given the ordering {{"x", SortOrder::Ascending}} we
/// know that all values of x in a batch with index N will be less than
/// or equal to all values of x in a batch with index N+k (assuming k > 0).
/// Furthermore, we also know that values will be sorted within a batch.
/// Any row N will have a value of x that is less than the value for
/// any row N+k.
///
/// Note that an ordering can be both Ordering::Unordered and Ordering::Implicit.
/// A node's output should be marked Ordering::Unordered if the order is
/// non-deterministic. For example, a hash-join has no predictable output order.
///
/// If the ordering is Ordering::Implicit then there is a meaningful order but that
/// ordering is not represented by any column in the data. The most common case for
/// this is when reading data from an in-memory table. The data has an implicit "row
/// order" which is not necessarily represented in the data set.
///
/// A filter or project node will not modify the ordering. Nothing needs to be done
/// other than ensure the index assigned to output batches is the same as the
/// input batch that was mapped.
///
/// Other nodes may introduce order. For example, an order-by node will emit
/// a brand new ordering independent of the input ordering.
///
/// Finally, as described above, such as a hash-join or aggregation may may
/// destroy ordering (although these nodes could also choose to establish a
/// new ordering based on the hash keys).
///
/// Some nodes will require an ordering. For example, a fetch node or an
/// asof join node will only function if the input data is ordered (for fetch
/// it is enough to be implicitly ordered. For an asof join the ordering must
/// be explicit and compatible with the on key.)
///
/// Nodes that maintain ordering should be careful to avoid introducing gaps
/// in the batch index. This may require emitting empty batches in order to
/// maintain continuity.
virtual const Ordering& ordering() const;
/// Upstream API:
/// These functions are called by input nodes that want to inform this node
/// about an updated condition (a new input batch or an impending
/// end of stream).
///
/// Implementation rules:
/// - these may be called anytime after StartProducing() has succeeded
/// (and even during or after StopProducing())
/// - these may be called concurrently
/// - these are allowed to call back into PauseProducing(), ResumeProducing()
/// and StopProducing()
/// Transfer input batch to ExecNode
///
/// A node will typically perform some kind of operation on the batch
/// and then call InputReceived on its outputs with the result.
///
/// Other nodes may need to accumulate some number of inputs before any
/// output can be produced. These nodes will add the batch to some kind
/// of in-memory accumulation queue and return.
virtual Status InputReceived(ExecNode* input, ExecBatch batch) = 0;
/// Mark the inputs finished after the given number of batches.
///
/// This may be called before all inputs are received. This simply fixes
/// the total number of incoming batches for an input, so that the ExecNode
/// knows when it has received all input, regardless of order.
virtual Status InputFinished(ExecNode* input, int total_batches) = 0;
/// \brief Perform any needed initialization
///
/// This hook performs any actions in between creation of ExecPlan and the call to
/// StartProducing. An example could be Bloom filter pushdown. The order of ExecNodes
/// that executes this method is undefined, but the calls are made synchronously.
///
/// At this point a node can rely on all inputs & outputs (and the input schemas)
/// being well defined.
virtual Status Init();
/// Lifecycle API:
/// - start / stop to initiate and terminate production
/// - pause / resume to apply backpressure
///
/// Implementation rules:
/// - StartProducing() should not recurse into the inputs, as it is
/// handled by ExecPlan::StartProducing()
/// - PauseProducing(), ResumeProducing(), StopProducing() may be called
/// concurrently, potentially even before the call to StartProducing
/// has finished.
/// - PauseProducing(), ResumeProducing(), StopProducing() may be called
/// by the downstream nodes' InputReceived(), InputFinished() methods
///
/// StopProducing may be called due to an error, by the user (e.g. cancel), or
/// because a node has all the data it needs (e.g. limit, top-k on sorted data).
/// This means the method may be called multiple times and we have the following
/// additional rules
/// - StopProducing() must be idempotent
/// - StopProducing() must be forwarded to inputs (this is needed for the limit/top-k
/// case because we may not be stopping the entire plan)
// Right now, since synchronous calls happen in both directions (input to
// output and then output to input), a node must be careful to be reentrant
// against synchronous calls from its output, *and* also concurrent calls from
// other threads. The most reliable solution is to update the internal state
// first, and notify outputs only at the end.
//
// Concurrent calls to PauseProducing and ResumeProducing can be hard to sequence
// as they may travel at different speeds through the plan.
//
// For example, consider a resume that comes quickly after a pause. If the source
// receives the resume before the pause the source may think the destination is full
// and halt production which would lead to deadlock.
//
// To resolve this a counter is sent for all calls to pause/resume. Only the call with
// the highest counter value is valid. So if a call to PauseProducing(5) comes after
// a call to ResumeProducing(6) then the source should continue producing.
/// \brief Start producing
///
/// This must only be called once.
///
/// This is typically called automatically by ExecPlan::StartProducing().
virtual Status StartProducing() = 0;
/// \brief Pause producing temporarily
///
/// \param output Pointer to the output that is full
/// \param counter Counter used to sequence calls to pause/resume
///
/// This call is a hint that an output node is currently not willing
/// to receive data.
///
/// This may be called any number of times.
/// However, the node is still free to produce data (which may be difficult
/// to prevent anyway if data is produced using multiple threads).
virtual void PauseProducing(ExecNode* output, int32_t counter) = 0;
/// \brief Resume producing after a temporary pause
///
/// \param output Pointer to the output that is now free
/// \param counter Counter used to sequence calls to pause/resume
///
/// This call is a hint that an output node is willing to receive data again.
///
/// This may be called any number of times.
virtual void ResumeProducing(ExecNode* output, int32_t counter) = 0;
/// \brief Stop producing new data
///
/// If this node is a source then the source should stop generating data
/// as quickly as possible. If this node is not a source then there is typically
/// nothing that needs to be done although a node may choose to start ignoring incoming
/// data.
///
/// This method will be called when an error occurs in the plan
/// This method may also be called by the user if they wish to end a plan early
/// Finally, this method may be called if a node determines it no longer needs any more
/// input (for example, a limit node).
///
/// This method may be called multiple times.
///
/// This is not a pause. There will be no way to start the source again after this has
/// been called.
virtual Status StopProducing();
std::string ToString(int indent = 0) const;
protected:
ExecNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> input_labels,
std::shared_ptr<Schema> output_schema);
virtual Status StopProducingImpl() = 0;
/// Provide extra info to include in the string representation.
virtual std::string ToStringExtra(int indent = 0) const;
std::atomic<bool> stopped_;
ExecPlan* plan_;
std::string label_;
NodeVector inputs_;
std::vector<std::string> input_labels_;
std::shared_ptr<Schema> output_schema_;
ExecNode* output_ = NULLPTR;
};
/// \brief An extensible registry for factories of ExecNodes
class ARROW_ACERO_EXPORT ExecFactoryRegistry {
public:
using Factory = std::function<Result<ExecNode*>(ExecPlan*, std::vector<ExecNode*>,
const ExecNodeOptions&)>;
virtual ~ExecFactoryRegistry() = default;
/// \brief Get the named factory from this registry
///
/// will raise if factory_name is not found
virtual Result<Factory> GetFactory(const std::string& factory_name) = 0;
/// \brief Add a factory to this registry with the provided name
///
/// will raise if factory_name is already in the registry
virtual Status AddFactory(std::string factory_name, Factory factory) = 0;
};
/// The default registry, which includes built-in factories.
ARROW_ACERO_EXPORT
ExecFactoryRegistry* default_exec_factory_registry();
/// \brief Construct an ExecNode using the named factory
inline Result<ExecNode*> MakeExecNode(
const std::string& factory_name, ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options,
ExecFactoryRegistry* registry = default_exec_factory_registry()) {
ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetFactory(factory_name));
return factory(plan, std::move(inputs), options);
}
/// @}
/// \addtogroup acero-api
/// @{
/// \brief Helper class for declaring execution nodes
///
/// A Declaration represents an unconstructed ExecNode (and potentially an entire graph
/// since its inputs may also be Declarations)
///
/// A Declaration can be converted to a plan and executed using one of the
/// DeclarationToXyz methods.
///
/// For more direct control, a Declaration can be added to an existing execution
/// plan with Declaration::AddToPlan, which will recursively construct any inputs as
/// necessary.
struct ARROW_ACERO_EXPORT Declaration {
using Input = std::variant<ExecNode*, Declaration>;
Declaration() {}
/// \brief construct a declaration
/// \param factory_name the name of the exec node to construct. The node must have
/// been added to the exec node registry with this name.
/// \param inputs the inputs to the node, these should be other declarations
/// \param options options that control the behavior of the node. You must use
/// the appropriate subclass. For example, if `factory_name` is
/// "project" then `options` should be ProjectNodeOptions.
/// \param label a label to give the node. Can be used to distinguish it from other
/// nodes of the same type in the plan.
Declaration(std::string factory_name, std::vector<Input> inputs,
std::shared_ptr<ExecNodeOptions> options, std::string label)
: factory_name{std::move(factory_name)},
inputs{std::move(inputs)},
options{std::move(options)},
label{std::move(label)} {}
template <typename Options>
Declaration(std::string factory_name, std::vector<Input> inputs, Options options,
std::string label)
: Declaration{std::move(factory_name), std::move(inputs),
std::shared_ptr<ExecNodeOptions>(
std::make_shared<Options>(std::move(options))),
std::move(label)} {}
template <typename Options>
Declaration(std::string factory_name, std::vector<Input> inputs, Options options)
: Declaration{std::move(factory_name), std::move(inputs), std::move(options),
/*label=*/""} {}
template <typename Options>
Declaration(std::string factory_name, Options options)
: Declaration{std::move(factory_name), {}, std::move(options), /*label=*/""} {}
template <typename Options>
Declaration(std::string factory_name, Options options, std::string label)
: Declaration{std::move(factory_name), {}, std::move(options), std::move(label)} {}
/// \brief Convenience factory for the common case of a simple sequence of nodes.
///
/// Each of decls will be appended to the inputs of the subsequent declaration,
/// and the final modified declaration will be returned.
///
/// Without this convenience factory, constructing a sequence would require explicit,
/// difficult-to-read nesting:
///
/// Declaration{"n3",
/// {
/// Declaration{"n2",
/// {
/// Declaration{"n1",
/// {
/// Declaration{"n0", N0Opts{}},
/// },
/// N1Opts{}},
/// },
/// N2Opts{}},
/// },
/// N3Opts{}};
///
/// An equivalent Declaration can be constructed more tersely using Sequence:
///
/// Declaration::Sequence({
/// {"n0", N0Opts{}},
/// {"n1", N1Opts{}},
/// {"n2", N2Opts{}},
/// {"n3", N3Opts{}},
/// });
static Declaration Sequence(std::vector<Declaration> decls);
/// \brief add the declaration to an already created execution plan
/// \param plan the plan to add the node to
/// \param registry the registry to use to lookup the node factory
///
/// This method will recursively call AddToPlan on all of the declaration's inputs.
/// This method is only for advanced use when the DeclarationToXyz methods are not
/// sufficient.
///
/// \return the instantiated execution node
Result<ExecNode*> AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry =
default_exec_factory_registry()) const;
// Validate a declaration
bool IsValid(ExecFactoryRegistry* registry = default_exec_factory_registry()) const;
/// \brief the name of the factory to use when creating a node
std::string factory_name;
/// \brief the declarations's inputs
std::vector<Input> inputs;
/// \brief options to control the behavior of the node
std::shared_ptr<ExecNodeOptions> options;
/// \brief a label to give the node in the plan
std::string label;
};
/// \brief How to handle unaligned buffers
enum class UnalignedBufferHandling { kWarn, kIgnore, kReallocate, kError };
/// \brief get the default behavior of unaligned buffer handling
///
/// This is configurable via the ACERO_ALIGNMENT_HANDLING environment variable which
/// can be set to "warn", "ignore", "reallocate", or "error". If the environment
/// variable is not set, or is set to an invalid value, this will return kWarn
UnalignedBufferHandling GetDefaultUnalignedBufferHandling();
/// \brief plan-wide options that can be specified when executing an execution plan
struct ARROW_ACERO_EXPORT QueryOptions {
/// \brief Should the plan use a legacy batching strategy
///
/// This is currently in place only to support the Scanner::ToTable
/// method. This method relies on batch indices from the scanner
/// remaining consistent. This is impractical in the ExecPlan which
/// might slice batches as needed (e.g. for a join)
///
/// However, it still works for simple plans and this is the only way
/// we have at the moment for maintaining implicit order.
bool use_legacy_batching = false;
/// If the output has a meaningful order then sequence the output of the plan
///
/// The default behavior (std::nullopt) will sequence output batches if there
/// is a meaningful ordering in the final node and will emit batches immediately
/// otherwise.
///
/// If explicitly set to true then plan execution will fail if there is no
/// meaningful ordering. This can be useful to validate a query that should
/// be emitting ordered results.
///
/// If explicitly set to false then batches will be emit immediately even if there
/// is a meaningful ordering. This could cause batches to be emit out of order but
/// may offer a small decrease to latency.
std::optional<bool> sequence_output = std::nullopt;
/// \brief should the plan use multiple background threads for CPU-intensive work
///
/// If this is false then all CPU work will be done on the calling thread. I/O tasks
/// will still happen on the I/O executor and may be multi-threaded (but should not use
/// significant CPU resources).
///
/// Will be ignored if custom_cpu_executor is set
bool use_threads = true;
/// \brief custom executor to use for CPU-intensive work
///
/// Must be null or remain valid for the duration of the plan. If this is null then
/// a default thread pool will be chosen whose behavior will be controlled by
/// the `use_threads` option.
::arrow::internal::Executor* custom_cpu_executor = NULLPTR;
/// \brief custom executor to use for IO work
///
/// Must be null or remain valid for the duration of the plan. If this is null then
/// the global io thread pool will be chosen whose behavior will be controlled by
/// the "ARROW_IO_THREADS" environment.
::arrow::internal::Executor* custom_io_executor = NULLPTR;
/// \brief a memory pool to use for allocations
///
/// Must remain valid for the duration of the plan.
MemoryPool* memory_pool = default_memory_pool();
/// \brief a function registry to use for the plan
///
/// Must remain valid for the duration of the plan.
FunctionRegistry* function_registry = GetFunctionRegistry();
/// \brief the names of the output columns
///
/// If this is empty then names will be generated based on the input columns
///
/// If set then the number of names must equal the number of output columns
std::vector<std::string> field_names;
/// \brief Policy for unaligned buffers in source data
///
/// Various compute functions and acero internals will type pun array
/// buffers from uint8_t* to some kind of value type (e.g. we might
/// cast to int32_t* to add two int32 arrays)
///
/// If the buffer is poorly aligned (e.g. an int32 array is not aligned
/// on a 4-byte boundary) then this is technically undefined behavior in C++.
/// However, most modern compilers and CPUs are fairly tolerant of this
/// behavior and nothing bad (beyond a small hit to performance) is likely
/// to happen.
///
/// Note that this only applies to source buffers. All buffers allocated internally
/// by Acero will be suitably aligned.
///
/// If this field is set to kWarn then Acero will check if any buffers are unaligned
/// and, if they are, will emit a warning.
///
/// If this field is set to kReallocate then Acero will allocate a new, suitably aligned
/// buffer and copy the contents from the old buffer into this new buffer.
///
/// If this field is set to kError then Acero will gracefully abort the plan instead.
///
/// If this field is set to kIgnore then Acero will not even check if the buffers are
/// unaligned.
///
/// If this field is not set then it will be treated as kWarn unless overridden
/// by the ACERO_ALIGNMENT_HANDLING environment variable
std::optional<UnalignedBufferHandling> unaligned_buffer_handling;
};
/// \brief Calculate the output schema of a declaration
///
/// This does not actually execute the plan. This operation may fail if the
/// declaration represents an invalid plan (e.g. a project node with multiple inputs)
///
/// \param declaration A declaration describing an execution plan
/// \param function_registry The function registry to use for function execution. If null
/// then the default function registry will be used.
///
/// \return the schema that batches would have after going through the execution plan
ARROW_ACERO_EXPORT Result<std::shared_ptr<Schema>> DeclarationToSchema(
const Declaration& declaration, FunctionRegistry* function_registry = NULLPTR);
/// \brief Create a string representation of a plan
///
/// This representation is for debug purposes only.
///
/// Conversion to a string may fail if the declaration represents an
/// invalid plan.
///
/// Use Substrait for complete serialization of plans
///
/// \param declaration A declaration describing an execution plan
/// \param function_registry The function registry to use for function execution. If null
/// then the default function registry will be used.
///
/// \return a string representation of the plan suitable for debugging output
ARROW_ACERO_EXPORT Result<std::string> DeclarationToString(
const Declaration& declaration, FunctionRegistry* function_registry = NULLPTR);
/// \brief Utility method to run a declaration and collect the results into a table
///
/// \param declaration A declaration describing the plan to run
/// \param use_threads If `use_threads` is false then all CPU work will be done on the
/// calling thread. I/O tasks will still happen on the I/O executor
/// and may be multi-threaded (but should not use significant CPU
/// resources).
/// \param memory_pool The memory pool to use for allocations made while running the plan.
/// \param function_registry The function registry to use for function execution. If null
/// then the default function registry will be used.
///
/// This method will add a sink node to the declaration to collect results into a
/// table. It will then create an ExecPlan from the declaration, start the exec plan,
/// block until the plan has finished, and return the created table.
ARROW_ACERO_EXPORT Result<std::shared_ptr<Table>> DeclarationToTable(
Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
ARROW_ACERO_EXPORT Result<std::shared_ptr<Table>> DeclarationToTable(
Declaration declaration, QueryOptions query_options);
/// \brief Asynchronous version of \see DeclarationToTable
///
/// \param declaration A declaration describing the plan to run
/// \param use_threads The behavior of use_threads is slightly different than the
/// synchronous version since we cannot run synchronously on the
/// calling thread. Instead, if use_threads=false then a new thread
/// pool will be created with a single thread and this will be used for
/// all compute work.
/// \param memory_pool The memory pool to use for allocations made while running the plan.
/// \param function_registry The function registry to use for function execution. If null
/// then the default function registry will be used.
ARROW_ACERO_EXPORT Future<std::shared_ptr<Table>> DeclarationToTableAsync(
Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
/// \brief Overload of \see DeclarationToTableAsync accepting a custom exec context
///
/// The executor must be specified (cannot be null) and must be kept alive until the
/// returned future finishes.
ARROW_ACERO_EXPORT Future<std::shared_ptr<Table>> DeclarationToTableAsync(
Declaration declaration, ExecContext custom_exec_context);
/// \brief a collection of exec batches with a common schema
struct BatchesWithCommonSchema {
std::vector<ExecBatch> batches;
std::shared_ptr<Schema> schema;
};
/// \brief Utility method to run a declaration and collect the results into ExecBatch
/// vector
///
/// \see DeclarationToTable for details on threading & execution
ARROW_ACERO_EXPORT Result<BatchesWithCommonSchema> DeclarationToExecBatches(
Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
ARROW_ACERO_EXPORT Result<BatchesWithCommonSchema> DeclarationToExecBatches(
Declaration declaration, QueryOptions query_options);
/// \brief Asynchronous version of \see DeclarationToExecBatches
///
/// \see DeclarationToTableAsync for details on threading & execution
ARROW_ACERO_EXPORT Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(
Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
/// \brief Overload of \see DeclarationToExecBatchesAsync accepting a custom exec context
///
/// \see DeclarationToTableAsync for details on threading & execution
ARROW_ACERO_EXPORT Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(
Declaration declaration, ExecContext custom_exec_context);
/// \brief Utility method to run a declaration and collect the results into a vector
///
/// \see DeclarationToTable for details on threading & execution
ARROW_ACERO_EXPORT Result<std::vector<std::shared_ptr<RecordBatch>>> DeclarationToBatches(
Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
ARROW_ACERO_EXPORT Result<std::vector<std::shared_ptr<RecordBatch>>> DeclarationToBatches(
Declaration declaration, QueryOptions query_options);
/// \brief Asynchronous version of \see DeclarationToBatches
///
/// \see DeclarationToTableAsync for details on threading & execution
ARROW_ACERO_EXPORT Future<std::vector<std::shared_ptr<RecordBatch>>>
DeclarationToBatchesAsync(Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
/// \brief Overload of \see DeclarationToBatchesAsync accepting a custom exec context
///
/// \see DeclarationToTableAsync for details on threading & execution
ARROW_ACERO_EXPORT Future<std::vector<std::shared_ptr<RecordBatch>>>
DeclarationToBatchesAsync(Declaration declaration, ExecContext exec_context);
/// \brief Utility method to run a declaration and return results as a RecordBatchReader
///
/// If an exec context is not provided then a default exec context will be used based
/// on the value of `use_threads`. If `use_threads` is false then the CPU executor will
/// be a serial executor and all CPU work will be done on the calling thread. I/O tasks
/// will still happen on the I/O executor and may be multi-threaded.
///
/// If `use_threads` is false then all CPU work will happen during the calls to
/// RecordBatchReader::Next and no CPU work will happen in the background. If
/// `use_threads` is true then CPU work will happen on the CPU thread pool and tasks may
/// run in between calls to RecordBatchReader::Next. If the returned reader is not
/// consumed quickly enough then the plan will eventually pause as the backpressure queue
/// fills up.
///
/// If a custom exec context is provided then the value of `use_threads` will be ignored.
///
/// The returned RecordBatchReader can be closed early to cancel the computation of record
/// batches. In this case, only errors encountered by the computation may be reported. In
/// particular, no cancellation error may be reported.
ARROW_ACERO_EXPORT Result<std::unique_ptr<RecordBatchReader>> DeclarationToReader(
Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
ARROW_ACERO_EXPORT Result<std::unique_ptr<RecordBatchReader>> DeclarationToReader(
Declaration declaration, QueryOptions query_options);
/// \brief Utility method to run a declaration and ignore results
///
/// This can be useful when the data are consumed as part of the plan itself, for
/// example, when the plan ends with a write node.
///
/// \see DeclarationToTable for details on threading & execution
ARROW_ACERO_EXPORT Status
DeclarationToStatus(Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
ARROW_ACERO_EXPORT Status DeclarationToStatus(Declaration declaration,
QueryOptions query_options);
/// \brief Asynchronous version of \see DeclarationToStatus
///
/// This can be useful when the data are consumed as part of the plan itself, for
/// example, when the plan ends with a write node.
///
/// \see DeclarationToTableAsync for details on threading & execution
ARROW_ACERO_EXPORT Future<> DeclarationToStatusAsync(
Declaration declaration, bool use_threads = true,
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
/// \brief Overload of \see DeclarationToStatusAsync accepting a custom exec context
///
/// \see DeclarationToTableAsync for details on threading & execution
ARROW_ACERO_EXPORT Future<> DeclarationToStatusAsync(Declaration declaration,
ExecContext exec_context);
/// @}
/// \brief Wrap an ExecBatch generator in a RecordBatchReader.
///
/// The RecordBatchReader does not impose any ordering on emitted batches.
ARROW_ACERO_EXPORT
std::shared_ptr<RecordBatchReader> MakeGeneratorReader(
std::shared_ptr<Schema>, std::function<Future<std::optional<ExecBatch>>()>,
MemoryPool*);
constexpr int kDefaultBackgroundMaxQ = 32;
constexpr int kDefaultBackgroundQRestart = 16;
/// \brief Make a generator of RecordBatchReaders
///
/// Useful as a source node for an Exec plan
ARROW_ACERO_EXPORT
Result<std::function<Future<std::optional<ExecBatch>>()>> MakeReaderGenerator(
std::shared_ptr<RecordBatchReader> reader, arrow::internal::Executor* io_executor,
int max_q = kDefaultBackgroundMaxQ, int q_restart = kDefaultBackgroundQRestart);
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,75 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include "arrow/acero/accumulation_queue.h"
#include "arrow/acero/bloom_filter.h"
#include "arrow/acero/options.h"
#include "arrow/acero/query_context.h"
#include "arrow/acero/schema_util.h"
#include "arrow/acero/task_util.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/util/tracing.h"
namespace arrow {
namespace acero {
using util::AccumulationQueue;
class ARROW_ACERO_EXPORT HashJoinImpl {
public:
using OutputBatchCallback = std::function<Status(int64_t, ExecBatch)>;
using BuildFinishedCallback = std::function<Status(size_t)>;
using FinishedCallback = std::function<Status(int64_t)>;
using RegisterTaskGroupCallback = std::function<int(
std::function<Status(size_t, int64_t)>, std::function<Status(size_t)>)>;
using StartTaskGroupCallback = std::function<Status(int, int64_t)>;
using AbortContinuationImpl = std::function<void()>;
virtual ~HashJoinImpl() = default;
virtual Status Init(QueryContext* ctx, JoinType join_type, size_t num_threads,
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
RegisterTaskGroupCallback register_task_group_callback,
StartTaskGroupCallback start_task_group_callback,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback) = 0;
virtual Status BuildHashTable(size_t thread_index, AccumulationQueue batches,
BuildFinishedCallback on_finished) = 0;
virtual Status ProbeSingleBatch(size_t thread_index, ExecBatch batch) = 0;
virtual Status ProbingFinished(size_t thread_index) = 0;
virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0;
virtual std::string ToString() const = 0;
static Result<std::unique_ptr<HashJoinImpl>> MakeBasic();
static Result<std::unique_ptr<HashJoinImpl>> MakeSwiss();
protected:
arrow::util::tracing::Span span_;
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,318 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "arrow/acero/schema_util.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/row/row_encoder_internal.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type.h"
// This file contains hash join logic related to handling of dictionary encoded key
// columns.
//
// A key column from probe side of the join can be matched against a key column from build
// side of the join, as long as the underlying value types are equal. That means that:
// - both scalars and arrays can be used and even mixed in the same column
// - dictionary column can be matched against non-dictionary column if underlying value
// types are equal
// - dictionary column can be matched against dictionary column with a different index
// type, and potentially using a different dictionary, if underlying value types are equal
//
// We currently require in hash join that for all dictionary encoded columns, the same
// dictionary is used in all input exec batches.
//
// In order to allow matching columns with different dictionaries, different dictionary
// index types, and dictionary key against non-dictionary key, internally comparisons will
// be evaluated after remapping values on both sides of the join to a common
// representation (which will be called "unified representation"). This common
// representation is a column of int32() type (not a dictionary column). It represents an
// index in the unified dictionary computed for the (only) dictionary present on build
// side (an empty dictionary is still created for an empty build side). Null value is
// always represented in this common representation as null int32 value, unified
// dictionary will never contain a null value (so there is no ambiguity of representing
// nulls as either index to a null entry in the dictionary or null index).
//
// Unified dictionary represents values present on build side. There may be values on
// probe side that are not present in it. All such values, that are not null, are mapped
// in the common representation to a special constant kMissingValueId.
//
namespace arrow {
using compute::ExecBatch;
using compute::ExecContext;
using compute::internal::RowEncoder;
namespace acero {
/// Helper class with operations that are stateless and common to processing of dictionary
/// keys on both build and probe side.
class HashJoinDictUtil {
public:
// Null values in unified representation are always represented as null that has
// corresponding integer set to this constant
static constexpr int32_t kNullId = 0;
// Constant representing a value, that is not null, missing on the build side, in
// unified representation.
static constexpr int32_t kMissingValueId = -1;
// Check if data types of corresponding pair of key column on build and probe side are
// compatible
static bool KeyDataTypesValid(const std::shared_ptr<DataType>& probe_data_type,
const std::shared_ptr<DataType>& build_data_type);
// Input must be dictionary array or dictionary scalar.
// A precomputed and provided here lookup table in the form of int32() array will be
// used to remap input indices to unified representation.
//
static Result<std::shared_ptr<ArrayData>> IndexRemapUsingLUT(
ExecContext* ctx, const Datum& indices, int64_t batch_length,
const std::shared_ptr<ArrayData>& map_array,
const std::shared_ptr<DataType>& data_type);
// Return int32() array that contains indices of input dictionary array or scalar after
// type casting.
static Result<std::shared_ptr<ArrayData>> ConvertToInt32(
const std::shared_ptr<DataType>& from_type, const Datum& input,
int64_t batch_length, ExecContext* ctx);
// Return an array that contains elements of input int32() array after casting to a
// given integer type. This is used for mapping unified representation stored in the
// hash table on build side back to original input data type of hash join, when
// outputting hash join results to parent exec node.
//
static Result<std::shared_ptr<ArrayData>> ConvertFromInt32(
const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t batch_length,
ExecContext* ctx);
// Return dictionary referenced in either dictionary array or dictionary scalar
static std::shared_ptr<Array> ExtractDictionary(const Datum& data);
};
/// Implements processing of dictionary arrays/scalars in key columns on the build side of
/// a hash join.
/// Each instance of this class corresponds to a single column and stores and
/// processes only the information related to that column.
/// Const methods are thread-safe, non-const methods are not (the caller must make sure
/// that only one thread at any time will access them).
///
class HashJoinDictBuild {
public:
// Returns true if the key column (described in input by its data type) requires any
// pre- or post-processing related to handling dictionaries.
//
static bool KeyNeedsProcessing(const std::shared_ptr<DataType>& build_data_type) {
return (build_data_type->id() == Type::DICTIONARY);
}
// Data type of unified representation
static std::shared_ptr<DataType> DataTypeAfterRemapping() { return int32(); }
// Should be called only once in hash join, before processing any build or probe
// batches.
//
// Takes a pointer to the dictionary for a corresponding key column on the build side as
// an input. If the build side is empty, it still needs to be called, but with
// dictionary pointer set to null.
//
// Currently it is required that all input batches on build side share the same
// dictionary. For each input batch during its pre-processing, dictionary will be
// checked and error will be returned if it is different then the one provided in the
// call to this method.
//
// Unifies the dictionary. The order of the values is still preserved.
// Null and duplicate entries are removed. If the dictionary is already unified, its
// copy will be produced and stored within this class.
//
// Prepares the mapping from ids within original dictionary to the ids in the resulting
// dictionary. This is used later on to pre-process (map to unified representation) key
// column on build side.
//
// Prepares the reverse mapping (in the form of hash table) from values to the ids in
// the resulting dictionary. This will be used later on to pre-process (map to unified
// representation) key column on probe side. Values on probe side that are not present
// in the original dictionary will be mapped to a special constant kMissingValueId. The
// exception is made for nulls, which get always mapped to nulls (both when null is
// represented as a dictionary id pointing to a null and a null dictionary id).
//
Status Init(ExecContext* ctx, std::shared_ptr<Array> dictionary,
std::shared_ptr<DataType> index_type, std::shared_ptr<DataType> value_type);
// Remap array or scalar values into unified representation (array of int32()).
// Outputs kMissingValueId if input value is not found in the unified dictionary.
// Outputs null for null input value (with corresponding data set to kNullId).
//
Result<std::shared_ptr<ArrayData>> RemapInputValues(ExecContext* ctx,
const Datum& values,
int64_t batch_length) const;
// Remap dictionary array or dictionary scalar on build side to unified representation.
// Dictionary referenced in the input must match the dictionary that was
// given during initialization.
// The output is a dictionary array that references unified dictionary.
//
Result<std::shared_ptr<ArrayData>> RemapInput(
ExecContext* ctx, const Datum& indices, int64_t batch_length,
const std::shared_ptr<DataType>& data_type) const;
// Outputs dictionary array referencing unified dictionary, given an array with 32-bit
// ids.
// Used to post-process values looked up in a hash table on build side of the hash join
// before outputting to the parent exec node.
//
Result<std::shared_ptr<ArrayData>> RemapOutput(const ArrayData& indices32Bit,
ExecContext* ctx) const;
// Release shared pointers and memory
void CleanUp();
private:
// Data type of dictionary ids for the input dictionary on build side
std::shared_ptr<DataType> index_type_;
// Data type of values for the input dictionary on build side
std::shared_ptr<DataType> value_type_;
// Mapping from (encoded as string) values to the ids in unified dictionary
std::unordered_map<std::string, int32_t> hash_table_;
// Mapping from input dictionary ids to unified dictionary ids
std::shared_ptr<ArrayData> remapped_ids_;
// Input dictionary
std::shared_ptr<Array> dictionary_;
// Unified dictionary
std::shared_ptr<ArrayData> unified_dictionary_;
};
/// Implements processing of dictionary arrays/scalars in key columns on the probe side of
/// a hash join.
/// Each instance of this class corresponds to a single column and stores and
/// processes only the information related to that column.
/// It is not thread-safe - every participating thread should use its own instance of
/// this class.
///
class HashJoinDictProbe {
public:
static bool KeyNeedsProcessing(const std::shared_ptr<DataType>& probe_data_type,
const std::shared_ptr<DataType>& build_data_type);
// Data type of the result of remapping input key column.
//
// The result of remapping is what is used in hash join for matching keys on build and
// probe side. The exact data types may be different, as described below, and therefore
// a common representation is needed for simplifying comparisons of pairs of keys on
// both sides.
//
// We support matching key that is of non-dictionary type with key that is of dictionary
// type, as long as the underlying value types are equal. We support matching when both
// keys are of dictionary type, regardless whether underlying dictionary index types are
// the same or not.
//
static std::shared_ptr<DataType> DataTypeAfterRemapping(
const std::shared_ptr<DataType>& build_data_type);
// Should only be called if KeyNeedsProcessing method returns true for a pair of
// corresponding key columns from build and probe side.
// Converts values in order to match the common representation for
// both build and probe side used in hash table comparison.
// Supports arrays and scalars as input.
// Argument opt_build_side should be null if dictionary key on probe side is matched
// with non-dictionary key on build side.
//
Result<std::shared_ptr<ArrayData>> RemapInput(
const HashJoinDictBuild* opt_build_side, const Datum& data, int64_t batch_length,
const std::shared_ptr<DataType>& probe_data_type,
const std::shared_ptr<DataType>& build_data_type, ExecContext* ctx);
void CleanUp();
private:
// May be null if probe side key is non-dictionary. Otherwise it is used to verify that
// only a single dictionary is referenced in exec batch on probe side of hash join.
std::shared_ptr<Array> dictionary_;
// Mapping from dictionary on probe side of hash join (if it is used) to unified
// representation.
std::shared_ptr<ArrayData> remapped_ids_;
// Encoder of key columns that uses unified representation instead of original data type
// for key columns that need to use it (have dictionaries on either side of the join).
RowEncoder encoder_;
};
// Encapsulates dictionary handling logic for build side of hash join.
//
class HashJoinDictBuildMulti {
public:
Status Init(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
const ExecBatch* opt_non_empty_batch, ExecContext* ctx);
static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
RowEncoder* encoder, ExecContext* ctx);
Status EncodeBatch(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map,
const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const;
Status PostDecode(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
ExecBatch* decoded_key_batch, ExecContext* ctx);
const HashJoinDictBuild& get_dict_build(int icol) const { return remap_imp_[icol]; }
private:
std::vector<bool> needs_remap_;
std::vector<HashJoinDictBuild> remap_imp_;
};
// Encapsulates dictionary handling logic for probe side of hash join
//
class HashJoinDictProbeMulti {
public:
void Init(size_t num_threads);
bool BatchRemapNeeded(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
ExecContext* ctx);
Status EncodeBatch(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch,
RowEncoder** out_encoder, ExecBatch* opt_out_key_batch,
ExecContext* ctx);
private:
void InitLocalStateIfNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx);
static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
RowEncoder* encoder, ExecContext* ctx);
struct ThreadLocalState {
bool is_initialized;
// Whether any key column needs remapping (because of dictionaries used) before doing
// join hash table lookups
bool any_needs_remap;
// Whether each key column needs remapping before doing join hash table lookups
std::vector<bool> needs_remap;
std::vector<HashJoinDictProbe> remap_imp;
// Encoder of key columns that uses unified representation instead of original data
// type for key columns that need to use it (have dictionaries on either side of the
// join).
RowEncoder post_remap_encoder;
};
std::vector<ThreadLocalState> local_states_;
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,103 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cassert>
#include <vector>
#include "arrow/acero/options.h"
#include "arrow/acero/schema_util.h"
#include "arrow/result.h"
#include "arrow/status.h"
namespace arrow {
using compute::ExecContext;
namespace acero {
class ARROW_ACERO_EXPORT HashJoinSchema {
public:
Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys, const Schema& right_schema,
const std::vector<FieldRef>& right_keys, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);
Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& left_output, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_output, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);
static Status ValidateSchemas(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& left_output,
const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_output,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);
bool HasDictionaries() const;
bool HasLargeBinary() const;
Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema, ExecContext* exec_context);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
const std::string& right_field_name_suffix);
bool LeftPayloadIsEmpty() const { return PayloadIsEmpty(0); }
bool RightPayloadIsEmpty() const { return PayloadIsEmpty(1); }
static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
}
SchemaProjectionMaps<HashJoinProjection> proj_maps[2];
private:
static bool IsTypeSupported(const DataType& type);
Status CollectFilterColumns(std::vector<FieldRef>& left_filter,
std::vector<FieldRef>& right_filter,
const Expression& filter, const Schema& left_schema,
const Schema& right_schema);
Expression RewriteFilterToUseFilterSchema(int right_filter_offset,
const SchemaProjectionMap& left_to_filter,
const SchemaProjectionMap& right_to_filter,
const Expression& filter);
bool PayloadIsEmpty(int side) const {
assert(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}
static Result<std::vector<FieldRef>> ComputePayload(const Schema& schema,
const std::vector<FieldRef>& output,
const std::vector<FieldRef>& filter,
const std::vector<FieldRef>& key);
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,81 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cstdint>
#include <functional>
#include <memory>
#include <vector>
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/util.h"
#include "arrow/acero/visibility.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/status.h"
#include "arrow/type_fwd.h"
#include "arrow/util/cancel.h"
#include "arrow/util/type_fwd.h"
namespace arrow {
namespace acero {
/// A utility base class for simple exec nodes with one input
///
/// Pause/Resume Producing are forwarded appropriately
/// There is nothing to do in StopProducingImpl
///
/// An AtomicCounter is used to keep track of when all data has arrived. When it
/// has the Finish() method will be invoked
class ARROW_ACERO_EXPORT MapNode : public ExecNode, public TracedNode {
public:
MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema);
Status InputFinished(ExecNode* input, int total_batches) override;
Status StartProducing() override;
void PauseProducing(ExecNode* output, int32_t counter) override;
void ResumeProducing(ExecNode* output, int32_t counter) override;
Status InputReceived(ExecNode* input, ExecBatch batch) override;
const Ordering& ordering() const override;
protected:
Status StopProducingImpl() override;
/// Transform a batch
///
/// The output batch will have the same guarantee as the input batch
/// If this was the last batch this call may trigger Finish()
virtual Result<ExecBatch> ProcessBatch(ExecBatch batch) = 0;
/// Function called after all data has been received
///
/// By default this does nothing. Override this to provide a custom implementation.
virtual void Finish();
protected:
// Counter for the number of batches received
AtomicCounter input_counter_;
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,869 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "arrow/acero/type_fwd.h"
#include "arrow/acero/visibility.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/expression.h"
#include "arrow/result.h"
#include "arrow/util/future.h"
namespace arrow {
using compute::Aggregate;
using compute::ExecBatch;
using compute::Expression;
using compute::literal;
using compute::Ordering;
using compute::SelectKOptions;
using compute::SortOptions;
namespace internal {
class Executor;
} // namespace internal
namespace acero {
/// \brief This must not be used in release-mode
struct DebugOptions;
using AsyncExecBatchGenerator = std::function<Future<std::optional<ExecBatch>>()>;
/// \addtogroup acero-nodes
/// @{
/// \brief A base class for all options objects
///
/// The only time this is used directly is when a node has no configuration
class ARROW_ACERO_EXPORT ExecNodeOptions {
public:
virtual ~ExecNodeOptions() = default;
/// \brief This must not be used in release-mode
std::shared_ptr<DebugOptions> debug_opts;
};
/// \brief A node representing a generic source of data for Acero
///
/// The source node will start calling `generator` during StartProducing. An initial
/// task will be created that will call `generator`. It will not call `generator`
/// reentrantly. If the source can be read in parallel then those details should be
/// encapsulated within `generator`.
///
/// For each batch received a new task will be created to push that batch downstream.
/// This task will slice smaller units of size `ExecPlan::kMaxBatchSize` from the
/// parent batch and call InputReceived. Thus, if the `generator` yields a large
/// batch it may result in several calls to InputReceived.
///
/// The SourceNode will, by default, assign an implicit ordering to outgoing batches.
/// This is valid as long as the generator generates batches in a deterministic fashion.
/// Currently, the only way to override this is to subclass the SourceNode.
///
/// This node is not generally used directly but can serve as the basis for various
/// specialized nodes.
class ARROW_ACERO_EXPORT SourceNodeOptions : public ExecNodeOptions {
public:
/// Create an instance from values
SourceNodeOptions(std::shared_ptr<Schema> output_schema,
std::function<Future<std::optional<ExecBatch>>()> generator,
Ordering ordering = Ordering::Unordered())
: output_schema(std::move(output_schema)),
generator(std::move(generator)),
ordering(std::move(ordering)) {}
/// \brief the schema for batches that will be generated by this source
std::shared_ptr<Schema> output_schema;
/// \brief an asynchronous stream of batches ending with std::nullopt
std::function<Future<std::optional<ExecBatch>>()> generator;
/// \brief the order of the data, defaults to Ordering::Unordered
Ordering ordering;
};
/// \brief a node that generates data from a table already loaded in memory
///
/// The table source node will slice off chunks, defined by `max_batch_size`
/// for parallel processing. The table source node extends source node and so these
/// chunks will be iteratively processed in small batches. \see SourceNodeOptions
/// for details.
class ARROW_ACERO_EXPORT TableSourceNodeOptions : public ExecNodeOptions {
public:
static constexpr int64_t kDefaultMaxBatchSize = 1 << 20;
/// Create an instance from values
TableSourceNodeOptions(std::shared_ptr<Table> table,
int64_t max_batch_size = kDefaultMaxBatchSize)
: table(std::move(table)), max_batch_size(max_batch_size) {}
/// \brief a table which acts as the data source
std::shared_ptr<Table> table;
/// \brief size of batches to emit from this node
/// If the table is larger the node will emit multiple batches from the
/// the table to be processed in parallel.
int64_t max_batch_size;
};
/// \brief define a lazily resolved Arrow table.
///
/// The table uniquely identified by the names can typically be resolved at the time when
/// the plan is to be consumed.
///
/// This node is for serialization purposes only and can never be executed.
class ARROW_ACERO_EXPORT NamedTableNodeOptions : public ExecNodeOptions {
public:
/// Create an instance from values
NamedTableNodeOptions(std::vector<std::string> names, std::shared_ptr<Schema> schema)
: names(std::move(names)), schema(std::move(schema)) {}
/// \brief the names to put in the serialized plan
std::vector<std::string> names;
/// \brief the output schema of the table
std::shared_ptr<Schema> schema;
};
/// \brief a source node which feeds data from a synchronous iterator of batches
///
/// ItMaker is a maker of an iterator of tabular data.
///
/// The node can be configured to use an I/O executor. If set then each time the
/// iterator is polled a new I/O thread task will be created to do the polling. This
/// allows a blocking iterator to stay off the CPU thread pool.
template <typename ItMaker>
class ARROW_ACERO_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions {
public:
/// Create an instance that will create a new task on io_executor for each iteration
SchemaSourceNodeOptions(std::shared_ptr<Schema> schema, ItMaker it_maker,
arrow::internal::Executor* io_executor)
: schema(std::move(schema)),
it_maker(std::move(it_maker)),
io_executor(io_executor),
requires_io(true) {}
/// Create an instance that will either iterate synchronously or use the default I/O
/// executor
SchemaSourceNodeOptions(std::shared_ptr<Schema> schema, ItMaker it_maker,
bool requires_io = false)
: schema(std::move(schema)),
it_maker(std::move(it_maker)),
io_executor(NULLPTR),
requires_io(requires_io) {}
/// \brief The schema of the record batches from the iterator
std::shared_ptr<Schema> schema;
/// \brief A maker of an iterator which acts as the data source
ItMaker it_maker;
/// \brief The executor to use for scanning the iterator
///
/// Defaults to the default I/O executor. Only used if requires_io is true.
/// If requires_io is false then this MUST be nullptr.
arrow::internal::Executor* io_executor;
/// \brief If true then items will be fetched from the iterator on a dedicated I/O
/// thread to keep I/O off the CPU thread
bool requires_io;
};
/// a source node that reads from a RecordBatchReader
///
/// Each iteration of the RecordBatchReader will be run on a new thread task created
/// on the I/O thread pool.
class ARROW_ACERO_EXPORT RecordBatchReaderSourceNodeOptions : public ExecNodeOptions {
public:
/// Create an instance from values
RecordBatchReaderSourceNodeOptions(std::shared_ptr<RecordBatchReader> reader,
arrow::internal::Executor* io_executor = NULLPTR)
: reader(std::move(reader)), io_executor(io_executor) {}
/// \brief The RecordBatchReader which acts as the data source
std::shared_ptr<RecordBatchReader> reader;
/// \brief The executor to use for the reader
///
/// Defaults to the default I/O executor.
arrow::internal::Executor* io_executor;
};
/// a source node that reads from an iterator of array vectors
using ArrayVectorIteratorMaker = std::function<Iterator<std::shared_ptr<ArrayVector>>()>;
/// \brief An extended Source node which accepts a schema and array-vectors
class ARROW_ACERO_EXPORT ArrayVectorSourceNodeOptions
: public SchemaSourceNodeOptions<ArrayVectorIteratorMaker> {
using SchemaSourceNodeOptions::SchemaSourceNodeOptions;
};
/// a source node that reads from an iterator of ExecBatch
using ExecBatchIteratorMaker = std::function<Iterator<std::shared_ptr<ExecBatch>>()>;
/// \brief An extended Source node which accepts a schema and exec-batches
class ARROW_ACERO_EXPORT ExecBatchSourceNodeOptions
: public SchemaSourceNodeOptions<ExecBatchIteratorMaker> {
public:
using SchemaSourceNodeOptions::SchemaSourceNodeOptions;
ExecBatchSourceNodeOptions(std::shared_ptr<Schema> schema,
std::vector<ExecBatch> batches,
::arrow::internal::Executor* io_executor);
ExecBatchSourceNodeOptions(std::shared_ptr<Schema> schema,
std::vector<ExecBatch> batches, bool requires_io = false);
};
using RecordBatchIteratorMaker = std::function<Iterator<std::shared_ptr<RecordBatch>>()>;
/// a source node that reads from an iterator of RecordBatch
class ARROW_ACERO_EXPORT RecordBatchSourceNodeOptions
: public SchemaSourceNodeOptions<RecordBatchIteratorMaker> {
using SchemaSourceNodeOptions::SchemaSourceNodeOptions;
};
/// \brief a node which excludes some rows from batches passed through it
///
/// filter_expression will be evaluated against each batch which is pushed to
/// this node. Any rows for which filter_expression does not evaluate to `true` will be
/// excluded in the batch emitted by this node.
///
/// This node will emit empty batches if all rows are excluded. This is done
/// to avoid gaps in the ordering.
class ARROW_ACERO_EXPORT FilterNodeOptions : public ExecNodeOptions {
public:
/// \brief create an instance from values
explicit FilterNodeOptions(Expression filter_expression)
: filter_expression(std::move(filter_expression)) {}
/// \brief the expression to filter batches
///
/// The return type of this expression must be boolean
Expression filter_expression;
};
/// \brief a node which selects a specified subset from the input
class ARROW_ACERO_EXPORT FetchNodeOptions : public ExecNodeOptions {
public:
static constexpr std::string_view kName = "fetch";
/// \brief create an instance from values
FetchNodeOptions(int64_t offset, int64_t count) : offset(offset), count(count) {}
/// \brief the number of rows to skip
int64_t offset;
/// \brief the number of rows to keep (not counting skipped rows)
int64_t count;
};
/// \brief a node which executes expressions on input batches, producing batches
/// of the same length with new columns.
///
/// Each expression will be evaluated against each batch which is pushed to
/// this node to produce a corresponding output column.
///
/// If names are not provided, the string representations of exprs will be used.
class ARROW_ACERO_EXPORT ProjectNodeOptions : public ExecNodeOptions {
public:
/// \brief create an instance from values
explicit ProjectNodeOptions(std::vector<Expression> expressions,
std::vector<std::string> names = {})
: expressions(std::move(expressions)), names(std::move(names)) {}
/// \brief the expressions to run on the batches
///
/// The output will have one column for each expression. If you wish to keep any of
/// the columns from the input then you should create a simple field_ref expression
/// for that column.
std::vector<Expression> expressions;
/// \brief the names of the output columns
///
/// If this is not specified then the result of calling ToString on the expression will
/// be used instead
///
/// This list should either be empty or have the same length as `expressions`
std::vector<std::string> names;
};
/// \brief a node which aggregates input batches and calculates summary statistics
///
/// The node can summarize the entire input or it can group the input with grouping keys
/// and segment keys.
///
/// By default, the aggregate node is a pipeline breaker. It must accumulate all input
/// before any output is produced. Segment keys are a performance optimization. If
/// you know your input is already partitioned by one or more columns then you can
/// specify these as segment keys. At each change in the segment keys the node will
/// emit values for all data seen so far.
///
/// Segment keys are currently limited to single-threaded mode.
///
/// Both keys and segment-keys determine the group. However segment-keys are also used
/// for determining grouping segments, which should be large, and allow streaming a
/// partial aggregation result after processing each segment. One common use-case for
/// segment-keys is ordered aggregation, in which the segment-key attribute specifies a
/// column with non-decreasing values or a lexicographically-ordered set of such columns.
///
/// If the keys attribute is a non-empty vector, then each aggregate in `aggregates` is
/// expected to be a HashAggregate function. If the keys attribute is an empty vector,
/// then each aggregate is assumed to be a ScalarAggregate function.
///
/// If the segment_keys attribute is a non-empty vector, then segmented aggregation, as
/// described above, applies.
///
/// The keys and segment_keys vectors must be disjoint.
///
/// If no measures are provided then you will simply get the list of unique keys.
///
/// This node outputs segment keys first, followed by regular keys, followed by one
/// column for each aggregate.
class ARROW_ACERO_EXPORT AggregateNodeOptions : public ExecNodeOptions {
public:
/// \brief create an instance from values
explicit AggregateNodeOptions(std::vector<Aggregate> aggregates,
std::vector<FieldRef> keys = {},
std::vector<FieldRef> segment_keys = {})
: aggregates(std::move(aggregates)),
keys(std::move(keys)),
segment_keys(std::move(segment_keys)) {}
// aggregations which will be applied to the targeted fields
std::vector<Aggregate> aggregates;
// keys by which aggregations will be grouped (optional)
std::vector<FieldRef> keys;
// keys by which aggregations will be segmented (optional)
std::vector<FieldRef> segment_keys;
};
/// \brief a default value at which backpressure will be applied
constexpr int32_t kDefaultBackpressureHighBytes = 1 << 30; // 1GiB
/// \brief a default value at which backpressure will be removed
constexpr int32_t kDefaultBackpressureLowBytes = 1 << 28; // 256MiB
/// \brief an interface that can be queried for backpressure statistics
class ARROW_ACERO_EXPORT BackpressureMonitor {
public:
virtual ~BackpressureMonitor() = default;
/// \brief fetches the number of bytes currently queued up
virtual uint64_t bytes_in_use() = 0;
/// \brief checks to see if backpressure is currently applied
virtual bool is_paused() = 0;
};
/// \brief Options to control backpressure behavior
struct ARROW_ACERO_EXPORT BackpressureOptions {
/// \brief Create default options that perform no backpressure
BackpressureOptions() : resume_if_below(0), pause_if_above(0) {}
/// \brief Create options that will perform backpressure
///
/// \param resume_if_below The producer should resume producing if the backpressure
/// queue has fewer than resume_if_below items.
/// \param pause_if_above The producer should pause producing if the backpressure
/// queue has more than pause_if_above items
BackpressureOptions(uint64_t resume_if_below, uint64_t pause_if_above)
: resume_if_below(resume_if_below), pause_if_above(pause_if_above) {}
/// \brief create an instance using default values for backpressure limits
static BackpressureOptions DefaultBackpressure() {
return BackpressureOptions(kDefaultBackpressureLowBytes,
kDefaultBackpressureHighBytes);
}
/// \brief helper method to determine if backpressure is disabled
/// \return true if pause_if_above is greater than zero, false otherwise
bool should_apply_backpressure() const { return pause_if_above > 0; }
/// \brief the number of bytes at which the producer should resume producing
uint64_t resume_if_below;
/// \brief the number of bytes at which the producer should pause producing
///
/// If this is <= 0 then backpressure will be disabled
uint64_t pause_if_above;
};
/// \brief a sink node which collects results in a queue
///
/// Emitted batches will only be ordered if there is a meaningful ordering
/// and sequence_output is not set to false.
class ARROW_ACERO_EXPORT SinkNodeOptions : public ExecNodeOptions {
public:
explicit SinkNodeOptions(std::function<Future<std::optional<ExecBatch>>()>* generator,
std::shared_ptr<Schema>* schema,
BackpressureOptions backpressure = {},
BackpressureMonitor** backpressure_monitor = NULLPTR,
std::optional<bool> sequence_output = std::nullopt)
: generator(generator),
schema(schema),
backpressure(backpressure),
backpressure_monitor(backpressure_monitor),
sequence_output(sequence_output) {}
explicit SinkNodeOptions(std::function<Future<std::optional<ExecBatch>>()>* generator,
BackpressureOptions backpressure = {},
BackpressureMonitor** backpressure_monitor = NULLPTR,
std::optional<bool> sequence_output = std::nullopt)
: generator(generator),
schema(NULLPTR),
backpressure(std::move(backpressure)),
backpressure_monitor(backpressure_monitor),
sequence_output(sequence_output) {}
/// \brief A pointer to a generator of batches.
///
/// This will be set when the node is added to the plan and should be used to consume
/// data from the plan. If this function is not called frequently enough then the sink
/// node will start to accumulate data and may apply backpressure.
std::function<Future<std::optional<ExecBatch>>()>* generator;
/// \brief A pointer which will be set to the schema of the generated batches
///
/// This is optional, if nullptr is passed in then it will be ignored.
/// This will be set when the node is added to the plan, before StartProducing is called
std::shared_ptr<Schema>* schema;
/// \brief Options to control when to apply backpressure
///
/// This is optional, the default is to never apply backpressure. If the plan is not
/// consumed quickly enough the system may eventually run out of memory.
BackpressureOptions backpressure;
/// \brief A pointer to a backpressure monitor
///
/// This will be set when the node is added to the plan. This can be used to inspect
/// the amount of data currently queued in the sink node. This is an optional utility
/// and backpressure can be applied even if this is not used.
BackpressureMonitor** backpressure_monitor;
/// \brief Controls whether batches should be emitted immediately or sequenced in order
///
/// \see QueryOptions for more details
std::optional<bool> sequence_output;
};
/// \brief Control used by a SinkNodeConsumer to pause & resume
///
/// Callers should ensure that they do not call Pause and Resume simultaneously and they
/// should sequence things so that a call to Pause() is always followed by an eventual
/// call to Resume()
class ARROW_ACERO_EXPORT BackpressureControl {
public:
virtual ~BackpressureControl() = default;
/// \brief Ask the input to pause
///
/// This is best effort, batches may continue to arrive
/// Must eventually be followed by a call to Resume() or deadlock will occur
virtual void Pause() = 0;
/// \brief Ask the input to resume
virtual void Resume() = 0;
};
/// \brief a sink node that consumes the data as part of the plan using callbacks
class ARROW_ACERO_EXPORT SinkNodeConsumer {
public:
virtual ~SinkNodeConsumer() = default;
/// \brief Prepare any consumer state
///
/// This will be run once the schema is finalized as the plan is starting and
/// before any calls to Consume. A common use is to save off the schema so that
/// batches can be interpreted.
virtual Status Init(const std::shared_ptr<Schema>& schema,
BackpressureControl* backpressure_control, ExecPlan* plan) = 0;
/// \brief Consume a batch of data
virtual Status Consume(ExecBatch batch) = 0;
/// \brief Signal to the consumer that the last batch has been delivered
///
/// The returned future should only finish when all outstanding tasks have completed
///
/// If the plan is ended early or aborts due to an error then this will not be
/// called.
virtual Future<> Finish() = 0;
};
/// \brief Add a sink node which consumes data within the exec plan run
class ARROW_ACERO_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions {
public:
explicit ConsumingSinkNodeOptions(std::shared_ptr<SinkNodeConsumer> consumer,
std::vector<std::string> names = {},
std::optional<bool> sequence_output = std::nullopt)
: consumer(std::move(consumer)),
names(std::move(names)),
sequence_output(sequence_output) {}
std::shared_ptr<SinkNodeConsumer> consumer;
/// \brief Names to rename the sink's schema fields to
///
/// If specified then names must be provided for all fields. Currently, only a flat
/// schema is supported (see GH-31875).
///
/// If not specified then names will be generated based on the source data.
std::vector<std::string> names;
/// \brief Controls whether batches should be emitted immediately or sequenced in order
///
/// \see QueryOptions for more details
std::optional<bool> sequence_output;
};
/// \brief Make a node which sorts rows passed through it
///
/// All batches pushed to this node will be accumulated, then sorted, by the given
/// fields. Then sorted batches will be forwarded to the generator in sorted order.
class ARROW_ACERO_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions {
public:
/// \brief create an instance from values
explicit OrderBySinkNodeOptions(
SortOptions sort_options,
std::function<Future<std::optional<ExecBatch>>()>* generator)
: SinkNodeOptions(generator), sort_options(std::move(sort_options)) {}
/// \brief options describing which columns and direction to sort
SortOptions sort_options;
};
/// \brief Apply a new ordering to data
///
/// Currently this node works by accumulating all data, sorting, and then emitting
/// the new data with an updated batch index.
///
/// Larger-than-memory sort is not currently supported.
class ARROW_ACERO_EXPORT OrderByNodeOptions : public ExecNodeOptions {
public:
static constexpr std::string_view kName = "order_by";
explicit OrderByNodeOptions(Ordering ordering) : ordering(std::move(ordering)) {}
/// \brief The new ordering to apply to outgoing data
Ordering ordering;
};
enum class JoinType {
LEFT_SEMI,
RIGHT_SEMI,
LEFT_ANTI,
RIGHT_ANTI,
INNER,
LEFT_OUTER,
RIGHT_OUTER,
FULL_OUTER
};
std::string ToString(JoinType t);
enum class JoinKeyCmp { EQ, IS };
/// \brief a node which implements a join operation using a hash table
class ARROW_ACERO_EXPORT HashJoinNodeOptions : public ExecNodeOptions {
public:
static constexpr const char* default_output_suffix_for_left = "";
static constexpr const char* default_output_suffix_for_right = "";
/// \brief create an instance from values that outputs all columns
HashJoinNodeOptions(
JoinType in_join_type, std::vector<FieldRef> in_left_keys,
std::vector<FieldRef> in_right_keys, Expression filter = literal(true),
std::string output_suffix_for_left = default_output_suffix_for_left,
std::string output_suffix_for_right = default_output_suffix_for_right,
bool disable_bloom_filter = false)
: join_type(in_join_type),
left_keys(std::move(in_left_keys)),
right_keys(std::move(in_right_keys)),
output_all(true),
output_suffix_for_left(std::move(output_suffix_for_left)),
output_suffix_for_right(std::move(output_suffix_for_right)),
filter(std::move(filter)),
disable_bloom_filter(disable_bloom_filter) {
this->key_cmp.resize(this->left_keys.size());
for (size_t i = 0; i < this->left_keys.size(); ++i) {
this->key_cmp[i] = JoinKeyCmp::EQ;
}
}
/// \brief create an instance from keys
///
/// This will create an inner join that outputs all columns and has no post join filter
///
/// `in_left_keys` should have the same length and types as `in_right_keys`
/// @param in_left_keys the keys in the left input
/// @param in_right_keys the keys in the right input
HashJoinNodeOptions(std::vector<FieldRef> in_left_keys,
std::vector<FieldRef> in_right_keys)
: left_keys(std::move(in_left_keys)), right_keys(std::move(in_right_keys)) {
this->join_type = JoinType::INNER;
this->output_all = true;
this->output_suffix_for_left = default_output_suffix_for_left;
this->output_suffix_for_right = default_output_suffix_for_right;
this->key_cmp.resize(this->left_keys.size());
for (size_t i = 0; i < this->left_keys.size(); ++i) {
this->key_cmp[i] = JoinKeyCmp::EQ;
}
this->filter = literal(true);
}
/// \brief create an instance from values using JoinKeyCmp::EQ for all comparisons
HashJoinNodeOptions(
JoinType join_type, std::vector<FieldRef> left_keys,
std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
std::vector<FieldRef> right_output, Expression filter = literal(true),
std::string output_suffix_for_left = default_output_suffix_for_left,
std::string output_suffix_for_right = default_output_suffix_for_right,
bool disable_bloom_filter = false)
: join_type(join_type),
left_keys(std::move(left_keys)),
right_keys(std::move(right_keys)),
output_all(false),
left_output(std::move(left_output)),
right_output(std::move(right_output)),
output_suffix_for_left(std::move(output_suffix_for_left)),
output_suffix_for_right(std::move(output_suffix_for_right)),
filter(std::move(filter)),
disable_bloom_filter(disable_bloom_filter) {
this->key_cmp.resize(this->left_keys.size());
for (size_t i = 0; i < this->left_keys.size(); ++i) {
this->key_cmp[i] = JoinKeyCmp::EQ;
}
}
/// \brief create an instance from values
HashJoinNodeOptions(
JoinType join_type, std::vector<FieldRef> left_keys,
std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
std::vector<FieldRef> right_output, std::vector<JoinKeyCmp> key_cmp,
Expression filter = literal(true),
std::string output_suffix_for_left = default_output_suffix_for_left,
std::string output_suffix_for_right = default_output_suffix_for_right,
bool disable_bloom_filter = false)
: join_type(join_type),
left_keys(std::move(left_keys)),
right_keys(std::move(right_keys)),
output_all(false),
left_output(std::move(left_output)),
right_output(std::move(right_output)),
key_cmp(std::move(key_cmp)),
output_suffix_for_left(std::move(output_suffix_for_left)),
output_suffix_for_right(std::move(output_suffix_for_right)),
filter(std::move(filter)),
disable_bloom_filter(disable_bloom_filter) {}
HashJoinNodeOptions() = default;
// type of join (inner, left, semi...)
JoinType join_type = JoinType::INNER;
// key fields from left input
std::vector<FieldRef> left_keys;
// key fields from right input
std::vector<FieldRef> right_keys;
// if set all valid fields from both left and right input will be output
// (and field ref vectors for output fields will be ignored)
bool output_all = false;
// output fields passed from left input
std::vector<FieldRef> left_output;
// output fields passed from right input
std::vector<FieldRef> right_output;
// key comparison function (determines whether a null key is equal another null
// key or not)
std::vector<JoinKeyCmp> key_cmp;
// suffix added to names of output fields coming from left input (used to distinguish,
// if necessary, between fields of the same name in left and right input and can be left
// empty if there are no name collisions)
std::string output_suffix_for_left;
// suffix added to names of output fields coming from right input
std::string output_suffix_for_right;
// residual filter which is applied to matching rows. Rows that do not match
// the filter are not included. The filter is applied against the
// concatenated input schema (left fields then right fields) and can reference
// fields that are not included in the output.
Expression filter = literal(true);
// whether or not to disable Bloom filters in this join
bool disable_bloom_filter = false;
};
/// \brief a node which implements the asof join operation
///
/// Note, this API is experimental and will change in the future
///
/// This node takes one left table and any number of right tables, and asof joins them
/// together. Batches produced by each input must be ordered by the "on" key.
/// This node will output one row for each row in the left table.
class ARROW_ACERO_EXPORT AsofJoinNodeOptions : public ExecNodeOptions {
public:
/// \brief Keys for one input table of the AsofJoin operation
///
/// The keys must be consistent across the input tables:
/// Each "on" key must refer to a field of the same type and units across the tables.
/// Each "by" key must refer to a list of fields of the same types across the tables.
struct Keys {
/// \brief "on" key for the join.
///
/// The input table must be sorted by the "on" key. Must be a single field of a common
/// type. Inexact match is used on the "on" key. i.e., a row is considered a match iff
/// left_on - tolerance <= right_on <= left_on.
/// Currently, the "on" key must be of an integer, date, or timestamp type.
FieldRef on_key;
/// \brief "by" key for the join.
///
/// Each input table must have each field of the "by" key. Exact equality is used for
/// each field of the "by" key.
/// Currently, each field of the "by" key must be of an integer, date, timestamp, or
/// base-binary type.
std::vector<FieldRef> by_key;
};
AsofJoinNodeOptions(std::vector<Keys> input_keys, int64_t tolerance)
: input_keys(std::move(input_keys)), tolerance(tolerance) {}
/// \brief AsofJoin keys per input table. At least two keys must be given. The first key
/// corresponds to a left table and all other keys correspond to right tables for the
/// as-of-join.
///
/// \see `Keys` for details.
std::vector<Keys> input_keys;
/// \brief Tolerance for inexact "on" key matching. A right row is considered a match
/// with the left row if `right.on - left.on <= tolerance`. The `tolerance` may be:
/// - negative, in which case a past-as-of-join occurs;
/// - or positive, in which case a future-as-of-join occurs;
/// - or zero, in which case an exact-as-of-join occurs.
///
/// The tolerance is interpreted in the same units as the "on" key.
int64_t tolerance;
};
/// \brief a node which select top_k/bottom_k rows passed through it
///
/// All batches pushed to this node will be accumulated, then selected, by the given
/// fields. Then sorted batches will be forwarded to the generator in sorted order.
class ARROW_ACERO_EXPORT SelectKSinkNodeOptions : public SinkNodeOptions {
public:
explicit SelectKSinkNodeOptions(
SelectKOptions select_k_options,
std::function<Future<std::optional<ExecBatch>>()>* generator)
: SinkNodeOptions(generator), select_k_options(std::move(select_k_options)) {}
/// SelectK options
SelectKOptions select_k_options;
};
/// \brief a sink node which accumulates all output into a table
class ARROW_ACERO_EXPORT TableSinkNodeOptions : public ExecNodeOptions {
public:
/// \brief create an instance from values
explicit TableSinkNodeOptions(std::shared_ptr<Table>* output_table,
std::optional<bool> sequence_output = std::nullopt)
: output_table(output_table), sequence_output(sequence_output) {}
/// \brief an "out parameter" specifying the table that will be created
///
/// Must not be null and remain valid for the entirety of the plan execution. After the
/// plan has completed this will be set to point to the result table
std::shared_ptr<Table>* output_table;
/// \brief Controls whether batches should be emitted immediately or sequenced in order
///
/// \see QueryOptions for more details
std::optional<bool> sequence_output;
/// \brief Custom names to use for the columns.
///
/// If specified then names must be provided for all fields. Currently, only a flat
/// schema is supported (see GH-31875).
///
/// If not specified then names will be generated based on the source data.
std::vector<std::string> names;
};
/// \brief a row template that describes one row that will be generated for each input row
struct ARROW_ACERO_EXPORT PivotLongerRowTemplate {
PivotLongerRowTemplate(std::vector<std::string> feature_values,
std::vector<std::optional<FieldRef>> measurement_values)
: feature_values(std::move(feature_values)),
measurement_values(std::move(measurement_values)) {}
/// A (typically unique) set of feature values for the template, usually derived from a
/// column name
///
/// These will be used to populate the feature columns
std::vector<std::string> feature_values;
/// The fields containing the measurements to use for this row
///
/// These will be used to populate the measurement columns. If nullopt then nulls
/// will be inserted for the given value.
std::vector<std::optional<FieldRef>> measurement_values;
};
/// \brief Reshape a table by turning some columns into additional rows
///
/// This operation is sometimes also referred to as UNPIVOT
///
/// This is typically done when there are multiple observations in each row in order to
/// transform to a table containing a single observation per row.
///
/// For example:
///
/// | time | left_temp | right_temp |
/// | ---- | --------- | ---------- |
/// | 1 | 10 | 20 |
/// | 2 | 15 | 18 |
///
/// The above table contains two observations per row. There is an implicit feature
/// "location" (left vs right) and a measurement "temp". What we really want is:
///
/// | time | location | temp |
/// | --- | --- | --- |
/// | 1 | left | 10 |
/// | 1 | right | 20 |
/// | 2 | left | 15 |
/// | 2 | right | 18 |
///
/// For a more complex example consider:
///
/// | time | ax1 | ay1 | bx1 | ay2 |
/// | ---- | --- | --- | --- | --- |
/// | 0 | 1 | 2 | 3 | 4 |
///
/// We can pretend a vs b and x vs y are features while 1 and 2 are two different
/// kinds of measurements. We thus want to pivot to
///
/// | time | a/b | x/y | f1 | f2 |
/// | ---- | --- | --- | ---- | ---- |
/// | 0 | a | x | 1 | null |
/// | 0 | a | y | 2 | 4 |
/// | 0 | b | x | 3 | null |
///
/// To do this we create a row template for each combination of features. One should
/// be able to do this purely by looking at the column names. For example, given the
/// above columns "ax1", "ay1", "bx1", and "ay2" we know we have three feature
/// combinations (a, x), (a, y), and (b, x). Similarly, we know we have two possible
/// measurements, "1" and "2".
///
/// For each combination of features we create a row template. In each row template we
/// describe the combination and then list which columns to use for the measurements.
/// If a measurement doesn't exist for a given combination then we use nullopt.
///
/// So, for our above example, we have:
///
/// (a, x): names={"a", "x"}, values={"ax1", nullopt}
/// (a, y): names={"a", "y"}, values={"ay1", "ay2"}
/// (b, x): names={"b", "x"}, values={"bx1", nullopt}
///
/// Finishing it off we name our new columns:
/// feature_field_names={"a/b","x/y"}
/// measurement_field_names={"f1", "f2"}
class ARROW_ACERO_EXPORT PivotLongerNodeOptions : public ExecNodeOptions {
public:
static constexpr std::string_view kName = "pivot_longer";
/// One or more row templates to create new output rows
///
/// Normally there are at least two row templates. The output # of rows
/// will be the input # of rows * the number of row templates
std::vector<PivotLongerRowTemplate> row_templates;
/// The names of the columns which describe the new features
std::vector<std::string> feature_field_names;
/// The names of the columns which represent the measurements
std::vector<std::string> measurement_field_names;
};
/// @}
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include "arrow/acero/options.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type.h"
namespace arrow {
using compute::ExecContext;
namespace acero {
class OrderByImpl {
public:
virtual ~OrderByImpl() = default;
virtual void InputReceived(const std::shared_ptr<RecordBatch>& batch) = 0;
virtual Result<Datum> DoFinish() = 0;
virtual std::string ToString() const = 0;
static Result<std::unique_ptr<OrderByImpl>> MakeSort(
ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
const SortOptions& options);
static Result<std::unique_ptr<OrderByImpl>> MakeSelectK(
ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
const SelectKOptions& options);
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,186 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <atomic>
#include <cassert>
#include <cstdint>
#include <functional>
#include <random>
#include "arrow/acero/util.h"
#include "arrow/buffer.h"
#include "arrow/util/pcg_random.h"
namespace arrow {
namespace acero {
class PartitionSort {
public:
/// \brief Bucket sort rows on partition ids in O(num_rows) time.
///
/// Include in the output exclusive cumulative sum of bucket sizes.
/// This corresponds to ranges in the sorted array containing all row ids for
/// each of the partitions.
///
/// prtn_ranges must be initialized and have at least num_prtns + 1 elements
/// when this method returns prtn_ranges[i] will contains the total number of
/// elements in partitions 0 through i. prtn_ranges[0] will be 0.
///
/// prtn_id_impl must be a function that takes in a row id (int) and returns
/// a partition id (int). The returned partition id must be between 0 and
/// num_prtns (exclusive).
///
/// output_pos_impl is a function that takes in a row id (int) and a position (int)
/// in the bucket sorted output. The function should insert the row in the
/// output.
///
/// For example:
///
/// in_arr: [5, 7, 2, 3, 5, 4]
/// num_prtns: 3
/// prtn_id_impl: [&in_arr] (int row_id) { return in_arr[row_id] / 3; }
/// output_pos_impl: [&sorted_row_ids] (int row_id, int pos) { sorted_row_ids[pos] =
/// row_id; }
///
/// After Execution
/// sorted_row_ids: [2, 0, 3, 4, 5, 1]
/// prtn_ranges: [0, 1, 5, 6]
template <class INPUT_PRTN_ID_FN, class OUTPUT_POS_FN>
static void Eval(int64_t num_rows, int num_prtns, uint16_t* prtn_ranges,
INPUT_PRTN_ID_FN prtn_id_impl, OUTPUT_POS_FN output_pos_impl) {
ARROW_DCHECK(num_rows > 0 && num_rows <= (1 << 15));
ARROW_DCHECK(num_prtns >= 1 && num_prtns <= (1 << 15));
memset(prtn_ranges, 0, (num_prtns + 1) * sizeof(uint16_t));
for (int64_t i = 0; i < num_rows; ++i) {
int prtn_id = static_cast<int>(prtn_id_impl(i));
++prtn_ranges[prtn_id + 1];
}
uint16_t sum = 0;
for (int i = 0; i < num_prtns; ++i) {
uint16_t sum_next = sum + prtn_ranges[i + 1];
prtn_ranges[i + 1] = sum;
sum = sum_next;
}
for (int64_t i = 0; i < num_rows; ++i) {
int prtn_id = static_cast<int>(prtn_id_impl(i));
int pos = prtn_ranges[prtn_id + 1]++;
output_pos_impl(i, pos);
}
}
};
/// \brief A control for synchronizing threads on a partitionable workload
class PartitionLocks {
public:
PartitionLocks();
~PartitionLocks();
/// \brief Initializes the control, must be called before use
///
/// \param num_threads Maximum number of threads that will access the partitions
/// \param num_prtns Number of partitions to synchronize
void Init(size_t num_threads, int num_prtns);
/// \brief Cleans up the control, it should not be used after this call
void CleanUp();
/// \brief Acquire a partition to work on one
///
/// \param thread_id The index of the thread trying to acquire the partition lock
/// \param num_prtns Length of prtns_to_try, must be <= num_prtns used in Init
/// \param prtns_to_try An array of partitions that still have remaining work
/// \param limit_retries If false, this method will spinwait forever until success
/// \param max_retries Max times to attempt checking out work before returning false
/// \param[out] locked_prtn_id The id of the partition locked
/// \param[out] locked_prtn_id_pos The index of the partition locked in prtns_to_try
/// \return True if a partition was locked, false if max_retries was attempted
/// without successfully acquiring a lock
///
/// This method is thread safe
bool AcquirePartitionLock(size_t thread_id, int num_prtns, const int* prtns_to_try,
bool limit_retries, int max_retries, int* locked_prtn_id,
int* locked_prtn_id_pos);
/// \brief Release a partition so that other threads can work on it
void ReleasePartitionLock(int prtn_id);
// Executes (synchronously and using current thread) the same operation on a set of
// multiple partitions. Tries to minimize partition locking overhead by randomizing and
// adjusting order in which partitions are processed.
//
// PROCESS_PRTN_FN is a callback which will be executed for each partition after
// acquiring the lock for that partition. It gets partition id as an argument.
// IS_PRTN_EMPTY_FN is a callback which filters out (when returning true) partitions
// with specific ids from processing.
//
template <typename IS_PRTN_EMPTY_FN, typename PROCESS_PRTN_FN>
Status ForEachPartition(size_t thread_id,
/*scratch space buffer with space for one element per partition;
dirty in and dirty out*/
int* temp_unprocessed_prtns, IS_PRTN_EMPTY_FN is_prtn_empty_fn,
PROCESS_PRTN_FN process_prtn_fn) {
int num_unprocessed_partitions = 0;
for (int i = 0; i < num_prtns_; ++i) {
bool is_prtn_empty = is_prtn_empty_fn(i);
if (!is_prtn_empty) {
temp_unprocessed_prtns[num_unprocessed_partitions++] = i;
}
}
while (num_unprocessed_partitions > 0) {
int locked_prtn_id;
int locked_prtn_id_pos;
AcquirePartitionLock(thread_id, num_unprocessed_partitions, temp_unprocessed_prtns,
/*limit_retries=*/false, /*max_retries=*/-1, &locked_prtn_id,
&locked_prtn_id_pos);
{
class AutoReleaseLock {
public:
AutoReleaseLock(PartitionLocks* locks, int prtn_id)
: locks(locks), prtn_id(prtn_id) {}
~AutoReleaseLock() { locks->ReleasePartitionLock(prtn_id); }
PartitionLocks* locks;
int prtn_id;
} auto_release_lock(this, locked_prtn_id);
ARROW_RETURN_NOT_OK(process_prtn_fn(locked_prtn_id));
}
if (locked_prtn_id_pos < num_unprocessed_partitions - 1) {
temp_unprocessed_prtns[locked_prtn_id_pos] =
temp_unprocessed_prtns[num_unprocessed_partitions - 1];
}
--num_unprocessed_partitions;
}
return Status::OK();
}
private:
std::atomic<bool>* lock_ptr(int prtn_id);
int random_int(size_t thread_id, int num_values);
struct PartitionLock {
static constexpr int kCacheLineBytes = 64;
std::atomic<bool> lock;
uint8_t padding[kCacheLineBytes];
};
int num_prtns_;
std::unique_ptr<PartitionLock[]> locks_;
std::unique_ptr<arrow::random::pcg32_fast[]> rngs_;
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,151 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <string_view>
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/task_util.h"
#include "arrow/acero/util.h"
#include "arrow/compute/exec.h"
#include "arrow/io/interfaces.h"
#include "arrow/util/async_util.h"
#include "arrow/util/type_fwd.h"
namespace arrow {
using compute::default_exec_context;
using io::IOContext;
namespace acero {
class ARROW_ACERO_EXPORT QueryContext {
public:
QueryContext(QueryOptions opts = {},
ExecContext exec_context = *default_exec_context());
Status Init(arrow::util::AsyncTaskScheduler* scheduler);
const ::arrow::internal::CpuInfo* cpu_info() const;
int64_t hardware_flags() const;
const QueryOptions& options() const { return options_; }
MemoryPool* memory_pool() const { return exec_context_.memory_pool(); }
::arrow::internal::Executor* executor() const { return exec_context_.executor(); }
ExecContext* exec_context() { return &exec_context_; }
IOContext* io_context() { return &io_context_; }
TaskScheduler* scheduler() { return task_scheduler_.get(); }
arrow::util::AsyncTaskScheduler* async_scheduler() { return async_scheduler_; }
size_t GetThreadIndex();
size_t max_concurrency() const;
/// \brief Start an external task
///
/// This should be avoided if possible. It is kept in for now for legacy
/// purposes. This should be called before the external task is started. If
/// a valid future is returned then it should be marked complete when the
/// external task has finished.
///
/// \param name A name to give the task for traceability and debugging
///
/// \return an invalid future if the plan has already ended, otherwise this
/// returns a future that must be completed when the external task
/// finishes.
Result<Future<>> BeginExternalTask(std::string_view name);
/// \brief Add a single function as a task to the query's task group
/// on the compute threadpool.
///
/// \param fn The task to run. Takes no arguments and returns a Status.
/// \param name A name to give the task for traceability and debugging
void ScheduleTask(std::function<Status()> fn, std::string_view name);
/// \brief Add a single function as a task to the query's task group
/// on the compute threadpool.
///
/// \param fn The task to run. Takes the thread index and returns a Status.
/// \param name A name to give the task for traceability and debugging
void ScheduleTask(std::function<Status(size_t)> fn, std::string_view name);
/// \brief Add a single function as a task to the query's task group on
/// the IO thread pool
///
/// \param fn The task to run. Returns a status.
/// \param name A name to give the task for traceability and debugging
void ScheduleIOTask(std::function<Status()> fn, std::string_view name);
// Register/Start TaskGroup is a way of performing a "Parallel For" pattern:
// - The task function takes the thread index and the index of the task
// - The on_finished function takes the thread index
// Returns an integer ID that will be used to reference the task group in
// StartTaskGroup. At runtime, call StartTaskGroup with the ID and the number of times
// you'd like the task to be executed. The need to register a task group before use will
// be removed after we rewrite the scheduler.
/// \brief Register a "parallel for" task group with the scheduler
///
/// \param task The function implementing the task. Takes the thread_index and
/// the task index.
/// \param on_finished The function that gets run once all tasks have been completed.
/// Takes the thread_index.
///
/// Must be called inside of ExecNode::Init.
int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
std::function<Status(size_t)> on_finished);
/// \brief Start the task group with the specified ID. This can only
/// be called once per task_group_id.
///
/// \param task_group_id The ID of the task group to run
/// \param num_tasks The number of times to run the task
Status StartTaskGroup(int task_group_id, int64_t num_tasks);
// This is an RAII class for keeping track of in-flight file IO. Useful for getting
// an estimate of memory use, and how much memory we expect to be freed soon.
// Returned by ReportTempFileIO.
struct [[nodiscard]] TempFileIOMark {
QueryContext* ctx_;
size_t bytes_;
TempFileIOMark(QueryContext* ctx, size_t bytes) : ctx_(ctx), bytes_(bytes) {
ctx_->in_flight_bytes_to_disk_.fetch_add(bytes_, std::memory_order_acquire);
}
ARROW_DISALLOW_COPY_AND_ASSIGN(TempFileIOMark);
~TempFileIOMark() {
ctx_->in_flight_bytes_to_disk_.fetch_sub(bytes_, std::memory_order_release);
}
};
TempFileIOMark ReportTempFileIO(size_t bytes) { return {this, bytes}; }
size_t GetCurrentTempFileIO() { return in_flight_bytes_to_disk_.load(); }
private:
QueryOptions options_;
// To be replaced with Acero-specific context once scheduler is done and
// we don't need ExecContext for kernels
ExecContext exec_context_;
IOContext io_context_;
arrow::util::AsyncTaskScheduler* async_scheduler_ = NULLPTR;
std::unique_ptr<TaskScheduler> task_scheduler_ = TaskScheduler::Make();
ThreadIndexer thread_indexer_;
std::atomic<size_t> in_flight_bytes_to_disk_{0};
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,226 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cassert>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "arrow/type.h" // for DataType, FieldRef, Field and Schema
namespace arrow {
using internal::checked_cast;
namespace acero {
// Identifiers for all different row schemas that are used in a join
//
enum class HashJoinProjection : int {
INPUT = 0,
KEY = 1,
PAYLOAD = 2,
FILTER = 3,
OUTPUT = 4
};
struct SchemaProjectionMap {
static constexpr int kMissingField = -1;
int num_cols;
const int* source_to_base;
const int* base_to_target;
inline int get(int i) const {
assert(i >= 0 && i < num_cols);
assert(source_to_base[i] != kMissingField);
return base_to_target[source_to_base[i]];
}
};
/// Helper class for managing different projections of the same row schema.
/// Used to efficiently map any field in one projection to a corresponding field in
/// another projection.
/// Materialized mappings are generated lazily at the time of the first access.
/// Thread-safe apart from initialization.
template <typename ProjectionIdEnum>
class SchemaProjectionMaps {
public:
static constexpr int kMissingField = -1;
Status Init(ProjectionIdEnum full_schema_handle, const Schema& schema,
const std::vector<ProjectionIdEnum>& projection_handles,
const std::vector<const std::vector<FieldRef>*>& projections) {
assert(projection_handles.size() == projections.size());
ARROW_RETURN_NOT_OK(RegisterSchema(full_schema_handle, schema));
for (size_t i = 0; i < projections.size(); ++i) {
ARROW_RETURN_NOT_OK(
RegisterProjectedSchema(projection_handles[i], *(projections[i]), schema));
}
RegisterEnd();
return Status::OK();
}
int num_cols(ProjectionIdEnum schema_handle) const {
int id = schema_id(schema_handle);
return static_cast<int>(schemas_[id].second.data_types.size());
}
bool is_empty(ProjectionIdEnum schema_handle) const {
return num_cols(schema_handle) == 0;
}
const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const {
int id = schema_id(schema_handle);
return schemas_[id].second.field_names[field_id];
}
const std::shared_ptr<DataType>& data_type(ProjectionIdEnum schema_handle,
int field_id) const {
int id = schema_id(schema_handle);
return schemas_[id].second.data_types[field_id];
}
const std::vector<std::shared_ptr<DataType>>& data_types(
ProjectionIdEnum schema_handle) const {
int id = schema_id(schema_handle);
return schemas_[id].second.data_types;
}
SchemaProjectionMap map(ProjectionIdEnum from, ProjectionIdEnum to) const {
int id_from = schema_id(from);
int id_to = schema_id(to);
SchemaProjectionMap result;
result.num_cols = num_cols(from);
result.source_to_base = mappings_[id_from].data();
result.base_to_target = inverse_mappings_[id_to].data();
return result;
}
protected:
struct FieldInfos {
std::vector<int> field_paths;
std::vector<std::string> field_names;
std::vector<std::shared_ptr<DataType>> data_types;
};
Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
FieldInfos out_fields;
const FieldVector& in_fields = schema.fields();
out_fields.field_paths.resize(in_fields.size());
out_fields.field_names.resize(in_fields.size());
out_fields.data_types.resize(in_fields.size());
for (size_t i = 0; i < in_fields.size(); ++i) {
const std::string& name = in_fields[i]->name();
const std::shared_ptr<DataType>& type = in_fields[i]->type();
out_fields.field_paths[i] = static_cast<int>(i);
out_fields.field_names[i] = name;
out_fields.data_types[i] = type;
}
schemas_.push_back(std::make_pair(handle, out_fields));
return Status::OK();
}
Status RegisterProjectedSchema(ProjectionIdEnum handle,
const std::vector<FieldRef>& selected_fields,
const Schema& full_schema) {
FieldInfos out_fields;
const FieldVector& in_fields = full_schema.fields();
out_fields.field_paths.resize(selected_fields.size());
out_fields.field_names.resize(selected_fields.size());
out_fields.data_types.resize(selected_fields.size());
for (size_t i = 0; i < selected_fields.size(); ++i) {
// All fields must be found in schema without ambiguity
ARROW_ASSIGN_OR_RAISE(auto match, selected_fields[i].FindOne(full_schema));
const std::string& name = in_fields[match[0]]->name();
const std::shared_ptr<DataType>& type = in_fields[match[0]]->type();
out_fields.field_paths[i] = match[0];
out_fields.field_names[i] = name;
out_fields.data_types[i] = type;
}
schemas_.push_back(std::make_pair(handle, out_fields));
return Status::OK();
}
void RegisterEnd() {
size_t size = schemas_.size();
mappings_.resize(size);
inverse_mappings_.resize(size);
int id_base = 0;
for (size_t i = 0; i < size; ++i) {
GenerateMapForProjection(static_cast<int>(i), id_base);
}
}
int schema_id(ProjectionIdEnum schema_handle) const {
for (size_t i = 0; i < schemas_.size(); ++i) {
if (schemas_[i].first == schema_handle) {
return static_cast<int>(i);
}
}
// We should never get here
assert(false);
return -1;
}
void GenerateMapForProjection(int id_proj, int id_base) {
int num_cols_proj = static_cast<int>(schemas_[id_proj].second.data_types.size());
int num_cols_base = static_cast<int>(schemas_[id_base].second.data_types.size());
std::vector<int>& mapping = mappings_[id_proj];
std::vector<int>& inverse_mapping = inverse_mappings_[id_proj];
mapping.resize(num_cols_proj);
inverse_mapping.resize(num_cols_base);
if (id_proj == id_base) {
for (int i = 0; i < num_cols_base; ++i) {
mapping[i] = inverse_mapping[i] = i;
}
} else {
const FieldInfos& fields_proj = schemas_[id_proj].second;
const FieldInfos& fields_base = schemas_[id_base].second;
for (int i = 0; i < num_cols_base; ++i) {
inverse_mapping[i] = SchemaProjectionMap::kMissingField;
}
for (int i = 0; i < num_cols_proj; ++i) {
int field_id = SchemaProjectionMap::kMissingField;
for (int j = 0; j < num_cols_base; ++j) {
if (fields_proj.field_paths[i] == fields_base.field_paths[j]) {
field_id = j;
// If there are multiple matches for the same input field,
// it will be mapped to the first match.
break;
}
}
assert(field_id != SchemaProjectionMap::kMissingField);
mapping[i] = field_id;
inverse_mapping[field_id] = i;
}
}
}
// vector used as a mapping from ProjectionIdEnum to fields
std::vector<std::pair<ProjectionIdEnum, FieldInfos>> schemas_;
std::vector<std::vector<int>> mappings_;
std::vector<std::vector<int>> inverse_mappings_;
};
using HashJoinProjectionMaps = SchemaProjectionMaps<HashJoinProjection>;
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,102 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <atomic>
#include <cstdint>
#include <functional>
#include <vector>
#include "arrow/acero/visibility.h"
#include "arrow/status.h"
#include "arrow/util/config.h"
#include "arrow/util/logging.h"
namespace arrow {
namespace acero {
// Atomic value surrounded by padding bytes to avoid cache line invalidation
// whenever it is modified by a concurrent thread on a different CPU core.
//
template <typename T>
class AtomicWithPadding {
private:
static constexpr int kCacheLineSize = 64;
uint8_t padding_before[kCacheLineSize];
public:
std::atomic<T> value;
private:
uint8_t padding_after[kCacheLineSize];
};
// Used for asynchronous execution of operations that can be broken into
// a fixed number of symmetric tasks that can be executed concurrently.
//
// Implements priorities between multiple such operations, called task groups.
//
// Allows to specify the maximum number of in-flight tasks at any moment.
//
// Also allows for executing next pending tasks immediately using a caller thread.
//
class ARROW_ACERO_EXPORT TaskScheduler {
public:
using TaskImpl = std::function<Status(size_t, int64_t)>;
using TaskGroupContinuationImpl = std::function<Status(size_t)>;
using ScheduleImpl = std::function<Status(TaskGroupContinuationImpl)>;
using AbortContinuationImpl = std::function<void()>;
virtual ~TaskScheduler() = default;
// Order in which task groups are registered represents priorities of their tasks
// (the first group has the highest priority).
//
// Returns task group identifier that is used to request operations on the task group.
virtual int RegisterTaskGroup(TaskImpl task_impl,
TaskGroupContinuationImpl cont_impl) = 0;
virtual void RegisterEnd() = 0;
// total_num_tasks may be zero, in which case task group continuation will be executed
// immediately
virtual Status StartTaskGroup(size_t thread_id, int group_id,
int64_t total_num_tasks) = 0;
// Execute given number of tasks immediately using caller thread
virtual Status ExecuteMore(size_t thread_id, int num_tasks_to_execute,
bool execute_all) = 0;
// Begin scheduling tasks using provided callback and
// the limit on the number of in-flight tasks at any moment.
//
// Scheduling will continue as long as there are waiting tasks.
//
// It will automatically resume whenever new task group gets started.
virtual Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl,
int num_concurrent_tasks, bool use_sync_execution) = 0;
// Abort scheduling and execution.
// Used in case of being notified about unrecoverable error for the entire query.
virtual void Abort(AbortContinuationImpl impl) = 0;
static std::unique_ptr<TaskScheduler> Make();
};
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <string>
#include "arrow/acero/options.h"
#include "arrow/acero/test_util_internal.h"
#include "arrow/testing/random.h"
namespace arrow {
namespace acero {
// \brief Make a delaying source that is optionally noisy (prints when it emits)
AsyncGenerator<std::optional<ExecBatch>> MakeDelayedGen(
Iterator<std::optional<ExecBatch>> src, std::string label, double delay_sec,
bool noisy = false);
// \brief Make a delaying source that is optionally noisy (prints when it emits)
AsyncGenerator<std::optional<ExecBatch>> MakeDelayedGen(
AsyncGenerator<std::optional<ExecBatch>> src, std::string label, double delay_sec,
bool noisy = false);
// \brief Make a delaying source that is optionally noisy (prints when it emits)
AsyncGenerator<std::optional<ExecBatch>> MakeDelayedGen(BatchesWithSchema src,
std::string label,
double delay_sec,
bool noisy = false);
/// A node that slightly resequences the input at random
struct JitterNodeOptions : public ExecNodeOptions {
random::SeedType seed;
/// The max amount to add to a node's "cost".
int max_jitter_modifier;
explicit JitterNodeOptions(random::SeedType seed, int max_jitter_modifier = 5)
: seed(seed), max_jitter_modifier(max_jitter_modifier) {}
static constexpr std::string_view kName = "jitter";
};
class GateImpl;
class Gate {
public:
static std::shared_ptr<Gate> Make();
Gate();
virtual ~Gate();
void ReleaseAllBatches();
void ReleaseOneBatch();
Future<> WaitForNextReleasedBatch();
private:
ARROW_DISALLOW_COPY_AND_ASSIGN(Gate);
GateImpl* impl_;
};
// A node that holds all input batches until a given gate is released
struct GatedNodeOptions : public ExecNodeOptions {
explicit GatedNodeOptions(Gate* gate) : gate(gate) {}
Gate* gate;
static constexpr std::string_view kName = "gated";
};
void RegisterTestNodes();
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,31 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "arrow/record_batch.h"
#include "arrow/type_traits.h"
namespace arrow::acero {
// normalize the value to unsigned 64-bits while preserving ordering of values
template <typename T, enable_if_t<std::is_integral<T>::value, bool> = true>
uint64_t NormalizeTime(T t);
uint64_t GetTime(const RecordBatch* batch, Type::type time_type, int col, uint64_t row);
} // namespace arrow::acero

View File

@@ -0,0 +1,65 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "arrow/acero/type_fwd.h"
#include "arrow/acero/visibility.h"
#include "arrow/result.h"
#include "arrow/status.h"
namespace arrow {
namespace acero {
namespace internal {
class ARROW_ACERO_EXPORT TpchGen {
public:
virtual ~TpchGen() = default;
/*
* \brief Create a factory for nodes that generate TPC-H data
*
* Note: Individual tables will reference each other. It is important that you only
* create a single TpchGen instance for each plan and then you can create nodes for each
* table from that single TpchGen instance. Note: Every batch will be scheduled as a new
* task using the ExecPlan's scheduler.
*/
static Result<std::unique_ptr<TpchGen>> Make(
ExecPlan* plan, double scale_factor = 1.0, int64_t batch_size = 4096,
std::optional<int64_t> seed = std::nullopt);
// The below methods will create and add an ExecNode to the plan that generates
// data for the desired table. If columns is empty, all columns will be generated.
// The methods return the added ExecNode, which should be used for inputs.
virtual Result<ExecNode*> Supplier(std::vector<std::string> columns = {}) = 0;
virtual Result<ExecNode*> Part(std::vector<std::string> columns = {}) = 0;
virtual Result<ExecNode*> PartSupp(std::vector<std::string> columns = {}) = 0;
virtual Result<ExecNode*> Customer(std::vector<std::string> columns = {}) = 0;
virtual Result<ExecNode*> Orders(std::vector<std::string> columns = {}) = 0;
virtual Result<ExecNode*> Lineitem(std::vector<std::string> columns = {}) = 0;
virtual Result<ExecNode*> Nation(std::vector<std::string> columns = {}) = 0;
virtual Result<ExecNode*> Region(std::vector<std::string> columns = {}) = 0;
};
} // namespace internal
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "arrow/compute/type_fwd.h"
namespace arrow {
namespace acero {
class ExecNode;
class ExecPlan;
class ExecNodeOptions;
class ExecFactoryRegistry;
class QueryContext;
struct QueryOptions;
struct Declaration;
class SinkNodeConsumer;
} // namespace acero
} // namespace arrow

View File

@@ -0,0 +1,184 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <atomic>
#include <cstdint>
#include <optional>
#include <thread>
#include <unordered_map>
#include <vector>
#include "arrow/acero/options.h"
#include "arrow/acero/type_fwd.h"
#include "arrow/buffer.h"
#include "arrow/compute/expression.h"
#include "arrow/compute/util.h"
#include "arrow/memory_pool.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/cpu_info.h"
#include "arrow/util/logging.h"
#include "arrow/util/mutex.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/type_fwd.h"
namespace arrow {
namespace acero {
ARROW_ACERO_EXPORT
Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
int expected_num_inputs, const char* kind_name);
ARROW_ACERO_EXPORT
Result<std::shared_ptr<Table>> TableFromExecBatches(
const std::shared_ptr<Schema>& schema, const std::vector<ExecBatch>& exec_batches);
class ARROW_ACERO_EXPORT AtomicCounter {
public:
AtomicCounter() = default;
int count() const { return count_.load(); }
std::optional<int> total() const {
int total = total_.load();
if (total == -1) return {};
return total;
}
// return true if the counter is complete
bool Increment() {
ARROW_DCHECK_NE(count_.load(), total_.load());
int count = count_.fetch_add(1) + 1;
if (count != total_.load()) return false;
return DoneOnce();
}
// return true if the counter is complete
bool SetTotal(int total) {
total_.store(total);
if (count_.load() != total) return false;
return DoneOnce();
}
// return true if the counter has not already been completed
bool Cancel() { return DoneOnce(); }
// return true if the counter has finished or been cancelled
bool Completed() { return complete_.load(); }
private:
// ensure there is only one true return from Increment(), SetTotal(), or Cancel()
bool DoneOnce() {
bool expected = false;
return complete_.compare_exchange_strong(expected, true);
}
std::atomic<int> count_{0}, total_{-1};
std::atomic<bool> complete_{false};
};
class ARROW_ACERO_EXPORT ThreadIndexer {
public:
size_t operator()();
static size_t Capacity();
private:
static size_t Check(size_t thread_index);
arrow::util::Mutex mutex_;
std::unordered_map<std::thread::id, size_t> id_to_index_;
};
/// \brief A consumer that collects results into an in-memory table
struct ARROW_ACERO_EXPORT TableSinkNodeConsumer : public SinkNodeConsumer {
public:
TableSinkNodeConsumer(std::shared_ptr<Table>* out, MemoryPool* pool)
: out_(out), pool_(pool) {}
Status Init(const std::shared_ptr<Schema>& schema,
BackpressureControl* backpressure_control, ExecPlan* plan) override;
Status Consume(ExecBatch batch) override;
Future<> Finish() override;
private:
std::shared_ptr<Table>* out_;
MemoryPool* pool_;
std::shared_ptr<Schema> schema_;
std::vector<std::shared_ptr<RecordBatch>> batches_;
arrow::util::Mutex consume_mutex_;
};
class ARROW_ACERO_EXPORT NullSinkNodeConsumer : public SinkNodeConsumer {
public:
Status Init(const std::shared_ptr<Schema>&, BackpressureControl*,
ExecPlan* plan) override {
return Status::OK();
}
Status Consume(ExecBatch exec_batch) override { return Status::OK(); }
Future<> Finish() override { return Status::OK(); }
public:
static std::shared_ptr<NullSinkNodeConsumer> Make() {
return std::make_shared<NullSinkNodeConsumer>();
}
};
/// CRTP helper for tracing helper functions
class ARROW_ACERO_EXPORT TracedNode {
public:
// All nodes should call TraceStartProducing or NoteStartProducing exactly once
// Most nodes will be fine with a call to NoteStartProducing since the StartProducing
// call is usually fairly cheap and simply schedules tasks to fetch the actual data.
explicit TracedNode(ExecNode* node) : node_(node) {}
// Create a span to record the StartProducing work
[[nodiscard]] ::arrow::internal::tracing::Scope TraceStartProducing(
std::string extra_details) const;
// Record a call to StartProducing without creating with a span
void NoteStartProducing(std::string extra_details) const;
// All nodes should call TraceInputReceived for each batch they receive. This call
// should track the time spent processing the batch. NoteInputReceived is available
// but usually won't be used unless a node is simply adding batches to a trivial queue.
// Create a span to record the InputReceived work
[[nodiscard]] ::arrow::internal::tracing::Scope TraceInputReceived(
const ExecBatch& batch) const;
// Record a call to InputReceived without creating with a span
void NoteInputReceived(const ExecBatch& batch) const;
// Create a span to record any "finish" work. This should NOT be called as part of
// InputFinished and many nodes may not need to call this at all. This should be used
// when a node has some extra work that has to be done once it has received all of its
// data. For example, an aggregation node calculating aggregations. This will
// typically be called as a result of InputFinished OR InputReceived.
[[nodiscard]] ::arrow::internal::tracing::Scope TraceFinish() const;
private:
ExecNode* node_;
};
} // namespace acero
} // namespace arrow

Some files were not shown because too many files have changed in this diff Show More