Files
1ereNSI/knn/pokemons_knn.py

129 lines
3.9 KiB
Python

# Imports regroupés en haut du fichier (PEP 8)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# Chargement des données
pokemons = pd.read_csv('pokemons.csv')
# Afficher les premières lignes pour vérification
print("=== Aperçu des données ===")
print(pokemons.head())
# Restructuration des données pour avoir une colonne "type" explicite
def restructurer_donnees(df):
"""
Transforme le DataFrame pour avoir des colonnes : nom, points_de_vie, attaque, type
"""
donnees = []
for _, row in df.iterrows():
if pd.notna(row['Pokemons de type Eau']):
donnees.append({
'nom': row['nom'],
'points_de_vie': row['points de vie'],
'attaque': row['Pokemons de type Eau'],
'type': 'Eau'
})
elif pd.notna(row['Pokemons de type Psy']):
donnees.append({
'nom': row['nom'],
'points_de_vie': row['points de vie'],
'attaque': row['Pokemons de type Psy'],
'type': 'Psy'
})
return pd.DataFrame(donnees)
pokemons_clean = restructurer_donnees(pokemons)
print("\n=== Données restructurées ===")
print(pokemons_clean.head(10))
# Visualisation des données
def afficher_graphique(df, pokemon_mystere=None):
"""
Affiche un nuage de points des Pokémon par type
"""
plt.figure(figsize=(10, 6))
# Pokémon de type Eau
eau = df[df['type'] == 'Eau']
plt.scatter(eau['points_de_vie'], eau['attaque'], color='blue', label='Type Eau', s=100)
# Pokémon de type Psy
psy = df[df['type'] == 'Psy']
plt.scatter(psy['points_de_vie'], psy['attaque'], color='purple', label='Type Psy', s=100)
# Pokémon mystère si fourni
if pokemon_mystere:
plt.scatter(pokemon_mystere[0], pokemon_mystere[1], color='red', marker='X', s=200, label='Mystère')
plt.title('Classification des Pokémon par type')
plt.xlabel('Points de vie')
plt.ylabel('Attaque')
plt.legend()
plt.grid(True)
plt.show()
# Fonction de calcul de distance euclidienne
def calculer_distance(pokemon1, pokemon2):
"""
Calcule la distance euclidienne entre deux Pokémon
pokemon1 et pokemon2 sont des listes [points_de_vie, attaque]
"""
return np.sqrt(np.sum(np.square(np.array(pokemon1) - np.array(pokemon2))))
# Algorithme KNN complet
def knn(df, pokemon_mystere, k=5):
"""
Algorithme des k plus proches voisins
Paramètres:
df : DataFrame contenant les Pokémon avec colonnes points_de_vie, attaque, type
pokemon_mystere : liste [points_de_vie, attaque] du Pokémon à classifier
k : nombre de voisins à considérer
Retourne:
Le type prédit pour le Pokémon mystère
"""
# Calcul des distances pour chaque Pokémon
df_copy = df.copy()
df_copy['distance'] = df_copy.apply(
lambda row: calculer_distance([row['points_de_vie'], row['attaque']], pokemon_mystere),
axis=1
)
# Tri par distance croissante et sélection des k premiers
k_voisins = df_copy.sort_values(by='distance').head(k)
print(f"\n=== Les {k} plus proches voisins ===")
print(k_voisins[['nom', 'type', 'distance']])
# Vote majoritaire
type_majoritaire = k_voisins['type'].value_counts().idxmax()
votes = k_voisins['type'].value_counts()
print(f"\n=== Votes ===")
print(votes)
return type_majoritaire
# Exemple d'utilisation
if __name__ == "__main__":
# Définir un Pokémon mystère (points_de_vie, attaque)
pokemon_mystere = [65, 40]
# Afficher le graphique avec le Pokémon mystère
afficher_graphique(pokemons_clean, pokemon_mystere)
# Appliquer l'algorithme KNN avec k=5
type_predit = knn(pokemons_clean, pokemon_mystere, k=5)
print(f"\n=== Résultat ===")
print(f"Le Pokémon mystère ({pokemon_mystere[0]} PV, {pokemon_mystere[1]} attaque) est probablement de type : {type_predit}")