신은섭(Shin Eun Seop)

add Main.java

...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
39 <groupId>org.apache.maven.plugins</groupId> 39 <groupId>org.apache.maven.plugins</groupId>
40 <artifactId>maven-shade-plugin</artifactId> 40 <artifactId>maven-shade-plugin</artifactId>
41 <executions> 41 <executions>
42 + <!-- Aggregation -->
42 <execution> 43 <execution>
43 <id>aggregation</id> 44 <id>aggregation</id>
44 <goals> 45 <goals>
...@@ -64,6 +65,7 @@ ...@@ -64,6 +65,7 @@
64 </filters> 65 </filters>
65 </configuration> 66 </configuration>
66 </execution> 67 </execution>
68 + <!-- Decision Tree -->
67 <execution> 69 <execution>
68 <id>decisionTree</id> 70 <id>decisionTree</id>
69 <goals> 71 <goals>
...@@ -89,6 +91,32 @@ ...@@ -89,6 +91,32 @@
89 </filters> 91 </filters>
90 </configuration> 92 </configuration>
91 </execution> 93 </execution>
94 + <!-- Main -->
95 + <execution>
96 + <id>Main</id>
97 + <goals>
98 + <goal>shade</goal>
99 + </goals>
100 + <configuration>
101 + <outputFile>target/assembly/${project.artifactId}-main.jar</outputFile>
102 + <shadedArtifactAttached>true</shadedArtifactAttached>
103 + <transformers>
104 + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
105 + <mainClass>detact.Main</mainClass>
106 + </transformer>
107 + </transformers>
108 + <filters>
109 + <filter>
110 + <artifact>*:*</artifact>
111 + <excludes>
112 + <exclude>META-INF/*.SF</exclude>
113 + <exclude>META-INF/*.DSA</exclude>
114 + <exclude>META-INF/*.RSA</exclude>
115 + </excludes>
116 + </filter>
117 + </filters>
118 + </configuration>
119 + </execution>
92 </executions> 120 </executions>
93 </plugin> 121 </plugin>
94 </plugins> 122 </plugins>
......
...@@ -43,7 +43,7 @@ public class Aggregation { ...@@ -43,7 +43,7 @@ public class Aggregation {
43 Utill.saveCSVDataSet(dataset, result_path); 43 Utill.saveCSVDataSet(dataset, result_path);
44 } 44 }
45 45
46 - private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ 46 + public Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
47 // cast timestamp to long 47 // cast timestamp to long
48 Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); 48 Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long"));
49 newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); 49 newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long"));
...@@ -51,7 +51,7 @@ public class Aggregation { ...@@ -51,7 +51,7 @@ public class Aggregation {
51 return newDF; 51 return newDF;
52 } 52 }
53 53
54 - private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ 54 + public Dataset<Row> averageValidClickCount(Dataset<Row> dataset){
55 // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row 55 // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row
56 WindowSpec w = Window.partitionBy("ip", "app") 56 WindowSpec w = Window.partitionBy("ip", "app")
57 .orderBy("utc_click_time") 57 .orderBy("utc_click_time")
...@@ -65,7 +65,7 @@ public class Aggregation { ...@@ -65,7 +65,7 @@ public class Aggregation {
65 return newDF; 65 return newDF;
66 } 66 }
67 67
68 - private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ 68 + public Dataset<Row> clickTimeDelta(Dataset<Row> dataset){
69 WindowSpec w = Window.partitionBy ("ip") 69 WindowSpec w = Window.partitionBy ("ip")
70 .orderBy("utc_click_time"); 70 .orderBy("utc_click_time");
71 71
...@@ -77,7 +77,7 @@ public class Aggregation { ...@@ -77,7 +77,7 @@ public class Aggregation {
77 return newDF; 77 return newDF;
78 } 78 }
79 79
80 - private Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){ 80 + public Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){
81 WindowSpec w = Window.partitionBy("ip") 81 WindowSpec w = Window.partitionBy("ip")
82 .orderBy("utc_click_time") 82 .orderBy("utc_click_time")
83 .rangeBetween(Window.currentRow(),Window.currentRow()+600); 83 .rangeBetween(Window.currentRow(),Window.currentRow()+600);
......
1 package detact; 1 package detact;
2 2
3 +import org.apache.spark.ml.Pipeline;
4 +import org.apache.spark.ml.PipelineModel;
5 +import org.apache.spark.ml.PipelineStage;
6 +import org.apache.spark.ml.evaluation.RegressionEvaluator;
7 +import org.apache.spark.ml.feature.VectorAssembler;
8 +import org.apache.spark.ml.feature.VectorIndexer;
9 +import org.apache.spark.ml.feature.VectorIndexerModel;
10 +import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
11 +import org.apache.spark.ml.regression.DecisionTreeRegressor;
12 +import org.apache.spark.sql.Dataset;
13 +import org.apache.spark.sql.Row;
14 +import org.apache.spark.sql.SparkSession;
15 +
3 public class Main { 16 public class Main {
17 + public static void main(String[] args) throws Exception{
18 + if (args.length != 1) {
19 + System.out.println("Usage: java -jar aggregation.jar <data_path>");
20 + System.exit(0);
21 + }
22 +
23 + String data_path = args[0];
24 +
25 + //Create Session
26 + SparkSession spark = SparkSession
27 + .builder()
28 + .appName("Detecting Fraud Clicks")
29 + .master("local")
30 + .getOrCreate();
31 +
32 + // detact.Aggregation
33 + Aggregation agg = new Aggregation();
34 +
35 + Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark);
36 + dataset = agg.changeTimestempToLong(dataset);
37 + dataset = agg.averageValidClickCount(dataset);
38 + dataset = agg.clickTimeDelta(dataset);
39 + dataset = agg.countClickInTenMinutes(dataset);
40 +
41 + VectorAssembler assembler = new VectorAssembler()
42 + .setInputCols(new String[]{
43 + "ip",
44 + "app",
45 + "device",
46 + "os",
47 + "channel",
48 + "utc_click_time",
49 + "avg_valid_click_count",
50 + "click_time_delta",
51 + "count_click_in_ten_mins"
52 + })
53 + .setOutputCol("features");
54 +
55 + Dataset<Row> output = assembler.transform(dataset);
56 +
57 + VectorIndexerModel featureIndexer = new VectorIndexer()
58 + .setInputCol("features")
59 + .setOutputCol("indexedFeatures")
60 + .setMaxCategories(2)
61 + .fit(output);
62 +
63 + // Split the result into training and test sets (30% held out for testing).
64 + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
65 + Dataset<Row> trainingData = splits[0];
66 + Dataset<Row> testData = splits[1];
67 +
68 + // Train a detact.DecisionTreeionTree model.
69 + DecisionTreeRegressor dt = new DecisionTreeRegressor()
70 + .setFeaturesCol("indexedFeatures")
71 + .setLabelCol("is_attributed")
72 + .setMaxDepth(10);
73 +
74 + // Chain indexer and tree in a Pipeline.
75 + Pipeline pipeline = new Pipeline()
76 + .setStages(new PipelineStage[]{featureIndexer, dt});
77 +
78 + // Train model. This also runs the indexer.
79 + PipelineModel model = pipeline.fit(trainingData);
80 +
81 + // Make predictions.
82 + Dataset<Row> predictions = model.transform(testData);
83 +
84 + // Select example rows to display.
85 + predictions.select("is_attributed", "features").show(5);
86 +
87 + // Select (prediction, true label) and compute test error.
88 + RegressionEvaluator evaluator = new RegressionEvaluator()
89 + .setLabelCol("is_attributed")
90 + .setPredictionCol("prediction")
91 + .setMetricName("rmse");
92 + double rmse = evaluator.evaluate(predictions);
93 + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
94 +
95 + DecisionTreeRegressionModel treeModel =
96 + (DecisionTreeRegressionModel) (model.stages()[1]);
97 + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
98 +
99 + // save model
100 + model.save("./decisionTree.model");
101 + }
4 } 102 }
......