Spark Machine Learning
- 引入依赖
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.4.5</version>
</dependency>
相关系数矩阵
相关系数矩阵(i行,j列)的元素是原矩阵i列与j列的相关系数
#身高与体重的相关系数矩阵
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(Vectors.dense(172, 60)),
RowFactory.create(Vectors.dense(175, 65)),
RowFactory.create(Vectors.dense(180, 70)),
RowFactory.create(Vectors.dense(190, 80))
);
StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
Dataset<Row> df = spark.createDataFrame(data, schema);
df.show();
Row pearson = Correlation.corr(df, "features").head();
System.out.println("Pearson correlation matrix:\n" + pearson);
Row spearman = Correlation.corr(df, "features", "spearman").head();
System.out.println("Spearman correlation matrix:\n" + spearman);
+------------+
| features|
+------------+
|[172.0,60.0]|
|[175.0,65.0]|
|[180.0,70.0]|
|[190.0,80.0]|
+------------+
Pearson correlation matrix:
[1.0 0.9957069831288888
0.9957069831288888 1.0 ]
Spearman correlation matrix:
[1.0 0.9999999999999981
0.9999999999999981 1.0 ]
Extracting, transforming and selecting features
Tokenizer 分词
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(0, "Hi I heard about Spark"),
RowFactory.create(1, "I wish Java could use case classes"),
RowFactory.create(2, "Logistic,regression,models,are,neat")
);
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> sentenceDataFrame = spark.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
Dataset<Row> tokenized = tokenizer.transform(sentenceDataFrame);
tokenized.show(false);
//正则
RegexTokenizer regexTokenizer = new RegexTokenizer()
.setInputCol("sentence")
.setOutputCol("words")
.setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false);
Dataset<Row> regexTokenized = regexTokenizer.transform(sentenceDataFrame);
regexTokenized.show(false);
+---+-----------------------------------+------------------------------------------+
|id |sentence |words |
+---+-----------------------------------+------------------------------------------+
|0 |Hi I heard about Spark |[hi, i, heard, about, spark] |
|1 |I wish Java could use case classes |[i, wish, java, could, use, case, classes]|
|2 |Logistic,regression,models,are,neat|[logistic,regression,models,are,neat] |
+---+-----------------------------------+------------------------------------------+
+---+-----------------------------------+------------------------------------------+
|id |sentence |words |
+---+-----------------------------------+------------------------------------------+
|0 |Hi I heard about Spark |[hi, i, heard, about, spark] |
|1 |I wish Java could use case classes |[i, wish, java, could, use, case, classes]|
|2 |Logistic,regression,models,are,neat|[logistic, regression, models, are, neat] |
+---+-----------------------------------+------------------------------------------+
StopWordsRemover 停止词
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
StopWordsRemover remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered");
List<Row> data = Arrays.asList(
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "balloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
);
StructType schema = new StructType(new StructField[]{
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
Dataset<Row> dataset = spark.createDataFrame(data, schema);
remover.transform(dataset).show(false);
+----------------------------+--------------------+
|raw |filtered |
+----------------------------+--------------------+
|[I, saw, the, red, balloon] |[saw, red, balloon] |
|[Mary, had, a, little, lamb]|[Mary, little, lamb]|
+----------------------------+--------------------+
$n$ -gram
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(0, Arrays.asList("Hi", "I", "heard", "about", "Spark")),
RowFactory.create(1, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")),
RowFactory.create(2, Arrays.asList("Logistic", "regression", "models", "are", "neat"))
);
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField(
"words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
Dataset<Row> wordDataFrame = spark.createDataFrame(data, schema);
NGram ngramTransformer = new NGram().setN(2).setInputCol("words").setOutputCol("ngrams");
Dataset<Row> ngramDataFrame = ngramTransformer.transform(wordDataFrame);
ngramDataFrame.select("ngrams").show(false);
Binarization 二值化
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(0, 0.1),
RowFactory.create(1, 0.8),
RowFactory.create(2, 0.2)
);
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
Dataset<Row> continuousDataFrame = spark.createDataFrame(data, schema);
Binarizer binarizer = new Binarizer()
.setInputCol("feature")
.setOutputCol("binarized_feature")
.setThreshold(0.5);
Dataset<Row> binarizedDataFrame = binarizer.transform(continuousDataFrame);
System.out.println("Binarizer output with Threshold = " + binarizer.getThreshold());
binarizedDataFrame.show();
+---+-------+-----------------+
| id|feature|binarized_feature|
+---+-------+-----------------+
| 0| 0.1| 0.0|
| 1| 0.8| 1.0|
| 2| 0.2| 0.0|
+---+-------+-----------------+
PCA 主成成分分析
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})),
RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
);
StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
Dataset<Row> df = spark.createDataFrame(data, schema);
PCAModel pca = new PCA()
.setInputCol("features")
.setOutputCol("pcaFeatures")
.setK(3)
.fit(df);
Dataset<Row> result = pca.transform(df).select("pcaFeatures");
result.show(false);
+-----------------------------------------------------------+
|pcaFeatures |
+-----------------------------------------------------------+
|[1.6485728230883807,-4.013282700516296,-5.524543751369388] |
|[-4.645104331781534,-1.1167972663619026,-5.524543751369387]|
|[-6.428880535676489,-5.337951427775355,-5.524543751369389] |
+-----------------------------------------------------------+
PolynomialExpansion 多项式转化
PolynomialExpansion polyExpansion = new PolynomialExpansion()
.setInputCol("features")
.setOutputCol("polyFeatures")
.setDegree(2);
List<Row> data = Arrays.asList(
RowFactory.create(Vectors.dense(2.0, 1.0)),
RowFactory.create(Vectors.dense(0.0, 0.0)),
RowFactory.create(Vectors.dense(3.0, -1.0))
);
StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
Dataset<Row> df = spark.createDataFrame(data, schema);
Dataset<Row> polyDF = polyExpansion.transform(df);
polyDF.show(false);
+----------+-----------------------+
|features |polyFeatures |
+----------+-----------------------+
|[2.0,1.0] |[2.0,4.0,1.0,2.0,1.0] |
|[0.0,0.0] |[0.0,0.0,0.0,0.0,0.0] |
|[3.0,-1.0]|[3.0,9.0,-1.0,-3.0,1.0]|
+----------+-----------------------+
TF-IDF
- TF:词频,词条在给定文档中所出现的频率,值越大则词条在该文档中重要性越大
词条 $i$ 的TF: $TF_i$ = $\frac{N_i}{N}$ ( $N$ :给定文档的总词条数)
- IDF:逆向文本频率,词条在所有文档中出现的频率,值越大则词条在该文档中重要性越小
词条 $i$ 的IDF: $IDF_i$ = $log(\frac{D}{D_i+1})$ ( $D$ :总文档数, $D_i$ :包含词条 $i$ 的文档数)
TF-IDF $_i$ = $TF_i*IDF_i$ (词条 $i$ 的TF-IDF值)
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(0.0, "I love spark spark spark"),
RowFactory.create(1.0, "I love java java is my life"),
RowFactory.create(2.0, "I love C++ C++ is the best")
);
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
sentenceData.show();
Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); //分词
Dataset<Row> wordsData = tokenizer.transform(sentenceData);
wordsData.show();
int numFeatures = 20;
HashingTF hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(numFeatures);
Dataset<Row> featurizedData = hashingTF.transform(wordsData);
featurizedData.show();
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
Dataset<Row> rescaledData = idfModel.transform(featurizedData);
rescaledData.show();
+-----+--------------------+
|label| sentence|
+-----+--------------------+
| 0.0|I love spark spar...|
| 1.0|I love java java ...|
| 2.0|I love C++ C++ is...|
+-----+--------------------+
+-----+--------------------+--------------------+
|label| sentence| words|
+-----+--------------------+--------------------+
| 0.0|I love spark spar...|[i, love, spark, ...|
| 1.0|I love java java ...|[i, love, java, j...|
| 2.0|I love C++ C++ is...|[i, love, c++, c+...|
#rawFeatures:hash桶数、词hash索引、词频率
+-----+--------------------+--------------------+--------------------+
|label| sentence| words| rawFeatures|
+-----+--------------------+--------------------+--------------------+
| 0.0|I love spark spar...|[i, love, spark, ...|(20,[0,5,9],[1.0,...|
| 1.0|I love java java ...|[i, love, java, j...|(20,[0,1,7,9,16,1...|
| 2.0|I love C++ C++ is...|[i, love, c++, c+...|(20,[0,1,3,9,10,1...|
+-----+--------------------+--------------------+--------------------+
+-----+--------------------+--------------------+--------------------+--------------------+
|label| sentence| words| rawFeatures| features|
+-----+--------------------+--------------------+--------------------+--------------------+
| 0.0|I love spark spar...|[i, love, spark, ...|(20,[0,5,9],[1.0,...|(20,[0,5,9],[0.0,...|
| 1.0|I love java java ...|[i, love, java, j...|(20,[0,1,7,9,16,1...|(20,[0,1,7,9,16,1...|
| 2.0|I love C++ C++ is...|[i, love, c++, c+...|(20,[0,1,3,9,10,1...|(20,[0,1,3,9,10,1...|
+-----+--------------------+--------------------+--------------------+--------------------+
CountVectorizer
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create("I love spark spark spark"),
RowFactory.create("I love java java is my life"),
RowFactory.create("I love C++ C++ is the best")
);
StructType schema = new StructType(new StructField[]{
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
sentenceData.show();
Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); //分词
Dataset<Row> wordsData = tokenizer.transform(sentenceData);
wordsData.show();
CountVectorizerModel cvModel = new CountVectorizer()
.setInputCol("words")
.setOutputCol("feature")
//.setVocabSize(3)
//.setMinDF(2)
.fit(wordsData);
System.out.println(Arrays.asList(cvModel.vocabulary()));
cvModel.transform(wordsData).show(false);
+--------------------+
| sentence|
+--------------------+
|I love spark spar...|
|I love java java ...|
|I love C++ C++ is...|
+--------------------+
+--------------------+--------------------+
| sentence| words|
+--------------------+--------------------+
|I love spark spar...|[i, love, spark, ...|
|I love java java ...|[i, love, java, j...|
|I love C++ C++ is...|[i, love, c++, c+...|
+--------------------+--------------------+
[love, spark, i, c++, java, is, the, life, best, my]
#feature:词总数、词索引、词频率
+---------------------------+-----------------------------------+--------------------------------------------+
|sentence |words |feature |
+---------------------------+-----------------------------------+--------------------------------------------+
|I love spark spark spark |[i, love, spark, spark, spark] |(10,[0,1,2],[1.0,3.0,1.0]) |
|I love java java is my life|[i, love, java, java, is, my, life]|(10,[0,2,4,5,7,9],[1.0,1.0,2.0,1.0,1.0,1.0])|
|I love C++ C++ is the best |[i, love, c++, c++, is, the, best] |(10,[0,2,3,5,6,8],[1.0,1.0,2.0,1.0,1.0,1.0])|
+---------------------------+-----------------------------------+--------------------------------------------+
FeatureHasher
SparkSession spark = SparkSession.builder().appName("ml test")
.master("local[*]")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(2.2, true, "1", "foo"),
RowFactory.create(3.3, false, "2", "bar"),
RowFactory.create(4.4, false, "3", "baz"),
RowFactory.create(5.5, false, "4", "foo")
);
StructType schema = new StructType(new StructField[]{
new StructField("real", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("bool", DataTypes.BooleanType, false, Metadata.empty()),
new StructField("stringNum", DataTypes.StringType, false, Metadata.empty()),
new StructField("string", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> dataset = spark.createDataFrame(data, schema);
FeatureHasher hasher = new FeatureHasher()
.setInputCols(new String[]{"real", "bool", "stringNum", "string"})
.setOutputCol("features");
Dataset<Row> featurized = hasher.transform(dataset);
featurized.show(false);
+----+-----+---------+-------------------+
|real|bool |stringNum|features |
+----+-----+---------+-------------------+
|2.2 |true |1 |(2,[0,1],[2.0,2.2])|
|3.3 |false|2 |(2,[0,1],[1.0,4.3])|
|4.4 |false|3 |(2,[0,1],[2.0,4.4])|
|5.5 |false|4 |(2,[0,1],[1.0,6.5])|
+----+-----+---------+-------------------+
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!