やりたいこと

担当案件でRDS(MYSQL)への接続が必要となったが、用途や接続先により接続方法が異なったため、備忘録として残しておく。


前提

  • SQL Alchemy や Alembic に関する説明やコマンド等は割愛
  • 前者がORM、後者はマイグレーションツール
  • プロジェクトにはAlembicで使用可能な Model は用意されているものとする

環境

  • Python: 3.11
  • RDS: Aurora MySQL 3.03.1 (compatible with MySQL 8.0.26)
  • SQL Alchemy: 2.0.16
  • ALembic: 1.11.1
  • sshtunnel: 0.4.0

ケース1: ローカル環境からの接続

開発環境としてローカルからAWSのRDSへアクセスする。
APIの実行およびDBのマイグレーションのいずれも行う。

RDSはpublicからのアクセスを許容している状態とする。

↓ 接続用コード

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base

dialect = "mysql"
driver = "pymysql"
username = DBのユーザー名
password = DBのパスワード
host = DBのエンドポイント
port = DBのポート番号
database = データベース名
charset_type = "utf8mb4"

# basic認証
def get_connection():
    db_url = f"{dialect}+{driver}://{username}:{password}@{host}:{port}/{database}?charset={charset_type}"

    engine = create_engine(db_url, echo=True, pool_pre_ping=True)

    self.session_local = scoped_session(
        sessionmaker(autocommit=False, autoflush=False, bind=engine)
    )

     return (declarative_base(), None)

def get_session(self):
        return session_local()

↓ Alembic.ini

sqlalchemy.url = mysql+pymysql://DBのユーザー名:DBのパスワード@DBのエンドポイント:DBのポート番号/データベース名?charset=utf8mb4

ここは特に難しいこともなく一般的な接続でOK.


ケース2: AWS(Lambda)からの接続

AWS 上にデプロイされた Lambda から RDS へのアクセス方法。
認証情報は SecretsManager に保存しているため、ソース上からは取り除いている。

このケースはAPIの実行による接続のみで、マイグレーションは行わいので、Alembic.iniは不要。

↓ 接続用コード

import os
import json
import boto3

from app.core.config import get_app_settings
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import event

from typing import Dict
from botocore.exceptions import ClientError

# AWS認証
def get_connection():
    secret_name = Secrets Managerのキー名
    region_name = AWSのリージョン

    session = boto3.session.Session()
    client = session.client(service_name="secretsmanager", region_name=region_name)

    try:
        get_secret_value_response = client.get_secret_value(SecretId=secret_name)
    except ClientError as e:
        raise e

    secret = get_secret_value_response["SecretString"]
    values: Dict = json.loads(secret)

    # 認証情報は Secrets Manager から取得する
    dialect = "mysql"
    driver = "pymysql"
    username = values[Secrets Managerから取得した設定一覧にあるDBのユーザー名]
    host = values[Secrets Managerから取得した設定一覧にあるDBのエンドポイント]
    port = values[Secrets Managerから取得した設定一覧にあるDBのポート番号]
    database = DBのデータベース名
    charset_type = "utf8mb4"

    db_url = f"{dialect}+{driver}://{username}@{host}:{port}/{database}?charset={charset_type}"

    engine = create_engine(db_url, echo=True, pool_pre_ping=True)

    # TLS用にpemファイルを認証に使用
    @event.listens_for(engine, "do_connect")
    def provide_token(dialect, conn_rec, cargs, cparams):
        rds_client = boto3.client("rds")
        cparams["ssl"] = {"ca": os.getcwd() + "/ssh/AmazonRootCA1.pem"}
        cparams["password"] = rds_client.generate_db_auth_token(
            DBHostname=host, Port=port, DBUsername=username, Region=region_name
        )

    session_local = scoped_session(
        sessionmaker(autocommit=False, autoflush=False, bind=engine)
    )

     return (declarative_base(), None)

def get_session(self):
        return session_local()

ポイントは def provide_token() のあたり。

LambdaからRDSへのアクセスはSSL通信を行う必要がある。
provide_token を追加することで、証明書の検証フローが追加される。

pemファイルは下記のURLからダウンロード可能なので、プロジェクトに組み込んでおくとよい。

↓ 参考
https://stackoverflow.com/questions/69447368/connecting-to-mysql-using-a-token
https://docs.aws.amazon.com/ja_jp/dms/latest/userguide/CHAP_Security.SSL.html

↓ 中間 CA 証明書
https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem


ケース3: EC2を踏み台とした接続

