Skip to content

Instantly share code, notes, and snippets.

View patil-suraj's full-sized avatar
🏠
Working from home

Suraj Patil patil-suraj

🏠
Working from home
View GitHub Profile
@patil-suraj
patil-suraj / pipeline_parallel.py
Created October 2, 2024 11:32 — forked from 3outeille/pipeline_parallel.py
Self contained example of how pipeline parallel works (AFAB and 1F1B) in 200 LOC
#VERBOSE=0 torchrun --nproc_per_node 3 self_contained_pp_LOC.py
import os, random, numpy as np, torch, torch.nn as nn, torch.distributed as dist, torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
STEP, local_rank, world_size, verbose = 0, int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]), os.environ.get("VERBOSE", "0") == "1"
def set_all_seed(seed):
@patil-suraj
patil-suraj / imagent-class-mapping.json
Created March 23, 2023 14:17
imagent-class-mapping
{"0": "tench, Tinca tinca",
"1": "goldfish, Carassius auratus",
"2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
"3": "tiger shark, Galeocerdo cuvieri",
"4": "hammerhead, hammerhead shark",
"5": "electric ray, crampfish, numbfish, torpedo",
"6": "stingray",
"7": "cock",
"8": "hen",
"9": "ostrich, Struthio camelus",
import torch
import torch.utils.benchmark as benchmark
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.cross_attention import TorchAttentionProcessor
def benchmark_torch_function(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return round(t0.blocked_autorange(min_run_time=1).mean, 2)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from pathlib import Path
import shutil
@patil-suraj
patil-suraj / onnx_t5.py
Last active April 13, 2023 10:44
Speeding up T5 with onnx 🚀
import inspect
import logging
import os
from pathlib import Path
import torch
from psutil import cpu_count
from transformers import T5Config, T5ForConditionalGeneration, T5Tokenizer
from transformers.generation_utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast, Seq2SeqLMOutput
import logging
from finetune_trainer import DataTrainingArguments, Seq2SeqTrainingArguments
from seq2seq_trainer import Seq2SeqTrainer
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from utils import Seq2SeqDataCollator, Seq2SeqDataset, build_compute_metrics_fn, freeze_embeds, freeze_params
# Setup logging
logging.basicConfig(level=logging.INFO)
@patil-suraj
patil-suraj / update_swift.sh
Created October 4, 2019 15:44 — forked from sgugger/update_swift.sh
Updating S4TF
#!/usr/bin/env bash
pushd ~/swift/
rm -rf usr
popd
pushd ~/download/
rm swift-tensorflow-DEVELOPMENT-cuda10.0-cudnn7-ubuntu18.04.tar.gz
wget https://storage.googleapis.com/s4tf-kokoro-artifact-testing/latest/swift-tensorflow-DEVELOPMENT-cuda10.0-cudnn7-ubuntu18.04.tar.gz
tar -xf swift-tensorflow-DEVELOPMENT-cuda10.0-cudnn7-ubuntu18.04.tar.gz
mv usr/ ~/swift/
mv ~/swift/usr/lib/python3.6 ~/swift/usr/lib/python3.7
@patil-suraj
patil-suraj / yahoo_finance.py
Created June 12, 2019 06:07 — forked from scrapehero/yahoo_finance.py
Python 2 code to extract stock market data from Yahoo Finance
from lxml import html
import requests
from exceptions import ValueError
from time import sleep
import json
import argparse
from collections import OrderedDict
from time import sleep
def parse(ticker):
@patil-suraj
patil-suraj / bottom_sheet_fix.dart
Created May 4, 2019 05:47 — forked from crimsonsuv/bottom_sheet_fix.dart
Flutter Modal bottom sheet whith input fix and full screen sheet
//Flutter Modal Bottom Sheet
//Modified by Suvadeep Das
//Based on https://gist.github.com/andrelsmoraes/9e4af0133bff8960c1feeb0ead7fd749
import 'dart:async';
import 'package:flutter/material.dart';
import 'package:meta/meta.dart';
const Duration _kBottomSheetDuration = const Duration(milliseconds: 200);