NMF code
NMF based on spark graphx is implemented as follows. I have already experimented it successfully. And then I will descirbe the structure and design of NMF algo on the other day.
/**
* Graph NMF algorithm implementation.
*
* Implementation Idea based on Pregel
*
* During the iteration, design the direction of the message propagating
* The initial iteration number is 1
* When iteration number is odd, forward propagate
* update matrix W and use matrix H to propagate
* When iteration number is even, back propagate
* update matrix H and use matrix W to propagate
*
* vertex attribute is [(Vector,Vector)], the first Vector is Vector W, and
* the other is Vector H
*
* update Rule:
* W_i <- (1-\theta*\lambda)*W_i + \theta sigma{ (d_ij - W_i*H_j) * H_j }
* H_j <- (1-\theta*\lambda)*H_j + \theta sigma{ (d_ij - W_i*H_j) * W_i }
*
* `theta` is the step size
* `lambda` is the normalized item
*
*/
object ICTGraphNMF extends Logging with Serializable {
/**
* Run GraphNMF on fixed iteration algorithm returning a graph with
* vertex attributes containing the two Vectors which is Vector W and Vector H
* and edge attributes containing the edge weight.
*
* @tparam VD the original vertex attribute (not used)
*
* @param graph the graph on which to run NMF, the edge attribute must be Double
* @param maxIteration the max iteration
* @param theta the step size
* @param lambda the normalization item
* @param reducedDim the reduced Dimension in NMF algorithm
*
* @return the graph containing the two Vectors which is Vector W and Vector H
* and edge attributes containing the edge weight.
*
*/
def run[VD: ClassTag](graph: Graph[VD, Double],
maxIterations: Int = Int.MaxValue,
theta: Double = 0.01,
lambda: Double = 0.1,
reducedDim: Int = 2) = {
def forwardVertexProgram(id: VertexId, attri: (Vector, Vector), msgSum: Vector): (Vector, Vector) = {
val scale = 1 - theta * lambda
val intercept = theta * msgSum
val newV = scale * attri._1 + intercept
if (newV.elements.count(elem => elem < 0.0) == 0)
(newV, attri._2)
else {
val newElementsNonZero = newV.elements.map(elem => if (elem > 0.0) elem else 0.0)
(Vector(newElementsNonZero), attri._2)
}
}
def backVertexProgram(id: VertexId, attri: (Vector, Vector), msgSum: Vector): (Vector, Vector) = {
val scale = 1 - theta * lambda
val intercept = theta * msgSum
val newV = scale * attri._2 + intercept
if (newV.elements.count(elem => elem < 0.0) == 0)
(attri._1, newV)
else {
val newElementsNonZero = newV.elements.map(elem => if (elem > 0.0) elem else 0.0)
(attri._1, Vector(newElementsNonZero))
}
}
def forwardSendMessage(edge: EdgeTriplet[(Vector, Vector), Double]) = {
Iterator((edge.srcId, (edge.attr - edge.srcAttr._1.dot(edge.dstAttr._2)) * edge.dstAttr._2))
}
def backSendMessage(edge: EdgeTriplet[(Vector, Vector), Double]) = {
Iterator((edge.dstId, (edge.attr - edge.srcAttr._1.dot(edge.dstAttr._2)) * edge.srcAttr._1))
}
def messageCombiner(a: Vector, b: Vector): Vector = a + b
// initiate each Vertex's vector W and vector H whose dimension is reducedDim on
var curGraph: Graph[(Vector, Vector), Double] = graph
.mapVertices((vid, vdata) =>
(Vector(Array.fill(reducedDim)(Random.nextDouble)),
Vector(Array.fill(reducedDim)(Random.nextDouble)))).cache()
// var curGraph = nmfGraph //.cache()
var messages = curGraph.mapReduceTriplets(forwardSendMessage, messageCombiner)
var activeMessages = messages.count()
var curIteration: Int = 1
var prevGraph: Graph[(Vector, Vector), Double] = null
while (activeMessages > 0 && (curIteration - 1) / 2 < maxIterations) {
if ((curIteration - 1) % 2 == 0) {
logDebug("Graph Information\n")
logDebug(curGraph.vertices.collect().mkString("\n"))
logDebug("GraphNMF interation:" + ((curIteration + 1) / 2).toString)
logDebug("forward propagating\nprapagating messages:")
logDebug(messages.collect().mkString("\n"))
val newVerts: VertexRDD[(Vector, Vector)] = curGraph.vertices.innerJoin(messages)(forwardVertexProgram).cache()
prevGraph = curGraph
curGraph = curGraph.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
curGraph.cache()
val oldMessages = messages
messages = curGraph.mapReduceTriplets(backSendMessage, messageCombiner).cache()
oldMessages.unpersist(blocking = false)
newVerts.unpersist(blocking = false)
} else {
logDebug("back propagating\nprapagating messages:")
logDebug(messages.collect().mkString("\n"))
val newVerts = curGraph.vertices.innerJoin(messages)(backVertexProgram).cache()
prevGraph = curGraph
curGraph = curGraph.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
curGraph.cache()
val oldMessages = messages
messages = curGraph.mapReduceTriplets(forwardSendMessage, messageCombiner).cache()
oldMessages.unpersist(blocking = false)
newVerts.unpersist(blocking = false)
}
activeMessages = messages.count()
prevGraph.unpersistVertices(blocking = false)
prevGraph.edges.unpersist(blocking = false)
curIteration += 1
}
curGraph
}
/**
* Run GraphNMF on fixed iteration algorithm returning a graph with
* vertex attributes containing the two Vectors which is Vector W and Vector H
* and edge attributes containing the edge weight.
*
* @tparam VD the original vertex attribute (not used)
*
* @param graph the graph on which to run NMF, the edge attribute must be Double
* @param maxIteration the max iteration
* @param theta the step size
* @param lambda the normalization item
* @param reducedDim the reduced Dimension in NMF algorithm
*
* @return the graph containing the two Vectors which is Vector W and Vector H
* and edge attributes containing the edge weight.
*
*/
def runWithZero[VD: ClassTag](graph: Graph[VD, Double],
maxIterations: Int = Int.MaxValue,
theta: Double = 0.01,
lambda: Double = 0.1,
reducedDim: Int = 2) = {
// val matrixWAccumulator = sc.accumulator(new Array[Double](reducedDim * reducedDim), "MatrixW")
// val matrixHAccumulator = sc.accumulator(new Array[Double](reducedDim * reducedDim), "MatrixH")
var MatrixH = new Array[Double](reducedDim * reducedDim)
var MatrixW = new Array[Double](reducedDim * reducedDim)
/**
* Wi * MatrixH (j) = Wi * Matrix(*,j)
* = Wi * Array(i+reducedDim * j) i=0,...,reducedDim-1
*/
def mutiplyVM(vec: Vector, arr: Array[Double]): Vector = {
var result = new Array[Double](reducedDim)
val vecElems = vec.elements
for (i <- 0 to reducedDim - 1) {
for (j <- 0 to reducedDim - 1) {
result(i) += (vecElems(j) * arr(j + reducedDim * i))
}
}
Vector(result)
}
def forwardVertexProgram(id: VertexId, attri: (Vector, Vector), msgSum: Vector): (Vector, Vector) = {
val scale = 1 - theta * lambda
val intercept = theta * (msgSum - mutiplyVM(attri._1, MatrixH))
val newV = scale * attri._1 + intercept
if (newV.elements.count(elem => elem < 0.0) == 0)
(newV, attri._2)
else {
val newElementsNonZero = newV.elements.map(elem => if (elem > 0.0) elem else 0.0)
(Vector(newElementsNonZero), attri._2)
}
}
def backVertexProgram(id: VertexId, attri: (Vector, Vector), msgSum: Vector): (Vector, Vector) = {
val scale = 1 - theta * lambda
val intercept = theta * (msgSum - mutiplyVM(attri._2, MatrixW))
val newV = scale * attri._2 + intercept
if (newV.elements.count(elem => elem < 0.0) == 0)
(attri._1, newV)
else {
val newElementsNonZero = newV.elements.map(elem => if (elem > 0.0) elem else 0.0)
(attri._1, Vector(newElementsNonZero))
}
}
def forwardSendMessage(edge: EdgeTriplet[(Vector, Vector), Double]) = {
Iterator((edge.srcId, edge.dstAttr._2.multiply(edge.attr)))
}
def backSendMessage(edge: EdgeTriplet[(Vector, Vector), Double]) = {
Iterator((edge.dstId, edge.srcAttr._1.multiply(edge.attr)))
}
def messageCombiner(a: Vector, b: Vector): Vector = a + b
// initiate each Vertex's vector W and vector H whose dimension is reducedDim on
var curGraph: Graph[(Vector, Vector), Double] = graph
.mapVertices((vid, vdata) =>
(Vector(Array.fill(reducedDim)(Random.nextDouble)),
Vector(Array.fill(reducedDim)(Random.nextDouble)))).cache()
// var curGraph = nmfGraph //.cache()
var messages = curGraph.mapReduceTriplets(forwardSendMessage, messageCombiner)
var activeMessages = messages.count()
var curIteration: Int = 1
var prevGraph: Graph[(Vector, Vector), Double] = null
while (activeMessages > 0 && (curIteration - 1) / 2 < maxIterations) {
if ((curIteration - 1) % 2 == 0) {
logDebug("Graph Information\n")
logDebug(curGraph.vertices.collect().mkString("\n"))
logDebug("GraphNMF interation:" + ((curIteration + 1) / 2).toString)
logDebug("forward propagating\nprapagating messages:")
logDebug(messages.collect().mkString("\n"))
MatrixH = new Array[Double](reducedDim * reducedDim)
curGraph.vertices.sortBy(_._1, true).map(vertexElem => { //sortWith((VD1, VD2) => VD1._1 < VD2._1).
// for each vertex compute Hj'*Hj and then add them together
val h = vertexElem._2._2
var i = 0
var j = 0
for (elemi <- h.elements) {
i = 0
for (elemj <- h.elements) {
var k = i + reducedDim * j
MatrixH(k) = MatrixH(k) + elemi * elemj
i += 1
}
j += 1
}
})
MatrixW = new Array[Double](reducedDim * reducedDim)
curGraph.vertices.sortBy(_._1, true).map(vertexElem => {
val w = vertexElem._2._1
var i = 0
var j = 0
for (elemi <- w.elements) {
i = 0
for (elemj <- w.elements) {
var k = i + reducedDim * j
MatrixW(k) = MatrixW(k) + elemi * elemj
i += 1
}
j += 1
}
})
val newVerts: VertexRDD[(Vector, Vector)] = curGraph.vertices.innerJoin(messages)(forwardVertexProgram).cache()
prevGraph = curGraph
curGraph = curGraph.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
curGraph.cache()
val oldMessages = messages
messages = curGraph.mapReduceTriplets(backSendMessage, messageCombiner).cache()
oldMessages.unpersist(blocking = false)
newVerts.unpersist(blocking = false)
} else {
logDebug("Graph Information\n")
logDebug(curGraph.vertices.collect().mkString("\n"))
logDebug("back propagating\nprapagating messages:")
logDebug(messages.collect().mkString("\n"))
val newVerts = curGraph.vertices.innerJoin(messages)(backVertexProgram).cache()
prevGraph = curGraph
curGraph = curGraph.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
curGraph.cache()
val oldMessages = messages
messages = curGraph.mapReduceTriplets(forwardSendMessage, messageCombiner).cache()
oldMessages.unpersist(blocking = false)
newVerts.unpersist(blocking = false)
}
activeMessages = messages.count()
prevGraph.unpersistVertices(blocking = false)
prevGraph.edges.unpersist(blocking = false)
curIteration += 1
}
curGraph
}
}