Uczenie maszynowe i sztuczne sieci neuronowe/DrzewaDecyzyjne cw: Różnice pomiędzy wersjami
(→Wstęp) |
|||
Linia 1: | Linia 1: | ||
=Wstęp= | =Wstęp= | ||
− | W bibliotece scikit-learn drzewa | + | W bibliotece scikit-learn drzewa decyzyjne implementowane są przez klasę [http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier DecisionTreeClassifier] |
Aby nauczyć taki klasyfikator potrzebujemy tablicę X o rozmiarach [N_przykładów, N_cech] i wektor Y określający przynależność przykładów w X do klas. | Aby nauczyć taki klasyfikator potrzebujemy tablicę X o rozmiarach [N_przykładów, N_cech] i wektor Y określający przynależność przykładów w X do klas. |
Wersja z 12:17, 4 sty 2016
Wstęp
W bibliotece scikit-learn drzewa decyzyjne implementowane są przez klasę DecisionTreeClassifier
Aby nauczyć taki klasyfikator potrzebujemy tablicę X o rozmiarach [N_przykładów, N_cech] i wektor Y określający przynależność przykładów w X do klas.
Najprostszy przykład:
from sklearn import tree
X = [[0, 0], [1, 1]]
Y = [0, 1]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
Po dopasowaniu można przewidywać przynależność nowych przykładów:
clf.predict([[2., 2.]])
Albo estymować prawdopodobieństwo przynależności do klas:
clf.predict_proba([[2., 2.]])
Przykład zbiór danych Iris
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
Po wytrenowaniu można zilustrować wynik za pomocą narzedzia Graphiz (wymaga to zainstalowania w systemie tego narzędzia), oraz doinstalowania do pythona biblioteki pydot:
from sklearn.externals.six import StringIO
import pydot
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("iris.pdf")
Można też podejrzeć wyniki w pythonie:
from IPython.display import Image
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data, feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
Po dopasowaniu model ten może być zastosowany do przewidywania przynależności przykładów do klas:
clf.predict(iris.data[:1, :])
lub estymowania prawdopodobieństwa przynależności do klas:
clf.predict_proba(iris.data[:1, :])
l