Support for each loop in workers (#83)

pull/84/head
Michael Barry 2022-02-23 20:32:41 -05:00 zatwierdzone przez GitHub
rodzic 0357f4ba8f
commit 209361eb7e
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
9 zmienionych plików z 151 dodań i 35 usunięć

Wyświetl plik

@ -0,0 +1,32 @@
package com.onthegomap.planetiler.collection;
import java.util.Iterator;
import java.util.function.Supplier;
/**
* A {@link Supplier} that returns {@code null} when there are no elements left, with an {@link Iterable} view to
* support for each loops.
*
* @param <T> Type of element returned
*/
public interface IterableOnce<T> extends Iterable<T>, Supplier<T> {
@Override
default Iterator<T> iterator() {
return new Iterator<>() {
T next = get();
@Override
public boolean hasNext() {
return next != null;
}
@Override
public T next() {
T result = next;
next = get();
return result;
}
};
}
}

Wyświetl plik

@ -31,7 +31,6 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAccumulator;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
@ -219,21 +218,20 @@ public class MbtilesWriter {
}
}
private void tileEncoderSink(Supplier<TileBatch> prev) throws IOException {
private void tileEncoderSink(Iterable<TileBatch> prev) throws IOException {
tileEncoder(prev, batch -> {
// no next step
});
}
private void tileEncoder(Supplier<TileBatch> prev, Consumer<TileBatch> next) throws IOException {
TileBatch batch;
private void tileEncoder(Iterable<TileBatch> prev, Consumer<TileBatch> next) throws IOException {
/*
* To optimize emitting many identical consecutive tiles (like large ocean areas), memoize output to avoid
* recomputing if the input hasn't changed.
*/
byte[] lastBytes = null, lastEncoded = null;
while ((batch = prev.get()) != null) {
for (TileBatch batch : prev) {
Queue<Mbtiles.TileEntry> result = new ArrayDeque<>(batch.size());
FeatureGroup.TileFeatures last = null;
// each batch contains tile ordered by z asc, x asc, y desc
@ -268,7 +266,7 @@ public class MbtilesWriter {
}
}
private void tileWriter(Supplier<TileBatch> tileBatches) throws ExecutionException, InterruptedException {
private void tileWriter(Iterable<TileBatch> tileBatches) throws ExecutionException, InterruptedException {
db.createTables();
if (!config.deferIndexCreation()) {
db.addTileIndex();
@ -292,8 +290,7 @@ public class MbtilesWriter {
Timer time = null;
int currentZ = Integer.MIN_VALUE;
try (var batchedWriter = db.newBatchedTileWriter()) {
TileBatch batch;
while ((batch = tileBatches.get()) != null) {
for (TileBatch batch : tileBatches) {
Queue<Mbtiles.TileEntry> tiles = batch.out.get();
Mbtiles.TileEntry tile;
while ((tile = tiles.poll()) != null) {

Wyświetl plik

@ -57,10 +57,9 @@ public abstract class SimpleReader implements Closeable {
.fromGenerator("read", read())
.addBuffer("read_queue", 1000)
.<SortableFeature>addWorker("process", threads, (prev, next) -> {
SourceFeature sourceFeature;
var featureCollectors = new FeatureCollector.Factory(config, stats);
FeatureRenderer renderer = newFeatureRenderer(writer, config, next);
while ((sourceFeature = prev.get()) != null) {
for (SourceFeature sourceFeature : prev) {
featuresRead.incrementAndGet();
FeatureCollector features = featureCollectors.get(sourceFeature);
if (sourceFeature.latLonGeometry().getEnvelopeInternal().intersects(latLonBounds)) {

Wyświetl plik

@ -256,11 +256,10 @@ public class OsmReader implements Closeable, MemoryEstimator.HasEstimate {
Counter ways = waysProcessed.counterForThread();
Counter rels = relsProcessed.counterForThread();
ReaderElement readerElement;
var featureCollectors = new FeatureCollector.Factory(config, stats);
NodeLocationProvider nodeLocations = newNodeLocationProvider();
FeatureRenderer renderer = createFeatureRenderer(writer, config, next);
while ((readerElement = prev.get()) != null) {
for (ReaderElement readerElement : prev) {
SourceFeature feature = null;
if (readerElement instanceof ReaderNode node) {
nodes.inc();

Wyświetl plik

@ -42,7 +42,6 @@ import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@ -131,8 +130,7 @@ public class Wikidata {
.addWorker("filter", processThreads, fetcher::filter)
.addBuffer("fetch_queue", 1_000_000, 100)
.sinkTo("fetch", 1, prev -> {
Long id;
while ((id = prev.get()) != null) {
for (Long id : prev) {
fetcher.fetch(id);
}
fetcher.flush();
@ -217,9 +215,8 @@ public class Wikidata {
}
/** Only pass elements that the profile cares about to next step in pipeline. */
private void filter(Supplier<ReaderElement> prev, Consumer<Long> next) {
ReaderElement elem;
while ((elem = prev.get()) != null) {
private void filter(Iterable<ReaderElement> prev, Consumer<Long> next) {
for (ReaderElement elem : prev) {
switch (elem.getType()) {
case ReaderElement.NODE -> nodes.inc();
case ReaderElement.WAY -> ways.inc();

Wyświetl plik

@ -1,5 +1,6 @@
package com.onthegomap.planetiler.worker;
import com.onthegomap.planetiler.collection.IterableOnce;
import com.onthegomap.planetiler.stats.Counter;
import com.onthegomap.planetiler.stats.Stats;
import java.util.ArrayDeque;
@ -12,7 +13,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
/**
* A high-performance blocking queue to hand off work from producing threads to consuming threads.
@ -28,7 +28,7 @@ import java.util.function.Supplier;
*
* @param <T> the type of elements held in this queue
*/
public class WorkQueue<T> implements AutoCloseable, Supplier<T>, Consumer<T> {
public class WorkQueue<T> implements AutoCloseable, IterableOnce<T>, Consumer<T> {
private final BlockingQueue<Queue<T>> itemQueue;
private final int batchSize;
@ -87,7 +87,7 @@ public class WorkQueue<T> implements AutoCloseable, Supplier<T>, Consumer<T> {
}
/** Returns a reader optimized to produce items for a single thread. */
public Supplier<T> threadLocalReader() {
public IterableOnce<T> threadLocalReader() {
return readerProvider.get();
}
@ -174,7 +174,7 @@ public class WorkQueue<T> implements AutoCloseable, Supplier<T>, Consumer<T> {
}
/** Caches thread-local values so that a single thread can read new items without having to do thread-local lookups. */
private class ReaderForThread implements Supplier<T> {
private class ReaderForThread implements IterableOnce<T> {
Queue<T> readBatch = null;
final Counter dequeueBlockTimeNanos = dequeueBlockTimeNanosAll.counterForThread();

Wyświetl plik

@ -2,6 +2,7 @@ package com.onthegomap.planetiler.worker;
import static com.onthegomap.planetiler.worker.Worker.joinFutures;
import com.onthegomap.planetiler.collection.IterableOnce;
import com.onthegomap.planetiler.stats.ProgressLoggers;
import com.onthegomap.planetiler.stats.Stats;
import java.time.Duration;
@ -10,7 +11,6 @@ import java.util.Iterator;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.function.Supplier;
/**
* A mini-framework for chaining sequential steps that run in dedicated threads with a queue between each.
@ -107,7 +107,7 @@ public record WorkerPipeline<T>(
* @param next call {@code next.accept} to pass items to the next step of the pipeline
* @throws Exception if an error occurs, will be rethrown by {@link #await()} as a {@link RuntimeException}
*/
void run(Supplier<I> prev, Consumer<O> next) throws Exception;
void run(IterableOnce<I> prev, Consumer<O> next) throws Exception;
}
/**
@ -125,7 +125,7 @@ public record WorkerPipeline<T>(
* elements to process
* @throws Exception if an error occurs, will be rethrown by {@link #await()} as a {@link RuntimeException}
*/
void run(Supplier<I> prev) throws Exception;
void run(IterableOnce<I> prev) throws Exception;
}
/**
@ -151,7 +151,7 @@ public record WorkerPipeline<T>(
}
/** Builder for a new topology that does not yet have any steps. */
public static record Empty(String prefix, Stats stats) {
public record Empty(String prefix, Stats stats) {
/**
* Adds an initial step that runs {@code producer} in {@code threads} worker threads to produce items for this
@ -213,7 +213,7 @@ public record WorkerPipeline<T>(
*
* @param <O> type of elements that the next step must process
*/
public static record Builder<O>(
public record Builder<O>(
String prefix,
String name,
// keep track of previous elements so that build can wire-up the computation graph
@ -276,8 +276,7 @@ public record WorkerPipeline<T>(
*/
public WorkerPipeline<O> sinkToConsumer(String name, int threads, Consumer<O> consumer) {
return sinkTo(name, threads, (prev) -> {
O item;
while ((item = prev.get()) != null) {
for (O item : prev) {
consumer.accept(item);
}
});

Wyświetl plik

@ -0,0 +1,96 @@
package com.onthegomap.planetiler.collection;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import org.junit.jupiter.api.Test;
public class IterableOnceTest {
@Test
public void testIterableOnceEmpty() {
IterableOnce<Integer> empty = () -> null;
var iter = empty.iterator();
assertFalse(iter.hasNext());
assertNull(iter.next());
assertFalse(iter.hasNext());
assertNull(iter.next());
}
@Test
public void testSingleItem() {
Queue<Integer> queue = new LinkedList<>(List.of(1));
IterableOnce<Integer> iterable = queue::poll;
var iter = iterable.iterator();
assertTrue(iter.hasNext());
assertEquals(1, iter.next());
assertFalse(iter.hasNext());
assertNull(iter.next());
}
@Test
public void testMultipleItems() {
Queue<Integer> queue = new LinkedList<>(List.of(1, 2));
IterableOnce<Integer> iterable = queue::poll;
var iter = iterable.iterator();
assertTrue(iter.hasNext());
assertEquals(1, iter.next());
assertTrue(iter.hasNext());
assertEquals(2, iter.next());
assertFalse(iter.hasNext());
assertNull(iter.next());
}
@Test
public void testMultipleIterators() {
Queue<Integer> queue = new LinkedList<>(List.of(1, 2));
IterableOnce<Integer> iterable = queue::poll;
var iter1 = iterable.iterator();
var iter2 = iterable.iterator();
assertTrue(iter1.hasNext());
assertTrue(iter2.hasNext());
assertEquals(1, iter1.next());
assertFalse(iter1.hasNext());
assertTrue(iter2.hasNext());
assertEquals(2, iter2.next());
assertFalse(iter1.hasNext());
assertFalse(iter2.hasNext());
}
@Test
public void testForeach() {
Queue<Integer> queue = new LinkedList<>(List.of(1, 2, 3, 4));
IterableOnce<Integer> iterable = queue::poll;
Set<Integer> result = new HashSet<>();
for (var item : iterable) {
result.add(item);
}
assertEquals(Set.of(1, 2, 3, 4), result);
}
@Test
public void testForeachWithSupplierAccess() {
Queue<Integer> queue = new LinkedList<>(List.of(1, 2, 3, 4));
IterableOnce<Integer> iterable = queue::poll;
List<Integer> result = new ArrayList<>();
int iters = 0;
for (var item : iterable) {
result.add(item);
Integer item2 = iterable.get();
if (item2 != null) {
result.add(item2);
}
iters++;
}
assertEquals(List.of(1, 2, 3, 4), result.stream().sorted().toList());
assertEquals(3, iters);
}
}

Wyświetl plik

@ -28,8 +28,7 @@ public class WorkerPipelineTest {
next.accept(1);
}).addBuffer("reader_queue", 1)
.<Integer>addWorker("process", 1, (prev, next) -> {
Integer item;
while ((item = prev.get()) != null) {
for (Integer item : prev) {
next.accept(item * 2 + 1);
next.accept(item * 2 + 2);
}
@ -76,8 +75,7 @@ public class WorkerPipelineTest {
.readFrom("reader", List.of(0, 1))
.addBuffer("reader_queue", 1)
.<Integer>addWorker("process", 1, (prev, next) -> {
Integer item;
while ((item = prev.get()) != null) {
for (Integer item : prev) {
next.accept(item * 2 + 1);
next.accept(item * 2 + 2);
}
@ -106,8 +104,7 @@ public class WorkerPipelineTest {
if (failureStage == 2) {
throw new ExpectedException();
}
Integer item;
while ((item = prev.get()) != null) {
for (Integer item : prev) {
next.accept(item * 2 + 1);
next.accept(item * 2 + 2);
}