import scipy.io
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import matplotlib.patches as patches
from matplotlib.path import Path

def load_dino(mat_file='dino2.mat'):
	# Load dino matrix file contents
	dino_mat = scipy.io.loadmat(mat_file)
	dino_01 = np.array(dino_mat['dino01'])
	dino_02 = np.array(dino_mat['dino02'])
	cor_01 = np.array(dino_mat['cor1'])
	cor_02 = np.array(dino_mat['cor2'])
	return dino_01, dino_02, cor_01, cor_02

def draw_points(img, points):
	# Plot points on image
    fig = plt.figure(figsize=(20,10))
    ax = fig.add_subplot(111)
    ax.imshow(img, cmap='gray')
    for elem in points:
        circ = Circle((elem[0], elem[1]),10)
        ax.add_patch(circ)
    plt.show()

def draw_correspondence(img1, img2, cor1, cor2):    
    # Pad images if unequal size
    offset = img1.shape[0] - img2.shape[0]
    img_shift = img1.shape[1]
    top1 = 0
    top2 = 0
    if offset > 0:
        top2 = offset
    elif offset < 0:
        top1 = abs(offset)

    # TODO: if you don't have OpenCV, please use PIL or other 
    # simple techniques to pad small image on the top to make equal
    # height
    new_img1 = cv2.copyMakeBorder(img1,top1,0,0,0,1,0) 
    new_img2 = cv2.copyMakeBorder(img2,top2,0,0,0,1,0) 
    
    # Stack images horizontally
    stack_img = np.hstack((new_img1, new_img2))
    
    
    fig = plt.figure(figsize=(20,10))
    ax = fig.add_subplot(111)
    ax.imshow(stack_img, cmap='gray')
    codes = [Path.MOVETO,Path.LINETO]

    # Plot points and lines
    for elem in cor1:
        circ = Circle((elem[0], elem[1]+top1),10)
        ax.add_patch(circ)

    for elem in cor2:
        circ = Circle((elem[0]+img_shift, elem[1]+top2),10)
        ax.add_patch(circ)
    
    # To be used for plotting correspondance lines 
    # Not to be used for plotting Epipolar lines. Use 
    # draw_lines instead
    codes = [Path.MOVETO,Path.LINETO]
    for elem1, elem2 in zip(cor1, cor2):
        verts = [(elem1[0], elem1[1]+top1), (elem2[0]+img_shift, elem2[1]+top2)]
        path = Path(verts, codes)
        ax.add_patch(patches.PathPatch(path, color='green', lw=2.0))
    plt.show()

def draw_lines(img1, points1, img2, lines):
    # Use for plotting epipolar lines
    fig = plt.figure(figsize=(20,10))
    ax1, ax2 = fig.add_subplot(1, 2, 1), fig.add_subplot(1, 2, 2)

    # Plot points on img1
    ax1.imshow(img1, cmap='gray')
    for elem in points1:
        circ = Circle((elem[0], elem[1]),10)
        ax1.add_patch(circ)

    # Plot corresponding Epipolar lines on img2
    ax2.imshow(img2, cmap='gray')
    codes = [Path.MOVETO,Path.LINETO]
    for line in lines:
        slope = -1*line[0]/line[1]
        intercept = -1*line[2]/line[1]
        verts = [(0, slope*0+intercept), 
                 (img2.shape[1], slope*(img2.shape[1])+intercept)]
        path = Path(verts, codes)
        ax2.add_patch(patches.PathPatch(path, color='green', lw=2.0))
    plt.show()
