import polars as pl
import numpy as np
from plotnine import ggplot, aes, geom_rect, theme_minimal, scale_fill_manual, theme, element_blank
from typing import List, Tuple
from enum import Enum
class MondrianColour(Enum):
= "#000000"
BLACK = "#FDDE06"
YELLOW = "#0300AD"
BLUE = "#E70503"
RED = "#ffffff"
WHITE
class Node:
def __init__(self, depth: int, x_range: Tuple[float, float], y_range: Tuple[float, float]):
self.depth = depth
self.x_range = x_range
self.y_range = y_range
self.left = None
self.right = None
self.split_value = None
self.is_vertical = np.random.choice([True, False])
def generate_tree(node: Node, max_depth: int, min_size: float, force_split: bool = False) -> None:
= node.x_range[1] - node.x_range[0]
width = node.y_range[1] - node.y_range[0]
height
if not force_split:
if node.depth >= max_depth or (np.random.random() < 0.1 and node.depth > 1):
return
if width < min_size and height < min_size:
return
if node.is_vertical and width >= min_size:
= np.random.uniform(node.x_range[0] + min_size, node.x_range[1] - min_size)
node.split_value = Node(node.depth + 1, (node.x_range[0], node.split_value), node.y_range)
node.left = Node(node.depth + 1, (node.split_value, node.x_range[1]), node.y_range)
node.right elif not node.is_vertical and height >= min_size:
= np.random.uniform(node.y_range[0] + min_size, node.y_range[1] - min_size)
node.split_value = Node(node.depth + 1, node.x_range, (node.y_range[0], node.split_value))
node.left = Node(node.depth + 1, node.x_range, (node.split_value, node.y_range[1]))
node.right else:
return
generate_tree(node.left, max_depth, min_size)
generate_tree(node.right, max_depth, min_size)
def initial_splits(root: Node, min_size: float) -> None:
# Vertical split
= True
root.is_vertical = np.random.uniform(root.x_range[0] + min_size, root.x_range[1] - min_size)
root.split_value = Node(1, (root.x_range[0], root.split_value), root.y_range)
root.left = Node(1, (root.split_value, root.x_range[1]), root.y_range)
root.right
# Horizontal splits
= False
root.left.is_vertical = np.random.uniform(root.left.y_range[0] + min_size, root.left.y_range[1] - min_size)
root.left.split_value = Node(2, root.left.x_range, (root.left.y_range[0], root.left.split_value))
root.left.left = Node(2, root.left.x_range, (root.left.split_value, root.left.y_range[1]))
root.left.right
= False
root.right.is_vertical = np.random.uniform(root.right.y_range[0] + min_size, root.right.y_range[1] - min_size)
root.right.split_value = Node(2, root.right.x_range, (root.right.y_range[0], root.right.split_value))
root.right.left = Node(2, root.right.x_range, (root.right.split_value, root.right.y_range[1]))
root.right.right
def tree_to_rectangles(node: Node, rectangles: List[dict]) -> None:
if node.left is None and node.right is None:
rectangles.append({'xmin': node.x_range[0],
'xmax': node.x_range[1],
'ymin': node.y_range[0],
'ymax': node.y_range[1],
'depth': node.depth
})else:
tree_to_rectangles(node.left, rectangles)
tree_to_rectangles(node.right, rectangles)
def draw(seed: int):
np.random.seed(seed)
= Node(0, (0, 1), (0, 1))
root = 0.05
min_size = 12
max_depth
# Perform initial splits
initial_splits(root, min_size)
# Continue generating the tree
generate_tree(root.left.left, max_depth, min_size)
generate_tree(root.left.right, max_depth, min_size)
generate_tree(root.right.left, max_depth, min_size)
generate_tree(root.right.right, max_depth, min_size)
= []
rectangles
tree_to_rectangles(root, rectangles)
= pl.Series(name="colour",values=np.random.choice([colour.value for colour in MondrianColour], size= len(rectangles)))
colours
= pl.DataFrame(rectangles).with_columns(colours)
df
= (ggplot(df, aes(xmin='xmin', xmax='xmax', ymin='ymin', ymax='ymax', fill='colour'))
plot + geom_rect(color='black', size=2)
+ scale_fill_manual(values=[colour.value for colour in MondrianColour])
+ theme_minimal()
+ theme(legend_position = "none",
=1,
aspect_ratio=element_blank(),
axis_text=element_blank(),
axis_ticks=element_blank(),
panel_grid=(10,10))
figure_size
)
return plot
=42) draw(seed
Mondrian
art
contest