신은섭(Shin Eun Seop)

fix model load

...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
52 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> 52 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
53 <mainClass>detact.Aggregation</mainClass> 53 <mainClass>detact.Aggregation</mainClass>
54 </transformer> 54 </transformer>
55 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
55 </transformers> 56 </transformers>
56 <filters> 57 <filters>
57 <filter> 58 <filter>
...@@ -78,6 +79,7 @@ ...@@ -78,6 +79,7 @@
78 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> 79 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
79 <mainClass>detact.ML.DecisionTree</mainClass> 80 <mainClass>detact.ML.DecisionTree</mainClass>
80 </transformer> 81 </transformer>
82 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
81 </transformers> 83 </transformers>
82 <filters> 84 <filters>
83 <filter> 85 <filter>
...@@ -104,6 +106,7 @@ ...@@ -104,6 +106,7 @@
104 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> 106 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
105 <mainClass>detact.Main</mainClass> 107 <mainClass>detact.Main</mainClass>
106 </transformer> 108 </transformer>
109 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
107 </transformers> 110 </transformers>
108 <filters> 111 <filters>
109 <filter> 112 <filter>
......
...@@ -107,7 +107,13 @@ public class DecisionTree { ...@@ -107,7 +107,13 @@ public class DecisionTree {
107 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); 107 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
108 108
109 // save model 109 // save model
110 - model.save("./decisionTree.model"); 110 + model.save("./decisionTree");
111 +
112 + // load model
113 + PipelineModel load_mode = PipelineModel.load("./decisionTree");
114 +
115 + // Make predictions.
116 + Dataset<Row> load_pred = model.transform(testData);
111 117
112 } 118 }
113 119
......
...@@ -78,8 +78,13 @@ public class Main { ...@@ -78,8 +78,13 @@ public class Main {
78 // Train model. This also runs the indexer. 78 // Train model. This also runs the indexer.
79 PipelineModel model = pipeline.fit(trainingData); 79 PipelineModel model = pipeline.fit(trainingData);
80 80
81 + // save model
82 + model.save("./decisionTree");
83 +
84 + PipelineModel p_model = PipelineModel.load("./decisionTree");
85 +
81 // Make predictions. 86 // Make predictions.
82 - Dataset<Row> predictions = model.transform(testData); 87 + Dataset<Row> predictions = p_model.transform(testData);
83 88
84 // Select example rows to display. 89 // Select example rows to display.
85 predictions.select("is_attributed", "features").show(5); 90 predictions.select("is_attributed", "features").show(5);
...@@ -93,10 +98,8 @@ public class Main { ...@@ -93,10 +98,8 @@ public class Main {
93 System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse); 98 System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
94 99
95 DecisionTreeRegressionModel treeModel = 100 DecisionTreeRegressionModel treeModel =
96 - (DecisionTreeRegressionModel) (model.stages()[1]); 101 + (DecisionTreeRegressionModel) (p_model.stages()[1]);
97 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); 102 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
98 103
99 - // save model
100 - model.save("./decisionTree.model");
101 } 104 }
102 } 105 }
......