Created
April 16, 2023 23:44
-
-
Save burke/27a5c4ae1f9f3ffbd90f87b2994d4911 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
# Add the _vendor directory to the Python path | |
vendor_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '_vendor') | |
sys.path.insert(0, vendor_dir) | |
# Now you can import the bundled library | |
import openai | |
from aqt import mw, gui_hooks | |
from aqt.qt import QAction | |
from aqt.utils import showInfo, showWarning | |
from aqt.progress import ProgressManager | |
from PyQt5.QtCore import QRunnable, QThreadPool, pyqtSignal, QObject, Qt, QSemaphore | |
class UpdateNoteTask(QRunnable): | |
def __init__(self, note, update_signal, semaphore): | |
super(UpdateNoteTask, self).__init__() | |
self.note = note | |
self.update_signal = update_signal | |
self.canceled = False | |
self.semaphore = semaphore | |
def run(self): | |
if not self.canceled: | |
update_why_field(self.note) | |
self.update_signal.increment_progress() | |
self.semaphore.release() | |
def on_canceled(self, canceled): | |
self.canceled = canceled | |
class UpdateSignal(QObject): | |
progress_increment = pyqtSignal() | |
check_canceled = pyqtSignal(bool) | |
def increment_progress(self): | |
self.progress_increment.emit() | |
def is_canceled(self, callback): | |
self.check_canceled.connect(callback) | |
def find_empty_why_notes(): | |
# Query to find all notes with the "Cloze-GPT" note type and an empty "Why" field | |
query = ( | |
"note:Cloze-GPT " # Search for notes with the "Cloze-GPT" note type | |
"(-tag:suspended or tag:suspended-_*)" # Exclude notes that are suspended | |
) | |
note_ids = mw.col.find_notes(query) | |
# Get the note objects | |
notes = [mw.col.get_note(note_id) for note_id in note_ids] | |
# Filter notes where the "Why" field is empty | |
empty_why_notes = [note for note in notes if not note["Why"].strip()] | |
return empty_why_notes | |
def update_all_empty_why_notes(): | |
empty_why_notes = find_empty_why_notes() | |
progress = ProgressManager(mw) | |
progress.start(immediate=True, min=0, max=len(empty_why_notes), label="Updating notes...") | |
thread_pool = QThreadPool() | |
semaphore = QSemaphore(4) | |
update_signal = UpdateSignal() | |
update_signal.progress_increment.connect(lambda: progress.update(value=progress._counter + 1)) | |
update_signal.check_canceled.connect(lambda: progress.dialog.wasCanceled(), Qt.ConnectionType.DirectConnection) | |
tasks = [] | |
for note in empty_why_notes: | |
semaphore.acquire() | |
task = UpdateNoteTask(note, update_signal, semaphore) | |
tasks.append(task) | |
update_signal.is_canceled(task.on_canceled) | |
thread_pool.start(task) | |
thread_pool.waitForDone() | |
progress.finish() | |
showInfo(f"Updated {len(empty_why_notes)} notes.") | |
def update_why_field(note): | |
config = get_config() | |
if not config["openai_api_key"]: | |
showWarning("Please set your OpenAI API key in the add-on configuration.") | |
return | |
# Construct your custom prompt using the note's fields | |
prompt_template = config["prompt_template"] | |
prompt = prompt_template.format(text=note["Text"], extra=note["Extra"]) | |
openai.api_key = config["openai_api_key"] | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt}, | |
], | |
n=1, | |
temperature=0.7, | |
) | |
generated_text = response.choices[0].message.content.strip() | |
note["Why"] = generated_text | |
note.flush() | |
def get_config(): | |
return mw.addonManager.getConfig(__name__) | |
def on_test_addon_action(): | |
update_all_empty_why_notes() | |
def setup_menu(): | |
def on_main_window_did_init(): | |
test_action = QAction("Test Cloze-GPT Enhancer", mw) | |
test_action.triggered.connect(on_test_addon_action) | |
mw.form.menuTools.addAction(test_action) | |
gui_hooks.main_window_did_init.append(on_main_window_did_init) | |
setup_menu() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"openai_api_key": "", | |
"prompt_template": "This is a cloze flash card, and clozes are marked with {{c<n>::answer}} or sometimes {{c<n>::answer::hint}}:\n\n{text}\n\nPlease provide some brief detail about why this is the case, or, if explaining why doesn't feel appropriate, just provide a little more detail:" | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment