사냥꾼의 IT 노트

Python 병렬 라이브러리 개발 프로젝트 - demo 코드 분석 본문

python

Python 병렬 라이브러리 개발 프로젝트 - demo 코드 분석

가면 쓴 사냥꾼 2023. 1. 5. 22:10

※본 포스팅은 22년 9월~22년 11에 진행된 프로젝트의 연구노트입니다.

demo.py

연구 과제를 준 기업에서 보낸 demo 파일. 해당 파일을 파이썬에 맞게 튜닝하고 코드를 분할해야 한다.

라이브러리 import

import ctypes
from ctypes.wintypes import POINT
import sys
from enum import Enum
  • ctypes : ctypes 라이브러리를 이용하기 위함
  • ctypes.wintypes : 특정 데이터형을 제공하여 구조체를 정의할 수 있게 함
  • sys : 파이썬 인터프리터를 제어하기 위함
  • enum : 여러 개 상수의 집합을 정의할 수 있음 

so 파일 불러오기

path =f'./libconnx.so'
connx = ctypes.cdll.LoadLibrary(path)
args = sys.argv
  • path : so 파일을 불러오기 위한 경로 지정
  • ctypes.cdll.LoadLibrary() : 컴파일된 so 파일을 불러옴
  • sys.argv : 미리 컴파일된 파일을 실행할 때, 인자값을 전달받아서 처리하기 위해 args에 값을 담음

C 구조체 정의

class connx_Tensor(ctypes.Structure) :
    pass

class connx_Graph(ctypes.Structure):
    pass

class CONNX_OPERATOR (ctypes.Structure):
    pass

class connx_DataType(ctypes.Structure) :
    pass

connx_Tensor._fields_ = [
        ("dtype", ctypes.c_int8),
        ("ndim", ctypes.c_int32),
        ("shape", ctypes.POINTER(ctypes.c_int32)),
        ("buffer", ctypes.c_void_p),
        ("size", ctypes.c_uint32),
        ("parent", ctypes.POINTER(connx_Tensor)),
        ("ref_count", ctypes.c_int32),
    ]

class connx_Node(ctypes.Structure):
    _fields_ = [
        ('output_count', ctypes.c_uint32),
        ('outputs', ctypes.POINTER(ctypes.c_uint32)),

        ('input_count', ctypes.c_uint32),
        ('inputs', ctypes.POINTER(ctypes.c_uint32)),

        ('attribute_count', ctypes.c_uint32),
        ('attributes', ctypes.POINTER((ctypes.c_void_p))),
        ('attribute_type', ctypes.POINTER(ctypes.c_uint32)),

        ('op_type', ctypes.c_char_p),
        ('op', CONNX_OPERATOR)
    ]

class connx_Model(ctypes.Structure):
    _fields_ = [
        ('version', ctypes.c_int32),

        ('opset_count', ctypes.c_uint32),
        ('opset_names', ctypes.POINTER(ctypes.POINTER(ctypes.c_char))),
        ('opset_versions', ctypes.POINTER(ctypes.c_uint32)),

        ('graph_count', ctypes.c_uint32),
        ('graphs', ctypes.POINTER(ctypes.POINTER(connx_Graph))),
    ]

connx_Graph._fields_ = [
    ('model', ctypes.POINTER(connx_Model)),

    ('id', ctypes.c_uint32),

    ('initializer_count', ctypes.c_uint32),
    ('initializers', ctypes.POINTER(ctypes.POINTER(connx_Tensor))),

    ('input_count', ctypes.c_uint32),
    ('inputs', ctypes.POINTER(ctypes.c_uint32)),

    ('output_count', ctypes.c_uint32),
    ('outputs', ctypes.POINTER(ctypes.c_uint32)),

    ('value_info_count', ctypes.c_uint32),
    ('value_infos', ctypes.POINTER(ctypes.POINTER(connx_Tensor))),

    ('node_count', ctypes.c_uint32),
    ('nodes', ctypes.POINTER(ctypes.POINTER(connx_Node))),
    
]

C 함수 정의

connx_init = connx.connx_init
connx_init.argtypes = None
connx_init.restype = None

