3D projection
추세선
# Implement stacked marginal distributions instead of blended gradients.
# That means for each position, show the proportion of red vs blue density stacked (like a bar split).
# This way, the y-axis in marginal strips shows portion of each color, not just blended color.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from scipy.stats import gaussian_kde
plt.close('all')
np.random.seed(4)
# Data
mean1 = [0.25, 0.25]
cov1 = [[0.01, 0.006],[0.006,0.01]]
mean2 = [0.0, -0.1]
cov2 = [[0.02, 0.012],[0.012,0.02]]
n1 = 5000
n2 = 20000
data1 = np.random.multivariate_normal(mean1, cov1, n1)
data2 = np.random.multivariate_normal(mean2, cov2, n2)
x1, y1 = data1.T
x2, y2 = data2.T
fig = plt.figure(figsize=(8,8), dpi=150)
left = 0.12; bottom = 0.12; width = 0.72; height = 0.72
main_ax = fig.add_axes([left, bottom, width, height])
left_ax = fig.add_axes([0.02, bottom, 0.08, height], sharey=main_ax)
bottom_ax = fig.add_axes([left, 0.02, width, 0.08], sharex=main_ax)
# Scatter
main_ax.scatter(x2, y2, s=2, color='#2b6fd6', alpha=0.25)
main_ax.scatter(x1, y1, s=2, color='#e45756', alpha=0.45)
# Ellipse
def draw_ellipse(mean, cov, ax, color, face_alpha=0.12, edge_alpha=0.45):
vals, vecs = np.linalg.eigh(cov)
order = vals.argsort()[::-1]
vals, vecs = vals[order], vecs[:,order]
theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
width, height = 2 * 2 * np.sqrt(vals)
ellip = Ellipse(xy=mean, width=width, height=height, angle=theta,
edgecolor=color, facecolor=color, lw=1.2, alpha=face_alpha)
ellip_edge = Ellipse(xy=mean, width=width, height=height, angle=theta,
edgecolor=color, facecolor='none', lw=1.2, alpha=edge_alpha)
ax.add_patch(ellip)
ax.add_patch(ellip_edge)
draw_ellipse(mean2, cov2, main_ax, '#2b6fd6')
draw_ellipse(mean1, cov1, main_ax, '#e45756')
# KDE densities
grid = np.linspace(-0.6, 0.6, 400)
kde_x_red = gaussian_kde(x1)
kde_x_blue = gaussian_kde(x2)
kde_y_red = gaussian_kde(y1)
kde_y_blue = gaussian_kde(y2)
dens_x_r = kde_x_red(grid)
dens_x_b = kde_x_blue(grid)
dens_y_r = kde_y_red(grid)
dens_y_b = kde_y_blue(grid)
# Normalize each marginal density so stacked portions sum to 1 at each position
for i in range(len(grid)):
s = dens_x_r[i] + dens_x_b[i]
if s > 0:
dens_x_r[i] /= s
dens_x_b[i] /= s
s2 = dens_y_r[i] + dens_y_b[i]
if s2 > 0:
dens_y_r[i] /= s2
dens_y_b[i] /= s2
# Bottom: stacked proportion (blue below, red on top)
bottom_ax.fill_between(grid, 0, dens_x_b, color='#2b6fd6', alpha=0.6)
bottom_ax.fill_between(grid, dens_x_b, dens_x_b+dens_x_r, color='#e45756', alpha=0.6)
bottom_ax.set_ylim(0,1)
# Left: stacked proportion horizontally (blue left, red right)
left_ax.fill_betweenx(grid, 0, dens_y_b, color='#2b6fd6', alpha=0.6)
left_ax.fill_betweenx(grid, dens_y_b, dens_y_b+dens_y_r, color='#e45756', alpha=0.6)
left_ax.set_xlim(0,1)
# Clean
bottom_ax.axis("off")
left_ax.axis("off")
main_ax.set_xlim(-0.6,0.6)
main_ax.set_ylim(-0.6,0.6)
main_ax.tick_params(labelsize=8)
main_ax.set_xlabel("")
main_ax.set_ylabel("")
out_path = "logit_weights_gradient_stacked.png"
plt.savefig(out_path, dpi=200, bbox_inches='tight', pad_inches=0.05)
# Adjust the GT bars so that:
# - They are lowered by exactly the same threshold offset applied to blue distributions (not 2x).
# - Ensure they render *in front* of the blue fills by plotting them after.
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.patches import Patch
plt.close('all')
x = np.linspace(0, 4, 1200)
def gauss(x, mu, sigma, amp=1.0):
return amp * np.exp(-0.5*((x-mu)/sigma)**2)
params = [(1.0, 0.5, 1.0),
(3.3, 0.5, 1.0),
(2.6, 0.5, 1.0)]
Zs = [gauss(x, mu, sigma, amp) for mu, sigma, amp in params]
steps = [0, 1, 2]
gt_x = [mu for mu, sigma, amp in params]
fig = plt.figure(figsize=(11.6, 7.6), dpi=120)
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=24, azim=-58)
pred_color = (65/255, 105/255, 225/255, 0.55)
pred_line = (65/255, 105/255, 225/255, 1.0)
gt_face = (220/255, 20/255, 60/255, 0.8)
ground_green = (0.0, 0.5, 0.0, 0.18)
threshold = 0.05
def contiguous_segments(mask):
segments = []
n = len(mask)
i = 0
while i < n:
if mask[i]:
j = i
while j+1 < n and mask[j+1]:
j += 1
segments.append((i, j))
i = j+1
else:
i += 1
return segments
# Draw blue distributions first
for y_step, z in zip(steps, Zs):
mask = z > threshold
segs = contiguous_segments(mask)
for (i0, i1) in segs:
x_seg = x[i0:i1+1]
z_seg = z[i0:i1+1].copy()
z_seg = z_seg - threshold
z_seg[z_seg < 0] = 0.0
verts = [(x_seg[0], y_step, 0.0)]
verts += [(xi, y_step, zi) for xi, zi in zip(x_seg, z_seg)]
verts.append((x_seg[-1], y_step, 0.0))
poly = Poly3DCollection([verts], facecolor=pred_color, edgecolor='none', zorder=1)
ax.add_collection3d(poly)
ax.plot(x_seg, np.full_like(x_seg, y_step), z_seg, lw=2.2, color=pred_line, zorder=2)
# GT bars after (so they appear in front)
bar_wx = 0.06
original_bar_h = 1.02
bar_h = max(0.0, original_bar_h - threshold) # lower by threshold offset only
for y_step, x0 in zip(steps, gt_x):
verts = [
[x0 - bar_wx/2, y_step, 0.0],
[x0 + bar_wx/2, y_step, 0.0],
[x0 + bar_wx/2, y_step, bar_h],
[x0 - bar_wx/2, y_step, bar_h],
]
poly = Poly3DCollection([verts], facecolor=gt_face, edgecolor='none', zorder=3)
ax.add_collection3d(poly)
panel = Poly3DCollection([
[[3.05, 0.0, 0.0], [4.05, 0.0, 0.0], [4.05, 2.0, 0.0], [3.05, 2.0, 0.0]]
], facecolor=ground_green, edgecolor='none')
ax.add_collection3d(panel)
ax.text(3.08, 0.2, 0.02, "Multi-token Numerical\nOptimization (over steps)",
color=(0.0, 0.45, 0.0), fontsize=14, rotation=78)
ax.text(0.15, -0.04, 0.02, "Step=0 EMD_loss=0.30", rotation=14)
ax.text(2.65, 0.96, 0.02, "Step=1 EMD_loss=0.67", rotation=14)
ax.text(0.75, 1.96, 0.02, "Step=2 EMD_loss=0.45", rotation=14)
ax.set_xlim(0, 4.1)
ax.set_ylim(-0.15, 2.15)
ax.set_zlim(0, 1.05)
ax.set_xlabel("Predicted Token", labelpad=18)
ax.set_ylabel("Time Step", labelpad=18)
ax.set_zlabel("Probability Density", labelpad=18)
ax.set_xticks([0, 1, 2, 3, 4])
ax.set_yticks([0, 1, 2])
ax.set_zticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
legend_handles = [
Patch(facecolor=gt_face, edgecolor='none', label='Ground Truth'),
Patch(facecolor=pred_color, edgecolor='none', label='Prediction')
]
ax.legend(handles=legend_handles, loc='lower left', bbox_to_anchor=(0.02, 0.02))
plt.tight_layout()
out_path = "ridge_recreation_bar_aligned.png"
plt.savefig(out_path, dpi=180, bbox_inches='tight')
out_path