为什么在一些状况下,Python 中的简略乘/除法比位运算要快?

首先秉持着捕风捉影的精力,咱们先来验证一下

In [33]: %timeit 1073741825*2                                                                                                                                                                                                                                                                           7.47 ns ± 0.0843 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)In [34]: %timeit 1073741825<<1                                                                                                                                                                                                                                                                          7.43 ns ± 0.0451 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)In [35]: %timeit 1073741823<<1                                                                                                                                                                                                                                                                          7.48 ns ± 0.0621 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)In [37]: %timeit 1073741823*2                                                                                                                                                                                                                                                                           7.47 ns ± 0.0564 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

咱们发现几个很乏味的景象

  • 在值 x<=2^30 时,乘法比间接位运算要快
  • 在值 x>2^32 时,乘法显著慢于位运算

这个景象很乏味,那么这个景象的 root cause 是什么?实际上这和 Python 底层的实现无关

简略聊聊
PyLongObject 的实现
在 Python 2.x 期间,Python 中将整型分为两类,一类是 long, 一类是 int 。在 Python3 中这两者进行了合并。目前在 Python3 中这两者做了合并,仅剩一个 long

首先来看看 long 这样一个数据结构底层的实现

struct _longobject {    PyObject_VAR_HEAD    digit ob_digit[1];};

在这里不必关怀,PyObject_VAR_HEAD 的含意,咱们只须要关怀 ob_digit 即可。

在这里,ob_digit 是应用了 C99 中的“柔性数组”来实现任意长度的整数的存储。这里咱们能够看一下官网代码中的文档

Long integer representation.The absolute value of a number is equal to SUM(for i=0 through abs(ob_size)-1) ob_digit[i] 2(SHIFTi) Negative numbers are represented with ob_size < 0; zero is represented by ob_size == 0. In a normalized number, ob_digit[abs(ob_size)-1] (the most significant digit) is never zero. Also, in all cases, for all valid i,0 <= ob_digit[i] <= MASK. The allocation function takes care of allocating extra memory so that ob_digit[0] ... ob_digit[abs(ob_size)-1] are actually available. CAUTION: Generic code manipulating subtypes of PyVarObject has to aware that ints abuse ob_size's sign bit.

简而言之,Python 是将一个十进制数转为 2^(SHIFT) 进制数来进行存储。这里可能不太好了了解。我来举个例子,在我的电脑上,SHIFT 为 30 ,假如当初有整数 1152921506754330628 ,那么将其转为 2^30 进制示意则为: 4(2^30)^0+2(2^30)^1+1*(2^30)^2 。那么此时 ob_digit 是一个含有三个元素的数组,其值为 [4,2,1]

OK,在明确了这样一些基础知识后,咱们回过头去看看 Python 中的乘法运算

Python 中的乘法运算
Python 中的乘法运算,分为两局部,其中对于大数的乘法,Python 应用了 Karatsuba 算法1,具体实现如下

