Како користити функцију НумПи аргмак() у Питхон-у

U ovom uputstvu, naučićete kako da koristite funkciju NumPy argmax() da pronađete indeks maksimalnog elementa u nizovima.

NumPy je moćna biblioteka za naučno računanje u Pajtonu; pruža N-dimenzionalne nizove koji su efikasniji od standardnih Pajtonovih listi. Jedna od uobičajenih operacija koje ćete izvoditi kada radite sa NumPy nizovima jeste pronalaženje maksimalne vrednosti u nizu. Međutim, ponekad ćete možda želeti da pronađete indeks na kojem se javlja maksimalna vrednost.

Funkcija argmax() vam pomaže da pronađete indeks maksimuma u jednodimenzionalnim i višedimenzionalnim nizovima. Nastavimo da naučimo kako to funkcioniše.

Kako pronaći indeks maksimalnog elementa u NumPy nizu

Da biste pratili ovo uputstvo, potrebno je da imate instalirane Pajton i NumPy. Možete kodirati tako što ćete pokrenuti Pajton REPL ili koristiti Jupyter notebook.

Prvo, hajde da uvezemo NumPy pod uobičajenim aliasom np.

import numpy as np

Možete koristiti funkciju NumPy max() da biste dobili maksimalnu vrednost u nizu (opciono duž određene ose).

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Output
10

U ovom slučaju, np.max(array_1) vraća 10, što je ispravno.

