2017-09-06 76 views
3

下面给出的设置数据:反向笛卡尔乘积

a | b | c | d 
1 | 3 | 7 | 11 
1 | 5 | 7 | 11 
1 | 3 | 8 | 11 
1 | 5 | 8 | 11 
1 | 6 | 8 | 11 

执行反向笛卡尔乘积得到:

a | b | c | d 
1 | 3,5 | 7,8 | 11 
1 | 6 | 8 | 11 

我目前使用Scala的工作,和我的输入/输出数据类型是目前:

ListBuffer[Array[Array[Int]]] 

我想出了一个解决方案(见下文),但觉得它可以优化。我愿意优化我的方法和全新的方法。在scala和c#中的解决方案是首选。

我也很好奇,如果这可以在MS SQL中完成。

我目前的解决方案:

def main(args: Array[String]): Unit = { 

    // Input 
    val data = ListBuffer(Array(Array(1), Array(3), Array(7), Array(11)), 
          Array(Array(1), Array(5), Array(7), Array(11)), 
          Array(Array(1), Array(3), Array(8), Array(11)), 
          Array(Array(1), Array(5), Array(8), Array(11)), 
          Array(Array(1), Array(6), Array(8), Array(11))) 

    reverseCartesianProduct(data) 
} 

def reverseCartesianProduct(input: ListBuffer[Array[Array[Int]]]): ListBuffer[Array[Array[Int]]] = { 
    val startIndex = input(0).size - 1 

    var results:ListBuffer[Array[Array[Int]]] = input 

    for (i <- startIndex to 0 by -1) { 
     results = groupForward(results, i, startIndex) 
    } 

    results 
} 

def groupForward(input: ListBuffer[Array[Array[Int]]], groupingIndex: Int, startIndex: Int): ListBuffer[Array[Array[Int]]] = { 

    if (startIndex < 0) { 
     val reduced = input.reduce((a, b) => { 
     mergeRows(a, b) 
     }) 

     return ListBuffer(reduced) 
    } 

    val grouped = if (startIndex == groupingIndex) { 
     Map(0 -> input) 
    } 
    else { 
     groupOnIndex(input, startIndex) 
    } 

    val results = grouped.flatMap{ 
     case (index, values: ListBuffer[Array[Array[Int]]]) => 
     groupForward(values, groupingIndex, startIndex - 1) 
    } 

    results.to[ListBuffer] 
    } 

    def groupOnIndex(list: ListBuffer[Array[Array[Int]]], index: Int): Map[Int, ListBuffer[Array[Array[Int]]]] = { 

    var results = Map[Int, ListBuffer[Array[Array[Int]]]]() 

    list.foreach(a => { 
     val key = a(index).toList.hashCode() 

     if (!results.contains(key)) { 
     results += (key -> ListBuffer[Array[Array[Int]]]()) 
     } 

     results(key) += a 
    }) 

    results 
    } 

    def mergeRows(a: Array[Array[Int]], b: Array[Array[Int]]): Array[Array[Int]] = { 

    val zipped = a.zip(b) 

    val merged = zipped.map{ case (array1: Array[Int], array2: Array[Int]) => 
     val m = array1 ++ array2 

     quickSort(m) 

     m.distinct 
     .array 
    } 

    merged 
    } 

其工作原理是:

  1. 遍历列,从右到左(该groupingIndex指定上运行其列此列是唯一一个为了合并行而不必具有彼此相等的值)。
  2. 对所有其他列(不是groupingIndex)上的数据进行递归分组。
  3. 将所有列分组后,假定每个组中的数据在除分组列以外的每列中都有相同的值。
  4. 合并具有匹配列的行。为每一列取不同的值并对每一列进行排序。

我很抱歉,如果这些没有意义,我的大脑今天就不能运作。

+0

答案必须是(1 | 3,5 | 7,8 | 11)联合(1 | 6 | 8 | 11)还是与其他答案同样好?ie(1 | 3,5 | 7 | 11)union(1 | 3,5,6 | 8 | 11),只要所有行都被覆盖一次?要找到最佳答案实际上是一项非常艰巨的任务,np-hard,请在这里查看答案:https://cs.stackexchange.com/questions/87247/reverse-cartesian-product-matching-all-given-rows – jbilander

回答

0

这是我对此的看法。代码使用Java,但可以轻松转换为Scala或C#。

