BlockAI 베타서비스가 오픈되었습니다.
logo
BlockAI

testing

whdgusdl48
whdgusdl48
Copied 0Updated on 2024.10.10
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 ''' ---whdgusdl48/testing Auto Generate Code--- Author : whdgusdl48 Project Name: testing Project Link: https://blockai.kr/whdgusdl48/testing (BlockAI) Create Date : 2024-10-10 ---Requirements--- # 사용자의 환경(OS, CUDA 등)에 따라 라이브러리 버전을 맞춰주세요 pip install torch==1.12 torchvision==0.13.1 torchtext==0.13.1 torchaudio==0.12.1 pip install pytorch-lightning==2.0.4 pip install tqdm pip install pandas pip install scikit-learn pip install transformers pip install timm ---Folder Structure--- --📂 data --📄 testing.py --📄 testing.ipynb --📄 requirements.txt ''' import os import argparse import copy from glob import glob from tqdm import tqdm import numpy as np import pandas as pd from sklearn import preprocessing from sklearn.model_selection import train_test_split import torch import pytorch_lightning as pl path_sep = os.sep # https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files class Dataset(torch.utils.data.Dataset): pass # https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html class Dataloader(pl.LightningDataModule): pass # https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html class Model(pl.LightningModule): def __init__(self): super().__init__() self.save_hyperparameters() pass if __name__ == '__main__': # https://docs.python.org/ko/3/library/argparse.html # 하이퍼 파라미터 등 각종 설정값을 입력받습니다 # 터미널 실행 예시 : python3 run.py --batch_size=64 ... # 실행 시 '--batch_size=64' 같은 인자를 입력하지 않으면 default 값이 기본으로 실행됩니다 parser = argparse.ArgumentParser() parser.add_argument('--data_folder', default='./data') parser.add_argument('--batch_size', default=0) parser.add_argument('--max_epoch', default=0) parser.add_argument('--shuffle', default=False) parser.add_argument('--train_ratio', default=1.0) args = parser.parse_args() dataloader = Dataloader(args.data_folder, args.batch_size, args.train_ratio, args.shuffle) model = Model() # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html # 학습 및 추론을 위한 Trainer 설정 trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=args.max_epoch) trainer.fit(model=model, datamodule=dataloader) # trainer.test(model=model, datamodule=dataloader) # predictions = trainer.predict(model=model, datamodule=dataloader)