lib: implement RangeMap/BTree

This commit is contained in:
2025-03-20 09:59:26 +02:00
parent 734cd7eb0e
commit a97d79d8ca
4 changed files with 626 additions and 0 deletions
+2
View File
@@ -1,2 +1,4 @@
pub const dtb = @import("util/dtb.zig"); pub const dtb = @import("util/dtb.zig");
pub const range = @import("util/range.zig"); pub const range = @import("util/range.zig");
pub const btree = @import("util/btree.zig");
pub const rangemap = @import("util/rangemap.zig");
+368
View File
@@ -0,0 +1,368 @@
const std = @import("std");
const Allocator = std.mem.Allocator;
pub const Order = std.math.Order;
pub fn CompareFn(comptime N: type) type {
return fn (*const N, *const N) Order;
}
pub fn SearchFn(comptime N: type, comptime C: type) type {
return fn (*const N, C) Order;
}
pub fn BTree(comptime N: type, comptime compare_fn: CompareFn(N), comptime deinit_fn: ?fn (*N) void) type {
return struct {
gpa: Allocator,
root: ?*Node = null,
pub const Error = error{ already_exists, does_not_exist } || Allocator.Error;
pub fn WalkFn(comptime C: type) type {
return fn (*const Node, C) void;
}
pub const Iterator = struct {
current: ?*Node,
pub fn next(self: *Iterator) ?*const Node {
while (self.current) |n| {
const v = n;
if (n.right) |r| {
// Emit
self.current = Node.leftmost(r);
} else {
var nn = n;
while (nn.parent) |p| {
if (nn == p.right) {
nn = p;
} else {
break;
}
}
self.current = nn.parent;
}
return v;
}
return null;
}
};
pub const Node = struct {
key: N,
parent: ?*Node = null,
left: ?*Node = null,
right: ?*Node = null,
fn new(a: Allocator, key: N) Error!*Node {
const node = try a.create(Node);
node.* = .{
.key = key,
};
return node;
}
fn deinit(node: ?*Node, a: Allocator) void {
if (node) |n| {
if (comptime deinit_fn) |f| {
f(&n.key);
}
Node.deinit(n.left, a);
Node.deinit(n.right, a);
// Free node itself
a.destroy(n);
}
}
fn insert(node: ?*Node, a: Allocator, key: N) Error!struct { *Node, *Node } {
if (node) |n| {
const ord = compare_fn(&n.key, &key);
var inserted: *Node = undefined;
switch (ord) {
.lt => {
const child, inserted = try Node.insert(n.right, a, key);
child.parent = n;
n.right = child;
},
.gt => {
const child, inserted = try Node.insert(n.left, a, key);
child.parent = n;
n.left = child;
},
.eq => return error.already_exists,
}
return .{ n, inserted };
} else {
const n = try Node.new(a, key);
return .{ n, n };
}
}
fn remove_node(node: *Node, a: Allocator, destroy: bool) ?*Node {
if (node.left == null) {
// Only right/none
const tmp = node.right;
if (tmp) |t| {
t.parent = node.parent;
}
// Destroy the node
if (comptime deinit_fn) |f| {
if (destroy) {
f(&node.key);
}
}
a.destroy(node);
return tmp;
}
if (node.right == null) {
// Only left/none
const tmp = node.left;
if (tmp) |t| {
t.parent = node.parent;
}
// Destroy the node
if (comptime deinit_fn) |f| {
if (destroy) {
f(&node.key);
}
}
a.destroy(node);
return tmp;
}
// Both
var successor = node.right;
while (successor) |succ| {
if (succ.left) |l| {
successor = l;
} else {
break;
}
}
if (successor) |succ| {
node.key = succ.key;
node.right = Node.remove(node.right, a, succ.key) catch unreachable;
}
return node;
}
fn remove(node: ?*Node, a: Allocator, key: N) Error!?*Node {
if (node) |n| {
const ord = compare_fn(&n.key, &key);
switch (ord) {
.lt => n.right = try Node.remove(n.right, a, key),
.gt => n.left = try Node.remove(n.left, a, key),
.eq => return Node.remove_node(n, a, true),
}
return node;
} else {
return error.does_not_exist;
}
}
fn walk(node: ?*Node, ctx: anytype, walk_fn: WalkFn(@TypeOf(ctx))) void {
if (node) |n| {
Node.walk(n.left, ctx, walk_fn);
walk_fn(n, ctx);
Node.walk(n.right, ctx, walk_fn);
}
}
fn leftmost(node: ?*Node) ?*Node {
var n = node;
while (n) |nn| {
if (nn.left == null) {
break;
}
n = nn.left;
}
return n;
}
};
pub fn new(a: std.mem.Allocator) @This() {
return .{ .gpa = a };
}
pub fn deinit(self: *@This()) void {
Node.deinit(self.root, self.gpa);
}
pub fn iterator(self: *@This()) Iterator {
return .{ .current = Node.leftmost(self.root) };
}
pub fn insert(self: *@This(), key: N) Error!*Node {
self.root, const inserted = try Node.insert(self.root, self.gpa, key);
return inserted;
}
pub fn remove(self: *@This(), key: N) Error!void {
self.root = try Node.remove(self.root, self.gpa, key);
}
pub fn remove_node(self: *@This(), node: *Node, destroy: bool) Error!void {
if (node.parent) |p| {
// Non-root node
const np = Node.remove_node(node, self.gpa, destroy);
if (np) |npp| {
npp.parent = p;
}
if (node == p.right) {
p.right = np;
} else {
p.left = np;
}
} else {
// Root node
const np = Node.remove_node(node, self.gpa, destroy);
if (np) |npp| {
npp.parent = null;
}
self.root = np;
}
}
pub fn lookup(self: *const @This(), key: N) ?*Node {
const search_fn = struct {
fn call(n: *const N, cx: N) Order {
return compare_fn(n, &cx);
}
}.call;
return self.search(key, search_fn);
}
pub fn search(
self: *const @This(),
ctx: anytype,
search_fn: SearchFn(N, @TypeOf(ctx)),
) ?*Node {
var node = self.root;
while (node) |n| {
const ord = search_fn(&n.key, ctx);
switch (ord) {
.gt => node = n.left,
.eq => return n,
.lt => node = n.right,
}
}
return null;
}
pub fn walk(self: *@This(), ctx: anytype, walk_fn: WalkFn(@TypeOf(ctx))) void {
Node.walk(self.root, ctx, walk_fn);
}
};
}
test "BTree insertion/removal" {
const int_compare_fn = struct {
fn call(a: *const u32, b: *const u32) Order {
if (a.* > b.*) {
return .gt;
} else if (a.* == b.*) {
return .eq;
} else {
return .lt;
}
}
}.call;
const Tree = BTree(u32, int_compare_fn, null);
var tree = Tree.new(std.testing.allocator);
defer tree.deinit();
for (50..100) |i| {
_ = try tree.insert(@truncate(i));
}
for (1..50) |i| {
_ = try tree.insert(@truncate(i));
}
for (1..100) |i| {
const k = @as(u32, @truncate(i));
try std.testing.expectEqual(k, tree.lookup(k).?.key);
}
for (1..100) |i| {
const k = 100 - @as(u32, @truncate(i));
if (i % 2 == 0) {
try tree.remove(k);
}
}
for (1..100) |i| {
const k = @as(u32, @truncate(i));
if (i % 2 == 0) {
try std.testing.expectEqual(null, tree.lookup(k));
} else {
try std.testing.expectEqual(k, tree.lookup(k).?.key);
}
}
}
test "BTree removal by node" {
const int_compare_fn = struct {
fn call(a: *const u32, b: *const u32) Order {
if (a.* > b.*) {
return .gt;
} else if (a.* == b.*) {
return .eq;
} else {
return .lt;
}
}
}.call;
const Tree = BTree(u32, int_compare_fn, null);
var tree = Tree.new(std.testing.allocator);
defer tree.deinit();
_ = try tree.insert(10);
_ = try tree.insert(11);
_ = try tree.insert(12);
{
const n = tree.lookup(10).?;
try tree.remove_node(n, true);
}
try std.testing.expectEqual(null, tree.lookup(10));
try std.testing.expectEqual(12, tree.lookup(12).?.key);
try std.testing.expectEqual(11, tree.lookup(11).?.key);
}
test "BTree iterator" {
const int_compare_fn = struct {
fn call(a: *const u32, b: *const u32) Order {
if (a.* > b.*) {
return .gt;
} else if (a.* == b.*) {
return .eq;
} else {
return .lt;
}
}
}.call;
const Tree = BTree(u32, int_compare_fn, null);
var tree = Tree.new(std.testing.allocator);
defer tree.deinit();
for (50..100) |i| {
_ = try tree.insert(@truncate(i));
}
for (1..50) |i| {
_ = try tree.insert(@truncate(i));
}
var it = tree.iterator();
for (1..100) |i| {
const n = it.next().?;
try std.testing.expectEqual(i, n.key);
}
try std.testing.expectEqual(null, it.next());
}
+16
View File
@@ -1,5 +1,7 @@
//! Utilities for manipulating ranges. //! Utilities for manipulating ranges.
const std = @import("std");
/// Non-inclusive range type over `T`. /// Non-inclusive range type over `T`.
pub fn Range(comptime T: type) type { pub fn Range(comptime T: type) type {
return struct { return struct {
@@ -29,5 +31,19 @@ pub fn Range(comptime T: type) type {
return null; return null;
} }
pub fn contains(self: *const @This(), scalar: T) bool {
return scalar >= self.start and scalar - self.start < self.len;
}
pub fn compare_disjoint(a: *const @This(), b: *const @This()) std.math.Order {
if (a.start >= b.end()) {
return .gt;
} else if (b.start >= a.end()) {
return .lt;
} else {
return .eq;
}
}
}; };
} }
+240
View File
@@ -0,0 +1,240 @@
const std = @import("std");
const btree = @import("btree.zig");
const Range = @import("range.zig").Range;
const Allocator = std.mem.Allocator;
const BTree = btree.BTree;
pub const Order = btree.Order;
pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) void) type {
return struct {
pub const Node = struct {
key: Range(K),
value: V,
};
pub const WalkFn = fn (*const Node) void;
pub const Iterator = struct {
inner: Tree.Iterator,
pub fn next(self: *Iterator) ?*const Node {
if (self.inner.next()) |n| {
return &n.key;
} else {
return null;
}
}
};
const Tree = BTree(Node, compare_fn, deinit_node_fn);
const Error = error{
scalar_out_of_range,
} || Tree.Error;
fn compare_fn(a: *const Node, b: *const Node) Order {
return Range(K).compare_disjoint(&a.key, &b.key);
}
fn deinit_node_fn(n: *Node) void {
if (comptime deinit_fn) |f| {
f(&n.value);
}
}
btree: Tree,
pub fn new(gpa: Allocator) @This() {
return .{ .btree = Tree.new(gpa) };
}
pub fn deinit(self: *@This()) void {
self.btree.deinit();
}
/// Returns the value at a given scalar point, along with the full range it belongs to.
pub fn get_scalar(self: *const @This(), scalar: K) ?*Node {
return if (self.get_scalar_node(scalar)) |n| &n.key else null;
}
/// Same as `get_scalar()`, but returns the underlying BST node.
pub fn get_scalar_node(self: *const @This(), scalar: K) ?*Tree.Node {
return self.btree.search(scalar, struct {
fn call(n: *const Node, cx: K) Order {
if (n.key.contains(cx)) {
return .eq;
} else if (cx < n.key.start) {
return .gt;
} else {
return .lt;
}
}
}.call);
}
/// Splits a given node at a scalar point inside its interval.
///
/// The part of the interval before `at` is considered a "left" half, the remaining
/// part is considered a "right" half.
///
/// # Note
///
/// The "right" halve's value after the split is left uninitialized and it is up to the
/// caller to assign a proper value to it.
///
/// # Errors
///
/// * `scalar_out_of_range` if the given `at` value is not inside the node's interval.
pub fn split_node(
self: *@This(),
node: *Tree.Node,
at: K,
) Error!?struct { *Tree.Node, *Tree.Node } {
if (!node.key.key.contains(at)) {
return error.scalar_out_of_range;
}
const start = node.key.key.start;
const end = node.key.key.end();
if (at == start or at == end - 1) {
// Nothing to split here
return null;
}
const value = node.key.value;
// Remove the node, don't drop the key
try self.btree.remove_node(node, false);
const lnode = try self.btree.insert(
.{ .key = .{ .start = start, .len = at - start }, .value = value },
);
const rnode = try self.btree.insert(
.{ .key = .{ .start = at, .len = end - at }, .value = undefined },
);
return .{ lnode, rnode };
}
/// Maps some range to a value. Returns an error if the requested range crosses another
/// mapped range.
pub fn insert(self: *@This(), start: K, len: K, value: V) Error!*Tree.Node {
return self.btree.insert(.{
.key = .{ .start = start, .len = len },
.value = value,
});
}
pub fn iterator(self: *@This()) Iterator {
return .{ .inner = self.btree.iterator() };
}
pub fn node_iterator(self: *@This()) Tree.Iterator {
return self.btree.iterator();
}
pub fn walk(self: *@This(), walk_fn: WalkFn) void {
self.btree.walk(walk_fn, struct {
fn call(n: *const Tree.Node, cx: WalkFn) void {
cx(&n.key);
}
}.call);
}
};
}
test "Range map insertion" {
const Map = RangeMap(u32, []const u8, null);
var map = Map.new(std.testing.allocator);
defer map.deinit();
_ = try map.insert(10, 10, "Range 2");
_ = try map.insert(0, 10, "Range 1");
_ = try map.insert(20, 10, "Range 3");
try std.testing.expectError(error.already_exists, map.insert(5, 10, "Invalid range"));
_ = try map.insert(1000, 10, "Range 4");
}
test "Range map get scalar" {
const Map = RangeMap(u32, []const u8, null);
var map = Map.new(std.testing.allocator);
defer map.deinit();
_ = try map.insert(10, 10, "Range [10..20)");
_ = try map.insert(30, 10, "Range [30..40)");
{
const n = map.get_scalar(15).?;
try std.testing.expectEqual(10, n.key.start);
try std.testing.expectEqual(20, n.key.end());
try std.testing.expectEqualStrings("Range [10..20)", n.value);
}
{
const n = map.get_scalar(35).?;
try std.testing.expectEqual(30, n.key.start);
try std.testing.expectEqual(40, n.key.end());
try std.testing.expectEqualStrings("Range [30..40)", n.value);
}
{
const n = map.get_scalar(30).?;
try std.testing.expectEqual(30, n.key.start);
try std.testing.expectEqual(40, n.key.end());
try std.testing.expectEqualStrings("Range [30..40)", n.value);
}
try std.testing.expectEqual(null, map.get_scalar(20));
try std.testing.expectEqual(null, map.get_scalar(21));
try std.testing.expectEqual(null, map.get_scalar(9));
try std.testing.expectEqual(null, map.get_scalar(100));
try std.testing.expectEqual(null, map.get_scalar(40));
try std.testing.expectEqual(null, map.get_scalar(41));
}
test "Range map split" {
const Map = RangeMap(u32, []const u8, null);
var map = Map.new(std.testing.allocator);
defer map.deinit();
_ = try map.insert(0x1000, 0x1000, "Range [0x1000..0x2000)");
const node = map.get_scalar_node(0x1000).?;
const lnode, const rnode = (try map.split_node(node, 0x1200)).?;
lnode.key.value = "Left";
rnode.key.value = "Right";
{
const n = map.get_scalar(0x1100).?;
try std.testing.expectEqual(0x1000, n.key.start);
try std.testing.expectEqual(0x200, n.key.len);
try std.testing.expectEqualStrings("Left", n.value);
}
{
const n = map.get_scalar(0x1300).?;
try std.testing.expectEqual(0x1200, n.key.start);
try std.testing.expectEqual(0xE00, n.key.len);
try std.testing.expectEqualStrings("Right", n.value);
}
}
test "Range map iterator" {
const Map = RangeMap(u32, []const u8, null);
var map = Map.new(std.testing.allocator);
defer map.deinit();
_ = try map.insert(0x1000, 0x1000, "Range [0x1000..0x2000)");
_ = try map.insert(0x2000, 0x1000, "Range [0x2000..0x3000)");
_ = try map.insert(0x4000, 0x1000, "Range [0x4000..0x5000)");
_ = try map.insert(0x3000, 0x1000, "Range [0x3000..0x4000)");
var it = map.iterator();
try std.testing.expectEqualStrings("Range [0x1000..0x2000)", it.next().?.value);
try std.testing.expectEqualStrings("Range [0x2000..0x3000)", it.next().?.value);
try std.testing.expectEqualStrings("Range [0x3000..0x4000)", it.next().?.value);
try std.testing.expectEqualStrings("Range [0x4000..0x5000)", it.next().?.value);
try std.testing.expectEqual(null, it.next());
}