Mittwoch, November 15, 2017

Java 8 Collection Beispiel | Stream#collect()

Als Ergänzung zum letzten Blogeintrag heute der Kollege von Stream#reduce(): Stream#collect(): In den folgenden Überlegungen versuche ich herauszufinden, wie sich die beiden Methoden voneinander abgrenzen. D.h. wann sollte ich #reduce() verwenden, wann besser #collect()? Als erstes ein Blick in die Interface Definition der beiden Methoden.

interface Stream<T> {
    // Die verschiedenen reduce(...) Methoden
    Optional<T> reduce(BinaryOperator<T> accumulator);
              T reduce(T identity,
                       BinaryOperator<T> accumulator);
          <U> U reduce(U identity,
                       BiFunction<U, ? super T, U> accumulator,
                       BinaryOperator<U> combiner);

    // Die verschiedenen collect(...) Methoden
          <U> U collect(Supplier<U> supplier,
                        BiConsumer<U, ? super T> accumulator,
                        BiConsumer<U, U> combiner);
       <R, A> R collect(Collector<? super T, A, R> collector);
}

Gemeinsamer Parameter in allen Methodendefinition ist der accumulator. In der reduce Variante ist der accumulator eine BiFunction und liefert ein Ergebnis zurück. In der collect Variante wird für den accumulator ein BiConsumer übergeben. Dieser liefert kein Ergebnis zurück. Eine weitere Gemeinsamkeit findet sich im Parameter combiner, welcher im Falle der parallelen Abarbeitung eines Streams zur Anwendung kommt.

Bei mir drängt sich die Frage auf, wann benutze ich welche Methode. Gibt es dafür eine Regel? Einen Regelsatz oder zumindest eine Faustformel? Bei meinem Streifzug durch das Internet konnte ich keine definitiven Regeln finden. Dafür aber allerhand Beispiele für reduce und collect. Zu erst ein Beispiel für reduce:

List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
assertThat(numbers.stream().reduce((n, m) -> n + m).get()).isEqualTo(55);

Falls identity = 0 angenommen werden kann, dann kann man auf die Optional#get() Abfrage verzichten.

assertThat(numbers.stream().reduce(0, (n, m) -> n + m)).isEqualTo(55);
assertThat(numbers.stream().reduce(0, Integer::sum)).isEqualTo(55);

Wie würde das mit collect aussehen?

assertThat(numbers.stream().collect(
    Collectors.summingInt(n -> n)).intValue()).isEqualTo(55);

Für die collect Methode bietet Java einige nützliche Utilities in java.util.stream.Collectors an (Siehe Java API Collectors). Eine dieser nützlichen Helferlein ist, wie oben gesehen, z.B. Collectors#summingXxx(). Im Vergleich zu dem reduce Beispiel wirkt die collect-Variante etwas umständlich. Ich würde hier die reduce Variante vorziehen. Das Ergebnis ist das Gleiche.

Falls man die Sache ausschreibt (also nicht die Helfer aus der Klasse Collectors verwendet) und nicht wie Collectors.summingInt() intern mit einem Array arbeitet, dann sieht die Sache für collect folgendermaßen aus (Das ist wirklich keine schöne Lösung, zeigt aber, wo die Schwäche bzw. dann auch die Stärke von collect liegt): Zunächst benötigt man eine Ablage für die Akkumulationsergebnisse, da der BiConsumer kein Ergebnis zurückliefert. Ich nenne die Zwischenablage hier IntHolder:

private class IntHolder {
    private int value;
    public IntHolder(int value) {
        this.value = value;
    }
    public void setValue(int value) {
        this.value = value;
    }
    public int getValue() {
        return value;
    }
}

Im Anschluss kann collect() mit den Parametern gefüttert werden.

IntHolder intHolder = numbers.stream().collect(() -> {
    return new IntHolder(0);
}, (l, v) -> {
    l.setValue(l.getValue() + v);
}, (l, m) -> {
    l.setValue(l.getValue() + m.getValue());
});
assertThat(intHolder.getValue()).isEqualTo(55);

