Use custom min-heap to improve k-way merge by 30-50% (#217)

pull/218/head
Michael Barry 2022-05-08 20:00:13 -04:00 zatwierdzone przez GitHub
rodzic 726e6d0107
commit 9062e6b79b
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
7 zmienionych plików z 847 dodań i 34 usunięć

Wyświetl plik

@ -32,7 +32,8 @@ The `planetiler-core` module includes the following software:
- `Imposm3Parsers` from [imposm3](https://github.com/omniscale/imposm3) (Apache license)
- `PbfDecoder` from [osmosis](https://github.com/openstreetmap/osmosis) (Public Domain)
- `PbfFieldDecoder` from [osmosis](https://github.com/openstreetmap/osmosis) (Public Domain)
- `NativeUtil` from [uppend](https://github.com/upserve/uppend/) (MIT License)
- `Madvise` from [uppend](https://github.com/upserve/uppend/) (MIT License)
- `ArrayLongMinHeap` implementations from [graphhopper](https://github.com/graphhopper/graphhopper) (Apache license)
Additionally, the `planetiler-basemap` module is based on [OpenMapTiles](https://github.com/openmaptiles/openmaptiles):

Wyświetl plik

@ -25,30 +25,34 @@ import java.util.concurrent.ThreadLocalRandom;
public class BenchmarkExternalMergeSort {
private static final Format FORMAT = Format.defaultInstance();
private static final int ITEM_SIZE_BYTES = 76;
private static final byte[] TEST_DATA = new byte[ITEM_SIZE_BYTES - Long.BYTES - Integer.BYTES];
private static final int DISK_OVERHEAD_BYTES = Long.BYTES + Integer.BYTES;
private static final int ITEM_DATA_BYTES = ITEM_SIZE_BYTES - DISK_OVERHEAD_BYTES;
private static final int MEMORY_OVERHEAD_BYTES = 8 + 16 + 8 + 8 + 24;
private static final int ITEM_MEMORY_BYTES = MEMORY_OVERHEAD_BYTES + ITEM_DATA_BYTES;
private static final byte[] TEST_DATA = new byte[ITEM_DATA_BYTES];
static {
ThreadLocalRandom.current().nextBytes(TEST_DATA);
}
public static void main(String[] args) {
double gb = args.length == 0 ? 1 : Double.parseDouble(args[0]);
double gb = args.length < 1 ? 1 : Double.parseDouble(args[0]);
long number = (long) (gb * 1_000_000_000 / ITEM_SIZE_BYTES);
Path path = Path.of("./featuretest");
FileUtils.delete(path);
FileUtils.deleteOnExit(path);
var config = PlanetilerConfig.defaults();
try {
List<Results> results = new ArrayList<>();
int limit = 2_000_000_000;
for (int writers : List.of(1, 2, 4)) {
results.add(run(path, writers, number, limit, false, true, true, config));
results.add(run(path, writers, number, limit, true, true, true, config));
for (int i = 0; i < 3; i++) {
try {
List<Results> results = new ArrayList<>();
for (int chunks : List.of(1, 10, 100, 1_000, 10_000)) {
results.add(run(path, 1, number, chunks, true, true, true, config));
}
for (var result : results) {
System.err.println(result);
}
} finally {
FileUtils.delete(path);
}
for (var result : results) {
System.err.println(result);
}
} finally {
FileUtils.delete(path);
}
}
@ -60,15 +64,18 @@ public class BenchmarkExternalMergeSort {
boolean madvise
) {}
private static Results run(Path tmpDir, int writeWorkers, long items, int chunkSizeLimit, boolean mmap,
boolean parallelSort,
boolean madvise, PlanetilerConfig config) {
private static Results run(Path tmpDir, int writeWorkers, long items, int numChunks,
boolean mmap, boolean parallelSort, boolean madvise, PlanetilerConfig config) {
long chunkSizeLimit = items * ITEM_MEMORY_BYTES / numChunks;
if (chunkSizeLimit > Integer.MAX_VALUE) {
throw new IllegalStateException("Chunk size too big: " + chunkSizeLimit);
}
boolean gzip = false;
int sortWorkers = Runtime.getRuntime().availableProcessors();
int readWorkers = 1;
FileUtils.delete(tmpDir);
var sorter =
new ExternalMergeSort(tmpDir, sortWorkers, chunkSizeLimit, gzip, mmap, parallelSort, madvise, config,
new ExternalMergeSort(tmpDir, sortWorkers, (int) chunkSizeLimit, gzip, mmap, parallelSort, madvise, config,
Stats.inMemory());
var writeTimer = Timer.start();
@ -91,7 +98,7 @@ public class BenchmarkExternalMergeSort {
writeWorkers,
readWorkers,
items,
chunkSizeLimit,
(int) chunkSizeLimit,
gzip,
mmap,
parallelSort,

Wyświetl plik

@ -0,0 +1,122 @@
package com.onthegomap.planetiler.collection;
import java.time.Duration;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.function.IntFunction;
import java.util.stream.IntStream;
/**
* Performance tests for {@link LongMinHeap} implementations.
*
* Times how long it takes to merge N sorted lists of random elements.
*/
public class BenchmarkKWayMerge {
public static void main(String[] args) {
for (int i = 0; i < 4; i++) {
System.err.println();
testMinHeap("quaternary", LongMinHeap::newArrayHeap);
System.err.println(String.join("\t",
"priorityqueue",
Long.toString(testPriorityQueue(10).toMillis()),
Long.toString(testPriorityQueue(100).toMillis()),
Long.toString(testPriorityQueue(1_000).toMillis()),
Long.toString(testPriorityQueue(10_000).toMillis())));
}
}
private static void testMinHeap(String name, IntFunction<LongMinHeap> constructor) {
System.err.println(String.join("\t",
name,
Long.toString(testUpdates(10, constructor).toMillis()),
Long.toString(testUpdates(100, constructor).toMillis()),
Long.toString(testUpdates(1_000, constructor).toMillis()),
Long.toString(testUpdates(10_000, constructor).toMillis())));
}
private static final Random random = new Random();
private static long[][] getVals(int size) {
int num = 10_000_000;
return IntStream.range(0, size)
.mapToObj(i -> random
.longs(0, 1_000_000_000)
.limit(num / size)
.sorted()
.toArray()
).toArray(long[][]::new);
}
private static Duration testUpdates(int size, IntFunction<LongMinHeap> heapFn) {
int[] indexes = new int[size];
long[][] vals = getVals(size);
var heap = heapFn.apply(size);
for (int i = 0; i < size; i++) {
heap.push(i, vals[i][indexes[i]++]);
}
var start = System.nanoTime();
while (!heap.isEmpty()) {
int id = heap.peekId();
int index = indexes[id]++;
long[] valList = vals[id];
if (index < valList.length) {
heap.updateHead(valList[index]);
} else {
heap.poll();
}
}
return Duration.ofNanos(System.nanoTime() - start);
}
static class Item implements Comparable<Item> {
long value;
int id;
@Override
public int compareTo(Item o) {
return Long.compare(value, o.value);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Item item = (Item) o;
return value == item.value;
}
@Override
public int hashCode() {
return (int) (value ^ (value >>> 32));
}
}
private static Duration testPriorityQueue(int size) {
long[][] vals = getVals(size);
int[] indexes = new int[size];
PriorityQueue<Item> heap = new PriorityQueue<>();
for (int i = 0; i < size; i++) {
Item item = new Item();
item.id = i;
item.value = vals[i][indexes[i]++];
heap.offer(item);
}
var start = System.nanoTime();
while (!heap.isEmpty()) {
var item = heap.poll();
int index = indexes[item.id]++;
long[] valList = vals[item.id];
if (index < valList.length) {
item.value = valList[index];
heap.offer(item);
}
}
return Duration.ofNanos(System.nanoTime() - start);
}
}

Wyświetl plik

@ -0,0 +1,222 @@
/*
* Licensed to GraphHopper GmbH under one or more contributor
* license agreements. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*
* GraphHopper GmbH licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.onthegomap.planetiler.collection;
import java.util.Arrays;
/**
* A min-heap stored in an array where each element has 4 children.
* <p>
* This is about 5-10% faster than the standard binary min-heap for the case of merging sorted lists.
* <p>
* Ported from <a href=
* "https://github.com/graphhopper/graphhopper/blob/master/core/src/main/java/com/graphhopper/coll/MinHeapWithUpdate.java">GraphHopper</a>
* and:
* <ul>
* <li>modified to use {@code long} values instead of {@code float}</li>
* <li>extracted a common interface for subclass implementations</li>
* <li>modified so that each element has 4 children instead of 2 (improves performance by 5-10%)</li>
* <li>performance improvements to minimize array lookups</li>
* </ul>
*
* @see <a href="https://en.wikipedia.org/wiki/D-ary_heap">d-ary heap (wikipedia)</a>
*/
class ArrayLongMinHeap implements LongMinHeap {
protected static final int NOT_PRESENT = -1;
protected final int[] tree;
protected final int[] positions;
protected final long[] vals;
protected final int max;
protected int size;
/**
* @param elements the number of elements that can be stored in this heap. Currently the heap cannot be resized or
* shrunk/trimmed after initial creation. elements-1 is the maximum id that can be stored in this heap
*/
ArrayLongMinHeap(int elements) {
// we use an offset of one to make the arithmetic a bit simpler/more efficient, the 0th elements are not used!
tree = new int[elements + 1];
positions = new int[elements + 1];
Arrays.fill(positions, NOT_PRESENT);
vals = new long[elements + 1];
vals[0] = Long.MIN_VALUE;
this.max = elements;
}
private static int firstChild(int index) {
return (index << 2) - 2;
}
private static int parent(int index) {
return (index + 2) >> 2;
}
@Override
public int size() {
return size;
}
@Override
public boolean isEmpty() {
return size == 0;
}
@Override
public void push(int id, long value) {
checkIdInRange(id);
if (size == max) {
throw new IllegalStateException("Cannot push anymore, the heap is already full. size: " + size);
}
if (contains(id)) {
throw new IllegalStateException("Element with id: " + id +
" was pushed already, you need to use the update method if you want to change its value");
}
size++;
tree[size] = id;
positions[id] = size;
vals[size] = value;
percolateUp(size);
}
@Override
public boolean contains(int id) {
checkIdInRange(id);
return positions[id] != NOT_PRESENT;
}
@Override
public void update(int id, long value) {
checkIdInRange(id);
int index = positions[id];
if (index < 0) {
throw new IllegalStateException(
"The heap does not contain: " + id + ". Use the contains method to check this before calling update");
}
long prev = vals[index];
vals[index] = value;
if (value > prev) {
percolateDown(index);
} else if (value < prev) {
percolateUp(index);
}
}
@Override
public void updateHead(long value) {
vals[1] = value;
percolateDown(1);
}
@Override
public int peekId() {
return tree[1];
}
@Override
public long peekValue() {
return vals[1];
}
@Override
public int poll() {
int id = peekId();
tree[1] = tree[size];
vals[1] = vals[size];
positions[tree[1]] = 1;
positions[id] = NOT_PRESENT;
size--;
percolateDown(1);
return id;
}
@Override
public void clear() {
for (int i = 1; i <= size; i++) {
positions[tree[i]] = NOT_PRESENT;
}
size = 0;
}
private void percolateUp(int index) {
assert index != 0;
if (index == 1) {
return;
}
final int el = tree[index];
final long val = vals[index];
// the finish condition (index==0) is covered here automatically because we set vals[0]=-inf
int parent;
long parentValue;
while (val < (parentValue = vals[parent = parent(index)])) {
vals[index] = parentValue;
positions[tree[index] = tree[parent]] = index;
index = parent;
}
tree[index] = el;
vals[index] = val;
positions[tree[index]] = index;
}
private void checkIdInRange(int id) {
if (id < 0 || id >= max) {
throw new IllegalArgumentException("Illegal id: " + id + ", legal range: [0, " + max + "[");
}
}
private void percolateDown(int index) {
if (size == 0) {
return;
}
assert index > 0;
assert index <= size;
final int el = tree[index];
final long val = vals[index];
int child;
while ((child = firstChild(index)) <= size) {
// optimization: this is a very hot code path for performance of k-way merging,
// so manually-unroll the loop over the 4 child elements to find the minimum value
int minChild = child;
long minValue = vals[child], value;
if (++child <= size) {
if ((value = vals[child]) < minValue) {
minChild = child;
minValue = value;
}
if (++child <= size) {
if ((value = vals[child]) < minValue) {
minChild = child;
minValue = value;
}
if (++child <= size && (value = vals[child]) < minValue) {
minChild = child;
minValue = value;
}
}
}
if (minValue >= val) {
break;
}
vals[index] = minValue;
positions[tree[index] = tree[minChild]] = index;
index = minChild;
}
tree[index] = el;
vals[index] = val;
positions[el] = index;
}
}

Wyświetl plik

@ -28,13 +28,13 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.PriorityQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
@ -212,7 +212,7 @@ class ExternalMergeSort implements FeatureSort {
});
ProgressLoggers loggers = ProgressLoggers.create()
.addPercentCounter("chunks", chunks.size(), doneCounter)
.addPercentCounter("chunks", groups.size(), doneCounter)
.addFileSize(this)
.newLine()
.addProcessStats()
@ -243,26 +243,33 @@ class ExternalMergeSort implements FeatureSort {
}
// k-way merge to interleave all the sorted chunks
PriorityQueue<Reader<?>> queue = new PriorityQueue<>(chunks.size());
List<Reader> iterators = new ArrayList<>();
for (Chunk chunk : chunks) {
if (chunk.itemCount > 0) {
queue.add(chunk.newReader());
iterators.add(chunk.newReader());
}
}
LongMinHeap heap = LongMinHeap.newArrayHeap(iterators.size());
for (int i = 0; i < iterators.size(); i++) {
heap.push(i, iterators.get(i).nextKey());
}
return new Iterator<>() {
@Override
public boolean hasNext() {
return !queue.isEmpty();
return !heap.isEmpty();
}
@Override
public SortableFeature next() {
Reader<?> iterator = queue.poll();
int i = heap.peekId();
Reader iterator = iterators.get(i);
assert iterator != null;
SortableFeature next = iterator.next();
if (iterator.hasNext()) {
queue.add(iterator);
heap.updateHead(iterator.nextKey());
} else {
heap.poll();
}
return next;
}
@ -284,14 +291,16 @@ class ExternalMergeSort implements FeatureSort {
}
private interface Writer extends Closeable {
void write(SortableFeature feature) throws IOException;
}
private interface Reader<T extends Reader<?>>
extends Closeable, Iterator<SortableFeature>, Comparable<T> {
private interface Reader extends Closeable, Iterator<SortableFeature> {
@Override
void close();
long nextKey();
}
/** Compresses bytes with minimal impact on write performance. Equivalent to {@code gzip -1} */
@ -304,7 +313,7 @@ class ExternalMergeSort implements FeatureSort {
}
/** Read all features from a chunk file using a {@link BufferedInputStream}. */
private static class ReaderBuffered extends BaseReader<ReaderBuffered> {
private static class ReaderBuffered extends BaseReader {
private final int count;
private final DataInputStream input;
@ -382,7 +391,8 @@ class ExternalMergeSort implements FeatureSort {
}
/** Common functionality between {@link ReaderMmap} and {@link ReaderBuffered}. */
private abstract static class BaseReader<T extends BaseReader<?>> implements Reader<T> {
private abstract static class BaseReader implements Reader {
SortableFeature next;
@Override
@ -403,8 +413,8 @@ class ExternalMergeSort implements FeatureSort {
}
@Override
public final int compareTo(T o) {
return next.compareTo(o.next);
public final long nextKey() {
return next.key();
}
abstract SortableFeature readNextFeature();
@ -413,6 +423,7 @@ class ExternalMergeSort implements FeatureSort {
/** Writer that a single thread can use to write features independent of writers used in other threads. */
@NotThreadSafe
private class ThreadLocalWriter implements CloseableConusmer<SortableFeature> {
private Chunk currentChunk;
private ThreadLocalWriter() {
@ -456,6 +467,7 @@ class ExternalMergeSort implements FeatureSort {
/** Write features to the chunk file through a memory-mapped file. */
private class WriterMmap implements Writer {
private final FileChannel channel;
private final MappedByteBuffer buffer;
@ -555,7 +567,7 @@ class ExternalMergeSort implements FeatureSort {
return mmapIO ? new WriterMmap(path) : new WriterBuffered(path, gzip);
}
private Reader<?> newReader() {
private Reader newReader() {
return mmapIO ? new ReaderMmap(path, itemCount) : new ReaderBuffered(path, itemCount, gzip);
}
@ -613,7 +625,8 @@ class ExternalMergeSort implements FeatureSort {
}
/** Memory-map the chunk file, then iterate through all features in it. */
private class ReaderMmap extends BaseReader<ReaderMmap> {
private class ReaderMmap extends BaseReader {
private final int count;
private final FileChannel channel;
private final MappedByteBuffer buffer;

Wyświetl plik

@ -0,0 +1,79 @@
/*
* Licensed to GraphHopper GmbH under one or more contributor
* license agreements. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*
* GraphHopper GmbH licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.onthegomap.planetiler.collection;
/**
* API for min-heaps that keeps track of {@code int} keys in a range from {@code [0, size)} ordered by {@code long}
* values.
* <p>
* Ported from <a href=
* "https://github.com/graphhopper/graphhopper/blob/master/core/src/main/java/com/graphhopper/coll/MinHeapWithUpdate.java">GraphHopper</a>
* and modified to extract a common interface for subclass implementations.
*/
public interface LongMinHeap {
/**
* Returns a new min-heap where each element has 4 children backed by elements in an array.
* <p>
* This is slightly faster than a traditional binary min heap due to a shallower, more cache-friendly memory layout.
*/
static LongMinHeap newArrayHeap(int elements) {
return new ArrayLongMinHeap(elements);
}
int size();
boolean isEmpty();
/**
* Adds an element to the heap, the given id must not exceed the size specified in the constructor. Its illegal to
* push the same id twice (unless it was polled/removed before). To update the value of an id contained in the heap
* use the {@link #update} method.
*/
void push(int id, long value);
/**
* @return true if the heap contains an element with the given id
*/
boolean contains(int id);
/**
* Updates the element with the given id. The complexity of this method is O(log(N)), just like push/poll. Its illegal
* to update elements that are not contained in the heap. Use {@link #contains} to check the existence of an id.
*/
void update(int id, long value);
/**
* Updates the weight of the head element in the heap, pushing it down and bubbling up the new min element if
* necessary.
*/
void updateHead(long value);
/**
* @return the id of the next element to be polled, i.e. the same as calling poll() without removing the element
*/
int peekId();
long peekValue();
/**
* Extracts the element with minimum value from the heap
*/
int poll();
void clear();
}

Wyświetl plik

@ -0,0 +1,369 @@
/*
* Licensed to GraphHopper GmbH under one or more contributor
* license agreements. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*
* GraphHopper GmbH licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.onthegomap.planetiler.collection;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntHashSet;
import com.carrotsearch.hppc.IntSet;
import java.util.PriorityQueue;
import java.util.Random;
import org.junit.jupiter.api.Test;
/**
* Ported from <a href=
* "https://github.com/graphhopper/graphhopper/blob/master/core/src/test/java/com/graphhopper/coll/MinHeapWithUpdateTest.java">GraphHopper</a>
* and modified to use long instead of float values, use stable random seed for reproducibility, and to use new
* implementations.
*/
class LongMinHeapTest {
protected LongMinHeap heap;
void create(int capacity) {
heap = LongMinHeap.newArrayHeap(capacity);
}
@Test
void outOfRange() {
create(4);
assertThrows(IllegalArgumentException.class, () -> heap.push(4, 12L));
assertThrows(IllegalArgumentException.class, () -> heap.push(-1, 12L));
}
@Test
void tooManyElements() {
create(3);
heap.push(1, 1L);
heap.push(2, 1L);
heap.push(0, 1L);
// pushing element 1 again is not allowed (but this is not checked explicitly). however pushing more elements
// than 3 is already an error
assertThrows(IllegalStateException.class, () -> heap.push(1, 1L));
assertThrows(IllegalStateException.class, () -> heap.push(2, 61L));
}
@Test
void duplicateElements() {
create(5);
heap.push(1, 2L);
heap.push(0, 4L);
heap.push(2, 1L);
assertEquals(2, heap.poll());
// pushing 2 again is ok because it was polled before
heap.push(2, 6L);
// but now its not ok to push it again
assertThrows(IllegalStateException.class, () -> heap.push(2, 4L));
}
@Test
void testContains() {
create(4);
heap.push(1, 1L);
heap.push(2, 7L);
heap.push(0, 5L);
assertFalse(heap.contains(3));
assertTrue(heap.contains(1));
assertEquals(1, heap.poll());
assertFalse(heap.contains(1));
}
@Test
void containsAfterClear() {
create(4);
heap.push(1, 1L);
heap.push(2, 1L);
assertEquals(2, heap.size());
heap.clear();
assertFalse(heap.contains(0));
assertFalse(heap.contains(1));
assertFalse(heap.contains(2));
}
@Test
void testSize() {
create(10);
assertEquals(0, heap.size());
assertTrue(heap.isEmpty());
heap.push(9, 36L);
heap.push(5, 23L);
heap.push(3, 23L);
assertEquals(3, heap.size());
assertFalse(heap.isEmpty());
}
@Test
void testClear() {
create(5);
assertTrue(heap.isEmpty());
heap.push(3, 12L);
heap.push(4, 3L);
assertEquals(2, heap.size());
heap.clear();
assertTrue(heap.isEmpty());
heap.push(4, 63L);
heap.push(1, 21L);
assertEquals(2, heap.size());
assertEquals(1, heap.peekId());
assertEquals(21L, heap.peekValue());
assertEquals(1, heap.poll());
assertEquals(4, heap.poll());
assertTrue(heap.isEmpty());
}
@Test
void testPush() {
create(5);
heap.push(4, 63L);
heap.push(1, 21L);
assertEquals(2, heap.size());
assertEquals(1, heap.peekId());
assertEquals(21L, heap.peekValue());
assertEquals(1, heap.poll());
assertEquals(4, heap.poll());
assertTrue(heap.isEmpty());
}
@Test
void testPeek() {
create(5);
heap.push(4, -16L);
heap.push(2, 13L);
heap.push(1, -51L);
heap.push(3, 4L);
assertEquals(1, heap.peekId());
assertEquals(-51L, heap.peekValue());
}
@Test
void pushAndPoll() {
create(10);
heap.push(9, 36L);
heap.push(5, 23L);
heap.push(3, 23L);
assertEquals(3, heap.size());
heap.poll();
assertEquals(2, heap.size());
heap.poll();
heap.poll();
assertTrue(heap.isEmpty());
}
@Test
void pollSorted() {
create(10);
heap.push(9, 36L);
heap.push(5, 21L);
heap.push(3, 23L);
heap.push(8, 57L);
heap.push(7, 22L);
IntArrayList polled = new IntArrayList();
while (!heap.isEmpty()) {
polled.add(heap.poll());
}
assertEquals(IntArrayList.from(5, 7, 3, 9, 8), polled);
}
@Test
void poll() {
create(10);
assertTrue(heap.isEmpty());
assertEquals(0, heap.size());
heap.push(9, 36L);
assertFalse(heap.isEmpty());
assertEquals(1, heap.size());
heap.push(5, 21L);
assertFalse(heap.isEmpty());
assertEquals(2, heap.size());
heap.push(3, 23L);
assertFalse(heap.isEmpty());
assertEquals(3, heap.size());
heap.push(8, 57L);
assertFalse(heap.isEmpty());
assertEquals(4, heap.size());
assertEquals(5, heap.poll());
assertFalse(heap.isEmpty());
assertEquals(3, heap.size());
assertEquals(3, heap.poll());
assertFalse(heap.isEmpty());
assertEquals(2, heap.size());
assertEquals(9, heap.poll());
assertFalse(heap.isEmpty());
assertEquals(1, heap.size());
assertEquals(8, heap.poll());
assertTrue(heap.isEmpty());
assertEquals(0, heap.size());
}
@Test
void clear() {
create(10);
heap.push(9, 36L);
heap.push(5, 21L);
heap.push(3, 23L);
heap.clear();
assertTrue(heap.isEmpty());
assertEquals(0, heap.size());
}
@Test
void poll100Ascending() {
create(100);
for (int i = 1; i < 100; i++) {
heap.push(i, i);
}
for (int i = 1; i < 100; i++) {
assertEquals(i, heap.poll());
}
}
@Test
void poll100Descending() {
create(100);
for (int i = 99; i >= 1; i--) {
heap.push(i, i);
}
for (int i = 1; i < 100; i++) {
assertEquals(i, heap.poll());
}
}
@Test
void update() {
create(10);
heap.push(9, 36L);
heap.push(5, 21L);
heap.push(3, 23L);
heap.update(3, 1L);
assertEquals(3, heap.peekId());
heap.update(3, 100L);
assertEquals(5, heap.peekId());
heap.update(9, -13L);
assertEquals(9, heap.peekId());
assertEquals(-13L, heap.peekValue());
IntArrayList polled = new IntArrayList();
while (!heap.isEmpty()) {
polled.add(heap.poll());
}
assertEquals(IntArrayList.from(9, 5, 3), polled);
}
@Test
void updateHead() {
create(10);
heap.push(1, 1);
heap.push(2, 2);
heap.push(3, 3);
heap.push(4, 4);
heap.push(5, 5);
heap.updateHead(6);
heap.updateHead(7);
heap.updateHead(8);
IntArrayList polled = new IntArrayList();
while (!heap.isEmpty()) {
polled.add(heap.poll());
}
assertEquals(IntArrayList.from(4, 5, 1, 2, 3), polled);
}
@Test
void randomPushsThenPolls() {
Random rnd = new Random(0);
int size = 1 + rnd.nextInt(100);
PriorityQueue<Entry> pq = new PriorityQueue<>(size);
create(size);
IntSet set = new IntHashSet();
while (pq.size() < size) {
int id = rnd.nextInt(size);
if (!set.add(id))
continue;
long val = (long) (Long.MAX_VALUE * rnd.nextFloat());
pq.add(new Entry(id, val));
heap.push(id, val);
}
while (!pq.isEmpty()) {
Entry entry = pq.poll();
assertEquals(entry.val, heap.peekValue());
assertEquals(entry.id, heap.poll());
assertEquals(pq.size(), heap.size());
}
}
@Test
void randomPushsAndPolls() {
Random rnd = new Random(0);
int size = 1 + rnd.nextInt(100);
PriorityQueue<Entry> pq = new PriorityQueue<>(size);
create(size);
IntSet set = new IntHashSet();
int pushCount = 0;
for (int i = 0; i < 1000; i++) {
boolean push = pq.isEmpty() || (rnd.nextBoolean());
if (push) {
int id = rnd.nextInt(size);
if (!set.add(id))
continue;
long val = (long) (Long.MAX_VALUE * rnd.nextFloat());
pq.add(new Entry(id, val));
heap.push(id, val);
pushCount++;
} else {
Entry entry = pq.poll();
assert entry != null;
assertEquals(entry.val, heap.peekValue());
assertEquals(entry.id, heap.poll());
assertEquals(pq.size(), heap.size());
set.removeAll(entry.id);
}
}
assertTrue(pushCount > 0);
}
static class Entry implements Comparable<Entry> {
int id;
long val;
public Entry(int id, long val) {
this.id = id;
this.val = val;
}
@Override
public int compareTo(Entry o) {
return Long.compare(val, o.val);
}
}
}