import polars as pl
import numpy as np
from plotnine import ggplot, aes, geom_rect, theme_minimal, scale_fill_manual, theme, element_blank
from typing import List, Tuple
from enum import Enum
class MondrianColour(Enum):
BLACK = "#000000"
YELLOW = "#FDDE06"
BLUE = "#0300AD"
RED = "#E70503"
WHITE = "#ffffff"
class Node:
def __init__(self, depth: int, x_range: Tuple[float, float], y_range: Tuple[float, float]):
self.depth = depth
self.x_range = x_range
self.y_range = y_range
self.left = None
self.right = None
self.split_value = None
self.is_vertical = np.random.choice([True, False])
def generate_tree(node: Node, max_depth: int, min_size: float, force_split: bool = False) -> None:
width = node.x_range[1] - node.x_range[0]
height = node.y_range[1] - node.y_range[0]
if not force_split:
if node.depth >= max_depth or (np.random.random() < 0.1 and node.depth > 1):
return
if width < min_size and height < min_size:
return
if node.is_vertical and width >= min_size:
node.split_value = np.random.uniform(node.x_range[0] + min_size, node.x_range[1] - min_size)
node.left = Node(node.depth + 1, (node.x_range[0], node.split_value), node.y_range)
node.right = Node(node.depth + 1, (node.split_value, node.x_range[1]), node.y_range)
elif not node.is_vertical and height >= min_size:
node.split_value = np.random.uniform(node.y_range[0] + min_size, node.y_range[1] - min_size)
node.left = Node(node.depth + 1, node.x_range, (node.y_range[0], node.split_value))
node.right = Node(node.depth + 1, node.x_range, (node.split_value, node.y_range[1]))
else:
return
generate_tree(node.left, max_depth, min_size)
generate_tree(node.right, max_depth, min_size)
def initial_splits(root: Node, min_size: float) -> None:
# Vertical split
root.is_vertical = True
root.split_value = np.random.uniform(root.x_range[0] + min_size, root.x_range[1] - min_size)
root.left = Node(1, (root.x_range[0], root.split_value), root.y_range)
root.right = Node(1, (root.split_value, root.x_range[1]), root.y_range)
# Horizontal splits
root.left.is_vertical = False
root.left.split_value = np.random.uniform(root.left.y_range[0] + min_size, root.left.y_range[1] - min_size)
root.left.left = Node(2, root.left.x_range, (root.left.y_range[0], root.left.split_value))
root.left.right = Node(2, root.left.x_range, (root.left.split_value, root.left.y_range[1]))
root.right.is_vertical = False
root.right.split_value = np.random.uniform(root.right.y_range[0] + min_size, root.right.y_range[1] - min_size)
root.right.left = Node(2, root.right.x_range, (root.right.y_range[0], root.right.split_value))
root.right.right = Node(2, root.right.x_range, (root.right.split_value, root.right.y_range[1]))
def tree_to_rectangles(node: Node, rectangles: List[dict]) -> None:
if node.left is None and node.right is None:
rectangles.append({
'xmin': node.x_range[0],
'xmax': node.x_range[1],
'ymin': node.y_range[0],
'ymax': node.y_range[1],
'depth': node.depth
})
else:
tree_to_rectangles(node.left, rectangles)
tree_to_rectangles(node.right, rectangles)
def draw(seed: int):
np.random.seed(seed)
root = Node(0, (0, 1), (0, 1))
min_size = 0.05
max_depth = 12
# Perform initial splits
initial_splits(root, min_size)
# Continue generating the tree
generate_tree(root.left.left, max_depth, min_size)
generate_tree(root.left.right, max_depth, min_size)
generate_tree(root.right.left, max_depth, min_size)
generate_tree(root.right.right, max_depth, min_size)
rectangles = []
tree_to_rectangles(root, rectangles)
colours = pl.Series(name="colour",values=np.random.choice([colour.value for colour in MondrianColour], size= len(rectangles)))
df = pl.DataFrame(rectangles).with_columns(colours)
plot = (ggplot(df, aes(xmin='xmin', xmax='xmax', ymin='ymin', ymax='ymax', fill='colour'))
+ geom_rect(color='black', size=2)
+ scale_fill_manual(values=[colour.value for colour in MondrianColour])
+ theme_minimal()
+ theme(legend_position = "none",
aspect_ratio=1,
axis_text=element_blank(),
axis_ticks=element_blank(),
panel_grid=element_blank(),
figure_size=(10,10))
)
return plot
draw(seed=42)