public abstract class Classifier<FeaturesType,E extends Classifier<FeaturesType,E,M>,M extends ClassificationModel<FeaturesType,M>> extends Predictor<FeaturesType,E,M>
Single-label binary or multiclass classification. Classes are indexed {0, 1, ..., numClasses - 1}.
| Constructor and Description |
|---|
Classifier() |
| Modifier and Type | Method and Description |
|---|---|
protected RDD<LabeledPoint> |
extractLabeledPoints(Dataset<?> dataset,
int numClasses)
Extract
labelCol and featuresCol from the given dataset,
and put it in an RDD with strong types. |
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
protected int |
getNumClasses(Dataset<?> dataset,
int maxNumClasses)
Get the number of classes.
|
java.lang.String |
getPredictionCol() |
java.lang.String |
getRawPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
E |
setRawPredictionCol(java.lang.String value) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
copy, extractLabeledPoints, fit, setFeaturesCol, setLabelCol, setPredictionCol, train, transformSchematransformSchemaclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitclear, copy, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParamstoString, uidpublic E setRawPredictionCol(java.lang.String value)
protected RDD<LabeledPoint> extractLabeledPoints(Dataset<?> dataset, int numClasses)
labelCol and featuresCol from the given dataset,
and put it in an RDD with strong types.
dataset - DataFrame with columns for labels (NumericType)
and features (Vector). Labels are cast to DoubleType.numClasses - Number of classes label can take. Labels must be integers in the range
[0, numClasses).SparkException - if any label is not an integer >= 0protected int getNumClasses(Dataset<?> dataset, int maxNumClasses)
Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
such as in extractLabeledPoints().
dataset - Dataset which contains a column labelColmaxNumClasses - Maximum number of classes allowed when inferred from data. If numClasses
is specified in the metadata, then maxNumClasses is ignored.java.lang.IllegalArgumentException - if metadata does not specify numClasses, and the
actual numClasses exceeds maxNumClassespublic StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()