RFormula通过一个R model formula选择一个特定的列。
目前我们支持R算子的一个受限的子集,包括~,.,:,+,-。这些基本的算子是:
~分开target和terms+连接term,+ 0表示删除截距(intercept)-删除term,- 1表示删除截距:交集.除了target之外的所有列
假设a和b是double列,我们用下面简单的例子来证明RFormula的有效性。
y ~ a + b表示模型y ~ w0 + w1 * a + w2 * b,其中w0是截距,w1和w2是系数y ~ a + b + a:b - 1表示模型y ~ w1 * a + w2 * b + w3 * a * b,其中w1,w2,w3是系数
RFormula产生一个特征向量列和一个double或string类型的标签列。比如在线性回归中使用R中的公式时,
字符串输入列是one-hot编码,数值列强制转换为double类型。如果标签列是字符串类型,它将使用StringIndexer转换为double
类型。如果DataFrame中不存在标签列,输出的标签列将通过公式中指定的返回变量来创建。
假设我们有一个DataFrame,它的列名是id, country, hour和clicked。
id | country | hour | clicked
---|---------|------|---------
7 | "US" | 18 | 1.0
8 | "CA" | 12 | 0.0
9 | "NZ" | 15 | 0.0
如果我们用clicked ~ country + hour(基于country和hour来预测clicked)来作用于RFormula,将会得到下面的结果。
id | country | hour | clicked | features | label
---|---------|------|---------|------------------|-------
7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0
8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0
9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0
下面是代码调用的例子。
import org.apache.spark.ml.feature.RFormula
val dataset = spark.createDataFrame(Seq(
(7, "US", 18, 1.0),
(8, "CA", 12, 0.0),
(9, "NZ", 15, 0.0)
)).toDF("id", "country", "hour", "clicked")
val formula = new RFormula()
.setFormula("clicked ~ country + hour")
.setFeaturesCol("features")
.setLabelCol("label")
val output = formula.fit(dataset).transform(dataset)
output.select("features", "label").show()