我在n-1的所有组合上运行groupingBy,并且选择计数最低的那一个,这意味着最大的合并深度,所以这是一种贪婪的方法。但不能保证你会找到最佳的解决方案,这意味着尽量减少k这是np-hard要做的,请参阅链接here的解释,但你会找到一个有效的解决方案,并做得相当快。

完整的示例在这里:https://github.com/jbilander/ReverseCartesianProduct/tree/master/src

Main.java

import java.util.*; 
    import java.util.stream.Collectors; 

    public class Main { 

     public static void main(String[] args) { 

      List<List<Integer>> data = List.of(List.of(1, 3, 7, 11), List.of(1, 5, 7, 11), List.of(1, 3, 8, 11), List.of(1, 5, 8, 11), List.of(1, 6, 8, 11)); 
      boolean done = false; 
      int rowLength = data.get(0).size(); //4 
      List<Table> tables = new ArrayList<>(); 

      // load data into table 
      for (List<Integer> integerList : data) { 

       Table table = new Table(rowLength); 
       tables.add(table); 

       for (int i = 0; i < integerList.size(); i++) { 
        table.getMap().get(i + 1).add(integerList.get(i)); 
       } 
      } 

      // keep track of count, needed so we know when to stop iterating 
      int numberOfRecords = tables.size(); 

      // start algorithm 
      while (!done) { 

       Collection<List<Table>> result = getMinimumGroupByResult(tables, rowLength); 

       if (result.size() < numberOfRecords) { 

        tables.clear(); 

        for (List<Table> tableList : result) { 

         Table t = new Table(rowLength); 
         tables.add(t); 

         for (Table table : tableList) { 
          for (int i = 1; i <= rowLength; i++) { 
           t.getMap().get(i).addAll(table.getMap().get(i)); 
          } 
         } 
        } 
        numberOfRecords = tables.size(); 
       } else { 
        done = true; 
       } 
      } 

      tables.forEach(System.out::println); 
     } 

     private static Collection<List<Table>> getMinimumGroupByResult(List<Table> tables, int rowLength) { 

      Collection<List<Table>> result = null; 
      int min = Integer.MAX_VALUE; 

      for (List<Integer> keyCombination : getKeyCombinations(rowLength)) { 

       switch (rowLength) { 

        case 4: { 
         Map<Tuple3<TreeSet<Integer>, TreeSet<Integer>, TreeSet<Integer>>, List<Table>> map = 
           tables.stream().collect(Collectors.groupingBy(t -> new Tuple3<>(
             t.getMap().get(keyCombination.get(0)), 
             t.getMap().get(keyCombination.get(1)), 
             t.getMap().get(keyCombination.get(2)) 
           ))); 
         if (map.size() < min) { 
          min = map.size(); 
          result = map.values(); 
         } 
        } 
        break; 
        case 5: { 
         //TODO: Handle n = 5 
        } 
        break; 
        case 6: { 
         //TODO: Handle n = 6 
        } 
        break; 
       } 
      } 

      return result; 
     } 

     private static List<List<Integer>> getKeyCombinations(int rowLength) { 

      switch (rowLength) { 
       case 4: 
        return List.of(List.of(1, 2, 3), List.of(1, 2, 4), List.of(2, 3, 4), List.of(1, 3, 4)); 

       //TODO: handle n = 5, n = 6, etc... 
      } 

      return List.of(List.of()); 
     } 
    } 

输出的tables.forEach(System.out::println)

Table{1=[1], 2=[3, 5, 6], 3=[8], 4=[11]} 
    Table{1=[1], 2=[3, 5], 3=[7], 4=[11]} 

或改写可读性:

 a | b | c | d 
    --|-------|---|--- 
    1 | 3,5,6 | 8 | 11 
    1 | 3,5 | 7 | 11 

如果你要在sql(mysql)中完成所有这些,你可以使用group_concat(),我认为MS SQL在这里有类似的内容:simulating-group-concatSTRING_AGG如果SQL Server 2017,但我认为你将不得不使用文本列在这种情况下有点讨厌:

eg

create table my_table (A varchar(50) not null, B varchar(50) not null, 
          C varchar(50) not null, D varchar(50) not null); 

    insert into my_table values ('1','3,5','4,15','11'), ('1','3,5','3,10','11'); 

    select A, B, group_concat(C order by C) as C, D from my_table group by A, B, D; 

会给结果的下方,所以你必须分析和排序,并更新逗号分隔的结果对任何未来合并迭代(按组)是正确的。

['1', '3,5', '3,10,4,15', '11']