Schnellste Möglichkeit zum Durchlaufen eines Numpy-Arrays

75904
embert

Ich habe eine Funktion geschrieben, um den Gamma-Koeffizienten eines Clusters zu berechnen. Der Engpass ist der Vergleich der Werte von dist_withingbis dist_between. Um dies zu beschleunigen, habe ich versucht, es mit Cython anzupassen und zu kompilieren (ich habe mich nur ein paar Mal mit C beschäftigt). Aber ich weiß nicht, wie ich schnell über numpy Arrays iterieren kann oder ob es überhaupt möglich ist, es schneller zu machen

for i in range(len(arr)):
    arr[i]

Ich dachte, ich einen Zeiger auf die Array - Daten verwenden können und in der Tat der Code ausgeführt wird nur in der Hälfte der Zeit, aber pointer1[i]und pointer2[j]in cdef unsigned int countlowernicht geben mir die erwarteten Werte aus den Arrays. Also, wie man richtig und schnell über ein Array iteriert? Und wo können noch Verbesserungen vorgenommen werden, auch wenn es in diesem Fall keinen Unterschied in Bezug auf die Laufzeitgeschwindigkeit gibt?

# cython: profile=True
import cython

import numpy as np
cimport numpy as np

from scipy.spatial.distance import squareform

DTYPE = np.float
DTYPEint = np.int

ctypedef np.float_t DTYPE_t
ctypedef np.int_t DTYPEint_t

@cython.profile(False)
cdef unsigned int countlower(np.ndarray[DTYPE_t, ndim=1] vec1,
                             np.ndarray[DTYPE_t, ndim=1] vec2,
                             int n1, int n2):
    # Function output corresponds to np.bincount(v1 < v2)[1]

    assert vec1.dtype == DTYPE and vec2.dtype == DTYPE

    cdef unsigned int i, j
    cdef unsigned int trues = 0
    cdef unsigned int* pointer1 = <unsigned int*> vec1.data
    cdef unsigned int* pointer2 = <unsigned int*> vec2.data

    for i in range(n1):
        for j in range(n2):
            if pointer1[i] < pointer2[j]:
                trues += 1

    return trues


def gamma(np.ndarray[DTYPE_t, ndim=2] Y, np.ndarray[DTYPEint_t, ndim=1] part):
    assert Y.dtype == DTYPE and part.dtype == DTYPEint

    if len(Y) != len(part):
        raise ValueError('Distance matrix and partition must have same shape')

    # defined locals
    cdef unsigned int K, c_label, n_, trues
    cdef unsigned int s_plus = 0
    cdef unsigned int s_minus = 0

    # assigned locals
    cdef np.ndarray n_in_ci = np.bincount(part)
    cdef int num_clust = len(n_in_ci) - 1
    cdef np.ndarray s = np.zeros(len(Y), dtype=DTYPE)

    # Partition should have at least two clusters
    K = len(set(part))
    if K < 2:
        return 0
    # Loop through clusters
    for c_label in range(1, K+1):
        dist_within = squareform(Y[part == c_label][:, part == c_label])
        dist_between = np.ravel(Y[part == c_label][:, part != c_label])
        n1 = len(dist_within)
        n2 = len(dist_between)

        trues = countlower(dist_within, dist_between, n1, n2)
        s_plus += trues
        s_minus += n1 * n2 - trues

    n_ =  s_plus + s_minus

    return (<double>s_plus - <double>s_minus) / <double>n_ if n_ != 0 else 0

Edit1: Wenn Sie nur die Zeiger übergeben, anstatt die Arrays an die zeitkritische Funktion zu übergeben (> 99% der Zeit wird dort verbracht), wird eine Beschleunigung von ~ 10% erzielt. Ich denke, manche Dinge können einfach nicht schneller gemacht werden

@cython.profile(False)
@cython.boundscheck(False)
@cython.nonecheck(False)
cdef unsigned int countlower(double* v1, double* v2, int n1, int n2):
    ''' Function output corresponds to np.bincount(v1 < v2)[1]'''
    ''' The upper is not correct. It rather corresponds to
    sum([np.bincount(v1[i] < v2)[1] for i in range(len(v1))])'''
    cdef unsigned int trues = 0

    cdef Py_ssize_t i, j
    with nogil, parallel():
        for i in prange(n1):
            for j in prange(n2):
                if v1[i] < v2[j]:
                    trues += 1
    return trues
Antworten
15

2 Antworten auf die Frage

24
Gareth Rees

1. Einleitung

Diese Frage ist schwierig, weil:

  1. Es ist nicht klar, was die Funktion countlowerbewirkt. Es ist immer eine gute Idee, einen Docstring für eine Funktion zu schreiben, in dem angegeben wird, was sie tut, welche Argumente erforderlich sind und was sie zurückgibt. (Und Testfälle werden immer geschätzt.)

  2. Es ist nicht klar, was die Rolle des Arguments n1und n2ist. Der Code in der Post gilt nur len(v1)für n1und len(v2)für n2. Ist das also eine Voraussetzung? Oder ist es manchmal möglich, andere Werte zu übergeben?

