Skip to content

Commit

Permalink
feat: refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
Lan Le committed Apr 17, 2024
1 parent c9146d8 commit 12a157c
Showing 1 changed file with 47 additions and 40 deletions.
87 changes: 47 additions & 40 deletions chem_spectra/lib/composer/ni.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def tf_img(self):
plt.plot([xL, xU], [itg_h, itg_h], color='#228B22')
plt.plot([xL, xL], [itg_h + h * 0.01, itg_h - h * 0.01], color='#228B22') # noqa: E501
plt.plot([xU, xU], [itg_h + h * 0.01, itg_h - h * 0.01], color='#228B22') # noqa: E501
plt.text((xL + xU) / 2, itg_h + h * 0.015, '{:0.2f}'.format(area), color='#228B22', ha='center', size=12) # noqa: E501
plt.text((xL + xU) / 2, itg_h + h * 0.015, '{:0.2f}'.format(area), color='#228B22', ha='center', size=10, rotation=-90.) # noqa: E501
# cys = (ks[iL:iU] - ref) * 1.5 + (y_max - h * 0.4)
cys = (ks[iL:iU] - ref) * 1.5 + itg_h + itg_h * 0.1
plt.plot(cxs, cys, color='#228B22')
Expand Down Expand Up @@ -495,7 +495,17 @@ def tf_img(self):
plt.locator_params(nbins=self.__plt_nbins())
plt.grid(False)


self.__draw_peaks(plt, x_peaks, y_peaks, h, w)

# Save
tf_img = tempfile.NamedTemporaryFile(suffix='.png')
plt.savefig(tf_img, format='png')
tf_img.seek(0)
plt.clf()
plt.cla()
return tf_img

def __draw_peaks(self, plt, x_peaks, y_peaks, h, w):
# TODO: Need to be refactor
ax = plt.gca()
differences = np.diff(x_peaks)
Expand Down Expand Up @@ -523,53 +533,50 @@ def tf_img(self):
max_values = [np.max(items) for items in groups_y]
diff_max_values = np.diff(max_values)

x_boundary_min, x_boundary_max = np.min(x_peaks), np.max(x_peaks)

for i in range(len(groups_x)):
mygroup_x = groups_x[i]
mygroup_y = groups_y[i]
max_current_group = max_values[i] + h * 0.2
if (i > 0):
prev_max_group = max_values[i-1] + h * 0.2
my_gap = abs(max_current_group - prev_max_group)
if my_gap < h*0.1:
max_current_group = max_current_group + h * 0.45
middle_idx = int(len(mygroup_x)/2)
gap_value = np.mean(mygroup_x)
x_text = 0
y_text = 20
for j in range(len(mygroup_x)):
x_pos = mygroup_x[j]
y_pos = mygroup_y[j] + h * 0.5
x_float = '{:.2f}'.format(x_pos)
peak_label = '{x}'.format(x=x_float)
if j > middle_idx:
x_text = (-w * 0.02) * (j-middle_idx)
elif j < middle_idx:
x_text = w * 0.02 * j

ax.add_patch(FancyArrowPatch((x_pos, max_current_group), (x_pos, max_current_group + h * 0.05)))
ax.add_patch(FancyArrowPatch((x_pos, max_current_group + h * 0.05), (gap_value + x_text, max_current_group + h * 0.2)))
# ax.add_patch(FancyArrowPatch((gap_value + x_text, max_current_group + h * 0.2), (gap_value + x_text, max_current_group + h * 0.25)))

ax.annotate(peak_label,
xy=(gap_value + x_text, max_current_group + h * 0.2), xycoords='data',
xytext=(0, 20), textcoords='offset points',
arrowprops=dict(arrowstyle="-"),
rotation=90, size=10)

# for i in range(len(x_peaks)):
# x_pos = x_peaks[i]
# y_pos = y_peaks[i] + h * 0.1
# x_float = '{:.2e}'.format(x_pos)
# y_float = '{:.2e}'.format(y_peaks[i])
# peak_label = '{x}'.format(x=x_float)
# ax.annotate(peak_label,
# xy=(x_pos, y_boundary_max - y_boundary_max * 0.05), xycoords='data',
# xytext=(0, 20), textcoords='offset points',
# arrowprops=dict(arrowstyle="-", connectionstyle="arc3"),
# rotation=90, size=10)

# Save
tf_img = tempfile.NamedTemporaryFile(suffix='.png')
plt.savefig(tf_img, format='png')
tf_img.seek(0)
plt.clf()
plt.cla()
return tf_img
x_pos = mygroup_x[j]
y_pos = mygroup_y[j] + h * 0.5
x_float = '{:.2f}'.format(x_pos)
peak_label = '{x}'.format(x=x_float)
if j >= middle_idx:
x_text = -(w * 0.005) * (middle_idx - j)
elif j < middle_idx:
x_text = w * 0.005 * (j-middle_idx)

ax.add_patch(FancyArrowPatch((x_pos, max_current_group), (x_pos, max_current_group + h * 0.05), linewidth=0.2))
ax.add_patch(FancyArrowPatch((x_pos, max_current_group + h * 0.05), (gap_value + x_text, max_current_group + h * 0.2), linewidth=0.2))

x_boundary_min = min(gap_value + x_text, x_boundary_min)
x_boundary_max = max(gap_value + x_text, x_boundary_max)

ax.annotate(peak_label,
xy=(gap_value + x_text, max_current_group + h * 0.2), xycoords='data',
xytext=(0, 20), textcoords='offset points',
arrowprops=dict(arrowstyle="-", linewidth=0.2),
rotation=90, size=10)

x_boundary_min = x_boundary_min - w * 0.02
x_boundary_max = x_boundary_max + w * 0.02
plt.xlim(
x_boundary_max,
x_boundary_min,
)
return plt

def __prepare_metadata_info_for_csv(self, csv_writer: csv.DictWriter):
csv_writer.writerow({
Expand Down

0 comments on commit 12a157c

Please sign in to comment.