Showing
25 changed files
with
124 additions
and
360 deletions
decisionTree/metadata/._SUCCESS.crc
0 → 100644
No preview for this file type
decisionTree/metadata/.part-00000.crc
0 → 100644
No preview for this file type
decisionTree/metadata/_SUCCESS
0 → 100644
File mode changed
decisionTree/metadata/part-00000
0 → 100644
1 | +{"class":"org.apache.spark.ml.PipelineModel","timestamp":1528805498147,"sparkVersion":"2.3.0","uid":"pipeline_70a068225fba","paramMap":{"stageUids":["vecIdx_c20b02d06e4a","dtr_20be5d6af4d6"]}} |
No preview for this file type
No preview for this file type
File mode changed
No preview for this file type
No preview for this file type
No preview for this file type
File mode changed
1 | +{"class":"org.apache.spark.ml.feature.VectorIndexerModel","timestamp":1528805498480,"sparkVersion":"2.3.0","uid":"vecIdx_c20b02d06e4a","paramMap":{"handleInvalid":"error","maxCategories":2,"outputCol":"indexedFeatures","inputCol":"features"}} |
No preview for this file type
No preview for this file type
File mode changed
No preview for this file type
No preview for this file type
No preview for this file type
File mode changed
1 | +{"class":"org.apache.spark.ml.regression.DecisionTreeRegressionModel","timestamp":1528805500043,"sparkVersion":"2.3.0","uid":"dtr_20be5d6af4d6","paramMap":{"seed":926680331,"featuresCol":"indexedFeatures","checkpointInterval":10,"maxMemoryInMB":256,"minInfoGain":0.0,"cacheNodeIds":false,"maxDepth":10,"impurity":"variance","maxBins":32,"labelCol":"is_attributed","predictionCol":"prediction","minInstancesPerNode":1},"numFeatures":9} |
1 | +import org.apache.spark.ml.Pipeline; | ||
2 | +import org.apache.spark.ml.PipelineModel; | ||
3 | +import org.apache.spark.ml.PipelineStage; | ||
4 | +import org.apache.spark.ml.feature.VectorAssembler; | ||
5 | +import org.apache.spark.ml.feature.VectorIndexer; | ||
6 | +import org.apache.spark.ml.feature.VectorIndexerModel; | ||
7 | +import org.apache.spark.ml.regression.DecisionTreeRegressor; | ||
1 | import org.apache.spark.sql.Dataset; | 8 | import org.apache.spark.sql.Dataset; |
2 | import org.apache.spark.sql.Row; | 9 | import org.apache.spark.sql.Row; |
3 | 10 | ||
... | @@ -5,6 +12,7 @@ import javax.swing.*; | ... | @@ -5,6 +12,7 @@ import javax.swing.*; |
5 | import java.awt.*; | 12 | import java.awt.*; |
6 | import java.io.BufferedReader; | 13 | import java.io.BufferedReader; |
7 | import java.io.File; | 14 | import java.io.File; |
15 | +import java.io.IOException; | ||
8 | import java.io.StringReader; | 16 | import java.io.StringReader; |
9 | import java.util.List; | 17 | import java.util.List; |
10 | 18 | ||
... | @@ -61,7 +69,9 @@ class PngPane extends JPanel { | ... | @@ -61,7 +69,9 @@ class PngPane extends JPanel { |
61 | add(label, BorderLayout.CENTER); | 69 | add(label, BorderLayout.CENTER); |
62 | } | 70 | } |
63 | } | 71 | } |
64 | - | 72 | +class SharedArea{ |
73 | + Dataset<Row> data; | ||
74 | +} | ||
65 | class CreateTable_tab extends JPanel{ | 75 | class CreateTable_tab extends JPanel{ |
66 | public JPanel centre_pane = new JPanel(); | 76 | public JPanel centre_pane = new JPanel(); |
67 | public JPanel south_pane = new JPanel(); | 77 | public JPanel south_pane = new JPanel(); |
... | @@ -82,7 +92,7 @@ class CreateTable_tab extends JPanel{ | ... | @@ -82,7 +92,7 @@ class CreateTable_tab extends JPanel{ |
82 | private DefaultTableModel tableModel3 = new DefaultTableModel(new Object[]{"unknown"},1); | 92 | private DefaultTableModel tableModel3 = new DefaultTableModel(new Object[]{"unknown"},1); |
83 | 93 | ||
84 | public CsvFile_chooser temp = new CsvFile_chooser(); | 94 | public CsvFile_chooser temp = new CsvFile_chooser(); |
85 | - | 95 | + private String current_state="100"; |
86 | 96 | ||
87 | public CreateTable_tab(){ | 97 | public CreateTable_tab(){ |
88 | super(); | 98 | super(); |
... | @@ -103,13 +113,16 @@ class CreateTable_tab extends JPanel{ | ... | @@ -103,13 +113,16 @@ class CreateTable_tab extends JPanel{ |
103 | // sub Panel 3 | 113 | // sub Panel 3 |
104 | pan3.setViewportView(table3); | 114 | pan3.setViewportView(table3); |
105 | centre_pane.add(pan3); | 115 | centre_pane.add(pan3); |
106 | - | 116 | + add(centre_pane, BorderLayout.CENTER); |
107 | //sub Panel 4 | 117 | //sub Panel 4 |
108 | south_pane.setLayout(new FlowLayout()); | 118 | south_pane.setLayout(new FlowLayout()); |
109 | south_pane.add(btn1); | 119 | south_pane.add(btn1); |
120 | + | ||
110 | btn1.addActionListener(new ActionListener() { | 121 | btn1.addActionListener(new ActionListener() { |
111 | - @Override | 122 | + |
112 | public void actionPerformed(ActionEvent e) { | 123 | public void actionPerformed(ActionEvent e) { |
124 | + | ||
125 | + | ||
113 | if(temp.is_selected) { | 126 | if(temp.is_selected) { |
114 | String path = temp.selected_file.getAbsolutePath(); | 127 | String path = temp.selected_file.getAbsolutePath(); |
115 | // 1st Column Raw Data | 128 | // 1st Column Raw Data |
... | @@ -126,20 +139,105 @@ class CreateTable_tab extends JPanel{ | ... | @@ -126,20 +139,105 @@ class CreateTable_tab extends JPanel{ |
126 | TableCreator table_maker = new TableCreator(); | 139 | TableCreator table_maker = new TableCreator(); |
127 | 140 | ||
128 | Dataset<Row> dataset = agg.loadCSVDataSet(path, spark); | 141 | Dataset<Row> dataset = agg.loadCSVDataSet(path, spark); |
129 | - List<String> stringDataset_Raw = dataset.toJSON().collectAsList(); | 142 | + if(current_state.equals("100")){ |
130 | - String[] header_r = {"ip", "app", "device", "os", "channel", "click_time", "is_attributed"}; | 143 | + List<String> stringDataset_Raw = dataset.toJSON().collectAsList(); |
131 | - table1.setModel(table_maker.getTableModel(stringDataset_Raw, header_r)); | 144 | + String[] header_r = {"ip", "app", "device", "os", "channel", "click_time", "is_attributed"}; |
132 | - | 145 | + table1.setModel(table_maker.getTableModel(stringDataset_Raw, header_r)); |
133 | - // 2nd Column Data with features | 146 | + current_state="200"; |
134 | - // Adding features | 147 | + }else if(current_state.equals("200")){ |
135 | - dataset = agg.changeTimestempToLong(dataset); | 148 | + // 2nd Column Data with features |
136 | - dataset = agg.averageValidClickCount(dataset); | 149 | + // Adding features |
137 | - dataset = agg.clickTimeDelta(dataset); | 150 | + dataset = agg.changeTimestempToLong(dataset); |
138 | - dataset = agg.countClickInTenMinutes(dataset); | 151 | + dataset = agg.averageValidClickCount(dataset); |
139 | - List<String> stringDataset_feat = dataset.toJSON().collectAsList(); | 152 | + dataset = agg.clickTimeDelta(dataset); |
140 | - String[] header_f = {"ip", "app", "device", "os", "channel", "is_attributed", "click_time", | 153 | + dataset = agg.countClickInTenMinutes(dataset); |
141 | - "avg_valid_click_count", "click_time_delta", "count_click_in_ten_mins"}; | 154 | + List<String> stringDataset_feat = dataset.toJSON().collectAsList(); |
142 | - table2.setModel(table_maker.getTableModel(stringDataset_feat, header_f)); | 155 | + String[] header_f = {"ip", "app", "device", "os", "channel", "is_attributed", "click_time", |
156 | + "avg_valid_click_count", "click_time_delta", "count_click_in_ten_mins"}; | ||
157 | + table2.setModel(table_maker.getTableModel(stringDataset_feat, header_f)); | ||
158 | + current_state="300"; | ||
159 | + }else if(current_state.equals("300")){ | ||
160 | + dataset = agg.changeTimestempToLong(dataset); | ||
161 | + dataset = agg.averageValidClickCount(dataset); | ||
162 | + dataset = agg.clickTimeDelta(dataset); | ||
163 | + dataset = agg.countClickInTenMinutes(dataset); | ||
164 | + | ||
165 | + VectorAssembler assembler = new VectorAssembler() | ||
166 | + .setInputCols(new String[]{ | ||
167 | + "ip", | ||
168 | + "app", | ||
169 | + "device", | ||
170 | + "os", | ||
171 | + "channel", | ||
172 | + "utc_click_time", | ||
173 | + "avg_valid_click_count", | ||
174 | + "click_time_delta", | ||
175 | + "count_click_in_ten_mins" | ||
176 | + }) | ||
177 | + .setOutputCol("features"); | ||
178 | + | ||
179 | + Dataset<Row> output = assembler.transform(dataset); | ||
180 | + | ||
181 | + VectorIndexerModel featureIndexer = new VectorIndexer() | ||
182 | + .setInputCol("features") | ||
183 | + .setOutputCol("indexedFeatures") | ||
184 | + .setMaxCategories(2) | ||
185 | + .fit(output); | ||
186 | + | ||
187 | + // Split the result into training and test sets (30% held out for testing). | ||
188 | +// Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3}); | ||
189 | +// Dataset<Row> trainingData = splits[0]; | ||
190 | +// Dataset<Row> testData = splits[1]; | ||
191 | + | ||
192 | + | ||
193 | + | ||
194 | + // Train a detact.DecisionTreeionTree model. | ||
195 | + DecisionTreeRegressor dt = new DecisionTreeRegressor() | ||
196 | + .setFeaturesCol("indexedFeatures") | ||
197 | + .setLabelCol("is_attributed") | ||
198 | + .setMaxDepth(10); | ||
199 | + | ||
200 | + // Chain indexer and tree in a Pipeline. | ||
201 | + Pipeline pipeline = new Pipeline() | ||
202 | + .setStages(new PipelineStage[]{featureIndexer, dt}); | ||
203 | + | ||
204 | + // Train model. This also runs the indexer. | ||
205 | + PipelineModel model = pipeline.fit(output); | ||
206 | + | ||
207 | + // save model | ||
208 | + try { | ||
209 | + model.save("./decisionTree"); | ||
210 | + } catch (IOException e1) { | ||
211 | + e1.printStackTrace(); | ||
212 | + } | ||
213 | + | ||
214 | + PipelineModel p_model = PipelineModel.load("./decisionTree"); | ||
215 | + | ||
216 | + // Make predictions. | ||
217 | + Dataset<Row> predictions = p_model.transform(assembler.transform(dataset)); | ||
218 | + predictions = predictions.drop("app") | ||
219 | + .drop("device") | ||
220 | + .drop("os") | ||
221 | + .drop("channel") | ||
222 | + .drop("utc_click_time") | ||
223 | + .drop("utc_attributed_time") | ||
224 | + .drop("avg_valid_click_count") | ||
225 | + .drop("click_time_delta") | ||
226 | + .drop("count_click_in_ten_mins") | ||
227 | + .drop("features") | ||
228 | + .drop("indexedFeatures"); | ||
229 | + predictions.printSchema(); | ||
230 | + List<String> stringDataset_feat = predictions.toJSON().collectAsList(); | ||
231 | + String[] header_f = {"ip","is_attributed","prediction"}; | ||
232 | + table3.setModel(table_maker.getTableModel(stringDataset_feat, header_f)); | ||
233 | +// | ||
234 | +// | ||
235 | +// | ||
236 | + current_state="400"; | ||
237 | + } | ||
238 | + | ||
239 | + | ||
240 | + | ||
143 | 241 | ||
144 | 242 | ||
145 | // 3nd Column Final results | 243 | // 3nd Column Final results |
... | @@ -148,7 +246,7 @@ class CreateTable_tab extends JPanel{ | ... | @@ -148,7 +246,7 @@ class CreateTable_tab extends JPanel{ |
148 | } | 246 | } |
149 | } | 247 | } |
150 | }); | 248 | }); |
151 | - add(centre_pane, BorderLayout.CENTER); | 249 | + |
152 | add(south_pane, BorderLayout.SOUTH); | 250 | add(south_pane, BorderLayout.SOUTH); |
153 | 251 | ||
154 | 252 | ||
... | @@ -181,7 +279,7 @@ class CsvFile_chooser extends JPanel{ | ... | @@ -181,7 +279,7 @@ class CsvFile_chooser extends JPanel{ |
181 | add(path_field); | 279 | add(path_field); |
182 | add(browser); | 280 | add(browser); |
183 | browser.addActionListener(new ActionListener(){ | 281 | browser.addActionListener(new ActionListener(){ |
184 | - @Override | 282 | + |
185 | public void actionPerformed(ActionEvent e) { | 283 | public void actionPerformed(ActionEvent e) { |
186 | Object obj = e.getSource(); | 284 | Object obj = e.getSource(); |
187 | if((JButton)obj == browser){ | 285 | if((JButton)obj == browser){ |
... | @@ -297,4 +395,6 @@ class Aggregation { | ... | @@ -297,4 +395,6 @@ class Aggregation { |
297 | (count("utc_click_time").over(w)).minus(1)); //TODO 본인것 포함할 것인지 정해야함. | 395 | (count("utc_click_time").over(w)).minus(1)); //TODO 본인것 포함할 것인지 정해야함. |
298 | return newDF; | 396 | return newDF; |
299 | } | 397 | } |
300 | -} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
398 | +} | ||
399 | + | ||
400 | + | ... | ... |
1 | -package detact; | ||
2 | - | ||
3 | -import org.apache.spark.sql.Dataset; | ||
4 | -import org.apache.spark.sql.Row; | ||
5 | -import org.apache.spark.sql.SparkSession; | ||
6 | -import org.apache.spark.sql.expressions.Window; | ||
7 | -import org.apache.spark.sql.expressions.WindowSpec; | ||
8 | - | ||
9 | -import static org.apache.spark.sql.functions.*; | ||
10 | - | ||
11 | -public class Aggregation { | ||
12 | - | ||
13 | - public static void main(String[] args) { | ||
14 | - | ||
15 | - if (args.length != 2) { | ||
16 | - System.out.println("Usage: java -jar aggregation.jar <data_path> <result_path>"); | ||
17 | - System.exit(0); | ||
18 | - } | ||
19 | - | ||
20 | - String data_path = args[0]; | ||
21 | - String result_path = args[1]; | ||
22 | - | ||
23 | - //Create Session | ||
24 | - SparkSession spark = SparkSession | ||
25 | - .builder() | ||
26 | - .appName("Detecting Fraud Clicks") | ||
27 | - .master("local") | ||
28 | - .getOrCreate(); | ||
29 | - | ||
30 | - // detact.Aggregation | ||
31 | - Aggregation agg = new Aggregation(); | ||
32 | - | ||
33 | - Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark); | ||
34 | - dataset = agg.changeTimestempToLong(dataset); | ||
35 | - dataset = agg.averageValidClickCount(dataset); | ||
36 | - dataset = agg.clickTimeDelta(dataset); | ||
37 | - dataset = agg.countClickInTenMinutes(dataset); | ||
38 | - | ||
39 | - // test | ||
40 | -// dataset.where("ip == '5348' and app == '19'").show(10); | ||
41 | - | ||
42 | - // Save to scv | ||
43 | - Utill.saveCSVDataSet(dataset, result_path); | ||
44 | - } | ||
45 | - | ||
46 | - public Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ | ||
47 | - // cast timestamp to long | ||
48 | - Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); | ||
49 | - newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); | ||
50 | - newDF = newDF.drop("click_time").drop("attributed_time"); | ||
51 | - return newDF; | ||
52 | - } | ||
53 | - | ||
54 | - public Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ | ||
55 | - // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row | ||
56 | - WindowSpec w = Window.partitionBy("ip", "app") | ||
57 | - .orderBy("utc_click_time") | ||
58 | - .rowsBetween(Window.unboundedPreceding(), Window.currentRow()); | ||
59 | - | ||
60 | - // aggregation | ||
61 | - Dataset<Row> newDF = dataset.withColumn("cum_count_click", count("utc_click_time").over(w)); | ||
62 | - newDF = newDF.withColumn("cum_sum_attributed", sum("is_attributed").over(w)); | ||
63 | - newDF = newDF.withColumn("avg_valid_click_count", col("cum_sum_attributed").divide(col("cum_count_click"))); | ||
64 | - newDF = newDF.drop("cum_count_click", "cum_sum_attributed"); | ||
65 | - return newDF; | ||
66 | - } | ||
67 | - | ||
68 | - public Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ | ||
69 | - WindowSpec w = Window.partitionBy ("ip") | ||
70 | - .orderBy("utc_click_time"); | ||
71 | - | ||
72 | - Dataset<Row> newDF = dataset.withColumn("lag(utc_click_time)", lag("utc_click_time",1).over(w)); | ||
73 | - newDF = newDF.withColumn("click_time_delta", when(col("lag(utc_click_time)").isNull(), | ||
74 | - lit(0)).otherwise(col("utc_click_time")).minus(when(col("lag(utc_click_time)").isNull(), | ||
75 | - lit(0)).otherwise(col("lag(utc_click_time)")))); | ||
76 | - newDF = newDF.drop("lag(utc_click_time)"); | ||
77 | - return newDF; | ||
78 | - } | ||
79 | - | ||
80 | - public Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){ | ||
81 | - WindowSpec w = Window.partitionBy("ip") | ||
82 | - .orderBy("utc_click_time") | ||
83 | - .rangeBetween(Window.currentRow(),Window.currentRow()+600); | ||
84 | - | ||
85 | - Dataset<Row> newDF = dataset.withColumn("count_click_in_ten_mins", | ||
86 | - (count("utc_click_time").over(w)).minus(1)); | ||
87 | - return newDF; | ||
88 | - } | ||
89 | - | ||
90 | -} |
1 | -package detact.ML; | ||
2 | - | ||
3 | -import detact.Aggregation; | ||
4 | -import detact.Utill; | ||
5 | -import org.apache.spark.ml.Pipeline; | ||
6 | -import org.apache.spark.ml.PipelineModel; | ||
7 | -import org.apache.spark.ml.PipelineStage; | ||
8 | -import org.apache.spark.ml.evaluation.RegressionEvaluator; | ||
9 | -import org.apache.spark.ml.feature.VectorAssembler; | ||
10 | -import org.apache.spark.ml.feature.VectorIndexer; | ||
11 | -import org.apache.spark.ml.feature.VectorIndexerModel; | ||
12 | -import org.apache.spark.ml.regression.DecisionTreeRegressionModel; | ||
13 | -import org.apache.spark.ml.regression.DecisionTreeRegressor; | ||
14 | -import org.apache.spark.sql.Dataset; | ||
15 | -import org.apache.spark.sql.Row; | ||
16 | -import org.apache.spark.sql.SparkSession; | ||
17 | - | ||
18 | - | ||
19 | -// DecisionTree Model | ||
20 | - | ||
21 | -public class DecisionTree { | ||
22 | - | ||
23 | - public static void main(String[] args) throws Exception { | ||
24 | - | ||
25 | - if (args.length != 1) { | ||
26 | - System.out.println("Usage: java -jar decisionTree.jar <agg_path>"); | ||
27 | - System.exit(0); | ||
28 | - } | ||
29 | - | ||
30 | - String agg_path = args[0]; | ||
31 | - | ||
32 | - //Create Session | ||
33 | - SparkSession spark = SparkSession | ||
34 | - .builder() | ||
35 | - .appName("Detecting Fraud Clicks") | ||
36 | - .master("local") | ||
37 | - .getOrCreate(); | ||
38 | - | ||
39 | - // load aggregated dataset | ||
40 | - Dataset<Row> resultds = Utill.loadCSVDataSet(agg_path, spark); | ||
41 | - | ||
42 | - // show Dataset schema | ||
43 | -// System.out.println("schema start"); | ||
44 | -// resultds.printSchema(); | ||
45 | -// String[] cols = resultds.columns(); | ||
46 | -// for (String col : cols) { | ||
47 | -// System.out.println(col); | ||
48 | -// } | ||
49 | -// System.out.println("schema end"); | ||
50 | - | ||
51 | - VectorAssembler assembler = new VectorAssembler() | ||
52 | - .setInputCols(new String[]{ | ||
53 | - "ip", | ||
54 | - "app", | ||
55 | - "device", | ||
56 | - "os", | ||
57 | - "channel", | ||
58 | - "utc_click_time", | ||
59 | - "avg_valid_click_count", | ||
60 | - "click_time_delta", | ||
61 | - "count_click_in_ten_mins" | ||
62 | - }) | ||
63 | - .setOutputCol("features"); | ||
64 | - | ||
65 | - Dataset<Row> output = assembler.transform(resultds); | ||
66 | - | ||
67 | - VectorIndexerModel featureIndexer = new VectorIndexer() | ||
68 | - .setInputCol("features") | ||
69 | - .setOutputCol("indexedFeatures") | ||
70 | - .setMaxCategories(2) | ||
71 | - .fit(output); | ||
72 | - | ||
73 | - // Split the result into training and test sets (30% held out for testing). | ||
74 | - Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3}); | ||
75 | - Dataset<Row> trainingData = splits[0]; | ||
76 | - Dataset<Row> testData = splits[1]; | ||
77 | - | ||
78 | - // Train a detact.DecisionTreeionTree model. | ||
79 | - DecisionTreeRegressor dt = new DecisionTreeRegressor() | ||
80 | - .setFeaturesCol("indexedFeatures") | ||
81 | - .setLabelCol("is_attributed") | ||
82 | - .setMaxDepth(10); | ||
83 | - | ||
84 | - // Chain indexer and tree in a Pipeline. | ||
85 | - Pipeline pipeline = new Pipeline() | ||
86 | - .setStages(new PipelineStage[]{featureIndexer, dt}); | ||
87 | - | ||
88 | - // Train model. This also runs the indexer. | ||
89 | - PipelineModel model = pipeline.fit(trainingData); | ||
90 | - | ||
91 | - // Make predictions. | ||
92 | - Dataset<Row> predictions = model.transform(testData); | ||
93 | - | ||
94 | - // Select example rows to display. | ||
95 | - predictions.select("is_attributed", "features").show(5); | ||
96 | - | ||
97 | - // Select (prediction, true label) and compute test error. | ||
98 | - RegressionEvaluator evaluator = new RegressionEvaluator() | ||
99 | - .setLabelCol("is_attributed") | ||
100 | - .setPredictionCol("prediction") | ||
101 | - .setMetricName("rmse"); | ||
102 | - double rmse = evaluator.evaluate(predictions); | ||
103 | - System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse); | ||
104 | - | ||
105 | - DecisionTreeRegressionModel treeModel = | ||
106 | - (DecisionTreeRegressionModel) (model.stages()[1]); | ||
107 | - System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); | ||
108 | - | ||
109 | - // save model | ||
110 | - model.save("./decisionTree"); | ||
111 | - | ||
112 | - // load model | ||
113 | - PipelineModel load_mode = PipelineModel.load("./decisionTree"); | ||
114 | - | ||
115 | - // Make predictions. | ||
116 | - Dataset<Row> load_pred = model.transform(testData); | ||
117 | - | ||
118 | - } | ||
119 | - | ||
120 | -} |
src/main/java/detact/Main.java
deleted
100644 → 0
1 | -package detact; | ||
2 | - | ||
3 | -import org.apache.spark.ml.Pipeline; | ||
4 | -import org.apache.spark.ml.PipelineModel; | ||
5 | -import org.apache.spark.ml.PipelineStage; | ||
6 | -import org.apache.spark.ml.evaluation.RegressionEvaluator; | ||
7 | -import org.apache.spark.ml.feature.VectorAssembler; | ||
8 | -import org.apache.spark.ml.feature.VectorIndexer; | ||
9 | -import org.apache.spark.ml.feature.VectorIndexerModel; | ||
10 | -import org.apache.spark.ml.regression.DecisionTreeRegressionModel; | ||
11 | -import org.apache.spark.ml.regression.DecisionTreeRegressor; | ||
12 | -import org.apache.spark.sql.Dataset; | ||
13 | -import org.apache.spark.sql.Row; | ||
14 | -import org.apache.spark.sql.SparkSession; | ||
15 | - | ||
16 | -public class Main { | ||
17 | - public static void main(String[] args) throws Exception{ | ||
18 | - if (args.length != 1) { | ||
19 | - System.out.println("Usage: java -jar aggregation.jar <data_path>"); | ||
20 | - System.exit(0); | ||
21 | - } | ||
22 | - | ||
23 | - String data_path = args[0]; | ||
24 | - | ||
25 | - //Create Session | ||
26 | - SparkSession spark = SparkSession | ||
27 | - .builder() | ||
28 | - .appName("Detecting Fraud Clicks") | ||
29 | - .master("local") | ||
30 | - .getOrCreate(); | ||
31 | - | ||
32 | - // detact.Aggregation | ||
33 | - Aggregation agg = new Aggregation(); | ||
34 | - | ||
35 | - Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark); | ||
36 | - dataset = agg.changeTimestempToLong(dataset); | ||
37 | - dataset = agg.averageValidClickCount(dataset); | ||
38 | - dataset = agg.clickTimeDelta(dataset); | ||
39 | - dataset = agg.countClickInTenMinutes(dataset); | ||
40 | - | ||
41 | - VectorAssembler assembler = new VectorAssembler() | ||
42 | - .setInputCols(new String[]{ | ||
43 | - "ip", | ||
44 | - "app", | ||
45 | - "device", | ||
46 | - "os", | ||
47 | - "channel", | ||
48 | - "utc_click_time", | ||
49 | - "avg_valid_click_count", | ||
50 | - "click_time_delta", | ||
51 | - "count_click_in_ten_mins" | ||
52 | - }) | ||
53 | - .setOutputCol("features"); | ||
54 | - | ||
55 | - Dataset<Row> output = assembler.transform(dataset); | ||
56 | - | ||
57 | - VectorIndexerModel featureIndexer = new VectorIndexer() | ||
58 | - .setInputCol("features") | ||
59 | - .setOutputCol("indexedFeatures") | ||
60 | - .setMaxCategories(2) | ||
61 | - .fit(output); | ||
62 | - | ||
63 | - // Split the result into training and test sets (30% held out for testing). | ||
64 | - Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3}); | ||
65 | - Dataset<Row> trainingData = splits[0]; | ||
66 | - Dataset<Row> testData = splits[1]; | ||
67 | - | ||
68 | - // Train a detact.DecisionTreeionTree model. | ||
69 | - DecisionTreeRegressor dt = new DecisionTreeRegressor() | ||
70 | - .setFeaturesCol("indexedFeatures") | ||
71 | - .setLabelCol("is_attributed") | ||
72 | - .setMaxDepth(10); | ||
73 | - | ||
74 | - // Chain indexer and tree in a Pipeline. | ||
75 | - Pipeline pipeline = new Pipeline() | ||
76 | - .setStages(new PipelineStage[]{featureIndexer, dt}); | ||
77 | - | ||
78 | - // Train model. This also runs the indexer. | ||
79 | - PipelineModel model = pipeline.fit(trainingData); | ||
80 | - | ||
81 | - // save model | ||
82 | - model.save("./decisionTree"); | ||
83 | - | ||
84 | - PipelineModel p_model = PipelineModel.load("./decisionTree"); | ||
85 | - | ||
86 | - // Make predictions. | ||
87 | - Dataset<Row> predictions = p_model.transform(testData); | ||
88 | - | ||
89 | - // Select example rows to display. | ||
90 | - predictions.select("is_attributed", "features").show(5); | ||
91 | - | ||
92 | - // Select (prediction, true label) and compute test error. | ||
93 | - RegressionEvaluator evaluator = new RegressionEvaluator() | ||
94 | - .setLabelCol("is_attributed") | ||
95 | - .setPredictionCol("prediction") | ||
96 | - .setMetricName("rmse"); | ||
97 | - double rmse = evaluator.evaluate(predictions); | ||
98 | - System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse); | ||
99 | - | ||
100 | - DecisionTreeRegressionModel treeModel = | ||
101 | - (DecisionTreeRegressionModel) (p_model.stages()[1]); | ||
102 | - System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); | ||
103 | - | ||
104 | - } | ||
105 | -} |
src/main/java/detact/Utill.java
deleted
100644 → 0
1 | -package detact; | ||
2 | - | ||
3 | -import org.apache.spark.sql.Dataset; | ||
4 | -import org.apache.spark.sql.Row; | ||
5 | -import org.apache.spark.sql.SparkSession; | ||
6 | - | ||
7 | -public class Utill { | ||
8 | - | ||
9 | - public static Dataset<Row> loadCSVDataSet(String path, SparkSession spark){ | ||
10 | - // Read SCV to DataSet | ||
11 | - return spark.read().format("com.databricks.spark.csv") | ||
12 | - .option("inferSchema", "true") | ||
13 | - .option("header", "true") | ||
14 | - .load(path); | ||
15 | - } | ||
16 | - | ||
17 | - public static void saveCSVDataSet(Dataset<Row> dataset, String path){ | ||
18 | - // Read SCV to DataSet | ||
19 | - dataset.write().format("com.databricks.spark.csv") | ||
20 | - .option("inferSchema", "true") | ||
21 | - .option("header", "true") | ||
22 | - .save(path); | ||
23 | - } | ||
24 | -} |
-
Please register or login to post a comment