I was playing CodeLama using 2 or 4 GPUs for 7B and 30 B models, respectively. I changed the official example instruct python code to accept user instructions using input() inside a while loop. So I can keep giving instructions and getting results from the models.
But whenever I run the code using 2 or 4 GPUs, it just hangs after I type in an instruction. The same code works fine if I only use 1 single GPU.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from typing import Optional
import fire
from llama import Llama
import logging,sys
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.2,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 8,
max_gen_len: Optional[int] = None,
):
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
logging.basicConfig(stream=sys.stdout,level=logging.INFO)
print("Please type in your coding questions. Press Ctrl+C to terminate.")
try:
while True:
print("\n=========Please type in your question=========================\n")
user_content = input("\nQuestion: ") # User question
instructions = [[{"role": "system", "content": "You are a helpful, expert coding assistant.",}, {"role": "user", "content": user_content}]]
results = generator.chat_completion(
instructions, # type: ignore
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
for result in results:
print(f"\nAnswer: {result['generation']['content']}")
except KeyboardInterrupt:
print("\nProgram terminated by user. Exiting...")
if __name__ == "__main__":
fire.Fire(main)
After spending hours on debugging the code, I suddenly realized that a single input() statement expects 2 or 4 lines of input from me, when 2 or 4 GPUs are started by torchrun.
This is totally unexpected and counter-intuitive.
In the end, I switch to use huggingface API instead, hoping better support of multiple GPUs with user inputs.