신은섭(Shin Eun Seop)

apply ml

......@@ -19,7 +19,12 @@
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.2.0</version>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.3.0</version>
</dependency>
</dependencies>
......
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import scala.Serializable;
import scala.Tuple2;
import java.util.*;
// ml
//ip,app,device,os,channel,click_time,attributed_time,is_attributed
//87540,12,1,13,497,2017-11-07 09:30:38,,0
class Record implements Serializable {
Integer ip;
Integer app;
Integer device;
Integer os;
Integer channel;
Calendar clickTime;
Calendar attributedTime;
Boolean isAttributed;
Integer clickInTenMins;
// constructor , getters and setters
public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, Calendar pClickTime, Calendar pAttributedTime, boolean pIsAttributed) {
ip = new Integer(pIp);
app = new Integer(pApp);
device = new Integer(pDevice);
os = new Integer(pOs);
channel = new Integer(pChannel);
clickTime = pClickTime;
attributedTime = pAttributedTime;
isAttributed = new Boolean(pIsAttributed);
clickInTenMins = new Integer(0);
}
public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, Calendar pClickTime, Calendar pAttributedTime, boolean pIsAttributed, int pClickInTenMins) {
ip = new Integer(pIp);
app = new Integer(pApp);
device = new Integer(pDevice);
os = new Integer(pOs);
channel = new Integer(pChannel);
clickTime = pClickTime;
attributedTime = pAttributedTime;
isAttributed = new Boolean(pIsAttributed);
clickInTenMins = new Integer(pClickInTenMins);
}
}
class RecordComparator implements Comparator<Record> {
@Override
......@@ -72,14 +45,14 @@ public class MapExample {
static SQLContext sqlContext = new SQLContext(sc);
public static void main(String[] args) throws Exception {
JavaRDD<String> file = sc.textFile("/Users/hyeongyunmun/Dropbox/DetectFraudClick/data/train.csv", 1);
JavaRDD<String> file = sc.textFile("data/train.csv", 1);
final String header = file.first();
JavaRDD<String> data = file.filter(line -> !line.equalsIgnoreCase(header));
JavaRDD<Record> records = data.map(line -> {
String[] fields = line.split(",");
Record sd = new Record(Integer.parseInt(fields[0]), Integer.parseInt(fields[1]), Integer.parseInt(fields[2]), Integer.parseInt(fields[3]), Integer.parseInt(fields[4]), DateUtil.CalendarFromString(fields[5]), DateUtil.CalendarFromString(fields[6]), "1".equalsIgnoreCase(fields[7].trim()));
Record sd = new Record(Integer.parseInt(fields[0]), Integer.parseInt(fields[1]), Integer.parseInt(fields[2]), Integer.parseInt(fields[3]), Integer.parseInt(fields[4]), fields[5], fields[6], Integer.parseInt(fields[7].trim()));
return sd;
});
......@@ -89,9 +62,9 @@ public class MapExample {
// return new Tuple2(value._2(),value._3());
// }}).sortByKey(new TupleComparator()).values();
JavaRDD<Record> firstSorted = records.sortBy(new Function<Record, Calendar>() {
JavaRDD<Record> firstSorted = records.sortBy(new Function<Record, String>() {
@Override
public Calendar call(Record record) throws Exception {
public String call(Record record) throws Exception {
return record.clickTime;
}
}, true, 1);
......@@ -161,23 +134,83 @@ public class MapExample {
Record record = list.get(i);
Calendar recordI = DateUtil.CalendarFromString(record.clickTime);
Calendar addTen = Calendar.getInstance();
addTen.setTime(record.clickTime.getTime());
addTen.setTime(recordI.getTime());
addTen.add(Calendar.MINUTE, 10);
int count = 0;
for (int j = i+1; j < list.size() && list.get(j).ip.compareTo(record.ip) == 0
&& list.get(j).clickTime.compareTo(record.clickTime) > 0 &&list.get(j).clickTime.compareTo(addTen) < 0; j++)
count++;
for (int j = i+1; j < list.size() && list.get(j).ip.compareTo(record.ip) == 0; j++) {
Calendar recordJ = DateUtil.CalendarFromString(list.get(j).clickTime);
if (recordJ.compareTo(recordI) > 0 && recordJ.compareTo(addTen) < 0) {
count++;
} else {
break;
}
}
resultList.add(new Record(record.ip, record.app, record.device, record.os, record.channel, record.clickTime, record.attributedTime, record.isAttributed, count));
}
JavaRDD<Record> result = sc.parallelize(resultList);
result.foreach(record -> {System.out.println(record.ip + " " + record.clickTime.getTime() + " " + record.clickInTenMins);});
// result.foreach(record -> {System.out.println(record.ip + " " + record.clickTime.getTime() + " " + record.clickInTenMins);});
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
Dataset<Row> resultds = sqlContext.createDataFrame(result, Record.class);
System.out.println("schema start");
resultds.printSchema();
System.out.println("schema end");
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"ip", "app", "device", "os", "channel", "clickInTenMins"})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(resultds);
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(2)
.fit(output);
// Split the result into training and test sets (30% held out for testing).
Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// Train a DecisionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures").setLabelCol("attributed");
// Chain indexer and tree in a Pipeline.
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{featureIndexer, dt});
// Train model. This also runs the indexer.
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("attributed", "features").show(5);
// Select (prediction, true label) and compute test error.
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("attributed")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
DecisionTreeRegressionModel treeModel =
(DecisionTreeRegressionModel) (model.stages()[1]);
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
}
}
......
import scala.Serializable;
public class Record implements Serializable {
Integer ip;
Integer app;
Integer device;
Integer os;
Integer channel;
String clickTime;
String attributedTime;
Integer isAttributed;
Integer clickInTenMins;
// constructor , getters and setters
public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, String pClickTime, String pAttributedTime, Integer pIsAttributed) {
ip = new Integer(pIp);
app = new Integer(pApp);
device = new Integer(pDevice);
os = new Integer(pOs);
channel = new Integer(pChannel);
clickTime = pClickTime;
attributedTime = pAttributedTime;
isAttributed = new Integer(pIsAttributed);
clickInTenMins = new Integer(0);
}
public Record(int pIp, int pApp, int pDevice, int pOs, int pChannel, String pClickTime, String pAttributedTime, Integer pIsAttributed, int pClickInTenMins) {
ip = new Integer(pIp);
app = new Integer(pApp);
device = new Integer(pDevice);
os = new Integer(pOs);
channel = new Integer(pChannel);
clickTime = pClickTime;
attributedTime = pAttributedTime;
isAttributed = new Integer(pIsAttributed);
clickInTenMins = new Integer(pClickInTenMins);
}
public Integer getIp() {
return ip;
}
public void setIp(Integer ip) {
this.ip = ip;
}
public Integer getApp() {
return app;
}
public void setApp(Integer app) {
this.app = app;
}
public Integer getDevice() {
return device;
}
public void setDevice(Integer device) {
this.device = device;
}
public Integer getOs() {
return os;
}
public void setOs(Integer os) {
this.os = os;
}
public Integer getChannel() {
return channel;
}
public void setChannel(Integer channel) {
this.channel = channel;
}
public String getClickTime() {
return clickTime;
}
public void setClickTime(String clickTime) {
this.clickTime = clickTime;
}
public String getAttributedTime() {
return attributedTime;
}
public void setAttributedTime(String attributedTime) {
this.attributedTime = attributedTime;
}
public Integer getAttributed() {
return isAttributed;
}
public void setAttributed(Integer attributed) {
isAttributed = attributed;
}
public Integer getClickInTenMins() {
return clickInTenMins;
}
public void setClickInTenMins(Integer clickInTenMins) {
this.clickInTenMins = clickInTenMins;
}
}
\ No newline at end of file