golang

Решение задачи с Route 256 на goalng

  • понедельник, 12 мая 2025 г. в 00:00:04
https://habr.com/ru/articles/908384/

В этой статье разбирается решение задачи «Гистограммы» с контеста Route 256 от Ozon.

Ссылочки:
Assembler в Go: техники ускорения и оптимизации / Хабр
Руководство по ассемблеру Go / Хабр
Часть 1. Почему Go-ассемблер и векторизация могут быть полезны: идея для ускорения / Хабр

Условие задачи

Гистограммой является массив, каждый элемент которого указывает высоту столбика на соответствующей позиции. Две гистограммы считаются совпадающими, если при совмещении одной гистограммы с другой гистограммой, повёрнутой на угол 180°, получается ровный прямоугольник без наложений и пропусков.

Иллюстрация

Пример теста

1 2 4
1 3 4
1 4 3

Ответ: 1 пара

Тривиальное решение:

func profile(a []uint32) string {
	n := len(a)
	res := make([]uint32, n-1)
	for i := 1; i < n; i++ {
		res[i-1] = a[i] - a[i-1]
	}

	b := unsafe.Slice((*byte)(unsafe.Pointer(&res[0])), (n-1)*4)
	return *(*string)(unsafe.Pointer(&b))
}

func reverseProfile(a []uint32) string {
	n := len(a)
	res := make([]uint32, n-1)
	for i := n - 1; i > 0; i-- {
		res[n-1-i] = a[i] - a[i-1]
	}

	b := unsafe.Slice((*byte)(unsafe.Pointer(&res[0])), (n-1)*4)
	return *(*string)(unsafe.Pointer(&b))
}

func processHistograms(histos [][]uint32) int {
	profCount := make(map[string]int)
	for i := 0; i < len(histos); i++ {
		p := profile(histos[i])
		profCount[p]++
	}

	res := 0
	for i := 0; i < len(histos); i++ {
		rp := reverseProfile(histos[i])
		res += profCount[rp]
		// если профиль и реверс совпадают, не считать пару (a,a)
		if profile(histos[i]) == rp {
			res--
		}
	}
	return res / 2
}

Мы вычисляем приросты гистограмм и находи соответствующие пары, приросты которых совпадают. Ничего сложного.

Это решение проходит тесты с контеста, но я решил попробовать ускорить вычисление с помощью SIMD.

Ускорить func profile довольно просто.

// func calcDifferences(src []int32, dst []int32)
// комменты от GPT-4.1
TEXT ·calcDifferences(SB), $0-48
    MOVQ src_base+0(FP),    SI  // Загружаем базовый адрес исходного массива
    MOVQ src_len+8(FP),     CX  // Загружаем длину исходного массива
    MOVQ dst_base+24(FP),   DI  // Загружаем базовый адрес массива назначения
    
    CMPQ CX, $1                 // Проверяем, если длина <= 1
    JLE  done                   // Если да, завершаем работу
    
    DECQ CX                     // Длина dst = длина src - 1
    XORQ AX, AX                 // Инициализируем счетчик индекса нулем

    // Проверяем, достаточно ли элементов для AVX обработки (минимум 8)
    CMPQ CX, $8
    JL   scalar_loop            // Если меньше 8, переходим к скалярной обработке

    // Вычисляем количество элементов для AVX обработки (кратно 8)
    MOVQ CX, DX
    ANDQ $-8, DX                // Округляем вниз до ближайшего кратного 8

avx_loop:
    VMOVDQU (SI)(AX*4), Y0      // Загружаем 8 элементов из src в YMM0
    VMOVDQU 4(SI)(AX*4), Y1     // Загружаем следующие 8 элементов (со сдвигом +1)
    VPSUBD Y0, Y1, Y2           // Вычисляем разницы (Y1 - Y0)
    VMOVDQU Y2, (DI)(AX*4)      // Сохраняем 8 результатов в dst
    
    ADDQ $8, AX                 // Увеличиваем индекс на 8
    CMPQ AX, DX                 // Сравниваем с количеством обработанных элементов
    JL   avx_loop               // Продолжаем цикл, если есть элементы для AVX обработки

    // Обрабатываем оставшиеся элементы скалярно
    CMPQ AX, CX
    JGE  done                   // Если все элементы обработаны, завершаем

scalar_loop:
    MOVL (SI)(AX*4), BX         // Загружаем src[i] в регистр BX
    MOVL 4(SI)(AX*4), DX        // Загружаем src[i+1] в регистр DX
    SUBL BX, DX                 // Вычисляем разницу: DX = src[i+1] - src[i]
    MOVL DX, (DI)(AX*4)         // Сохраняем результат в dst[i]
    
    INCQ AX                     // Увеличиваем индекс
    CMPQ AX, CX                 // Сравниваем с длиной массива
    JL   scalar_loop            // Продолжаем, если индекс < длины

done:
    RET                         // Возврат из функции
