社区讨论
求一点点神奇调优
P13013[GESP202506 五级] 奖品兑换参与者 2已保存回复 2
讨论操作
快速查看讨论及其快照的属性,并进行相关操作。
- 当前回复
- 2 条
- 当前快照
- 1 份
- 快照标识符
- @mhjofbvd
- 此快照首次捕获于
- 2025/11/04 05:53 4 个月前
- 此快照最后确认于
- 2025/11/04 05:53 4 个月前
写了一份 AVX512 优化暴力的代码,目前极限数据要跑 1.4s 左右,有没有大佬能够帮忙优化下?
CPP#include <immintrin.h>
#include <iostream>
#include <algorithm>
#include <cstdint>
using namespace std;
void compute_magic(uint64_t d, uint64_t &magic, uint64_t &base_shift) {
if (d == 1) {
magic = 1;
base_shift = 0;
return;
}
base_shift = 0;
uint64_t tmp = d - 1;
while (tmp) {
base_shift++;
tmp >>= 1;
}
uint64_t total_shift = 32 + base_shift;
uint64_t power = 1ULL << total_shift;
magic = (power + d - 1) / d;
}
int main() {
uint64_t n, m, a, b;
cin >> n >> m >> a >> b;
uint64_t limit1 = n / a;
uint64_t limit2 = m / b;
uint64_t limit = min(limit1, limit2);
uint64_t a_magic, a_base_shift, b_magic, b_base_shift;
compute_magic(a, a_magic, a_base_shift);
uint64_t a_total_shift = a_base_shift + 32;
compute_magic(b, b_magic, b_base_shift);
uint64_t b_total_shift = b_base_shift + 32;
uint64_t best = 0;
for (uint64_t i = 0; i <= limit; i += 8) {
if (i + 7 > limit) {
for (uint64_t j = i; j <= limit; j++) {
uint64_t n_remain = n - a * j;
uint64_t y1 = (n_remain * b_magic) >> b_total_shift;
uint64_t m_remain = m - b * j;
uint64_t y2 = (m_remain * a_magic) >> a_total_shift;
uint64_t y = min(y1, y2);
uint64_t total = j + y;
if (total > best) best = total;
}
break;
}
__m512i x_vec = _mm512_set_epi64(i+7, i+6, i+5, i+4, i+3, i+2, i+1, i);
__m512i a_vec = _mm512_set1_epi64(a);
__m512i ax_vec = _mm512_mullo_epi64(x_vec, a_vec);
__m512i n_vec = _mm512_set1_epi64(n);
__m512i n_remain_vec = _mm512_sub_epi64(n_vec, ax_vec);
__m512i b_magic_vec = _mm512_set1_epi64(b_magic);
__m512i product1 = _mm512_mullo_epi64(n_remain_vec, b_magic_vec);
__m512i b_shift_vec = _mm512_set1_epi64(b_total_shift);
__m512i y1_vec = _mm512_srlv_epi64(product1, b_shift_vec);
__m512i b_vec = _mm512_set1_epi64(b);
__m512i bx_vec = _mm512_mullo_epi64(x_vec, b_vec);
__m512i m_vec = _mm512_set1_epi64(m);
__m512i m_remain_vec = _mm512_sub_epi64(m_vec, bx_vec);
__m512i a_magic_vec = _mm512_set1_epi64(a_magic);
__m512i product2 = _mm512_mullo_epi64(m_remain_vec, a_magic_vec);
__m512i a_shift_vec = _mm512_set1_epi64(a_total_shift);
__m512i y2_vec = _mm512_srlv_epi64(product2, a_shift_vec);
__m512i y_vec = _mm512_min_epu64(y1_vec, y2_vec);
__m512i total_vec = _mm512_add_epi64(x_vec, y_vec);
uint64_t total_arr[8];
_mm512_storeu_si512((__m512i*)total_arr, total_vec);
for (int j = 0; j < 8; j++) {
if (total_arr[j] > best) best = total_arr[j];
}
}
cout << best << endl;
return 0;
}
回复
共 2 条回复,欢迎继续交流。
正在加载回复...