記事の目的

SageMakerの分散トレーニング、デプロイオーケストレーションシステムを活用するための

自作pytorchプログラムの改修方法理解

対象者

基礎的な機械学習知識を所有しており、

SageMakerの分散トレーニング、デプロイオーケストレーションシステムを活用したい人

この記事を読み終わるまでの時間

10m

基本的な処理の流れ

事前セットアップ

import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/DEMO-pytorch-mnist'

role = sagemaker.get_execution_role()

使用、訓練データ取得

from torchvision import datasets, transforms

datasets.MNIST('data', download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))

s3へのデータアップロード

inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))

sageMakerのエコシステムを用いて訓練、デプロイするため、sagemakerに内包されたPytorchクラスを生成

from sagemaker.pytorch import PyTorch

estimator = PyTorch(entry_point='mnist.py',
                    role=role,
                    framework_version='1.1.0',
                    train_instance_count=2,
                    train_instance_type='ml.c4.xlarge',
                    hyperparameters={
                        'epochs': 6,
                        'backend': 'gloo'
                    })

訓練の実施

estimator.fit({'training': inputs})

本題:SageMakerエコシステムに適合するためのmnist.pyの書き方

ポイント

  • input_fn、predict_fn、model_fn、output_fnの4兄弟を作成しない場合は以下のdefualt_**関数4種類が呼び出される。
  • つまりmodel_fnに関しては、実装がないとエラーとなる。
  • 他のものについては、default_**の処理で問題なければ実装の必要はない。

default関数①

# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_serving.py より抜粋
def default_input_fn(input_data, content_type):
    np_array = encoders.decode(input_data, content_type)
    return np_array.astype(np.float32) if content_type in content_types.UTF8_TYPES else np_array# https://github.com/aws/sagemaker-inference-toolkit/blob/master/ src/sagemaker_inference/decoder.pyより抜粋

def decode(obj, content_type):
    try:
    decoder = _decoder_map[content_type] return decoder(obj)
        except KeyError:
            raise errors.UnsupportedFormatError(content_type)
    _decoder_map = {
    content_types.NPY: _npy_to_numpy, content_types.CSV: _csv_to_numpy, content_types.JSON: _json_to_numpy, content_types.NPZ: _npz_to_sparse,
    }
    def _json_to_numpy(string_like, dtype=None): # type: (str) -> np.arrayapplication/json の場合 リクエストデータを 32bit floatの numpy arrayとして解釈する
    data = json.loads(string_like)
    return np.array(data, dtype=dtype)

default関数②

# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_serving.py より抜粋
def default_output_fn(prediction, accept):
    return encoders.encode(prediction, accept), accept# https://github.com/aws/sagemaker-inference-toolkit/blob/master/ src/sagemaker_inference/encoder.pyより抜粋

def encode(array_like, content_type):
    try:
        encoder = _encoder_map[content_type]
        return encoder(array_like)
    except KeyError:
        raise errors.UnsupportedFormatError(content_type)
_encoder_map = {
content_types.NPY: _array_to_npy, content_types.CSV: _array_to_csv, content_types.JSON: _array_to_json,
        }

def _array_to_json(array_like):
    def default(_array_like):
        if hasattr(_array_like, "tolist"):
                    return _array_like.tolist()
        return json.JSONEncoder().default(_array_like)

default関数③

def default_model_fn(model_dir):
    """Loads a model. For Scikit-learn, a default function to load a model is not provided. Users should provide customized model_fn() in script.
    Args:
               model_dir: a directory where model is saved.Returns: A Scikit-learn model.
    """
    raise NotImplementedError(textwrap.dedent("""
    Please provide a model_fn implementation.
    See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk """))

default関数④

# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_service.py より抜粋
def default_predict_fn(input_data, model):
"""A default predict_fn for Scikit-learn. Calls a model on data deserialized in input_fn. Args:
          input_data: input data (Numpy array) for prediction deserialized by input_fn
model: Scikit-learn model loaded in memory by model_fn Returns: a prediction
"""
    output = model.predict(input_data)
    return output

boto3からのリクエスト例(invoke_endpointを実施)

response = smr_client.invoke_endpoint( 
    EndpointName=endpoint_name, 
    ContentType='text/csv', 
    Accept='text/csv', 
    Body='1,2,3,10000'
)
predictions = response['Body'].read().decode('utf-8') print(predictions)

上記のポイント

  • ContentTypeをinput_fnで使用
  • Acceptをoutput_fnで使用
  • 実際にmodelに渡されるのはBody
  • 処理の流れは以下
    • (エンドポイント立ち上げ時)model_fn
    • (API実行時)input_fn→predict_fn→output_fn

fnメソッド4種類ののカスタム例(上記API呼び出しを行う想定での改修)

# モデル読み込み
def model_fn(model_dir):
  with open(os.path.join(model_dir,'my_model.txt')) as f:
      hello = f.read()[:-1]
  return hello

# 前処理
def input_fn(input_data, content_type): 
    if content_type == 'text/csv':
        transformed_data = input_data.split(',')
  else:
      raise ValueError("Illegal content type")
  return transformed_data

# 予測
def predict_fn(transformed_data, model): prediction_list = []
    for data in transformed_data:
        if data[-1] == '1':
          ordinal = f'{data}st'
        elif data[-1] == '2':
          ordinal = f'{data}nd'
                elif data[-1] == '3': 
                    ordinal = f'{data}rd'
        else:
            ordinal = f'{data}th'
        prediction = f'{model} for the {ordinal} time'
        prediction_list.append(prediction)
    return prediction_list

# 後処理
def output_fn(prediction_list, accept):
    if accept == 'text/csv':
        response = ''
        for prediction in prediction_list:
            response += prediction + '¥n'
    else:
            raise ValueError("Illegal accept type")
            return response, accept

上記処理の際の想定output:

Hello my great machine learning model for the 1st time 
Hello my great machine learning model for the 2nd time 
Hello my great machine learning model for the 3rd time 
Hello my great machine learning model for the 10000th time

参考記事

default_output_fn参考元:

https://github.com/aws/sagemaker-scikit-learn-container/blob/master/src/sagemaker_sklearn_container/handler_service.py

メインロジック参考元:

https://github.com/aws/amazon-sagemaker-examples/blob/b01821341caf3d6af351e852b0dd9955db0e4515/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb

プログラムコード参考元:

https://d1.awsstatic.com/webinars/jp/pdf/services/202208_AWS_Black_Belt_AWS_AIML_Dark_04_inference_part2.pdf