static PyLongObject *k_mul(PyLongObject *a, PyLongObject *b){    Py_ssize_t asize = Py_ABS(Py_SIZE(a));    Py_ssize_t bsize = Py_ABS(Py_SIZE(b));    PyLongObject *ah = NULL;    PyLongObject *al = NULL;    PyLongObject *bh = NULL;    PyLongObject *bl = NULL;    PyLongObject *ret = NULL;    PyLongObject *t1, *t2, *t3;    Py_ssize_t shift;           /* the number of digits we split off */    Py_ssize_t i;    /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl     * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh  + ah*bh + al*bl     * Then the original product is     *     ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl     * By picking X to be a power of 2, "*X" is just shifting, and it's     * been reduced to 3 multiplies on numbers half the size.     */    /* We want to split based on the larger number; fiddle so that b     * is largest.     */    if (asize > bsize) {        t1 = a;        a = b;        b = t1;        i = asize;        asize = bsize;        bsize = i;    }    /* Use gradeschool math when either number is too small. */    i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;    if (asize <= i) {        if (asize == 0)            return (PyLongObject *)PyLong_FromLong(0);        else            return x_mul(a, b);    }    /* If a is small compared to b, splitting on b gives a degenerate     * case with ah==0, and Karatsuba may be (even much) less efficient     * than "grade school" then.  However, we can still win, by viewing     * b as a string of "big digits", each of width a->ob_size.  That     * leads to a sequence of balanced calls to k_mul.     */    if (2 * asize <= bsize)        return k_lopsided_mul(a, b);    /* Split a & b into hi & lo pieces. */    shift = bsize >> 1;    if (kmul_split(a, shift, &ah, &al) < 0) goto fail;    assert(Py_SIZE(ah) > 0);            /* the split isn't degenerate */    if (a == b) {        bh = ah;        bl = al;        Py_INCREF(bh);        Py_INCREF(bl);    }    else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;    /* The plan:     * 1. Allocate result space (asize + bsize digits:  that's always     *    enough).     * 2. Compute ah*bh, and copy into result at 2*shift.     * 3. Compute al*bl, and copy into result at 0.  Note that this     *    can't overlap with #2.     * 4. Subtract al*bl from the result, starting at shift.  This may     *    underflow (borrow out of the high digit), but we don't care:     *    we're effectively doing unsigned arithmetic mod     *    BASE**(sizea + sizeb), and so long as the *final* result fits,     *    borrows and carries out of the high digit can be ignored.     * 5. Subtract ah*bh from the result, starting at shift.     * 6. Compute (ah+al)*(bh+bl), and add it into the result starting     *    at shift.     */    /* 1. Allocate result space. */    ret = _PyLong_New(asize + bsize);    if (ret == NULL) goto fail;#ifdef Py_DEBUG    /* Fill with trash, to catch reference to uninitialized digits. */    memset(ret->ob_digit, 0xDF, Py_SIZE(ret) * sizeof(digit));#endif    /* 2. t1 <- ah*bh, and copy into high digits of result. */    if ((t1 = k_mul(ah, bh)) == NULL) goto fail;    assert(Py_SIZE(t1) >= 0);    assert(2*shift + Py_SIZE(t1) <= Py_SIZE(ret));    memcpy(ret->ob_digit + 2*shift, t1->ob_digit,           Py_SIZE(t1) * sizeof(digit));    /* Zero-out the digits higher than the ah*bh copy. */    i = Py_SIZE(ret) - 2*shift - Py_SIZE(t1);    if (i)        memset(ret->ob_digit + 2*shift + Py_SIZE(t1), 0,               i * sizeof(digit));    /* 3. t2 <- al*bl, and copy into the low digits. */    if ((t2 = k_mul(al, bl)) == NULL) {        Py_DECREF(t1);        goto fail;    }    assert(Py_SIZE(t2) >= 0);    assert(Py_SIZE(t2) <= 2*shift); /* no overlap with high digits */    memcpy(ret->ob_digit, t2->ob_digit, Py_SIZE(t2) * sizeof(digit));    /* Zero out remaining digits. */    i = 2*shift - Py_SIZE(t2);          /* number of uninitialized digits */    if (i)        memset(ret->ob_digit + Py_SIZE(t2), 0, i * sizeof(digit));    /* 4 & 5. Subtract ah*bh (t1) and al*bl (t2).  We do al*bl first     * because it's fresher in cache.     */    i = Py_SIZE(ret) - shift;  /* # digits after shift */    (void)v_isub(ret->ob_digit + shift, i, t2->ob_digit, Py_SIZE(t2));    Py_DECREF(t2);    (void)v_isub(ret->ob_digit + shift, i, t1->ob_digit, Py_SIZE(t1));    Py_DECREF(t1);    /* 6. t3 <- (ah+al)(bh+bl), and add into result. */    if ((t1 = x_add(ah, al)) == NULL) goto fail;    Py_DECREF(ah);    Py_DECREF(al);    ah = al = NULL;    if (a == b) {        t2 = t1;        Py_INCREF(t2);    }    else if ((t2 = x_add(bh, bl)) == NULL) {        Py_DECREF(t1);        goto fail;    }    Py_DECREF(bh);    Py_DECREF(bl);    bh = bl = NULL;    t3 = k_mul(t1, t2);    Py_DECREF(t1);    Py_DECREF(t2);    if (t3 == NULL) goto fail;    assert(Py_SIZE(t3) >= 0);    /* Add t3.  It's not obvious why we can't run out of room here.     * See the (*) comment after this function.     */    (void)v_iadd(ret->ob_digit + shift, i, t3->ob_digit, Py_SIZE(t3));    Py_DECREF(t3);    return long_normalize(ret);  fail:    Py_XDECREF(ret);    Py_XDECREF(ah);    Py_XDECREF(al);    Py_XDECREF(bh);    Py_XDECREF(bl);    return NULL;}

这里不对 Karatsuba 算法1 的实现做独自解释,有趣味的敌人能够参考文末的 reference 去理解具体的详情。

在一般状况下,一般乘法的工夫复杂度为 n^2 (n 为位数),而 K 算法的工夫复杂度为 3n^(log3) ≈ 3n^1.585 ,看起来 K 算法的性能要优于一般乘法,那么为什么 Python 不全副应用 K 算法呢?

很简略,K 算法的劣势实际上要在当 n 足够大的时候,才会对一般乘法造成劣势。同时思考到内存拜访等因素,当 n 不够大时,实际上采纳 K 算法的性能将差于间接进行乘法。

所以咱们来看看 Python 中乘法的实现

static PyObject *long_mul(PyLongObject *a, PyLongObject *b){    PyLongObject *z;    CHECK_BINOP(a, b);    /* fast path for single-digit multiplication */    if (Py_ABS(Py_SIZE(a)) <= 1 && Py_ABS(Py_SIZE(b)) <= 1) {        stwodigits v = (stwodigits)(MEDIUM_VALUE(a)) * MEDIUM_VALUE(b);        return PyLong_FromLongLong((long long)v);    }    z = k_mul(a, b);    /* Negate if exactly one of the inputs is negative. */    if (((Py_SIZE(a) ^ Py_SIZE(b)) < 0) && z) {        _PyLong_Negate(&z);        if (z == NULL)            return NULL;    }    return (PyObject *)z;}

在这里咱们看到,当两个数皆小于 2^30-1 时,Python 将间接应用一般乘法并返回,否则将应用 K 算法进行计算

这个时候,咱们来看一下位运算的实现,以右移为例

static PyObject *long_rshift(PyObject *a, PyObject *b){    Py_ssize_t wordshift;    digit remshift;    CHECK_BINOP(a, b);    if (Py_SIZE(b) < 0) {        PyErr_SetString(PyExc_ValueError, "negative shift count");        return NULL;    }    if (Py_SIZE(a) == 0) {        return PyLong_FromLong(0);    }    if (divmod_shift(b, &wordshift, &remshift) < 0)        return NULL;    return long_rshift1((PyLongObject *)a, wordshift, remshift);}static PyObject *long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift){    PyLongObject *z = NULL;    Py_ssize_t newsize, hishift, i, j;    digit lomask, himask;    if (Py_SIZE(a) < 0) {        /* Right shifting negative numbers is harder */        PyLongObject *a1, *a2;        a1 = (PyLongObject *) long_invert(a);        if (a1 == NULL)            return NULL;        a2 = (PyLongObject *) long_rshift1(a1, wordshift, remshift);        Py_DECREF(a1);        if (a2 == NULL)            return NULL;        z = (PyLongObject *) long_invert(a2);        Py_DECREF(a2);    }    else {        newsize = Py_SIZE(a) - wordshift;        if (newsize <= 0)            return PyLong_FromLong(0);        hishift = PyLong_SHIFT - remshift;        lomask = ((digit)1 << hishift) - 1;        himask = PyLong_MASK ^ lomask;        z = _PyLong_New(newsize);        if (z == NULL)            return NULL;        for (i = 0, j = wordshift; i < newsize; i++, j++) {            z->ob_digit[i] = (a->ob_digit[j] >> remshift) & lomask;            if (i+1 < newsize)                z->ob_digit[i] |= (a->ob_digit[j+1] << hishift) & himask;        }        z = maybe_small_long(long_normalize(z));    }    return (PyObject *)z;}

在这里咱们能看到,在两侧都是小数的状况下,位挪动算法将比一般乘法,存在更多的内存调配等操作。这样也会答复了咱们文初所提到的一个问题,“为什么一些时候乘法比位运算更快”。

总结
本文差不多就到这里了,实际上通过这次剖析咱们能失去一些很乏味然而也很冷门的常识。实际上咱们目前看到这样一个后果,是 Python 对于咱们常见且高频的操作所做的一个特定的设计。而这也揭示咱们,Python 实际上对于很多操作都存在本人内建的设计哲学,在日常应用的时候,其余语言的教训,可能无奈复用

以上就是本次分享的所有内容,想要理解更多 python 常识欢送返回公众号:Python 编程学习圈 ,发送 “J” 即可收费获取,每日干货分享