Last active
July 6, 2024 00:53
-
-
Save ralphbean/8ea941b0cf06b92191ac4b3074ede656 to your computer and use it in GitHub Desktop.
refresh-oci-copy-file.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" Write oci-copy.yaml file based on latest data in hugginface | |
In order to get the latest revision: | |
$ python3 refresh-oci-copy-file.py prometheus-eval/prometheus-8x7b-v2.0 | |
In order to get files and digests for a specific revision in the history: | |
$ python3 refresh-oci-copy-file.py --revision e0bb4692356a1738acf25f15180e9f025725b0f2 prometheus-eval/prometheus-8x7b-v2.0 | |
""" | |
import argparse | |
import hashlib | |
import logging | |
import mimetypes | |
import os | |
import yaml | |
import httpx | |
parser = argparse.ArgumentParser() | |
parser.add_argument("repository") | |
parser.add_argument("--revision", default="main") | |
parser.add_argument("--debug", default=False, action="store_true") | |
args = parser.parse_args() | |
if args.debug: | |
logging.basicConfig(level=logging.DEBUG) | |
else: | |
logging.basicConfig(level=logging.WARN) | |
known_types = { | |
".safetensors": "application/octet-stream", | |
".model": "application/octet-stream", | |
".gguf": "application/octet-stream", | |
".pt": "application/octet-stream", | |
} | |
for suffix, mime_type in known_types.items(): | |
mimetypes.add_type(mime_type, suffix) | |
def determine_digest(url, info, token): | |
if info.get("lfs"): | |
return info["lfs"]["sha256"] | |
headers = {} | |
if token: | |
headers["Authorization"] = f"Bearer {token}" | |
response = httpx.get(url, headers=headers) | |
response.raise_for_status() | |
data = response.content | |
return hashlib.sha256(data).hexdigest() | |
token = os.environ.get("HUGGINGFACE_TOKEN") | |
print(f"🤗 Querying hugginface.co for {args.repository}") | |
headers = {} | |
if token: | |
print(f"🔑 Using key {token[:6]}{'*' * len(token[6:])}") | |
headers["Authorization"] = f"Bearer {token}" | |
else: | |
print( | |
"🤷 No $HUGGINGFACE_TOKEN environment variable found. " | |
"Proceeding unauthenticated." | |
) | |
url = f"https://huggingface.co/api/models/{args.repository}/revision/{args.revision}" | |
params = dict(blobs=True) | |
response = httpx.get(url=url, headers=headers, params=params) | |
response.raise_for_status() | |
data = response.json() | |
revision = data["sha"] | |
result = {"artifact_type": "application/x-mlmodel", "artifacts": []} | |
for sibling in data["siblings"]: | |
if sibling["rfilename"].startswith("."): | |
continue | |
url = f"https://huggingface.co/{args.repository}/resolve/{revision}/{sibling['rfilename']}" | |
print(f"🔗 Considering {url}") | |
artifact = { | |
"source": url, | |
"filename": sibling["rfilename"], | |
"type": mimetypes.guess_type(url)[0], | |
"sha256sum": determine_digest(url, sibling, token), | |
} | |
result["artifacts"].append(artifact) | |
print("💾 Writing oci-copy.yaml") | |
with open("oci-copy.yaml", "w") as f: | |
f.write(yaml.dump(result, sort_keys=True)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment