-
Notifications
You must be signed in to change notification settings - Fork 1
/
route_ai.py
130 lines (125 loc) · 3.91 KB
/
route_ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
"""route_AI.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Agb0s1mgyZWzOIDpshJeEYNjJuObuN_L
"""
# # -*- coding: utf-8 -*-
# """ai_fp_route_optimization.ipynb
#
# Automatically generated by Colaboratory.
#
# Original file is located at
# https://colab.research.google.com/drive/1mkd8zWzcmFTUBYzIR-kt1emzqieAt2Uz
#
# import 函式庫
# """
#
import math
import glob
import json
import numpy as np
from tensorflow import keras
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Dropout
from sklearn.preprocessing import MaxAbsScaler
from sklearn.preprocessing import OneHotEncoder
from keras.models import load_model
#
ALL_ROUTE = ['R', 'E', '$', '?','M','T']
#
class Route_ai:
# """docstring for Route_ai."""
#
def __init__(self):
self.model=load_model('AI_FP_ROUTE.h5')
#
def encode_single(self, value, category):
np_array = np.array([[value]])
encoder = OneHotEncoder(categories=[category], sparse=False)
onehot_encoded = encoder.fit_transform(np_array)
collapsed = np.sum(onehot_encoded, axis=0)
# # inverse = encoder.inverse_transform(collapsed[np.newaxis, ...])
# # print(np.array_equal(np_array, inverse))
return collapsed
#
def encode_route(self,route):
# """
# Encode the chosen character into a one-hot vector of length ALL_CHARACTERS
# """
ALL_ROUTE.sort()
return self.encode_single(route,ALL_ROUTE)
#
def encode_sample_with_loop(self, sample):
# """
# Encode a single sample into a 1D vector
# [{"cards": 10, "relics": 1, "ascension": 20, "character": "IRONCLAD", "floor": 0, "potions": 0, "path": "M", "max_hp": 82, "current_hp": 82, "gold": 99, "value": 2884.076040777472, "upgrade_cards": 0, "curse_cards": 1}
# """
route=self.encode_route(sample['path'])
# # print(character)
num_data = np.array([sample['cards'],sample['relics'],sample['ascension'],sample['floor'],sample['potions'],sample['max_hp'],sample['current_hp'],sample['gold'],sample['upgrade_cards'],sample['curse_cards']])
x=np.concatenate((route, num_data))
# #print(len(x))
return x
#return num_data
#
def preprocess_with_loop(self, data):
preprocess_list=[]
y=[]
for i,sample in enumerate(data):
if sample['path']!= None and sample['path']!='B':
preprocess_list.append(self.encode_sample_with_loop(sample))
y.append(sample['value'])
X=np.vstack(preprocess_list)
#print(len(X))
Y=np.array(y,dtype='float64')
# #print(len(Y))
return X,Y
#
def scale_X(self,X_data):
# """
# Used with one hot encoded model
# """
X_copy = np.copy(X_data)
X_copy=np.array(X_copy).reshape(1,-1)
max_abs_scaler = MaxAbsScaler()
X_maxabs = max_abs_scaler.fit_transform(X_copy)
# # with open('input_scales.json', 'w') as out_file:
# # json.dump(max_abs_scaler.scale_.tolist(), out_file)
return X_maxabs
#
def scale_Y(self, Y_data):
Y_copy = np.copy(Y_data)
#
# # Scale Y
Y_copy /= 10
#
# # To allow healing (negative damage), uncomment `Y[Y < -1] = -1` and comment out `Y[Y < 0] = 0`
# #Y_copy[Y_copy < -1] = -1 # Healing (negative damage)
# # Y_copy[Y_copy < 0] = 0 # No healing
#
# # Cap damage taken at 100
Y_copy[Y_copy > 800] = 800
return Y_copy
#
def predict(self, data):
x=self.model.predict(self.scale_X(self.encode_sample_with_loop(data)))
return x
fuck=Route_ai()
shit=dict()
shit={
"cards": 13,
"upgrade_cards": 0,
"curse_cards": 0,
"relics": 1,
"ascension": 0,
"floor": 4,
"potions": 1,
"max_hp": 80,
"current_hp": 71,
"gold": 152,
"path": "E"
}
print(fuck.predict(shit))