ACM-ICPC 2010国内予選 F. 古い記憶

AOJで通ってしまったので、その記録を晒してみます。

F. 古い記憶

基本的な方針は以下の通りです。

  • ピースをつなぎ合わせて(深さ優先で)文字列を生成していく
  • 改ざん文書に対し編集距離d以下のものが出来たら答えに追加


さすがにナイーブに書いたら遅すぎたので、以下の高速化を行いました。
編集距離関係はちょっと根拠が怪しいので間違い等あればご指摘いただけると嬉しいです。

  • 重複するピースを削除
  • 既に生成した文字列をsetで管理
  • ピースのつなぎ位置を前計算
    • 各ピースiについて、iの最後k文字と先頭k文字が一致するピースjを列挙
    • 以下のコードでは(i, k)の組に対して、利用可能なjの集合を管理している
  • ピースをつないだ後の文字列長が(改ざん文書の文字列長+d)を超えたらカット
    • 明らかに編集距離はdを超える
  • 編集距離DPにおける計算削減
    • 今回は編集距離がd以下かどうかさえ分かれば良いので、マジメに計算しなくて良い
    • 具体的には、DP(i, j)でabs(i-j)>dとなる個所は適当に大きな値を仮定する
      • abs(i-j)>dならDP(i, j)>dで、ここから編集距離をd以下の状態にはできない
    • こうするとテーブル1行あたり(2*d+1)回の計算で済むのでO(nd)で判定できる
    • 各行について計算した範囲で値がd以下の場所が無ければ計算を終了
      • abs(i-j)>dのカットと同じような理由で
    • 2つの文字列s, tについて,DP(len(s), len(t))の値はd以下なら(たぶん)正しい値
  • ピース接続の過程で1文字追加ごとにDPテーブルを更新
    • 編集距離がd以下にならない事が判明した時点でチェックを止める
  • 文字列生成に関わる処理はなるべく減らす方向に


そんなこんなで出来たコードが以下のものになります。
手元の環境でジャッジデータが15秒くらい。
まだピース長が全部同じ理由とか、最低13文字ある理由とかが分かっていないですし、
編集距離計算のチートを除くとTLEしてしまうので、もうちょっと頑張れるんだろうと思います。


(追記:7/14 01:43)

以下の高速化を追加し、手元でジャッジデータが11秒、AOJで2.2秒となりました。

  • 2個前に使ったピースとオーバーラップする候補は考えない
    • 考えると1個前に使ったピースがいらない子になるので
  • ピースと文書の編集距離を使った枝刈り
    • ピースの後ろk文字について文書のすべての部分文字列との編集距離の最小値をとる
      • 編集距離計算のDPをちょっといじるとO(nm)で計算できる
    • 生成中の文にピースの後ろk文字を足すと編集距離は少なくともその値分は増加
      • それでdを超えるようなら考えなくても良い
    • そもそもピース単体でこの値がdを超えるピースを削っておく事もできる


以下のコードは追記前のままなので、追記分は反映されていません。

#include <iostream>
#include <string>
#include <vector>
#include <set>
#include <algorithm>

using namespace std;

const int INF = 100000000;

int d;
string mes, cur;
vector<string> piece;
vector<int> connect[30][20];
vector<string> ans;
set<string> mem[43];
int dp[43][43];

bool editDist(int s, int e){
    for(int i=s+1;i<=e;i++){
        int m = INF;
        for(int j=max(1,i-d);j<=min((int)mes.size(),i+d);j++){
            if(cur[i-1]==mes[j-1]) dp[i][j] = dp[i-1][j-1];
            else {
                dp[i][j] = dp[i-1][j-1]+1;
                dp[i][j] = min(dp[i][j], dp[i-1][j]+1);
                dp[i][j] = min(dp[i][j], dp[i][j-1]+1);
            }
            m = min(m, dp[i][j]);
        }
        if(m > d) return false;
    }
    return true;
}

void dfs(int pIdx){
    int len = piece[0].size();
    int csize = cur.size();
    if(dp[csize][mes.size()] <= d) ans.push_back(cur);
    for(int i=len-1;i>=0;i--){
        int nsize = csize+len-i;
        if(nsize > mes.size() + d) break;
        cur += ' ';
        for(int j=0;j<connect[pIdx][i].size();j++){
            bool flag = true;
            for(int k=i;k<len;k++){
                cur[csize+k-i] = piece[connect[pIdx][i][j]][k];
                if(!editDist(csize+k-i, csize+k-i+1)){
                    flag = false;
                    break;
                }
            }
            if(!flag) continue;
            if(mem[nsize].count(cur)) continue;
            mem[nsize].insert(cur);
            dfs(connect[pIdx][i][j]);
        }
    }
    cur = cur.substr(0, csize);
}

int main(){
    int n, len;
    for(int i=0;i<=42;i++) dp[i][0] = dp[0][i] = i;
    while(cin >> d >> n, n){
        cin >> mes;
        piece.clear();
        for(int i=0;i<n;i++){
            string str; cin >> str;
            piece.push_back(str);
        }
        sort(piece.begin(), piece.end());
        piece.erase(unique(piece.begin(), piece.end()), piece.end());
        len = piece[0].size();
        for(int i=1;i<=mes.size()+d;i++)
            for(int j=1;j<=mes.size()+d;j++)
                if(abs(i-j) > d) dp[i][j] = INF;
        for(int i=0;i<piece.size();i++){
            for(int j=0;j<len;j++){
                connect[i][j].clear();
                for(int k=0;k<piece.size();k++){
                    if(piece[i].substr(len-j) == piece[k].substr(0,j))
                        connect[i][j].push_back(k);
                }
            }
        }
        for(int i=0;i<=mes.size()+d;i++)
            mem[i].clear();
        ans.clear();
        for(int i=0;i<piece.size();i++){
            cur = piece[i];
            if(!editDist(0, len)) continue;
            dfs(i);
        }
        sort(ans.begin(), ans.end());
        cout << ans.size() << endl;
        if(ans.size() <= 5){
            for(int i=0;i<ans.size();i++)
                cout << ans[i] << endl;
        }
    }
}