社区讨论

求一点点神奇调优

P13013[GESP202506 五级] 奖品兑换参与者 2已保存回复 2

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
2 条
当前快照
1 份
快照标识符
@mhjofbvd
此快照首次捕获于
2025/11/04 05:53
4 个月前
此快照最后确认于
2025/11/04 05:53
4 个月前
查看原帖
写了一份 AVX512 优化暴力的代码,目前极限数据要跑 1.4s 左右,有没有大佬能够帮忙优化下?
CPP
#include <immintrin.h>
#include <iostream>
#include <algorithm>
#include <cstdint>

using namespace std;

void compute_magic(uint64_t d, uint64_t &magic, uint64_t &base_shift) {
    if (d == 1) {
        magic = 1;
        base_shift = 0;
        return;
    }
    base_shift = 0;
    uint64_t tmp = d - 1;
    while (tmp) {
        base_shift++;
        tmp >>= 1;
    }
    uint64_t total_shift = 32 + base_shift;
    uint64_t power = 1ULL << total_shift;
    magic = (power + d - 1) / d;
}

int main() {
    uint64_t n, m, a, b;
    cin >> n >> m >> a >> b;

    uint64_t limit1 = n / a;
    uint64_t limit2 = m / b;
    uint64_t limit = min(limit1, limit2);

    uint64_t a_magic, a_base_shift, b_magic, b_base_shift;
    compute_magic(a, a_magic, a_base_shift);
    uint64_t a_total_shift = a_base_shift + 32;
    compute_magic(b, b_magic, b_base_shift);
    uint64_t b_total_shift = b_base_shift + 32;

    uint64_t best = 0;

    for (uint64_t i = 0; i <= limit; i += 8) {
        if (i + 7 > limit) {
            for (uint64_t j = i; j <= limit; j++) {
                uint64_t n_remain = n - a * j;
                uint64_t y1 = (n_remain * b_magic) >> b_total_shift;
                uint64_t m_remain = m - b * j;
                uint64_t y2 = (m_remain * a_magic) >> a_total_shift;
                uint64_t y = min(y1, y2);
                uint64_t total = j + y;
                if (total > best) best = total;
            }
            break;
        }

        __m512i x_vec = _mm512_set_epi64(i+7, i+6, i+5, i+4, i+3, i+2, i+1, i);
        __m512i a_vec = _mm512_set1_epi64(a);
        __m512i ax_vec = _mm512_mullo_epi64(x_vec, a_vec);

        __m512i n_vec = _mm512_set1_epi64(n);
        __m512i n_remain_vec = _mm512_sub_epi64(n_vec, ax_vec);

        __m512i b_magic_vec = _mm512_set1_epi64(b_magic);
        __m512i product1 = _mm512_mullo_epi64(n_remain_vec, b_magic_vec);
        __m512i b_shift_vec = _mm512_set1_epi64(b_total_shift);
        __m512i y1_vec = _mm512_srlv_epi64(product1, b_shift_vec);

        __m512i b_vec = _mm512_set1_epi64(b);
        __m512i bx_vec = _mm512_mullo_epi64(x_vec, b_vec);

        __m512i m_vec = _mm512_set1_epi64(m);
        __m512i m_remain_vec = _mm512_sub_epi64(m_vec, bx_vec);

        __m512i a_magic_vec = _mm512_set1_epi64(a_magic);
        __m512i product2 = _mm512_mullo_epi64(m_remain_vec, a_magic_vec);
        __m512i a_shift_vec = _mm512_set1_epi64(a_total_shift);
        __m512i y2_vec = _mm512_srlv_epi64(product2, a_shift_vec);

        __m512i y_vec = _mm512_min_epu64(y1_vec, y2_vec);
        __m512i total_vec = _mm512_add_epi64(x_vec, y_vec);

        uint64_t total_arr[8];
        _mm512_storeu_si512((__m512i*)total_arr, total_vec);
        for (int j = 0; j < 8; j++) {
            if (total_arr[j] > best) best = total_arr[j];
        }
    }

    cout << best << endl;
    return 0;
}

回复

2 条回复,欢迎继续交流。

正在加载回复...