Ich werde im Folgenden davon ausgehen:

  1. die Spezifikation der countlowerFunktion ist Return the number of pairs i, j such that v1[i] < v2[j];

  2. n1ist immer len(v1)und n2ist immer len(v2);

  3. Die Cython-Details sind für das Problem nicht wesentlich, und es ist in Ordnung, in Python zu arbeiten.

Hier ist meine Neufassung der countlowerFunktion. Beachten Sie den docstring, den doctest und die einfache Implementierung, die die Sequenzelemente und nicht deren Indizes durchläuft :

def countlower1(v, w):
    """Return the number of pairs i, j such that v[i] < w[j].

    >>> countlower1(list(range(0, 200, 2)), list(range(40, 140)))
    4500

    """
    return sum(x < y for x in v for y in w)

Und hier ist ein Testfall mit 1000 Elementen, den ich im Rest dieser Antwort verwenden werde, um die Leistung verschiedener Implementierungen dieser Funktion zu vergleichen:

>>> v = np.array(list(range(0, 2000, 2)))
>>> w = np.array(list(range(400, 1400)))
>>> from timeit import timeit
>>> timeit(lambda:countlower1(v, w), number=1)
8.449613849865273

2. Vektorisieren

Der Grund für die Verwendung von NumPy besteht darin, dass Sie Operationen mit Arrays von numerischen Datentypen fester Größe vektorisieren können. Wenn Sie eine Operation erfolgreich vektorisieren können, wird sie hauptsächlich in C ausgeführt, wodurch der erhebliche Aufwand des Python-Interpreters vermieden wird.

Immer wenn Sie die Elemente eines Arrays durchlaufen, profitieren Sie von NumPy nicht, und dies ist ein Zeichen dafür, dass es an der Zeit ist, Ihre Herangehensweise zu überdenken.

Lassen Sie uns also die countlowerFunktion vektorisieren . Dies ist einfach mit einem spärlichen numpy.meshgrid:

import numpy as np

def countlower2(v, w):
    """Return the number of pairs i, j such that v[i] < w[j].

    >>> countlower2(np.arange(0, 2000, 2), np.arange(400, 1400))
    450000

    """
    grid = np.meshgrid(v, w, sparse=True)
    return np.sum(grid[0] < grid[1])

Mal sehen, wie schnell das im 1000-Element-Testfall ist:

>>> timeit(lambda:countlower2(v, w), number=1)
0.005706002004444599

Das ist etwa 1500 mal schneller als countlower1.

3. Verbessern Sie den Algorithmus

Für vektorisierte countlower2Arrays der Länge \ $ O (n) \ $ benötigt das vektorisierte System immer noch \ $ O (n ^ 2) \ $ Zeit, da jedes Elementpaar verglichen werden muss. Kann man das besser machen?

Angenommen, ich beginne damit, das erste Array zu sortieren v. Dann ein Element betrachtet yaus dem zweiten Array w, und die Stelle an, yin denen sortierte ersten Array passen würde, das heißt, finden, iso dass v[i - 1] < y <= v[i]. Dann yist größer als iElemente aus v. Diese Position kann in time \ $ O (\ lognn) \ $ using gefunden werden bisect.bisect_left, so dass der Algorithmus als Ganzes eine Laufzeit von \ $ O (n \ logn) \ $ hat.

Hier ist eine einfache Implementierung:

from bisect import bisect_left

def countlower3(v, w):
    """Return the number of pairs i, j such that v[i] < w[j].

    >>> countlower3(list(range(0, 2000, 2)), list(range(400, 1400)))
    450000

    """
    v = sorted(v)
    return sum(bisect_left(v, y) for y in w)

Diese Implementierung ist etwa dreimal schneller als countlower3im Testfall mit 1000 Elementen:

>>> timeit(lambda:countlower3(v, w), number=1)
0.0021441911812871695

Dies zeigt, wie wichtig es ist, den besten Algorithmus zu finden, und nicht nur den Algorithmus, den Sie haben, zu beschleunigen. Hier schlägt ein \ $ O (n \ log n) \ $ -Algorithmus im einfachen Python einen vektorisierten \ $ O (n ^ 2) \ $ -Algorithmus in NumPy.

4. Vektorisieren Sie erneut

Jetzt können wir den verbesserten Algorithmus vektorisieren numpy.searchsorted:

import numpy as np

def countlower4(v, w):
    """Return the number of pairs i, j such that v[i] < w[j].

    >>> countlower4(np.arange(0, 20000, 2), np.arange(4000, 14000))
    45000000

    """
    return np.sum(np.searchsorted(np.sort(v), w))

Und das ist sechsmal schneller:

>>> timeit(lambda:countlower4(v, w), number=1)
0.0003434771206229925

5. Antworten auf Ihre Fragen

