#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from bigdl.dllib.utils.common import get_node_and_core_number
from bigdl.dllib.nncontext import init_nncontext
from bigdl.orca import OrcaContext
from bigdl.orca.data import SparkXShards
from bigdl.orca.data.utils import *
from bigdl.dllib.utils.log4Error import invalidInputError
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from bigdl.orca.data.shard import SparkXShards
from pyspark.sql.types import StructType
[docs]def read_csv(file_path: str, **kwargs) -> "SparkXShards":
"""
Read csv files to SparkXShards of pandas DataFrames.
:param file_path: A csv file path, a list of multiple csv file paths, or a directory
containing csv files. Local file system, HDFS, and AWS S3 are supported.
:param kwargs: You can specify read_csv options supported by pandas.
:return: An instance of SparkXShards.
"""
return read_file_spark(file_path, "csv", **kwargs)
[docs]def read_json(file_path: str, **kwargs) -> "SparkXShards":
"""
Read json files to SparkXShards of pandas DataFrames.
:param file_path: A json file path, a list of multiple json file paths, or a directory
containing json files. Local file system, HDFS, and AWS S3 are supported.
:param kwargs: You can specify read_json options supported by pandas.
:return: An instance of SparkXShards.
"""
return read_file_spark(file_path, "json", **kwargs)
[docs]def read_file_spark(file_path: str, file_type: str, **kwargs) -> "SparkXShards":
sc = init_nncontext()
node_num, core_num = get_node_and_core_number()
backend = OrcaContext.pandas_read_backend
if backend == "pandas":
file_url_splits = file_path.split("://")
prefix = file_url_splits[0]
file_paths = []
if isinstance(file_path, list):
[file_paths.extend(extract_one_path(path, os.environ)) for path in file_path]
else:
file_paths = extract_one_path(file_path, os.environ)
if not file_paths:
invalidInputError(False,
"The file path is invalid or empty, please check your data")
num_files = len(file_paths)
total_cores = node_num * core_num
num_partitions = num_files if num_files < total_cores else total_cores
rdd = sc.parallelize(file_paths, num_partitions)
if prefix == "hdfs":
pd_rdd = rdd.mapPartitions(
lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs))
elif prefix == "s3":
pd_rdd = rdd.mapPartitions(
lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs))
else:
def loadFile(iterator):
dfs = []
for x in iterator:
df = read_pd_file(x, file_type, **kwargs)
dfs.append(df)
import pandas as pd
return [pd.concat(dfs)]
pd_rdd = rdd.mapPartitions(loadFile)
else: # Spark backend; spark.read.csv/json accepts a folder path as input
invalidInputError(file_type == "json" or file_type == "csv",
"Unsupported file type: %s. Only csv and json files are"
" supported for now" % file_type)
spark = OrcaContext.get_spark_session()
# TODO: add S3 confidentials
# The following implementation is adapted from
# https://github.com/databricks/koalas/blob/master/databricks/koalas/namespace.py
# with some modifications.
if "mangle_dupe_cols" in kwargs:
invalidInputError(kwargs["mangle_dupe_cols"],
"mangle_dupe_cols can only be True")
kwargs.pop("mangle_dupe_cols")
if "parse_dates" in kwargs:
invalidInputError(not kwargs["parse_dates"], "parse_dates can only be False")
kwargs.pop("parse_dates")
names = kwargs.get("names", None)
if "names" in kwargs:
kwargs.pop("names")
usecols = kwargs.get("usecols", None)
if "usecols" in kwargs:
kwargs.pop("usecols")
dtype = kwargs.get("dtype", None)
if "dtype" in kwargs:
kwargs.pop("dtype")
squeeze = kwargs.get("squeeze", False)
if "squeeze" in kwargs:
kwargs.pop("squeeze")
index_col = kwargs.get("index_col", None)
if "index_col" in kwargs:
kwargs.pop("index_col")
if file_type == "csv":
# Handle pandas-compatible keyword arguments
kwargs["inferSchema"] = True
header = kwargs.get("header", "infer")
if isinstance(names, str):
kwargs["schema"] = names
if header == "infer":
header = 0 if names is None else None
if header == 0:
kwargs["header"] = True
elif header is None:
kwargs["header"] = False
else:
invalidInputError(False,
"Unknown header argument {}".format(header))
if "quotechar" in kwargs:
quotechar = kwargs["quotechar"]
kwargs.pop("quotechar")
kwargs["quote"] = quotechar
if "escapechar" in kwargs:
escapechar = kwargs["escapechar"]
kwargs.pop("escapechar")
kwargs["escape"] = escapechar
# sep and comment are the same as pandas
if "comment" in kwargs:
comment = kwargs["comment"]
if not isinstance(comment, str) or len(comment) != 1:
invalidInputError(False,
"Only length-1 comment characters supported")
df = spark.read.csv(file_path, **kwargs)
if header is None:
df = df.selectExpr(
*["`%s` as `%s`" % (field.name, i) for i, field in enumerate(df.schema)])
else:
df = spark.read.json(file_path, **kwargs)
# Handle pandas-compatible postprocessing arguments
if usecols is not None and not callable(usecols):
usecols = list(usecols)
renamed = False
if isinstance(names, list):
if len(set(names)) != len(names):
invalidInputError(False,
"Found duplicate names, please check your names input")
if usecols is not None:
if not callable(usecols):
# usecols is list
if len(names) != len(usecols) and len(names) != len(df.schema):
invalidInputError(False,
"Passed names did not match usecols")
if len(names) == len(df.schema):
df = df.selectExpr(
*["`%s` as `%s`" % (field.name, name) for field, name
in zip(df.schema, names)]
)
renamed = True
else:
if len(names) != len(df.schema):
invalidInputError(False,
"The number of names [%s] does not match the number "
"of columns [%d]. Try names by a Spark SQL DDL-formatted "
"string." % (len(names), len(df.schema)))
df = df.selectExpr(
*["`%s` as `%s`" % (field.name, name) for field, name
in zip(df.schema, names)]
)
renamed = True
index_map = dict([(i, field.name) for i, field in enumerate(df.schema)])
if usecols is not None:
if callable(usecols):
cols = [field.name for field in df.schema if usecols(field.name)]
missing = []
elif all(isinstance(col, int) for col in usecols):
cols = [field.name for i, field in enumerate(df.schema) if i in usecols]
missing = [
col
for col in usecols
if col >= len(df.schema) or df.schema[col].name not in cols
]
elif all(isinstance(col, str) for col in usecols):
cols = [field.name for field in df.schema if field.name in usecols]
if isinstance(names, list):
missing = [c for c in usecols if c not in names]
else:
missing = [col for col in usecols if col not in cols]
else:
invalidInputError(False,
"usecols must only be list-like of all strings, "
"all unicode, all integers or a callable.")
if len(missing) > 0:
invalidInputError(False,
"usecols do not match columns, columns expected but"
" not found: %s" % missing)
if len(cols) > 0:
df = df.select(cols)
if isinstance(names, list):
if not renamed:
df = df.selectExpr(
*["`%s` as `%s`" % (col, name) for col, name in zip(cols, names)]
)
# update index map after rename
for index, col in index_map.items():
if col in cols:
index_map[index] = names[cols.index(col)]
if df.rdd.getNumPartitions() < node_num:
df = df.repartition(node_num)
from bigdl.orca.data.utils import spark_df_to_rdd_pd
pd_rdd = spark_df_to_rdd_pd(df, squeeze, index_col, dtype, index_map)
try:
data_shards = SparkXShards(pd_rdd, class_name="pandas.core.frame.DataFrame")
except Exception as e:
alternative_backend = "pandas" if backend == "spark" else "spark"
print("An error occurred when reading files with '%s' backend, you may switch to '%s' "
"backend for another try. You can set the backend using "
"OrcaContext.pandas_read_backend" % (backend, alternative_backend))
invalidInputError(False, str(e))
return data_shards
[docs]def read_parquet(file_path: str,
columns: Optional[List[str]]=None,
schema: Optional["StructType"]=None, **options) -> "SparkXShards":
"""
Read parquet files to SparkXShards of pandas DataFrames.
:param file_path: Parquet file path, a list of multiple parquet file paths, or a directory
containing parquet files. Local file system, HDFS, and AWS S3 are supported.
:param columns: list of column name, default=None.
If not None, only these columns will be read from the file.
:param schema: pyspark.sql.types.StructType for the input schema or
a DDL-formatted string (For example col0 INT, col1 DOUBLE).
:param options: other options for reading parquet.
:return: An instance of SparkXShards.
"""
sc = init_nncontext()
spark = OrcaContext.get_spark_session()
# df = spark.read.parquet(file_path)
df = spark.read.load(file_path, "parquet", schema=schema, **options)
if columns:
df = df.select(*columns)
def to_pandas(columns):
def f(iter):
import pandas as pd
data = list(iter)
pd_df = pd.DataFrame(data, columns=columns)
return [pd_df]
return f
pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns))
try:
data_shards = SparkXShards(pd_rdd)
except Exception as e:
print("An error occurred when reading parquet files")
invalidInputError(False, str(e))
return data_shards