// Стандартная реализация на Go для проверки
func CalcDifferencesGo(arr []uint32) []uint32 {
	if len(arr) <= 1 {
		return []uint32{}
	}

	result := make([]uint32, len(arr)-1)
	for i := 0; i < len(arr)-1; i++ {
		result[i] = arr[i+1] - arr[i]
	}
	return result
}

//go:noescape
func calcDifferences(src []uint32, dst []uint32)

// CalculateGrowth - ускоренная реализация на SIMD
func CalcDifferencesASM(arr []uint32) []uint32 {
	if len(arr) <= 1 {
		return []uint32{}
	}

	if len(arr) <= 4 {
		return CalcDifferencesGo(arr)
	}

	result := make([]uint32, len(arr)-1)
	calcDifferences(arr, result)
	return result
}

Как это работает

Мы сохраняем в векторный регистр Y0 восемь значений по 32 бита, а регистр Y1 со смешением на один элемент. Получив векторную разницу между регистрами, мы получим 8 элементов прироста гистограмм сразу.

«Хвост» данных, которые не войдут полностью в векторный регистр, обрабатываем скалярно.

Полученные результаты для длин гистограмм от 10 до 100000
Ускорение в ~2 раза (а хотелось бы в 4~8)

goos: windows
goarch: amd64
pkg: qwe
cpu: AMD Ryzen 5 8400F 6-Core Processor
Benchmark_profile/Go-10-12               67875990                19.03 ns/op
Benchmark_profile/SIMD-10-12             66618182                18.76 ns/op
Benchmark_profile/Go-100-12              13702243                92.77 ns/op
Benchmark_profile/SIMD-100-12            15909126                77.92 ns/op
Benchmark_profile/Go-1000-12              1294166                954.4 ns/op
Benchmark_profile/SIMD-1000-12            1835912                636.4 ns/op
Benchmark_profile/Go-10000-12              144406                 7867 ns/op
Benchmark_profile/SIMD-10000-12            277306                 4469 ns/op
Benchmark_profile/Go-100000-12              19545                63121 ns/op
Benchmark_profile/SIMD-100000-12            43426                28094 ns/op

Ускорение func reverseProfile чуть сложнее

Нам нужно не только делать векторное вычитание, но и перестановку элементов в векторе.

// func calcReverseDifferences(src []int32, dst []int32)
// комменты от GPT-4.1
TEXT ·calcReverseDifferences(SB), $0-48
    MOVQ src_base+0(FP), SI     // SI = &src[0]
    MOVQ src_len+8(FP), CX      // CX = len(src)
    MOVQ dst_base+24(FP), DI    // DI = &dst[0]
    
    CMPQ CX, $1                 // Проверяем минимальную длину
    JLE  done
    
    DECQ CX                     // CX = len(dst) = len(src)-1
    MOVQ CX, AX                 // AX будет индексом в dst (идем с конца)
    XORQ DX, DX                 // DX = индекс в src (идем с начала)

    // Проверяем, достаточно ли элементов для AVX
    CMPQ CX, $8
    JL   scalar_loop_start

    // AVX-обработка (блоками по 8 элементов)
    MOVQ CX, R8
    ANDQ $-8, R8                // R8 = количество элементов, кратных 8
    MOVQ R8, R9                 // R9 = счетчик обработанных элементов

avx_loop:
    // Загружаем текущие 8 элементов
    VMOVDQU (SI)(DX*4), Y0      // Y0 = src[DX..DX+7]
    // Загружаем следующие 8 элементов (сдвиг на 1)
    VMOVDQU 4(SI)(DX*4), Y1     // Y1 = src[DX+1..DX+8]
    
    // Вычисляем разницы
    VPSUBD Y0, Y1, Y2           // Y2 = разницы
    
    // Реверсируем порядок элементов
    VPERM2I128 $0x01, Y2, Y2, Y3  // Меняем местами 128-битные половины
    VPSHUFD $0x1B, Y3, Y3       // Реверсируем порядок в каждой половине
    
    // Сохраняем в обратном порядке
    MOVQ AX, R10
    SUBQ $8, R10                // R10 = AX-7 (начальная позиция для записи)
    VMOVDQU Y3, (DI)(R10*4)     // Сохраняем 8 элементов
    
    ADDQ $8, DX                 // Увеличиваем счетчик src
    SUBQ $8, AX                 // Уменьшаем счетчик dst
    CMPQ DX, R8                 // Проверяем завершение AVX-цикла
    JL   avx_loop

scalar_loop_start:
    DECQ AX
scalar_loop:
    CMPQ DX, CX                 // Проверяем завершение
    JGE  done
    
    MOVL (SI)(DX*4), BX         // Загружаем src[i]
    MOVL 4(SI)(DX*4), BP        // Загружаем src[i+1]
    SUBL BX, BP                 // Вычисляем разницу
    MOVL BP, (DI)(AX*4)         // Сохраняем в dst
    
    INCQ DX                     // Следующий элемент src
    DECQ AX                     // Предыдущая позиция dst
    JMP  scalar_loop

