Today’s problem was all about connected components where the nodes were 3D positions, so the first thing I did was to define a data structure for 3D positions:
structure Position3D where
x : Int
y : Int
z : Int
To avoid leaving the beautiful, simple world of integers, I defined a method for computing the square distance between positions, hoping that I wouldn’t need to find the actual distances:
def Position3D.squaredDistance (p1 p2 : Position3D) : Int :=
(p1.x - p2.x)^2 + (p1.y - p2.y)^2 + (p1.z - p2.z)^2
I then parsed each position and was on my merry way. The first thing I did was to build an array with each pair of positions and their squared distance, sorted by that distance. I then created a list of HashSets to repr…
Today’s problem was all about connected components where the nodes were 3D positions, so the first thing I did was to define a data structure for 3D positions:
structure Position3D where
x : Int
y : Int
z : Int
To avoid leaving the beautiful, simple world of integers, I defined a method for computing the square distance between positions, hoping that I wouldn’t need to find the actual distances:
def Position3D.squaredDistance (p1 p2 : Position3D) : Int :=
(p1.x - p2.x)^2 + (p1.y - p2.y)^2 + (p1.z - p2.z)^2
I then parsed each position and was on my merry way. The first thing I did was to build an array with each pair of positions and their squared distance, sorted by that distance. I then created a list of HashSets to represent each connected component, manually removing components as I inserted a merged one:
-- part 1
let (connected, _) := distances.foldl (fun acc dist =>
match acc with
| ⟨_,0⟩ => acc
| ⟨circuits,cords⟩ =>
let ⟨⟨b1,b2⟩, d⟩ := dist
let b1Circuit := circuits.find? (b1 ∈ ·) |>.get!
let b2Circuit := circuits.find? (b2 ∈ ·) |>.get!
let circuits' := circuits.filter (λ c => b1 ∉ c && b2 ∉ c)
|>.push (b1Circuit ∪ b2Circuit)
(circuits', cords - 1)
) (circuits, 1000)
let largestThree := connected.map (·.size) |>.qsort |>.reverse.toList.take 3
However, this approach is both ugly and inefficient. And this fold is begging for an early termination, so why not write imperative code with Lean’s imperative features? Behold! Mutation and a for loop:
let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
let b1Circuit := circuits'.find? (b1 ∈ ·) |>.get!
let b2Circuit := circuits'.find? (b2 ∈ ·) |>.get!
circuits' := circuits'.filter (λ c => b1 ∉ c && b2 ∉ c)
|>.push (b1Circuit ∪ b2Circuit)
cords := cords - 1
if cords == 0 then
break
This linear scan to find which component a box belongs to for each iteration is still an eyesore, though, so let’s use the good old union-find. It turns out there is a UF implementation in the Batteries library so we don’t even have to write our own. However, I ended up adding two helpers I needed to my utility library:
def Batteries.UnionFind.clusterSizes (self : Batteries.UnionFind) : Array (Nat × Nat) :=
let allRoots := Array.range self.size |>.map self.rootD
let uniqueRoots := allRoots.toList.eraseDups.toArray
uniqueRoots.map (λ root => (root, allRoots.filter (· == root) |>.size))
|>.qsort (·.snd > ·.snd)
def Batteries.UnionFind.numClusters (self : Batteries.UnionFind) : Nat :=
Array.range self.size |>.foldl (λ acc i =>
acc.insert (self.rootD i)
) (∅ : Std.HashSet Nat)
|>.size
With that in place, our imperative solution for part 1 is much simpler and more efficient:
let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
circuits' := circuits'.union! b1 b2
cords := cords - 1
if cords == 0 then
break
let largestThree := circuits'.clusterSizes.map (·.snd) |>.toList.take 3
Things I (re-)learned today
- Using mutable variables.
- Using imperative-style
forloops.
Solution
import Batteries.Data.List
import Batteries.Data.UnionFind
import Aoc
open Aoc
structure Position3D where
x : Int
y : Int
z : Int
deriving BEq, Hashable, Repr, Inhabited
def Position3D.parse (s : String) : Except String Position3D :=
match s.splitOn "," with
| [x, y, z] => { x:= x.toInt!, y:= y.toInt!, z:= z.toInt! } |> Except.ok
| _ => Except.error s!"Invalid Position3D: {s}"
def Position3D.squaredDistance (p1 p2 : Position3D) : Int :=
(p1.x - p2.x)^2 + (p1.y - p2.y)^2 + (p1.z - p2.z)^2
def main : IO Unit := do
let input <- readLines "Day08/input.txt"
let boxes <- input |> Array.mapM Position3D.parse |> IO.ofExcept
-- a union-find data structure representing the circuits
let circuits := boxes.foldl (λ uf _ => uf.push) (Batteries.UnionFind.empty)
-- a Nat-index map of junction boxes for looking up by index
let index := boxes.toList.zipIdx
|>.map (λ (b, i) => (i, b)) |> Std.HashMap.ofList
-- the squared distance between pairs of boxes, from smallest to largest
let distances := (List.range boxes.size).tails.foldl (
λ acc boxes =>
match boxes with
| [] => acc
| b1i :: bs =>
let newDists := bs.map (λ b2i =>
let b1 := index.get! b1i
let b2 := index.get! b2i
((b1i, b2i), b1.squaredDistance b2)
)
acc ++ newDists
) ∅ |> Array.mk |>.qsort (·.snd < ·.snd)
-- part 1
let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
circuits' := circuits'.union! b1 b2
cords := cords - 1
if cords == 0 then
break
let largestThree := circuits'.clusterSizes.map (·.snd) |>.toList.take 3
IO.println s!"Part 1: {largestThree.prod}"
-- part 2
circuits' := circuits
let mut lastPair : Option (Nat × Nat) := none
for ⟨⟨b1, b2⟩, _⟩ in distances do
circuits' := circuits'.union! b1 b2
if circuits'.numClusters == 1 then
lastPair := some (b1, b2)
break
let ⟨b1i,b2i⟩ := lastPair.get!
let b1 := index.get! b1i
let b2 := index.get! b2i
IO.println s!"Part 2: {b1.x * b2.x}"