top of page

Arbre de décision

Arbre de décision

Un arbre de décision est un algorithme de machine learning, qui permet d’obtenir une série de règles de décision. Cet algorithme peut être utilisé en classification supervisée, ou en régression. A la racine de l’arbre se trouve l’ensemble des observations, et les feuilles de l’arbre permettent d’obtenir les différentes prédictions possibles. Les arbres de décision sont utilisables même dans le cas de classes non-linéairement séparables.

 

Principe d’un arbre de décision


L’algorithme de création d’un arbre de décision cherche à créer des nœuds, qui permettent de diviser le jeu d’apprentissage en deux branches, à l’aide d’une règle de décision qui s’applique sur une seule caractéristique (colonne). L’objectif est de minimiser à chaque nœud la variance au sein des branches.


A la fin de l’exécution de l’algorithme, on atteint les feuilles de l’arbre.


Pour faire une prédiction, on descend l’arbre jusqu’à atteindre une feuille, et :

  • si on cherche à faire de la régression, on choisit comme prédiction la moyenne de la variable cible pour toutes les observations du jeu d’apprentissage classées dans cette feuille,

  • si on cherche à faire de la classification, on choisit comme prédiction la classe majoritaire parmi les observations du jeu d’apprentissage classées dans cette feuille.


Syntaxe Python


En Python, on peut facilement créer un modèle d'arbre de décision. Ci-dessous, le code permettant de créer l'arbre à deux étages représentés dans la suite de l'article :

Une fonction similaire existe pour la classification.


Représentation graphique


Grâce à la bibliothèque Scikit-learn de Python, on peut utiliser une fonction de visualisation des arbres de décision :

Visualisation du petit arbre de décision
Visualisation du petit arbre de décision

Dans cette visualisation, on voit, pour chaque noeud (représenté par un cadre) :

  • la règle de décision : pour le premier noeud, s'il est écrit x[0] <= 5.086, cela signifie que les observations pour lesquelles la colonne x[0] contient une valeur inférieure à 5.086, seront classés dans la branche True, et les autre dans la branche False,

  • la squared_error, qui représente la variance de la valeur immobilière médiane (variable cible) pour les observations qui passent par ce noeud, et qui est la quantité à minimiser,

  • le nombre d'observations de la base d'apprentissage qui passent par ce noeud,

  • et la prédiction que fera l'arbre de décision pour les observations nouvelles qui passeront par ce noeud (c'est la 'value').


Prétraitement des données


Comme les arbres de décision reposent sur le principe de seuils, la normalisation préalable des données est inutile.

 

Problème du sur-apprentissage 


Les arbres de décision présentent souvent du sur-apprentissage, à cause de leur structure très rigide et très dépendante des données d’apprentissage. On détecte cet éventuel sur-apprentissage en traçant la courbe d’apprentissage du modèle.

 

Pour éviter ce sur-apprentissage, deux solutions sont possibles : on peut régler les hyperparamètres max_depth et min_samples_leaf, ou utiliser un modèle en Random Forest.

  • Le paramètre max_depth règle le nombre d’étages maximal de l’arbre de décision, donc permet de limiter sa complexité. On cherchera à réduire la valeur de max_depth si on se trouve en sur-apprentissage.

  • Le paramètre min_samples_leaf indique le nombre d’observations minimal nécessaire pour créer une feuille. On cherchera à l’augmenter si on se trouve en sur-apprentissage.


Afin d’obtenir de meilleures performances et d’éviter les problèmes liés à la rigidité du modèle en arbre de décision, on peut utiliser un algorithme en Random Forest, qui crée et entraîne plusieurs arbres de décision en parallèle, sur des sous-ensembles aléatoirement choisis de la base d’apprentissage. La prédiction du modèle est ensuite, dans le cas d’une classification, la classe majoritairement prédite par les arbres, et, dans le cas d’une régression, la moyenne des valeurs prédites par les arbres. Le modèle Random Forest étant complexe, son exécution est cependant coûteuse en temps et en mémoire.



Voir aussi : all(), any(), append(), count(), enumerate(), extend(), filter(), float() format() input(), int(), isdigit(), isinstance(), items(), join(), endswith(), list(), map(), max(), mean(), min(), pop(), range(), len(), startswith(), zip(), type(), get(), symmetric_difference(), keys(), difference()


Numpy : arange(), array(), delete(), hsplit(), hstack(), linspace(), logical_and(), logical_or(), polyfit()


Pandas : concat(), concatenate(), describe(), dict(), drop_duplicates(), dropna(), fillna(), from_dict(), groupby(), head(), iloc, info(), insert(), isin(), melt(), merge(), pivot_table(), read_csv(), read_excel(), rename(), where()


Machine Learning : F1-Score, Précision, Rappel, Normalisation, Courbe d’apprentissage, Les résidus, Régression VS classification, Mean Absolute Error (MAE), Mean Squared Error (MSE), Root Mean Squared Error (RMSE), Accuracy, L’astuce du noyau, Bases d’apprentissage et de test, Classes linéairement séparables, Apprentissage supervisé VS non-supervisé, Coefficient de détermination R2, Validation croisée


N'hésitez pas à consulter nos formations sur cette page.

Glossaire pour apprendre à coder sur Python

Apprenez à coder sur Python

avec des experts

Notre organisme de formation spécialisé dans le langage Python et la Data Science forme les débutants et perfectionne les experts sur ce langage informatique. Pourquoi pas vous?

Des questions?

Contact Expert Python

Un formateur Python vous répond très vite

bottom of page