done:
    RET
// Стандартная реализация на Go для проверки
func CalcReverseDifferencesGo(arr []uint32) []uint32 {
	if len(arr) <= 1 {
		return []uint32{}
	}

	result := make([]uint32, len(arr)-1)
	for i := len(arr) - 1; i > 0; i-- {
		result[len(arr)-1-i] = arr[i] - arr[i-1]
	}
	return result
}

//go:noescape
func calcReverseDifferences(src []uint32, dst []uint32)

// CalculateGrowth - обертка, которая создает массив назначения
func CalReversecDifferencesASM(arr []uint32) []uint32 {
	if len(arr) <= 1 {
		return []uint32{}
	}

	if len(arr) <= 4 {
		return CalcReverseDifferencesGo(arr)
	}

	result := make([]uint32, len(arr)-1)
	calcReverseDifferences(arr, result)
	return result
}

Все также как и с calcDifferences но другой порядок работы с индексами

VPERM2I128 $0x01, Y2, Y2, Y3

Эта инструкция берет два 256-битных регистра (в данном случае оба — Y2) и переставляет их 128-битные половины.

VPSHUFD $0x1B, Y3, Y3

Эта инструкция переставляет 32-битные элементы внутри каждой 128-битной половины регистра Y3.

Иллюстрация

Полученные результаты ускорения идентичны calcDifferences

goos: windows
goarch: amd64
pkg: qwe
cpu: AMD Ryzen 5 8400F 6-Core Processor
Benchmark_reverse_profile/Go-10-12                57465208             20.66 ns/op
Benchmark_reverse_profile/ASM-10-12               61966186             19.28 ns/op
Benchmark_reverse_profile/Go-100-12               11149742             107.8 ns/op
Benchmark_reverse_profile/ASM-100-12              15162312             76.59 ns/op
Benchmark_reverse_profile/Go-1000-12               1000000              1022 ns/op
Benchmark_reverse_profile/ASM-1000-12              1756675             679.6 ns/op
Benchmark_reverse_profile/Go-10000-12               129410              9519 ns/op
Benchmark_reverse_profile/ASM-10000-12              208200              4928 ns/op
Benchmark_reverse_profile/Go-100000-12               15706             76804 ns/op
Benchmark_reverse_profile/ASM-100000-12              38287             30263 ns/op

Итого совокупный прирост для набора гистограмм с 200000 парами составил

2.80 раза (180.07% быстрее)

func TestHistogramProcessing(t *testing.T) {
	// Инициализация генератора случайных чисел
	seed := time.Now().UnixNano()
	rand.Seed(seed)
	t.Logf("Используем seed: %d", seed)

	const lengthHistograms = 10_000
	const numHistograms = 10_000
	const numPair = 20
	const numRandom = 10000

	var Histogram [][]uint32

	for i := 0; i < numHistograms; i++ {
		qwe, asd := generatePairedHistograms(lengthHistograms, numPair)
		Histogram = append(Histogram, asd...)
		Histogram = append(Histogram, qwe)
	}

	// Создаем случайных гистограмм
	for i := 0; i < numRandom; i++ {
		Histogram = append(Histogram, generateRandomHistogram(lengthHistograms))
	}

	// Замер времени для processHistograms
	start1 := time.Now()
	result1Paired := processHistograms(Histogram)
	duration1 := time.Since(start1)

	// Замер времени для processHistogramsAVX
	start2 := time.Now()
	result2Paired := processHistogramsAVX(Histogram)
	duration2 := time.Since(start2)

	if result1Paired != result2Paired {
		t.Errorf("Результаты для парных гистограмм не совпадают: processHistograms=%d, processHistogramsAVX=%d",
			result1Paired, result2Paired)
	} else {
		fmt.Printf("Результаты для парных гистограмм совпадают: %d \n", result1Paired)
		t.Logf("Результаты для парных гистограмм совпадают: %d", result1Paired)
	}

	// Выводим замеры времени
	t.Logf("Время выполнения processHistograms: %v", duration1)
	t.Logf("Время выполнения processHistogramsAVX: %v", duration2)

	// Вычисляем во сколько раз одна функция быстрее другой
	var faster, slower time.Duration
	var fasterName, slowerName string
	if duration1 < duration2 {
		faster, slower = duration1, duration2
		fasterName, slowerName = "processHistograms", "processHistogramsAVX"
	} else {
		faster, slower = duration2, duration1
		fasterName, slowerName = "processHistogramsAVX", "processHistograms"
	}

	speedup := float64(slower) / float64(faster)
	percentDiff := (speedup - 1) * 100

	t.Logf("%s быстрее %s примерно в %.2f раза (%.2f%% быстрее)",
		fasterName, slowerName, speedup, percentDiff)

	fmt.Printf("%s быстрее %s примерно в %.2f раза (%.2f%% быстрее)\n",
		fasterName, slowerName, speedup, percentDiff)
}