신은섭(Shin Eun Seop)

fix model load

......@@ -52,6 +52,7 @@
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>detact.Aggregation</mainClass>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
</transformers>
<filters>
<filter>
......@@ -78,6 +79,7 @@
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>detact.ML.DecisionTree</mainClass>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
</transformers>
<filters>
<filter>
......@@ -104,6 +106,7 @@
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>detact.Main</mainClass>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
</transformers>
<filters>
<filter>
......
......@@ -107,7 +107,13 @@ public class DecisionTree {
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
// save model
model.save("./decisionTree.model");
model.save("./decisionTree");
// load model
PipelineModel load_mode = PipelineModel.load("./decisionTree");
// Make predictions.
Dataset<Row> load_pred = model.transform(testData);
}
......
......@@ -78,8 +78,13 @@ public class Main {
// Train model. This also runs the indexer.
PipelineModel model = pipeline.fit(trainingData);
// save model
model.save("./decisionTree");
PipelineModel p_model = PipelineModel.load("./decisionTree");
// Make predictions.
Dataset<Row> predictions = model.transform(testData);
Dataset<Row> predictions = p_model.transform(testData);
// Select example rows to display.
predictions.select("is_attributed", "features").show(5);
......@@ -93,10 +98,8 @@ public class Main {
System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
DecisionTreeRegressionModel treeModel =
(DecisionTreeRegressionModel) (model.stages()[1]);
(DecisionTreeRegressionModel) (p_model.stages()[1]);
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
// save model
model.save("./decisionTree.model");
}
}
......