Source code for prody.utilities.laptools

"""This module defines functions for solving linear assignment problems."""

from numpy import arange, insert, delete, zeros, unique, vstack, ceil, argmin, tile
from scipy.optimize import linear_sum_assignment as lap

__all__ = ['multilap', 'SolutionDepletionException']

class SolutionDepletionException(Exception):
    pass

def expand_nodes(node):
    includes0, excludes0, solution = node

    pairs = []
    R, C = solution
    for r, c in zip(R, C):
        pairs.append((r, c))

    nodes = []; previous_pairs = []
    for i, pair in enumerate(pairs):
        if i == len(pairs) - 1:
            continue
        if pair in includes0:
            continue
        if pair in excludes0: # unlikely
            raise ValueError('%s should be excluded'%str(pair))
        
        includes = list(includes0)
        includes.extend(previous_pairs)

        excludes = list(excludes0)
        excludes.append(pair)

        nodes.append([includes, excludes, None])
        previous_pairs.append(pair)

    return nodes
    
[docs]def multilap(cost_matrix, nodes=[], BIG_NUMBER=1e6): """ Finds the (next) optimal solution to the linear assignment problem. The function can handle the cases where each row can be assigned to multiple columns. .. [KM68] Murty KG. Letter to the editor-An algorithm for ranking all the assignments in order of increasing cost. *Operations research* **1968** 16(3):682-687. """ m, n = cost_matrix.shape row_labels = arange(m) # make duplicates for alternative assignments n_copies = int(ceil(float(n) / m)) if n_copies > 1: cost_matrix = vstack((cost_matrix,)*n_copies) row_labels = tile(row_labels, n_copies) else: cost_matrix = cost_matrix.copy() if len(nodes): M, N = cost_matrix.shape costs = [] for node in nodes: includes, excludes, assignments = node if assignments is None: D = cost_matrix.copy() R = arange(M) C = arange(N) # mask excludes for r, c in excludes: r_ = row_labels[r] D[row_labels==r_, c] = BIG_NUMBER # remove includes delR = []; delC = [] for r, c in includes: delR.append(r) delC.append(c) if delC: D = delete(D, delC, axis=1) C = delete(C, delC) if delR: D = delete(D, delR, axis=0) R = delete(R, delR) # solve lap I, J = lap(D) cost = D[I, J].sum() if includes: cost += cost_matrix[delR, delC].sum() row_ind = R[I]; col_ind = C[J] # add back includes if includes: row_ind = insert(row_ind, 0, delR) col_ind = insert(col_ind, 0, delC) assignments = node[-1] = (row_ind, col_ind) else: R, C = assignments cost = cost_matrix[R, C].sum() costs.append(cost) # remove assignments with violations. -0.01 is to avoid precision problem for i in reversed(range(len(costs))): if costs[i] - BIG_NUMBER >= -0.01: costs.pop(i) nodes.pop(i) if len(costs) == 0: raise SolutionDepletionException('solution depleted') i = argmin(costs) node = nodes.pop(i) R, C = node[-1] else: R, C = lap(cost_matrix) node = ([], [], (R, C)) cost = cost_matrix[R, C].sum() new_nodes = expand_nodes(node) nodes.extend(new_nodes) R_ = row_labels[R] return (R_, C), gen_mappings((R_, C))
def gen_mappings(assignments): from itertools import product as iproduct I, J = assignments m = len(unique(I)) M = I.max() + 1 if m == len(I): mappings = [(I, J)] else: pool = [[] for _ in range(M)] for i, j in zip(I, J): pool[i].append((i, j)) mappings = [] for mapping in iproduct(*pool): r = zeros(len(mapping), dtype=int) c = zeros(len(mapping), dtype=int) for i, pair in enumerate(mapping): r[i] = pair[0] c[i] = pair[1] mappings.append((r, c)) return mappings