# -*- coding: utf-8 -*-
# ENGG 2440A / ESTR 2004 Lecture 3 -- python code
# Requires: python scipy package

from scipy.misc import toimage

# plot a tiling for a 2*n by 2*n grid with missing square (mx, my)
def plot_tiling(n, mx, my):
    if mx < 0 or mx > 2**n or my < 0 or my > 2**n:
        print "Empty square is outside the grid"
        return
    g = Grid(n)
    g.tile(mx, my)
    g.plot_tiling()

class Grid:
    def __init__(self, n):
        self.n = n
        self.N = 2**n
        self.tiling = [[0] * self.N for i in range(self.N)]
        self.white = 256
        self.next_tile = 1
    
    # compute a tiling of the grid with missing square mx, my
    def tile(self, mx, my):
        # mark the missing square
        self.mx = mx
        self.my = my
        
        # tile the grid recursively    
        Subgrid(self, 0, 0, self.n).tile(mx, my)
             
    # add a tile with bottom left corner at (x, y) of a given shape
    # rx, ry indicates the removed corner position (0, 0), (0, 1), (1, 0) or (1, 1)
    def add_tile(self, x, y, rx, ry):
        for i in range(0, 2):
            for j in range(0, 2):
                if (i != rx or j != ry):
                    self.tiling[x + i][y + j] = self.next_tile
        self.next_tile = self.next_tile + 1
    
    # display the computed tiling
    def plot_tiling(self):
        dim = 1024
        if self.n > 9:
            print "Image is too large to display"
        else:
            # calculate the scaling ratio
            ratio = dim / self.N

            # instantiate an array for the image
            image = [[0] * (dim + 1) for i in range(dim + 1)]
            
            # length of a boundary
            blength = ratio + 1                                    
            
            # draw a frame
            for t in range(dim + 1):
                image[t][0] = self.white
                image[t][dim] = self.white                                                                                                            
                image[0][t] = self.white
                image[dim][t] = self.white                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
            # draw the horizontal boundaries
            for y in range(1, self.N):
                for x in range(0, self.N):
                    if self.tiling[x][y - 1] != self.tiling[x][y]:
                        for xpos in range(x * ratio, x * ratio + blength):
                            image[xpos][y * ratio] = self.white

            # draw the vertical boundaries
            for x in range(1, self.N):
                for y in range(0, self.N):
                    if self.tiling[x - 1][y] != self.tiling[x][y]:
                        for ypos in range(y * ratio, y * ratio + blength):
                            image[x * ratio][ypos] = self.white
            
            # mark the square that was removed
            for xpos in range(self.mx * ratio + 1, self.mx * ratio + ratio):
                for ypos in range(self.my * ratio + 1, self.my * ratio + ratio): 
                    image[xpos][ypos] = self.white   
                        
        toimage(image).show()
        

class Subgrid(Grid):
    # instantiate a subgrid of grid with bottom left corner (x0, y0) and dimension 2**n by 2**n
    def __init__(self, grid, x0, y0, n):
        self.parent = grid
        self.x0 = x0
        self.y0 = y0
        self.n = n
        
    def tile(self, mx, my):
        if self.n == 1:
            self.add_tile(0, 0, mx, my)
        else:
            halfN = 2**(self.n - 1)

            # removed corner of central tile
            rx = mx / halfN
            ry = my / halfN
 
            for i in range(0, 2):
                for j in range(0, 2):
                    s = Subgrid(self, halfN * i, halfN * j, self.n - 1)
                    if i == rx and j == ry:
                        # tile quadrant containing removed square
                        s.tile(mx - halfN * i, my - halfN * j)
                    else:
                        # tile other quadrants
                        s.tile((1 - i) * (halfN - 1), (1 - j) * (halfN - 1))
            
            # cover central squares
            self.add_tile(halfN - 1, halfN - 1, rx, ry)
        
    def add_tile(self, x, y, rx, ry):
        self.parent.add_tile(x + self.x0, y + self.y0, rx, ry)