はじめに

Go 言語初心者です。いつか Go で CLI を作ってみたいと思っていたのですが、アイディアも形になってきたのでようやく手を動かしはじめました。

今回は AWS SDK for Go v2 を使って、複数のリージョンから EC2 インスタンスの情報を取ってくるだけのごく簡単なサンプルを作ってみます。goroutine や channel を使った並行処理の実装もしてみました。リポジトリはこちらです。

構成

以下のような簡単な構成です。名前を go-test とします。

go-test
├── aws.go
├── go-test
├── go.mod
├── go.sum
├── main.go
└── util.go

リージョンの取得

複数リージョンに対して API を叩きますが、今回は DescribeRegions API を素で叩いたときの戻り値を for 文で回します。以下ですね。

$ aws ec2 describe-regions --query 'Regions[].RegionName'
[
    "ap-south-1",
    "eu-north-1",
    "eu-west-3",
    "eu-west-2",
    "eu-west-1",
    "ap-northeast-3",
    "ap-northeast-2",
    "ap-northeast-1",
    "ca-central-1",
    "sa-east-1",
    "ap-southeast-1",
    "ap-southeast-2",
    "eu-central-1",
    "us-east-1",
    "us-east-2",
    "us-west-1",
    "us-west-2"
]

Go で書く場合はこんな感じだと思います。

// aws.go

package main

import (
    "context"
    "fmt"

    "github.com/aws/aws-sdk-go-v2/aws"
    "github.com/aws/aws-sdk-go-v2/config"
    "github.com/aws/aws-sdk-go-v2/service/ec2"
)
...
func getAwsRegion(ctx context.Context, cfg *aws.Config) ([]string, error) {
    client := ec2.NewFromConfig(*cfg)

    obj, err := client.DescribeRegions(ctx, &ec2.DescribeRegionsInput{}, retryOpt)
    if err != nil {
        return nil, err
    }

    var res []string
    for _, r := range obj.Regions {
        fmt.Println(aws.ToString(r.RegionName))
        res = append(res, aws.ToString(r.RegionName))
    }

    return res, nil
}
...

動作を確認します。API の向き先となるリージョンなどの情報を config.LoadDefaultConfig() に渡してロードし、getAwsRegion() に渡します。

// main.go

package main

import (
    "context"

    "github.com/aws/aws-sdk-go-v2/config"
)

const (
    name   = "go-test"
    region = "ap-northeast-1"
)

func main() {
    ctx := context.Background()

    cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
    if err != nil {
        log.Fatal(err)
    }

    getAwsRegion(ctx, &cfg)
}

正しく実行されました。

$ go run .
ap-south-1
eu-north-1
eu-west-3
eu-west-2
eu-west-1
ap-northeast-3
ap-northeast-2
ap-northeast-1
ca-central-1
sa-east-1
ap-southeast-1
ap-southeast-2
eu-central-1
us-east-1
us-east-2
us-west-1
us-west-2

同期処理

次に、DescribeInstances API を実行して複数リージョンの EC2 インスタンス情報を同期的に取ってくる処理です。

// aws.go

type instanceInfo struct {
    Name             string
    InstanceId       string
    PrivateIpAddress string
    PublicIpAddress  string
    AvailabilityZone string
    State            types.InstanceStateName
}

var state = []string{
    "pending",
    "running",
    "stopping",
    "stopped",
}

func retryOpt(opt *ec2.Options) {
    opt.RetryMaxAttempts = 3
    opt.RetryMode = aws.RetryModeStandard
}

...

func getNameTagValue(tags []types.Tag) string {
    for _, t := range tags {
        if *t.Key == "Name" {
            return *t.Value
        }
    }
    return ""
}

func getAwsInstanceSync(ctx context.Context, cfg *aws.Config, filter ...types.Filter) ([]instanceInfo, error) {
    regions, err := getAwsRegion(ctx, cfg)
    if err != nil {
        return nil, err
    }

    f := []types.Filter{
        {
            Name:   aws.String("instance-state-name"),
            Values: state,
        },
    }

    if filter != nil {
        f = append(f, filter...)
    }

    var res []instanceInfo

    for _, region := range regions {
        ccfg := cfg.Copy()
        ccfg.Region = region

        client := ec2.NewFromConfig(ccfg)

        obj, err := client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{Filters: f}, retryOpt)
        if err != nil {
            return nil, err
        }

        for _, r := range obj.Reservations {
            for _, i := range r.Instances {
                out := instanceInfo{
                    getNameTagValue(i.Tags),
                    aws.ToString(i.InstanceId),
                    aws.ToString(i.PrivateIpAddress),
                    aws.ToString(i.PublicIpAddress),
                    aws.ToString(i.Placement.AvailabilityZone),
                    i.State.Name,
                }
                res = append(res, out)
            }
        }
    }

    return res, nil
}

