Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Batch leap frame and a sample batch tf transformer #600

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ object Bundle {

val pipeline = "pipeline"
val tensorflow = "tensorflow"
val batch_tensorflow = "batch_tensorflow"
}

def apply[Transformer <: AnyRef](name: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package ml.combust.mleap.runtime.frame

import java.lang.Iterable

import ml.combust.mleap.core.types.{BasicType, StructField, StructType}
import ml.combust.mleap.runtime.frame.Row.RowSelector
import ml.combust.mleap.runtime.function.{Selector, UserDefinedFunction}

import scala.collection.JavaConverters._
import scala.util.{Failure, Try}

/** Class for storing a leap frame locally in batches of rows.
*
* @param schema schema of leap frame
*/
case class BatchLeapFrame(override val schema: StructType,
dataset: Seq[Row]) extends LeapFrame[BatchLeapFrame] {
def this(schema: StructType, rows: Iterable[Row]) = this(schema, rows.asScala.toSeq)

/** Try to select fields to create a new LeapFrame.
*
* Returns a Failure if attempting to select any fields that don't exist.
*
* @param fieldNames field names to select
* @return try new LeapFrame with selected fields
*/
override def select(fieldNames: String*): Try[BatchLeapFrame] = {
schema.indicesOf(fieldNames: _*).flatMap {
indices =>
schema.selectIndices(indices: _*).map {
schema2 =>
val dataset2 = dataset.map(_.selectIndices(indices: _*))
BatchLeapFrame(schema2, dataset2)
}
}
}

/** Try to add a column to the LeapFrame.
*
* Returns a Failure if trying to add a field that already exists.
*
* @param name name of column
* @param selectors row selectors used to generate inputs to udf
* @param udf user defined function for calculating column value
* @return LeapFrame with new column
*/
override def withColumn(name: String, selectors: Selector*)
(udf: UserDefinedFunction): Try[BatchLeapFrame] = {
val rowUDF : UserDefinedFunction = UserDefinedFunction(
{(x:Row) => x},
udf.output,
udf.inputs
)
RowUtil.createRowSelectors(schema, selectors: _*)(udf).flatMap {
rowSelectors =>
val field = StructField(name, udf.outputTypes.head)
schema.withField(field).map(schema2 => {
val results = (udf.f.asInstanceOf[Seq[Row] => Seq[Row]])(dataset.map(r => udfValue(rowSelectors: _*)(rowUDF)(r)))
val dataset2: Seq[Row] = dataset.zip(results).map {
case (r1, r2) => r1.toSeq :+ r2.head
}.map(x => Row(x: _*))
BatchLeapFrame(schema2, dataset2)
})
}
}

/** Try to add multiple columns to the LeapFrame.
*
* Returns a Failure if trying to add a field that already exists.
*
* @param names names of columns
* @param selectors row selectors used to generate inputs to udf
* @param udf user defined function for calculating column values
* @return LeapFrame with new columns
*/
override def withColumns(names: Seq[String], selectors: Selector*)
(udf: UserDefinedFunction): Try[BatchLeapFrame] = {
val rowUDF : UserDefinedFunction = UserDefinedFunction(
{(x: Row) => x},
udf.output,
udf.inputs
)
RowUtil.createRowSelectors(schema, selectors: _*)(rowUDF).flatMap {
rowSelectors =>
val fields = names.zip(udf.outputTypes).map {
case (name, dt) => StructField(name, dt)
}

schema.withFields(fields).map(
schema2 => {
val results = (udf.f.asInstanceOf[Seq[Row] => Seq[Row]])(dataset.map(r => udfValue(rowSelectors: _*)(rowUDF)(r)))
val dataset2: Seq[Row] = dataset.zip(results).map {
case (r1, r2) => r1.toSeq ++ r2.toSeq
}.map(x => Row(x: _*))
BatchLeapFrame(schema2, dataset2)
})
}
}

def udfValue(rowSelectors: RowSelector *)(udf : UserDefinedFunction)(row : Row): Row = {
udf.inputs.length match {
case 0 =>
Row()
case 1 =>
Row(rowSelectors.head (row) )
case 2 =>
Row(rowSelectors.head (row), rowSelectors (1) (row) )
case 3 =>
Row(rowSelectors.head (row), rowSelectors (1) (row), rowSelectors (2) (row) )
case 4 =>
Row(rowSelectors.head (row), rowSelectors (1) (row), rowSelectors (2) (row), rowSelectors (3) (row) )
case 5 =>
Row(rowSelectors.head (row), rowSelectors (1) (row), rowSelectors (2) (row), rowSelectors (3) (row), rowSelectors (4) (row) )
}
}

/** Try dropping column(s) from the LeapFrame.
*
* Returns a Failure if the column does not exist.
*
* @param names names of column to drop
* @return LeapFrame with column(s) dropped
*/
override def drop(names: String *): Try[BatchLeapFrame] = {
for(indices <- schema.indicesOf(names: _*);
schema2 <- schema.dropIndices(indices: _*)) yield {
val dataset2 = dataset.map(_.dropIndices(indices: _*))
BatchLeapFrame(schema = schema2, dataset = dataset2)
}
}

/** Try filtering the leap frame using the UDF
*
* @param selectors row selectors used as inputs for the filter
* @param udf filter udf, must return a Boolean
* @return LeapFrame with rows filtered
*/
override def filter(selectors: Selector *)
(udf: UserDefinedFunction): Try[BatchLeapFrame] = {
if(udf.outputTypes.length != 1 || udf.outputTypes.head.base != BasicType.Boolean) {
return Failure(new IllegalArgumentException("must provide a UDF that outputs a boolean for filtering"))
}

RowUtil.createRowSelectors(schema, selectors: _*)(udf).map {
rowSelectors =>
val dataset2 = dataset.filter(_.shouldFilter(rowSelectors: _*)(udf))
BatchLeapFrame(schema, dataset2)
}
}

/** Collect all rows into a Seq
*
* @return all rows in the leap frame
*/
override def collect(): Seq[Row] = dataset
}
3 changes: 2 additions & 1 deletion mleap-tensorflow/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
ml.combust.mleap.tensorflow.ops = [
"ml.combust.mleap.tensorflow.TensorflowTransformerOp"
"ml.combust.mleap.tensorflow.TensorflowTransformerOp",
"ml.combust.mleap.tensorflow.BatchTensorflowTransformerOp"
]

ml.combust.mleap.registry.default.ops += "ml.combust.mleap.tensorflow.ops"
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package ml.combust.mleap.tensorflow

import ml.combust.mleap.core.Model
import ml.combust.mleap.core.types.{StructField, StructType, TensorType}
import ml.combust.mleap.tensor.Tensor
import ml.combust.mleap.tensorflow.converter.{BatchMleapConverter, BatchTensorflowConverter}
import org.tensorflow

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Try

case class BatchTensorflowModel(graph: tensorflow.Graph,
inputs: Seq[(String, TensorType)],
outputs: Seq[(String, TensorType)],
nodes: Option[Seq[String]] = None) extends Model with AutoCloseable {
@transient
private var session: Option[tensorflow.Session] = None

def apply(values: Seq[Tensor[_]] *): Seq[Seq[Any]] = {
val garbage: mutable.ArrayBuilder[tensorflow.Tensor[_]] = mutable.ArrayBuilder.make[tensorflow.Tensor[_]]()

val x = values.transpose
val result = Try {
val tensors: Seq[(String, tensorflow.Tensor[_])] = x.zip(inputs).map {
case (v: Seq[Tensor[_]], (name, dataType)) =>
val tensor = BatchMleapConverter.convert(v, dataType)
garbage += tensor
(name, tensor)
}

withSession {
session =>
val runner = session.runner()

tensors.foreach {
case (name, tensor) => runner.feed(name, tensor)
}

outputs.foreach {
case (name, _) => runner.fetch(name)
}

nodes.foreach {
_.foreach {
name => runner.addTarget(name)
}
}

runner.run().asScala.zip(outputs).map {
case (tensor, (_, dataType)) =>
garbage += tensor
BatchTensorflowConverter.convert(tensor, dataType)
}
}
}

garbage.result.foreach(_.close())
result.get.transpose
}

private def withSession[T](f: (tensorflow.Session) => T): T = {
val s = session.getOrElse {
session = Some(new tensorflow.Session(graph))
session.get
}

f(s)
}

override def close(): Unit = {
session.foreach(_.close())
graph.close()
}

override def finalize(): Unit = {
close()
super.finalize()
}

override def inputSchema: StructType = StructType(inputs.map {
case (name, dt) => StructField(name, dt)
}).get

override def outputSchema: StructType = StructType(outputs.map {
case (name, dt) => StructField(name, dt)
}).get
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package ml.combust.mleap.tensorflow

import ml.combust.mleap.core.types.NodeShape
import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, SimpleTransformer, Transformer}
import ml.combust.mleap.runtime.function.{FieldSelector, Selector, UserDefinedFunction}
import ml.combust.mleap.tensor.Tensor

import scala.util.Try

case class BatchTensorflowTransformer(override val uid: String = Transformer.uniqueName("batch_tensorflow"),
override val shape: NodeShape,
override val model: BatchTensorflowModel)
extends SimpleTransformer {
private val f = (rows: Seq[Row]) => {
model(rows.map(x => x.toSeq.map(Tensor.scalar(_))): _*).map(Row(_: _*))
}

override val exec: UserDefinedFunction =
UserDefinedFunction(f, outputSchema, inputSchema)

val outputCols: Seq[String] = outputSchema.fields.map(_.name)
val inputCols: Seq[String] = inputSchema.fields.map(_.name)
private val inputSelector: Seq[Selector] = inputCols.map(FieldSelector)

override def transform[TB <: FrameBuilder[TB]](builder: TB): Try[TB] = {
builder.withColumns(outputCols, inputSelector: _*)(exec)
}

override def close(): Unit = { model.close() }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package ml.combust.mleap.tensorflow

import java.nio.file.Files

import ml.bundle.{BasicType, DataShape}
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.OpModel
import ml.combust.mleap.bundle.ops.MleapOp
import ml.combust.mleap.core
import ml.combust.mleap.core.types.TensorType
import ml.combust.mleap.runtime.MleapContext
import ml.combust.mleap.runtime.types.BundleTypeConverters._

class BatchTensorflowTransformerOp extends MleapOp[BatchTensorflowTransformer, BatchTensorflowModel] {
override val Model: OpModel[MleapContext, BatchTensorflowModel] = new OpModel[MleapContext, BatchTensorflowModel] {
override val klazz: Class[BatchTensorflowModel] = classOf[BatchTensorflowModel]

override def opName: String = Bundle.BuiltinOps.batch_tensorflow

override def store(model: Model, obj: BatchTensorflowModel)
(implicit context: BundleContext[MleapContext]): Model = {
Files.write(context.file("graph.pb"), obj.graph.toGraphDef)
val (inputNames, inputMleapDataTypes) = obj.inputs.unzip
val (inputBasicTypes, inputShapes) = inputMleapDataTypes.map {
dt => (dt.base: BasicType, dt.shape: DataShape)
}.unzip

val (outputNames, outputMleapDataTypes) = obj.outputs.unzip
val (outputBasicTypes, outputShapes) = outputMleapDataTypes.map {
dt => (dt.base: BasicType, dt.shape: DataShape)
}.unzip

model.withValue("input_names", Value.stringList(inputNames)).
withValue("input_types", Value.basicTypeList(inputBasicTypes)).
withValue("input_shapes", Value.dataShapeList(inputShapes)).
withValue("output_names", Value.stringList(outputNames)).
withValue("output_types", Value.basicTypeList(outputBasicTypes)).
withValue("output_shapes", Value.dataShapeList(outputShapes)).
withValue("nodes", obj.nodes.map(Value.stringList))
}

override def load(model: Model)
(implicit context: BundleContext[MleapContext]): BatchTensorflowModel = {
val graphBytes = Files.readAllBytes(context.file("graph.pb"))

val inputNames = model.value("input_names").getStringList
val inputTypes = model.value("input_types").getBasicTypeList.map(v => v: core.types.BasicType)
val inputShapes = model.value("input_shapes").getDataShapeList.map(v => v: core.types.DataShape)

val outputNames = model.value("output_names").getStringList
val outputTypes = model.value("output_types").getBasicTypeList.map(v => v: core.types.BasicType)
val outputShapes = model.value("output_shapes").getDataShapeList.map(v => v: core.types.DataShape)

val nodes = model.getValue("nodes").map(_.getStringList)

val inputs = inputNames.zip(inputTypes.zip(inputShapes).map {
case (b, s) => core.types.DataType(b, s).asInstanceOf[TensorType]
})
val outputs = outputNames.zip(outputTypes.zip(outputShapes).map {
case (b, s) => core.types.DataType(b, s).asInstanceOf[TensorType]
})

val graph = new org.tensorflow.Graph()
graph.importGraphDef(graphBytes)
BatchTensorflowModel(graph,
inputs,
outputs,
nodes)
}
}

override def model(node: BatchTensorflowTransformer): BatchTensorflowModel = node.model
}
Loading