KimchiSoup(junu)

Merge branch 'ml' of https://github.com/Java-Cesco/Detecting_fraud_clicks into feauture/GUI_2

...@@ -75,3 +75,8 @@ fabric.properties ...@@ -75,3 +75,8 @@ fabric.properties
75 75
76 # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 76 # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
77 hs_err_pid* 77 hs_err_pid*
78 +
79 +
80 +# datafile
81 +train.zip
82 +train.csv
...\ No newline at end of file ...\ No newline at end of file
......
1 +Detecting_fraud_clicks
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="MarkdownExportedFiles">
4 + <htmlFiles />
5 + <imageFiles />
6 + <otherFiles />
7 + </component>
8 +</project>
...\ No newline at end of file ...\ No newline at end of file
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="MarkdownProjectSettings">
4 + <PreviewSettings splitEditorLayout="SPLIT" splitEditorPreview="PREVIEW" useGrayscaleRendering="false" zoomFactor="1.0" maxImageWidth="0" showGitHubPageIfSynced="false" allowBrowsingInPreview="false" synchronizePreviewPosition="true" highlightPreviewType="NONE" highlightFadeOut="5" highlightOnTyping="true" synchronizeSourcePosition="true" verticallyAlignSourceAndPreviewSyncPosition="true" showSearchHighlightsInPreview="false" showSelectionInPreview="true">
5 + <PanelProvider>
6 + <provider providerId="com.vladsch.idea.multimarkdown.editor.swing.html.panel" providerName="Default - Swing" />
7 + </PanelProvider>
8 + </PreviewSettings>
9 + <ParserSettings gitHubSyntaxChange="false">
10 + <PegdownExtensions>
11 + <option name="ABBREVIATIONS" value="false" />
12 + <option name="ANCHORLINKS" value="true" />
13 + <option name="ASIDE" value="false" />
14 + <option name="ATXHEADERSPACE" value="true" />
15 + <option name="AUTOLINKS" value="true" />
16 + <option name="DEFINITIONS" value="false" />
17 + <option name="DEFINITION_BREAK_DOUBLE_BLANK_LINE" value="false" />
18 + <option name="FENCED_CODE_BLOCKS" value="true" />
19 + <option name="FOOTNOTES" value="false" />
20 + <option name="HARDWRAPS" value="false" />
21 + <option name="HTML_DEEP_PARSER" value="false" />
22 + <option name="INSERTED" value="false" />
23 + <option name="QUOTES" value="false" />
24 + <option name="RELAXEDHRULES" value="true" />
25 + <option name="SMARTS" value="false" />
26 + <option name="STRIKETHROUGH" value="true" />
27 + <option name="SUBSCRIPT" value="false" />
28 + <option name="SUPERSCRIPT" value="false" />
29 + <option name="SUPPRESS_HTML_BLOCKS" value="false" />
30 + <option name="SUPPRESS_INLINE_HTML" value="false" />
31 + <option name="TABLES" value="true" />
32 + <option name="TASKLISTITEMS" value="true" />
33 + <option name="TOC" value="false" />
34 + <option name="WIKILINKS" value="true" />
35 + </PegdownExtensions>
36 + <ParserOptions>
37 + <option name="COMMONMARK_LISTS" value="true" />
38 + <option name="DUMMY" value="false" />
39 + <option name="EMOJI_SHORTCUTS" value="true" />
40 + <option name="FLEXMARK_FRONT_MATTER" value="false" />
41 + <option name="GFM_LOOSE_BLANK_LINE_AFTER_ITEM_PARA" value="false" />
42 + <option name="GFM_TABLE_RENDERING" value="true" />
43 + <option name="GITBOOK_URL_ENCODING" value="false" />
44 + <option name="GITHUB_EMOJI_URL" value="false" />
45 + <option name="GITHUB_LISTS" value="false" />
46 + <option name="GITHUB_WIKI_LINKS" value="true" />
47 + <option name="JEKYLL_FRONT_MATTER" value="false" />
48 + <option name="SIM_TOC_BLANK_LINE_SPACER" value="true" />
49 + </ParserOptions>
50 + </ParserSettings>
51 + <HtmlSettings headerTopEnabled="false" headerBottomEnabled="false" bodyTopEnabled="false" bodyBottomEnabled="false" embedUrlContent="false" addPageHeader="true">
52 + <GeneratorProvider>
53 + <provider providerId="com.vladsch.idea.multimarkdown.editor.swing.html.generator" providerName="Default Swing HTML Generator" />
54 + </GeneratorProvider>
55 + <headerTop />
56 + <headerBottom />
57 + <bodyTop />
58 + <bodyBottom />
59 + </HtmlSettings>
60 + <CssSettings previewScheme="UI_SCHEME" cssUri="" isCssUriEnabled="false" isCssTextEnabled="false" isDynamicPageWidth="true">
61 + <StylesheetProvider>
62 + <provider providerId="com.vladsch.idea.multimarkdown.editor.swing.html.css" providerName="Default Swing Stylesheet" />
63 + </StylesheetProvider>
64 + <ScriptProviders />
65 + <cssText />
66 + </CssSettings>
67 + <HtmlExportSettings updateOnSave="false" parentDir="$ProjectFileDir$" targetDir="$ProjectFileDir$" cssDir="" scriptDir="" plainHtml="false" imageDir="" copyLinkedImages="false" imageUniquifyType="0" targetExt="" useTargetExt="false" noCssNoScripts="false" linkToExportedHtml="true" exportOnSettingsChange="true" regenerateOnProjectOpen="false" />
68 + <LinkMapSettings>
69 + <textMaps />
70 + </LinkMapSettings>
71 + </component>
72 +</project>
...\ No newline at end of file ...\ No newline at end of file
1 +<component name="MarkdownNavigator.ProfileManager">
2 + <settings default="" pdf-export="" />
3 +</component>
...\ No newline at end of file ...\ No newline at end of file
...@@ -11,4 +11,14 @@ ...@@ -11,4 +11,14 @@
11 <component name="ProjectRootManager" version="2" languageLevel="JDK_1_8" project-jdk-name="1.8" project-jdk-type="JavaSDK"> 11 <component name="ProjectRootManager" version="2" languageLevel="JDK_1_8" project-jdk-name="1.8" project-jdk-type="JavaSDK">
12 <output url="file://$PROJECT_DIR$/out" /> 12 <output url="file://$PROJECT_DIR$/out" />
13 </component> 13 </component>
14 + <component name="MavenProjectsManager">
15 + <option name="originalFiles">
16 + <list>
17 + <option value="$PROJECT_DIR$/pom.xml" />
18 + </list>
19 + </option>
20 + </component>
21 + <component name="ProjectRootManager" version="2" languageLevel="JDK_1_8" default="false" project-jdk-name="1.8" project-jdk-type="JavaSDK">
22 + <output url="file:///tmp" />
23 + </component>
14 </project> 24 </project>
...\ No newline at end of file ...\ No newline at end of file
......
1 # 2018-JAVA-Cesco 1 # 2018-JAVA-Cesco
2 Detecting fraud clicks using machine learning 2 Detecting fraud clicks using machine learning
3 +
4 +## execution script
5 +### Amazon Linux
6 +```bash
7 +# update
8 +sudo yum update -y
9 +
10 +# install git
11 +sudo yum install git -y
12 +
13 +# install maven and java 1.8
14 +sudo wget http://repos.fedorapeople.org/repos/dchen/apache-maven/epel-apache-maven.repo -O /etc/yum.repos.d/epel-apache-maven.repo
15 +sudo sed -i s/\$releasever/6/g /etc/yum.repos.d/epel-apache-maven.repo
16 +sudo yum install -y apache-maven java-1.8.0-openjdk-devel.x86_64
17 +
18 +mvn --version
19 +
20 +# clone repo
21 +git clone https://github.com/Java-Cesco/Detecting_fraud_clicks.git
22 +cd Detecting_fraud_clicks
23 +
24 +# maven build
25 +mvn package
26 +
27 +# run
28 +java8 -jar target/assembly/Detecting_fraud_clicks-aggregation.jar train_sample.csv agg_data
29 +java8 -jar target/assembly/Detecting_fraud_clicks-decisionTree.jar agg_data
30 +
31 +```
32 +> NOTE. if you face Memory error using `-Xmx2g` option in `java`
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -16,13 +16,16 @@ ...@@ -16,13 +16,16 @@
16 <artifactId>spark-core_2.11</artifactId> 16 <artifactId>spark-core_2.11</artifactId>
17 <version>2.3.0</version> 17 <version>2.3.0</version>
18 </dependency> 18 </dependency>
19 - <!-- https://mavnrepository.com/artifact/org.apache.spark/spark-sql --> 19 + <dependency>
20 + <groupId>org.apache.spark</groupId>
21 + <artifactId>spark-mllib_2.11</artifactId>
22 + <version>2.3.0</version>
23 + </dependency>
20 <dependency> 24 <dependency>
21 <groupId>org.apache.spark</groupId> 25 <groupId>org.apache.spark</groupId>
22 <artifactId>spark-sql_2.11</artifactId> 26 <artifactId>spark-sql_2.11</artifactId>
23 <version>2.3.0</version> 27 <version>2.3.0</version>
24 </dependency> 28 </dependency>
25 -
26 <dependency> 29 <dependency>
27 <groupId>com.databricks</groupId> 30 <groupId>com.databricks</groupId>
28 <artifactId>spark-csv_2.11</artifactId> 31 <artifactId>spark-csv_2.11</artifactId>
...@@ -30,19 +33,96 @@ ...@@ -30,19 +33,96 @@
30 </dependency> 33 </dependency>
31 </dependencies> 34 </dependencies>
32 35
33 -
34 - <!--maven-compiler-plugin-->
35 <build> 36 <build>
36 <plugins> 37 <plugins>
37 <plugin> 38 <plugin>
38 <groupId>org.apache.maven.plugins</groupId> 39 <groupId>org.apache.maven.plugins</groupId>
39 - <artifactId>maven-compiler-plugin</artifactId> 40 + <artifactId>maven-shade-plugin</artifactId>
40 - <version>3.1</version> 41 + <executions>
42 + <!-- Aggregation -->
43 + <execution>
44 + <id>aggregation</id>
45 + <goals>
46 + <goal>shade</goal>
47 + </goals>
41 <configuration> 48 <configuration>
42 - <source>1.8</source> 49 + <outputFile>target/assembly/${project.artifactId}-aggregation.jar</outputFile>
43 - <target>1.8</target> 50 + <shadedArtifactAttached>true</shadedArtifactAttached>
51 + <transformers>
52 + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
53 + <mainClass>detact.Aggregation</mainClass>
54 + </transformer>
55 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
56 + </transformers>
57 + <filters>
58 + <filter>
59 + <artifact>*:*</artifact>
60 + <excludes>
61 + <exclude>META-INF/*.SF</exclude>
62 + <exclude>META-INF/*.DSA</exclude>
63 + <exclude>META-INF/*.RSA</exclude>
64 + </excludes>
65 + </filter>
66 + </filters>
44 </configuration> 67 </configuration>
68 + </execution>
69 + <!-- Decision Tree -->
70 + <execution>
71 + <id>decisionTree</id>
72 + <goals>
73 + <goal>shade</goal>
74 + </goals>
75 + <configuration>
76 + <outputFile>target/assembly/${project.artifactId}-decisionTree.jar</outputFile>
77 + <shadedArtifactAttached>true</shadedArtifactAttached>
78 + <transformers>
79 + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
80 + <mainClass>detact.ML.DecisionTree</mainClass>
81 + </transformer>
82 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
83 + </transformers>
84 + <filters>
85 + <filter>
86 + <artifact>*:*</artifact>
87 + <excludes>
88 + <exclude>META-INF/*.SF</exclude>
89 + <exclude>META-INF/*.DSA</exclude>
90 + <exclude>META-INF/*.RSA</exclude>
91 + </excludes>
92 + </filter>
93 + </filters>
94 + </configuration>
95 + </execution>
96 + <!-- Main -->
97 + <execution>
98 + <id>Main</id>
99 + <goals>
100 + <goal>shade</goal>
101 + </goals>
102 + <configuration>
103 + <outputFile>target/assembly/${project.artifactId}-main.jar</outputFile>
104 + <shadedArtifactAttached>true</shadedArtifactAttached>
105 + <transformers>
106 + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
107 + <mainClass>detact.Main</mainClass>
108 + </transformer>
109 + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
110 + </transformers>
111 + <filters>
112 + <filter>
113 + <artifact>*:*</artifact>
114 + <excludes>
115 + <exclude>META-INF/*.SF</exclude>
116 + <exclude>META-INF/*.DSA</exclude>
117 + <exclude>META-INF/*.RSA</exclude>
118 + </excludes>
119 + </filter>
120 + </filters>
121 + </configuration>
122 + </execution>
123 + </executions>
45 </plugin> 124 </plugin>
46 </plugins> 125 </plugins>
47 </build> 126 </build>
127 +
48 </project> 128 </project>
...\ No newline at end of file ...\ No newline at end of file
......
1 +package detact;
2 +
3 +import org.apache.spark.sql.Dataset;
4 +import org.apache.spark.sql.Row;
5 +import org.apache.spark.sql.SparkSession;
6 +import org.apache.spark.sql.expressions.Window;
7 +import org.apache.spark.sql.expressions.WindowSpec;
8 +
9 +import static org.apache.spark.sql.functions.*;
10 +
11 +public class Aggregation {
12 +
13 + public static void main(String[] args) {
14 +
15 + if (args.length != 2) {
16 + System.out.println("Usage: java -jar aggregation.jar <data_path> <result_path>");
17 + System.exit(0);
18 + }
19 +
20 + String data_path = args[0];
21 + String result_path = args[1];
22 +
23 + //Create Session
24 + SparkSession spark = SparkSession
25 + .builder()
26 + .appName("Detecting Fraud Clicks")
27 + .master("local")
28 + .getOrCreate();
29 +
30 + // detact.Aggregation
31 + Aggregation agg = new Aggregation();
32 +
33 + Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark);
34 + dataset = agg.changeTimestempToLong(dataset);
35 + dataset = agg.averageValidClickCount(dataset);
36 + dataset = agg.clickTimeDelta(dataset);
37 + dataset = agg.countClickInTenMinutes(dataset);
38 +
39 + // test
40 +// dataset.where("ip == '5348' and app == '19'").show(10);
41 +
42 + // Save to scv
43 + Utill.saveCSVDataSet(dataset, result_path);
44 + }
45 +
46 + public Dataset<Row> changeTimestempToLong(Dataset<Row> dataset){
47 + // cast timestamp to long
48 + Dataset<Row> newDF = dataset.withColumn("utc_click_time", dataset.col("click_time").cast("long"));
49 + newDF = newDF.withColumn("utc_attributed_time", dataset.col("attributed_time").cast("long"));
50 + newDF = newDF.drop("click_time").drop("attributed_time");
51 + return newDF;
52 + }
53 +
54 + public Dataset<Row> averageValidClickCount(Dataset<Row> dataset){
55 + // set Window partition by 'ip' and 'app' order by 'utc_click_time' select rows between 1st row to current row
56 + WindowSpec w = Window.partitionBy("ip", "app")
57 + .orderBy("utc_click_time")
58 + .rowsBetween(Window.unboundedPreceding(), Window.currentRow());
59 +
60 + // aggregation
61 + Dataset<Row> newDF = dataset.withColumn("cum_count_click", count("utc_click_time").over(w));
62 + newDF = newDF.withColumn("cum_sum_attributed", sum("is_attributed").over(w));
63 + newDF = newDF.withColumn("avg_valid_click_count", col("cum_sum_attributed").divide(col("cum_count_click")));
64 + newDF = newDF.drop("cum_count_click", "cum_sum_attributed");
65 + return newDF;
66 + }
67 +
68 + public Dataset<Row> clickTimeDelta(Dataset<Row> dataset){
69 + WindowSpec w = Window.partitionBy ("ip")
70 + .orderBy("utc_click_time");
71 +
72 + Dataset<Row> newDF = dataset.withColumn("lag(utc_click_time)", lag("utc_click_time",1).over(w));
73 + newDF = newDF.withColumn("click_time_delta", when(col("lag(utc_click_time)").isNull(),
74 + lit(0)).otherwise(col("utc_click_time")).minus(when(col("lag(utc_click_time)").isNull(),
75 + lit(0)).otherwise(col("lag(utc_click_time)"))));
76 + newDF = newDF.drop("lag(utc_click_time)");
77 + return newDF;
78 + }
79 +
80 + public Dataset<Row> countClickInTenMinutes(Dataset<Row> dataset){
81 + WindowSpec w = Window.partitionBy("ip")
82 + .orderBy("utc_click_time")
83 + .rangeBetween(Window.currentRow(),Window.currentRow()+600);
84 +
85 + Dataset<Row> newDF = dataset.withColumn("count_click_in_ten_mins",
86 + (count("utc_click_time").over(w)).minus(1));
87 + return newDF;
88 + }
89 +
90 +}
1 +package detact.ML;
2 +
3 +import detact.Aggregation;
4 +import detact.Utill;
5 +import org.apache.spark.ml.Pipeline;
6 +import org.apache.spark.ml.PipelineModel;
7 +import org.apache.spark.ml.PipelineStage;
8 +import org.apache.spark.ml.evaluation.RegressionEvaluator;
9 +import org.apache.spark.ml.feature.VectorAssembler;
10 +import org.apache.spark.ml.feature.VectorIndexer;
11 +import org.apache.spark.ml.feature.VectorIndexerModel;
12 +import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
13 +import org.apache.spark.ml.regression.DecisionTreeRegressor;
14 +import org.apache.spark.sql.Dataset;
15 +import org.apache.spark.sql.Row;
16 +import org.apache.spark.sql.SparkSession;
17 +
18 +
19 +// DecisionTree Model
20 +
21 +public class DecisionTree {
22 +
23 + public static void main(String[] args) throws Exception {
24 +
25 + if (args.length != 1) {
26 + System.out.println("Usage: java -jar decisionTree.jar <agg_path>");
27 + System.exit(0);
28 + }
29 +
30 + String agg_path = args[0];
31 +
32 + //Create Session
33 + SparkSession spark = SparkSession
34 + .builder()
35 + .appName("Detecting Fraud Clicks")
36 + .master("local")
37 + .getOrCreate();
38 +
39 + // load aggregated dataset
40 + Dataset<Row> resultds = Utill.loadCSVDataSet(agg_path, spark);
41 +
42 + // show Dataset schema
43 +// System.out.println("schema start");
44 +// resultds.printSchema();
45 +// String[] cols = resultds.columns();
46 +// for (String col : cols) {
47 +// System.out.println(col);
48 +// }
49 +// System.out.println("schema end");
50 +
51 + VectorAssembler assembler = new VectorAssembler()
52 + .setInputCols(new String[]{
53 + "ip",
54 + "app",
55 + "device",
56 + "os",
57 + "channel",
58 + "utc_click_time",
59 + "avg_valid_click_count",
60 + "click_time_delta",
61 + "count_click_in_ten_mins"
62 + })
63 + .setOutputCol("features");
64 +
65 + Dataset<Row> output = assembler.transform(resultds);
66 +
67 + VectorIndexerModel featureIndexer = new VectorIndexer()
68 + .setInputCol("features")
69 + .setOutputCol("indexedFeatures")
70 + .setMaxCategories(2)
71 + .fit(output);
72 +
73 + // Split the result into training and test sets (30% held out for testing).
74 + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
75 + Dataset<Row> trainingData = splits[0];
76 + Dataset<Row> testData = splits[1];
77 +
78 + // Train a detact.DecisionTreeionTree model.
79 + DecisionTreeRegressor dt = new DecisionTreeRegressor()
80 + .setFeaturesCol("indexedFeatures")
81 + .setLabelCol("is_attributed")
82 + .setMaxDepth(10);
83 +
84 + // Chain indexer and tree in a Pipeline.
85 + Pipeline pipeline = new Pipeline()
86 + .setStages(new PipelineStage[]{featureIndexer, dt});
87 +
88 + // Train model. This also runs the indexer.
89 + PipelineModel model = pipeline.fit(trainingData);
90 +
91 + // Make predictions.
92 + Dataset<Row> predictions = model.transform(testData);
93 +
94 + // Select example rows to display.
95 + predictions.select("is_attributed", "features").show(5);
96 +
97 + // Select (prediction, true label) and compute test error.
98 + RegressionEvaluator evaluator = new RegressionEvaluator()
99 + .setLabelCol("is_attributed")
100 + .setPredictionCol("prediction")
101 + .setMetricName("rmse");
102 + double rmse = evaluator.evaluate(predictions);
103 + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
104 +
105 + DecisionTreeRegressionModel treeModel =
106 + (DecisionTreeRegressionModel) (model.stages()[1]);
107 + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
108 +
109 + // save model
110 + model.save("./decisionTree");
111 +
112 + // load model
113 + PipelineModel load_mode = PipelineModel.load("./decisionTree");
114 +
115 + // Make predictions.
116 + Dataset<Row> load_pred = model.transform(testData);
117 +
118 + }
119 +
120 +}
1 +package detact;
2 +
3 +import org.apache.spark.ml.Pipeline;
4 +import org.apache.spark.ml.PipelineModel;
5 +import org.apache.spark.ml.PipelineStage;
6 +import org.apache.spark.ml.evaluation.RegressionEvaluator;
7 +import org.apache.spark.ml.feature.VectorAssembler;
8 +import org.apache.spark.ml.feature.VectorIndexer;
9 +import org.apache.spark.ml.feature.VectorIndexerModel;
10 +import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
11 +import org.apache.spark.ml.regression.DecisionTreeRegressor;
12 +import org.apache.spark.sql.Dataset;
13 +import org.apache.spark.sql.Row;
14 +import org.apache.spark.sql.SparkSession;
15 +
16 +public class Main {
17 + public static void main(String[] args) throws Exception{
18 + if (args.length != 1) {
19 + System.out.println("Usage: java -jar aggregation.jar <data_path>");
20 + System.exit(0);
21 + }
22 +
23 + String data_path = args[0];
24 +
25 + //Create Session
26 + SparkSession spark = SparkSession
27 + .builder()
28 + .appName("Detecting Fraud Clicks")
29 + .master("local")
30 + .getOrCreate();
31 +
32 + // detact.Aggregation
33 + Aggregation agg = new Aggregation();
34 +
35 + Dataset<Row> dataset = Utill.loadCSVDataSet(data_path, spark);
36 + dataset = agg.changeTimestempToLong(dataset);
37 + dataset = agg.averageValidClickCount(dataset);
38 + dataset = agg.clickTimeDelta(dataset);
39 + dataset = agg.countClickInTenMinutes(dataset);
40 +
41 + VectorAssembler assembler = new VectorAssembler()
42 + .setInputCols(new String[]{
43 + "ip",
44 + "app",
45 + "device",
46 + "os",
47 + "channel",
48 + "utc_click_time",
49 + "avg_valid_click_count",
50 + "click_time_delta",
51 + "count_click_in_ten_mins"
52 + })
53 + .setOutputCol("features");
54 +
55 + Dataset<Row> output = assembler.transform(dataset);
56 +
57 + VectorIndexerModel featureIndexer = new VectorIndexer()
58 + .setInputCol("features")
59 + .setOutputCol("indexedFeatures")
60 + .setMaxCategories(2)
61 + .fit(output);
62 +
63 + // Split the result into training and test sets (30% held out for testing).
64 + Dataset<Row>[] splits = output.randomSplit(new double[]{0.7, 0.3});
65 + Dataset<Row> trainingData = splits[0];
66 + Dataset<Row> testData = splits[1];
67 +
68 + // Train a detact.DecisionTreeionTree model.
69 + DecisionTreeRegressor dt = new DecisionTreeRegressor()
70 + .setFeaturesCol("indexedFeatures")
71 + .setLabelCol("is_attributed")
72 + .setMaxDepth(10);
73 +
74 + // Chain indexer and tree in a Pipeline.
75 + Pipeline pipeline = new Pipeline()
76 + .setStages(new PipelineStage[]{featureIndexer, dt});
77 +
78 + // Train model. This also runs the indexer.
79 + PipelineModel model = pipeline.fit(trainingData);
80 +
81 + // save model
82 + model.save("./decisionTree");
83 +
84 + PipelineModel p_model = PipelineModel.load("./decisionTree");
85 +
86 + // Make predictions.
87 + Dataset<Row> predictions = p_model.transform(testData);
88 +
89 + // Select example rows to display.
90 + predictions.select("is_attributed", "features").show(5);
91 +
92 + // Select (prediction, true label) and compute test error.
93 + RegressionEvaluator evaluator = new RegressionEvaluator()
94 + .setLabelCol("is_attributed")
95 + .setPredictionCol("prediction")
96 + .setMetricName("rmse");
97 + double rmse = evaluator.evaluate(predictions);
98 + System.out.println("Root Mean Squared Error (RMSE) on test result = " + rmse);
99 +
100 + DecisionTreeRegressionModel treeModel =
101 + (DecisionTreeRegressionModel) (p_model.stages()[1]);
102 + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
103 +
104 + }
105 +}
1 +package detact;
2 +
3 +import org.apache.spark.sql.Dataset;
4 +import org.apache.spark.sql.Row;
5 +import org.apache.spark.sql.SparkSession;
6 +
7 +public class Utill {
8 +
9 + public static Dataset<Row> loadCSVDataSet(String path, SparkSession spark){
10 + // Read SCV to DataSet
11 + return spark.read().format("com.databricks.spark.csv")
12 + .option("inferSchema", "true")
13 + .option("header", "true")
14 + .load(path);
15 + }
16 +
17 + public static void saveCSVDataSet(Dataset<Row> dataset, String path){
18 + // Read SCV to DataSet
19 + dataset.write().format("com.databricks.spark.csv")
20 + .option("inferSchema", "true")
21 + .option("header", "true")
22 + .save(path);
23 + }
24 +}
This diff could not be displayed because it is too large.