golang

Применения Continuation-passing style в Go

  • воскресенье, 25 августа 2024 г. в 00:00:06
https://habr.com/ru/articles/836702/

В этой статье мы рассмотрим концепцию программирования в стиле передачи продолжений и примеры его применения, исследуем, как этот стиль может улучшить читаемость и поддержку кода в приложениях на Go. Также обсудим потенциальные подводные камни и ограничения, чтобы дать полное представление о том, как и когда использовать его в практике разработки.

Введение

При обычном (Direct style) вызове мы подаём на вход функции параметры и на выходе ожидаем какое-то значение. Например, функция сложения:

func add(x, y int) int {
	return x + y
}

// Использование функции add
res := add(1, 2)
fmt.Println(res)

При использовании Continuation-passing style (CPS) к списку параметров добавляется функция-продолжение k

func addCps(x, y int, k func(res int)) {
	k(x + y)
}

// Использование функции addCps
addCps(1, 2, func(res int) { fmt.Println(res) })

При использовании CPS мы получаем возможности, которые недоступны в обычном Direct style программировании - теперь функция контролирует поток исполнения. То есть следуя своей внутренней логике функция может запустить продолжение дважды, сохранить его и исполнить позже или не вызывать продолжение вовсе.

func addCps(x, y int, k func(res int)) {
    if x == 0 {
		// Выходим без вызова продолжения
		return
    }
	if x >= y {
		// Вызываем продолжение дважды - сначала с суммой, потом с нулем
		k(x + y)
		k(0)
	} else {
		// Однократное исполнение продолжения
		k(x + y)
	}
}

Разделение бизнес-логики и служебного кода

Допустим у нас есть структура данных LinkedList

type LinkedList struct {
	Head int
	Tail *LinkedList
}

Добавим метод, чтобы вывести его содержимое на консоль

func (list *LinkedList) Print() {
	// Служебный код
	for cur := list; cur != nil; cur = cur.Tail {
		// Бизнес-логика
		fmt.Printf("%d ", cur.Head)
	} 
}

А теперь нам потребовалось подсчитать, например, сумму элементов списка

func (list *LinkedList) Sum() int {
	sum := 0
	// Служебный код
	for cur := list; cur != nil; cur = cur.Tail {
		// Бизнес-логика
		sum += cur.Head
	} 
	return sum
}

Кажется запахло копипастой. А если это библиотечная структура данных, то пользователи так же будут обязаны копировать весь этот бойлерплейт. Служебный код лучше инкапсулировать, а бизнес-логику поместить за абстракцией продолжения:

func (list *LinkedList) Traverse(k func(val int)) {
	for cur := list; cur != nil; cur = cur.Tail {
		k(cur.Head)
	}
}

Отрефакторим методы Print и Sum

func (list *LinkedList) Print() {
	list.Traverse(func(val int) { fmt.Printf("%d ", val) })
}

func (list *LinkedList) Sum() int {
	sum := 0
	list.Traverse(func(val int) { sum += val })
	return sum
}

Метод Traverse всегда будет проходить по списку от начала и до конца. Это не всегда нужно, иногда требуется досрочно завершить итерацию при наступлении какого-либо условия. Пусть пользователь сам сообщит, когда надо прервать алгоритм, для этого изменим у продолжения возвращаемый тип с void на bool

Значение true пусть кодирует продление итераций, а false завершение.

func (list *LinkedList) Traverse2(k func(val int) bool) {
	for cur := list; cur != nil; cur = cur.Tail {
		// Теперь проверяем, что вернуло продолжение
		keepGoing := k(cur.Head)

		if !keepGoing {
			// keepGoing == false - это сигнализирует о завершении
			break
		}
	}
}

На основе Traverse2 можно написать метод поиска Contains(x), который завершит перебор элементов, когда найдено первое вхождение:

func (list *LinkedList) Contains(x int) bool {
	found := false
	list.Traverse2(func(val int) bool {
		if val == x {
			found = true
			return false
		}
		return true
	})
	return found
}

Продолжения в Go stdlib

Подобный подход применяется в недавно представленной конструкции Range-over-func [1].

Итератор для range имеет минимальные отличия от нашего рукописного:

func (list *LinkedList) Iter() func(func(int) bool) {
	iterator := func(yield func(int) bool) {
		for cur := list; cur != nil; cur = cur.Tail {
			if !yield(cur.Head) {
				return
			}
		}
	}
	return iterator
}

Главное отличие - метод Iter() не принимает продолжение как аргумент, вместо этого он написан в виде каррированной функции. Iter() можно воспринимать как конструктор или "билдер" для итератора, в него можно поместить код инициализации и прочие настройки.

