신은섭(Shin Eun Seop)

add decision tree ml model

Java-Cesco/Detecting_fraud_clicks/#10
...@@ -74,4 +74,9 @@ fabric.properties ...@@ -74,4 +74,9 @@ fabric.properties
74 *.rar 74 *.rar
75 75
76 # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 76 # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
77 -hs_err_pid*
...\ No newline at end of file ...\ No newline at end of file
77 +hs_err_pid*
78 +
79 +
80 +# datafile
81 +train.zip
82 +train.csv
...\ No newline at end of file ...\ No newline at end of file
......
1 +package detact;
2 +
1 import org.apache.spark.sql.Dataset; 3 import org.apache.spark.sql.Dataset;
2 import org.apache.spark.sql.Row; 4 import org.apache.spark.sql.Row;
3 import org.apache.spark.sql.SparkSession; 5 import org.apache.spark.sql.SparkSession;
...@@ -5,12 +7,13 @@ import org.apache.spark.sql.expressions.Window; ...@@ -5,12 +7,13 @@ import org.apache.spark.sql.expressions.Window;
5 import org.apache.spark.sql.expressions.WindowSpec; 7 import org.apache.spark.sql.expressions.WindowSpec;
6 8
7 import static org.apache.spark.sql.functions.*; 9 import static org.apache.spark.sql.functions.*;
8 -import static org.apache.spark.sql.functions.lit;
9 -import static org.apache.spark.sql.functions.when;
10 10
11 public class Aggregation { 11 public class Aggregation {
12 +
13 + public static String AGGREGATED_PATH = "agg_data";
14 + public static String ORIGINAL_DATA_PATH = "train_sample.csv";
12 15
13 - public static void main(String[] args) throws Exception { 16 + public static void main(String[] args) {
14 17
15 //Create Session 18 //Create Session
16 SparkSession spark = SparkSession 19 SparkSession spark = SparkSession
...@@ -19,10 +22,10 @@ public class Aggregation { ...@@ -19,10 +22,10 @@ public class Aggregation {
19 .master("local") 22 .master("local")
20 .getOrCreate(); 23 .getOrCreate();
21 24
22 - // Aggregation 25 + // detact.Aggregation
23 Aggregation agg = new Aggregation(); 26 Aggregation agg = new Aggregation();
24 27
25 - Dataset<Row> dataset = Utill.loadCSVDataSet("./train_sample.csv", spark); 28 + Dataset<Row> dataset = Utill.loadCSVDataSet(Aggregation.ORIGINAL_DATA_PATH, spark);
26 dataset = agg.changeTimestempToLong(dataset); 29 dataset = agg.changeTimestempToLong(dataset);
27 dataset = agg.averageValidClickCount(dataset); 30 dataset = agg.averageValidClickCount(dataset);
28 dataset = agg.clickTimeDelta(dataset); 31 dataset = agg.clickTimeDelta(dataset);
...@@ -32,7 +35,7 @@ public class Aggregation { ...@@ -32,7 +35,7 @@ public class Aggregation {
32 dataset.where("ip == '5348' and app == '19'").show(10); 35 dataset.where("ip == '5348' and app == '19'").show(10);
33 36
34 // Save to scv 37 // Save to scv
35 - Utill.saveCSVDataSet(dataset, "./agg_data"); 38 + Utill.saveCSVDataSet(dataset, Aggregation.AGGREGATED_PATH);
36 } 39 }
37 40
38 private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ 41 private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
...@@ -75,7 +78,7 @@ public class Aggregation { ...@@ -75,7 +78,7 @@ public class Aggregation {
75 .rangeBetween(Window.currentRow(),Window.currentRow()+600); 78 .rangeBetween(Window.currentRow(),Window.currentRow()+600);
76 79
77 Dataset<Row> newDF = dataset.withColumn("count_click_in_ten_mins", 80 Dataset<Row> newDF = dataset.withColumn("count_click_in_ten_mins",
78 - (count("utc_click_time").over(w)).minus(1)); //TODO 본인것 포함할 것인지 정해야함. 81 + (count("utc_click_time").over(w)).minus(1));
79 return newDF; 82 return newDF;
80 } 83 }
81 84
......
1 -import org.apache.spark.SparkConf; 1 +package detact.ML;
2 -import org.apache.spark.api.java.JavaRDD; 2 +
3 -import org.apache.spark.api.java.JavaSparkContext; 3 +import detact.Aggregation;
4 -import org.apache.spark.api.java.function.Function; 4 +import detact.Utill;
5 import org.apache.spark.ml.Pipeline; 5 import org.apache.spark.ml.Pipeline;
6 import org.apache.spark.ml.PipelineModel; 6 import org.apache.spark.ml.PipelineModel;
7 import org.apache.spark.ml.PipelineStage; 7 import org.apache.spark.ml.PipelineStage;
...@@ -12,35 +12,47 @@ import org.apache.spark.ml.feature.VectorIndexerModel; ...@@ -12,35 +12,47 @@ import org.apache.spark.ml.feature.VectorIndexerModel;
12 import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 12 import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
13 import org.apache.spark.ml.regression.DecisionTreeRegressor; 13 import org.apache.spark.ml.regression.DecisionTreeRegressor;
14 import org.apache.spark.sql.Dataset; 14 import org.apache.spark.sql.Dataset;
15 -import org.apache.spark.sql.Encoders;
16 import org.apache.spark.sql.Row; 15 import org.apache.spark.sql.Row;
17 -import org.apache.spark.sql.SQLContext; 16 +import org.apache.spark.sql.SparkSession;
18 -import scala.Serializable;
19 -
20 -import java.util.*;
21 17
22 18
23 -// ml 19 +// DecisionTree Model
24 20
25 -public class MapExample { 21 +public class DecisionTree {
26 22
27 public static void main(String[] args) throws Exception { 23 public static void main(String[] args) throws Exception {
24 +
25 + //Create Session
26 + SparkSession spark = SparkSession
27 + .builder()
28 + .appName("Detecting Fraud Clicks")
29 + .master("local")
30 + .getOrCreate();
28 31
29 - // Automatically identify categorical features, and index them. 32 + // load aggregated dataset
30 - // Set maxCategories so features with > 4 distinct values are treated as continuous. 33 + Dataset<Row> resultds = Utill.loadCSVDataSet(Aggregation.AGGREGATED_PATH, spark);
31 -
32 - Aggregation agg = new Aggregation();
33 -
34 - agg.
35 -
36 - Dataset<Row> resultds = sqlContext.createDataFrame(result);
37 34
38 - System.out.println("schema start"); 35 + // show Dataset schema
39 - resultds.printSchema(); 36 +// System.out.println("schema start");
40 - System.out.println("schema end"); 37 +// resultds.printSchema();
38 +// String[] cols = resultds.columns();
39 +// for (String col : cols) {
40 +// System.out.println(col);
41 +// }
42 +// System.out.println("schema end");
41 43
42 VectorAssembler assembler = new VectorAssembler() 44 VectorAssembler assembler = new VectorAssembler()
43 - .setInputCols(new String[]{"ip", "app", "device", "os", "channel", "clickInTenMins"}) 45 + .setInputCols(new String[]{
46 + "ip",
47 + "app",
48 + "device",
49 + "os",
50 + "channel",
51 + "utc_click_time",
52 + "avg_valid_click_count",
53 + "click_time_delta",
54 + "count_click_in_ten_mins"
55 + })
44 .setOutputCol("features"); 56 .setOutputCol("features");
45 57
46 Dataset<Row> output = assembler.transform(resultds); 58 Dataset<Row> output = assembler.transform(resultds);
...@@ -56,9 +68,11 @@ public class MapExample { ...@@ -56,9 +68,11 @@ public class MapExample {
56 Dataset<Row> trainingData = splits[0]; 68 Dataset<Row> trainingData = splits[0];
57 Dataset<Row> testData = splits[1]; 69 Dataset<Row> testData = splits[1];
58 70
59 - // Train a DecisionTree model. 71 + // Train a detact.DecisionTreeionTree model.
60 DecisionTreeRegressor dt = new DecisionTreeRegressor() 72 DecisionTreeRegressor dt = new DecisionTreeRegressor()
61 - .setFeaturesCol("indexedFeatures").setLabelCol("attributed"); 73 + .setFeaturesCol("indexedFeatures")
74 + .setLabelCol("is_attributed")
75 + .setMaxDepth(10);
62 76
63 // Chain indexer and tree in a Pipeline. 77 // Chain indexer and tree in a Pipeline.
64 Pipeline pipeline = new Pipeline() 78 Pipeline pipeline = new Pipeline()
...@@ -71,19 +85,20 @@ public class MapExample { ...@@ -71,19 +85,20 @@ public class MapExample {
71 Dataset<Row> predictions = model.transform(testData); 85 Dataset<Row> predictions = model.transform(testData);
72 86
73 // Select example rows to display. 87 // Select example rows to display.
74 - predictions.select("attributed", "features").show(5); 88 + predictions.select("is_attributed", "features").show(5);
75 89
76 // Select (prediction, true label) and compute test error. 90 // Select (prediction, true label) and compute test error.
77 RegressionEvaluator evaluator = new RegressionEvaluator() 91 RegressionEvaluator evaluator = new RegressionEvaluator()
78 - .setLabelCol("attributed") 92 + .setLabelCol("is_attributed")
79 .setPredictionCol("prediction") 93 .setPredictionCol("prediction")
80 .setMetricName("rmse"); 94 .setMetricName("rmse");
81 double rmse = evaluator.evaluate(predictions); 95 double rmse = evaluator.evaluate(predictions);
82 System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse); 96 System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
83 - 97 +
84 DecisionTreeRegressionModel treeModel = 98 DecisionTreeRegressionModel treeModel =
85 (DecisionTreeRegressionModel) (model.stages()[1]); 99 (DecisionTreeRegressionModel) (model.stages()[1]);
86 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); 100 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
87 101
88 } 102 }
103 +
89 } 104 }
......
1 +package detact;
2 +
1 import org.apache.spark.sql.Dataset; 3 import org.apache.spark.sql.Dataset;
2 import org.apache.spark.sql.Row; 4 import org.apache.spark.sql.Row;
3 import org.apache.spark.sql.SparkSession; 5 import org.apache.spark.sql.SparkSession;
......