Durch die Verwendung eines IntHolder wird die Implementierung aufgebläht. Reduce kann in diesem Fall eindeutig mit einer kompakteren Schreibweise punkten. Hier noch drei Alternativen für das Summieren von Integers aus einem Stream:

assertThat(numbers.stream().mapToInt(i -> i.intValue()).sum()).isEqualTo(55);
assertThat(numbers.stream().mapToInt(Integer::intValue).sum()).isEqualTo(55);
assertThat(numbers.stream().mapToInt(i -> i).sum()).isEqualTo(55);

Statt int-Werte zu addieren, versuche ich es nun einmal mit String Werten.

final String STRING = "abcdefghi";
List<String> strings = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", "i");
assertThat(strings.stream().reduce(
        "",
        (a, b) -> a + b)).isEqualTo(STRING); // (1)
assertThat(strings.stream().reduce(
        new StringBuilder(),
        StringBuilder::append,
        StringBuilder::append).toString()).isEqualTo(STRING); // (2)

Hier zeigt sich der Nachteil von reduce. In Beispiel (1) wird in jedem Akkumulationsschritt ein neuer String erzeugt. Das ist hinglänglich bekannt, dass dies in Java aus Performance-Sicht keine gute Idee ist. In Beispiel (2) wird der StringBuilder verwendet. Das Beispiel kann funktionieren, so lange der Stream nicht parallel abgearbeitet wird.

final String STRING = "abcdefghijklmnopqrstuvwxyz";
List<String> strings = Arrays.asList(
    "abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz");
assertThat(strings.parallelStream().reduce(
        new StringBuilder(),
        StringBuilder::append,
        StringBuilder::append).toString()).isNotEqualTo(STRING); (1)

Mit collect könnte man folgendes formulieren:

assertThat(strings.stream().collect(
        () -> new StringBuilder(),
        (a, b) -> a.append(b),
        (a, b) -> a.append(b)).toString()).isEqualTo(STRING); // (2)
assertThat(strings.stream().collect(
        () -> new StringBuilder(),
        StringBuilder::append,
        StringBuilder::append).toString()).isEqualTo(STRING); // (3)        
assertThat(strings.stream().collect(Collectors.joining())).isEqualTo(STRING); // (4)

Die Beispiele (1) und (3) sehen frappierend ähnlich aus. Man kann leicht übersehen, dass in dem einen Fall #reduce() verwendet wird, in dem anderen Fall #collect(). Die Variante mit collect funktioniert allerdings korrekt mit parallelen Streams. Am einfachsten ist Variante (4). Hier wird die Utility Klasse Collectors verwendet.

Ja und nun? Welche Schlussfolgerung kann man ziehen?

Der Unterschied zeigt sich im Akkumulator. #reduce() führt mit der BiFunction als Akkumulator eine funktionale Reduktion aus. #collect() ändert mit dem BiConsumer einen existierenden Wert und arbeitet somit nicht seiteneffektfrei. Das ist laut der Javadoc von BiConsumer auch nicht gefordert. Die Vereinbarung gilt für alle Varianten von java.util.function.Consumer. (Das kann bei der parallelen Verarbeitung eines Streams eventuell ein Problem sein). #reduce() ist dann im Vorteil, wenn die Zwischenergebnisse von accumulator ohne große Performance Verluste angelegt werden können. Zum Beispiel sind die einfachen numerischen Datentypen dafür sehr gut geeignet. D.h. habe ich einen Akkumulator, der einen numerischen Datentyp erzeugt, dann ist vermutlich #reduce() der geeignetere Kandidat. Sind (komplexe) Objekte das Zwischenergebnis (z.B. String wie oben), ist möglicherweise #collect() die bessere Wahl. Von der Definition her bietet #reduce() die größte Sicherheit bei der Verarbeitung eines parallelen Streams. Aber diese Sicherheit kann trügerisch sein:

