Skip to content

Commit

Permalink
Some doc, and various fixes in scala style. Also adopted Spark's own …
Browse files Browse the repository at this point in the history
…scalastyle-config.xml whole.
  • Loading branch information
fastier-li committed Feb 15, 2017
1 parent e17c09d commit b2cda08
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 232 deletions.
8 changes: 8 additions & 0 deletions photon-ml/src/integTest/resources/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## Integration tests resources

- DriverIntegTest contains input data to test the Photon driver.
- GameIntegTest contains data to test Game.
- The integration tests for Game use the Yahoo! Music User Ratings dataset at this link:
http://webscope.sandbox.yahoo.com/catalog.php?datatype=cC15
- GLMSuiteIntegTest contains data for I/O tests.
- IOUtilsTest is used in various I/O tests.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class CoordinateDescent(
val coordinateTimer = Timer.start()
logger.debug(s"Start to update coordinate with ID $coordinateId (${coordinate.getClass})")

// Update the model
// Update the model => call the optimizer
val modelUpdatingTimer = Timer.start()
val oldModel = updatedGAMEModel.getModel(coordinateId).get
val (updatedModel, optimizationTrackerOption) = if (updatedScoresContainer.keys.size > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ abstract class GAMEDriver(
/**
* Resolves paths for specified date ranges to physical file paths
*
* @param baseDir the base dirs to which date-specific relative paths will be appended
* @param baseDirs the base dirs to which date-specific relative paths will be appended
* @param dateRangeOpt optional date range
* @param daysAgo optional days-ago specification for date range
* @param daysAgoOpt optional days-ago specification for date range
* @return all resolved paths
*/
protected[game] def pathsForDateRange(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ object AvroDataReader {
*/
val INTERCEPT_NAME = "(INTERCEPT)"
val INTERCEPT_TERM = ""
val INTERCEPT_KEY = Utils.getFeatureKey(INTERCEPT_NAME, INTERCEPT_TERM)
val INTERCEPT_KEY: String = Utils.getFeatureKey(INTERCEPT_NAME, INTERCEPT_TERM)

/**
* Reads feature keys and values from the avro generic record.
Expand Down Expand Up @@ -354,47 +354,49 @@ object AvroDataReader {
* @param avroSchema the avro schema for the field
* @return spark sql schema for the field
*/
protected[data] def avroTypeToSql(name: String, avroSchema: Schema): Option[StructField] = avroSchema.getType match {
case INT => Some(StructField(name, IntegerType, nullable = false))
case STRING => Some(StructField(name, StringType, nullable = false))
case BOOLEAN => Some(StructField(name, BooleanType, nullable = false))
case DOUBLE => Some(StructField(name, DoubleType, nullable = false))
case FLOAT => Some(StructField(name, FloatType, nullable = false))
case LONG => Some(StructField(name, LongType, nullable = false))
case MAP =>
avroTypeToSql(name, avroSchema.getValueType).map { valueSchema =>
StructField(
name,
MapType(StringType, valueSchema.dataType, valueContainsNull = valueSchema.nullable),
nullable = false)
}

case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, take the first non-null type for the value type
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
avroTypeToSql(name, remainingUnionTypes.head).map(_.copy(nullable = true))
} else {
avroTypeToSql(name, Schema.createUnion(remainingUnionTypes.asJava)).map(_.copy(nullable = true))
protected[data] def avroTypeToSql(name: String, avroSchema: Schema): Option[StructField] =

avroSchema.getType match {
case INT => Some(StructField(name, IntegerType, nullable = false))
case STRING => Some(StructField(name, StringType, nullable = false))
case BOOLEAN => Some(StructField(name, BooleanType, nullable = false))
case DOUBLE => Some(StructField(name, DoubleType, nullable = false))
case FLOAT => Some(StructField(name, FloatType, nullable = false))
case LONG => Some(StructField(name, LongType, nullable = false))
case MAP =>
avroTypeToSql(name, avroSchema.getValueType).map { valueSchema =>
StructField(
name,
MapType(StringType, valueSchema.dataType, valueContainsNull = valueSchema.nullable),
nullable = false)
}

} else avroSchema.getTypes.asScala.map(_.getType) match {
case Seq(t1) =>
avroTypeToSql(name, avroSchema.getTypes.get(0))
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
Some(StructField(name, LongType, nullable = false))
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
Some(StructField(name, DoubleType, nullable = false))
case _ =>
// Unsupported union type. Drop this for now.
None
}
case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, take the first non-null type for the value type
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
avroTypeToSql(name, remainingUnionTypes.head).map(_.copy(nullable = true))
} else {
avroTypeToSql(name, Schema.createUnion(remainingUnionTypes.asJava)).map(_.copy(nullable = true))
}

} else avroSchema.getTypes.asScala.map(_.getType) match {
case Seq(t1) =>
avroTypeToSql(name, avroSchema.getTypes.get(0))
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
Some(StructField(name, LongType, nullable = false))
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
Some(StructField(name, DoubleType, nullable = false))
case _ =>
// Unsupported union type. Drop this for now.
None
}

case _ =>
// Unsupported avro field type. Drop this for now.
None
}
case _ =>
// Unsupported avro field type. Drop this for now.
None
}

/**
* Read the fields from the avro record into column values according to the supplied spark sql schema.
Expand All @@ -403,34 +405,37 @@ object AvroDataReader {
* @param schemaFields the spark sql schema to apply when reading the record
* @return column values
*/
protected[data] def readColumnValuesFromRecord(record: GenericRecord, schemaFields: Seq[StructField]) = schemaFields
.flatMap { field: StructField => field.dataType match {
case IntegerType => checkNull(record, field).orElse(Some(Utils.getIntAvro(record, field.name)))
case StringType => Some(Utils.getStringAvro(record, field.name, field.nullable))
case BooleanType => checkNull(record, field).orElse(Some(Utils.getBooleanAvro(record, field.name)))
case DoubleType => checkNull(record, field).orElse(Some(Utils.getDoubleAvro(record, field.name)))
case FloatType => checkNull(record, field).orElse(Some(Utils.getFloatAvro(record, field.name)))
case LongType => checkNull(record, field).orElse(Some(Utils.getLongAvro(record, field.name)))
case MapType(_, _, _) => Some(Utils.getMapAvro(record, field.name, field.nullable))
case _ =>
// Unsupported field type. Drop this for now.
None
}
}
protected[data] def readColumnValuesFromRecord(record: GenericRecord, schemaFields: Seq[StructField]): Seq[Any] =

schemaFields
.flatMap { field: StructField =>
field.dataType match {
case IntegerType => checkNull(record, field).orElse(Some(Utils.getIntAvro(record, field.name)))
case StringType => Some(Utils.getStringAvro(record, field.name, field.nullable))
case BooleanType => checkNull(record, field).orElse(Some(Utils.getBooleanAvro(record, field.name)))
case DoubleType => checkNull(record, field).orElse(Some(Utils.getDoubleAvro(record, field.name)))
case FloatType => checkNull(record, field).orElse(Some(Utils.getFloatAvro(record, field.name)))
case LongType => checkNull(record, field).orElse(Some(Utils.getLongAvro(record, field.name)))
case MapType(_, _, _) => Some(Utils.getMapAvro(record, field.name, field.nullable))
case _ =>
// Unsupported field type. Drop this for now.
None
}
}

/**
* Checks whether null values are allowed for the record, and if so, passes along the null value. Otherwise, returns
* None.
*
* @param record the avro GenericRecord
* @param field the schema field
* @return Some(null) if the field is null and nullable. None otherwise.
*/
protected[data] def checkNull(record: GenericRecord, field: StructField): Option[_] = {
* Checks whether null values are allowed for the record, and if so, passes along the null value. Otherwise, returns
* None.
*
* @param record the avro GenericRecord
* @param field the schema field
* @return Some(null) if the field is null and nullable. None otherwise.
*/
protected[data] def checkNull(record: GenericRecord, field: StructField): Option[_] =

if (record.get(field.name) == null && field.nullable) {
Some(null)
} else {
None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.DataFrame
/**
* The DataReader interface. This interface should be implemented by readers for specific data formats.
*
* @param sc the Spark context
* @param defaultFeatureColumn the default column to use for features
*/
abstract class DataReader(protected val defaultFeatureColumn: String = "features") {
Expand Down Expand Up @@ -182,7 +181,7 @@ abstract class DataReader(protected val defaultFeatureColumn: String = "features
* different sources, and it can be more scalable to combine them into problem-specific feature vectors that can be
* independently distributed.
*
* @param path the path to the file or folder
* @param paths the path to the file or folder
* @param indexMapLoaders a map of index map loaders, containing one loader for each merged feature column
* @param featureColumnMap a map that specifies how the feature columns should be merged. The keys specify the name
* of the merged destination column, and the values are sets of source columns to merge, e.g.:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ package com.linkedin.photon.ml.estimators

import scala.collection.Map

import org.slf4j.Logger
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.slf4j.Logger

import com.linkedin.photon.ml.algorithm._
import com.linkedin.photon.ml.constants.StorageLevel
import com.linkedin.photon.ml.data._
import com.linkedin.photon.ml.evaluation._
import com.linkedin.photon.ml.evaluation.Evaluator.EvaluationResults
import com.linkedin.photon.ml.evaluation._
import com.linkedin.photon.ml.function.glm._
import com.linkedin.photon.ml.function.svm.{DistributedSmoothedHingeLossFunction, SingleNodeSmoothedHingeLossFunction}
import com.linkedin.photon.ml.model.GAMEModel
Expand All @@ -37,9 +37,8 @@ import com.linkedin.photon.ml.projector.IdentityProjection
import com.linkedin.photon.ml.sampler.{BinaryClassificationDownSampler, DefaultDownSampler}
import com.linkedin.photon.ml.supervised.classification.{LogisticRegressionModel, SmoothedHingeLossLinearSVMModel}
import com.linkedin.photon.ml.supervised.regression.{LinearRegressionModel, PoissonRegressionModel}
import com.linkedin.photon.ml.TaskType
import com.linkedin.photon.ml.util._
import com.linkedin.photon.ml.{BroadcastLike, RDDLike, SparkContextConfiguration}
import com.linkedin.photon.ml.{BroadcastLike, RDDLike, TaskType}

/**
* Estimator implementation for GAME models
Expand All @@ -61,7 +60,7 @@ class GameEstimator(val params: GameParams, val sparkContext: SparkContext, val
* Fits GAME models to the training dataset
*
* @param data the training set
* @param validationData optional validation set for per-iteration validation
* @param validatingData optional validation set for per-iteration validation
* @return a set of GAME models, one for each combination of fixed and random effect combination specified in the
* params
*/
Expand Down Expand Up @@ -230,8 +229,7 @@ class GameEstimator(val params: GameParams, val sparkContext: SparkContext, val
/**
* Creates the validation evaluator(s)
*
* @param validatingDirs The input path for validating data set
* @param featureShardIdToFeatureMapLoader A map of shard id to feature indices
* @param data The input data
* @return The validating game data sets and the companion evaluator
*/
protected[estimators] def prepareValidatingEvaluators(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import org.apache.spark.SparkContext
* To access an IndexMap within RDD operations, directly referring to an object inside Driver is inefficient.
* The driver will try to serialize the entire object onto RDDs. This trait provides a uniform way of loading
* feature index maps, regardless of their concrete implementation.
*
* TODO(fastier): simplify this loader hierarchy
*/
trait IndexMapLoader extends Serializable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@ import java.io.{File => JFile}
import collection.JavaConverters._

/**
* An off heap index map implementation using [[PalDB]].
*
* The internal implementation assumed the following things:
* 1. One DB storage is partitioned into multiple pieces we call partitions. It should be generated and controlled by
* [[com.linkedin.photon.ml.FeatureIndexingJob]]. The partition strategy is via the hashcode of the feature names,
* following the rules defined in [[org.apache.spark.HashPartitioner]].
*
* 2. Each time when a user is querying the index of a certain feature, the index map will first compute the hashcode,
* and then compute the expected partition of the storageReader.
*
* 3. Because the way we are building each index partition (they are built in parallel, without sharing information
* with each other). Each partition's internal index always starts from 0. Thus, we are keeping an offset array to
* properly record how much offset we should provide for each index coming from a particular partition. In this way,
* we could safely ensure that each index is unique.
*
* 4. Each time when a user is querying for the feature name of a given index, we'll do a binary search for the proper
* storage according to offset ranges and then return null or the proper feature name.
*/
* An off heap index map implementation using [[PalDB]].
*
* The internal implementation assumes the following:
*
* 1. One DB storage is partitioned into multiple pieces we call partitions. It should be generated and controlled by
* [[com.linkedin.photon.ml.FeatureIndexingJob]]. The partition strategy is via the hashcode of the feature names,
* following the rules defined in [[org.apache.spark.HashPartitioner]].
*
* 2. Each time a user queries the index of a certain feature, the index map will first compute the feature hashcode,
* and then compute the expected partition for this hashcode.
*
* 3. Because the index partitions are built in parallel, without sharing information between them, each partition's
* internal index always starts from 0. We keep an array of offsets to properly compute the indexes (keep them
* unique across partitions).
*
* 4. Each time a user queries for the feature name of a given index, we do a binary search for the proper
* storage according to offset ranges and then return null or the proper feature name.
*/
class PalDBIndexMap extends IndexMap {
import PalDBIndexMap._

Expand All @@ -57,14 +57,14 @@ class PalDBIndexMap extends IndexMap {
private var _partitioner: HashPartitioner = _

/**
* Load a storage at a particular path
*
* @param storePath The directory where the storage is put
* @param partitionsNum The number of partitions, the storage contains
* @param isLocal default: false, if set false will use SparkFiles to access cached files; otherwise,
* it will directly read from local files
* @return
*/
* Load a storage at a particular path
*
* @param storePath The directory where the storage is put
* @param partitionsNum The number of partitions, the storage contains
* @param isLocal default: false, if set false will use SparkFiles to access cached files; otherwise,
* it will directly read from local files
* @return A PalDBIndexMap instance
*/
def load(
storePath: String,
partitionsNum: Int,
Expand Down Expand Up @@ -161,17 +161,18 @@ class PalDBIndexMap extends IndexMap {
}

object PalDBIndexMap {
/* PalDB is not thread safe for parallel reads even for different storages, necessary to lock it.
/**
* PalDB is not thread safe for parallel reads even for different storages, necessary to lock it.
*/
private val PALDB_READER_LOCK = "READER_LOCK"

/**
* Returns the formatted filename for a partition file of PalDB IndexMap storing (name -> index) mapping
* This method should be used consistently as a protocol to handle naming conventions
*
* @param partitionId the partition Id
* @return the formatted filename
*/
* Returns the formatted filename for a partition file of PalDB IndexMap storing (name -> index) mapping
* This method should be used consistently as a protocol to handle naming conventions
*
* @param partitionId the partition Id
* @return the formatted filename
*/
def partitionFilename(partitionId: Int, namespace: String = IndexMap.GLOBAL_NS): String =
s"paldb-partition-$namespace-$partitionId.dat"

Expand Down
Loading

0 comments on commit b2cda08

Please sign in to comment.