さきほどの getAwsRegion() でリージョンを列挙し、for 文で回すたびにクライアントを初期化して DescribeInstances API を叩きます。
また、可変長引数で複数の types.Filter を受け取れるようにしました。これにより、以下のような感じで 0 個以上の types.Filter を渡すことができます。

...
out, err := getAwsInstanceSync(ctx, &cfg,
    types.Filter{
        Name:   aws.String("tag:Name"),
        Values: []string{"*stg*", "*staging*"},
    },
    types.Filter{
        Name:   aws.String("instance-type"),
        Values: []string{"t2.micro"},
    },
)
...

デフォルトで以下の types.Filter を適用し、shutting-downterminated のインスタンスは無視するようにしています。

var state = []string{
    "pending",
    "running",
    "stopping",
    "stopped",
}

...

func getAwsInstanceSync(ctx context.Context, cfg *aws.Config, filter ...types.Filter) ([]instanceInfo, error) {
    ...
    f := []types.Filter{
        {
            Name:   aws.String("instance-state-name"),
            Values: state,
        },
    }

    if filter != nil {
        f = append(f, filter...)
    }
    ...
}

また、以下のユーティリティ関数を用意しています。

  • printError: エラー出力のラッパー
  • printJson: 戻り値のスライスを JSON に変換
  • benchmark: 処理時間の計測
// util.go

package main

import (
    "encoding/json"
    "fmt"
    "os"
    "runtime"
    "time"

    "github.com/fatih/color"
)

func printError(err error) {
    if name == "" {
        fmt.Println(color.RedString("%v\n", err))
        return
    }
    fmt.Println(color.RedString("[%v] %v\n", name, err))
}

func printJson(obj any) {
    b, err := json.MarshalIndent(obj, "", "  ")
    if err != nil {
        printError(err)
    }
    os.Stdout.Write(b)
    fmt.Println("")
}

func benchmark(f func()) {
    baseTime := time.Now()

    f()

    fmt.Println("")
    fmt.Println(color.GreenString("CpuCoreNumber: %v", runtime.NumCPU()))
    fmt.Println(color.GreenString("GoroutineNumber: %v", runtime.NumGoroutine()))
    fmt.Println(color.GreenString("ElapsedTime: %v", time.Since(baseTime)))
    fmt.Println("")
}

Filter を指定せずに実行してみます。

// main.go

package main

import (
    "context"
    "os"

    "github.com/aws/aws-sdk-go-v2/config"
)

const (
    name   = "go-lab"
    region = "ap-northeast-1"
)

func main() {
    ctx := context.Background()

    cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
    if err != nil {
        printError(err)
        os.Exit(1)
    }

    benchmark(func() {
        out, err := getAwsInstanceSync(ctx, &cfg)
        if err != nil {
            printError(err)
            os.Exit(1)
        }
        printJson(out)
    })
}

以下の通り 13 秒かかっています。環境に依存しますが、検証時は 2 リージョンで合計 17 台が存在している状態でした。

$ go run .
[
  {
    "Name": "dummy-01",
    "InstanceId": "i-xxxxxxxxxxxxxxxxx",
    "PrivateIpAddress": "172.X.X.X",
    "PublicIpAddress": "",
    "AvailabilityZone": "ap-northeast-1a",
    "State": "stopped"
  },
  {
    "Name": "dummy-02",
    "InstanceId": "i-yyyyyyyyyyyyyyyyy",
    "PrivateIpAddress": "10.X.X.X",
    "PublicIpAddress": "",
    "AvailabilityZone": "ap-northeast-1d",
    "State": "stopped"
  },
  ...
]

CpuCoreNumber: 4
GoroutineNumber: 37
ElapsedTime: 12.603867384s

並行処理

次に並行処理を実装してみます。基本はそのままに、goroutine と channel を使います。やってみて、エラーハンドリングが難しく感じました。

// aws.go

