lib: implement merge on insert in rangemap
This commit is contained in:
+218
-8
@@ -7,11 +7,22 @@ const Allocator = std.mem.Allocator;
|
||||
const BTree = btree.BTree;
|
||||
pub const Order = btree.Order;
|
||||
|
||||
pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V) void) type {
|
||||
pub fn RangeMap(
|
||||
comptime K: type,
|
||||
comptime V: type,
|
||||
comptime ops: struct {
|
||||
deinit_fn: ?fn(*V) void = null,
|
||||
merge_fn: ?fn (*const V, *const V) bool = null,
|
||||
},
|
||||
) type {
|
||||
return struct {
|
||||
pub const Node = struct {
|
||||
key: Range(K),
|
||||
value: V,
|
||||
|
||||
pub fn len(self: *const @This()) K {
|
||||
return self.key.len;
|
||||
}
|
||||
};
|
||||
|
||||
pub const WalkFn = fn (*const Node) void;
|
||||
@@ -28,10 +39,11 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V)
|
||||
}
|
||||
};
|
||||
|
||||
const Tree = BTree(Node, compare_fn, deinit_node_fn);
|
||||
pub const Tree = BTree(Node, compare_fn, deinit_node_fn);
|
||||
|
||||
const Error = error{
|
||||
pub const Error = error{
|
||||
scalar_out_of_range,
|
||||
range_out_of_bounds,
|
||||
} || Tree.Error;
|
||||
|
||||
fn compare_fn(a: *const Node, b: *const Node) Order {
|
||||
@@ -39,7 +51,7 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V)
|
||||
}
|
||||
|
||||
fn deinit_node_fn(n: *Node) void {
|
||||
if (comptime deinit_fn) |f| {
|
||||
if (comptime ops.deinit_fn) |f| {
|
||||
f(&n.value);
|
||||
}
|
||||
}
|
||||
@@ -122,6 +134,42 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V)
|
||||
/// Maps some range to a value. Returns an error if the requested range crosses another
|
||||
/// mapped range.
|
||||
pub fn insert(self: *@This(), start: K, len: K, value: V) Error!*Tree.Node {
|
||||
try validate_range(start, len);
|
||||
|
||||
if (comptime ops.merge_fn) |merge_fn| {
|
||||
const left: ?*Tree.Node = if (start > 0) self.get_scalar_node(start - 1) else null;
|
||||
const right = self.get_scalar_node(start + len);
|
||||
|
||||
if (left) |l| {
|
||||
const l_start = l.key.key.start;
|
||||
|
||||
if (merge_fn(&l.key.value, &value)) {
|
||||
if (right) |r| {
|
||||
if (merge_fn(&r.key.value, &value)) {
|
||||
l.key.key.len += len + r.key.key.len;
|
||||
try self.btree.remove_node(r, true);
|
||||
return self.get_scalar_node(l_start).?;
|
||||
}
|
||||
}
|
||||
|
||||
l.key.key.len += len;
|
||||
return l;
|
||||
}
|
||||
}
|
||||
|
||||
if (right) |r| {
|
||||
// Only right node to potentially merge with
|
||||
if (merge_fn(&r.key.value, &value)) {
|
||||
const r_len = r.key.key.len;
|
||||
try self.btree.remove_node(r, true);
|
||||
return self.btree.insert(.{
|
||||
.key = .{ .start = start, .len = len + r_len },
|
||||
.value = value,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return self.btree.insert(.{
|
||||
.key = .{ .start = start, .len = len },
|
||||
.value = value,
|
||||
@@ -143,11 +191,18 @@ pub fn RangeMap(comptime K: type, comptime V: type, comptime deinit_fn: ?fn (*V)
|
||||
}
|
||||
}.call);
|
||||
}
|
||||
|
||||
fn validate_range(start: K, end: K) Error!void {
|
||||
// Check for addition overflowing the K's bit size
|
||||
if (std.math.add(K, start, end) == error.Overflow) {
|
||||
return error.range_out_of_bounds;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
test "Range map insertion" {
|
||||
const Map = RangeMap(u32, []const u8, null);
|
||||
const Map = RangeMap(u32, []const u8, .{});
|
||||
var map = Map.init(std.testing.allocator);
|
||||
defer map.deinit();
|
||||
|
||||
@@ -160,8 +215,155 @@ test "Range map insertion" {
|
||||
_ = try map.insert(1000, 10, "Range 4");
|
||||
}
|
||||
|
||||
test "Range map merging insertion" {
|
||||
const Map = RangeMap(u32, bool, .{
|
||||
.merge_fn = struct {
|
||||
fn call(lhs: *const bool, rhs: *const bool) bool {
|
||||
return !lhs.* and !rhs.*;
|
||||
}
|
||||
}.call,
|
||||
});
|
||||
var map = Map.init(std.testing.allocator);
|
||||
defer map.deinit();
|
||||
|
||||
// Should not merge
|
||||
_ = try map.insert(10, 10, false);
|
||||
_ = try map.insert(0, 10, true);
|
||||
|
||||
{
|
||||
var it = map.iterator();
|
||||
try std.testing.expectEqual(true, it.next().?.value);
|
||||
try std.testing.expectEqual(false, it.next().?.value);
|
||||
try std.testing.expectEqual(null, it.next());
|
||||
}
|
||||
|
||||
// Merge left + inserted + right
|
||||
_ = try map.insert(30, 10, false);
|
||||
_ = try map.insert(20, 10, false);
|
||||
|
||||
{
|
||||
var it = map.iterator();
|
||||
const n0 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key);
|
||||
try std.testing.expectEqual(true, n0.value);
|
||||
|
||||
const n1 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key);
|
||||
try std.testing.expectEqual(false, n1.value);
|
||||
|
||||
try std.testing.expectEqual(null, it.next());
|
||||
}
|
||||
|
||||
// Should not merge again
|
||||
_ = try map.insert(40, 10, true);
|
||||
_ = try map.insert(50, 10, false);
|
||||
|
||||
{
|
||||
var it = map.iterator();
|
||||
const n0 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key);
|
||||
try std.testing.expectEqual(true, n0.value);
|
||||
|
||||
const n1 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key);
|
||||
try std.testing.expectEqual(false, n1.value);
|
||||
|
||||
const n2 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key);
|
||||
try std.testing.expectEqual(true, n2.value);
|
||||
|
||||
const n3 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 50, .len = 10 }, n3.key);
|
||||
try std.testing.expectEqual(false, n3.value);
|
||||
|
||||
try std.testing.expectEqual(null, it.next());
|
||||
}
|
||||
|
||||
// Should merge left + shouldn't merge right non-contiguous
|
||||
_ = try map.insert(71, 9, false);
|
||||
_ = try map.insert(60, 10, false);
|
||||
|
||||
{
|
||||
var it = map.iterator();
|
||||
const n0 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key);
|
||||
try std.testing.expectEqual(true, n0.value);
|
||||
|
||||
const n1 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key);
|
||||
try std.testing.expectEqual(false, n1.value);
|
||||
|
||||
const n2 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key);
|
||||
try std.testing.expectEqual(true, n2.value);
|
||||
|
||||
const n3 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 50, .len = 20 }, n3.key);
|
||||
try std.testing.expectEqual(false, n3.value);
|
||||
|
||||
const n4 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 71, .len = 9 }, n4.key);
|
||||
try std.testing.expectEqual(false, n4.value);
|
||||
|
||||
try std.testing.expectEqual(null, it.next());
|
||||
}
|
||||
|
||||
// Should merge left and right
|
||||
_ = try map.insert(70, 1, false);
|
||||
|
||||
{
|
||||
var it = map.iterator();
|
||||
const n0 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key);
|
||||
try std.testing.expectEqual(true, n0.value);
|
||||
|
||||
const n1 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key);
|
||||
try std.testing.expectEqual(false, n1.value);
|
||||
|
||||
const n2 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key);
|
||||
try std.testing.expectEqual(true, n2.value);
|
||||
|
||||
const n3 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 50, .len = 30 }, n3.key);
|
||||
try std.testing.expectEqual(false, n3.value);
|
||||
|
||||
try std.testing.expectEqual(null, it.next());
|
||||
}
|
||||
|
||||
// Should merge right
|
||||
_ = try map.insert(110, 10, false);
|
||||
_ = try map.insert(100, 10, false);
|
||||
|
||||
{
|
||||
var it = map.iterator();
|
||||
const n0 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 0, .len = 10 }, n0.key);
|
||||
try std.testing.expectEqual(true, n0.value);
|
||||
|
||||
const n1 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 10, .len = 30 }, n1.key);
|
||||
try std.testing.expectEqual(false, n1.value);
|
||||
|
||||
const n2 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 40, .len = 10 }, n2.key);
|
||||
try std.testing.expectEqual(true, n2.value);
|
||||
|
||||
const n3 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 50, .len = 30 }, n3.key);
|
||||
try std.testing.expectEqual(false, n3.value);
|
||||
|
||||
const n4 = it.next().?;
|
||||
try std.testing.expectEqual(Range(u32) { .start = 100, .len = 20 }, n4.key);
|
||||
try std.testing.expectEqual(false, n4.value);
|
||||
|
||||
try std.testing.expectEqual(null, it.next());
|
||||
}
|
||||
}
|
||||
|
||||
test "Range map get scalar" {
|
||||
const Map = RangeMap(u32, []const u8, null);
|
||||
const Map = RangeMap(u32, []const u8, .{});
|
||||
var map = Map.init(std.testing.allocator);
|
||||
defer map.deinit();
|
||||
|
||||
@@ -195,7 +397,7 @@ test "Range map get scalar" {
|
||||
}
|
||||
|
||||
test "Range map split" {
|
||||
const Map = RangeMap(u32, []const u8, null);
|
||||
const Map = RangeMap(u32, []const u8, .{});
|
||||
var map = Map.init(std.testing.allocator);
|
||||
defer map.deinit();
|
||||
|
||||
@@ -222,7 +424,7 @@ test "Range map split" {
|
||||
}
|
||||
|
||||
test "Range map iterator" {
|
||||
const Map = RangeMap(u32, []const u8, null);
|
||||
const Map = RangeMap(u32, []const u8, .{});
|
||||
var map = Map.init(std.testing.allocator);
|
||||
defer map.deinit();
|
||||
|
||||
@@ -238,3 +440,11 @@ test "Range map iterator" {
|
||||
try std.testing.expectEqualStrings("Range [0x4000..0x5000)", it.next().?.value);
|
||||
try std.testing.expectEqual(null, it.next());
|
||||
}
|
||||
|
||||
test "Range map should disallow overflowing ranges" {
|
||||
const Map = RangeMap(u32, bool, .{});
|
||||
var map = Map.init(std.testing.allocator);
|
||||
defer map.deinit();
|
||||
|
||||
try std.testing.expectError(error.range_out_of_bounds, map.insert(0xF0000000, 0x20000000, false));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user