신은섭(Shin Eun Seop)

add Main.java

......@@ -39,6 +39,7 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<executions>
<!-- Aggregation -->
<execution>
<id>aggregation</id>
<goals>
......@@ -64,6 +65,7 @@
</filters>
</configuration>
</execution>
<!-- Decision Tree -->
<execution>
<id>decisionTree</id>
<goals>
......@@ -89,6 +91,32 @@
</filters>
</configuration>
</execution>
<!-- Main -->
<execution>
<id>Main</id>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<outputFile>target/assembly/${project.artifactId}-main.jar</outputFile>
<shadedArtifactAttached>true</shadedArtifactAttached>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>detact.Main</mainClass>
</transformer>
</transformers>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
......
......@@ -43,7 +43,7 @@ public class Aggregation {
Utill.saveCSVDataSet(dataset, result_path);
}
private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
public Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
// cast timestamp to long
Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long"));
newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long"));
......@@ -51,7 +51,7 @@ public class Aggregation {
return newDF;
}
private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){
public Dataset<Row> averageValidClickCount(Dataset<Row> dataset){
// set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row
WindowSpec w = Window.partitionBy("ip", "app")
.orderBy("utc_click_time")
......@@ -65,7 +65,7 @@ public class Aggregation {
return newDF;
}
private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){
public Dataset<Row> clickTimeDelta(Dataset<Row> dataset){
WindowSpec w = Window.partitionBy ("ip")
.orderBy("utc_click_time");
......@@ -77,7 +77,7 @@ public class Aggregation {
return newDF;
}
private Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){
public Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){
WindowSpec w = Window.partitionBy("ip")
.orderBy("utc_click_time")
.rangeBetween(Window.currentRow(),Window.currentRow()+600);
......
package detact;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class Main {
public static void main(String[] args) throws Exception{
if (args.length != 1) {
System.out.println("Usage: java -jar aggregation.jar <data_path>");
System.exit(0);
}
String data_path = args[0];
//Create Session
SparkSession spark = SparkSession
.builder()
.appName("Detecting Fraud Clicks")
.master("local")
.getOrCreate();
// detact.Aggregation
Aggregation agg = new Aggregation();
Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark);
dataset = agg.changeTimestempToLong(dataset);
dataset = agg.averageValidClickCount(dataset);
dataset = agg.clickTimeDelta(dataset);
dataset = agg.countClickInTenMinutes(dataset);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{
"ip",
"app",
"device",
"os",
"channel",
"utc_click_time",
"avg_valid_click_count",
"click_time_delta",
"count_click_in_ten_mins"
})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(dataset);
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(2)
.fit(output);
// Split the result into training and test sets (30% held out for testing).
Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// Train a detact.DecisionTreeionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures")
.setLabelCol("is_attributed")
.setMaxDepth(10);
// Chain indexer and tree in a Pipeline.
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{featureIndexer, dt});
// Train model. This also runs the indexer.
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("is_attributed", "features").show(5);
// Select (prediction, true label) and compute test error.
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("is_attributed")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
DecisionTreeRegressionModel treeModel =
(DecisionTreeRegressionModel) (model.stages()[1]);
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
// save model
model.save("./decisionTree.model");
}
}
......