func getAwsInstanceAsync(ctx context.Context, cfg *aws.Config, filter ...types.Filter) ([]instanceInfo, error) {
    regions, err := getAwsRegion(ctx, cfg)
    if err != nil {
        return nil, err
    }

    f := []types.Filter{
        {
            Name:   aws.String("instance-state-name"),
            Values: state,
        },
    }

    if filter != nil {
        f = append(f, filter...)
    }

    ich := make(chan instanceInfo)
    ech := make(chan error)
    var wg sync.WaitGroup

    ctx, cancel := context.WithCancel(ctx)
    defer cancel()

    for _, region := range regions {
        wg.Add(1)

        go func(region string) {
            defer wg.Done()

            ccfg := cfg.Copy()
            ccfg.Region = region

            client := ec2.NewFromConfig(ccfg)

            obj, err := client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{Filters: f}, retryOpt)
            if err != nil {
                ech <- err
                return
            }

            for _, r := range obj.Reservations {
                for _, i := range r.Instances {
                    out := instanceInfo{
                        getNameTagValue(i.Tags),
                        aws.ToString(i.InstanceId),
                        aws.ToString(i.PrivateIpAddress),
                        aws.ToString(i.PublicIpAddress),
                        aws.ToString(i.Placement.AvailabilityZone),
                        i.State.Name,
                    }
                    ich <- out
                }
            }
        }(region)
    }

    go func() {
        wg.Wait()
        close(ich)
        close(ech)
    }()

    var res []instanceInfo

    for {
        select {
        case i, ok := <-ich:
            if ok {
                res = append(res, i)
            } else {
                ich = nil
            }
        case err, ok := <-ech:
            if ok {
                return nil, err
            } else {
                ech = nil
            }
        }
        if ich == nil && ech == nil {
            break
        }
    }

    return res, nil
}

ゴルーチン間でインスタンス情報とエラーを伝達するための channel を作成します。

ich := make(chan instanceInfo)
ech := make(chan error)

ゴルーチンがすべて終了したことを検知するための WaitGroup を作成します。

var wg sync.WaitGroup

キャンセル可能な context を作成します。のちほど DescribeInstances する際に引数として渡します。

ctx, cancel := context.WithCancel(ctx)
defer cancel()

リージョンごとにゴルーチンを起動して以下の処理を行います。

  • リージョンを読み込んで EC2 クライアントを作成し、EC2 インスタンス情報を取得
  • エラーが発生したらエラーチャネルを通じてメインゴルーチンに送信して処理を終了
  • インスタンス情報を整形し、それ用のチャネルを通じてメインゴルーチンに送信

すべてのゴルーチンの終了を待つために sync.WaitGroup を使います。

  • Add()WaitGroup をインクリメント
  • Done() でデクリメント
  • Wait ですべてのゴルーチンが終了するまで (0 になるまで) ブロック
for _, region := range regions {
    wg.Add(1)

    go func(region string) {
        defer wg.Done()

        ccfg := cfg.Copy()
        ccfg.Region = region

        client := ec2.NewFromConfig(ccfg)

        obj, err := client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{Filters: f}, retryOpt)
        if err != nil {
            ech <- err
            return
        }

        for _, r := range obj.Reservations {
            for _, i := range r.Instances {
                out := instanceInfo{
                    getNameTagValue(i.Tags),
                    aws.ToString(i.InstanceId),
                    aws.ToString(i.PrivateIpAddress),
                    aws.ToString(i.PublicIpAddress),
                    aws.ToString(i.Placement.AvailabilityZone),
                    i.State.Name,
                }
                ich <- out
            }
        }
    }(region)
}

すべてのゴルーチンが終了したときにチャネルを閉じるための別のゴルーチンを起動します。

go func() {
    wg.Wait()
    close(ich)
    close(ech)
}()

チャネルからインスタンス情報、またはエラーを受信します。select 文で複数のチャネルを同時に待ち受けます。

  • エラーを受信した場合: その時点で関数を終了
  • インスタンス情報を受信した場合: インスタンス情報を結果のスライスに追加
  • 両方のチャネルが閉じられた(nil になった)場合、ループは終了
var res []instanceInfo

for {
    select {
    case i, ok := <-ich:
        if ok {
            res = append(res, i)
        } else {
            ich = nil
        }
    case err, ok := <-ech:
        if ok {
            return nil, err
        } else {
            ech = nil
        }
    }
    if ich == nil && ech == nil {
        break
    }
}

実行します。結果がどう変わるか見てみましょう。

// main.go

package main

import (
    "context"
    "os"

    "github.com/aws/aws-sdk-go-v2/config"
)

const (
    name   = "go-lab"
    region = "ap-northeast-1"
)