Pretpostavimo da želite da pronađete indeks na kojem se javlja maksimalna vrednost u nizu. Možete koristiti sledeći pristup u dva koraka:

  • Pronađite maksimalni element.
  • Pronađite indeks maksimalnog elementa.
  • U nizu array_1, maksimalna vrednost od 10 se javlja na indeksu 4, nakon indeksiranja od nule. Prvi element je na indeksu 0; drugi element je na indeksu 1, i tako dalje.

    Da biste pronašli indeks na kome se javlja maksimum, možete koristiti funkciju NumPy where(). np.where(condition) vraća niz svih indeksa gde je uslov Tačan.

    Moraćete da pristupite elementu na prvom indeksu. Da bismo pronašli gde se javlja maksimalna vrednost, postavljamo uslov array_1==10; podsetimo se da je 10 maksimalna vrednost u nizu array_1.

    print(int(np.where(array_1==10)[0]))
    
    # Output
    4

    Koristili smo np.where() samo sa uslovom, ali ovo nije preporučeni metod za korišćenje ove funkcije.

    📑 Napomena: NumPy where() funkcija:
    np.where(uslov,x,y) vraća:

    – Elemente iz x kada je uslov Tačan, i
    – Elemente iz y kada je uslov Netačan.

    Stoga, kombinovanjem funkcija np.max() i np.where() možemo pronaći maksimalni element, praćen indeksom na kojem se pojavljuje.

    Umesto prethodnog procesa u dva koraka, možete koristiti funkciju NumPy argmax() da biste dobili indeks maksimalnog elementa u nizu.

    Sintaksa funkcije NumPy argmax()

    Opšta sintaksa za korišćenje funkcije NumPy argmax() je sledeća:

    np.argmax(array,axis,out)
    # uvezli smo numpy pod aliasom np

    U gornjoj sintaksi:

    • array je bilo koji važeći NumPy niz.
    • axis je opcioni parametar. Kada radite sa višedimenzionalnim nizovima, možete koristiti parametar ose da biste pronašli indeks maksimuma duž određene ose.
    • out je još jedan opcioni parametar. Parametar out možete postaviti na NumPy niz za skladištenje izlaza funkcije argmax().

    Napomena: Od NumPy verzije 1.22.0 postoji dodatni parametar keepdims. Kada navedemo parametar ose u pozivu funkcije argmax(), niz se smanjuje duž te ose. Ali postavljanje parametra keepdims na True osigurava da vraćeni izlaz bude istog oblika kao i ulazni niz.

    Korišćenje NumPy argmax() za pronalaženje indeksa maksimalnog elementa

    #1. Hajde da upotrebimo funkciju NumPy argmax() da pronađemo indeks maksimalnog elementa u array_1.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Funkcija argmax() vraća 4, što je ispravno! ✅

    #2. Ako redefinišemo array_1 tako da se 10 pojavljuje dva puta, funkcija argmax() vraća samo indeks prvog pojavljivanja.

    array_1 = np.array([1,5,7,2,10,10,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Za ostale primere, koristićemo elemente niza array_1 koje smo definisali u primeru #1.

    Korišćenje NumPy argmax() za pronalaženje indeksa maksimalnog elementa u 2D nizu

    Hajde da preoblikujemo NumPy niz array_1 u dvodimenzionalni niz sa dva reda i četiri kolone.

    array_2 = array_1.reshape(2,4)
    print(array_2)
    
    # Output
    [[ 1  5  7  2]
     [10  9  8  4]]

    Za dvodimenzionalni niz, osa 0 označava redove, a osa 1 označava kolone. NumPy nizovi prate nulto indeksiranje. Dakle, indeksi redova i kolona za NumPy niz array_2 su sledeći:

    Sada, pozovimo funkciju argmax() na dvodimenzionalnom nizu, array_2.

    print(np.argmax(array_2))
    
    # Output
    4

    Iako smo pozvali argmax() za dvodimenzionalni niz, on i dalje vraća 4. Ovo je identično izlazu za jednodimenzionalni niz, array_1 iz prethodnog odeljka.

    Zašto se to dešava? 🤔

    To je zato što nismo naveli nijednu vrednost za parametar ose. Kada ovaj parametar ose nije podešen, funkcija argmax() podrazumevano vraća indeks maksimalnog elementa duž spljoštenog niza.

    Šta je spljošteni niz? Ako postoji N-dimenzionalni niz oblika d1 x d2 x … x dN, gde su d1, d2, do dN veličine niza duž N dimenzija, onda je spljošteni niz dugačak jednodimenzionalni niz veličine d1 * d2 * … * dN.

    Da biste proverili kako spljošteni niz izgleda za array_2, možete pozvati metod flatten(), kao što je prikazano u nastavku:

    array_2.flatten()
    
    # Output
    array([ 1,  5,  7,  2, 10,  9,  8,  4])

    Indeks maksimalnog elementa duž redova (osa = 0)

    Nastavimo da pronađemo indeks maksimalnog elementa duž redova (osa = 0).

    np.argmax(array_2,axis=0)
    
    # Output
    array([1, 1, 1, 1])

    Ovaj izlaz može biti malo teži za razumevanje, ali razumećemo kako funkcioniše.

    Postavili smo parametar axis na nulu (axis = 0), jer bismo želeli da pronađemo indeks maksimalnog elementa duž redova. Stoga, funkcija argmax() vraća broj reda u kojem se pojavljuje maksimalni element—za svaku od četiri kolone.

    Hajde da vizualizujemo ovo radi boljeg razumevanja.

    Iz gornjeg dijagrama i argmax() izlaza, imamo sledeće:

    • Za prvu kolonu sa indeksom 0, maksimalna vrednost 10 se javlja u drugom redu, na indeksu = 1.
    • Za drugu kolonu sa indeksom 1, maksimalna vrednost 9 se javlja u drugom redu, na indeksu = 1.
    • Za treću i četvrtu kolonu sa indeksima 2 i 3, maksimalne vrednosti 8 i 4 se javljaju u drugom redu, na indeksu = 1.

    Upravo zbog toga imamo izlazni niz ([1, 1, 1, 1]) jer se maksimalni element duž redova javlja u drugom redu (za sve kolone).

    Indeks maksimalnog elementa duž kolona (osa = 1)

    Zatim, upotrebimo funkciju argmax() da pronađemo indeks maksimalnog elementa duž kolona.

    Pokrenite sledeći isečak koda i posmatrajte izlaz.

    np.argmax(array_2,axis=1)
    array([2, 0])

    Možete li raščlaniti izlaz?

    Postavili smo axis = 1 da bismo izračunali indeks maksimalnog elementa duž kolona.

    Funkcija argmax() vraća, za svaki red, broj kolone u kojoj se javlja maksimalna vrednost.

    Evo vizuelnog objašnjenja:

    Iz gornjeg dijagrama i argmax() izlaza, imamo sledeće:

    • Za prvi red sa indeksom 0, maksimalna vrednost 7 se javlja u trećoj koloni, na indeksu = 2.
    • Za drugi red sa indeksom 1, maksimalna vrednost 10 se javlja u prvoj koloni, na indeksu = 0.

    Nadam se da sada razumete šta je izlaz, niz ([2, 0]) znači.

    Korišćenje opcionog izlaznog parametra u NumPy argmax()

    Možete koristiti opcioni parametar out u funkciji NumPy argmax() da biste sačuvali izlaz u NumPy nizu.

    Hajde da inicijalizujemo niz nula za čuvanje izlaza prethodnog poziva funkcije argmax() – da pronađemo indeks maksimuma duž kolona (osa=1).

    out_arr = np.zeros((2,))
    print(out_arr)
    [0. 0.]

    Sada, hajde da ponovo pogledamo primer pronalaženja indeksa maksimalnog elementa duž kolona (osa = 1) i postavimo out na out_arr koji smo definisali gore.

    np.argmax(array_2,axis=1,out=out_arr)

    Vidimo da Pajton interpreter izbacuje TypeError, pošto je out_arr podrazumevano inicijalizovan na niz float-a.

    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
         56     try:
    ---> 57         return bound(*args, **kwds)
         58     except TypeError:
    
    TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

    Stoga, kada postavljate izlazni parametar na izlazni niz, važno je osigurati da je izlazni niz ispravnog oblika i tipa podataka. Pošto su indeksi niza uvek celi brojevi, trebalo bi da postavimo parametar dtype na int kada definišemo izlazni niz.

    out_arr = np.zeros((2,),dtype=int)
    print(out_arr)
    
    # Output
    [0 0]

    Sada možemo da nastavimo i da pozovemo funkciju argmax() sa parametrima ose i izlaza, i ovaj put ona radi bez greške.

    np.argmax(array_2,axis=1,out=out_arr)

    Izlazu funkcije argmax() sada se može pristupiti u nizu out_arr.

    print(out_arr)
    # Output
    [2 0]

    Zaključak

    Nadam se da vam je ovo uputstvo pomoglo da razumete kako da koristite funkciju NumPy argmax(). Primere koda možete pokrenuti u Jupyter notebook-u.

    Hajde da pregledamo šta smo naučili.

    • Funkcija NumPy argmax() vraća indeks maksimalnog elementa u nizu. Ako se maksimalni element pojavljuje više puta u nizu a, onda np.argmax(a) vraća indeks prvog pojavljivanja elementa.
    • Kada radite sa višedimenzionalnim nizovima, možete koristiti opcioni parametar ose da biste dobili indeks maksimalnog elementa duž određene ose. Na primer, u dvodimenzionalnom nizu: postavljanjem axis = 0 i axis = 1, možete dobiti indeks maksimalnog elementa duž redova i kolona, respektivno.
    • Ako želite da sačuvate vraćenu vrednost u drugom nizu, možete da podesite opcioni izlazni parametar na izlazni niz. Međutim, izlazni niz treba da bude kompatibilnog oblika i tipa podataka.

    Zatim pogledajte detaljan vodič o Pajton setovima.