Сам код итератора соответствует нашему Traverse2 с точностью до переименования k на yield. Стоит обратить внимание, что итератор ожидает на вход продолжение, а тело оператора range не является функцией, но компилятор сам позаботится об этом и автоматически приведёт тело к виду func(...) bool

Пример использования Iter

// Функция для печати на консоль
printIt := func(x int) bool { println(x); return true }

// Итератор можно применить в конструкции range
for x := range list.Iter() {
    println(x)
}

// Можно вызвать напрямую
iterator := list.Iter()
// Итератор это функция, к которой нужно применить продолжение
iterator(printIt)

// Более краткая запись
list.Iter()(printIt)

Трамплины

Известно, что многие алгоритмы в программировании элегантно решаются с помощью рекурсии: обработка древовидных данных (json / xml / структура каталогов), графов и т.д. Рекурсия всем хороша, кроме одного - в большинстве языков программирования существует ограничение на глубину стека вызовов, что может привести к ошибке переполнения стека.

Существует несколько решений это проблемы: переписывание алгоритма без использования рекурсивного вызова, приведение к виду хвостовой рекурсии для задействования оптимизации хвостового вызова (Tail-call optimization - TCO)

В компиляторе Go не поддерживается TCO, но мы можем вручную применить технику для оптимизации вызовов - так называемый трамплин. Суть паттерна в следующем: трамплин принимает на вход функцию, при завершении работы функция возвращает не результат, а следующее продолжение, трамплин в цикле исполняет это продолжение и ожидает от него другое продолжение и так далее.

Рассмотрим этот паттерн на примере. Вернемся к обычному связному списку и напишем итератор в рекурсивном стиле:

func IterRec(list *LinkedList, k func(v int)) {
	if list == nil {
		// Дошли до конца списка - завершаем итерацию
		return
	}

	// Иначе вызываем продолжение на текущей голове списка
	k(list.Head)

	// Рекурсивный вызов для хвоста списка
	IterRec(list.Tail, k)
}

Разумеется, эта функция упадет с ошибкой переполнения на большом наборе данных. Но ведь каждый следующий шаг мы можем вычислять лениво (lazy). Для этого во-первых объявим тип, представляющий ленивые рекурсивные вычисления:

// Объявляем рекурсивный тип функции
type Thunk func() Thunk

Теперь завернем рекурсивный вызов в лениво вычисляемую обертку (thunk):

func IterRec(list *LinkedList, k func(v int)) Thunk {
	if list == nil {
		// Дошли до конца списка - вернём пустое продолжение
		return nil
	}

	// Иначе вызываем продолжение k на текущей голове списка
	k(list.Head)

	// Ленивое продолжение для хвоста списка
	return func() Thunk { return IterRec(list.Tail, k) }
}

Отлично, теперь ленивые вычисления лежат в хипе, а не стеке. Так как это всего лишь один call frame, то его объем небольшой.

Ленивые вычисления сами собой не исполнятся, нужен тот кто их запустит. В нашем случае запускать их будет трамплин:

// Запуск отложенных вычислений
func RunTrampoline(initial Thunk) {
	thunk := initial()
	for thunk != nil {
		thunk = thunk()
	}
}

Запустим какое-то вычисление на списке.

max := list.Head
findMax := func(v int) {
	if v > max {
		max = v
	}
}
RunTrampoline(IterRec(list, findMax))
println(max)

Управление ресурсами

Типичный алгоритм при работе с ресурсами это:

  • запросить ресурс (например открыть файл)

  • произвести действия с ресурсом (прочитать / записать)

  • освободить ресурс (закрыть файл)

На любом шаге может возникнуть ошибка.

func writeFile(path, content string) error {
    file, err := os.OpenFile(path, os.O_CREATE | os.O_WRONLY, 0600)
    if err != nil {
        return err
    }
    defer file.Close()

    _, err = file.Write([]byte(content))
    if err != nil {
        return err
    }

    return nil
}

Мало того, что код наполнен низкоуровневыми деталями, так и сам интерфейс для работы с файловыми ресурсами никак не ограждает нас от некорректного использования - мы вполне можем забыть про close() и ресурс утечет, но ни ошибки компиляции, ни warning вы не увидите. Здесь бы применить линейные типы, но увы в go их не завезли.

Мы уже решали задачу по разделению кода на системную и пользовательскую части, и теперь можем применить похожий паттерн для работы с файловыми ресурсами.

// Файловый ресурс
type FileResource = func(cont FileContinuation) error

// Функция-продолжение для инкапсуляции бизнес-логики
type FileContinuation = func(fd *os.File) error