connx_init_model_name = connx.connx_init_model_name
connx_init_model_name.argtypes = [ctypes.POINTER(ctypes.c_char)]
connx_init_model_name.restype = ctypes.c_int32

connx_load_Model = connx.connx_load_Model
connx_load_Model.argtypes = [ctypes.POINTER(connx_Model), ctypes.c_uint32]
connx_load_Model.restype = ctypes.c_int32

temp_set_input_count = connx.temp_set_input
temp_set_input_count.argtypes = [ctypes.c_uint32]
temp_set_input_count.restype = None

connx_alloc = connx.connx_alloc
connx_alloc.argtypes = [ctypes.c_size_t]
connx_alloc.restype = ctypes.c_void_p

connx_convert_input_file = connx.connx_convert_input_file
connx_convert_input_file.argtypes = [ctypes.POINTER(ctypes.POINTER(connx_Tensor))]
connx_convert_input_file.restype = ctypes.c_int32

connx_convert_output_file = connx.connx_convert_output_file
connx_convert_output_file.argtypes = [ctypes.POINTER(ctypes.POINTER(connx_Tensor))]
connx_convert_output_file.restype = ctypes.c_int32

connx_run_Model = connx.connx_run_Model
connx_run_Model.argtypes = [ctypes.POINTER(connx_Model), ctypes.POINTER(ctypes.POINTER(connx_Tensor)), ctypes.POINTER(ctypes.POINTER(connx_Tensor))]
connx_run_Model.restype = ctypes.c_int32

compare = connx.compare
compare.argtypes = [ctypes.POINTER(ctypes.POINTER(connx_Tensor)), ctypes.POINTER(ctypes.POINTER(connx_Tensor))]
compare.restype = ctypes.c_int32

connx_Model_destroy = connx.connx_Model_destroy
connx_Model_destroy.argtypes = [ctypes.POINTER(connx_Model)]
connx_Model_destroy.restype = None

connx 모델을 구축하는 로직으로, 포인터로 C에서 정의된 값의 주소를 불러와 이용할 수 있다. 이는 반환타입의 설정이 없기 때문에 return할 경우 int로 강제 형변환을 적용하기 위함이다.

main code

ret = ctypes.c_int32()

connx_init()
model = connx_Model()
print(model)
ret = connx_init_model_name(b'./mnist')
if ret != 0:
    print('connx_init_model_name failed')
    exit(1)

ret = connx_load_Model(ctypes.byref(model), 1)
if ret != 0:
    print('connx_load failed')
    exit(1)

input_count = 0
graph = model.graphs[0].contents
if graph.input_count - graph.initializer_count < 0 :
    input_count = 1
else:
    input_count = graph.input_count - graph.initializer_count

temp_set_input_count(input_count)

ouput_count = graph.output_count

inputs = ctypes.POINTER(connx_Tensor)()

outputs = ctypes.POINTER(connx_Tensor)()

ret = connx_convert_input_file(ctypes.byref(inputs))

print(1)

if ret != 0:
    print("Error: input file is not correct")
    exit(1)

ret = connx_run_Model(ctypes.byref(model), ctypes.byref(inputs), ctypes.byref(outputs))

if ret != 0:
    print("Error: run model failed")
    exit(1)

onnx_result = ctypes.POINTER(connx_Tensor)()

ret = connx_convert_output_file(ctypes.byref(onnx_result))

if ret != 0:
    print("Error: output file is not correct")
    exit(1)

ret = compare(ctypes.byref(onnx_result), ctypes.byref(outputs))

connx model을 이용하기 위한 로직

  • connx_init_model_name() : ./mnist의 데이터셋을 이용해 모델명을 불러옴
  • connx_load_model() : connx_model()을 불러옴
  • input_count = 0 ~ print(1) : mnist의 모델을 불러와 그래프까지 그리기 위한 로직
    • ctypes.byref() : 매개 변수를 참조로 전달하기 위해 사용하는 부분. 이는 포인터와 거의 동일한 기능이나, 포인터는 실제 포인터 객체를 생성해 더 많은 메모리를 잡아먹으므로 포인터가 아닌 byref가 성능에 더 도움이 된다.
  •  onnx_result = ctypes.POINTER(connx_Tensor)() : 결과 값을 저장

실행 결과