Showing
3 changed files
with
130 additions
and
4 deletions
... | @@ -39,6 +39,7 @@ | ... | @@ -39,6 +39,7 @@ |
39 | <groupId>org.apache.maven.plugins</groupId> | 39 | <groupId>org.apache.maven.plugins</groupId> |
40 | <artifactId>maven-shade-plugin</artifactId> | 40 | <artifactId>maven-shade-plugin</artifactId> |
41 | <executions> | 41 | <executions> |
42 | + <!-- Aggregation --> | ||
42 | <execution> | 43 | <execution> |
43 | <id>aggregation</id> | 44 | <id>aggregation</id> |
44 | <goals> | 45 | <goals> |
... | @@ -64,6 +65,7 @@ | ... | @@ -64,6 +65,7 @@ |
64 | </filters> | 65 | </filters> |
65 | </configuration> | 66 | </configuration> |
66 | </execution> | 67 | </execution> |
68 | + <!-- Decision Tree --> | ||
67 | <execution> | 69 | <execution> |
68 | <id>decisionTree</id> | 70 | <id>decisionTree</id> |
69 | <goals> | 71 | <goals> |
... | @@ -89,6 +91,32 @@ | ... | @@ -89,6 +91,32 @@ |
89 | </filters> | 91 | </filters> |
90 | </configuration> | 92 | </configuration> |
91 | </execution> | 93 | </execution> |
94 | + <!-- Main --> | ||
95 | + <execution> | ||
96 | + <id>Main</id> | ||
97 | + <goals> | ||
98 | + <goal>shade</goal> | ||
99 | + </goals> | ||
100 | + <configuration> | ||
101 | + <outputFile>target/assembly/${project.artifactId}-main.jar</outputFile> | ||
102 | + <shadedArtifactAttached>true</shadedArtifactAttached> | ||
103 | + <transformers> | ||
104 | + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> | ||
105 | + <mainClass>detact.Main</mainClass> | ||
106 | + </transformer> | ||
107 | + </transformers> | ||
108 | + <filters> | ||
109 | + <filter> | ||
110 | + <artifact>*:*</artifact> | ||
111 | + <excludes> | ||
112 | + <exclude>META-INF/*.SF</exclude> | ||
113 | + <exclude>META-INF/*.DSA</exclude> | ||
114 | + <exclude>META-INF/*.RSA</exclude> | ||
115 | + </excludes> | ||
116 | + </filter> | ||
117 | + </filters> | ||
118 | + </configuration> | ||
119 | + </execution> | ||
92 | </executions> | 120 | </executions> |
93 | </plugin> | 121 | </plugin> |
94 | </plugins> | 122 | </plugins> | ... | ... |
... | @@ -43,7 +43,7 @@ public class Aggregation { | ... | @@ -43,7 +43,7 @@ public class Aggregation { |
43 | Utill.saveCSVDataSet(dataset, result_path); | 43 | Utill.saveCSVDataSet(dataset, result_path); |
44 | } | 44 | } |
45 | 45 | ||
46 | - private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ | 46 | + public Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ |
47 | // cast timestamp to long | 47 | // cast timestamp to long |
48 | Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); | 48 | Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); |
49 | newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); | 49 | newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); |
... | @@ -51,7 +51,7 @@ public class Aggregation { | ... | @@ -51,7 +51,7 @@ public class Aggregation { |
51 | return newDF; | 51 | return newDF; |
52 | } | 52 | } |
53 | 53 | ||
54 | - private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ | 54 | + public Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ |
55 | // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row | 55 | // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row |
56 | WindowSpec w = Window.partitionBy("ip", "app") | 56 | WindowSpec w = Window.partitionBy("ip", "app") |
57 | .orderBy("utc_click_time") | 57 | .orderBy("utc_click_time") |
... | @@ -65,7 +65,7 @@ public class Aggregation { | ... | @@ -65,7 +65,7 @@ public class Aggregation { |
65 | return newDF; | 65 | return newDF; |
66 | } | 66 | } |
67 | 67 | ||
68 | - private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ | 68 | + public Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ |
69 | WindowSpec w = Window.partitionBy ("ip") | 69 | WindowSpec w = Window.partitionBy ("ip") |
70 | .orderBy("utc_click_time"); | 70 | .orderBy("utc_click_time"); |
71 | 71 | ||
... | @@ -77,7 +77,7 @@ public class Aggregation { | ... | @@ -77,7 +77,7 @@ public class Aggregation { |
77 | return newDF; | 77 | return newDF; |
78 | } | 78 | } |
79 | 79 | ||
80 | - private Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){ | 80 | + public Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){ |
81 | WindowSpec w = Window.partitionBy("ip") | 81 | WindowSpec w = Window.partitionBy("ip") |
82 | .orderBy("utc_click_time") | 82 | .orderBy("utc_click_time") |
83 | .rangeBetween(Window.currentRow(),Window.currentRow()+600); | 83 | .rangeBetween(Window.currentRow(),Window.currentRow()+600); | ... | ... |
1 | package detact; | 1 | package detact; |
2 | 2 | ||
3 | +import org.apache.spark.ml.Pipeline; | ||
4 | +import org.apache.spark.ml.PipelineModel; | ||
5 | +import org.apache.spark.ml.PipelineStage; | ||
6 | +import org.apache.spark.ml.evaluation.RegressionEvaluator; | ||
7 | +import org.apache.spark.ml.feature.VectorAssembler; | ||
8 | +import org.apache.spark.ml.feature.VectorIndexer; | ||
9 | +import org.apache.spark.ml.feature.VectorIndexerModel; | ||
10 | +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; | ||
11 | +import org.apache.spark.ml.regression.DecisionTreeRegressor; | ||
12 | +import org.apache.spark.sql.Dataset; | ||
13 | +import org.apache.spark.sql.Row; | ||
14 | +import org.apache.spark.sql.SparkSession; | ||
15 | + | ||
3 | public class Main { | 16 | public class Main { |
17 | + public static void main(String[] args) throws Exception{ | ||
18 | + if (args.length != 1) { | ||
19 | + System.out.println("Usage: java -jar aggregation.jar <data_path>"); | ||
20 | + System.exit(0); | ||
21 | + } | ||
22 | + | ||
23 | + String data_path = args[0]; | ||
24 | + | ||
25 | + //Create Session | ||
26 | + SparkSession spark = SparkSession | ||
27 | + .builder() | ||
28 | + .appName("Detecting Fraud Clicks") | ||
29 | + .master("local") | ||
30 | + .getOrCreate(); | ||
31 | + | ||
32 | + // detact.Aggregation | ||
33 | + Aggregation agg = new Aggregation(); | ||
34 | + | ||
35 | + Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark); | ||
36 | + dataset = agg.changeTimestempToLong(dataset); | ||
37 | + dataset = agg.averageValidClickCount(dataset); | ||
38 | + dataset = agg.clickTimeDelta(dataset); | ||
39 | + dataset = agg.countClickInTenMinutes(dataset); | ||
40 | + | ||
41 | + VectorAssembler assembler = new VectorAssembler() | ||
42 | + .setInputCols(new String[]{ | ||
43 | + "ip", | ||
44 | + "app", | ||
45 | + "device", | ||
46 | + "os", | ||
47 | + "channel", | ||
48 | + "utc_click_time", | ||
49 | + "avg_valid_click_count", | ||
50 | + "click_time_delta", | ||
51 | + "count_click_in_ten_mins" | ||
52 | + }) | ||
53 | + .setOutputCol("features"); | ||
54 | + | ||
55 | + Dataset<Row> output = assembler.transform(dataset); | ||
56 | + | ||
57 | + VectorIndexerModel featureIndexer = new VectorIndexer() | ||
58 | + .setInputCol("features") | ||
59 | + .setOutputCol("indexedFeatures") | ||
60 | + .setMaxCategories(2) | ||
61 | + .fit(output); | ||
62 | + | ||
63 | + // Split the result into training and test sets (30% held out for testing). | ||
64 | + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3}); | ||
65 | + Dataset<Row> trainingData = splits[0]; | ||
66 | + Dataset<Row> testData = splits[1]; | ||
67 | + | ||
68 | + // Train a detact.DecisionTreeionTree model. | ||
69 | + DecisionTreeRegressor dt = new DecisionTreeRegressor() | ||
70 | + .setFeaturesCol("indexedFeatures") | ||
71 | + .setLabelCol("is_attributed") | ||
72 | + .setMaxDepth(10); | ||
73 | + | ||
74 | + // Chain indexer and tree in a Pipeline. | ||
75 | + Pipeline pipeline = new Pipeline() | ||
76 | + .setStages(new PipelineStage[]{featureIndexer, dt}); | ||
77 | + | ||
78 | + // Train model. This also runs the indexer. | ||
79 | + PipelineModel model = pipeline.fit(trainingData); | ||
80 | + | ||
81 | + // Make predictions. | ||
82 | + Dataset<Row> predictions = model.transform(testData); | ||
83 | + | ||
84 | + // Select example rows to display. | ||
85 | + predictions.select("is_attributed", "features").show(5); | ||
86 | + | ||
87 | + // Select (prediction, true label) and compute test error. | ||
88 | + RegressionEvaluator evaluator = new RegressionEvaluator() | ||
89 | + .setLabelCol("is_attributed") | ||
90 | + .setPredictionCol("prediction") | ||
91 | + .setMetricName("rmse"); | ||
92 | + double rmse = evaluator.evaluate(predictions); | ||
93 | + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse); | ||
94 | + | ||
95 | + DecisionTreeRegressionModel treeModel = | ||
96 | + (DecisionTreeRegressionModel) (model.stages()[1]); | ||
97 | + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); | ||
98 | + | ||
99 | + // save model | ||
100 | + model.save("./decisionTree.model"); | ||
101 | + } | ||
4 | } | 102 | } | ... | ... |
-
Please register or login to post a comment