/* implementation of agrawal, kayal, saxena primality algorithm by chris calabro ccalabro@cs.ucsd.edu usage: primality n where n is an integer in the range [0, 2^32-1] */ #include #include #include const int MAXINT = 0x7fffffff; /* if number is found to be composite, this string will say why */ char reason[256]; /* x % y has the undesirable behavior that it leaves a negative number when x < 0. but 0 <= mod(x, y) < y. */ inline int big_mod(int x, int y) { int z = x % y; return z < 0 ? z + y : z; } inline double log2(double x) { return log(x) / .69314718056; } /* only handle 32-bit integers for now return -1 on overflow. */ int pow(int x, int y) { long long acc = 1, z = x; while (y) { if (z > MAXINT) { // overflow //printf("overflow by z = %Ld > MAXINT\n", z); return -1; } if (y & 1) { acc *= z; if (acc > MAXINT) { // overflow //printf("overflow by acc = %Ld > MAXINT\n", acc); return -1; } } z *= z; y >>= 1; } return acc; } /* compute x^y mod m. assume m >= 2. overflow is impossible. */ int modpow(int x, int y, int m) { long long acc = 1, z = x % m; while (y) { if (y & 1) acc = (acc * z) % m; z = (z * z) % m; y >>= 1; } return acc; } /* is n of the form a^b with b >= 2? assume n >= 2. */ bool is_power(int n) { int b, p, low, high, mid, n_bits; /* the number of bits in n_bits is how many iterations we use to search for b */ n_bits = n >> 1; high = n; for (b = 2; n_bits; b++) { /* find the b^th root of n by binary search. we can be certain that low <= n^(1/b) <= high from iteration to iteration since an upper bound on n^(1/b) also serves as an upper bound for n^(1/(b+1)). */ low = 1; /* loop invariant: low <= n^(1/b) <= high */ while (high - low >= 2) { mid = (high + low) / 2; p = pow(mid, b); /* there may have been overflow */ if (p >= 0 && p < n) { low = mid; } else { high = mid; } } /* now low <= n^(1/b) <= high and 0 <= high - low <= 1 */ if (pow(low, b) == n) { sprintf(reason, "%d = %d ^ %d", n, low, b); return true; } if (pow(high, b) == n) { sprintf(reason, "%d = %d ^ %d", n, high, b); return true; } n_bits >>= 1; } return false; } /* assume x,y >= 0, (x,y) != (0,0) */ int gcd(int x, int y) { int temp; while (y != 0) { // (x, y) = (y, x mod y); temp = y; y = x % y; x = temp; } return x; } bool is_prime_naive(int n) { if (n == 1) return false; for (int q = 2; q < n; q++) { if (n % q == 0) return false; } return true; } /* return 1 if n has no prime factor, ie if n = 1 */ int largest_prime_factor(int n) { /* search from smallest to largest prime divisor. searching the other way is actually slower! */ int q, max = 1; while (n >= 2) { // divide n by smallest prime divisor for (q = 2; q <= n; q++) { if (n % q == 0) { if (q > max) max = q; n /= q; break; } } } return max; } /* reduce poly in buf mod x^r-1, n. assume buf has size s = 2r-1. assume buf is already reduced mod n. use an alg specialized for x^r-1, not a generic polynomial division alg. */ void poly_reduce(int* buf, int n, int r, int s) { for (int i = s - 1; i >= r; i--) { /* can overflow occur here? */ buf[i - r] = (buf[i - r] + buf[i]) % n; buf[i] = 0; } } /* acc = buf1*buf2 mod n. assume acc, buf1 do not intersect. assume acc, buf2 do not intersect. use naive method, no FFT. */ void poly_mul(int* acc, int* buf1, int* buf2, int n, int s) { int i, j; long long a; for (i = 0; i < s; i++) { // acc[i] = sum_{j = 0}^i buf1[j] * buf2[i - j]; a = 0; for (j = 0; j <= i; j++) { // cast needed to prevent overflow a += ((long long)buf1[j] * buf2[i - j]) % n; } acc[i] = a % n; } } bool poly_equal(int* x, int* y, int s) { for (int i = 0; i < s; i++) if (x[i] != y[i]) return false; return true; } /* for debugging */ void poly_print(int* buf, int s) { for (int i = 0; i < s; i++) printf("%d ", buf[i]); printf("\n"); } void poly_copy(int* dst, int* src, int s) { for (int i = 0; i < s; i++) dst[i] = src[i]; } /* test whether (x-a)^n = x^n - a (mod x^r-1, n) acc, buf1, buf2 are needed for scratch work. assume each has size s = 2r-1. assume n, r >= 2. */ bool poly_identity(int a, int n, int r, int s, int* acc, int* buf1, int* buf2) { // compute (x-a)^n by repeated squaring int m, i, ma = big_mod(-a, n); // acc = 1 acc[0] = 1; for (i = 1; i < s; i++) acc[i] = 0; // buf1 = x - a buf1[0] = ma; buf1[1] = 1; for (i = 2; i < s; i++) buf1[i] = 0; for (m = n; m; m >>= 1) { if (m & 1) { // acc *= buf1; poly_copy(buf2, acc, s); poly_mul(acc, buf1, buf2, n, s); poly_reduce(acc, n, r, s); } // buf1 *= buf1; poly_copy(buf2, buf1, s); poly_mul(buf1, buf2, buf2, n, s); } // test whether acc = x^n-a (mod x^r-1, n) = x^(n mod r) - a; m = n % r; /* special case: if m = 0, then x^(n mod r)-a = 1-a */ if (m == 0) { if (acc[0] != (1+ma) % n) { sprintf(reason, "failed poly identity with a: %d r: %d", a, r); return false; } for (i = 1; i < s; i++) if (acc[i] != 0) { sprintf(reason, "failed poly identity with a: %d r: %d", a, r); return false; } return true; } if (acc[0] != ma || acc[m] != 1) { sprintf(reason, "failed poly identity with a: %d r: %d", a, r); return false; } for (i = 1; i < m; i++) if (acc[i] != 0) { sprintf(reason, "failed poly identity with a: %d r: %d", a, r); return false; } for (i = m + 1; i < s; i++) if (acc[i] != 0) { sprintf(reason, "failed poly identity with a: %d r: %d", a, r); return false; } return true; } bool is_prime(int n) { if (n <= 1) { sprintf(reason, "%d <= 1", n); return false; } if (is_power(n)) { return false; } int a, q, r, s; for (r = 2; r < n; r++) { q = gcd(n, r); if (q != 1) { sprintf(reason, "%d = %d * %d", n, q, n / q); return false; } if (is_prime_naive(r)) { q = largest_prime_factor(r - 1); if (q >= ceil(4 * sqrt(r) * log2(n)) && modpow(n, (r - 1) / q, r) != 1) break; } } //printf("using r = %d\n", r); q = (int)ceil(2 * sqrt(r) * log2(n)); s = 2 * r - 1; int* buf1 = new int[s]; int* buf2 = new int[s]; int* acc = new int[s]; for (a = 1; a <= q; a++) { if (!poly_identity(a, n, r, s, acc, buf1, buf2)) { delete[] buf1; delete[] buf2; delete[] acc; return false; } } delete[] buf1; delete[] buf2; delete[] acc; return true; } void usage() { printf("primality n where n is an integer in the range [0, 2^32-1] "); exit(1); } void main(int argc, char** argv) { if (argc < 2) usage(); int n = atoi(argv[1]); if (is_prime(n)) { printf("%d is prime\n", n); } else { printf("%d is not prime: %s\n", n, reason); } }