Unverified Commit f348dd3d authored by Nano's avatar Nano Committed by GitHub
Browse files

Merge pull request #2491 from ShuYuMo2003/patch-4

字符串hash:修改评论中提出的问题
parents e64d6529 72938517
Loading
Loading
Loading
Loading
+59 −57
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ bool cmp(const string& s, const string& t) {

一般采取的方法是对整个字符串先预处理出每个前缀的哈希值,将哈希值看成一个 $b$ 进制的数对 $M$ 取模的结果,这样的话每次就能快速求出子串的哈希了:

令 $f_i(s)$ 表示 $f(s[1..i])$ ,那么 $f(s[l..r])=\frac{f_r(s)-f_{l-1}(s)}{b^{l-1}}$ ,其中 $\frac{1}{b^{l-1}}$ 可以预处理出来,用 [乘法逆元](../math/inverse.md) 或者是在比较哈希值时等式两边同时乘上 $b$ 的若干次方化为整式均可
令 $f_i(s)$ 表示 $f(s[1..i])$ ,那么 $f(s[l..r])=f_r(s)-f_{l-1}(s) \times b^{r-l+1}$ ,其中 $b^{r-l+1}$ 可以预处理出来。

这样的话,就可以在 $O(n)$ 的预处理后每次 $O(1)$ 地计算子串的哈希值了。

@@ -98,69 +98,71 @@ bool cmp(const string& s, const string& t) {
    
    ??? mdui-shadow-6 "参考代码"
        ```cpp
        #include <algorithm>
        #include <cstdio>
        #include <cstring>
        
        const int N = 1000010;
        const int m1 = 998244353;
        const int m2 = 1000001011;
        const int K = 233;
        
        typedef long long ll;
        
        int m, h1[N], h2[N], len, i1[N], i2[N];
        char s[N];
        
        void add(char x) {
          h1[len + 1] = ((ll)h1[len] * K + x) % m1;
          h2[len + 1] = ((ll)h2[len] * K + x) % m2;
          ++len;
        }
        
        int get1(int l, int r) { return (ll)(h1[r] - h1[l - 1] + m1) * i1[l - 1] % m1; }
        int get2(int l, int r) { return (ll)(h2[r] - h2[l - 1] + m2) * i2[l - 1] % m2; }
        
        bool cmp(int l1, int r1, int l2, int r2) {
          return get1(l1, r1) == get1(l2, r2) && get2(l1, r1) == get2(l2, r2);
        #include <iostream>
        #include <string>
        using namespace std;
        const int CN = 1e6 + 6;
        const int M1 = 11431471;
        const int B1 = 231;
        const int M2 = 37101101;
        const int B2 = 312;
        int read() {
          int s = 0, ne = 1;
          char c = getchar();
          while (c < '0' || c > '9') ne = c == '-' ? -1 : 1, c = getchar();
          while (c >= '0' && c <= '9') s = (s << 1) + (s << 3) + c - '0', c = getchar();
          return s * ne;
        }
        
        int qpow(int x, int y, int mod) {
          int out = 1;
          while (y) {
            if (y & 1) out = (ll)out * x % mod;
            x = (ll)x * x % mod;
            y >>= 1;
        int qp(int a, int b, int P) {
          int r = 1;
          while (b) {
            if (b & 1) r = 1ll * r * a % P;
            a = 1ll * a * a % P;
            b >>= 1;
          }
          return out;
          return r;
        }
        
        int main() {
          i1[0] = i2[0] = 1;
          int k1 = qpow(K, m1 - 2, m1);  // 求逆元
          int k2 = qpow(K, m2 - 2, m2);
          for (int i = 1; i < N; ++i) {
            i1[i] = (ll)i1[i - 1] * k1 % m1;
            i2[i] = (ll)i2[i - 1] * k2 % m2;
        int H1[CN], H2[CN], l1 = 0;
        void add1(int x) {
          H1[l1 + 1] = (1ll * H1[l1] * B1 % M1 + x) % M1,
                  H2[l1 + 1] = (1ll * H2[l1] * B2 % M2 + x) % M2;
          l1++;
        }
        
          scanf("%d", &m);
        
          while (m--) {
            scanf("%s", s + 1);
            int n = strlen(s + 1);
            for (int i = 1; i <= n; ++i) add(s[i]);
            // 先把当前串加到答案串的后面,可以方便地求哈希
            for (int i = std::min(n, len - n); i >= 0; --i) {
              if (cmp(len - n - i + 1, len - n, len - n + 1, len - n + i)) {
                len -= n;  // 确定了要加多长再真正地加进去
                for (int j = i + 1; j <= n; ++j) add(s[j]);
                printf("%s", s + i + 1);
                break;
        int h1[CN], h2[CN], l2 = 0;
        void add2(int x) {
          h1[l2 + 1] = (1ll * h1[l2] * B1 % M1 + x) % M1,
                  h2[l2 + 1] = (1ll * h2[l2] * B2 % M2 + x) % M2;
          l2++;
        }
        int get(int* h, int l, int r, int b, int m) {
          return 1ll * (h[r] - 1ll * h[l - 1] * qp(b, r - l + 1, m) % m + m) % m;
        }
        int n, len;
        char cur[CN], nxt[CN];
        int main() {
          n = read() - 1;
          cin >> cur;
          len = strlen(cur);
          for (int i = 0; i < len; i++) add1(cur[i] - '0');
          while (n--) {
            cin >> nxt;
            int l = strlen(nxt);
            for (int i = 0; i < l; i++) add2(nxt[i] - '0');
            int p = 0;
            for (int i = 0; i < l && i < len; i++) {
              int G1 = get(H1, len - i, len, B1, M1),
                  G2 = get(H2, len - i, len, B2, M2);
              int g1 = get(h1, 1, i + 1, B1, M1), g2 = get(h2, 1, i + 1, B2, M2);
              if (G1 == g1 && G2 == g2) p = i + 1;
            }
        
          return 0;
            for (int i = len; i < len - p + l; i++)
              cur[i] = nxt[i - len + p], add1(cur[i] - '0');
            len = len - p + l, cur[len] = '\0';
            l2 = 0;
          }
          cout << cur;
        }
        ```