In Kommentaren haben Sie gefragt:

  1. "Was heißt vektorisieren?" Bitte lesen Sie den Abschnitt " Was ist NumPy? " Der NumPy-Dokumentation, insbesondere den Abschnitt, der beginnt:

    Vektorisierung beschreibt das Fehlen expliziter Schleifen, Indexierungen usw. im Code. Diese Dinge finden natürlich nur hinter den Szenen statt. (in optimiertem, vorkompiliertem C-Code).

  2. "Was ist Maschennetz?" Bitte lesen Sie die Dokumentation fürnumpy.meshgrid .

    Ich verwende meshgrid, um ein NumPy-Array zu erstellen, griddas alle Elementpaare enthält, von x, ydenen xein Element vund yein Element von ist w. Dann <wende ich die Funktion auf diese Paare an und erhalte ein Array von Booleans, die ich zusammenfasse. Probieren Sie es im interaktiven Dolmetscher aus und überzeugen Sie sich selbst:

    >>> import numpy as np
    >>> v = [2, 4, 6]
    >>> w = [1, 3, 5]
    >>> np.meshgrid(v, w)
    [array([[2, 4, 6],
           [2, 4, 6],
           [2, 4, 6]]), array([[1, 1, 1],
           [3, 3, 3],
           [5, 5, 5]])]
    >>> _[0] < _[1]
    array([[False, False, False],
           [ True, False, False],
           [ True,  True, False]], dtype=bool)
    >>> np.sum(_)
    3
    
fuu, da das Lesen von Code für mich ziemlich schwierig ist, werde ich in Zukunft versuchen, eine bessere Dokumentation zu erstellen. Ihre Annahmen sind korrekt. Das Sortierzeug erweitert mein Denken. Was bedeutet Vektorisieren im Allgemeinen und im Fall von NumPy? Was haben Sie in Abschnitt 2 gemacht (was ist hier ein Netzgitter?) Die Hauptfunktion wurde im Vergleich zu meiner edit1-Version etwa 30-mal schneller. Keine Notwendigkeit mehr, Cython zu verwenden. Jedenfalls, wenn ich daraus eine Cython-Datei mache, wird es doppelt so lange dauern. Irgendeine Idee, warum das so ist? embert vor 6 Jahren 0
Siehe überarbeitete Antwort. Gareth Rees vor 6 Jahren 1
@embert "Wenn ich daraus eine Cython-Datei mache, wird es doppelt so lange dauern. Irgendeine Idee, warum das so ist?" Der Pythonâ € ™ Cythonâ € ™ Python fügt lediglich eine Indirektion hinzu, bei der der Python-Typ in eine "spezielle Form" für Cython umgewandelt werden muss, aber diese "spezielle Form" wird nie verwendet. Außerdem landet es ohnehin in einer schnellen C-Routine (https://github.com/numpy/numpy/blob/d1987d11dfe5101d3c0b12fecaae05570f361d44/numpy/core/src/multiarray/item_selection.c#L1911). Ich habe herausgefunden, dass das Inlining dieser Routinen in Cython * helfen kann, indem Sie Overhead entfernen, aber nicht immer * viel *. Veedrac vor 5 Jahren 0
Schauen Sie auch [in this] (https://github.com/numpy/numpy/blob/db198d5a3d31374985a24d3c44c88c356d0b3a3e/numpy/core/src/npysort/binsearch.c.src#L39) nach, wenn Sie dies wünschen. Ich vermute nicht, auch wenn ich keine Monate zu spät gekommen bin. Veedrac vor 5 Jahren 0
3
Matt

When you deal with performance in cython, I would suggest using the --annotate flag (or use IPython with cython magic that allow you quick iteration with anotate flag too), it will tell you which part of your code may be slow. It generates an Html report with highlighted lines. The more yellow, potentially the slower. You can also click on the line o see the generated C, and generally you just call out into Python world from C when things get slow, like checking array bounds, negative indexing, catching exceptions... So, you might want to use the following decorators on your functions if you know you won't have out of bounds errors, or negative indexing from the end :

@cython.boundscheck(False)
@cython.wraparound(False)

Keep in mind that it you do have out of bounds, you will segfault.

This memview bench might give you ideas.

Yo might also want to look at numpy view, if you like to avoid copy and know things won't be muted (but I think it's the default now)

Die Memview Bank ist sehr hilfreich. Verstehe jetzt ein bisschen besser. Die Option -a ist auch hilfreich. Die Überprüfung auf Nullteilung erzeugt zwar einen gewissen C-Code, die Geschwindigkeit ist jedoch unwesentlich. Wenn ich mit numpy-Typen arbeite, sollte ich smth wie np.float_t statt double verwenden? Ich brauche einen Line-Profiler, der mit Cython arbeitet! embert vor 6 Jahren 0
Mein Wissen über Cython geht nicht so weit. Es scheint mir jedoch, dass die Unterteilung in C im Allgemeinen viel langsamer ist als die Multiplikation. Schnelle Suche gebe mir [@cython.cdivision (True) `Dekorateur] (http://docs.cython.org/src/reference/compilation.html). Wenn Sie sich den Def für `n_` ansehen, würde ich sagen, dass Sie die Logik vereinfachen können (` x + trues-trues = x`). Matt vor 6 Jahren 0