Source code for bigdl.orca.learn.mpi.mpi_estimator

# Copyright 2016 The BigDL Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import types

import cloudpickle
from pyspark.rdd import RDD
from pyspark.sql import DataFrame
from import Dataset, DataLoader

from bigdl.dllib.utils.log4Error import *
from bigdl.dllib.utils.utils import get_node_ip
from bigdl.orca.learn.mpi.mpi_runner import MPIRunner
from bigdl.orca.learn.mpi.utils import *

[docs]class MPIEstimator: def __init__(self, model_creator, optimizer_creator, loss_creator, metrics=None, scheduler_creator=None, config=None, init_func=None, # Init the distributed environment for MPI if any hosts=None, workers_per_node=1, env=None): """ Create Orca MPI Estimator :param model_creator: A model creator function that takes the parameter "config" and returns a model :param optimizer_creator: An optimizer creator function that has two parameters "model" and "config" and returns a optimizer. :param loss_creator: An creater function to return a loss. Default: None if loss computation is not needed. :param metrics: One or a list of validation metrics. Function(s) that computes the metrics between the output and target tensors are also supported. :param scheduler_creator: A scheduler creator function that has two parameters "optimizer" and "config" and returns a learning rate scheduler wrapping the optimizer. By default a scheduler will take effect automatically every epoch. Default: None if no scheduler is needed. :param config: A parameter config dict, that plays a role of configuration to create model, loss, optimizer, scheduler and data. Default: None if no config is needed. :param init_func: A function takes the parameter "config" to init the distributed environment for MPI if any. :param hosts: host information to be run distributedly. It can be None, 'all' or list of hostname/ip. If hosts is None, means it runs on single(self) node. If hosts is 'all', it will get executor hosts from current Spark Context. Default: None. :param workers_per_node: The number of workers on each node. :param env: Special environment should be passed to MPI environment. """ self.dir = os.getcwd() self.mpi_runner = MPIRunner(hosts=hosts, processes_per_node=workers_per_node, env=env) with open("saved_mpi_estimator.pkl", "wb") as f: cloudpickle.dump((model_creator, optimizer_creator, loss_creator, metrics, scheduler_creator, config, init_func), f) self.mpi_runner.scp_file("saved_mpi_estimator.pkl", self.dir) # Need to put in the current directory so that the PYTHONPATH is the same. train_file = os.path.abspath(__file__ + "/../") p = subprocess.Popen(["cp", train_file, self.dir]) os.waitpid(, 0) self.mpi_runner.scp_file(train_file, self.dir) # Specify feature_cols and label_cols for Spark DataFrame data. # Specify train_func or validate_func for customized training and validation logic. # Specify train_batches and validate_batches in case of unbalance data. # Specify validate_steps to validate periodically. Note that validation would always be # triggered at the end of an epoch.
[docs] def fit(self, data, epochs=1, batch_size=32, validation_data=None, validate_batch_size=32, train_func=None, validate_func=None, train_batches=None, validate_batches=None, validate_steps=None, feature_cols=None, label_cols=None, mpi_options=None): """ Run distributed training through MPI. :param data: An instance of a Spark DataFrame or a function that takes config as argument and returns a PyTorch DataLoader for training. :param epochs: The number of epochs to train the model. Default is 1. :param batch_size: Batch size on each workers used for training. Default is 32. If your training data is a function, you can set batch_size to be the input batch_size of the function for the PyTorch DataLoader. :param validation_data: validation data. Validation data type should be the same as train data. :param validate_batch_size: Each worker's batch size for validation. Default is 32. If your training data is a function, you can set batch_size to be the input batch_size of the function for the PyTorch DataLoader :param train_func: Specific training loop to take parameters "config", "epochs", "model", "train_ld", "train_batches", "optimizer", "loss", "scheduler", "validate_func", "valid_ld", "metrics", "validate_batches" and "validate_steps". Default: None to use our default training loop :param validate_func: Specific validate function. Default: None to use our default validation function. :param train_batches: Specify train_batches in case of unbalance data. Default: None to train the whole train data :param validate_batches: Specify validate_batches in case of unbalance data. Default: None to validate the whole validation data :param validate_steps:Specify validate_steps to validate periodically. Note that validation would always be triggered at the end of an epoch. :param feature_cols: Specify the feature column names if data is Spark Dataframe :param label_cols: Specify the label column names if data is Spark Dataframe :param mpi_options: Specify str of addition mpi options. :return: """ if isinstance(data, DataFrame): invalidInputError(feature_cols is not None and label_cols is not None, "feature_cols and label_cols must be provided if data is" " a Spark DataFrame") data_rdd =, label_cols)) # TODO: make object store memory configurable? object_store_address = self.mpi_runner.launch_plasma(object_store_memory="100g") # partition_id, subpartition_id, subpartition_size, object_id, node_ip plasma_meta = data_rdd.mapPartitionsWithIndex( put_to_plasma(object_store_address)).collect() # The following is mainly for debugging and confirmation purpose. train_size_map = {} for partition_id, subpartition_id, subpartition_size, object_id, ip in plasma_meta: if ip not in train_size_map: train_size_map[ip] = {} if partition_id not in train_size_map[ip]: train_size_map[ip][partition_id] = [] train_size_map[ip][partition_id].append(subpartition_size) size = 0 count = 0 for node, meta in train_size_map.items(): for partition_id, subpartition_size in meta.items(): size += sum(subpartition_size) count += len(subpartition_size) print("Node {} has {} subpartitions and {} train records".format(node, count, size)) size = 0 count = 0 data_creator = plasma_data_creator(plasma_meta, object_store_address, self.mpi_runner.processes_per_node, batch_size) data_rdd.unpersist() if validation_data: invalidInputError(isinstance(validation_data, DataFrame), "expect validation data to be DataFrame") validation_data_rdd =, label_cols)) validate_plasma_meta = validation_data_rdd.mapPartitionsWithIndex( put_to_plasma(object_store_address)).collect() validate_size_map = {} for partition_id, subpartition_id, subpartition_size, object_id, ip in \ validate_plasma_meta: if ip not in validate_size_map: validate_size_map[ip] = {} if partition_id not in validate_size_map[ip]: validate_size_map[ip][partition_id] = [] validate_size_map[ip][partition_id].append(subpartition_size) size = 0 count = 0 for node, meta in validate_size_map.items(): for partition_id, subpartition_size in meta.items(): size += sum(subpartition_size) count += len(subpartition_size) print("Node {} has {} subpartitions and {} test records" .format(node, count, size)) size = 0 count = 0 validation_data_creator = plasma_data_creator( validate_plasma_meta, object_store_address, self.mpi_runner.processes_per_node, validate_batch_size) validation_data_rdd.unpersist() else: validation_data_creator = None else: invalidInputError(isinstance(data, types.FunctionType), "expect data is FunctionType") data_creator = data if validation_data: invalidInputError(isinstance(validation_data, types.FunctionType), "expect validaton data is FunctionType") validation_data_creator = validation_data else: validation_data_creator = None if not train_func: train_func = train if validation_data_creator: if not validate_func: validate_func = validate with open("mpi_train_data.pkl", "wb") as f: cloudpickle.dump((data_creator, epochs, batch_size, validation_data_creator, validate_batch_size, train_func, validate_func, train_batches, validate_batches, validate_steps), f) self.mpi_runner.scp_file("mpi_train_data.pkl", self.dir)"{}/".format(self.dir), mpi_options=mpi_options, pkl_path=self.dir) if isinstance(data, DataFrame): self.mpi_runner.shutdown_plasma()
[docs] def shutdown(self): self.mpi_runner.shutdown_plasma()
def convert_row(feature_cols, label_cols): def convert_for_cols(row, cols): result = [] for name in cols: result.append(row[name]) if len(result) == 1: return result[0] return result def transform(row): features = convert_for_cols(row, feature_cols) if label_cols: labels = convert_for_cols(row, label_cols) return features, labels else: return features, return transform def put_to_plasma(address): def f(index, iterator): import pyarrow.plasma as plasma client = plasma.connect(address) part_size = 1000000 # TODO: Make subpartition size configurable? buffer = [] sub_index = 0 for record in iterator: if len(buffer) == part_size: res_buffer = process_records(buffer) object_id = client.put(res_buffer) buffer = [record] yield index, sub_index, part_size, object_id, get_node_ip() sub_index += 1 else: buffer.append(record) remain_size = len(buffer) if remain_size > 0: res_buffer = process_records(buffer) object_id = client.put(res_buffer) buffer = [] client.disconnect() yield index, sub_index, remain_size, object_id, get_node_ip() else: client.disconnect() return f class PlasmaNDArrayDataset(Dataset): def __init__(self, meta_data, object_store_address, workers_per_node=1, batch_size=1): import pyarrow.plasma as plasma self.client = plasma.connect(object_store_address) print("Connected to plasma") # All the subpartitions on this node all_data = [subpartition for subpartition in meta_data if subpartition[4] == get_node_ip()] rank = int(os.environ.get("PMI_RANK", 0)) print("Global rank: ", rank) # rank = int(os.environ.get("PMIX_RANK", 0)) # For OpenMPI local_rank = rank % workers_per_node print("Local rank: ", local_rank) data_splits = list(chunks(all_data, len(all_data) // workers_per_node)) worker_data = data_splits[local_rank] if len(data_splits) == (workers_per_node + 1): # Can't evenly split among workers remain_data = data_splits[-1] if local_rank < len(remain_data): worker_data += [remain_data[local_rank]] self.object_ids = [subpartition[3] for subpartition in worker_data] self.sizes = [subpartition[2] for subpartition in worker_data] print("Data size for worker: ", sum(self.sizes)) self.batch_size = batch_size offsets = [] for i in self.sizes: if len(offsets) == 0: offsets.append(i) else: offsets.append(offsets[-1] + i) self.offsets = offsets self.current_index = 0 # Current index for object_id; data loaded self.load_from_plasma(self.current_index) def reset(self): self.current_index = 0 self.load_from_plasma(self.current_index) def load_from_plasma(self, index): print("Loading {} of size {}".format(self.object_ids[index], self.sizes[index])) current_data = self.client.get(self.object_ids[index], timeout_ms=0) self.current_x = current_data["x"] self.current_y = current_data["y"] self.current_offset = self.offsets[index] def __len__(self): return sum(self.sizes) // self.batch_size def __getitem__(self, i): # Directly get a batch # print("Loading batch ", i) if i == 0 and self.current_index != 0: self.reset() current_available_size = self.current_offset - i * self.batch_size x_list = [] y_list = [] if current_available_size < self.batch_size: if current_available_size != 0: # Add all the remaining records into this batch x_list.append(index(self.current_x, start=-current_available_size)) y_list.append(index(self.current_y, start=-current_available_size)) # Load subsequent file(s) to complete the batch remain_size = self.batch_size - current_available_size while True: self.current_index += 1 self.load_from_plasma(self.current_index) if self.sizes[self.current_index] >= remain_size: x_list.append(index(self.current_x, end=remain_size)) y_list.append(index(self.current_y, end=remain_size)) break else: x_list.append(self.current_x) y_list.append(self.current_y) remain_size -= self.sizes[self.current_index] if remain_size == 0: break # The current file contains a full batch elif current_available_size == self.batch_size: x_list.append(index(self.current_x, start=-current_available_size)) y_list.append(index(self.current_y, start=-current_available_size)) else: x_list.append(index(self.current_x, start=-current_available_size, end=-current_available_size + self.batch_size)) y_list.append(index(self.current_y, start=-current_available_size, end=-current_available_size + self.batch_size)) if isinstance(self.current_x, list): x_np = [] for i in range(len(self.current_x)): x_np.append(np.concatenate([x[i] for x in x_list])) else: x_np = np.concatenate(x_list) y_np = np.concatenate(y_list) # Can put collate_fn into train_func if necessary. return x_np, y_np def plasma_data_creator(meta_data, object_store_address, workers_per_node=1, batch_size=1): def create_plasma_dataloader(config): dataset = PlasmaNDArrayDataset(meta_data, object_store_address, workers_per_node, batch_size) # TODO: support more options loader = DataLoader( dataset, batch_size=None, shuffle=False, collate_fn=None, ) return loader return create_plasma_dataloader def train(config, epochs, model, train_ld, train_batches, optimizer, loss, scheduler, validate_func, valid_ld, metrics, validate_batches, validate_steps): import torch import time total_loss = 0 total_samp = 0 total_iter = 0 total_time = 0 previous_iteration_time = None step = 0 for i in range(epochs): model.train() if config['use_ipex']: import intel_extension_for_pytorch as ipex if config['bf16']: model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16) else: model, optimizer = ipex.optimize(model, optimizer=optimizer) train_iter = iter(train_ld) for j in range(train_batches): # Iterate again from the beginning if running out of batches. if j > 0 and j % len(train_ld) == 0: train_iter = iter(train_ld) current_time = time.time() if previous_iteration_time: iteration_time = current_time - previous_iteration_time else: iteration_time = 0 previous_iteration_time = current_time x, y = next(train_iter) if config['bf16']: with torch.cpu.amp.autocast(): o = model(x, y) l = loss(o, y) l_np = l.detach().cpu().numpy() y_np = y.detach().cpu().numpy() else: o = model(x, y) l = loss(o, y) l_np = l.detach().cpu().numpy() y_np = y.detach().cpu().numpy() optimizer.zero_grad() l.backward() optimizer.step() if scheduler: scheduler.step() batch_samples = y_np.shape[0] total_time += iteration_time total_loss += l_np * batch_samples total_iter += 1 total_samp += batch_samples step += 1 should_print = ("print_freq" in config and (step % config["print_freq"] == 0)) \ or (j + 1 == train_batches) if should_print: average_batch_time = 1000.0 * total_time / total_iter total_time = 0 average_loss = total_loss / total_samp total_loss = 0 print( "Finished training it {}/{} of epoch {}, {:.2f} ms/it, ".format( j + 1, train_batches, i, average_batch_time) + "loss {:.6f}, ".format(average_loss) ) total_iter = 0 total_samp = 0 should_validate = valid_ld and ((validate_steps > 0 and (step % validate_steps == 0)) or (j + 1 == train_batches)) if should_validate: validate_func(config, model, valid_ld, metrics, validate_batches) # TODO: add loss def validate(config, model, valid_ld, metrics, validate_batches): import torch from bigdl.orca.learn.metrics import Metric model.eval() metrics = Metric.convert_metrics_dict(metrics, backend="pytorch") valid_iter = iter(valid_ld) with torch.no_grad(): for j in range(validate_batches): # Iterate again from the beginning if running out of batches. if j > 0 and j % len(valid_ld) == 0: valid_iter = iter(valid_ld) x, y = next(valid_iter) o = model(x, y) for metric in metrics.values(): metric(o, y) result = {name: metric.compute() for name, metric in metrics.items()} output = "Validation results: " for metric, value in result.items(): output += "{}:{} ".format(metric, value) print(output) return result