python

Часть 4. Ищем матчи в Dota 2 по названиям роликов на YouTube с помощью BERT и OpenDota

  • суббота, 20 августа 2022 г. в 00:39:42
https://habr.com/ru/post/682480/
  • Поисковые технологии
  • Python
  • Data Mining
  • Машинное обучение


Представьте, что с одной стороны у вас есть видео на YouTube с интересными моментами из матча по Dota 2. А с другой стороны база данных всех матчей по Dota 2. Как для видео найти соответствующую запись в БД? Этой задачей мы сегодня и займемся.

В данном цикле статей мы реализовываем систему автоматического поиска хайлайтов в матчах Dota 2. Для ее создания нам требуется размеченный датасет с тайм-кодами. На YouTube есть множество каналов, где люди выкладывают нарезки с интересными моментами из профессиональных матчей по Dota 2.

В предыдущих частях:

  1. В первой части мы распарсили реплей одного матча по Dota 2 и нашли хайлайты с помощью кластеризации.

  2. Во второй части мы написали сервис для параллельного парсинга реплеев на Celery и Flask.

  3. В третьей части мы научились скачивать видео с нарезками хайлайтов с YouTube, семплировать кадры и распознавать на них время.

Под катом

  1. Получаем заголовок и время публикации видео

  2. Находим матчи с помощью API OpenDota

  3. Краткое отступление про BERT

  4. Строим эмбединги

  5. Поиск матча по заголовку видео

  6. Заключение

  7. Все ссылки на код и использованные материалы вы найдете в конце статьи.

Получаем заголовок и время публикации видео

Напишем простую функцию с помощью уже знакомой нам по предыдущей части библиотеки yt_dlp, которая на вход принимает URL видео, а на выходе возвращает метаданные.

...
import yt_dlp


@lru_cache
def get_video_metadata(url: str, save: bool = False) -> Tuple[str | int, Dict]:
    options = dict()
    with yt_dlp.YoutubeDL(options) as ydl:
        metadata = ydl.extract_info(url, download=False)
    video_id = metadata.get('id')

    if save:
        metadata_file = f'{video_id}.json'
        video_metadata_path = VIDEO_DIR / metadata_file
        with open(video_metadata_path, 'w') as fout:
            json.dump(metadata, fout, indent=4)
    return video_id, metadata

Словарь metadata содержит в себе заголовок видео, время загрузки, ссылку на канал и многое другое. Возьмем в качестве примера матч Team Spirit против Team Secret.

url = 'https://www.youtube.com/watch?v=cXA5Hw2boLA&ab_channel=DotADigest'
video_id, metadata = get_video_metadata(url)
print(metadata)

