kyegomez / BitNet
- суббота, 2 марта 2024 г. в 00:00:04
Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
PyTorch Implementation of the linear methods and model from the paper "BitNet: Scaling 1-bit Transformers for Large Language Models"
BitLinear = tensor -> layernorm -> Binarize -> abs max quantization -> dequant
"The implementation of the BitNet architecture is quite simple, requiring only the replacement of linear projections (i.e., nn.Linear in PyTorch) in the Transformer. " -- BitNet is really easy to implement just swap out the linears with the BitLinear modules!
train.py
file that trains on enwiki8 a small 1gb dataset of wikipedia: HERE IS THE LINKpip install bitnet
import torch
from bitnet import BitLinear
# Input
x = torch.randn(10, 512)
# BitLinear layer
layer = BitLinear(512, 400)
# Output
y = layer(x)
print(y)
import torch
from bitnet import BitNetTransformer
bitnet = BitNetTransformer(
num_tokens=20000,
dim=512,
depth=6,
dim_head=64,
heads=8,
ff_mult=4,
)
tokens = torch.randint(0, 20000, (1, 512))
logits = bitnet(tokens)
print(logits.shape)
import torch
from bitnet.bitffn import BitFeedForward
# Random input
x = torch.randn(10, 512)
# FFN
ff = BitFeedForward(512)
# Apply FFN
y = ff(x)
print(y.shape)
# torch.Size([10, 512])
from bitnet import BitNetInference
bitnet = BitNetInference()
bitnet.load_model("../model_checkpoint.pth") # Download model
output_str = bitnet.generate("The dog jumped over the ", 512)
print(output_str)
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from bitnet import replace_linears_in_hf
# Load a model from Hugging Face's Transformers
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Replace Linear layers with BitLinear
replace_linears_in_hf(model)
# Example text to classify
text = "Replace this with your text"
inputs = tokenizer(
text, return_tensors="pt", padding=True, truncation=True, max_length=512
)
# Perform inference
model.eval() # Set the model to evaluation mode
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
print(predictions)
# Process predictions
predicted_class_id = predictions.argmax().item()
print(f"Predicted class ID: {predicted_class_id}")
# Optionally, map the predicted class ID to a label, if you know the classification labels
# labels = ["Label 1", "Label 2", ...] # Define your labels corresponding to the model's classes
# print(f"Predicted label: {labels[predicted_class_id]}")
MIT
@misc{2310.11453,
Author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei},
Title = {BitNet: Scaling 1-bit Transformers for Large Language Models},
Year = {2023},
Eprint = {arXiv:2310.11453},
}
BitNetTransformer