diff --git a/src/util.zig b/src/util.zig index fc44082..1953e6e 100644 --- a/src/util.zig +++ b/src/util.zig @@ -1,2 +1,4 @@ pub const dtb = @import("util/dtb.zig"); pub const range = @import("util/range.zig"); +pub const btree = @import("util/btree.zig"); +pub const rangemap = @import("util/rangemap.zig"); diff --git a/src/util/btree.zig b/src/util/btree.zig new file mode 100644 index 0000000..99c9010 --- /dev/null +++ b/src/util/btree.zig @@ -0,0 +1,368 @@ +const std = @import("std"); + +const Allocator = std.mem.Allocator; +pub const Order = std.math.Order; + +pub fn CompareFn(comptime N: type) type { + return fn (*const N, *const N) Order; +} + +pub fn SearchFn(comptime N: type, comptime C: type) type { + return fn (*const N, C) Order; +} + +pub fn BTree(comptime N: type, comptime compare_fn: CompareFn(N), comptime deinit_fn: ?fn (*N) void) type { + return struct { + gpa: Allocator, + root: ?*Node = null, + + pub const Error = error{ already_exists, does_not_exist } || Allocator.Error; + + pub fn WalkFn(comptime C: type) type { + return fn (*const Node, C) void; + } + + pub const Iterator = struct { + current: ?*Node, + + pub fn next(self: *Iterator) ?*const Node { + while (self.current) |n| { + const v = n; + + if (n.right) |r| { + // Emit + self.current = Node.leftmost(r); + } else { + var nn = n; + while (nn.parent) |p| { + if (nn == p.right) { + nn = p; + } else { + break; + } + } + self.current = nn.parent; + } + + return v; + } + + return null; + } + }; + + pub const Node = struct { + key: N, + parent: ?*Node = null, + left: ?*Node = null, + right: ?*Node = null, + + fn new(a: Allocator, key: N) Error!*Node { + const node = try a.create(Node); + node.* = .{ + .key = key, + }; + return node; + } + + fn deinit(node: ?*Node, a: Allocator) void { + if (node) |n| { + if (comptime deinit_fn) |f| { + f(&n.key); + } + + Node.deinit(n.left, a); + Node.deinit(n.right, a); + + // Free node itself + a.destroy(n); + } + } + + fn insert(node: ?*Node, a: Allocator, key: N) Error!struct { *Node, *Node } { + if (node) |n| { + const ord = compare_fn(&n.key, &key); + var inserted: *Node = undefined; + switch (ord) { + .lt => { + const child, inserted = try Node.insert(n.right, a, key); + child.parent = n; + n.right = child; + }, + .gt => { + const child, inserted = try Node.insert(n.left, a, key); + child.parent = n; + n.left = child; + }, + .eq => return error.already_exists, + } + return .{ n, inserted }; + } else { + const n = try Node.new(a, key); + return .{ n, n }; + } + } + + fn remove_node(node: *Node, a: Allocator, destroy: bool) ?*Node { + if (node.left == null) { + // Only right/none + const tmp = node.right; + if (tmp) |t| { + t.parent = node.parent; + } + // Destroy the node + if (comptime deinit_fn) |f| { + if (destroy) { + f(&node.key); + } + } + a.destroy(node); + return tmp; + } + if (node.right == null) { + // Only left/none + const tmp = node.left; + if (tmp) |t| { + t.parent = node.parent; + } + // Destroy the node + if (comptime deinit_fn) |f| { + if (destroy) { + f(&node.key); + } + } + a.destroy(node); + return tmp; + } + + // Both + var successor = node.right; + while (successor) |succ| { + if (succ.left) |l| { + successor = l; + } else { + break; + } + } + if (successor) |succ| { + node.key = succ.key; + node.right = Node.remove(node.right, a, succ.key) catch unreachable; + } + return node; + } + + fn remove(node: ?*Node, a: Allocator, key: N) Error!?*Node { + if (node) |n| { + const ord = compare_fn(&n.key, &key); + switch (ord) { + .lt => n.right = try Node.remove(n.right, a, key), + .gt => n.left = try Node.remove(n.left, a, key), + .eq => return Node.remove_node(n, a, true), + } + return node; + } else { + return error.does_not_exist; + } + } + + fn walk(node: ?*Node, ctx: anytype, walk_fn: WalkFn(@TypeOf(ctx))) void { + if (node) |n| { + Node.walk(n.left, ctx, walk_fn); + walk_fn(n, ctx); + Node.walk(n.right, ctx, walk_fn); + } + } + + fn leftmost(node: ?*Node) ?*Node { + var n = node; + while (n) |nn| { + if (nn.left == null) { + break; + } + n = nn.left; + } + return n; + } + }; + + pub fn new(a: std.mem.Allocator) @This() { + return .{ .gpa = a }; + } + + pub fn deinit(self: *@This()) void { + Node.deinit(self.root, self.gpa); + } + + pub fn iterator(self: *@This()) Iterator { + return .{ .current = Node.leftmost(self.root) }; + } + + pub fn insert(self: *@This(), key: N) Error!*Node { + self.root, const inserted = try Node.insert(self.root, self.gpa, key); + return inserted; + } + + pub fn remove(self: *@This(), key: N) Error!void { + self.root = try Node.remove(self.root, self.gpa, key); + } + + pub fn remove_node(self: *@This(), node: *Node, destroy: bool) Error!void { + if (node.parent) |p| { + // Non-root node + const np = Node.remove_node(node, self.gpa, destroy); + if (np) |npp| { + npp.parent = p; + } + if (node == p.right) { + p.right = np; + } else { + p.left = np; + } + } else { + // Root node + const np = Node.remove_node(node, self.gpa, destroy); + if (np) |npp| { + npp.parent = null; + } + self.root = np; + } + } + + pub fn lookup(self: *const @This(), key: N) ?*Node { + const search_fn = struct { + fn call(n: *const N, cx: N) Order { + return compare_fn(n, &cx); + } + }.call; + + return self.search(key, search_fn); + } + + pub fn search( + self: *const @This(), + ctx: anytype, + search_fn: SearchFn(N, @TypeOf(ctx)), + ) ?*Node { + var node = self.root; + while (node) |n| { + const ord = search_fn(&n.key, ctx); + switch (ord) { + .gt => node = n.left, + .eq => return n, + .lt => node = n.right, + } + } + return null; + } + + pub fn walk(self: *@This(), ctx: anytype, walk_fn: WalkFn(@TypeOf(ctx))) void { + Node.walk(self.root, ctx, walk_fn); + } + }; +} + +test "BTree insertion/removal" { + const int_compare_fn = struct { + fn call(a: *const u32, b: *const u32) Order { + if (a.* > b.*) { + return .gt; + } else if (a.* == b.*) { + return .eq; + } else { + return .lt; + } + } + }.call; + const Tree = BTree(u32, int_compare_fn, null); + var tree = Tree.new(std.testing.allocator); + defer tree.deinit(); + + for (50..100) |i| { + _ = try tree.insert(@truncate(i)); + } + for (1..50) |i| { + _ = try tree.insert(@truncate(i)); + } + + for (1..100) |i| { + const k = @as(u32, @truncate(i)); + try std.testing.expectEqual(k, tree.lookup(k).?.key); + } + + for (1..100) |i| { + const k = 100 - @as(u32, @truncate(i)); + if (i % 2 == 0) { + try tree.remove(k); + } + } + + for (1..100) |i| { + const k = @as(u32, @truncate(i)); + if (i % 2 == 0) { + try std.testing.expectEqual(null, tree.lookup(k)); + } else { + try std.testing.expectEqual(k, tree.lookup(k).?.key); + } + } +} + +test "BTree removal by node" { + const int_compare_fn = struct { + fn call(a: *const u32, b: *const u32) Order { + if (a.* > b.*) { + return .gt; + } else if (a.* == b.*) { + return .eq; + } else { + return .lt; + } + } + }.call; + const Tree = BTree(u32, int_compare_fn, null); + var tree = Tree.new(std.testing.allocator); + defer tree.deinit(); + + _ = try tree.insert(10); + _ = try tree.insert(11); + _ = try tree.insert(12); + + { + const n = tree.lookup(10).?; + try tree.remove_node(n, true); + } + + try std.testing.expectEqual(null, tree.lookup(10)); + try std.testing.expectEqual(12, tree.lookup(12).?.key); + try std.testing.expectEqual(11, tree.lookup(11).?.key); +} + +test "BTree iterator" { + const int_compare_fn = struct { + fn call(a: *const u32, b: *const u32) Order { + if (a.* > b.*) { + return .gt; + } else if (a.* == b.*) { + return .eq; + } else { + return .lt; + } + } + }.call; + const Tree = BTree(u32, int_compare_fn, null); + var tree = Tree.new(std.testing.allocator); + defer tree.deinit(); + + for (50..100) |i| { + _ = try tree.insert(@truncate(i)); + } + for (1..50) |i| { + _ = try tree.insert(@truncate(i)); + } + + var it = tree.iterator(); + for (1..100) |i| { + const n = it.next().?; + try std.testing.expectEqual(i, n.key); + } + try std.testing.expectEqual(null, it.next()); +} diff --git a/src/util/range.zig b/src/util/range.zig index 9ec441a..34e3ffa 100644 --- a/src/util/range.zig +++ b/src/util/range.zig @@ -1,5 +1,7 @@ //! Utilities for manipulating ranges. +const std = @import("std"); + /// Non-inclusive range type over `T`. pub fn Range(comptime T: type) type { return struct { @@ -29,5 +31,19 @@ pub fn Range(comptime T: type) type { return null; } + + pub fn contains(self: *const @This(), scalar: T) bool { + return scalar >= self.start and scalar - self.start < self.len; + } + + pub fn compare_disjoint(a: *const @This(), b: *const @This()) std.math.Order { + if (a.start >= b.end()) { + return .gt; + } else if (b.start >= a.end()) { + return .lt; + } else { + return .eq; + } + } }; } diff --git a/src/util/rangemap.zig b/src/util/rangemap.zig new file mode 100644 index 0000000..3c958a9 --- /dev/null +++ b/src/util/rangemap.zig @@ -0,0 +1,240 @@ +const std = @import("std"); + +const btree = @import("btree.zig"); + +const Range = @import("range.zig").Range; +const Allocator = std.mem.Allocator; +const BTree = btree.BTree; +pub const Order = btree.Order; + +pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) void) type { + return struct { + pub const Node = struct { + key: Range(K), + value: V, + }; + + pub const WalkFn = fn (*const Node) void; + + pub const Iterator = struct { + inner: Tree.Iterator, + + pub fn next(self: *Iterator) ?*const Node { + if (self.inner.next()) |n| { + return &n.key; + } else { + return null; + } + } + }; + + const Tree = BTree(Node, compare_fn, deinit_node_fn); + + const Error = error{ + scalar_out_of_range, + } || Tree.Error; + + fn compare_fn(a: *const Node, b: *const Node) Order { + return Range(K).compare_disjoint(&a.key, &b.key); + } + + fn deinit_node_fn(n: *Node) void { + if (comptime deinit_fn) |f| { + f(&n.value); + } + } + + btree: Tree, + + pub fn new(gpa: Allocator) @This() { + return .{ .btree = Tree.new(gpa) }; + } + + pub fn deinit(self: *@This()) void { + self.btree.deinit(); + } + + /// Returns the value at a given scalar point, along with the full range it belongs to. + pub fn get_scalar(self: *const @This(), scalar: K) ?*Node { + return if (self.get_scalar_node(scalar)) |n| &n.key else null; + } + + /// Same as `get_scalar()`, but returns the underlying BST node. + pub fn get_scalar_node(self: *const @This(), scalar: K) ?*Tree.Node { + return self.btree.search(scalar, struct { + fn call(n: *const Node, cx: K) Order { + if (n.key.contains(cx)) { + return .eq; + } else if (cx < n.key.start) { + return .gt; + } else { + return .lt; + } + } + }.call); + } + + /// Splits a given node at a scalar point inside its interval. + /// + /// The part of the interval before `at` is considered a "left" half, the remaining + /// part is considered a "right" half. + /// + /// # Note + /// + /// The "right" halve's value after the split is left uninitialized and it is up to the + /// caller to assign a proper value to it. + /// + /// # Errors + /// + /// * `scalar_out_of_range` if the given `at` value is not inside the node's interval. + pub fn split_node( + self: *@This(), + node: *Tree.Node, + at: K, + ) Error!?struct { *Tree.Node, *Tree.Node } { + if (!node.key.key.contains(at)) { + return error.scalar_out_of_range; + } + + const start = node.key.key.start; + const end = node.key.key.end(); + + if (at == start or at == end - 1) { + // Nothing to split here + return null; + } + + const value = node.key.value; + + // Remove the node, don't drop the key + try self.btree.remove_node(node, false); + + const lnode = try self.btree.insert( + .{ .key = .{ .start = start, .len = at - start }, .value = value }, + ); + const rnode = try self.btree.insert( + .{ .key = .{ .start = at, .len = end - at }, .value = undefined }, + ); + + return .{ lnode, rnode }; + } + + /// Maps some range to a value. Returns an error if the requested range crosses another + /// mapped range. + pub fn insert(self: *@This(), start: K, len: K, value: V) Error!*Tree.Node { + return self.btree.insert(.{ + .key = .{ .start = start, .len = len }, + .value = value, + }); + } + + pub fn iterator(self: *@This()) Iterator { + return .{ .inner = self.btree.iterator() }; + } + + pub fn node_iterator(self: *@This()) Tree.Iterator { + return self.btree.iterator(); + } + + pub fn walk(self: *@This(), walk_fn: WalkFn) void { + self.btree.walk(walk_fn, struct { + fn call(n: *const Tree.Node, cx: WalkFn) void { + cx(&n.key); + } + }.call); + } + }; +} + +test "Range map insertion" { + const Map = RangeMap(u32, []const u8, null); + var map = Map.new(std.testing.allocator); + defer map.deinit(); + + _ = try map.insert(10, 10, "Range 2"); + _ = try map.insert(0, 10, "Range 1"); + _ = try map.insert(20, 10, "Range 3"); + + try std.testing.expectError(error.already_exists, map.insert(5, 10, "Invalid range")); + + _ = try map.insert(1000, 10, "Range 4"); +} + +test "Range map get scalar" { + const Map = RangeMap(u32, []const u8, null); + var map = Map.new(std.testing.allocator); + defer map.deinit(); + + _ = try map.insert(10, 10, "Range [10..20)"); + _ = try map.insert(30, 10, "Range [30..40)"); + + { + const n = map.get_scalar(15).?; + try std.testing.expectEqual(10, n.key.start); + try std.testing.expectEqual(20, n.key.end()); + try std.testing.expectEqualStrings("Range [10..20)", n.value); + } + { + const n = map.get_scalar(35).?; + try std.testing.expectEqual(30, n.key.start); + try std.testing.expectEqual(40, n.key.end()); + try std.testing.expectEqualStrings("Range [30..40)", n.value); + } + { + const n = map.get_scalar(30).?; + try std.testing.expectEqual(30, n.key.start); + try std.testing.expectEqual(40, n.key.end()); + try std.testing.expectEqualStrings("Range [30..40)", n.value); + } + try std.testing.expectEqual(null, map.get_scalar(20)); + try std.testing.expectEqual(null, map.get_scalar(21)); + try std.testing.expectEqual(null, map.get_scalar(9)); + try std.testing.expectEqual(null, map.get_scalar(100)); + try std.testing.expectEqual(null, map.get_scalar(40)); + try std.testing.expectEqual(null, map.get_scalar(41)); +} + +test "Range map split" { + const Map = RangeMap(u32, []const u8, null); + var map = Map.new(std.testing.allocator); + defer map.deinit(); + + _ = try map.insert(0x1000, 0x1000, "Range [0x1000..0x2000)"); + + const node = map.get_scalar_node(0x1000).?; + const lnode, const rnode = (try map.split_node(node, 0x1200)).?; + + lnode.key.value = "Left"; + rnode.key.value = "Right"; + + { + const n = map.get_scalar(0x1100).?; + try std.testing.expectEqual(0x1000, n.key.start); + try std.testing.expectEqual(0x200, n.key.len); + try std.testing.expectEqualStrings("Left", n.value); + } + { + const n = map.get_scalar(0x1300).?; + try std.testing.expectEqual(0x1200, n.key.start); + try std.testing.expectEqual(0xE00, n.key.len); + try std.testing.expectEqualStrings("Right", n.value); + } +} + +test "Range map iterator" { + const Map = RangeMap(u32, []const u8, null); + var map = Map.new(std.testing.allocator); + defer map.deinit(); + + _ = try map.insert(0x1000, 0x1000, "Range [0x1000..0x2000)"); + _ = try map.insert(0x2000, 0x1000, "Range [0x2000..0x3000)"); + _ = try map.insert(0x4000, 0x1000, "Range [0x4000..0x5000)"); + _ = try map.insert(0x3000, 0x1000, "Range [0x3000..0x4000)"); + + var it = map.iterator(); + try std.testing.expectEqualStrings("Range [0x1000..0x2000)", it.next().?.value); + try std.testing.expectEqualStrings("Range [0x2000..0x3000)", it.next().?.value); + try std.testing.expectEqualStrings("Range [0x3000..0x4000)", it.next().?.value); + try std.testing.expectEqualStrings("Range [0x4000..0x5000)", it.next().?.value); + try std.testing.expectEqual(null, it.next()); +}