Merge branch 'ml' of https://github.com/Java-Cesco/Detecting_fraud_clicks into ml
Showing
4 changed files
with
149 additions
and
4 deletions
... | @@ -39,6 +39,7 @@ | ... | @@ -39,6 +39,7 @@ |
39 | <groupId>org.apache.maven.plugins</groupId> | 39 | <groupId>org.apache.maven.plugins</groupId> |
40 | <artifactId>maven-shade-plugin</artifactId> | 40 | <artifactId>maven-shade-plugin</artifactId> |
41 | <executions> | 41 | <executions> |
42 | + <!-- Aggregation --> | ||
42 | <execution> | 43 | <execution> |
43 | <id>aggregation</id> | 44 | <id>aggregation</id> |
44 | <goals> | 45 | <goals> |
... | @@ -51,6 +52,7 @@ | ... | @@ -51,6 +52,7 @@ |
51 | <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> | 52 | <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> |
52 | <mainClass>detact.Aggregation</mainClass> | 53 | <mainClass>detact.Aggregation</mainClass> |
53 | </transformer> | 54 | </transformer> |
55 | + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> | ||
54 | </transformers> | 56 | </transformers> |
55 | <filters> | 57 | <filters> |
56 | <filter> | 58 | <filter> |
... | @@ -64,6 +66,7 @@ | ... | @@ -64,6 +66,7 @@ |
64 | </filters> | 66 | </filters> |
65 | </configuration> | 67 | </configuration> |
66 | </execution> | 68 | </execution> |
69 | + <!-- Decision Tree --> | ||
67 | <execution> | 70 | <execution> |
68 | <id>decisionTree</id> | 71 | <id>decisionTree</id> |
69 | <goals> | 72 | <goals> |
... | @@ -76,6 +79,34 @@ | ... | @@ -76,6 +79,34 @@ |
76 | <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> | 79 | <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> |
77 | <mainClass>detact.ML.DecisionTree</mainClass> | 80 | <mainClass>detact.ML.DecisionTree</mainClass> |
78 | </transformer> | 81 | </transformer> |
82 | + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> | ||
83 | + </transformers> | ||
84 | + <filters> | ||
85 | + <filter> | ||
86 | + <artifact>*:*</artifact> | ||
87 | + <excludes> | ||
88 | + <exclude>META-INF/*.SF</exclude> | ||
89 | + <exclude>META-INF/*.DSA</exclude> | ||
90 | + <exclude>META-INF/*.RSA</exclude> | ||
91 | + </excludes> | ||
92 | + </filter> | ||
93 | + </filters> | ||
94 | + </configuration> | ||
95 | + </execution> | ||
96 | + <!-- Main --> | ||
97 | + <execution> | ||
98 | + <id>Main</id> | ||
99 | + <goals> | ||
100 | + <goal>shade</goal> | ||
101 | + </goals> | ||
102 | + <configuration> | ||
103 | + <outputFile>target/assembly/${project.artifactId}-main.jar</outputFile> | ||
104 | + <shadedArtifactAttached>true</shadedArtifactAttached> | ||
105 | + <transformers> | ||
106 | + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> | ||
107 | + <mainClass>detact.Main</mainClass> | ||
108 | + </transformer> | ||
109 | + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> | ||
79 | </transformers> | 110 | </transformers> |
80 | <filters> | 111 | <filters> |
81 | <filter> | 112 | <filter> | ... | ... |
... | @@ -44,7 +44,7 @@ public class Aggregation { | ... | @@ -44,7 +44,7 @@ public class Aggregation { |
44 | Utill.saveCSVDataSet(dataset, result_path); | 44 | Utill.saveCSVDataSet(dataset, result_path); |
45 | } | 45 | } |
46 | 46 | ||
47 | - private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ | 47 | + public Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ |
48 | // cast timestamp to long | 48 | // cast timestamp to long |
49 | Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); | 49 | Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); |
50 | newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); | 50 | newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); |
... | @@ -52,7 +52,7 @@ public class Aggregation { | ... | @@ -52,7 +52,7 @@ public class Aggregation { |
52 | return newDF; | 52 | return newDF; |
53 | } | 53 | } |
54 | 54 | ||
55 | - private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ | 55 | + public Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ |
56 | // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row | 56 | // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row |
57 | WindowSpec w = Window.partitionBy("ip", "app") | 57 | WindowSpec w = Window.partitionBy("ip", "app") |
58 | .orderBy("utc_click_time") | 58 | .orderBy("utc_click_time") |
... | @@ -66,7 +66,7 @@ public class Aggregation { | ... | @@ -66,7 +66,7 @@ public class Aggregation { |
66 | return newDF; | 66 | return newDF; |
67 | } | 67 | } |
68 | 68 | ||
69 | - private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ | 69 | + public Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ |
70 | WindowSpec w = Window.partitionBy ("ip") | 70 | WindowSpec w = Window.partitionBy ("ip") |
71 | .orderBy("utc_click_time"); | 71 | .orderBy("utc_click_time"); |
72 | 72 | ||
... | @@ -78,7 +78,7 @@ public class Aggregation { | ... | @@ -78,7 +78,7 @@ public class Aggregation { |
78 | return newDF; | 78 | return newDF; |
79 | } | 79 | } |
80 | 80 | ||
81 | - private Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){ | 81 | + public Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){ |
82 | WindowSpec w = Window.partitionBy("ip") | 82 | WindowSpec w = Window.partitionBy("ip") |
83 | .orderBy("utc_click_time") | 83 | .orderBy("utc_click_time") |
84 | .rangeBetween(Window.currentRow(),Window.currentRow()+600); | 84 | .rangeBetween(Window.currentRow(),Window.currentRow()+600); | ... | ... |
... | @@ -107,6 +107,15 @@ public class DecisionTree { | ... | @@ -107,6 +107,15 @@ public class DecisionTree { |
107 | (DecisionTreeRegressionModel) (model.stages()[1]); | 107 | (DecisionTreeRegressionModel) (model.stages()[1]); |
108 | System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); | 108 | System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); |
109 | 109 | ||
110 | + // save model | ||
111 | + model.save("./decisionTree"); | ||
112 | + | ||
113 | + // load model | ||
114 | + PipelineModel load_mode = PipelineModel.load("./decisionTree"); | ||
115 | + | ||
116 | + // Make predictions. | ||
117 | + Dataset<Row> load_pred = model.transform(testData); | ||
118 | + | ||
110 | } | 119 | } |
111 | 120 | ||
112 | } | 121 | } | ... | ... |
src/main/java/detact/Main.java
0 → 100644
1 | +package detact; | ||
2 | + | ||
3 | +import org.apache.spark.ml.Pipeline; | ||
4 | +import org.apache.spark.ml.PipelineModel; | ||
5 | +import org.apache.spark.ml.PipelineStage; | ||
6 | +import org.apache.spark.ml.evaluation.RegressionEvaluator; | ||
7 | +import org.apache.spark.ml.feature.VectorAssembler; | ||
8 | +import org.apache.spark.ml.feature.VectorIndexer; | ||
9 | +import org.apache.spark.ml.feature.VectorIndexerModel; | ||
10 | +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; | ||
11 | +import org.apache.spark.ml.regression.DecisionTreeRegressor; | ||
12 | +import org.apache.spark.sql.Dataset; | ||
13 | +import org.apache.spark.sql.Row; | ||
14 | +import org.apache.spark.sql.SparkSession; | ||
15 | + | ||
16 | +public class Main { | ||
17 | + public static void main(String[] args) throws Exception{ | ||
18 | + if (args.length != 1) { | ||
19 | + System.out.println("Usage: java -jar aggregation.jar <data_path>"); | ||
20 | + System.exit(0); | ||
21 | + } | ||
22 | + | ||
23 | + String data_path = args[0]; | ||
24 | + | ||
25 | + //Create Session | ||
26 | + SparkSession spark = SparkSession | ||
27 | + .builder() | ||
28 | + .appName("Detecting Fraud Clicks") | ||
29 | + .master("local") | ||
30 | + .getOrCreate(); | ||
31 | + | ||
32 | + // detact.Aggregation | ||
33 | + Aggregation agg = new Aggregation(); | ||
34 | + | ||
35 | + Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark); | ||
36 | + dataset = agg.changeTimestempToLong(dataset); | ||
37 | + dataset = agg.averageValidClickCount(dataset); | ||
38 | + dataset = agg.clickTimeDelta(dataset); | ||
39 | + dataset = agg.countClickInTenMinutes(dataset); | ||
40 | + | ||
41 | + VectorAssembler assembler = new VectorAssembler() | ||
42 | + .setInputCols(new String[]{ | ||
43 | + "ip", | ||
44 | + "app", | ||
45 | + "device", | ||
46 | + "os", | ||
47 | + "channel", | ||
48 | + "utc_click_time", | ||
49 | + "avg_valid_click_count", | ||
50 | + "click_time_delta", | ||
51 | + "count_click_in_ten_mins" | ||
52 | + }) | ||
53 | + .setOutputCol("features"); | ||
54 | + | ||
55 | + Dataset<Row> output = assembler.transform(dataset); | ||
56 | + | ||
57 | + VectorIndexerModel featureIndexer = new VectorIndexer() | ||
58 | + .setInputCol("features") | ||
59 | + .setOutputCol("indexedFeatures") | ||
60 | + .setMaxCategories(2) | ||
61 | + .fit(output); | ||
62 | + | ||
63 | + // Split the result into training and test sets (30% held out for testing). | ||
64 | + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3}); | ||
65 | + Dataset<Row> trainingData = splits[0]; | ||
66 | + Dataset<Row> testData = splits[1]; | ||
67 | + | ||
68 | + // Train a detact.DecisionTreeionTree model. | ||
69 | + DecisionTreeRegressor dt = new DecisionTreeRegressor() | ||
70 | + .setFeaturesCol("indexedFeatures") | ||
71 | + .setLabelCol("is_attributed") | ||
72 | + .setMaxDepth(10); | ||
73 | + | ||
74 | + // Chain indexer and tree in a Pipeline. | ||
75 | + Pipeline pipeline = new Pipeline() | ||
76 | + .setStages(new PipelineStage[]{featureIndexer, dt}); | ||
77 | + | ||
78 | + // Train model. This also runs the indexer. | ||
79 | + PipelineModel model = pipeline.fit(trainingData); | ||
80 | + | ||
81 | + // save model | ||
82 | + model.save("./decisionTree"); | ||
83 | + | ||
84 | + PipelineModel p_model = PipelineModel.load("./decisionTree"); | ||
85 | + | ||
86 | + // Make predictions. | ||
87 | + Dataset<Row> predictions = p_model.transform(testData); | ||
88 | + | ||
89 | + // Select example rows to display. | ||
90 | + predictions.select("is_attributed", "features").show(5); | ||
91 | + | ||
92 | + // Select (prediction, true label) and compute test error. | ||
93 | + RegressionEvaluator evaluator = new RegressionEvaluator() | ||
94 | + .setLabelCol("is_attributed") | ||
95 | + .setPredictionCol("prediction") | ||
96 | + .setMetricName("rmse"); | ||
97 | + double rmse = evaluator.evaluate(predictions); | ||
98 | + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse); | ||
99 | + | ||
100 | + DecisionTreeRegressionModel treeModel = | ||
101 | + (DecisionTreeRegressionModel) (p_model.stages()[1]); | ||
102 | + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); | ||
103 | + | ||
104 | + } | ||
105 | +} |
-
Please register or login to post a comment