func WorkWithFile(path string, flags int, perm os.FileMode) FileResource {
	// Каррированный инициализатор ресурса
	return func(cont FileContinuation) error {
		// Системные вызовы
		file, err := os.OpenFile(path, flags, perm)
		if err != nil {
			return err
		}
		defer file.Close()

		// Пользовательская бизнес-логика
		err = cont(file)
		return err
	}
}

Теперь функцию WorkWithFile достаточно протестировать один раз и потом переиспользовать во всем проекте, уже не задумываясь о корректности работы с файлами.

func main() {
    // Файл не будет открыт здесь
    // FileResource ожидает на вход функцию-продолжение
    fileRes := WorkWithFile("./file.txt", os.O_CREATE | os.O_WRONLY, 0600)
    // ...

    // Файл откроется только здесь
    err := fileRes(myBusinessLogic)
    // А тут он уже закрыт

    // Передаем этот же ресурс в другую функцию
    // Файл повторно откроется
    err = fileRes(otherBusinessLogic)
}

func myBusinessLogic(fd *os.File) error {
    // Работаем с файловым дескриптором fd
}

Автоматический commit/rollback транзакций

Транзакции это еще один ресурс с которым необходимо аккуратно работать.

// Транзакционный ресурс
type TxResource = func(TxContinuation) error

// Функция-продолжение для инкапсуляции бизнес-логики
type TxContinuation = func(tx *sql.Tx) error

// Конструктор ресурса
func Transaction(db *sql.DB) TxResource {
	return func(cont func(tx *sql.Tx) error) error {
		// Стартуем транзакцию
		tx, err := db.Begin()
		if err != nil {
			return err
		}

		// Исполняем транзакционный код
		err = cont(tx)

		// Коммит или откат транзакции
		if err != nil {
			_ = tx.Rollback()
			return err
		} else {
			return tx.Commit()
		}
	}
}

Пример использования

func execInTransaction(db *sql.DB) (string, error) {
    var result string
    err := Transaction(db)(func(tx *sql.Tx) error {
        res, err := tx.Query("some query")
        if err != nil {
            return err
        }
        
        result = "some calculated result"
        return nil
    })
    return result, err
}

В транзакционном коде при ошибке достаточно просто вернуть err и rollback запустится автоматически.

Больше никаких дедлоков в sync.WaitGroup

Вы могли видеть следующий пример использования синхронизации:

func worker(id int, wg *sync.WaitGroup) {
    fmt.Printf("Worker %d starting\n", id)
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done\n", id)
    wg.Done()
}

func main() {
    var wg sync.WaitGroup
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }
    wg.Wait()
}

Операции получения и освобождения ресурса (счетчика wg) хаотично разбросаны по коду. Можем случайно пропустить wg.Done или неправильно укажем значение в методе wg.Add. Зачем бизнес-логике знать о wg? И еще надо помнить, что wg нужно передавать по ссылке. Весь системный код лучше поместить в обертку SafeWaitGroup:

type Spawner interface {
	Run(task func())
}

type SafeWaitGroup interface {
	Spawner
	Wait()
}

type safeWaitGroup struct {
    wg *sync.WaitGroup
}

func NewSafeWaitGroup() SafeWaitGroup {
    return &safeWaitGroup{new(sync.WaitGroup)}
}

func (swg *safeWaitGroup) Run(task func ()) {
    swg.wg.Add(1)
    go func() {
        task()
        swg.wg.Add(-1)
    }()
}

func (swg *safeWaitGroup) Wait() {
    swg.wg.Wait()
}

func RunGroup(taskRunner func(Spawner)) {
	swg := NewSafeWaitGroup()
	taskRunner(swg)
	swg.Wait()
}

Теперь пример можно переписать в более безопасном виде.

func worker(id int) {
	fmt.Printf("Worker %d starting\n", id)
	time.Sleep(time.Second)
	fmt.Printf("Worker %d done\n", id)
}

func main() {
	RunGroup(func(spawner Spawner) {
		for i := 1; i <= 5; i++ {
			i := i // замыкание текущего значения нужно до версии 1.22
			spawner.Run(func() { worker(i) })
		}
	})
}

Метод Run() ожидает лениво запускаемую функцию и это тоже можно оформить в абстракцию:

func Suspended[A any](arg A, k func(arg A)) func() {
	return func() {
		k(arg)
	}
}

func main() {
	RunGroup(func(spawner Spawner) {
		for i := 1; i <= 5; i++ {
			i := i // замыкание текущего значения нужно до версии 1.22
			spawner.Run(Suspended(i, worker))
		}
	})
}

Заключение

Мы разобрались как с помощью CPS сделать инверсию control flow, как скрывать системные детали реализации и безопасно управлять ресурсами, таким образом получить надежный и читаемый код.

Источники

  1. Go Wiki: Rangefunc Experiment

  2. Functional programming in Golang