Last active
May 28, 2024 06:41
-
-
Save Airbus5717/0a884eba843b0ca6b52cb77da2c5ab61 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### CONFIGURATION | |
TAKE_ACTIONS = False # If false, it has no access to terminal and code execution | |
USE_MEMORY = True | |
USE_PYTHON = True | |
USE_CALCULATOR = True | |
USE_SHELL = True | |
USE_WIKI_SEARCH = True | |
USE_SEARCH_ONLINE = True | |
ASK_HUMAN_BEFORE_EXE_CMDS = True | |
# it uses transformers or local transformers model | |
FINE_TUNED_MODEL_PATH = "../finetuning/mixtral/qlora-mixtral-out" | |
#DEFAULT_MODEL = "cognitivecomputations/dolphin-2.5-mixtral-8x7b" | |
DEFAULT_MODEL = FINE_TUNED_MODEL_PATH | |
############################# | |
##### ##### | |
### DO NOT EDIT BELOW ### | |
##### ##### | |
############################# | |
import platform | |
import warnings | |
import torch | |
import numexpr as ne | |
from typing import Optional, List, Mapping, Any | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM | |
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast | |
from langchain.llms.base import LLM | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.tools import BaseTool | |
from langchain_community.utilities import WikipediaAPIWrapper | |
from langchain.agents import Tool | |
from langchain_community.tools import WikipediaQueryRun, ShellTool | |
from langchain.agents import create_json_chat_agent, AgentExecutor | |
from langchain.memory import ConversationBufferMemory | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
### CODE | |
USE_PYTHON = USE_PYTHON and TAKE_ACTIONS | |
USE_SHELL = USE_SHELL and TAKE_ACTIONS | |
MODEL_NAME = FINE_TUNED_MODEL_PATH if USE_FINE_TUNED_MODEL else DEFAULT_MODEL | |
QUANTIZATION_CONFIG = BitsAndBytesConfig(load_in_4bit=True) | |
MODEL = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.bfloat16, | |
quantization_config=QUANTIZATION_CONFIG, | |
device_map="auto", | |
) | |
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# print(type(MODEL)) | |
warnings.filterwarnings("ignore") | |
# Tools classes | |
class CustomLLMMistral(LLM): | |
model: MixtralForCausalLM # MistralForCausalLM | |
tokenizer: LlamaTokenizerFast | |
@property | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> str: | |
messages = [ | |
{"role": "user", "content": prompt}, | |
] | |
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
model_inputs = encodeds.to(self.model.device) | |
generated_ids = self.model.generate( | |
model_inputs, | |
max_new_tokens=1024, # Adjusted max_length for generated output | |
do_sample=True, | |
pad_token_id=TOKENIZER.eos_token_id, | |
top_k=50, # Increased top_k for more diverse sampling | |
temperature=0.5, # Lowered temperature for more conservative sampling | |
num_return_sequences=1, # Set to generate a single sequence | |
) | |
decoded = self.tokenizer.batch_decode(generated_ids) | |
try: | |
output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip() | |
except Exception as e: | |
print(e) | |
exit() | |
if stop is not None: | |
for word in stop: | |
output = output.split(word)[0].strip() | |
# Mistral 7B sometimes fails to properly close the Markdown Snippets. | |
# If they are not correctly closed, Langchain will struggle to parse the output. | |
while not output.endswith("```"): | |
output += "`" | |
return output | |
@property | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return {"model": self.model} | |
class Calculator(BaseTool): | |
name = "calculator" | |
description = "Use this tool for math operations. It requires numexpr syntax. Use it always you need to solve any math operation. You can give it the whole complete expression at once. Be sure syntax is correct." | |
def _run(self, expression: str): | |
try: | |
return ne.evaluate(expression).item() | |
except Exception: | |
return "This is not a numexpr valid syntax. Try a different syntax." | |
def _arun(self, radius: int): | |
raise NotImplementedError("This tool does not support async") | |
def custom_py_run(input): | |
from io import StringIO | |
import contextlib | |
with contextlib.redirect_stdout(StringIO()) as output: | |
try: | |
exec(input) | |
except Exception as e: | |
return e | |
return output.getvalue().strip() | |
def _get_platform() -> str: | |
system = platform.system() | |
if system == "Darwin": | |
return "MacOS" | |
return system | |
llm = CustomLLMMistral(model=MODEL, tokenizer=TOKENIZER) | |
search = DuckDuckGoSearchRun() | |
wikipedia = WikipediaQueryRun( | |
api_wrapper=WikipediaAPIWrapper(top_k_results=2, doc_content_chars_max=2500) | |
) | |
shell_tool = ShellTool( | |
description=f""" | |
Run shell commands on this {_get_platform()} machine., | |
it might return 'None' due to the user not allowing to run such command if it happens then conclude with 'User has not allowed to run the command', | |
if some package is not available, suggest to install the packages that are required to run the program | |
some commands do not show output, you may use other tools to find out if they are correctly ran, | |
inline the terminal action 'echo $?' to check the result of the command if it is successful or not | |
then take actions needed. | |
""", | |
ask_human_input=ASK_HUMAN_BEFORE_EXE_CMDS, | |
) | |
wikipedia_tool = Tool( | |
name="wikipedia", | |
description="Never search for more than one concept at a single step. If you need to compare two concepts, search for each one individually and use the meta data syntax when prompting for the next query. Syntax: string with a simple concept", | |
func=wikipedia.run, | |
) | |
search_tool = Tool( | |
name="Search online", | |
description="this tool searches the internet for results", | |
func=search, | |
) | |
calculator_tool = Calculator() | |
py_tool = Tool( | |
name="Python", | |
description=f"""This is a python execution tool, use it only execution is required. | |
this tool runs python code on this {_get_platform()} machine. it uses a function exec to evaluate the code and collects the outputs and returns the outputs. | |
provide valid python code to be ran once | |
""", | |
func=custom_py_run, | |
) | |
tools = [] | |
if USE_SEARCH_ONLINE: | |
tools.append(search_tool) | |
if USE_SHELL: | |
tools.append(shell_tool) | |
if USE_PYTHON: | |
tools.append(py_tool) | |
if USE_CALCULATOR: | |
tools.append(calculator_tool) | |
if USE_WIKI_SEARCH: | |
tools.append(wikipedia_tool) | |
system = """ | |
As a Cyber Security Researcher Assistant, you're equipped to solve tasks. Try Minimize the steps to solve the task, but make sure it is correct. Each task requires a series of steps, represented by JSON blobs with specific keys called Meta Data: | |
thought -> your thoughts | |
action -> tool name or "Final Answer" to give a final answer | |
action_input -> tool parameters or the final solution | |
Available tools: {tool_names} | |
Tool descriptions: | |
{tools} | |
If you have enough information, use "Final Answer" with the solution. | |
If information is insufficient or incorrect, try once. If the issue persists, use "Final Answer". | |
""" | |
human = """ | |
Add "STOP" after each snippet. Follow this JSON schema: | |
```json | |
{{"thought": "<your thoughts>", | |
"action": "<tool name or "Final Answer" to give a final answer or "None" as a to conclude>", | |
"action_input": "<tool parameters or final output>"}} | |
```\n | |
STOP | |
Query: "{input}". Provide the next necessary step only. | |
Base your answer on previous steps, even if you believe you know the solution. | |
Add "STOP" after each snippet. | |
Previous steps and gathered information: | |
""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system), | |
MessagesPlaceholder("chat_history", optional=True), | |
("human", human), | |
MessagesPlaceholder("agent_scratchpad"), | |
] | |
) | |
agent = create_json_chat_agent( | |
tools=tools, | |
llm=llm, | |
prompt=prompt, | |
stop_sequence=["STOP"], | |
template_tool_response="{observation}", | |
) | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
verbose=True, | |
handle_parsing_errors=True, | |
memory=memory if USE_MEMORY else None, | |
) | |
# Display currently used agents | |
print("Agents: [", ", ".join(i.name for i in tools), "]") | |
# Prompt input | |
x = input("input: ") | |
run = agent_executor.invoke( | |
{"input": f"{x}"}, {"recursion_limit": 4, "max_concurrency": 1} | |
) | |
print(run) | |
out = run["output"] | |
print(f"\n### FINAL OUTPUT ###\n{out}\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment