Jak používat funkci NumPy argmax() v Pythonu

V tomto tutoriálu se naučíte, jak používat funkci NumPy argmax() k nalezení indexu maximálního prvku v polích.

NumPy je výkonná knihovna pro vědecké výpočty v Pythonu; poskytuje N-rozměrná pole, která jsou výkonnější než seznamy Pythonu. Jednou z běžných operací, které budete provádět při práci s poli NumPy, je nalezení maximální hodnoty v poli. Někdy však můžete chtít najít index, na kterém se vyskytuje maximální hodnota.

Funkce argmax() vám pomůže najít index maxima v jednorozměrných i vícerozměrných polích. Pojďme se naučit, jak to funguje.

Jak najít index maximálního prvku v poli NumPy

Abyste mohli pokračovat v tomto tutoriálu, musíte mít nainstalované Python a NumPy. Kódovat můžete spuštěním Python REPL nebo spuštěním poznámkového bloku Jupyter.

Nejprve importujme NumPy pod obvyklým aliasem np.

import numpy as np

K získání maximální hodnoty v poli (volitelně podél určité osy) můžete použít funkci NumPy max().

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Output
10

V tomto případě np.max(pole_1) vrátí 10, což je správné.

Předpokládejme, že byste chtěli najít index, při kterém se v poli vyskytuje maximální hodnota. Můžete použít následující dvoustupňový přístup:

  • Najděte maximální prvek.
  • Najděte index maximálního prvku.
  • V poli_1 se maximální hodnota 10 vyskytuje na indexu 4 po nulovém indexování. První prvek je na indexu 0; druhý prvek je na indexu 1 a tak dále.

    Chcete-li najít index, na kterém se vyskytuje maximum, můžete použít funkci NumPy where(). np.where(condition) vrací pole všech indexů, kde je podmínka True.

    Budete muset klepnout na pole a získat přístup k položce na prvním indexu. Abychom zjistili, kde se vyskytuje maximální hodnota, nastavíme podmínku na pole_1==10; připomeňme, že 10 je maximální hodnota v poli_1.

    print(int(np.where(array_1==10)[0]))
    
    # Output
    4

    Použili jsme np.where() pouze s podmínkou, ale toto není doporučená metoda pro použití této funkce.

      Budou vaše oblíbené aplikace pro iPhone fungovat na iPadu?

    📑 Poznámka: Funkce NumPy where():
    np.where(condition,x,y) vrací:

    – Prvky z x, když je podmínka True, a
    – Prvky z y, když je podmínka False.

    Zřetězením funkcí np.max() a np.where() tedy můžeme najít maximální prvek, za kterým následuje index, na kterém se vyskytuje.

    Místo výše uvedeného dvoufázového procesu můžete použít funkci NumPy argmax() k získání indexu maximálního prvku v poli.

    Syntaxe funkce NumPy argmax().

    Obecná syntaxe pro použití funkce NumPy argmax() je následující:

    np.argmax(array,axis,out)
    # we've imported numpy under the alias np

    Ve výše uvedené syntaxi:

    • pole je jakékoli platné pole NumPy.
    • osa je volitelný parametr. Při práci s vícerozměrnými poli můžete použít parametr axis k nalezení indexu maxima podél konkrétní osy.
    • out je další volitelný parametr. Parametr out můžete nastavit na pole NumPy pro uložení výstupu funkce argmax().

    Poznámka: Od verze NumPy 1.22.0 je zde další parametr keepdims. Když ve volání funkce argmax() zadáme parametr axis, pole se zmenší podél této osy. Ale nastavení parametru keepdims na hodnotu True zajistí, že vrácený výstup bude mít stejný tvar jako vstupní pole.

    Použití NumPy argmax() k nalezení indexu maximálního prvku

    #1. Použijme funkci NumPy argmax() k nalezení indexu maximálního prvku v poli_1.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Funkce argmax() vrací 4, což je správně! ✅

    #2. Pokud předefinujeme pole_1 tak, že 10 se vyskytuje dvakrát, funkce argmax() vrátí pouze index prvního výskytu.

    array_1 = np.array([1,5,7,2,10,10,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Pro zbytek příkladů použijeme prvky pole_1, které jsme definovali v příkladu č. 1.

    Použití NumPy argmax() k nalezení indexu maximálního prvku ve 2D poli

    Pojďme přetvořit pole NumPy array_1 na dvourozměrné pole se dvěma řádky a čtyřmi sloupci.

    array_2 = array_1.reshape(2,4)
    print(array_2)
    
    # Output
    [[ 1  5  7  2]
     [10  9  8  4]]

    U dvourozměrného pole osa 0 označuje řádky a osa 1 označuje sloupce. Pole NumPy následují nulové indexování. Indexy řádků a sloupců pro pole NumPy array_2 jsou tedy následující:

    Nyní zavolejte funkci argmax() na dvourozměrném poli pole_2.

    print(np.argmax(array_2))
    
    # Output
    4

    I když jsme na dvourozměrném poli zavolali argmax(), stále vrací 4. Toto je identické s výstupem pro jednorozměrné pole, pole_1 z předchozí části.

      Jak zvýrazňovat a komentovat soubory PDF na iPadu

    Proč se to děje? 🤔

    Důvodem je, že jsme nezadali žádnou hodnotu pro parametr osy. Když tento parametr osy není nastaven, funkce argmax() ve výchozím nastavení vrací index maximálního prvku podél sloučeného pole.

    Co je zploštělé pole? Pokud existuje N-rozměrné pole tvaru d1 x d2 x … x dN, kde d1, d2, až dN jsou velikosti pole podél N rozměrů, pak zploštělé pole je dlouhé jednorozměrné pole velikosti d1 * d2 * … * dN.

    Chcete-li zkontrolovat, jak vypadá sloučené pole pro pole_2, můžete zavolat metodu flatten(), jak je uvedeno níže:

    array_2.flatten()
    
    # Output
    array([ 1,  5,  7,  2, 10,  9,  8,  4])

    Index maximálního prvku podél řádků (osa = 0)

    Pokračujme v hledání indexu maximálního prvku podél řádků (osa = 0).

    np.argmax(array_2,axis=0)
    
    # Output
    array([1, 1, 1, 1])

    Tento výstup může být trochu obtížné pochopit, ale pochopíme, jak to funguje.

    Parametr axis jsme nastavili na nulu (osa = 0), protože bychom chtěli najít index maximálního prvku podél řádků. Proto funkce argmax() vrací číslo řádku, ve kterém se vyskytuje prvek maxima – pro každý ze tří sloupců.

    Pojďme si to pro lepší pochopení představit.

    Z výše uvedeného diagramu a výstupu argmax() máme následující:

    • Pro první sloupec na indexu 0 se maximální hodnota 10 vyskytuje ve druhém řádku, na indexu = 1.
    • Pro druhý sloupec na indexu 1 se maximální hodnota 9 vyskytuje ve druhém řádku, na indexu = 1.
    • Pro třetí a čtvrtý sloupec u indexu 2 a 3 se maximální hodnoty 8 a 4 vyskytují ve druhém řádku, u indexu = 1.

    To je přesně důvod, proč máme výstupní pole ([1, 1, 1, 1]), protože maximální prvek podél řádků se vyskytuje ve druhém řádku (pro všechny sloupce).

    Index maximálního prvku podél sloupců (osa = 1)

    Dále pomocí funkce argmax() najdeme index maximálního prvku podél sloupců.

    Spusťte následující fragment kódu a sledujte výstup.

    np.argmax(array_2,axis=1)
    array([2, 0])

    Můžete analyzovat výstup?

    Nastavili jsme axis = 1 pro výpočet indexu maximálního prvku podél sloupců.

    Funkce argmax() vrací pro každý řádek číslo sloupce, ve kterém se vyskytuje maximální hodnota.

      Jak převést soubor CSV na VCF pro přenos kontaktů

    Zde je vizuální vysvětlení:

    Z výše uvedeného diagramu a výstupu argmax() máme následující:

    • Pro první řádek na indexu 0 se maximální hodnota 7 vyskytuje ve třetím sloupci, na indexu = 2.
    • Pro druhý řádek na indexu 1 se maximální hodnota 10 vyskytuje v prvním sloupci, na indexu = 0.

    Doufám, že nyní rozumíte výstupu, pole([2, 0]) znamená.

    Použití parametru volitelného výstupu v NumPy argmax()

    Můžete použít volitelný parametr out the ve funkci NumPy argmax() k uložení výstupu do pole NumPy.

    Inicializujeme pole nul pro uložení výstupu předchozího volání funkce argmax() – abychom našli index maxima podél sloupců (osa=1).

    out_arr = np.zeros((2,))
    print(out_arr)
    [0. 0.]

    Nyní se vraťme k příkladu nalezení indexu maximálního prvku podél sloupců (osa = 1) a nastavíme out na out_arr, který jsme definovali výše.

    np.argmax(array_2,axis=1,out=out_arr)

    Vidíme, že interpret Pythonu vyvolá TypeError, protože out_arr byl ve výchozím nastavení inicializován na pole floats.

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
         56     try:
    ---> 57         return bound(*args, **kwds)
         58     except TypeError:
    
    TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

    Proto při nastavování parametru out pro výstupní pole je důležité zajistit, aby výstupní pole mělo správný tvar a datový typ. Protože indexy pole jsou vždy celá čísla, měli bychom při definování výstupního pole nastavit parametr dtype na int.

    out_arr = np.zeros((2,),dtype=int)
    print(out_arr)
    
    # Output
    [0 0]

    Nyní můžeme pokračovat a zavolat funkci argmax() s parametry axis i out a tentokrát běží bez chyby.

    np.argmax(array_2,axis=1,out=out_arr)

    Výstup funkce argmax() je nyní přístupný v poli out_arr.

    print(out_arr)
    # Output
    [2 0]

    Závěr

    Doufám, že vám tento tutoriál pomohl pochopit, jak používat funkci NumPy argmax(). Příklady kódu můžete spustit v poznámkovém bloku Jupyter.

    Zopakujme si, co jsme se naučili.

    • Funkce NumPy argmax() vrací index maximálního prvku v poli. Pokud se maximální prvek vyskytuje v poli a více než jednou, pak np.argmax(a) vrátí index prvního výskytu prvku.
    • Při práci s vícerozměrnými poli můžete použít volitelný parametr axis k získání indexu maximálního prvku podél konkrétní osy. Například ve dvourozměrném poli: nastavením osy = 0 a osy = 1 můžete získat index maximálního prvku podél řádků a sloupců.
    • Pokud byste chtěli uložit vrácenou hodnotu do jiného pole, můžete nastavit volitelný parametr out na výstupní pole. Výstupní pole by však mělo mít kompatibilní tvar a datový typ.

    Dále se podívejte na podrobnou příručku o sadách Pythonu.