为什么要优化
如果一个算法的复杂度是 f ( n ) f(n) f ( n ) ,那么实际耗时往往是 T ( n ) = k × f ( n ) T(n)=k \times f(n) T ( n ) = k × f ( n ) , k k k 就是代码的常数,但常数过大的时候就有可能超时。有些出题人会有意无意地卡常数,这时候不会代码优化技巧,则有可能被卡常错失AC。另外有可能因为对底层机制的了解不够,而踩坑,导致非常数因素的大幅耗时,最终TLE。
有些出题人对算法是有要求的,那么有可能会刻意卡掉一些空间需求高的算法,那么需要学会计算内存空间和空间优化。
时间优化技巧
IO优化
从最基本的算法无关的优化点开始,IO优化。在输入输出的时候,底层都是会涉及IO,那么有IO操作的地方,必定有IO耗时,这个是操作系统的知识,这里不展开。
如果全代码都只用到了scanf, printf, gets, putchar这一系列的C语言函数,则不需要考虑IO优化。
如果使用到了C++的cin/cout,则有几个点需要注意。
在输入或输出上不可以将C和C++的IO库混用。例如:在输入上不可scanf, getchar 和 cin 混用;在输出上不可printf 和 cout 混用;但 scanf 和 cout 是可以混用的。
若使用到了C++的IO库,则需要在代码中添加以下3行代码。但由于添加这3行代码后,输出结果不会立刻出现在控制台,会导致本地调试的时候困难,所以推荐在本地编译的时候增加宏定义,来区分是否是本地环境。切记这个宏定义不可和OJ添加的宏定义重名。
1 2 3 ios_base::sync_with_stdio(false ); cin .tie(0 ); cout .tie(0 );
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 int main () {#ifdef ACM_LOCAL freopen("./data/std.in" , "r" , stdin ); #else ios_base::sync_with_stdio(false ); cin .tie(0 ); cout .tie(0 ); #endif #ifdef ACM_LOCAL auto start = clock(); #endif int t = 1 ; while (t--) solve(); #ifdef ACM_LOCAL auto end = clock(); cerr << "Run Time: " << double (end - start) / CLOCKS_PER_SEC << "s" << endl ; #endif return 0 ; }
由于C++的endl
包含了flush
,会导致输出流刷新。但由于内部机制,会使得输入流也发生同步行为,会导致IO耗时暴增,所以可以参考第二条,在非本地情况下将endl
通过宏定义的方式,替换成 '\n'
。
1 2 3 #ifndef ACM_LOCAL #define endl '\n' #endif
完整代码基础样板可以参考 https://github.com/happier233/ACM-Code/blob/master/00_头文件/00_Header.cpp
不建议在平时做题的时候经常使用快读模板,会养成坏习惯,在现场赛的时候往往没有时间去敲一份快读模板。正常情况下出题人不会硬卡IO的耗时,所以如果用快读模板过了题,但有可能去除快读还是会TLE,这说明你的算法没有达到正确的耗时要求。
在开启IO优化的情况下,可以完全保证C++的cin, cout耗时和C的scanf, printf的耗时持平。
内存访问优化
介绍
内存访问优化在于每次去读取特大的数组的时候,尽量保持连续的内存空间访问。
先举个代码例子,用来求一个二维矩阵上的一个问题的代码,为了防止逻辑过于简单导致编译器优化,所以采取了一些方法屏蔽了编译器优化。
测试代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 const int N = 10000 ;int a[N][N], b[N][N];int sum1 () { int s = 0 ; for (int i = 0 ; i < N; i++) { s += a[i][0 ]; for (int j = 1 ; j < N; j++) { a[i][j] += a[i][j - 1 ]; s += a[i][j]; } } return s; } int sum2 () { int s = 0 ; for (int i = 0 ; i < N; i++) { s += a[0 ][i]; for (int j = 1 ; j < N; j++) { a[j][i] += a[j - 1 ][i]; s += a[j][i]; } } return s; } void rnd () { srand(0 ); for (int i = 0 ; i < N; i++) { for (int j = i; j < N; j++) { b[i][j] = rand(); b[j][i] = b[i][j]; } } } void clone () { memcpy (b[0 ], a[0 ], N * N * sizeof (int )); } void solve () { rnd(); clock_t s, t; clone(); s = clock(); printf ("sum1: %d\n" , sum1()); t = clock(); printf ("sum1 clock delte: %lld\n" , ll(t - s)); clone(); s = clock(); printf ("sum2: %d\n" , sum2()); t = clock(); printf ("sum2 clock delte: %lld\n" , ll(t - s)); }
输出结果
1 2 3 4 sum1: -682505014 sum1 clock delte: 90491 sum2: -682505014 sum2 clock delte: 782007
分析
最终的输出结果可见产生了近10倍的耗时差距 ,这个耗时差距随着第二维增大而增大。接下来分析一下产生的原因。
由于同一个数组的内存是连续的,不论是一维还是二维还是三维。
为了便于计算,对int a[1024][1024]
,进行讨论那么如果a[0][0]
是在地址0x0000
,那么a[1][0]
的内存地址是0x1000
(1024 ∗ 4 1024*4 1 0 2 4 ∗ 4 ,一个int是4字节),a[2][0]
的内存地址是0x2000
。
因为CPU是有Cache机制的,为了降低访问内存的延时造成CPU空等,这个知识可以在计组和体系结构学到。所以在每次计算一次,需要访问4Kb的内存。CPU每次不是只读取一个int的内存,而是读取几百字节或者几k字节的,那么Cache命中率会大幅下降。但在 sum1
的写法中,CPU读取一次内存,可以进行几百次操作,可以大幅提高处理效率。
这个现象在写 dp 算法的时候尤其需要注意。不只是二维,在一维或者当维度变多的时候都适用。原则就是尽量访问靠近的内存。但不要刻意为了这个优化让代码变得过于复杂。
空间大小对时间的优化
先给出测试代码和输出结果,就是一个简单的对数组求和。在数组的计算过程中,由于可能会爆int,所以有人干脆会把数组开成long long,但这在复杂数据结构中会产生巨大的耗时常数影响。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 const int N = int (1e7 );const int MOD = int (1e9 + 7 );ll a[N]; int b[N];void init () { a[0 ] = 1 ; for (int i = 1 ; i < N; i++) b[i] = a[i] = a[i - 1 ] * i % MOD; } void solve () { init(); clock_t s, t; s = clock(); printf ("sum1: %d\n" , accumulate(a, a + N, 0 )); t = clock(); printf ("sum1 clock delte: %lld\n" , ll(t - s)); s = clock(); printf ("sum2: %d\n" , accumulate(b, b + N, 0 )); t = clock(); printf ("sum2 clock delte: %lld\n" , ll(t - s)); }
1 2 3 4 sum1: -1246626476 sum1 clock delte: 8125 sum2: -1246626477 sum2 clock delte: 3392
可见,对a的求和时间是b的两倍多 ,这就直接让常数产生了翻倍,虽然实际算法中,对单个数组的访问不会占到大比例,但确实会产生影响,如果数据的结果是不会超出int的,并且对自己代码的耗时不确定能否通过的,建议修改成int进行存储。
小技巧
这些小技巧优化有限,虽然不知道具体导致这些代码编译后的差异何在,但实测的确有效果,主要受到编译器优化影响。
取模连写
1 2 3 4 5 6 7 int a, b, c, mod; a = a * b % mod; (a *= b) %= mod; a = ((a * b) % mod) + c) % mod; ((a *= b) %= mod) += c) %= mod;
减少取模次数
假设 m o d = 1 0 9 + 7 mod=10^9+7 m o d = 1 0 9 + 7 , 有2个二维非负整数矩阵 a , b a, b a , b ,都是 n × m n \times m n × m 大小, n , m ∈ [ 1 , 1 0 4 ] n, m \in [1, 10^4] n , m ∈ [ 1 , 1 0 4 ] ,其中所有值都有 a i j , b i j ∈ [ 0 , 1 0 4 ] a_{ij}, b_{ij} \in [0, 10^4] a i j , b i j ∈ [ 0 , 1 0 4 ] ,求 $ \sum_{i=1}^n ((\sum_{j=1}^m {a_{ij}}) \times (\sum_{j=1}^m {b_{ij}})) \bmod MOD $
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 typedef long long ll;const int N = 10005 ;const int MOD = 1000000007 ;const ll MOD2 = 1l l * MOD * MOD;int a[N][N], b[N][N];int calc1 (int n, int m) { int s = 0 ; for (int i = 0 ; i < n; i++) { int sa = 0 , sb = 0 ; for (int j = 0 ; j < m; j++) { (sa += a[i][j]) %= MOD; (sb += a[i][j]) %= MOD; } int si = (1l l * sa * sb) % MOD; (s += si) %= MOD; } return s; } int calc2 (int n, int m) { ll s = 0 ; for (int i = 0 ; i < n; i++) { ll sa = 0 , sb = 0 ; for (int j = 0 ; j < m; j++) { sa += a[i][j]; sb += b[i][j]; } ll si = sa * sb; (s += si) %= MOD; } return s; } int calc3 (int n, int m) { ll s = 0 ; for (int i = 0 ; i < n; i++) { ll sa = 0 , sb = 0 ; for (int j = 0 ; j < m; j++) { sa += a[i][j]; sb += b[i][j]; } ll si = sa * sb; s += si; if (s >= MOD2) s -= MOD2; } return s % MOD; } int calc4 (int n, int m) { ll s = 0 ; for (int i = 0 ; i < n; i++) { ll sa = accumulate(a[i], a[i] + m, 0l l); ll sb = accumulate(b[i], b[i] + m, 0l l); ll si = sa * sb; s += si; if (s >= MOD2) s -= MOD2; } return s % MOD; }
可以看得出,在每层for里,calc1一共有4次取模,calc2只有1次取模。通过在线编译查看编译后的汇编代码,https://gcc.godbolt.org/ ,对比这2个代码的汇编,也可以确认calc1执行了4次取模。由于CPU对于取模的运算耗时是巨大的,所以降低取模次数,可以很好的降低常数。
一次优化原理如下,由于 $\sum_{j=1}^m {a_{ij}} $ 和 $ \sum_{j=1}^m {b_{ij}}$ 都在 [ 0 , 1 0 8 ] [0, 10^{8}] [ 0 , 1 0 8 ] 范围内,所以可以用 long long
存下,所以 s i = ( ∑ j = 1 m a i j ) × ( ∑ j = 1 m b i j ) ∈ [ 0 , 1 0 16 ] s_i = (\sum_{j=1}^m {a_{ij}}) \times (\sum_{j=1}^m {b_{ij}}) \in [0, 10^{16}] s i = ( ∑ j = 1 m a i j ) × ( ∑ j = 1 m b i j ) ∈ [ 0 , 1 0 1 6 ] 也可以用 long long
存下,但由于最终的 s s s 是在 [ 0 , 1 0 20 ] [0, 10^{20}] [ 0 , 1 0 2 0 ] 超出了 long long
的表达范围,所以需要对 s s s 的加法结果进行取模。
二次优化原理如下,由于实际long long
可以表达到 [ − 2 63 , 2 63 − 1 ] [-2^{63}, 2^{63}-1] [ − 2 6 3 , 2 6 3 − 1 ] 范围,2 63 2^{63} 2 6 3 约等于 9 ∗ 1 0 18 9*10^{18} 9 ∗ 1 0 1 8 。我们可以使用 M O D 2 MOD^2 M O D 2 作为 s s s 的模数,最终返回的时候再对 MOD 取模一次,正确性可以通过数论证明,这里不做展开,那么 s s s 只有大约 1 100 \frac{1}{100} 1 0 0 1 的概率实际需要进行取模,在平时题目中,这个概率是不确定的但往往是比较小的一个概率,那么使用 if 来判定然后做减法,往往会比取模效率提高非常多。
二次优化+偷懒原理如下,由于这种求和操作非常普通,但往往写起来需要时间,赛场上的任何一秒钟都是重要的,而且为了防止自己写错,可以使用C++ algorithm 中的 accumulate 函数来进行求和操作。accumulate 函数具体用法参考以下手册链接。
https://zh.cppreference.com/w/cpp/algorithm/accumulate
gcc内置函数
使用gcc编译器作为本地编译器的时候,可以考虑使用gcc内置函数解决一些小问题。https://www.cnblogs.com/liuzhanshan/p/6861596.html
另外gcc还有一个内置的扩展stl库,叫pb_ds
,里面的rbtree可以解决stl中set无法解决的一些问题,在不需要刻意降低常数的情况下使用,避免自己需要额外编写数据结构的时间。
空间优化技巧
2倍空间线段树
在普通的线段树写法中,线段树空间都是需要开4倍才能保证正确性的,这里我不做展开证明。
但如果4倍空间开不下怎么办,那么我有一个2倍空间的线段树写法,并且可以数学证明正确性。
这种写法会带来 %5 ~ %10的性能损耗,但却可以有一个全新的方法访问线段树中的任意一个[l, r]节点,前提该线段树中存在该区间的节点。例如可以方便的访问叶子节点,具体有什么用途未知,但可以发挥想象,万一哪天就用到了呢。
传统线段树中用 k 作为当前节点的下标,线段树的根节点为1。
在优化后的线段树中,对于每个覆盖了[l, r]的区间来说,他的下标是 (l + r) | (l != r)
。**所以这种线段树覆盖的区间最左值必须是非负整数。**最好的覆盖区间是 [0, n] 或者 [1, n]。
这种情况下,下标最大的节点必定是最右端的叶子节点,也就是n的2倍。那么接下来证明线段树中的任意一个节点不可能发生seg1 [l, r]算出来的下标和 seg2 [x, y]重复。
分几种情况进行讨论:
若 r < x r < x r < x ,说明 seg1在seg2的左侧。那么seg1的最大ID就是当 $ l=r $ 或 l + 1 = r l + 1 = r l + 1 = r 的时候,I D 1 = 2 ∗ r ID_1=2*r I D 1 = 2 ∗ r 。那么seg2的最小ID就是当x = y x=y x = y 的时候,I D 2 = 2 ∗ x ID_2=2*x I D 2 = 2 ∗ x ,由于r < x r<x r < x ,所以I D 1 < I D 2 ID_1<ID_2 I D 1 < I D 2 。不重复
若 y < l y < l y < l ,说明恰好与情况1相反,不再论证。
由于线段树的性质,2个线段树的节点覆盖的区间不可能相交,要么没有交集,要么是包含关系。这里就讨论seg1包含seg2的情况。若seg1包含seg2,则l ≤ x ≤ y ≤ r l \le x \le y \le r l ≤ x ≤ y ≤ r ,并且线段树具有2分的性质,要么seg2在seg1的左半区间,要么在seg1的右半区间。后面2种情况分别列举左右区间问题。
若seg2在seg1的左半区间,则m a x ( I D 2 ) = 2 ∗ y max(ID_2)=2*y m a x ( I D 2 ) = 2 ∗ y ,y ≤ ⌊ l + r 2 ⌋ y \le \lfloor \frac{l + r}{2} \rfloor y ≤ ⌊ 2 l + r ⌋ 。当 l + r l + r l + r 是偶数的时候,I D 1 = l + r + 1 , I D 2 = 2 ∗ y = l + r , I D 1 > I D 2 ID_1=l+r+1,ID_2=2*y=l+r,ID_1>ID_2 I D 1 = l + r + 1 , I D 2 = 2 ∗ y = l + r , I D 1 > I D 2 ;当 l + r l + r l + r 是奇数的时候,I D 1 = l + r , I D 2 = 2 ∗ y = l + r − 1 , I D 1 > I D 2 ID_1=l+r,ID_2=2*y=l+r-1,ID_1>ID_2 I D 1 = l + r , I D 2 = 2 ∗ y = l + r − 1 , I D 1 > I D 2 。
若seg2在seg1的右半区间,则m i n ( I D 2 ) = 2 ∗ x min(ID_2)=2*x m i n ( I D 2 ) = 2 ∗ x ,x > ⌊ l + r 2 ⌋ x > \lfloor \frac{l + r}{2} \rfloor x > ⌊ 2 l + r ⌋ 。当 l + r l + r l + r 是偶数的时候,I D 1 = l + r + 1 , I D 2 = 2 ∗ x = l + r + 2 , I D 1 < I D 2 ID_1=l+r+1,ID_2=2*x=l+r+2,ID_1<ID_2 I D 1 = l + r + 1 , I D 2 = 2 ∗ x = l + r + 2 , I D 1 < I D 2 ;当 l + r l + r l + r 是奇数的时候,I D 1 = l + r , I D 2 = 2 ∗ x = l + r + 1 , I D 1 < I D 2 ID_1=l+r,ID_2=2*x=l+r+1,ID_1<ID_2 I D 1 = l + r , I D 2 = 2 ∗ x = l + r + 1 , I D 1 < I D 2 。
所以不论什么情况下,都不可能出现 I D ID I D 重复,由于线段树的性质,当覆盖的区间长度为n,那么实际节点为2 ∗ n − 1 2*n-1 2 ∗ n − 1 个,所以这种写法的线段树可以充分使用数组空间。
普通线段树区间修改求和模板:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 const int N = 100005 ;int a[N];struct SegTree { ll sum[N << 1 ], laz[N << 1 ]; int get (int l, int r) { return (l + r) | (l != r); } void up (int l, int r) { int mid = (l + r) >> 1 ; sum[get(l, r)] = sum[get(l, mid)] + sum[get(mid + 1 , r)]; } void push (int l, int r) { if (laz[get(l, r)] == 0 ) return ; int mid = (l + r) >> 1 ; laz[get(l, mid)] += laz[get(l, r)]; sum[get(l, mid)] += laz[get(l, r)] * (mid - l + 1 ); laz[get(mid + 1 , r)] += laz[get(l, r)]; sum[get(mid + 1 , r)] += laz[get(l, r)] * (r - mid); laz[get(l, r)] = 0 ; } void build (int l, int r) { laz[get(l, r)] = 0 ; if (l == r) { a[get(l, r)] = a[l]; return ; } int mid = (l + r) >> 1 ; build(l, mid); build(mid + 1 , r); up(l, r); } void update (int l, int r, int x, int y, ll w) { if (l == x && r == y) { sum[get(l, r)] += (l - r + 1 ) * w; laz[get(l, r)] += w; return ; } push(l, r); int mid = (l + r) >> 1 ; if (y <= mid) { update(l, mid, x, y, w); } else if (x >= mid) { update(mid + 1 , r, x, y, w); } else { update(l, mid, x, mid, w); update(mid + 1 , r, mid + 1 , y, w); } up(l, r); } ll query (int l, int r, int x, int y) { if (l == x && r == y) { return sum[get(l, r)]; } push(l, r); int mid = (l + r) >> 1 ; if (y <= mid) { return query(l, mid, x, y); } else if (x >= mid) { return query(mid + 1 , r, x, y); } else { return query(l, mid, x, mid) + query(mid + 1 , r, mid + 1 , y); } } };
再给出一种偷懒写法,虽然我自己不太常用这种偷懒写法,但有时候的确省时间:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 const int N = 100005 ;int a[N];struct SegTree {#define smid ((l + r) >> 1) #define self get(l, r) #define lson get(l, smid) #define rson get(smid + 1, r) ll sum[N << 1 ], laz[N << 1 ]; int get (int l, int r) { return (l + r) | (l != r); } void up (int l, int r) { sum[self] = sum[lson] + sum[rson]; } void push (int l, int r) { if (laz[self] == 0 ) return ; int mid = (l + r) >> 1 ; laz[lson] += laz[self]; sum[lson] += laz[self] * (mid - l + 1 ); laz[rson] += laz[self]; sum[rson] += laz[self] * (r - mid); laz[self] = 0 ; } void build (int l, int r) { laz[self] = 0 ; if (l == r) { a[self] = a[l]; return ; } int mid = (l + r) >> 1 ; build(l, mid); build(mid + 1 , r); up(l, r); } void update (int l, int r, int x, int y, ll w) { if (l == x && r == y) { sum[self] += (l - r + 1 ) * w; laz[self] += w; return ; } push(l, r); int mid = (l + r) >> 1 ; if (y <= mid) { update(l, mid, x, y, w); } else if (x >= mid) { update(mid + 1 , r, x, y, w); } else { update(l, mid, x, mid, w); update(mid + 1 , r, mid + 1 , y, w); } up(l, r); } ll query (int l, int r, int x, int y) { if (l == x && r == y) { return sum[self]; } push(l, r); int mid = (l + r) >> 1 ; if (y <= mid) { return query(l, mid, x, y); } else if (x >= mid) { return query(mid + 1 , r, x, y); } else { return query(l, mid, x, mid) + query(mid + 1 , r, mid + 1 , y); } } };