記事の目的
SageMakerの分散トレーニング、デプロイオーケストレーションシステムを活用するための
自作pytorchプログラムの改修方法理解
対象者
基礎的な機械学習知識を所有しており、
SageMakerの分散トレーニング、デプロイオーケストレーションシステムを活用したい人
この記事を読み終わるまでの時間
10m
基本的な処理の流れ
事前セットアップ
import sagemaker sagemaker_session = sagemaker.Session() bucket = sagemaker_session.default_bucket() prefix = 'sagemaker/DEMO-pytorch-mnist' role = sagemaker.get_execution_role()
使用、訓練データ取得
from torchvision import datasets, transforms datasets.MNIST('data', download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]))
s3へのデータアップロード
inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix) print('input spec (in this case, just an S3 path): {}'.format(inputs))
sageMakerのエコシステムを用いて訓練、デプロイするため、sagemakerに内包されたPytorchクラスを生成
from sagemaker.pytorch import PyTorch estimator = PyTorch(entry_point='mnist.py', role=role, framework_version='1.1.0', train_instance_count=2, train_instance_type='ml.c4.xlarge', hyperparameters={ 'epochs': 6, 'backend': 'gloo' })
訓練の実施
estimator.fit({'training': inputs})
本題:SageMakerエコシステムに適合するためのmnist.pyの書き方
ポイント
- input_fn、predict_fn、model_fn、output_fnの4兄弟を作成しない場合は以下のdefualt_**関数4種類が呼び出される。
- つまりmodel_fnに関しては、実装がないとエラーとなる。
- 他のものについては、default_**の処理で問題なければ実装の必要はない。
default関数①
# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_serving.py より抜粋 def default_input_fn(input_data, content_type): np_array = encoders.decode(input_data, content_type) return np_array.astype(np.float32) if content_type in content_types.UTF8_TYPES else np_array# https://github.com/aws/sagemaker-inference-toolkit/blob/master/ src/sagemaker_inference/decoder.pyより抜粋 def decode(obj, content_type): try: decoder = _decoder_map[content_type] return decoder(obj) except KeyError: raise errors.UnsupportedFormatError(content_type) _decoder_map = { content_types.NPY: _npy_to_numpy, content_types.CSV: _csv_to_numpy, content_types.JSON: _json_to_numpy, content_types.NPZ: _npz_to_sparse, } def _json_to_numpy(string_like, dtype=None): # type: (str) -> np.arrayapplication/json の場合 リクエストデータを 32bit floatの numpy arrayとして解釈する data = json.loads(string_like) return np.array(data, dtype=dtype)
default関数②
# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_serving.py より抜粋 def default_output_fn(prediction, accept): return encoders.encode(prediction, accept), accept# https://github.com/aws/sagemaker-inference-toolkit/blob/master/ src/sagemaker_inference/encoder.pyより抜粋 def encode(array_like, content_type): try: encoder = _encoder_map[content_type] return encoder(array_like) except KeyError: raise errors.UnsupportedFormatError(content_type) _encoder_map = { content_types.NPY: _array_to_npy, content_types.CSV: _array_to_csv, content_types.JSON: _array_to_json, } def _array_to_json(array_like): def default(_array_like): if hasattr(_array_like, "tolist"): return _array_like.tolist() return json.JSONEncoder().default(_array_like)
default関数③
def default_model_fn(model_dir): """Loads a model. For Scikit-learn, a default function to load a model is not provided. Users should provide customized model_fn() in script. Args: model_dir: a directory where model is saved.Returns: A Scikit-learn model. """ raise NotImplementedError(textwrap.dedent(""" Please provide a model_fn implementation. See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk """))
default関数④
# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_service.py より抜粋 def default_predict_fn(input_data, model): """A default predict_fn for Scikit-learn. Calls a model on data deserialized in input_fn. Args: input_data: input data (Numpy array) for prediction deserialized by input_fn model: Scikit-learn model loaded in memory by model_fn Returns: a prediction """ output = model.predict(input_data) return output
boto3からのリクエスト例(invoke_endpointを実施)
response = smr_client.invoke_endpoint( EndpointName=endpoint_name, ContentType='text/csv', Accept='text/csv', Body='1,2,3,10000' ) predictions = response['Body'].read().decode('utf-8') print(predictions)
上記のポイント
- ContentTypeをinput_fnで使用
- Acceptをoutput_fnで使用
- 実際にmodelに渡されるのはBody
- 処理の流れは以下
- (エンドポイント立ち上げ時)model_fn
- (API実行時)input_fn→predict_fn→output_fn
fnメソッド4種類ののカスタム例(上記API呼び出しを行う想定での改修)
# モデル読み込み def model_fn(model_dir): with open(os.path.join(model_dir,'my_model.txt')) as f: hello = f.read()[:-1] return hello # 前処理 def input_fn(input_data, content_type): if content_type == 'text/csv': transformed_data = input_data.split(',') else: raise ValueError("Illegal content type") return transformed_data # 予測 def predict_fn(transformed_data, model): prediction_list = [] for data in transformed_data: if data[-1] == '1': ordinal = f'{data}st' elif data[-1] == '2': ordinal = f'{data}nd' elif data[-1] == '3': ordinal = f'{data}rd' else: ordinal = f'{data}th' prediction = f'{model} for the {ordinal} time' prediction_list.append(prediction) return prediction_list # 後処理 def output_fn(prediction_list, accept): if accept == 'text/csv': response = '' for prediction in prediction_list: response += prediction + '¥n' else: raise ValueError("Illegal accept type") return response, accept
上記処理の際の想定output:
Hello my great machine learning model for the 1st time Hello my great machine learning model for the 2nd time Hello my great machine learning model for the 3rd time Hello my great machine learning model for the 10000th time
参考記事
default_output_fn参考元:
メインロジック参考元:
プログラムコード参考元: