#! /usr/bin/python3
########################################################################
# This file contains several utility functions for visualising,
# creating and managing datasets for Machine Learning applications.
# All datasets are in the form of numpy arrays with each row containing
# one data sample. If applicable, the last column contains class labels.
########################################################################
import numpy as np
import os, gzip, pickle
import cv2
from matplotlib import pyplot as plt
from PIL import Image

########################################################################
# FUNCTION:     pickle_read_data (fname, tvx=True, gz=True, verify=True)
# PARAMETERS:   fname (string) is the name of the file containing
#                       a dataset.
#               tvx (boolean; default:True) indicates if the data in
#                       the file is stored as three separate sets -
#                       training, validation and test sets.
#               gz (boolean; default: True) indicates if the file
#                       is in gzipped format.
#               verify (boolean; default: True) indicates if shape
#                       of the data arrays after reading the file
#                       should be displayed.
# RETURN VALUE: Training, validation and test sets in a tuple of three
#                       tuples. In each tuple, the first element is
#                       a numpy array containing one data sample per row.
#                       The second element is an array of class labels.
#               If tvx=False, then a single tuple with the first element
#                       containing data samples, one per row. The second
#                       element containing class labels.
#               Returns False if unsuccessful.
########################################################################
def pickle_read_data(fname, tvx=True, gz=True, verify=True) :
    if os.access(fname, os.R_OK) :
        if gz :
            df = gzip.open(fname, 'rb')
        else :
            df = open(fname, 'rb')
        u = pickle._Unpickler(df)
        u.encoding = 'latin1'
        if tvx :
            td, vd, xd = u.load()
        else :
            td = u.load()
        df.close()
    else :
        print('Error in Reading ', fname)
        return False
        
    if verify :
        print('Training Set Shape: ', td[0].shape)
        if tvx :
            print('Validation Set Shape: ', vd[0].shape)
            print('Testing Set Shape: ', xd[0].shape)

    if tvx :
        return td, vd, xd
    else :
        return td
#-----------------------------------------------------------------------
########################################################################
# FUNCTION:     tile_images (X, tile_shape, tile_spacing=(0,0),
#                               display=True, save=False,
#                               imgname='tiled_output.png')
# PARAMETERS:   X (numpy array) is the dataset to tile. The format is
#                       one data sample per row. The numpy array is
#                       the same shape as that for training a model in
#                       Keras, i.e., (N, rows, cols, channels)
#               tile_shape (tuple) is the specification of the tiling
#                       grid. The first is the number of rows and the
#                       second is the number of samples per row for
#                       display. Eg. tile_shape=(10,5) displays 50
#                       samples in 10 rows with 5 samples per row.
#               tile_spacing (tuple; default: (0,0)) is the extra
#                       spacing in pixels between the data samples
#                       being displayed.
#               display (boolean; default: True) indicates if the
#                       tiled output should be displayed as an image.
#               save (boolean; default: False) indicates if the
#                       tiled output should be saved in a file.
#               imgname (String; default: tiled_output.png) is the name
#                       of the file in which the tiled output will
#                       be saved.
# RETURN VALUE: out_img (numpy array) is the tiled numpy array created
#                       from the specified data samples.
#               False if unsuccessful.
########################################################################
def tile_images (X, tile_shape, tile_spacing=(0,0),
                 display=True, save=False, imgname='tiled_output.png') :
    img_shape = (X.shape[1], X.shape[2])
    out_shape = (img_shape[0] * tile_shape[0] +
                 (tile_shape[0] + 1) * tile_spacing[0],
                 img_shape[1] * tile_shape[1] +
                 (tile_shape[1] + 1) * tile_spacing[1])
    out_img = np.zeros(out_shape, dtype='uint8')

    for img in range(X.shape[0]) :
        pos_r = int(img / tile_shape[1])
        pos_c = int(img % tile_shape[1])
        if X[img].max() < 1 :
            X[img] = X[img] * 255.
        if pos_r < tile_shape[0] :
            out_img[tile_spacing[0] * (pos_r + 1) + img_shape[0] * pos_r: 
                    tile_spacing[0] * (pos_r + 1) + (pos_r + 1) * img_shape[0],
                    tile_spacing[1] * (pos_c + 1) + pos_c * img_shape[1]:
                    tile_spacing[1] * (pos_c + 1) + (pos_c + 1) * img_shape[1]]\
            = (X[img].reshape(img_shape)).astype('uint8')
        else :
            return False
    if display :
        plt.imshow(Image.fromarray(out_img))
        plt.show()
    if save :
        cv2.imwrite(imgname, out_img)
        
    return out_img
#-----------------------------------------------------------------------
def compare_predictions(plab, grd_truth, cat=True) :
    if cat :
        pt = [np.argmax(plab[i]) for i in range(plab.shape[0])]
        gt = [np.argmax(grd_truth[i]) for i in range(plab.shape[0])]
    else :
        pt = plab
        gt = grd_truth

    err = np.ix_(np.array(pt) != np.array(gt))
 
    return err[0]
