Source code for mstc.processing.encoder

"""Components for encoding."""
import logging
import xarray as xr
import tensorflow as tf
import tensorflow_hub as hub
from .core import Component


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(
    '%(asctime)s;%(levelname)s;%(message)s',
    '%H:%M:%S'
))
logger.addHandler(stream_handler)


[docs]class Encoder(Component):
[docs] def __init__(self, attributes={}): """ Initialize the encoder. Args: attributes (dict): attributes to add to the resulting xr.DataArray. """ self.attributes = attributes
[docs] def __call__(self, an_object): """ Encoding samples from an object and return results in xr.DataArray. Args: an_object (object): an object containing the data to be encoded. Returns: an object. """ raise NotImplementedError
[docs]class HubEncoder(Encoder):
[docs] def __init__(self, hub_module, batch_size=32, **kwargs): self.batch_size = batch_size self.graph = tf.Graph() with self.graph.as_default(): logger.debug("load module") self.module = hub.Module(hub_module) logger.debug("load module done") self.height, self.width = hub.get_expected_image_size(self.module) kwargs.update({'encoded_image_size': (self.height, self.width)}) super(HubEncoder, self).__init__(attributes=kwargs) logger.debug("tf.data pipeline definition") def _resize_images(dataset, height=self.height, width=self.width): dataset = tf.cond( tf.equal(tf.shape(dataset)[-1], 3), lambda: dataset, lambda: tf.image.grayscale_to_rgb(dataset)) dataset = tf.image.resize_images(dataset, (height, width)) return dataset self.data = tf.placeholder(tf.float32, shape=[None, None, None, None]) dataset = tf.data.Dataset.from_tensor_slices(self.data) dataset = dataset.map(_resize_images) dataset = dataset.batch(batch_size) # single batch self.iterator = dataset.make_initializable_iterator() next_batch = self.iterator.get_next() self.features = self.module(next_batch)
[docs] def __call__(self, data_array): """ Encoding images with a tensorflow hub module. The images are resized to fit the module. Args: data_array (xarray.DataArray): expected dims are sample, height, width, channel. Length of channel must be 1 or 3. Returns: a xr.DataArray. """ zeroth_dim = data_array.dims[0] # vessels to aggregate evaluated batches features_arrays = [] with self.graph.as_default(): # if True: # to match indentation # sess = tf.InteractiveSession(graph=self.graph) with tf.Session(graph=self.graph) as sess: sess.run(tf.global_variables_initializer()) logger.debug("running batches") length = data_array.sizes[zeroth_dim] for i in range(0, length, self.batch_size): # handcrafted lazy feeding sess.run( self.iterator.initializer, feed_dict={self.data: data_array[i:i+self.batch_size]} ) features_batch_vec = sess.run([self.features]) batch_array = xr.DataArray( data=features_batch_vec[0] ) features_arrays.append(batch_array) logger.info('Encoding by tensorflow finished') features_array = ( xr.concat( features_arrays, dim='dim_0' ).rename( {'dim_0': zeroth_dim, 'dim_1': 'hub_feature'} ).assign_coords( **{zeroth_dim: data_array.coords[zeroth_dim]} ).assign_attrs(self.attributes) ) return features_array
[docs]class Flatten(Encoder): """Flatten a xr.DataArray over all dimensions but one."""
[docs] def __init__(self, dim='features', dim_to_keep='', **kwargs): """ Initialize the flattening encoder. Args: dim (str): name of the dimension generated by flattening, defaults to 'features'. dim_to_keep (str): name of the dimension to keep, defaults to '' that conists in flattening all dimensions but the first. kwargs (dict): arguments to pass to Encoder as attributes. """ super(Flatten, self).__init__(attributes=kwargs) self.dim = dim self.dim_to_keep = dim_to_keep
[docs] def __call__(self, data_array): """ Encoding a xr.DataArray by flattening all dimensions but one. The kept dimension becomes the first of the generated xr.DataArray. Args: data_array (xr.DataArray): a data array that has to be flattened. Returns: a xr.DataArray. """ dimensions = data_array.dims # handle the case where the dim is not provided if len(self.dim_to_keep) < 1: self.dim_to_keep = dimensions[0] # here we preserve the order of the dimensions for consistency # same can be achived with an OrderedSet but it seemed an # overkill to_flatten = [ dimension for dimension in dimensions if dimension != self.dim_to_keep ] return data_array.stack( {self.dim: to_flatten} ).assign_attrs(self.attributes)