Source code for bigdl.orca.data.pandas.preprocessing

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
from bigdl.dllib.utils.common import get_node_and_core_number
from bigdl.dllib.nncontext import init_nncontext
from bigdl.orca import OrcaContext
from bigdl.orca.data import SparkXShards
from bigdl.orca.data.utils import *
from bigdl.dllib.utils.log4Error import invalidInputError
from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
    from bigdl.orca.data.shard import SparkXShards
    from pyspark.sql.types import StructType


[docs]def read_csv(file_path: str, **kwargs) -> "SparkXShards": """ Read csv files to SparkXShards of pandas DataFrames. :param file_path: A csv file path, a list of multiple csv file paths, or a directory containing csv files. Local file system, HDFS, and AWS S3 are supported. :param kwargs: You can specify read_csv options supported by pandas. :return: An instance of SparkXShards. """ return read_file_spark(file_path, "csv", **kwargs)
[docs]def read_json(file_path: str, **kwargs) -> "SparkXShards": """ Read json files to SparkXShards of pandas DataFrames. :param file_path: A json file path, a list of multiple json file paths, or a directory containing json files. Local file system, HDFS, and AWS S3 are supported. :param kwargs: You can specify read_json options supported by pandas. :return: An instance of SparkXShards. """ return read_file_spark(file_path, "json", **kwargs)
[docs]def read_file_spark(file_path: str, file_type: str, **kwargs) -> "SparkXShards": sc = init_nncontext() node_num, core_num = get_node_and_core_number() backend = OrcaContext.pandas_read_backend if backend == "pandas": file_url_splits = file_path.split("://") prefix = file_url_splits[0] file_paths = [] if isinstance(file_path, list): [file_paths.extend(extract_one_path(path, os.environ)) for path in file_path] else: file_paths = extract_one_path(file_path, os.environ) if not file_paths: invalidInputError(False, "The file path is invalid or empty, please check your data") num_files = len(file_paths) total_cores = node_num * core_num num_partitions = num_files if num_files < total_cores else total_cores rdd = sc.parallelize(file_paths, num_partitions) if prefix == "hdfs": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs)) elif prefix == "s3": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs)) else: def loadFile(iterator): dfs = [] for x in iterator: df = read_pd_file(x, file_type, **kwargs) dfs.append(df) import pandas as pd return [pd.concat(dfs)] pd_rdd = rdd.mapPartitions(loadFile) else: # Spark backend; spark.read.csv/json accepts a folder path as input invalidInputError(file_type == "json" or file_type == "csv", "Unsupported file type: %s. Only csv and json files are" " supported for now" % file_type) spark = OrcaContext.get_spark_session() # TODO: add S3 confidentials # The following implementation is adapted from # https://github.com/databricks/koalas/blob/master/databricks/koalas/namespace.py # with some modifications. if "mangle_dupe_cols" in kwargs: invalidInputError(kwargs["mangle_dupe_cols"], "mangle_dupe_cols can only be True") kwargs.pop("mangle_dupe_cols") if "parse_dates" in kwargs: invalidInputError(not kwargs["parse_dates"], "parse_dates can only be False") kwargs.pop("parse_dates") names = kwargs.get("names", None) if "names" in kwargs: kwargs.pop("names") usecols = kwargs.get("usecols", None) if "usecols" in kwargs: kwargs.pop("usecols") dtype = kwargs.get("dtype", None) if "dtype" in kwargs: kwargs.pop("dtype") squeeze = kwargs.get("squeeze", False) if "squeeze" in kwargs: kwargs.pop("squeeze") index_col = kwargs.get("index_col", None) if "index_col" in kwargs: kwargs.pop("index_col") if file_type == "csv": # Handle pandas-compatible keyword arguments kwargs["inferSchema"] = True header = kwargs.get("header", "infer") if isinstance(names, str): kwargs["schema"] = names if header == "infer": header = 0 if names is None else None if header == 0: kwargs["header"] = True elif header is None: kwargs["header"] = False else: invalidInputError(False, "Unknown header argument {}".format(header)) if "quotechar" in kwargs: quotechar = kwargs["quotechar"] kwargs.pop("quotechar") kwargs["quote"] = quotechar if "escapechar" in kwargs: escapechar = kwargs["escapechar"] kwargs.pop("escapechar") kwargs["escape"] = escapechar # sep and comment are the same as pandas if "comment" in kwargs: comment = kwargs["comment"] if not isinstance(comment, str) or len(comment) != 1: invalidInputError(False, "Only length-1 comment characters supported") df = spark.read.csv(file_path, **kwargs) if header is None: df = df.selectExpr( *["`%s` as `%s`" % (field.name, i) for i, field in enumerate(df.schema)]) else: df = spark.read.json(file_path, **kwargs) # Handle pandas-compatible postprocessing arguments if usecols is not None and not callable(usecols): usecols = list(usecols) renamed = False if isinstance(names, list): if len(set(names)) != len(names): invalidInputError(False, "Found duplicate names, please check your names input") if usecols is not None: if not callable(usecols): # usecols is list if len(names) != len(usecols) and len(names) != len(df.schema): invalidInputError(False, "Passed names did not match usecols") if len(names) == len(df.schema): df = df.selectExpr( *["`%s` as `%s`" % (field.name, name) for field, name in zip(df.schema, names)] ) renamed = True else: if len(names) != len(df.schema): invalidInputError(False, "The number of names [%s] does not match the number " "of columns [%d]. Try names by a Spark SQL DDL-formatted " "string." % (len(names), len(df.schema))) df = df.selectExpr( *["`%s` as `%s`" % (field.name, name) for field, name in zip(df.schema, names)] ) renamed = True index_map = dict([(i, field.name) for i, field in enumerate(df.schema)]) if usecols is not None: if callable(usecols): cols = [field.name for field in df.schema if usecols(field.name)] missing = [] elif all(isinstance(col, int) for col in usecols): cols = [field.name for i, field in enumerate(df.schema) if i in usecols] missing = [ col for col in usecols if col >= len(df.schema) or df.schema[col].name not in cols ] elif all(isinstance(col, str) for col in usecols): cols = [field.name for field in df.schema if field.name in usecols] if isinstance(names, list): missing = [c for c in usecols if c not in names] else: missing = [col for col in usecols if col not in cols] else: invalidInputError(False, "usecols must only be list-like of all strings, " "all unicode, all integers or a callable.") if len(missing) > 0: invalidInputError(False, "usecols do not match columns, columns expected but" " not found: %s" % missing) if len(cols) > 0: df = df.select(cols) if isinstance(names, list): if not renamed: df = df.selectExpr( *["`%s` as `%s`" % (col, name) for col, name in zip(cols, names)] ) # update index map after rename for index, col in index_map.items(): if col in cols: index_map[index] = names[cols.index(col)] if df.rdd.getNumPartitions() < node_num: df = df.repartition(node_num) from bigdl.orca.data.utils import spark_df_to_rdd_pd pd_rdd = spark_df_to_rdd_pd(df, squeeze, index_col, dtype, index_map) try: data_shards = SparkXShards(pd_rdd, class_name="pandas.core.frame.DataFrame") except Exception as e: alternative_backend = "pandas" if backend == "spark" else "spark" print("An error occurred when reading files with '%s' backend, you may switch to '%s' " "backend for another try. You can set the backend using " "OrcaContext.pandas_read_backend" % (backend, alternative_backend)) invalidInputError(False, str(e)) return data_shards
[docs]def read_parquet(file_path: str, columns: Optional[List[str]]=None, schema: Optional["StructType"]=None, **options) -> "SparkXShards": """ Read parquet files to SparkXShards of pandas DataFrames. :param file_path: Parquet file path, a list of multiple parquet file paths, or a directory containing parquet files. Local file system, HDFS, and AWS S3 are supported. :param columns: list of column name, default=None. If not None, only these columns will be read from the file. :param schema: pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For example col0 INT, col1 DOUBLE). :param options: other options for reading parquet. :return: An instance of SparkXShards. """ sc = init_nncontext() spark = OrcaContext.get_spark_session() # df = spark.read.parquet(file_path) df = spark.read.load(file_path, "parquet", schema=schema, **options) if columns: df = df.select(*columns) def to_pandas(columns): def f(iter): import pandas as pd data = list(iter) pd_df = pd.DataFrame(data, columns=columns) return [pd_df] return f pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns)) try: data_shards = SparkXShards(pd_rdd) except Exception as e: print("An error occurred when reading parquet files") invalidInputError(False, str(e)) return data_shards