List<Integer> numbers = Arrays.asList(
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20);
assertThat(numbers.stream().reduce(0, (n, m) -> n - m).intValue()).isEqualTo(-210);
assertThat(numbers.parallelStream().reduce(0, (n, m) -> n - m).intValue()).isNotEqualTo(-210);

Interessant ist vielleicht eine kurze Betrachtung der Laufzeiten von reduce und collect für das Summieren von Integer-Werten.

package misc;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.joda.time.DateTime;
import org.junit.Test;

public class ReduceVsCollectRuntime {

    @Test
    public void joiningSomeStrings() {
        List<Integer> integers = integers(4_000_000);

        long runtimeWithReduce = 0;
        long runtimeWithCollect = 0;
        long runtimeWithCollectors = 0;

        runtimeWithReduce = runtime(
            integers.stream(), ReduceVsCollectRuntime::joinWithReduce);
        runtimeWithCollect = runtime(
            integers.stream(), ReduceVsCollectRuntime::joinWithCollect);
        runtimeWithCollectors = runtime(
            integers.stream(), ReduceVsCollectRuntime::joinWithCollectors);

        System.out.println("Reduce Millis: " + runtimeWithReduce);
        System.out.println("Collect Millis: " + runtimeWithCollect);
        System.out.println("Collectors Millis: " + runtimeWithCollectors);

        runtimeWithReduce = runtime(
            integers.parallelStream(), ReduceVsCollectRuntime::joinWithReduce);
        runtimeWithCollect = runtime(
            integers.parallelStream(), ReduceVsCollectRuntime::joinWithCollect);
        runtimeWithCollectors = runtime(
            integers.parallelStream(), ReduceVsCollectRuntime::joinWithCollectors);

        System.out.println("Parallel Reduce Millis: " + runtimeWithReduce);
        System.out.println("Parallel Collect Millis: " + runtimeWithCollect);
        System.out.println("Parallel CollectorsMillis: " + runtimeWithCollectors);
    }

    public static Long joinWithReduce(Stream<Integer> stream) {
        return stream.reduce(
                0L,
                (n, m) -> n.longValue() + m.longValue(),
                (n, m) -> n.longValue() + m.longValue());
    }

    public static Long joinWithCollect(Stream<Integer> stream) {
        return stream.collect(
                () -> new long[1],
                (n, m) -> n[0] += m.longValue(),
                (n, m) -> n[0] += m[0])[0];
    }

    public static Long joinWithCollectors(Stream<Integer> stream) {
        return stream.collect(Collectors.summingLong(i -> i));
    }

    private long runtime(Stream<Integer> stream, Function<Stream<Integer>, Long> streamFunction) {
        DateTime start = DateTime.now();
        Long joinReduce = streamFunction.apply(stream);
        System.out.println("Sum: " + joinReduce);
        DateTime end = DateTime.now();
        return end.getMillis() - start.getMillis();
    }

    private List<Integer> integers(int nums) {
        Random random = new Random();
        List<Integer> integers = new ArrayList<>();
        for (int i = 0; i < nums; i++) {
            integers.add(random.nextInt());
        }
        return integers;
    }

}

Und das ist das Ergebnis:

Sum: 4832884415034
Sum: 4832884415034
Sum: 4832884415034
Reduce Millis: 171
Collect Millis: 16
Collectors Millis: 31
Sum: 4832884415034
Sum: 4832884415034
Sum: 4832884415034
Parallel Reduce Millis: 62
Parallel Collect Millis: 0
Parallel Collectors Millis: 16

In den getesteten Fällen schneidet #collect() jeweils am besten ab (Gemessen auf einem i5-4570 CPU 3.2GHz). Also gewinnt immer #collect()?


Weitere Links:

Allgemeine Informationen zu Lambdas, Functions, Streams, etc.:

AssertJ und java.util.List

AssertJ hat eine praktische Möglichkeit, Listen in JUnit Tests abzuprüfen. Insbesondere, wenn in der Liste komplexe Objekte abgelegt sind, s...