Add tests + fix db.py SQLite commit pattern

Tests:
- tests/test_commands.py: parse_args, extract_args, format_bounty
- tests/test_db.py: full CRUD + tracking + reminders
- tests/conftest.py: temp DB fixture
- requirements-dev.txt: pytest + pytest-asyncio

db.py fixes:
- Explicit conn.commit() after every write (SQLite row_factory
  disables implicit transaction management)
- fetchone() before commit() (can't commit while cursor open)
- Functions return dict instead of sqlite3.Row
This commit is contained in:
shokollm
2026-04-01 08:41:44 +00:00
parent 9f0ad2d404
commit 7957947a04
6 changed files with 543 additions and 29 deletions

View File

@@ -10,11 +10,14 @@ DB_PATH = Path(__file__).parent / "jigaido.db"
def get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, detect_types=sqlite3.PARSE_DECLTYPES)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
return conn
def _row_to_dict(row: sqlite3.Row) -> dict:
return dict(row)
def init_db() -> None:
schema = (Path(__file__).parent / "schema.sql").read_text()
with get_conn() as conn:
@@ -32,15 +35,18 @@ def upsert_user(telegram_user_id: int, username: str | None) -> int:
RETURNING id""",
(telegram_user_id, username),
)
return cur.fetchone()["id"]
result = cur.fetchone()
conn.commit()
return result[0]
def get_user_by_telegram_id(telegram_user_id: int) -> Optional[sqlite3.Row]:
def get_user_by_telegram_id(telegram_user_id: int) -> Optional[dict]:
with get_conn() as conn:
return conn.execute(
row = conn.execute(
"SELECT * FROM users WHERE telegram_user_id = ?",
(telegram_user_id,),
).fetchone()
return _row_to_dict(row) if row else None
# ── Groups ─────────────────────────────────────────────────────────────────
@@ -56,15 +62,18 @@ def upsert_group(telegram_chat_id: int, creator_user_id: int) -> int:
RETURNING id""",
(telegram_chat_id, creator_user_id),
)
return cur.fetchone()["id"]
result = cur.fetchone()
conn.commit()
return result[0]
def get_group(telegram_chat_id: int) -> Optional[sqlite3.Row]:
def get_group(telegram_chat_id: int) -> Optional[dict]:
with get_conn() as conn:
return conn.execute(
row = conn.execute(
"SELECT * FROM groups WHERE telegram_chat_id = ?",
(telegram_chat_id,),
).fetchone()
return _row_to_dict(row) if row else None
def get_group_creator_user_id(group_id: int) -> Optional[int]:
@@ -73,7 +82,7 @@ def get_group_creator_user_id(group_id: int) -> Optional[int]:
"SELECT creator_user_id FROM groups WHERE id = ?",
(group_id,),
).fetchone()
return row["creator_user_id"] if row else None
return row[0] if row else None
# ── Group Admins ────────────────────────────────────────────────────────────
@@ -86,6 +95,7 @@ def add_group_admin(group_id: int, user_id: int) -> bool:
"INSERT INTO group_admins (group_id, user_id) VALUES (?, ?)",
(group_id, user_id),
)
conn.commit()
return True
except sqlite3.IntegrityError:
return False
@@ -98,6 +108,7 @@ def remove_group_admin(group_id: int, user_id: int) -> bool:
"DELETE FROM group_admins WHERE group_id = ? AND user_id = ?",
(group_id, user_id),
)
conn.commit()
return cur.rowcount > 0
@@ -114,13 +125,14 @@ def is_group_creator(group_id: int, user_id: int) -> bool:
return get_group_creator_user_id(group_id) == user_id
def get_user_by_username(username: str) -> Optional[sqlite3.Row]:
def get_user_by_username(username: str) -> Optional[dict]:
"""Look up user by username (without @)."""
with get_conn() as conn:
return conn.execute(
row = conn.execute(
"SELECT * FROM users WHERE username = ?",
(username,),
).fetchone()
return _row_to_dict(row) if row else None
# ── Bounties ────────────────────────────────────────────────────────────────
@@ -143,33 +155,36 @@ def add_bounty(
RETURNING id""",
(group_id, created_by_user_id, informed_by_username, text, link, due_date_ts),
)
return cur.fetchone()["id"]
result = cur.fetchone()
conn.commit()
return result[0]
except sqlite3.IntegrityError as e:
if "UNIQUE" in str(e) and "link" in str(e):
raise ValueError(f"Link already exists in this group: {link}")
raise
def get_bounty(bounty_id: int) -> Optional[sqlite3.Row]:
def get_bounty(bounty_id: int) -> Optional[dict]:
with get_conn() as conn:
return conn.execute("SELECT * FROM bounties WHERE id = ?", (bounty_id,)).fetchone()
row = conn.execute("SELECT * FROM bounties WHERE id = ?", (bounty_id,)).fetchone()
return _row_to_dict(row) if row else None
def get_group_bounties(group_id: int) -> list[sqlite3.Row]:
def get_group_bounties(group_id: int) -> list[dict]:
with get_conn() as conn:
return list(conn.execute(
return [_row_to_dict(r) for r in conn.execute(
"SELECT * FROM bounties WHERE group_id = ? ORDER BY created_at DESC",
(group_id,),
))
)]
def get_user_personal_bounties(user_id: int) -> list[sqlite3.Row]:
def get_user_personal_bounties(user_id: int) -> list[dict]:
"""Bounties created by user in DM (group_id IS NULL)."""
with get_conn() as conn:
return list(conn.execute(
return [_row_to_dict(r) for r in conn.execute(
"SELECT * FROM bounties WHERE group_id IS NULL AND created_by_user_id = ? ORDER BY created_at DESC",
(user_id,),
))
)]
def update_bounty(
@@ -189,6 +204,7 @@ def update_bounty(
WHERE id = ?""",
(text, link, due_date_ts, bounty_id),
)
conn.commit()
return cur.rowcount > 0
except sqlite3.IntegrityError as e:
if "UNIQUE" in str(e) and "link" in str(e):
@@ -199,6 +215,7 @@ def update_bounty(
def delete_bounty(bounty_id: int) -> bool:
with get_conn() as conn:
cur = conn.execute("DELETE FROM bounties WHERE id = ?", (bounty_id,))
conn.commit()
return cur.rowcount > 0
@@ -212,6 +229,7 @@ def track_bounty(user_id: int, bounty_id: int) -> bool:
"INSERT INTO user_bounty_tracking (user_id, bounty_id) VALUES (?, ?)",
(user_id, bounty_id),
)
conn.commit()
return True
except sqlite3.IntegrityError:
return False
@@ -223,6 +241,7 @@ def untrack_bounty(user_id: int, bounty_id: int) -> bool:
"DELETE FROM user_bounty_tracking WHERE user_id = ? AND bounty_id = ?",
(user_id, bounty_id),
)
conn.commit()
return cur.rowcount > 0
@@ -235,37 +254,37 @@ def is_tracking(user_id: int, bounty_id: int) -> bool:
return row is not None
def get_user_tracked_bounties_in_group(user_id: int, group_id: int) -> list[sqlite3.Row]:
def get_user_tracked_bounties_in_group(user_id: int, group_id: int) -> list[dict]:
with get_conn() as conn:
return list(conn.execute(
return [_row_to_dict(r) for r in conn.execute(
"""SELECT b.* FROM bounties b
JOIN user_bounty_tracking t ON t.bounty_id = b.id
WHERE t.user_id = ? AND b.group_id = ?
ORDER BY b.created_at DESC""",
(user_id, group_id),
))
)]
def get_user_tracked_bounties_personal(user_id: int) -> list[sqlite3.Row]:
def get_user_tracked_bounties_personal(user_id: int) -> list[dict]:
"""Tracked bounties where group_id IS NULL (personal)."""
with get_conn() as conn:
return list(conn.execute(
return [_row_to_dict(r) for r in conn.execute(
"""SELECT b.* FROM bounties b
JOIN user_bounty_tracking t ON t.bounty_id = b.id
WHERE t.user_id = ? AND b.group_id IS NULL
ORDER BY b.created_at DESC""",
(user_id,),
))
)]
# ── Reminders ───────────────────────────────────────────────────────────────
def get_bounties_due_soon(user_id: int, days: int = 7) -> list[sqlite3.Row]:
def get_bounties_due_soon(user_id: int, days: int = 7) -> list[dict]:
"""Get tracked bounties with due_date within `days` that haven't been reminded yet."""
now = int(time.time())
deadline = now + days * 86400
with get_conn() as conn:
return list(conn.execute(
return [_row_to_dict(r) for r in conn.execute(
"""SELECT b.*, u.username, u.telegram_user_id FROM bounties b
JOIN user_bounty_tracking t ON t.bounty_id = b.id
JOIN users u ON u.id = b.created_by_user_id
@@ -278,7 +297,7 @@ def get_bounties_due_soon(user_id: int, days: int = 7) -> list[sqlite3.Row]:
)
ORDER BY b.due_date_ts ASC""",
(user_id, deadline, now, user_id),
))
)]
def log_reminder(user_id: int, bounty_id: int) -> None:
@@ -287,8 +306,9 @@ def log_reminder(user_id: int, bounty_id: int) -> None:
"INSERT OR IGNORE INTO reminder_log (user_id, bounty_id) VALUES (?, ?)",
(user_id, bounty_id),
)
conn.commit()
def get_all_user_ids() -> list[int]:
with get_conn() as conn:
return [row["telegram_user_id"] for row in conn.execute("SELECT telegram_user_id FROM users")]
return [row[0] for row in conn.execute("SELECT telegram_user_id FROM users")]