# This file is part of the pyMOR project (
# Copyright 2013-2020 pyMOR developers and contributors. All rights reserved.
# License: BSD 2-Clause License (
from import loadmat, mmread
from scipy.sparse import issparse
import numpy as np
import tempfile
import os
from contextlib import contextmanager
import shutil
from pathlib import Path

from pymor.core.logger import getLogger

def _loadmat(path, key=None):

        data = loadmat(path, mat_dtype=True)
    except Exception as e:
        raise IOError(e)

    if key:
            return data[key]
        except KeyError:
            raise IOError(f'"{key}" not found in MATLAB file {path}')

    data = [v for v in data.values() if isinstance(v, np.ndarray) or issparse(v)]

    if len(data) == 0:
        raise IOError(f'No matrix data contained in MATLAB file {path}')
    elif len(data) > 1:
        raise IOError(f'More than one matrix object stored in MATLAB file {path}')
        return data[0]

def _mmread(path, key=None):

    if key:
        raise IOError('Cannot specify "key" for Matrix Market file')
        matrix = mmread(path)
        if issparse(matrix):
            matrix = matrix.tocsc()
        return matrix
    except Exception as e:
        raise IOError(e)

def _load(path, key=None):
    data = np.load(path)
    if isinstance(data, dict):
        if key:
                matrix = data[key]
            except KeyError:
                raise IOError(f'"{key}" not found in NPY file {path}')
        elif len(data) == 0:
            raise IOError(f'No data contained in NPY file {path}')
        elif len(data) > 1:
            raise IOError(f'More than one object stored in NPY file {path} for key {key}')
            matrix = next(iter(data.values()))
        matrix = data
    if not isinstance(matrix, np.ndarray) and not issparse(matrix):
        raise IOError(f'Loaded data is not a matrix in NPY file {path}')
    return matrix

def _loadtxt(path, key=None):
    if key:
        raise IOError('Cannot specify "key" for TXT file')
        return np.loadtxt(path)
    except Exception as e:
        raise IOError(e)

[docs]def load_matrix(path, key=None): logger = getLogger('')'Loading matrix from file %s', path) # convert if path is str path = Path(path) if len(path.suffixes[-1]) == 3: extension = path.suffixes[-1].lower() elif path.suffixes[-1].lower() == 'gz' and len(path.suffixes) >= 2 and len(path.suffixes[-2]) == 3: extension = '.'.join(path.suffixes[-2:]).lower() else: extension = None file_format_map = {'mat': ('MATLAB', _loadmat), 'mtx': ('Matrix Market', _mmread), 'mtz.gz': ('Matrix Market', _mmread), 'npy': ('NPY/NPZ', _load), 'npz': ('NPY/NPZ', _load), 'txt': ('Text', _loadtxt)} if extension in file_format_map: file_type, loader = file_format_map[extension] + ' file detected.') return loader(path, key) logger.warning('Could not detect file format. Trying all loaders ...') loaders = [_loadmat, _mmread, _loadtxt, _load] for loader in loaders: try: return loader(path, key) except IOError: pass raise IOError(f'Could not load file {path} (key = {key})')
[docs]@contextmanager def SafeTemporaryFileName(name=None, parent_dir=None): """Cross Platform safe equivalent of re-opening a NamedTemporaryFile Creates an automatically cleaned up temporary directory with a single file therein. name: filename component, defaults to 'temp_file' dir: the parent dir of the new tmp dir. defaults to tempfile.gettempdir() """ parent_dir = parent_dir or tempfile.gettempdir() name = name or 'temp_file' dirname = tempfile.mkdtemp(dir=parent_dir) path = os.path.join(dirname, name) yield path shutil.rmtree(dirname)
[docs]def file_owned_by_current_user(filename): try: return os.stat(filename).st_uid == os.getuid() except AttributeError: # this is actually less secure than above since getuser looks in env for username # a portable way to getuid might be in psutil from getpass import getuser import win32security f = win32security.GetFileSecurity(filename, win32security.OWNER_SECURITY_INFORMATION) username, _, _ = win32security.LookupAccountSid(None, f.GetSecurityDescriptorOwner()) return username == getuser()