Source code for bigdl.nano.pytorch.patching.encryption_patching.encryption_patching

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


_torch_encryption_patch = None
is_encryption_patched = False


[docs]def patch_encryption(): """ patch_torch is used to patch torch.save and torch.load methods to replace original ones. Patched details include: | 1. torch.save is now located at bigdl.nano.pytorch.encryption.save | 2. torch.load is now located at bigdl.nano.pytorch.encryption.load A key argument is added to torch.save and torch.load which is used to encrypt/decrypt the content before saving/loading it to/from disk. .. note:: Please be noted that the key is only secured in Intel SGX mode. """ global is_encryption_patched if is_encryption_patched: return mapping_torch = _get_encryption_patch_map() for mapping_iter in mapping_torch: setattr(mapping_iter[0], mapping_iter[1], mapping_iter[2]) is_encryption_patched = True
def _get_encryption_patch_map(): global _torch_encryption_patch import torch from bigdl.nano.pytorch.encryption import save, load _torch_encryption_patch = [] _torch_encryption_patch += [ [torch, "old_save", torch.save], [torch, "old_load", torch.load], [torch, "save", save], [torch, "load", load], ] return _torch_encryption_patch