func main() {
    ctx := context.Background()

    cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
    if err != nil {
        printError(err)
        os.Exit(1)
    }

    benchmark(func() {
        out, err := getAwsInstanceAsync(ctx, &cfg)
        if err != nil {
            printError(err)
            os.Exit(1)
        }
        printJson(out)
    })
}

2 秒かからずに完了しました。並行処理の威力を身をもって体感することができました (もちろん環境差はあります)

$ go run .
[
  {
    "Name": "dummy-01",
    "InstanceId": "i-xxxxxxxxxxxxxxxxx",
    "PrivateIpAddress": "172.X.X.X",
    "PublicIpAddress": "",
    "AvailabilityZone": "ap-northeast-1a",
    "State": "stopped"
  },
  {
    "Name": "dummy-02",
    "InstanceId": "i-yyyyyyyyyyyyyyyyy",
    "PrivateIpAddress": "10.X.X.X",
    "PublicIpAddress": "",
    "AvailabilityZone": "ap-northeast-1d",
    "State": "stopped"
  },
  ...
]

CpuCoreNumber: 4
GoroutineNumber: 37
ElapsedTime: 1.372817432s

ウェイトやリトライについて

並行処理で外部の API をコールする場合などは、提供元の制限に対しては提供元に応じた適切な対処が必要です。つまり提供元の制限に応じて、ウェイトの挿入やリトライアルゴリズムの実装などを検討する必要があります。
AWS SDK の場合は Exponential Backoff and Jitter が実装されており、リトライのためのオプションを API に無名関数として渡すことができます。以下の部分ですね。

...
retryOpt := func(opt *ec2.Options) {
    opt.RetryMaxAttempts = 3
    opt.RetryMode = aws.RetryModeStandard
}
...
obj, err := client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{Filters: f}, retryOpt)

なおリトライのためのオプションに関しては、自前のリトライアルゴリズムを渡すこともできます。このあたりのノウハウについてはこちらの記事が詳しいです (ありがとうございます)

errgroup の使用

並行処理のエラーハンドリングが難しいと感じましたが、errgroup というライブラリを使えば上述のような「エラー発生時にすべてのゴルーチンをキャンセルする」といった動作をもっと簡単に書けるようです。エラーチャネルをハンドリングする必要がなくなり、記述が簡単になりました。

// aws.go

func getAwsInstanceAsync2(ctx context.Context, cfg *aws.Config, filter ...types.Filter) ([]instanceInfo, error) {
    regions, err := getAwsRegion(ctx, cfg)
    if err != nil {
        return nil, err
    }

    f := []types.Filter{
        {
            Name:   aws.String("instance-state-name"),
            Values: state,
        },
    }

    if filter != nil {
        f = append(f, filter...)
    }

    var mu sync.Mutex
    var res []instanceInfo

    eg, ctx := errgroup.WithContext(ctx)

    for _, region := range regions {
        region := region

        eg.Go(func() error {
            ccfg := cfg.Copy()
            ccfg.Region = region

            client := ec2.NewFromConfig(ccfg)

            obj, err := client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{Filters: f}, retryOpt)
            if err != nil {
                return err
            }

            for _, r := range obj.Reservations {
                for _, i := range r.Instances {
                    out := instanceInfo{
                        getNameTagValue(i.Tags),
                        aws.ToString(i.InstanceId),
                        aws.ToString(i.PrivateIpAddress),
                        aws.ToString(i.PublicIpAddress),
                        aws.ToString(i.Placement.AvailabilityZone),
                        i.State.Name,
                    }

                    mu.Lock()
                    res = append(res, out)
                    mu.Unlock()
                }
            }

            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        return nil, err
    }

    return res, nil
}

まず errgroup を作ります。

eg, ctx := errgroup.WithContext(ctx)

このケースでは channel を使わずに sync.Mutex を使い、共有のスライスに対して排他制御を行っています。

var mu sync.Mutex

...

for _, r := range obj.Reservations {
    for _, i := range r.Instances {
        out := instanceInfo{
            getNameTagValue(i.Tags),
            aws.ToString(i.InstanceId),
            aws.ToString(i.PrivateIpAddress),
            aws.ToString(i.PublicIpAddress),
            aws.ToString(i.Placement.AvailabilityZone),
            i.State.Name,
        }

        mu.Lock()
        res = append(res, out)
        mu.Unlock()
    }
}

eg.Wait() ですべてのゴルーチンを待ちます。

if err := eg.Wait(); err != nil {
    return nil, err
}

おわりに

AWS SDK for Go v2 で Go に入門してみました。どんどん書いてもっと慣れていきたいと思います。