add Aggregation class to aggrigate new feature
Java-Cesco/Detecting_fraud_clicks#3
Showing
1 changed file
with
76 additions
and
0 deletions
src/main/java/Aggregation.java
0 → 100644
| 1 | +import org.apache.spark.sql.Dataset; | ||
| 2 | +import org.apache.spark.sql.Row; | ||
| 3 | +import org.apache.spark.sql.SparkSession; | ||
| 4 | +import org.apache.spark.sql.expressions.Window; | ||
| 5 | +import org.apache.spark.sql.expressions.WindowSpec; | ||
| 6 | + | ||
| 7 | +import static org.apache.spark.sql.functions.*; | ||
| 8 | +import static org.apache.spark.sql.functions.lit; | ||
| 9 | +import static org.apache.spark.sql.functions.when; | ||
| 10 | + | ||
| 11 | +public class Aggregation { | ||
| 12 | + | ||
| 13 | + public static void main(String[] args) throws Exception { | ||
| 14 | + | ||
| 15 | + //Create Session | ||
| 16 | + SparkSession spark = SparkSession | ||
| 17 | + .builder() | ||
| 18 | + .appName("Detecting Fraud Clicks") | ||
| 19 | + .master("local") | ||
| 20 | + .getOrCreate(); | ||
| 21 | + | ||
| 22 | + Aggregation agg = new Aggregation(); | ||
| 23 | + | ||
| 24 | + Dataset<Row> dataset = agg.loadCSVDataSet("./train_sample.csv", spark); | ||
| 25 | + dataset = agg.changeTimestempToLong(dataset); | ||
| 26 | + dataset = agg.averageValidClickCount(dataset); | ||
| 27 | + dataset = agg.clickTimeDelta(dataset); | ||
| 28 | + | ||
| 29 | + dataset.where("ip == '5348' and app == '19'").show(); | ||
| 30 | + | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + | ||
| 34 | + private Dataset<Row> loadCSVDataSet(String path, SparkSession spark){ | ||
| 35 | + // Read SCV to DataSet | ||
| 36 | + Dataset<Row> dataset = spark.read().format("csv") | ||
| 37 | + .option("inferSchema", "true") | ||
| 38 | + .option("header", "true") | ||
| 39 | + .load("train_sample.csv"); | ||
| 40 | + return dataset; | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + private Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){ | ||
| 44 | + // cast timestamp to long | ||
| 45 | + Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long")); | ||
| 46 | + newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long")); | ||
| 47 | + newDF = newDF.drop("click_time").drop("attributed_time"); | ||
| 48 | + return newDF; | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + private Dataset<Row> averageValidClickCount(Dataset<Row> dataset){ | ||
| 52 | + // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row | ||
| 53 | + WindowSpec w = Window.partitionBy("ip", "app") | ||
| 54 | + .orderBy("utc_click_time") | ||
| 55 | + .rowsBetween(Window.unboundedPreceding(), Window.currentRow()); | ||
| 56 | + | ||
| 57 | + // aggregation | ||
| 58 | + Dataset<Row> newDF = dataset.withColumn("cum_count_click", count("utc_click_time").over(w)); | ||
| 59 | + newDF = newDF.withColumn("cum_sum_attributed", sum("is_attributed").over(w)); | ||
| 60 | + newDF = newDF.withColumn("avg_valid_click_count", col("cum_sum_attributed").divide(col("cum_count_click"))); | ||
| 61 | + newDF = newDF.drop("cum_count_click", "cum_sum_attributed"); | ||
| 62 | + return newDF; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + private Dataset<Row> clickTimeDelta(Dataset<Row> dataset){ | ||
| 66 | + WindowSpec w = Window.partitionBy ("ip") | ||
| 67 | + .orderBy("utc_click_time"); | ||
| 68 | + | ||
| 69 | + Dataset<Row> newDF = dataset.withColumn("lag(utc_click_time)", lag("utc_click_time",1).over(w)); | ||
| 70 | + newDF = newDF.withColumn("click_time_delta", when(col("lag(utc_click_time)").isNull(), | ||
| 71 | + lit(0)).otherwise(col("utc_click_time")).minus(when(col("lag(utc_click_time)").isNull(), | ||
| 72 | + lit(0)).otherwise(col("lag(utc_click_time)")))); | ||
| 73 | + newDF = newDF.drop("lag(utc_click_time)"); | ||
| 74 | + return newDF; | ||
| 75 | + } | ||
| 76 | +} |
-
Please register or login to post a comment