EC2 Default User
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
39 <groupId>org.apache.maven.plugins</groupId> 39 <groupId>org.apache.maven.plugins</groupId>
40 <artifactId>maven-shade-plugin</artifactId> 40 <artifactId>maven-shade-plugin</artifactId>
41 <executions> 41 <executions>
42 + <!-- Aggregation -->
42 <execution> 43 <execution>
43 <id>aggregation</id> 44 <id>aggregation</id>
44 <goals> 45 <goals>
...@@ -51,6 +52,7 @@ ...@@ -51,6 +52,7 @@
51 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> 52 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
52 <mainClass>detact.Aggregation</mainClass> 53 <mainClass>detact.Aggregation</mainClass>
53 </transformer> 54 </transformer>
55 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
54 </transformers> 56 </transformers>
55 <filters> 57 <filters>
56 <filter> 58 <filter>
...@@ -64,6 +66,7 @@ ...@@ -64,6 +66,7 @@
64 </filters> 66 </filters>
65 </configuration> 67 </configuration>
66 </execution> 68 </execution>
69 + <!-- Decision Tree -->
67 <execution> 70 <execution>
68 <id>decisionTree</id> 71 <id>decisionTree</id>
69 <goals> 72 <goals>
...@@ -76,6 +79,34 @@ ...@@ -76,6 +79,34 @@
76 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> 79 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
77 <mainClass>detact.ML.DecisionTree</mainClass> 80 <mainClass>detact.ML.DecisionTree</mainClass>
78 </transformer> 81 </transformer>
82 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
83 + </transformers>
84 + <filters>
85 + <filter>
86 + <artifact>*:*</artifact>
87 + <excludes>
88 + <exclude>META-INF/*.SF</exclude>
89 + <exclude>META-INF/*.DSA</exclude>
90 + <exclude>META-INF/*.RSA</exclude>
91 + </excludes>
92 + </filter>
93 + </filters>
94 + </configuration>
95 + </execution>
96 + <!-- Main -->
97 + <execution>
98 + <id>Main</id>
99 + <goals>
100 + <goal>shade</goal>
101 + </goals>
102 + <configuration>
103 + <outputFile>target/assembly/${project.artifactId}-main.jar</outputFile>
104 + <shadedArtifactAttached>true</shadedArtifactAttached>
105 + <transformers>
106 + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
107 + <mainClass>detact.Main</mainClass>
108 + </transformer>
109 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
79 </transformers> 110 </transformers>
80 <filters> 111 <filters>
81 <filter> 112 <filter>
......
...@@ -44,7 +44,7 @@ public class Aggregation { ...@@ -44,7 +44,7 @@ public class Aggregation {
44 Utill.saveCSVDataSet(dataset, result_path); 44 Utill.saveCSVDataSet(dataset, result_path);
45 } 45 }
46 46
47 - private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ 47 + public Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
48 // cast timestamp to long 48 // cast timestamp to long
49 Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); 49 Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long"));
50 newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); 50 newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long"));
...@@ -52,7 +52,7 @@ public class Aggregation { ...@@ -52,7 +52,7 @@ public class Aggregation {
52 return newDF; 52 return newDF;
53 } 53 }
54 54
55 - private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ 55 + public Dataset<Row> averageValidClickCount(Dataset<Row> dataset){
56 // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row 56 // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row
57 WindowSpec w = Window.partitionBy("ip", "app") 57 WindowSpec w = Window.partitionBy("ip", "app")
58 .orderBy("utc_click_time") 58 .orderBy("utc_click_time")
...@@ -66,7 +66,7 @@ public class Aggregation { ...@@ -66,7 +66,7 @@ public class Aggregation {
66 return newDF; 66 return newDF;
67 } 67 }
68 68
69 - private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ 69 + public Dataset<Row> clickTimeDelta(Dataset<Row> dataset){
70 WindowSpec w = Window.partitionBy ("ip") 70 WindowSpec w = Window.partitionBy ("ip")
71 .orderBy("utc_click_time"); 71 .orderBy("utc_click_time");
72 72
...@@ -78,7 +78,7 @@ public class Aggregation { ...@@ -78,7 +78,7 @@ public class Aggregation {
78 return newDF; 78 return newDF;
79 } 79 }
80 80
81 - private Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){ 81 + public Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){
82 WindowSpec w = Window.partitionBy("ip") 82 WindowSpec w = Window.partitionBy("ip")
83 .orderBy("utc_click_time") 83 .orderBy("utc_click_time")
84 .rangeBetween(Window.currentRow(),Window.currentRow()+600); 84 .rangeBetween(Window.currentRow(),Window.currentRow()+600);
......
...@@ -107,6 +107,15 @@ public class DecisionTree { ...@@ -107,6 +107,15 @@ public class DecisionTree {
107 (DecisionTreeRegressionModel) (model.stages()[1]); 107 (DecisionTreeRegressionModel) (model.stages()[1]);
108 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); 108 System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
109 109
110 + // save model
111 + model.save("./decisionTree");
112 +
113 + // load model
114 + PipelineModel load_mode = PipelineModel.load("./decisionTree");
115 +
116 + // Make predictions.
117 + Dataset<Row> load_pred = model.transform(testData);
118 +
110 } 119 }
111 120
112 } 121 }
......
1 +package detact;
2 +
3 +import org.apache.spark.ml.Pipeline;
4 +import org.apache.spark.ml.PipelineModel;
5 +import org.apache.spark.ml.PipelineStage;
6 +import org.apache.spark.ml.evaluation.RegressionEvaluator;
7 +import org.apache.spark.ml.feature.VectorAssembler;
8 +import org.apache.spark.ml.feature.VectorIndexer;
9 +import org.apache.spark.ml.feature.VectorIndexerModel;
10 +import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
11 +import org.apache.spark.ml.regression.DecisionTreeRegressor;
12 +import org.apache.spark.sql.Dataset;
13 +import org.apache.spark.sql.Row;
14 +import org.apache.spark.sql.SparkSession;
15 +
16 +public class Main {
17 + public static void main(String[] args) throws Exception{
18 + if (args.length != 1) {
19 + System.out.println("Usage: java -jar aggregation.jar <data_path>");
20 + System.exit(0);
21 + }
22 +
23 + String data_path = args[0];
24 +
25 + //Create Session
26 + SparkSession spark = SparkSession
27 + .builder()
28 + .appName("Detecting Fraud Clicks")
29 + .master("local")
30 + .getOrCreate();
31 +
32 + // detact.Aggregation
33 + Aggregation agg = new Aggregation();
34 +
35 + Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark);
36 + dataset = agg.changeTimestempToLong(dataset);
37 + dataset = agg.averageValidClickCount(dataset);
38 + dataset = agg.clickTimeDelta(dataset);
39 + dataset = agg.countClickInTenMinutes(dataset);
40 +
41 + VectorAssembler assembler = new VectorAssembler()
42 + .setInputCols(new String[]{
43 + "ip",
44 + "app",
45 + "device",
46 + "os",
47 + "channel",
48 + "utc_click_time",
49 + "avg_valid_click_count",
50 + "click_time_delta",
51 + "count_click_in_ten_mins"
52 + })
53 + .setOutputCol("features");
54 +
55 + Dataset<Row> output = assembler.transform(dataset);
56 +
57 + VectorIndexerModel featureIndexer = new VectorIndexer()
58 + .setInputCol("features")
59 + .setOutputCol("indexedFeatures")
60 + .setMaxCategories(2)
61 + .fit(output);
62 +
63 + // Split the result into training and test sets (30% held out for testing).
64 + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
65 + Dataset<Row> trainingData = splits[0];
66 + Dataset<Row> testData = splits[1];
67 +
68 + // Train a detact.DecisionTreeionTree model.
69 + DecisionTreeRegressor dt = new DecisionTreeRegressor()
70 + .setFeaturesCol("indexedFeatures")
71 + .setLabelCol("is_attributed")
72 + .setMaxDepth(10);
73 +
74 + // Chain indexer and tree in a Pipeline.
75 + Pipeline pipeline = new Pipeline()
76 + .setStages(new PipelineStage[]{featureIndexer, dt});
77 +
78 + // Train model. This also runs the indexer.
79 + PipelineModel model = pipeline.fit(trainingData);
80 +
81 + // save model
82 + model.save("./decisionTree");
83 +
84 + PipelineModel p_model = PipelineModel.load("./decisionTree");
85 +
86 + // Make predictions.
87 + Dataset<Row> predictions = p_model.transform(testData);
88 +
89 + // Select example rows to display.
90 + predictions.select("is_attributed", "features").show(5);
91 +
92 + // Select (prediction, true label) and compute test error.
93 + RegressionEvaluator evaluator = new RegressionEvaluator()
94 + .setLabelCol("is_attributed")
95 + .setPredictionCol("prediction")
96 + .setMetricName("rmse");
97 + double rmse = evaluator.evaluate(predictions);
98 + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
99 +
100 + DecisionTreeRegressionModel treeModel =
101 + (DecisionTreeRegressionModel) (p_model.stages()[1]);
102 + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
103 +
104 + }
105 +}