> {
    "id": "cXA5Hw2boLA",
    "title": "SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 Dota 2 Highlights",
    "thumbnail": "https://i.ytimg.com/vi/cXA5Hw2boLA/hqdefault.jpg",
    "description": "DOTA 2 TS SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 by Gamers8 Dota 2 Highlights 2022 Tournament - Semi Final Playoff #dota2 #dpc  \nWatch Live Riyadh Masters Dota 2: https://www.twitch.tv/gamers8gg - Commentary by Gareth & Lacoste",
    "uploader": "DotA Digest",
    "upload_date": "20220724",
    "uploader_id": "RUDota2TV",
    "uploader_url": "http://www.youtube.com/user/RUDota2TV",
    "channel_id": "UCUqLL4VcEy4mXcQL0O_H_bg",
    "channel_url": "https://www.youtube.com/channel/UCUqLL4VcEy4mXcQL0O_H_bg",
    "duration": 1648,
    "view_count": 50716,
    "average_rating": null,
    "age_limit": 0,
    "webpage_url": "https://www.youtube.com/watch?v=cXA5Hw2boLA",
 ...

Находим матчи с помощью API OpenDota

Для простоты ограничимся анализом канала DotA Digest. Обычно автор указывает в заголовке название команд и турнир. Также мы знаем дату загрузки видео. Попробуем на основе этой информации найти Replay.

Идея простая. Видео с нарезками хайлайтов выходит после того, как матч уже сыгран. При этом авторы YouTube-каналов стремятся выпустить видео как можно раньше, чтобы собрать побольше просмотров. Значит мы можем взять все профессиональные матчи за несколько дней до даты публикации видео и скорее всего один из них окажется искомым (запечатлённым на видео).

Ссылку на реплей можно получить, зная match_id — уникальный идентификатор матча, который генерирует сама игра. Для этого воспользуемся API OpenDota.

Напишем простой wrapper для /explorer endpoint, который позволяет делать запросы к базе данных OpenDota.

import requests
from tenacity import retry, stop_after_attempt, wait_fixed


@retry(stop=stop_after_attempt(7), wait=wait_fixed(10))
def query_opendota(sql: str, **kwargs: dict) -> List[Dict]:
    query = sql.format(**kwargs)
    logger.debug(query)

		r = requests.get(
        'https://api.opendota.com/api/explorer', 
        params=dict(sql=query)
    )
    r.raise_for_status()
    result = r.json()
    rows = result['rows']
    return rows

Напишем функцию, которая находит профессиональные матчи за предыдущие два дня относительно выбранной даты.

from datetime import datetime, timedelta


def get_nearest_matches(date: datetime) -> List[Dict]:
    end_time = date + timedelta(days=1)
    start_time = date - timedelta(days=2)
    query = '''
    SELECT
        match_id,
        start_time,
        matches.leagueid,
        leagues.name as league,
        radiant_team_id,
        radiant_team.name as radiant_name,
        radiant_team.tag as radiant_tag,
        dire_team_id,
        dire_team.name as dire_name,
        dire_team.tag as dire_tag
    FROM
        matches
        join teams as dire_team on matches.dire_team_id = dire_team.team_id
        join teams as radiant_team on
             matches.radiant_team_id = radiant_team.team_id
        join leagues on matches.leagueid = leagues.leagueid
    WHERE
        start_time >= extract(epoch from timestamp '{start_time}')
        and start_time < extract(epoch from timestamp '{end_time}')
    ORDER BY
        start_time desc
    LIMIT
    	  500
    '''
    matches = query_opendota(
        query,
        start_time=datetime.strftime(start_time, '%m-%d-%Y'),
        end_time=datetime.strftime(end_time, '%m-%d-%Y')
    )
    return matches

Проверим на все том же матче Team Spirit против Team Secret. Для этого распарсим дату публикации видео из метаданных. Найдем все профессиональные матчи, которые были сыграны в день публикации и в предыдущие два дня.

url = 'https://www.youtube.com/watch?v=cXA5Hw2boLA&ab_channel=DotADigest'
video_id, metadata = get_video_metadata(url)
upload_date = metadata['upload_date']
upload_date = datetime.strptime(upload_date, '%Y%m%d')
matches = get_nearest_matches(upload_date)
print(len(matches))

> 288

Всего матчей 288. Среди них есть 12 матчей с участием команды Team Spirit, 3 за Radiant и 9 за Dire.

print([m for m in matches if m['radiant_name'] == 'Team Spirit'])

> [
    {
        "match_id": 6676488286,
        "start_time": 1658692615,
        "leagueid": 14391,
        "league": "Riyadh Masters by Gamers8",
        "radiant_team_id": 7119388,
        "radiant_name": "Team Spirit",
        "radiant_tag": "TSpirit",
        "dire_team_id": 15,
        "dire_name": "PSG.LGD",
        "dire_tag": "PSG.LGD"
    },
    {
        "match_id": 6676051545,
        "start_time": 1658673115,
        "leagueid": 14391,
        "league": "Riyadh Masters by Gamers8",
        "radiant_team_id": 7119388,
        "radiant_name": "Team Spirit",
        "radiant_tag": "TSpirit",
        "dire_team_id": 1838315,
        "dire_name": "Team Secret",
        "dire_tag": "Secret"
    },
    {
        "match_id": 6672690521,
        "start_time": 1658491214,
        "leagueid": 14391,
        "league": "Riyadh Masters by Gamers8",
        "radiant_team_id": 7119388,
        "radiant_name": "Team Spirit",
        "radiant_tag": "TSpirit",
        "dire_team_id": 8291895,
        "dire_name": "Tundra Esports",
        "dire_tag": "Tundra"
    }
]

Нюанс в том, что Spirit играли не только с Team Secret, но и с PSG.LGD, и с Tundra Esports. При этом из заголовка видео понятно, что оно относится только к матчу Secret - Spirit.

print(metadata['fulltitle'])
> "SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 Dota 2 Highlights"

Как же отфильтровать матчи в БД, подходящие под заголовок видео?

Можно было бы воспользоваться регулярными выражениями. Но заголовок видео пишет человек, поэтому в названиях команд и турниров могут быть опечатки и сокращения.

Можно было бы пойти дальше и использовать расстояние Левинштейна, чтобы обойти проблему опечаток и сокращений. Но в названии видео также могут содержаться комментарии и кликбейтные словосочетания.

Но я решил стрельнуть из пушки по воробьям. Благо в наше время все сильнее набирает популярность новое, модное, молодежное явление — нейросети.

Краткое отступление про BERT

Для решения данной задачи я решил использовать нейросетьDistillBERT, которая по сути своей является оптимизированной по скорости и памяти версией BERT. Она позволяет отображать текст в "хорошие" числовые векторы, которые потом можно между собой сравнивать.

BERT — модель, обученная предсказывать пропущенные слова (на самом деле токены), а также наличие взаимосвязи между предложениями.

Сосредоточимся на пропущенных токенах. Сами по себе они являются словами или кусочками слов. Для их получения используется алгоритм BPE (Byte Pair Encoding). Всего в словаре модели 30 000 токенов.

В оригинальной статье в качестве датасета авторы использовали тексты из книг и википедии, суммарно более 3 млрд слов. Причем 15% токенов заменялись на служебный — [MASK]. Задача модели была предсказать, какое слово скрывается под маской. Для этого она могла использовать контекст некоторой длины как слева, так и справа. Рассмотрим пример.

[CLS] the man went to [MASK] store [SEP] he bought a gallon [MASK] milk [SEP]

Можно было бы предположить, что оригинальное предложение выглядит так

The man went to asdf store. He bought a gallon dinosaur's milk.

Но более правдоподобным выглядит вариант

The men went to the store. He bought a gallon of milk.

Примерно этим нейросеть и занималась на этапе обучения. Сначала могла сказать про галлон молока динозавра. Но у динозавров нет молока, у них здоровенные яйца. Поэтому модель получала от лосса по весам. И потихоньку сходилась к более правдопобному варианту. И так для миллионов текстов.

Схема модели:

  1. Берем текст некоторой длины

  2. Разбиваем на токены

  3. Каждому токену ставим в соответствие некоторый числовой emdedding — 768-мерный вектор

  4. Добавляем Positional Encoding, чтобы учесть порядок слов (подробнее можете почитать в статье)

  5. Прогоняем через Transformer-блоки

  6. Получаем обновленные embedding'и исходных токенов, которые учитывают слова в тексте и взаимосвязи между ними

  7. Делаем классификацию слов под маской

Основная ценность модели для нас в п.6 — хороших эмбедингах токенов.

Строим эмбединги

Давайте поставим в соответствие заголовку видео некоторый числовой вектор. Аналогичным способом получим числовые векторы для записей в базе данных. Измерим расстояние от вектора-заголовка до всех векторов матчей и выберем ближайший. Это и будет нашим алгоритмом поиска. А для создания числовых векторов по тексту будем использовать DistillBERT.

Напишем функцию для загрузки предобученной модели с HuggingFace.

import transformers
from transformers import DistilBertTokenizer, DistilBertModel


@lru_cache
def load_text_model() -> tuple:
    model_version = 'distilbert-base-uncased'
    logger.info(f'Loading model: {model_version}')
    transformers.logging.set_verbosity_error()
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    model = DistilBertModel.from_pretrained("distilbert-base-uncased")
    return tokenizer, model

И функцию для получения выхода с последнего слоя блока трансформера.

def get_distilbert_hidden_state(batch: List[str]) -> torch.Tensor:
    """
    Hidden State from the last TransformerBlock of DistilBert. 
    Returns (B, T, H) Tensor, where

    - B - Batch Size,
    - T - Number of tokens in the longest string
    - H - Hidden Vector Dim (768 for DistilBert)
    """
    tokenizer, model = load_text_model()
    inputs = tokenizer(batch, return_tensors='pt', padding=True)
    outputs = model(**inputs)
    hidden_state = outputs.last_hidden_state
    return hidden_state

Разберемся на примерах. Попробуем получить выход с последнего слоя для заголовка видео "SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 Dota 2 Highlights".

hidden_states = get_distilbert_hidden_state([
    'SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 Dota 2 Highlights',
])

type(hidden_state)
> torch.Tensor

hidden_states.shape
> torch.Size([1, 22, 768])

Имеем PyTorch Tensor с размерностью:

  • 1 — размер батча. В данном случае мы передали список из одной строки

  • 22 — число токенов, которое было получено с помощью BPE tokenizer

  • 768 — размерность скрытого слоя. Была выбрана на этапе обучения сети, в данный момент мы на неё не влияем.

Батч может содержать и несколько строк разной длины.

hidden_states = get_distilbert_hidden_state([
    'SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 Dota 2 Highlights',
    'SECRET vs SPIRIT - RIYADH MASTERS 2022s',
])

hidden_states.shape
> torch.Size([2, 22, 768])

Теперь в нашем примере каждая текстовая строка закодирована матрицей чисел  \mathbb{R}^{22х768}. Но матрицы не так удобно между собой сравнивать. Поэтому без страха, без уважения усредним по первой компоненте тензора, т.е. по всем токенам.

def get_text_embeddings(batch: List[str]) -> torch.Tensor:
    hidden_state = get_distilbert_hidden_state(batch)
    embeddings = torch.mean(hidden_state, dim=1)
    return embeddings

Т.е. мы представляем текст, как нечто среднее от слов (токенов) в нем. Таким образом мы научились ставить в соответствие текстовой строке некоторый вектор из 768 чисел.

embedding = get_text_embeddings([
    'Hello, World!',
])

embedding.shape
> torch.Size([1, 768])

print(embedding)
> tensor([[ 1.5433e-03, -3.2035e-01,  4.5550e-01,  6.7697e-03, -4.9975e-03,
         -3.3267e-01,  5.5809e-01, -6.2899e-02,  4.8135e-03, -6.1892e-02,
         -5.6701e-02, -3.2326e-01, -1.0381e-02,  4.1199e-01,  7.4125e-02,
          7.1323e-02, -1.0036e-03,  9.4666e-03,  1.4607e-01,  3.1909e-01,
         ...

Поиск матча по заголовку видео

Как померить близость векторов? Например, можно взять косинус угла между ними. Из курса линейной алгебры вспомним про скалярное произведение.

(A,B) = ||A| |\cdot ||B|| \cdot \cos{\theta}\cos{\theta} = \frac{(A,B)}{||A|| \cdot ||B||} = \frac{\sum_{i}^{768}{a_ib_i}} {\sqrt{\sum{a^2_i}} \cdot \sqrt{\sum{b^2_i}}}

Напишем функцию поиска. На вход она принимает текст поискового запроса, а также корпус текстов (документы), внутри которого мы будем искать. Возвращает Top-N ближайших документов к поисковому запросу на основе косинусной меры близости.

def search(text: str, corpus: List[str], top: int = 3) -> List[Tuple[int, str]]:
    batch = [text] + corpus
    with torch.no_grad():
        embeddings = get_text_embeddings(batch)
        similarities = F.cosine_similarity(embeddings[0:1], embeddings[1:], dim=1)
        top_ids = torch.topk(similarities, k=top).indices
    result = [
        (int(idx), corpus[idx])
        for idx in top_ids
    ]
    return result

Проверим на практике.

search(
    'SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 Dota 2 Highlights',
    [
        'NAVI vs Aliance - The International 2013',
        'LGD Pushat',
        'SECRET vs SPIRIT - RIYADH MASTERS 2022',  
        'Net drug, ya ne opravdivaus',
    ],
    top=4
)

> [
 (2, 'SECRET vs SPIRIT - RIYADH MASTERS 2022'),
 (0, 'NAVI vs Aliance - The International 2013'),
 (1, 'LGD Pushat'),
 (3, 'Net drug, ya ne opravdivaus')]

Работает. Также для удобства напишем функцию, которая возвращает самый близкий результат.

def search_top1(text: str, corpus: List[str]) -> str:
    _, result = search(text, corpus, top=1)[0]
    return result

Теперь вспомним, что в начале статьи мы получили список matches — всех профессиональных матчей, которые были сыграны за несколько дней до публикации видео. Одна запись содержит названия команд, теги команд и название турнира.

record = {
    "match_id": 6676488286,
    "start_time": 1658692615,
    "leagueid": 14391,
    "league": "Riyadh Masters by Gamers8",
    "radiant_team_id": 7119388,
    "radiant_name": "Team Spirit",
    "radiant_tag": "TSpirit",
    "dire_team_id": 15,
    "dire_name": "PSG.LGD",
    "dire_tag": "PSG.LGD"
}

Напишем функцию, которая создает текстовые описания для матчей. Будем использовать не только имена, но и вариации с тегами команд, потому что мы заранее не знаем, что будет лучше работать.

def generate_team_pairs(matches: List[Dict]) -> Dict:
    pairs = dict()
    for m in matches:
        match_id = m['match_id']
        radiant_tag = m['radiant_tag']
        dire_tag = m['dire_tag']
        radiant_name = m['radiant_name']
        dire_name = m['dire_name']

        name_pair_1 = f'{radiant_name} vs {dire_name}'
        name_pair_2 = f'{dire_name} vs {radiant_name}'
        tag_pair_1 = f'{radiant_tag} vs {dire_tag}'
        tag_pair_2 = f'{dire_tag} vs {radiant_tag}'
        for pair in (name_pair_1, name_pair_2, tag_pair_1, tag_pair_2):
            pairs[pair] = match_id
    return pairs

Применяя к записи из БД, получаем

generate_team_pairs([record])

> {
    "Team Spirit vs PSG.LGD": 6676488286,
    "PSG.LGD vs Team Spirit": 6676488286,
    "TSpirit vs PSG.LGD": 6676488286,
    "PSG.LGD vs TSpirit": 6676488286
}

Где 6676488286 — идентификатор матча в БД.

Не забудем и про названия турниров.

def generate_team_tournaments(matches: List[Dict]) -> Dict:
    team_tournaments = dict()
    pairs = generate_team_pairs(matches)
    for m in matches:
        match_id = m['match_id']
        league = m['league']
        for pair, pair_match_id in pairs.items():
            if match_id == pair_match_id:
                team_tournament = f'{pair} | {league}'
                team_tournaments[team_tournament] = match_id
    return team_tournaments
generate_team_tournaments([record])

> {
    "Team Spirit vs PSG.LGD | Riyadh Masters by Gamers8": 6676488286,
    "PSG.LGD vs Team Spirit | Riyadh Masters by Gamers8": 6676488286,
    "TSpirit vs PSG.LGD | Riyadh Masters by Gamers8": 6676488286,
    "PSG.LGD vs TSpirit | Riyadh Masters by Gamers8": 6676488286
}

В качестве вишенки на торте воспользуемся наблюдением, что на канале DotA Digest, рассмотрением которого мы ограничились в данной статье, заголовки видео формируются в формате

{Team1} vs {Team2} - {Clickbait comment} - {Tournament} Dota 2 Highlights

Поэтому к заголовкам видео можно применять простое преобразование.

title = 'SECRET vs TEAM SPIRIT - RAMPAGE! SEMI FINAL - RIYADH MASTERS 2022 Dota 2 Highlights'
teams_pair, _, tournament = [t.strip() for t in title.split('-')]
tournament = tournament.replace('Dota 2 Highlights', '')
title = f'{teams_pair} | {tournament}'
print(title)
> 'SECRET vs TEAM SPIRIT | RIYADH MASTERS 2022 '

И вот она — функция для поиска матча, соответствующего заголовку ролика на YouTube.

def search_team_tournament_pairs(video_title: str, matches: List[Dict]) -> int:
    teams_pair, _, tournament = [t.strip() for t in video_title.split('-')]
    tournament = tournament.replace('Dota 2 Highlights', '')
    team_tournament = f'{teams_pair} | {tournament}'
    team_tournaments = generate_team_tournaments(matches)
    candidate = search_top1(team_tournament, list(team_tournaments))
    match_id = team_tournaments[candidate]
    return match_id

Принимает на вход название видео и список матчей из БД. Приводит заголовок к "православному" виду. Генерирует строки с описанием матчей из БД. Строит эмбединги (векторные представления) заголовка и описаний. Сравнивает их между собой по косинусной мере и находит ближайшего кандидата. Возвращает match_id, по которому получить реплей.

Выведем результат поиска.

match_id = search_team_tournament_pairs(teams_pair, tournament, matches)
[m for m in matches if m['match_id'] == match_id]
> [
    {
        "match_id": 6673204572,
        "start_time": 1658513208,
        "leagueid": 14391,
        "league": "Riyadh Masters by Gamers8",
        "radiant_team_id": 1838315,
        "radiant_name": "Team Secret",
        "radiant_tag": "Secret",
        "dire_team_id": 7119388,
        "dire_name": "Team Spirit",
        "dire_tag": "TSpirit"
    }
]

Заключение

Мы научились искать записи в БД по заголовкам видео на YouTube. Поигрались с нейросетями. И вообще замечательно провели время.

В следующей части мы подготовим датасет и улучшим алгоритм поиска хайлайтов из первой части.

* Качество поиска можно уточнить, если разными способами генерировать текстовые описания для записей в БД. Реализация есть в репозитории на GitHub.

** Внимательный читатель заметит, что видео могут содержать не один матч, а серию, потому что команды зачастую играют 2 или 3 карты. Мне было лень возиться, но в перспективе это важно учесть.

Ссылки