実環境のRDSがAWSのPrivate Subnetにいるため、外部からは直接アクセス出来ないのはよくある話。
これではローカルからAlembicを使ったマイグレーションが出来ず、DBの構築・設定変更が出来ない。

回避策として EC2 を立ち上げ、踏み台サーバーとして利用する。

各種DBのクライアントツール(MYSQL Workbenchとか)にもある設定。
これをソースコード上でやってみる。

RDSはEC2からのアクセスをセキュリティグループのインバウンドにて許容する。
外部からはSSH経由でEC2にアクセスし、EC2を踏み台としてRDSを操作する。
(SSHはMUSTではないけど、ついでに)

↓ 接続用コード

import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sshtunnel import SSHTunnelForwarder

SSH_PKEY = EC2にSSH接続する際の秘密鍵
SSH_ADDRESS = EC2のIPアドレス
SSH_PORT = 22
SSH_USERNAME = EC2のユーザー名
SSH_REMOTE_BIND_ADDRESS = DBのエンドポイント

LOCAL_HOST = "127.0.0.1"
LOCAL_PORT = ローカルPC側のポート(1~65535で任意。well knownポートがあるので実質1024~)

DB_USERNAME = DBのユーザー名
DB_PASSWORD = DBのパスワード
DB_PORT = DBのポート番号
DATABASE_NAME = データベース名


##############################################################
# 実行前にEC2に登録した鍵(id_rsa) を sshフォルダに置く
##############################################################

# SSH Tunnel認証
def get_connection():
    server = SSHTunnelForwarder(
        (SSH_ADDRESS, SSH_PORT),
        ssh_username=SSH_USERNAME,
        ssh_pkey=SSH_PKEY,
        local_bind_address=(LOCAL_HOST, LOCAL_PORT),
        remote_bind_address=(SSH_REMOTE_BIND_ADDRESS, DB_PORT),
        allow_agent=False,
    )

    server.start()

    db_url = f"mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{LOCAL_HOST}:{LOCAL_PORT}/{DATABASE_NAME}"

    engine = create_engine(db_url, echo=True)

    session_local = scoped_session(
        sessionmaker(autocommit=False, autoflush=False, bind=engine)
    )

    return (declarative_base(), server)

def get_session(self):
        return session_local()

↓ Alembic.ini

sqlalchemy.url = mysql+pymysql://DB_USERNAME:DB_PASSWORD@127.0.0.1:LOCAL_PORT/DATABASE_NAME?charset=utf8mb4

create_engine() の前に踏み台サーバーアクセスのためのひと手間加えるだけでよい。

SQLAlchemy のみであれば LOCAL_PORT は指定せず、 SSHTunnelForwarder に任せることも可能だが、
Alembicで指定する必要があるため、固定値にする必要がある。
あとは EC2 / RDS の設定値をのそのままセットすればよい。

EC2へはSSHでアクセスするため、事前に鍵ペアの生成と公開鍵のEC2への登録が必要となる。
SSH_PKEY は秘密鍵(id_rsa)のパスを指定する。

今回の案件では鍵の有効期間が1分なので、鍵の登録~実行がわちゃわちゃする。
(ここを自動化しておけばよかった)


おまけ: ケースごとの切り替え方法

ローカルか実環境か、などを実行時に意識しなくてよいよう、環境設定に連動するようにする。

最近すっかり聞かなくなったデザインパターンからFactory Methodをちょっとだけ流用。
抽象クラスを用意して、上記の3パターンでそれぞれ抽象クラスを継承したクラスを実装。

DBを操作したいレイヤー用に下記クラスを用意し、外側からは get_db() だけ呼べばあとはよろしくやってくれるようにしておく。

DB = None
SessionLocal = None
Server = None


# DB種別取得
TYPE = 環境設定.db_type

# 種別ごとにインスタンスを切り替える
if TYPE == ケース1(ローカル実行):
    DB = DbAws()
elif TYPE == ケース2(AWS実行):
    DB = DbBasic()
elif TYPE == ケース3(踏み台アクセス):
    DB = DbSshTunnel()
else:
    raise Exception

Base, Server = DB.get_connection()
SessionLocal = DB.get_session()


def get_db():
    try:
        yield SessionLocal
    finally:
        SessionLocal.close()
        if Server is not None:
            Server.close()

注意点としては、sshtunnelは2023/12時点では python 3.11 に対応していないので、コメントアウトしておく必要がある。。。


最後に

用途や接続先によってやり方が微妙に異なるのはよくある話。
こういうのを全て網羅したpip のモジュールがあってもよさそうだけど、ない。

RDSにアクセスする以上ついて回る話なので、今後に活かしていく。