MapExample.java
3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import scala.Serializable;
import java.util.*;
// ml
public class MapExample {
static SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("Cesco");
static JavaSparkContext sc = new JavaSparkContext(conf);
static SQLContext sqlContext = new SQLContext(sc);
public static void main(String[] args) throws Exception {
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
Dataset<Row> resultds = sqlContext.createDataFrame(result);
System.out.println("schema start");
resultds.printSchema();
System.out.println("schema end");
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"ip", "app", "device", "os", "channel", "clickInTenMins"})
.setOutputCol("features");
Dataset<Row> output = assembler.transform(resultds);
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(2)
.fit(output);
// Split the result into training and test sets (30% held out for testing).
Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// Train a DecisionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures").setLabelCol("attributed");
// Chain indexer and tree in a Pipeline.
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{featureIndexer, dt});
// Train model. This also runs the indexer.
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("attributed", "features").show(5);
// Select (prediction, true label) and compute test error.
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("attributed")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
DecisionTreeRegressionModel treeModel =
(DecisionTreeRegressionModel) (model.stages()[1]);
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
}
}