Uczenie maszynowe i sztuczne sieci neuronowe/DrzewaDecyzyjne cw: Różnice pomiędzy wersjami
Z Brain-wiki
Linia 39: | Linia 39: | ||
graph = pydot.graph_from_dot_data(dot_data.getvalue()) | graph = pydot.graph_from_dot_data(dot_data.getvalue()) | ||
graph.write_pdf("iris.pdf") | graph.write_pdf("iris.pdf") | ||
+ | </source> | ||
+ | |||
+ | Można też podejrzeć wyniki w pythonie: | ||
+ | <source lang = python> | ||
+ | from IPython.display import Image | ||
+ | dot_data = StringIO() | ||
+ | tree.export_graphviz(clf, out_file=dot_data, feature_names=iris.feature_names, | ||
+ | class_names=iris.target_names, | ||
+ | filled=True, rounded=True, | ||
+ | special_characters=True) | ||
+ | graph = pydot.graph_from_dot_data(dot_data.getvalue()) | ||
+ | Image(graph.create_png()) | ||
</source> | </source> |
Wersja z 18:35, 2 sty 2016
Wstęp
W bibliotece scikit-learn drzewa decyzyjen iplementowane są przez klasę DecisionTreeClassifier
Aby nauczyć taki klasyfikator potrzebujemy tablicę X o rozmiarach [N_przykładów, N_cech] i wektor Y określający przynależność przykładów w X do klas.
Najprostszy przykład:
from sklearn import tree
X = [[0, 0], [1, 1]]
Y = [0, 1]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
Po dopasowaniu można przewidywać przynależność nowych przykładów:
clf.predict([[2., 2.]])
Albo estymować prawdopodobieństwo przynależności do klas:
clf.predict_proba([[2., 2.]])
Przykład zbiór danych Iris
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
Po wytrenowaniu można zilustrować wynik za pomocą narzedzia Graphiz (wymaga to zainstalowania w systemie tego narzędzia), oraz doinstalowania do pythona biblioteki pydot:
from sklearn.externals.six import StringIO
import pydot
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("iris.pdf")
Można też podejrzeć wyniki w pythonie:
from IPython.display import Image
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data, feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())