diff --git a/src/util/rangemap.zig b/src/util/rangemap.zig index 4a5e957..b661db8 100644 --- a/src/util/rangemap.zig +++ b/src/util/rangemap.zig @@ -7,11 +7,22 @@ const Allocator = std.mem.Allocator; const BTree = btree.BTree; pub const Order = btree.Order; -pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) void) type { +pub fn RangeMap( + comptime K: type, + comptime V: type, + comptime ops: struct { + deinit_fn: ?fn(*V) void = null, + merge_fn: ?fn (*const V, *const V) bool = null, + }, +) type { return struct { pub const Node = struct { key: Range(K), value: V, + + pub fn len(self: *const @This()) K { + return self.key.len; + } }; pub const WalkFn = fn (*const Node) void; @@ -28,10 +39,11 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) } }; - const Tree = BTree(Node, compare_fn, deinit_node_fn); + pub const Tree = BTree(Node, compare_fn, deinit_node_fn); - const Error = error{ + pub const Error = error{ scalar_out_of_range, + range_out_of_bounds, } || Tree.Error; fn compare_fn(a: *const Node, b: *const Node) Order { @@ -39,7 +51,7 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) } fn deinit_node_fn(n: *Node) void { - if (comptime deinit_fn) |f| { + if (comptime ops.deinit_fn) |f| { f(&n.value); } } @@ -122,6 +134,42 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) /// Maps some range to a value. Returns an error if the requested range crosses another /// mapped range. pub fn insert(self: *@This(), start: K, len: K, value: V) Error!*Tree.Node { + try validate_range(start, len); + + if (comptime ops.merge_fn) |merge_fn| { + const left: ?*Tree.Node = if (start > 0) self.get_scalar_node(start - 1) else null; + const right = self.get_scalar_node(start + len); + + if (left) |l| { + const l_start = l.key.key.start; + + if (merge_fn(&l.key.value, &value)) { + if (right) |r| { + if (merge_fn(&r.key.value, &value)) { + l.key.key.len += len + r.key.key.len; + try self.btree.remove_node(r, true); + return self.get_scalar_node(l_start).?; + } + } + + l.key.key.len += len; + return l; + } + } + + if (right) |r| { + // Only right node to potentially merge with + if (merge_fn(&r.key.value, &value)) { + const r_len = r.key.key.len; + try self.btree.remove_node(r, true); + return self.btree.insert(.{ + .key = .{ .start = start, .len = len + r_len }, + .value = value, + }); + } + } + } + return self.btree.insert(.{ .key = .{ .start = start, .len = len }, .value = value, @@ -143,11 +191,18 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) } }.call); } + + fn validate_range(start: K, end: K) Error!void { + // Check for addition overflowing the K's bit size + if (std.math.add(K, start, end) == error.Overflow) { + return error.range_out_of_bounds; + } + } }; } test "Range map insertion" { - const Map = RangeMap(u32, []const u8, null); + const Map = RangeMap(u32, []const u8, .{}); var map = Map.init(std.testing.allocator); defer map.deinit(); @@ -160,8 +215,155 @@ test "Range map insertion" { _ = try map.insert(1000, 10, "Range 4"); } +test "Range map merging insertion" { + const Map = RangeMap(u32, bool, .{ + .merge_fn = struct { + fn call(lhs: *const bool, rhs: *const bool) bool { + return !lhs.* and !rhs.*; + } + }.call, + }); + var map = Map.init(std.testing.allocator); + defer map.deinit(); + + // Should not merge + _ = try map.insert(10, 10, false); + _ = try map.insert(0, 10, true); + + { + var it = map.iterator(); + try std.testing.expectEqual(true, it.next().?.value); + try std.testing.expectEqual(false, it.next().?.value); + try std.testing.expectEqual(null, it.next()); + } + + // Merge left + inserted + right + _ = try map.insert(30, 10, false); + _ = try map.insert(20, 10, false); + + { + var it = map.iterator(); + const n0 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key); + try std.testing.expectEqual(true, n0.value); + + const n1 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key); + try std.testing.expectEqual(false, n1.value); + + try std.testing.expectEqual(null, it.next()); + } + + // Should not merge again + _ = try map.insert(40, 10, true); + _ = try map.insert(50, 10, false); + + { + var it = map.iterator(); + const n0 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key); + try std.testing.expectEqual(true, n0.value); + + const n1 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key); + try std.testing.expectEqual(false, n1.value); + + const n2 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key); + try std.testing.expectEqual(true, n2.value); + + const n3 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 50, .len = 10 }, n3.key); + try std.testing.expectEqual(false, n3.value); + + try std.testing.expectEqual(null, it.next()); + } + + // Should merge left + shouldn't merge right non-contiguous + _ = try map.insert(71, 9, false); + _ = try map.insert(60, 10, false); + + { + var it = map.iterator(); + const n0 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key); + try std.testing.expectEqual(true, n0.value); + + const n1 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key); + try std.testing.expectEqual(false, n1.value); + + const n2 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key); + try std.testing.expectEqual(true, n2.value); + + const n3 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 50, .len = 20 }, n3.key); + try std.testing.expectEqual(false, n3.value); + + const n4 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 71, .len = 9 }, n4.key); + try std.testing.expectEqual(false, n4.value); + + try std.testing.expectEqual(null, it.next()); + } + + // Should merge left and right + _ = try map.insert(70, 1, false); + + { + var it = map.iterator(); + const n0 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key); + try std.testing.expectEqual(true, n0.value); + + const n1 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key); + try std.testing.expectEqual(false, n1.value); + + const n2 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key); + try std.testing.expectEqual(true, n2.value); + + const n3 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 50, .len = 30 }, n3.key); + try std.testing.expectEqual(false, n3.value); + + try std.testing.expectEqual(null, it.next()); + } + + // Should merge right + _ = try map.insert(110, 10, false); + _ = try map.insert(100, 10, false); + + { + var it = map.iterator(); + const n0 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key); + try std.testing.expectEqual(true, n0.value); + + const n1 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key); + try std.testing.expectEqual(false, n1.value); + + const n2 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key); + try std.testing.expectEqual(true, n2.value); + + const n3 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 50, .len = 30 }, n3.key); + try std.testing.expectEqual(false, n3.value); + + const n4 = it.next().?; + try std.testing.expectEqual(Range(u32) { .start = 100, .len = 20 }, n4.key); + try std.testing.expectEqual(false, n4.value); + + try std.testing.expectEqual(null, it.next()); + } +} + test "Range map get scalar" { - const Map = RangeMap(u32, []const u8, null); + const Map = RangeMap(u32, []const u8, .{}); var map = Map.init(std.testing.allocator); defer map.deinit(); @@ -195,7 +397,7 @@ test "Range map get scalar" { } test "Range map split" { - const Map = RangeMap(u32, []const u8, null); + const Map = RangeMap(u32, []const u8, .{}); var map = Map.init(std.testing.allocator); defer map.deinit(); @@ -222,7 +424,7 @@ test "Range map split" { } test "Range map iterator" { - const Map = RangeMap(u32, []const u8, null); + const Map = RangeMap(u32, []const u8, .{}); var map = Map.init(std.testing.allocator); defer map.deinit(); @@ -238,3 +440,11 @@ test "Range map iterator" { try std.testing.expectEqualStrings("Range [0x4000..0x5000)", it.next().?.value); try std.testing.expectEqual(null, it.next()); } + +test "Range map should disallow overflowing ranges" { + const Map = RangeMap(u32, bool, .{}); + var map = Map.init(std.testing.allocator); + defer map.deinit(); + + try std.testing.expectError(error.range_out_of_bounds, map.insert(0xF0000000, 0x20000000, false)); +}