Spark Machine Learning

  • 引入依赖
<dependency>
  <groupId>org.apache.spark</groupId>
  <artifactId>spark-mllib_2.11</artifactId>
  <version>2.4.5</version>
</dependency>

相关系数矩阵

相关系数矩阵(i行,j列)的元素是原矩阵i列与j列的相关系数

#身高与体重的相关系数矩阵
SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
List<Row> data = Arrays.asList(
          RowFactory.create(Vectors.dense(172, 60)),
          RowFactory.create(Vectors.dense(175, 65)),
          RowFactory.create(Vectors.dense(180, 70)),
          RowFactory.create(Vectors.dense(190, 80))
);

StructType schema = new StructType(new StructField[]{
  new StructField("features", new VectorUDT(), false, Metadata.empty()),
});

Dataset<Row> df = spark.createDataFrame(data, schema);
df.show();
Row pearson = Correlation.corr(df, "features").head();
System.out.println("Pearson correlation matrix:\n" + pearson);

Row spearman = Correlation.corr(df, "features", "spearman").head();
System.out.println("Spearman correlation matrix:\n" + spearman);

+------------+
|    features|
+------------+
|[172.0,60.0]|
|[175.0,65.0]|
|[180.0,70.0]|
|[190.0,80.0]|
+------------+

Pearson correlation matrix:
[1.0                 0.9957069831288888  
0.9957069831288888  1.0                 ]

Spearman correlation matrix:
[1.0                 0.9999999999999981  
0.9999999999999981  1.0                 ]

Extracting, transforming and selecting features

Tokenizer 分词

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
List<Row> data = Arrays.asList(
                RowFactory.create(0, "Hi I heard about Spark"),
                RowFactory.create(1, "I wish Java could use case classes"),
                RowFactory.create(2, "Logistic,regression,models,are,neat")
        );

StructType schema = new StructType(new StructField[]{
        new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
        new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});

Dataset<Row> sentenceDataFrame = spark.createDataFrame(data, schema);

Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
Dataset<Row> tokenized = tokenizer.transform(sentenceDataFrame);
tokenized.show(false);

//正则
RegexTokenizer regexTokenizer = new RegexTokenizer()
        .setInputCol("sentence")
        .setOutputCol("words")
        .setPattern("\\W");  // alternatively .setPattern("\\w+").setGaps(false);
Dataset<Row> regexTokenized = regexTokenizer.transform(sentenceDataFrame);
regexTokenized.show(false);

+---+-----------------------------------+------------------------------------------+
|id |sentence                           |words                                     |
+---+-----------------------------------+------------------------------------------+
|0  |Hi I heard about Spark             |[hi, i, heard, about, spark]              |
|1  |I wish Java could use case classes |[i, wish, java, could, use, case, classes]|
|2  |Logistic,regression,models,are,neat|[logistic,regression,models,are,neat]     |
+---+-----------------------------------+------------------------------------------+

+---+-----------------------------------+------------------------------------------+
|id |sentence                           |words                                     |
+---+-----------------------------------+------------------------------------------+
|0  |Hi I heard about Spark             |[hi, i, heard, about, spark]              |
|1  |I wish Java could use case classes |[i, wish, java, could, use, case, classes]|
|2  |Logistic,regression,models,are,neat|[logistic, regression, models, are, neat] |
+---+-----------------------------------+------------------------------------------+

StopWordsRemover 停止词

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
StopWordsRemover remover = new StopWordsRemover()
                .setInputCol("raw")
                .setOutputCol("filtered");

List<Row> data = Arrays.asList(
        RowFactory.create(Arrays.asList("I", "saw", "the", "red", "balloon")),
        RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
);

StructType schema = new StructType(new StructField[]{
        new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});

Dataset<Row> dataset = spark.createDataFrame(data, schema);
remover.transform(dataset).show(false);

+----------------------------+--------------------+
|raw                         |filtered            |
+----------------------------+--------------------+
|[I, saw, the, red, balloon] |[saw, red, balloon] |
|[Mary, had, a, little, lamb]|[Mary, little, lamb]|
+----------------------------+--------------------+

$n$ -gram

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();

