Jak používat funkci NumPy argmax() v Pythonu
V tomto návodu se naučíte, jak využít funkci NumPy argmax() k nalezení indexu největšího prvku v polích.
NumPy je klíčová knihovna pro vědecké výpočty v Pythonu. Nabízí N-rozměrná pole, která jsou efektivnější než standardní seznamy Pythonu. Běžnou operací při práci s poli NumPy je zjišťování maximální hodnoty. Někdy je však nutné zjistit i index, na kterém se tato maximální hodnota nachází.
Funkce argmax() vám pomůže nalézt index maxima jak v jednorozměrných, tak ve vícerozměrných polích. Pojďme se podívat, jak přesně to funguje.
Hledání indexu maximálního prvku v poli NumPy
Pro úspěšné absolvování tohoto návodu je třeba mít nainstalovaný Python a knihovnu NumPy. Kód můžete testovat buď v Python REPL, nebo v Jupyter notebooku.
Nejprve importujeme knihovnu NumPy pod běžně používaným aliasem np.
import numpy as np
K získání maximální hodnoty v poli, případně podél určité osy, můžete využít funkci NumPy max().
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# Výstup
10
V tomto případě np.max(array_1) vrátí hodnotu 10, což je správně.
Představme si, že potřebujete zjistit index, na kterém se maximální hodnota v poli vyskytuje. To lze provést pomocí následujících kroků:
- Zjištění maximálního prvku.
- Nalezení indexu tohoto maximálního prvku.
V poli array_1 se maximální hodnota 10 nachází na indexu 4 (při použití nulového indexování). První prvek má index 0, druhý index 1 a tak dále.
K nalezení indexu, kde se nachází maximum, lze použít funkci NumPy where(). Funkce np.where(podmínka) vrací pole všech indexů, kde je podmínka splněna.
Je třeba získat přístup k prvkům pole a extrahovat položku na prvním indexu. Pro zjištění, kde se maximální hodnota nachází, nastavíme podmínku na array_1==10. Nezapomeňte, že 10 je maximální hodnota v array_1.
print(int(np.where(array_1==10)[0]))
# Výstup
4
Funkci np.where() jsme použili pouze s podmínkou, ale toto není doporučený způsob jejího využití.
📝 Poznámka: Funkce NumPy where():np.where(podmínka, x, y) vrací:
- Prvky z
x, pokud je podmínka pravdivá (True) a - Prvky z
y, pokud je podmínka nepravdivá (False).
Zřetězením funkcí np.max() a np.where() můžeme tedy nalézt maximální prvek a následně i index, na kterém se nachází.
Namísto výše uvedeného dvoukrokového procesu je možné využít funkci NumPy argmax() k přímému získání indexu maximálního prvku v poli.
Syntaxe funkce NumPy argmax()
Obecná syntaxe pro použití funkce NumPy argmax() je následující:
np.argmax(array, axis, out)
# numpy jsme importovali pod aliasem np
Ve výše uvedené syntaxi:
arrayje libovolné platné NumPy pole.axisje volitelný parametr. Při práci s vícerozměrnými poli ho můžete použít pro nalezení indexu maxima podél konkrétní osy.outje další volitelný parametr. Můžete ho nastavit na NumPy pole pro uložení výstupu funkceargmax().
Poznámka: Od verze NumPy 1.22.0 existuje další parametr keepdims. Pokud ve volání funkce argmax() zadáme parametr axis, pole se zmenší podél této osy. Nastavením parametru keepdims na hodnotu True zajistíme, že vrácený výstup bude mít stejný tvar jako vstupní pole.
Použití NumPy argmax() k nalezení indexu maximálního prvku
#1. Použijme funkci NumPy argmax() k nalezení indexu maximálního prvku v array_1.
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# Výstup
4
Funkce argmax() vrátí 4, což je správně! ✅
#2. Pokud předefinujeme array_1 tak, že se hodnota 10 vyskytuje dvakrát, funkce argmax() vrátí pouze index prvního výskytu.
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# Výstup
4
Pro zbývající příklady použijeme prvky pole array_1, jak jsme ho definovali v příkladu č. 1.
Použití NumPy argmax() k nalezení indexu maximálního prvku ve 2D poli
Převeďme pole NumPy array_1 na dvourozměrné pole se dvěma řádky a čtyřmi sloupci.
array_2 = array_1.reshape(2,4)
print(array_2)
# Výstup
[[ 1 5 7 2]
[10 9 8 4]]
U dvourozměrného pole osa 0 označuje řádky a osa 1 označuje sloupce. Pole NumPy se indexují od nuly. Indexy řádků a sloupců pro pole array_2 jsou tedy následující:
Nyní zavolejme funkci argmax() na dvourozměrném poli array_2.
print(np.argmax(array_2))
# Výstup
4
I když jsme zavolali argmax() na dvourozměrném poli, stále se vrací 4. To je stejný výstup jako u jednorozměrného pole array_1 z předchozí části.
Proč se to děje? 🤔
Důvodem je, že jsme nezadali žádnou hodnotu pro parametr axis. Pokud tento parametr není nastaven, funkce argmax() ve výchozím nastavení vrací index maximálního prvku v rámci sloučeného pole.
Co je sloučené pole? Pokud existuje N-rozměrné pole o tvaru d1 x d2 x … x dN, kde d1, d2 až dN jsou velikosti pole podél N rozměrů, pak sloučené pole je jednorozměrné pole o velikosti d1 * d2 * … * dN.
Chcete-li zkontrolovat, jak vypadá sloučené pole pro array_2, můžete zavolat metodu flatten(), jak je uvedeno níže:
array_2.flatten()
# Výstup
array([ 1, 5, 7, 2, 10, 9, 8, 4])
Index maximálního prvku podél řádků (osa = 0)
Pokračujme hledáním indexu maximálního prvku podél řádků (axis = 0).
np.argmax(array_2, axis=0)
# Výstup
array([1, 1, 1, 1])
Tento výstup může být trochu obtížný na pochopení, ale vysvětlíme, jak funguje.
Parametr axis jsme nastavili na nulu (axis = 0), protože chceme nalézt index maximálního prvku podél řádků. Funkce argmax() proto vrací číslo řádku, ve kterém se vyskytuje prvek s maximální hodnotou – pro každý ze sloupců.
Pro lepší pochopení si to znázorníme.

