LLAMA 너무 커서 inf2.24large에서 시작하기 (허깅 페이스 토큰도 만들어야함)

AMI : Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20231003

source /opt/aws_neuron_venv_pytorch/bin/activate
pip3 install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com
pip3 install optimum
pip3 install -U "huggingface_hub[cli]"
huggingface-cli login


llama 사용권한요청1 : https://llama.meta.com/llama-downloads/

llama 사용권한요청2 : https://huggingface.co/meta-llama/Llama-2-7b-hf


from transformers import AutoTokenizer, LlamaForCausalLM
from huggingface_hub import login

login("본인 허깅페이스 토큰~")

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=30)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]


import time
import torch
from transformers import AutoTokenizer
from transformers_neuronx import LlamaForSampling
from huggingface_hub import login

# load meta-llama/Llama-2-13b to the NeuronCores with 24-way tensor parallelism and run compilation
neuron_model2 = LlamaForSampling.from_pretrained('meta-llama/Llama-2-7b-chat-hf', batch_size=1, tp_degree=12, amp='f16')

# construct a tokenizer and encode prompt text
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

prompt = ["Hello, I'm a language model,"]
#input_ids = tokenizer.encode(prompt, return_tensors="pt")
encoded_input = tokenizer(prompt, return_tensors='pt')

# run inference with top-k sampling
with torch.inference_mode():
    start = time.time()
    generated_sequences = neuron_model2.sample(encoded_input.input_ids, sequence_length=128, top_k=50)
    elapsed = time.time() - start

generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences]
print(f'generated sequences {generated_sequences} in {elapsed} seconds')

+ Recent posts