from __future__ import division
import json
import os.path as osp
import numpy as np
from traits.api import HasTraits
try:
import h5py
except ImportError: # pragma: nocover
h5py = None
class OptionalDependencyMissingError(Exception):
"""Raised when an optional dependency such as h5py is required but not
installed.
"""
class _NumpyJsonEncoder(json.JSONEncoder):
def default(self, o):
# TODO: Figure out the right way to do this that maintains dtypes
if isinstance(o, np.recarray):
raise RuntimeError("Recarrays are not currently supported when "
"saving to json")
elif isinstance(o, np.ndarray):
return o.tolist()
else:
return json.JSONEncoder.default(self, o)
[docs]class Schema(HasTraits):
"""Extension to :class:`HasTraits` to add methods for automatically saving
and loading typed data.
Examples
--------
Create a new data class::
import numpy as np
from traits.api import Array
from traitschema import Schema
class Matrix(Schema):
data = Array(dtype=np.float64)
matrix = Matrix(data=np.random.random((8, 8)))
Serialize to HDF5 using :mod:`h5py`::
matrix.to_hdf("out.h5")
Load from HDF5::
matrix_copy = Matrix.from_hdf("out.h5")
"""
def __init__(self, **kwargs):
super(Schema, self).__init__(**kwargs)
traits = self.class_visible_traits()
for key, value in kwargs.items():
if key not in traits:
raise RuntimeError("trait {} is not in {}".format(
key, self.__class__.__name__
))
setattr(self, key, value)
def __str__(self): # pragma: nocover
attr_strs = ["{}={}".format(attr, getattr(self, attr))
for attr in self.visible_traits()]
return "<{}({})>".format(self.__class__.__name__, '\n '.join(attr_strs))
def __repr__(self): # pragma: nocover
return self.__str__()
def __eq__(self, other):
for attr in self.visible_traits():
this = getattr(self, attr)
that = getattr(other, attr)
try:
if this != that:
return False
except AttributeError:
return False
except ValueError:
if not all(this == that):
return False
return True
[docs] def to_dict(self):
"""Return all visible traits as a dictionary."""
return {name: getattr(self, name) for name in self.visible_traits()}
[docs] def save(self, filename):
"""Serialize using the type determined by the file extension.
Parameters
----------
filename : str
Full output path.
Notes
-----
Only default saving options are used, so this method is less flexible
than using the ``to_xyz`` methods instead.
"""
func = {
'.npz': 'to_npz',
'.h5': 'to_hdf',
'.json': 'to_json',
}[osp.splitext(filename)[1]]
if func != 'to_json':
getattr(self, func)(filename)
else:
with open(filename, 'w') as jf:
jf.write(getattr(self, func)())
[docs] @classmethod
def load(cls, filename):
"""Counterpart to :meth:`save`."""
func = {
'.npz': 'from_npz',
'.h5': 'from_hdf',
'.json': 'from_json',
}[osp.splitext(filename)[1]]
if func != 'from_json':
return getattr(cls, func)(filename)
else:
with open(filename, 'r') as jf:
return getattr(cls, func)(jf)
[docs] def to_npz(self, filename, compress=False):
"""Save in numpy's npz archive format.
Parameters
----------
filename : str
compress : bool
Save as a compressed archive (default: False)
Notes
-----
To ensure loading of scalar values works as expected, casting traits
should be used (e.g., ``CStr`` instead of ``String`` or ``Str``). See
the :mod:`traits` documentation for details.
"""
save = np.savez_compressed if compress else np.savez
attrs = self.to_dict()
save(filename, **attrs)
[docs] @classmethod
def from_npz(cls, filename):
"""Load data from numpy's npz format.
Parameters
----------
filename : str
"""
npz = np.load(filename)
attrs = {key: value for key, value in npz.items()}
self = cls(**attrs)
return self
[docs] def to_hdf(self, filename, mode='w', compression=None,
compression_opts=None, encode_string_arrays=True,
encoding='utf8'):
"""Serialize to HDF5 using :mod:`h5py`.
Parameters
----------
filename : str
Path to save HDF5 file to.
mode : str
Default: ``'w'``
compression : str or None
Compression to use with arrays (see :mod:`h5py` documentation for
valid choices).
compression_opts : int or None
Compression options, generally a number specifying compression level
(see :mod:`h5py` documentation for details).
encode_string_arrays : bool
When True, force encoding of arrays of unicode strings using the
``encoding`` keyword argument. Not setting this will result in
errors if using arrays of unicode strings. Default: True.
encoding : str
Encoding to use when forcing encoding of unicode string arrays.
Default: ``'utf8'``.
Notes
-----
Each stored dataset will also have a ``desc`` attribute which uses the
``desc`` attribute of each trait.
The root node also has attributes:
* ``classname`` - the class name of the instance being serialized
* ``python_module`` - the Python module in which the class is defined
"""
if h5py is None: # pragma: nocover
raise OptionalDependencyMissingError("h5py not found")
with h5py.File(filename, mode) as hfile:
for name in self.class_visible_traits():
trait = self.trait(name)
# Workaround for saving arrays containing unicode. When the
# data type is unicode, each element is encoded as utf-8
# before being saved to hdf5
data = getattr(self, name)
if data is None:
# If a trait has not been populated, don't try to store it
continue
data_is_recarray = isinstance(data, np.recarray)
if trait.array is True and encode_string_arrays:
# Encode each element of an array containing unicode
# elements
if ~data_is_recarray and data.dtype.char == 'U':
data = [s.encode(encoding) for s in data]
elif data_is_recarray:
# Determine what the final dtypes will be
final_dtypes = []
unicode_fields = []
for i, field in enumerate(data.dtype.names):
if data[field].dtype.kind != 'U':
final_dtypes.append((field,
data[field].dtype.str))
else:
final_dtypes.append((field, '<S256'))
unicode_fields.append(field)
# Update dtypes of the data. This will coerce the
# unicode fields to bytes automatically
data = data.astype(final_dtypes)
chunks = True if trait.array else False
compression_kwargs = {}
if chunks:
if compression is not None:
compression_kwargs['compression'] = compression
if compression_opts is not None:
compression_kwargs['compression_opts'] = compression_opts
dset = hfile.create_dataset('/{}'.format(name),
data=data,
chunks=chunks,
**compression_kwargs)
# Store the data type as an attribute to make it easier to
# reconstruct with correct data types
dset.attrs['type'] = str(type(data))
if trait.desc is not None:
dset.attrs['desc'] = trait.desc
hfile.attrs['classname'] = self.__class__.__name__
hfile.attrs['python_module'] = self.__class__.__module__
[docs] @classmethod
def from_hdf(cls, filename, decode_string_arrays=True, encoding='utf-8'):
"""Deserialize from HDF5 using :mod:`h5py`.
Parameters
----------
filename : str
decode_string_arrays: bool
Arrays of bytes should be decoded into strings
encoding: str
Encoding scheme to use for decoding
Returns
-------
Deserialized instance
"""
if h5py is None: # pragma: nocover
raise OptionalDependencyMissingError("h5py not found")
self = cls()
with h5py.File(filename, 'r') as hfile:
for name in self.visible_traits():
trait = self.trait(name)
if name not in hfile:
continue
dset = hfile['/{}'.format(name)]
data = dset.value
# Use type attribute to determine how to proceed
data_is_recarray = dset.attrs['type'] == str(np.recarray)
if trait.array is True and decode_string_arrays:
# Encode each element of an array containing bytes
if ~data_is_recarray and data.dtype.char == 'S':
data = [s.decode(encoding) for s in data]
elif data_is_recarray:
# Determine what the final dtypes will be
final_dtypes = []
bytes_fields = []
for i, field in enumerate(data.dtype.names):
if data[field].dtype.kind != 'S':
final_dtypes.append((field,
data[field].dtype.str))
else:
final_dtypes.append((field, '<U256'))
bytes_fields.append(field)
# Update dtypes of the data. This will coerce the
# bytes fields to unicode automatically
data = data.astype(final_dtypes)
setattr(self, name, data)
return self
# FIXME: this should optionally write to a file
[docs] def to_json(self, json_kwargs={}):
"""Serialize to JSON.
Parameters
----------
json_kwargs : dict
Keyword arguments to pass to :func:`json.dumps`.
Returns
-------
JSON string.
Notes
-----
This uses a custom JSON encoder to handle numpy arrays but could
conceivably lose precision. If this is important, please consider
serializing in HDF5 format instead. As a consequence of using a custom
encoder, the ``cls`` keyword arugment, if passed, will be ignored.
"""
data = {name: getattr(self, name) for name in self.visible_traits()}
json_kwargs['cls'] = _NumpyJsonEncoder
return json.dumps(data, **json_kwargs)
# FIXME allow filenames
[docs] @classmethod
def from_json(cls, data):
"""Deserialize from a JSON string or file.
Parameters
----------
data : str or file-like
Returns
-------
Deserialized instance
"""
if not isinstance(data, str):
loaded = json.load(data)
else:
loaded = json.loads(data)
return cls(**{key: value for key, value in loaded.items()})