Z výše uvedeného diagramu a výstupu funkce argmax() vyplývá:
- Pro první sloupec na indexu 0 se maximální hodnota 10 vyskytuje ve druhém řádku, na indexu = 1.
- Pro druhý sloupec na indexu 1 se maximální hodnota 9 vyskytuje ve druhém řádku, na indexu = 1.
- Pro třetí a čtvrtý sloupec na indexech 2 a 3 se maximální hodnoty 8 a 4 vyskytují ve druhém řádku, na indexu = 1.
Proto máme výstupní pole ([1, 1, 1, 1]), protože maximální prvek podél řádků se vyskytuje ve druhém řádku (pro všechny sloupce).
Index maximálního prvku podél sloupců (osa = 1)
Dále pomocí funkce argmax() nalezneme index maximálního prvku podél sloupců.
Spusťte následující kód a sledujte výstup.
np.argmax(array_2, axis=1)
array([2, 0])
Dokážete analyzovat tento výstup?
Nastavili jsme axis = 1, abychom vypočítali index maximálního prvku podél sloupců.
Funkce argmax() vrací pro každý řádek číslo sloupce, ve kterém se nachází maximální hodnota.
Zde je vizuální vysvětlení:

Z výše uvedeného diagramu a výstupu funkce argmax() vyplývá:
- Pro první řádek na indexu 0 se maximální hodnota 7 vyskytuje ve třetím sloupci, na indexu = 2.
- Pro druhý řádek na indexu 1 se maximální hodnota 10 vyskytuje v prvním sloupci, na indexu = 0.
Doufám, že nyní rozumíte výstupu array([2, 0]).
Použití volitelného parametru out ve funkci NumPy argmax()
Můžete použít volitelný parametr out ve funkci NumPy argmax() k uložení výstupu do NumPy pole.
Inicializujme pole nul pro uložení výstupu předchozího volání funkce argmax() – tedy pro nalezení indexu maxima podél sloupců (axis=1).
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
Vraťme se nyní k příkladu s hledáním indexu maximálního prvku podél sloupců (axis = 1) a nastavme out na out_arr, které jsme definovali výše.
np.argmax(array_2, axis=1, out=out_arr)
Vidíme, že interpret Pythonu vyvolá TypeError, protože out_arr byl ve výchozím nastavení inicializován jako pole s čísly s plovoucí desetinnou čárkou (floats).
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except TypeError:
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Proto je při nastavování parametru out pro výstupní pole důležité zajistit, aby mělo výstupní pole správný tvar a datový typ. Protože indexy pole jsou vždy celá čísla, měli bychom při definování výstupního pole nastavit parametr dtype na int.
out_arr = np.zeros((2,), dtype=int)
print(out_arr)
# Výstup
[0 0]
Nyní můžeme pokračovat a zavolat funkci argmax() s parametry axis i out, tentokrát to poběží bez chyby.
np.argmax(array_2, axis=1, out=out_arr)
Výstup funkce argmax() je nyní uložen v poli out_arr.
print(out_arr)
# Výstup
[2 0]
Závěr
Doufám, že vám tento návod pomohl pochopit, jak používat funkci NumPy argmax(). Kód si můžete vyzkoušet v Jupyter notebooku.
Zopakujme si, co jsme se naučili:
- Funkce
NumPy argmax()vrací index maximálního prvku v poli. Pokud se maximální prvek v poli vyskytuje vícekrát,np.argmax(a)vrátí index jeho prvního výskytu. - Při práci s vícerozměrnými poli můžete použít volitelný parametr
axisk získání indexu maximálního prvku podél konkrétní osy. Například ve dvourozměrném poli: nastavenímaxis = 0aaxis = 1můžete získat index maximálního prvku podél řádků a sloupců. - Pokud chcete uložit vrácenou hodnotu do jiného pole, můžete nastavit volitelný parametr
outna výstupní pole. Výstupní pole by však mělo mít kompatibilní tvar a datový typ.
Pro další informace se můžete podívat na podrobný průvodce o množinách v Pythonu.