List<Row> data = Arrays.asList(
        RowFactory.create(0, Arrays.asList("Hi", "I", "heard", "about", "Spark")),
        RowFactory.create(1, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")),
        RowFactory.create(2, Arrays.asList("Logistic", "regression", "models", "are", "neat"))
);

StructType schema = new StructType(new StructField[]{
        new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
        new StructField(
                "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});

Dataset<Row> wordDataFrame = spark.createDataFrame(data, schema);

NGram ngramTransformer = new NGram().setN(2).setInputCol("words").setOutputCol("ngrams");

Dataset<Row> ngramDataFrame = ngramTransformer.transform(wordDataFrame);
ngramDataFrame.select("ngrams").show(false);

Binarization 二值化

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
List<Row> data = Arrays.asList(
        RowFactory.create(0, 0.1),
        RowFactory.create(1, 0.8),
        RowFactory.create(2, 0.2)
);
StructType schema = new StructType(new StructField[]{
        new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
        new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
Dataset<Row> continuousDataFrame = spark.createDataFrame(data, schema);

Binarizer binarizer = new Binarizer()
        .setInputCol("feature")
        .setOutputCol("binarized_feature")
        .setThreshold(0.5);

Dataset<Row> binarizedDataFrame = binarizer.transform(continuousDataFrame);

System.out.println("Binarizer output with Threshold = " + binarizer.getThreshold());
binarizedDataFrame.show();

+---+-------+-----------------+
| id|feature|binarized_feature|
+---+-------+-----------------+
|  0|    0.1|              0.0|
|  1|    0.8|              1.0|
|  2|    0.2|              0.0|
+---+-------+-----------------+

PCA 主成成分分析

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
List<Row> data = Arrays.asList(
        RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})),
        RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
        RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
);

StructType schema = new StructType(new StructField[]{
        new StructField("features", new VectorUDT(), false, Metadata.empty()),
});

Dataset<Row> df = spark.createDataFrame(data, schema);

PCAModel pca = new PCA()
        .setInputCol("features")
        .setOutputCol("pcaFeatures")
        .setK(3)
        .fit(df);

Dataset<Row> result = pca.transform(df).select("pcaFeatures");
result.show(false);

+-----------------------------------------------------------+
|pcaFeatures                                                |
+-----------------------------------------------------------+
|[1.6485728230883807,-4.013282700516296,-5.524543751369388] |
|[-4.645104331781534,-1.1167972663619026,-5.524543751369387]|
|[-6.428880535676489,-5.337951427775355,-5.524543751369389] |
+-----------------------------------------------------------+

PolynomialExpansion 多项式转化

PolynomialExpansion polyExpansion = new PolynomialExpansion()
                .setInputCol("features")
                .setOutputCol("polyFeatures")
                .setDegree(2);

List<Row> data = Arrays.asList(
        RowFactory.create(Vectors.dense(2.0, 1.0)),
        RowFactory.create(Vectors.dense(0.0, 0.0)),
        RowFactory.create(Vectors.dense(3.0, -1.0))
);
StructType schema = new StructType(new StructField[]{
        new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
Dataset<Row> df = spark.createDataFrame(data, schema);

Dataset<Row> polyDF = polyExpansion.transform(df);
polyDF.show(false);

+----------+-----------------------+
|features  |polyFeatures           |
+----------+-----------------------+
|[2.0,1.0] |[2.0,4.0,1.0,2.0,1.0]  |
|[0.0,0.0] |[0.0,0.0,0.0,0.0,0.0]  |
|[3.0,-1.0]|[3.0,9.0,-1.0,-3.0,1.0]|
+----------+-----------------------+

TF-IDF

  • TF:词频,词条在给定文档中所出现的频率,值越大则词条在该文档中重要性越大

词条 $i$ 的TF: $TF_i$ = $\frac{N_i}{N}$ ( $N$ :给定文档的总词条数)

  • IDF:逆向文本频率,词条在所有文档中出现的频率,值越大则词条在该文档中重要性越小

词条 $i$ 的IDF: $IDF_i$ = $log(\frac{D}{D_i+1})$ ( $D$ :总文档数, $D_i$ :包含词条 $i$ 的文档数)

TF-IDF $_i$ = $TF_i*IDF_i$ (词条 $i$ 的TF-IDF值)

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
List<Row> data = Arrays.asList(
                  RowFactory.create(0.0, "I love spark spark spark"),
                  RowFactory.create(1.0, "I love java java is my life"),
                  RowFactory.create(2.0, "I love C++ C++ is the best")
        );

StructType schema = new StructType(new StructField[]{
          new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
          new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
sentenceData.show();

Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); //分词
Dataset<Row> wordsData = tokenizer.transform(sentenceData);
wordsData.show();

int numFeatures = 20;
HashingTF hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(numFeatures);
Dataset<Row> featurizedData = hashingTF.transform(wordsData);
featurizedData.show();

IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);


Dataset<Row> rescaledData = idfModel.transform(featurizedData);
rescaledData.show();


+-----+--------------------+
|label|            sentence|
+-----+--------------------+
|  0.0|I love spark spar...|
|  1.0|I love java java ...|
|  2.0|I love C++ C++ is...|
+-----+--------------------+

+-----+--------------------+--------------------+
|label|            sentence|               words|
+-----+--------------------+--------------------+
|  0.0|I love spark spar...|[i, love, spark, ...|
|  1.0|I love java java ...|[i, love, java, j...|
|  2.0|I love C++ C++ is...|[i, love, c++, c+...|

#rawFeatures:hash桶数、词hash索引、词频率
+-----+--------------------+--------------------+--------------------+
|label|            sentence|               words|         rawFeatures|
+-----+--------------------+--------------------+--------------------+
|  0.0|I love spark spar...|[i, love, spark, ...|(20,[0,5,9],[1.0,...|
|  1.0|I love java java ...|[i, love, java, j...|(20,[0,1,7,9,16,1...|
|  2.0|I love C++ C++ is...|[i, love, c++, c+...|(20,[0,1,3,9,10,1...|
+-----+--------------------+--------------------+--------------------+

+-----+--------------------+--------------------+--------------------+--------------------+
|label|            sentence|               words|         rawFeatures|            features|
+-----+--------------------+--------------------+--------------------+--------------------+
|  0.0|I love spark spar...|[i, love, spark, ...|(20,[0,5,9],[1.0,...|(20,[0,5,9],[0.0,...|
|  1.0|I love java java ...|[i, love, java, j...|(20,[0,1,7,9,16,1...|(20,[0,1,7,9,16,1...|
|  2.0|I love C++ C++ is...|[i, love, c++, c+...|(20,[0,1,3,9,10,1...|(20,[0,1,3,9,10,1...|
+-----+--------------------+--------------------+--------------------+--------------------+

CountVectorizer

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
List<Row> data = Arrays.asList(
                  RowFactory.create("I love spark spark spark"),
                  RowFactory.create("I love java java is my life"),
                  RowFactory.create("I love C++ C++ is the best")
        );
StructType schema = new StructType(new StructField[]{
          new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
sentenceData.show();

Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); //分词
Dataset<Row> wordsData = tokenizer.transform(sentenceData);
wordsData.show();

CountVectorizerModel cvModel = new CountVectorizer()
                              .setInputCol("words")
                              .setOutputCol("feature")
                              //.setVocabSize(3)
                              //.setMinDF(2)
                              .fit(wordsData);
System.out.println(Arrays.asList(cvModel.vocabulary()));
cvModel.transform(wordsData).show(false);

+--------------------+
|            sentence|
+--------------------+
|I love spark spar...|
|I love java java ...|
|I love C++ C++ is...|
+--------------------+
+--------------------+--------------------+
|            sentence|               words|
+--------------------+--------------------+
|I love spark spar...|[i, love, spark, ...|
|I love java java ...|[i, love, java, j...|
|I love C++ C++ is...|[i, love, c++, c+...|
+--------------------+--------------------+

[love, spark, i, c++, java, is, the, life, best, my]

#feature:词总数、词索引、词频率
+---------------------------+-----------------------------------+--------------------------------------------+
|sentence                   |words                              |feature                                     |
+---------------------------+-----------------------------------+--------------------------------------------+
|I love spark spark spark   |[i, love, spark, spark, spark]     |(10,[0,1,2],[1.0,3.0,1.0])                  |
|I love java java is my life|[i, love, java, java, is, my, life]|(10,[0,2,4,5,7,9],[1.0,1.0,2.0,1.0,1.0,1.0])|
|I love C++ C++ is the best |[i, love, c++, c++, is, the, best] |(10,[0,2,3,5,6,8],[1.0,1.0,2.0,1.0,1.0,1.0])|
+---------------------------+-----------------------------------+--------------------------------------------+

FeatureHasher

SparkSession spark = SparkSession.builder().appName("ml test")
                 .master("local[*]")
                 .getOrCreate();
List<Row> data = Arrays.asList(
                RowFactory.create(2.2, true, "1", "foo"),
                RowFactory.create(3.3, false, "2", "bar"),
                RowFactory.create(4.4, false, "3", "baz"),
                RowFactory.create(5.5, false, "4", "foo")
        );
StructType schema = new StructType(new StructField[]{
        new StructField("real", DataTypes.DoubleType, false, Metadata.empty()),
        new StructField("bool", DataTypes.BooleanType, false, Metadata.empty()),
        new StructField("stringNum", DataTypes.StringType, false, Metadata.empty()),
        new StructField("string", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> dataset = spark.createDataFrame(data, schema);
FeatureHasher hasher = new FeatureHasher()
        .setInputCols(new String[]{"real", "bool", "stringNum", "string"})
        .setOutputCol("features");

Dataset<Row> featurized = hasher.transform(dataset);

featurized.show(false);

+----+-----+---------+-------------------+
|real|bool |stringNum|features           |
+----+-----+---------+-------------------+
|2.2 |true |1        |(2,[0,1],[2.0,2.2])|
|3.3 |false|2        |(2,[0,1],[1.0,4.3])|
|4.4 |false|3        |(2,[0,1],[2.0,4.4])|
|5.5 |false|4        |(2,[0,1],[1.0,6.5])|
+----+-----+---------+-------------------+


spark     

本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!