Skip to content

Commit 7ca7c47

Browse files
authored
Make quote style consistent (#891)
1 parent 9276edb commit 7ca7c47

File tree

24 files changed

+240
-82
lines changed

24 files changed

+240
-82
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2+
# Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3+
# Code repository: https://github.com/rasbt/reasoning-from-scratch
4+
5+
# Verify that Python source files (and optionally notebooks) use double quotes for strings.
6+
7+
import argparse
8+
import ast
9+
import io
10+
import json
11+
import sys
12+
import tokenize
13+
from pathlib import Path
14+
15+
EXCLUDED_DIRS = {
16+
".git",
17+
".hg",
18+
".mypy_cache",
19+
".pytest_cache",
20+
".ruff_cache",
21+
".svn",
22+
".tox",
23+
".venv",
24+
"__pycache__",
25+
"build",
26+
"dist",
27+
"node_modules",
28+
}
29+
30+
PREFIX_CHARS = {"r", "u", "f", "b"}
31+
SINGLE_QUOTE = "'"
32+
DOUBLE_QUOTE = "\""
33+
TRIPLE_SINGLE = SINGLE_QUOTE * 3
34+
TRIPLE_DOUBLE = DOUBLE_QUOTE * 3
35+
36+
37+
def should_skip(path):
38+
parts = set(path.parts)
39+
return bool(EXCLUDED_DIRS & parts)
40+
41+
42+
def collect_fstring_expr_string_positions(source):
43+
"""
44+
Return set of (lineno, col_offset) for string literals that appear inside
45+
formatted expressions of f-strings. These should be exempt from the double
46+
quote check, since enforcing double quotes there is unnecessarily strict.
47+
"""
48+
try:
49+
tree = ast.parse(source)
50+
except SyntaxError:
51+
return set()
52+
53+
positions = set()
54+
55+
class Collector(ast.NodeVisitor):
56+
def visit_JoinedStr(self, node):
57+
for value in node.values:
58+
if isinstance(value, ast.FormattedValue):
59+
self._collect_from_expr(value.value)
60+
# Continue walking to catch nested f-strings within expressions
61+
self.generic_visit(node)
62+
63+
def _collect_from_expr(self, node):
64+
if isinstance(node, ast.Constant) and isinstance(node.value, str):
65+
positions.add((node.lineno, node.col_offset))
66+
elif isinstance(node, ast.Str): # Python <3.8 compatibility
67+
positions.add((node.lineno, node.col_offset))
68+
else:
69+
for child in ast.iter_child_nodes(node):
70+
self._collect_from_expr(child)
71+
72+
Collector().visit(tree)
73+
return positions
74+
75+
76+
def check_quotes_in_source(source, path):
77+
violations = []
78+
ignored_positions = collect_fstring_expr_string_positions(source)
79+
tokens = tokenize.generate_tokens(io.StringIO(source).readline)
80+
for tok_type, tok_str, start, _, _ in tokens:
81+
if tok_type == tokenize.STRING:
82+
if start in ignored_positions:
83+
continue
84+
lowered = tok_str.lower()
85+
# ignore triple-quoted strings
86+
if lowered.startswith((TRIPLE_DOUBLE, TRIPLE_SINGLE)):
87+
continue
88+
89+
# find the prefix and quote type
90+
# prefix = ""
91+
for c in PREFIX_CHARS:
92+
if lowered.startswith(c):
93+
# prefix = c
94+
lowered = lowered[1:]
95+
break
96+
97+
# report if not using double quotes
98+
if lowered.startswith(SINGLE_QUOTE):
99+
line, col = start
100+
violations.append(f"{path}:{line}:{col}: uses single quotes")
101+
return violations
102+
103+
104+
def check_file(path):
105+
try:
106+
if path.suffix == ".ipynb":
107+
return check_notebook(path)
108+
else:
109+
text = path.read_text(encoding="utf-8")
110+
return check_quotes_in_source(text, path)
111+
except Exception as e:
112+
return [f"{path}: failed to check ({e})"]
113+
114+
115+
def check_notebook(path):
116+
violations = []
117+
with open(path, encoding="utf-8") as f:
118+
nb = json.load(f)
119+
for cell in nb.get("cells", []):
120+
if cell.get("cell_type") == "code":
121+
src = "".join(cell.get("source", []))
122+
violations.extend(check_quotes_in_source(src, path))
123+
return violations
124+
125+
126+
def parse_args():
127+
parser = argparse.ArgumentParser(description="Verify double-quoted string literals.")
128+
parser.add_argument(
129+
"--include-notebooks",
130+
action="store_true",
131+
help="Also scan Jupyter notebooks (.ipynb files) for single-quoted strings.",
132+
)
133+
return parser.parse_args()
134+
135+
136+
def main():
137+
args = parse_args()
138+
project_root = Path(".").resolve()
139+
py_files = sorted(project_root.rglob("*.py"))
140+
notebook_files = sorted(project_root.rglob("*.ipynb")) if args.include_notebooks else []
141+
142+
violations = []
143+
for path in py_files + notebook_files:
144+
if should_skip(path):
145+
continue
146+
violations.extend(check_file(path))
147+
148+
if violations:
149+
print("\n".join(violations))
150+
print(f"\n{len(violations)} violations found.")
151+
return 1
152+
153+
print("All files use double quotes correctly.")
154+
return 0
155+
156+
157+
if __name__ == "__main__":
158+
sys.exit(main())

appendix-D/01_main-chapter-code/previous_chapters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
7373
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
7474
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
7575
self.dropout = nn.Dropout(dropout)
76-
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
76+
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
7777

7878
def forward(self, x):
7979
b, num_tokens, d_in = x.shape

appendix-E/01_main-chapter-code/previous_chapters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
8080
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
8181
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
8282
self.dropout = nn.Dropout(dropout)
83-
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
83+
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
8484

8585
def forward(self, x):
8686
b, num_tokens, d_in = x.shape
@@ -257,8 +257,8 @@ def assign(left, right):
257257

258258

259259
def load_weights_into_gpt(gpt, params):
260-
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
261-
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
260+
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
261+
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
262262

263263
for b in range(len(params["blocks"])):
264264
q_w, k_w, v_w = np.split(
@@ -318,7 +318,7 @@ def load_weights_into_gpt(gpt, params):
318318

319319

320320
def text_to_token_ids(text, tokenizer):
321-
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
321+
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
322322
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
323323
return encoded_tensor
324324

ch02/02_bonus_bytepair-encoder/bpe_openai_gpt2.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def get_pairs(word):
7070

7171

7272
class Encoder:
73-
def __init__(self, encoder, bpe_merges, errors='replace'):
73+
def __init__(self, encoder, bpe_merges, errors="replace"):
7474
self.encoder = encoder
7575
self.decoder = {v: k for k, v in self.encoder.items()}
7676
self.errors = errors # how to handle errors in decoding
@@ -92,7 +92,7 @@ def bpe(self, token):
9292
return token
9393

9494
while True:
95-
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
95+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
9696
if bigram not in self.bpe_ranks:
9797
break
9898
first, second = bigram
@@ -119,43 +119,43 @@ def bpe(self, token):
119119
break
120120
else:
121121
pairs = get_pairs(word)
122-
word = ' '.join(word)
122+
word = " ".join(word)
123123
self.cache[token] = word
124124
return word
125125

126126
def encode(self, text):
127127
bpe_tokens = []
128128
for token in re.findall(self.pat, text):
129-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
130-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
129+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
130+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
131131
return bpe_tokens
132132

133133
def decode(self, tokens):
134-
text = ''.join([self.decoder[token] for token in tokens])
135-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
134+
text = "".join([self.decoder[token] for token in tokens])
135+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
136136
return text
137137

138138

139139
def get_encoder(model_name, models_dir):
140-
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
140+
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
141141
encoder = json.load(f)
142-
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
142+
with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f:
143143
bpe_data = f.read()
144-
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
144+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
145145
return Encoder(encoder=encoder, bpe_merges=bpe_merges)
146146

147147

148148
def download_vocab():
149149
# Modified code from
150-
subdir = 'gpt2_model'
150+
subdir = "gpt2_model"
151151
if not os.path.exists(subdir):
152152
os.makedirs(subdir)
153-
subdir = subdir.replace('\\', '/') # needed for Windows
153+
subdir = subdir.replace("\\", "/") # needed for Windows
154154

155-
for filename in ['encoder.json', 'vocab.bpe']:
155+
for filename in ["encoder.json", "vocab.bpe"]:
156156
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M/" + filename, stream=True)
157157

158-
with open(os.path.join(subdir, filename), 'wb') as f:
158+
with open(os.path.join(subdir, filename), "wb") as f:
159159
file_size = int(r.headers["content-length"])
160160
chunk_size = 1000
161161
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:

ch04/01_main-chapter-code/previous_chapters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
6060
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
6161
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
6262
self.dropout = nn.Dropout(dropout)
63-
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
63+
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
6464

6565
def forward(self, x):
6666
b, num_tokens, d_in = x.shape

ch04/01_main-chapter-code/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def test_main(capsys):
3333
captured = capsys.readouterr()
3434

3535
# Normalize line endings and strip trailing whitespace from each line
36-
normalized_expected = '\n'.join(line.rstrip() for line in expected.splitlines())
37-
normalized_output = '\n'.join(line.rstrip() for line in captured.out.splitlines())
36+
normalized_expected = "\n".join(line.rstrip() for line in expected.splitlines())
37+
normalized_output = "\n".join(line.rstrip() for line in captured.out.splitlines())
3838

3939
# Compare normalized strings
4040
assert normalized_output == normalized_expected

ch05/01_main-chapter-code/previous_chapters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
7171
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
7272
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
7373
self.dropout = nn.Dropout(dropout)
74-
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
74+
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
7575

7676
def forward(self, x):
7777
b, num_tokens, d_in = x.shape

ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftex
4343
content = strip_headers(content)
4444

4545
# Regular expression to replace multiple blank lines with a single blank line
46-
content = re.sub(r'\n\s*\n', '\n\n', content)
46+
content = re.sub(r"\n\s*\n", "\n\n", content)
4747
estimated_size = len(content.encode("utf-8"))
4848

4949
if current_size + estimated_size > max_size_mb * 1024 * 1024:

ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -148,26 +148,26 @@ def train_model_simple(model, optimizer, device, n_epochs,
148148

149149
if __name__ == "__main__":
150150

151-
parser = argparse.ArgumentParser(description='GPT Model Training Configuration')
152-
153-
parser.add_argument('--data_dir', type=str, default='gutenberg/data',
154-
help='Directory containing the training data')
155-
parser.add_argument('--output_dir', type=str, default='model_checkpoints',
156-
help='Directory where the model checkpoints will be saved')
157-
parser.add_argument('--n_epochs', type=int, default=1,
158-
help='Number of epochs to train the model')
159-
parser.add_argument('--print_sample_iter', type=int, default=1000,
160-
help='Iterations between printing sample outputs')
161-
parser.add_argument('--eval_freq', type=int, default=100,
162-
help='Frequency of evaluations during training')
163-
parser.add_argument('--save_ckpt_freq', type=int, default=100_000,
164-
help='Frequency of saving model checkpoints during training')
165-
parser.add_argument('--lr', type=float, default=5e-4,
166-
help='Learning rate for the optimizer')
167-
parser.add_argument('--batch_size', type=int, default=4,
168-
help='Batch size for training')
169-
parser.add_argument('--debug', type=bool, default=False,
170-
help='Uses a very small model for debugging purposes')
151+
parser = argparse.ArgumentParser(description="GPT Model Training Configuration")
152+
153+
parser.add_argument("--data_dir", type=str, default="gutenberg/data",
154+
help="Directory containing the training data")
155+
parser.add_argument("--output_dir", type=str, default="model_checkpoints",
156+
help="Directory where the model checkpoints will be saved")
157+
parser.add_argument("--n_epochs", type=int, default=1,
158+
help="Number of epochs to train the model")
159+
parser.add_argument("--print_sample_iter", type=int, default=1000,
160+
help="Iterations between printing sample outputs")
161+
parser.add_argument("--eval_freq", type=int, default=100,
162+
help="Frequency of evaluations during training")
163+
parser.add_argument("--save_ckpt_freq", type=int, default=100_000,
164+
help="Frequency of saving model checkpoints during training")
165+
parser.add_argument("--lr", type=float, default=5e-4,
166+
help="Learning rate for the optimizer")
167+
parser.add_argument("--batch_size", type=int, default=4,
168+
help="Batch size for training")
169+
parser.add_argument("--debug", type=bool, default=False,
170+
help="Uses a very small model for debugging purposes")
171171

172172
args = parser.parse_args()
173173

ch05/05_bonus_hparam_tuning/hparam_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def train_model(model, train_loader, val_loader, optimizer, device,
118118
print(f"Total hyperparameter configurations: {total_combinations}")
119119

120120
# Placeholder for the best loss and best hyperparameters
121-
best_val_loss = float('inf')
121+
best_val_loss = float("inf")
122122
best_hparams = {}
123123

124124
script_path = os.path.abspath(__file__)

0 commit comments

Comments
 (0)